Giter Site home page Giter Site logo

Comments (5)

nok avatar nok commented on May 30, 2024 1

Yes you are right. In general all estimators in all programming languages return the index of the resulted label (y), because it would be an overhead to reimplement the mapping for each programming language. Nevertheless I noted this requirement for a future release.

from sklearn-porter.

vijaykilledar avatar vijaykilledar commented on May 30, 2024

for export json option can we add classes name array to JSON data ?

from sklearn-porter.

HTCode avatar HTCode commented on May 30, 2024

A hack for this would be to put in the exported JSON data few labelled training samples ((x_i, y_i), ...) from each class for which we know that the python classifier predicts correctly their classes (i.e. the most confident training samples from each class). Then in the target language, one can match the indexes provided by clf.predict(x_i) to their actual labels y_i ...

from sklearn-porter.

skjerns avatar skjerns commented on May 30, 2024

I just came across this problem as well:

In my case, I have input classes ranging from [1,2,3,4,5], however there is no example in the training for class 2. As a result, the C-version of my random forest outputs classes [1,2,3,4], with 2,3,4 being actually 3,4,5. Is there any way to prevent that, or are there ideas to fix this without tampering with the C code?

I have a semi-production pipeline where sometimes classes are not part of the training set, and I would be glad to have some way to automatically correct that without manually putting class labels into the c code.

(see also BayesWitnesses/m2cgen#77 where I outline this problem in more detail)

from sklearn-porter.

skjerns avatar skjerns commented on May 30, 2024

I solved it temporarily by writing a small wrapper:

It adds a conversion function to the c code and embeds it:

int idx2label(int class_idx) { 
    int labels[5] = {0,2,3,4,5}; // your original ints
    return labels[class_idx];
}
import sklearn_porter
def save_model_sklearn_porter(clf, file):
    """
    Saves an sklearn model which keeps the original class IDs, even if they are not consecutive.     
    """
    porter = sklearn_porter.Porter(clf, language='C')
    output = porter.export(embed_data=True)
       
    # see which labels are in the classifier, so far only ints are supported
    labels = [str(int(i)) for i in clf.classes_]
    
    # create new label code and conversion function
    labels_code = 'int labels[{}] = {{{}}}'.format(len(labels), ','.join(labels))
    convert_func = '\n\nint idx2label(int class_idx) { \n' +\
                   '    {};\n    return labels[class_idx];\n}}\n\n'.format(labels_code)
    
    # insert this function in the beginning of the file
    lines = output.splitlines()
    position = 0
    for idx, line in enumerate(lines): 
        if line.strip().startswith('#'): position=idx
    lines.insert(position+1, convert_func)
    output = '\n'.join(lines)
    
    # replace last occurrence of `return class_idx` with the label transfer function
    # with [::-1] we can revert the string and look for the first element as if it where the last
    output = output[::-1].replace('return class_idx'[::-1], 'return idx2label(class_idx)'[::-1], 1)[::-1]
    
    with open(file, 'w') as file:
        file.write(output)
    return output

from sklearn-porter.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.