Giter Site home page Giter Site logo

cnn_spectrogram_algorithm's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

eden-kramer-lab

cnn_spectrogram_algorithm's Issues

Changes to SpectrogramClassificationAlgorithm.ipynb

  1. Add these two lines.
# train the model, run on test data and get output dat
NEWPATH = set_patient(0)
arch=resnet34                    <--- ADD THIS LINE
data = get_the_data(arch)          <--- ADD THIS INPUT TO FUNCTION
dat = train_the_model(arch,data)
  1. Add input arch to def get_the_data(arch):

  2. Use:

def train_the_model(arch,data):
    #train the model
    learn = ConvLearner.pretrained(arch,data,precompute=True)
    lr = 1e-2
    learn.fit(lr,1)
    learn.precompute = False
    learn.fit(1e-3,3,cycle_len=1)
    learn.unfreeze()
    lr = np.array([1e-4,1e-3,1e-2])
    cb = [EarlyStopping(learn,save_path='best_mod',patience = 6)]
    learn.fit(lr,6,cycle_len=1,cycle_mult=2,callbacks=cb)
    torch.save(learn.model.state_dict(),'test_saved_model.pkl')
    
    #get output predictions and probabilities
    log_preds_test = learn.predict(is_test=True)
    preds_test = np.argmax(log_preds_test,axis=1)
    probs_test = np.exp(log_preds_test[:,1])
    
    #make test: a dataframe of test image names, predictions, and probabilities
    test_names = np.empty_like(data.test_ds.fnames)
    for i in range(len(data.test_ds.fnames)):
        test_names[i] = data.test_ds.fnames[i]
        #temp = data.test_ds.fnames[i]
        #matchobj = re.search('.*im.*',temp)
        #test_names[i] = matchobj.group()
    test = pd.DataFrame(data = test_names,columns = ['image_number'])
    test['prediction'] = preds_test
    test['probability'] = probs_test
    
    return test

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.