Giter Site home page Giter Site logo

class_labs's People

Contributors

bpark738 avatar tomthetrainer avatar turambar avatar

Stargazers

 avatar

Watchers

 avatar  avatar  avatar

class_labs's Issues

Comment the heck out of things in a Gist for me please

@turambar @bpark738
I was attempting to write the lab today, and in order to do that I like to describe each step.

I struggled to grasp the dataset and the transformation, perhaps if one of you did a pass through the example and commented almost everything for me in a gist that would help.

Here is what I mean

My comments/questions will start with //TH so you can grep for them

public static void main(String[] args) throws IOException, InterruptedException {

        // STEP 0: Flags controlling which data

        // 0 for removing Time and Elapsed columns; 1 for removing Time; 2 for removing Elapsed
        //TH what does the data look like before we do this, what is Time and Elapsed, maybe data sample
        int remove = 0;
        int numLabelClasses = 2;
        boolean resampled = false; // If true use resampled data
       //TH I get iterators, no need to explain these
        DataSetIterator trainData;
        DataSetIterator validData;
        DataSetIterator testData;

        if(resampled){
           //TH Grandpa needs some help here, resampled in what way, I assume simplify by less timesteps ?
            NB_INPUTS-=2;
            featuresDir = new File(baseDir, "resampled");

            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(1, ",");
            trainFeatures.initialize( new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));

            trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels,
                    BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            // Load validation data
//TH I could review numbered file input split and all that, but a comment from either of you might be quicker
            SequenceRecordReader validFeatures = new CSVSequenceRecordReader(1, ",");
            validFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES  - 1));
            SequenceRecordReader validLabels = new CSVSequenceRecordReader();
            validLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES  - 1));


            validData = new SequenceRecordReaderDataSetIterator(validFeatures, validLabels,
                    BATCH_SIZE, numLabelClasses, false,SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            // Load test data
            SequenceRecordReader testFeatures = new CSVSequenceRecordReader(1, ",");
            testFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
            SequenceRecordReader testLabels = new CSVSequenceRecordReader();
            testLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));


            testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels,
                    BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
        }

        else{
            
//TH I get the schema, no need to dive to deep for me on this.
Schema schema =  new SequenceSchema.Builder()
                    .addColumnsDouble("Time","Elapsed","ALP").addColumnCategorical("ALPMissing")
                    .addColumnDouble("ALT").addColumnCategorical("ALTMissing").addColumnDouble("AST")
                    .addColumnCategorical("ASTMissing").addColumnDouble("Age").addColumnCategorical("AgeMissing")
                    .addColumnDouble("Albumin").addColumnCategorical("AlbuminMissing").addColumnDouble("BUN")
                    .addColumnCategorical("BUNMissing").addColumnDouble("Bilirubin").addColumnCategorical("BilirubinMissing")
                    .addColumnDouble("Cholesterol").addColumnCategorical("CholesterolMissing").addColumnDouble("Creatinine")
                    .addColumnCategorical("CreatinineMissing").addColumnDouble("DiasABP").addColumnCategorical("DiasABPMissing")
                    .addColumnDouble("FiO2").addColumnCategorical("FiO2Missing").addColumnDouble("GCS")
                    .addColumnCategorical("GCSMissing").addColumnCategorical("Gender0").addColumnCategorical("Gender1")
                    .addColumnDouble("Glucose").addColumnCategorical("GlucoseMissing").addColumnDouble("HCO3")
                    .addColumnCategorical("HCO3Missing").addColumnDouble("HCT").addColumnCategorical("HCTMissing")
                    .addColumnDouble("HR").addColumnCategorical("HRMissing").addColumnDouble("Height")
                    .addColumnCategorical("HeightMissing").addColumnCategorical("ICUType1").addColumnCategorical("ICUType2")
                    .addColumnCategorical("ICUType3").addColumnCategorical("ICUType4").addColumnDouble("K")
                    .addColumnCategorical("KMissing").addColumnDouble("Lactate").addColumnCategorical("LactateMissing")
                    .addColumnDouble("MAP").addColumnCategorical("MAPMissing").addColumnDouble("MechVent")
                    .addColumnCategorical("MechVentMissing").addColumnDouble("Mg").addColumnCategorical("MgMissing")
                    .addColumnDouble("NIDiasABP").addColumnCategorical("NIDiasABPMissing").addColumnDouble("NIMAP")
                    .addColumnCategorical("NIMAPMissing").addColumnDouble("NISysABP").addColumnCategorical("NISysABPMissing")
                    .addColumnDouble("Na").addColumnCategorical("NaMissing").addColumnDouble("PaCO2")
                    .addColumnCategorical("PaCO2Missing").addColumnDouble("PaO2").addColumnCategorical("PaO2Missing")
                    .addColumnDouble("Platelets").addColumnCategorical("PlateletsMissing").addColumnDouble("RespRate")
                    .addColumnCategorical("RespRateMissing").addColumnDouble("SaO2").addColumnCategorical("SaO2Missing")
                    .addColumnDouble("SysABP").addColumnCategorical("SysABPMissing").addColumnDouble("Temp")
                    .addColumnCategorical("TempMissing").addColumnDouble("TroponinI").addColumnCategorical("TroponinIMissing")
                    .addColumnDouble("TroponinT").addColumnCategorical("TroponinTMissing").addColumnDouble("Urine")
                    .addColumnCategorical("UrineMissing").addColumnDouble("WBC").addColumnCategorical("WBCMissing")
                    .addColumnDouble("Weight").addColumnCategorical("WeightMissing").addColumnDouble("pH")
                    .addColumnCategorical("pHMissing").build();
            
            TransformProcess transformProcess;
//TH so we are removing some or more depending on a variable that is set, please comment that for me, or split into separate classes, I know you hate redundant code, I sort of love it, keep it simple for the masses I say.  Either way, do what is efficient and workable
            if(remove == 0){
                transformProcess = new TransformProcess.Builder(schema).removeColumns("Time", "Elapsed").build();
                NB_INPUTS-=2;
            }
            else if(remove == 1){
                transformProcess = new TransformProcess.Builder(schema).removeColumns("Time").build();
                NB_INPUTS-=1;
            }
            else if(remove == 2){
                transformProcess = new TransformProcess.Builder(schema).removeColumns("Elapsed").build();
                NB_INPUTS-=1;
            }
            else{
                transformProcess = new TransformProcess.Builder(schema).build();
            }

            // Load training data
            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(1, ",");
            trainFeatures.initialize( new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));

            TransformProcessSequenceRecordReader trainRemovedFeatures = new TransformProcessSequenceRecordReader(trainFeatures, transformProcess);

            trainData = new SequenceRecordReaderDataSetIterator(trainRemovedFeatures, trainLabels,
                    BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            // Load validation data
            SequenceRecordReader validFeatures = new CSVSequenceRecordReader(1, ",");
            validFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES  - 1));
            SequenceRecordReader validLabels = new CSVSequenceRecordReader();
            validLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES  - 1));
            TransformProcessSequenceRecordReader validRemovedFeatures = new TransformProcessSequenceRecordReader(validFeatures, transformProcess);


            validData = new SequenceRecordReaderDataSetIterator(validRemovedFeatures, validLabels,
                    BATCH_SIZE, numLabelClasses, false,SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            // Load test data
            SequenceRecordReader testFeatures = new CSVSequenceRecordReader(1, ",");
            testFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
            SequenceRecordReader testLabels = new CSVSequenceRecordReader();
            testLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
            TransformProcessSequenceRecordReader testRemovedFeatures = new TransformProcessSequenceRecordReader(testFeatures, transformProcess);


            testData = new SequenceRecordReaderDataSetIterator(testRemovedFeatures, testLabels,
                    BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
        }

        // STEP 1: ETL/vectorization


        // STEP 2: Model configuration and initialization

        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(RANDOM_SEED)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(LEARNING_RATE)
                .weightInit(WeightInit.XAVIER)
                .updater(Updater.ADAM)
                .graphBuilder()
                .addInputs("trainFeatures")
                .setOutputs("predictMortality")
                .addLayer("L1", new GravesLSTM.Builder()
                                .nIn(NB_INPUTS)
                                .nOut(lstmLayerSize)
                                .activation(Activation.TANH)
                                .build(),
                        "trainFeatures")
                .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.XENT)
                        .activation(Activation.SOFTMAX)
                        .weightInit(WeightInit.XAVIER)
                        .nIn(lstmLayerSize).nOut(numLabelClasses).build(),"L1")
                .pretrain(false).backprop(true)
                .build();

        // STEP 3 Performance monitoring

        ComputationGraph model = new ComputationGraph(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(10));

        // STEP 4 Model training

        for( int i=0; i<NB_EPOCHS; i++ ){

            model.fit(trainData); // implicit inner loop over minibatches

            // loop over batches in training data to compute training AUC

//TH would like to define AUC, and ROC and maybe point to docs
//TH not exactly the correct place for this, but they will want to discuss what the output is, so step back from the output layer which I assume is softmax or what is needed for classification, what is the output at each neuron, like one layer back? It gets a collection of sequences in, emits a collection of sequences?
//TH Thanks
            ROC roc = new ROC(100);
            trainData.reset();

            while(trainData.hasNext()){
                DataSet batch = trainData.next();
                INDArray[] output = model.output(batch.getFeatures());
                roc.evalTimeSeries(batch.getLabels(), output[0]);
            }

            log.info("EPOCH " + i + " TRAIN AUC: " + roc.calculateAUC());

            roc = new ROC(100);
            while (validData.hasNext()) {
                DataSet batch = validData.next();
                INDArray[] output = model.output(batch.getFeatures());
                roc.evalTimeSeries(batch.getLabels(), output[0]);
            }

            log.info("EPOCH " + i + " VALID AUC: " + roc.calculateAUC());

            trainData.reset();
            validData.reset();
        }

        ROC roc = new ROC(100);

        while (testData.hasNext()) {
            DataSet batch = testData.next();
            INDArray[] output = model.output(batch.getFeatures());
            roc.evalTimeSeries(batch.getLabels(), output[0]);
        }
        log.info("***** Test Evaluation *****");
        log.info("{}", roc.calculateAUC());
    }
}


Extend physionet LSTM example

  • add example with time columns removed (or add config options to existing example)
  • add example using resampled data (or add config options to existing example)
  • once code is ready, add to physionet branch, make a PR, and request reviews from @DaveKale and @tomthetrainer
  • @turambar: upload new data including resampled time series
  • run experiments without time columns to see how performance affected
  • run experiments with resampled time series, compare performance vs. raw time series
  • target replication, a la Learning to Diagnose
  • (maybe) multitask LSTM, a la Multitask Clinical Time Series

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.