tomthetrainer / class_labs Goto Github PK
View Code? Open in Web Editor NEWClass labs to share with other trainers only
Class labs to share with other trainers only
@turambar @bpark738
I was attempting to write the lab today, and in order to do that I like to describe each step.
I struggled to grasp the dataset and the transformation, perhaps if one of you did a pass through the example and commented almost everything for me in a gist that would help.
Here is what I mean
My comments/questions will start with //TH so you can grep for them
public static void main(String[] args) throws IOException, InterruptedException {
// STEP 0: Flags controlling which data
// 0 for removing Time and Elapsed columns; 1 for removing Time; 2 for removing Elapsed
//TH what does the data look like before we do this, what is Time and Elapsed, maybe data sample
int remove = 0;
int numLabelClasses = 2;
boolean resampled = false; // If true use resampled data
//TH I get iterators, no need to explain these
DataSetIterator trainData;
DataSetIterator validData;
DataSetIterator testData;
if(resampled){
//TH Grandpa needs some help here, resampled in what way, I assume simplify by less timesteps ?
NB_INPUTS-=2;
featuresDir = new File(baseDir, "resampled");
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(1, ",");
trainFeatures.initialize( new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels,
BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Load validation data
//TH I could review numbered file input split and all that, but a comment from either of you might be quicker
SequenceRecordReader validFeatures = new CSVSequenceRecordReader(1, ",");
validFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES - 1));
SequenceRecordReader validLabels = new CSVSequenceRecordReader();
validLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES - 1));
validData = new SequenceRecordReaderDataSetIterator(validFeatures, validLabels,
BATCH_SIZE, numLabelClasses, false,SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Load test data
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(1, ",");
testFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
testLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels,
BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
}
else{
//TH I get the schema, no need to dive to deep for me on this.
Schema schema = new SequenceSchema.Builder()
.addColumnsDouble("Time","Elapsed","ALP").addColumnCategorical("ALPMissing")
.addColumnDouble("ALT").addColumnCategorical("ALTMissing").addColumnDouble("AST")
.addColumnCategorical("ASTMissing").addColumnDouble("Age").addColumnCategorical("AgeMissing")
.addColumnDouble("Albumin").addColumnCategorical("AlbuminMissing").addColumnDouble("BUN")
.addColumnCategorical("BUNMissing").addColumnDouble("Bilirubin").addColumnCategorical("BilirubinMissing")
.addColumnDouble("Cholesterol").addColumnCategorical("CholesterolMissing").addColumnDouble("Creatinine")
.addColumnCategorical("CreatinineMissing").addColumnDouble("DiasABP").addColumnCategorical("DiasABPMissing")
.addColumnDouble("FiO2").addColumnCategorical("FiO2Missing").addColumnDouble("GCS")
.addColumnCategorical("GCSMissing").addColumnCategorical("Gender0").addColumnCategorical("Gender1")
.addColumnDouble("Glucose").addColumnCategorical("GlucoseMissing").addColumnDouble("HCO3")
.addColumnCategorical("HCO3Missing").addColumnDouble("HCT").addColumnCategorical("HCTMissing")
.addColumnDouble("HR").addColumnCategorical("HRMissing").addColumnDouble("Height")
.addColumnCategorical("HeightMissing").addColumnCategorical("ICUType1").addColumnCategorical("ICUType2")
.addColumnCategorical("ICUType3").addColumnCategorical("ICUType4").addColumnDouble("K")
.addColumnCategorical("KMissing").addColumnDouble("Lactate").addColumnCategorical("LactateMissing")
.addColumnDouble("MAP").addColumnCategorical("MAPMissing").addColumnDouble("MechVent")
.addColumnCategorical("MechVentMissing").addColumnDouble("Mg").addColumnCategorical("MgMissing")
.addColumnDouble("NIDiasABP").addColumnCategorical("NIDiasABPMissing").addColumnDouble("NIMAP")
.addColumnCategorical("NIMAPMissing").addColumnDouble("NISysABP").addColumnCategorical("NISysABPMissing")
.addColumnDouble("Na").addColumnCategorical("NaMissing").addColumnDouble("PaCO2")
.addColumnCategorical("PaCO2Missing").addColumnDouble("PaO2").addColumnCategorical("PaO2Missing")
.addColumnDouble("Platelets").addColumnCategorical("PlateletsMissing").addColumnDouble("RespRate")
.addColumnCategorical("RespRateMissing").addColumnDouble("SaO2").addColumnCategorical("SaO2Missing")
.addColumnDouble("SysABP").addColumnCategorical("SysABPMissing").addColumnDouble("Temp")
.addColumnCategorical("TempMissing").addColumnDouble("TroponinI").addColumnCategorical("TroponinIMissing")
.addColumnDouble("TroponinT").addColumnCategorical("TroponinTMissing").addColumnDouble("Urine")
.addColumnCategorical("UrineMissing").addColumnDouble("WBC").addColumnCategorical("WBCMissing")
.addColumnDouble("Weight").addColumnCategorical("WeightMissing").addColumnDouble("pH")
.addColumnCategorical("pHMissing").build();
TransformProcess transformProcess;
//TH so we are removing some or more depending on a variable that is set, please comment that for me, or split into separate classes, I know you hate redundant code, I sort of love it, keep it simple for the masses I say. Either way, do what is efficient and workable
if(remove == 0){
transformProcess = new TransformProcess.Builder(schema).removeColumns("Time", "Elapsed").build();
NB_INPUTS-=2;
}
else if(remove == 1){
transformProcess = new TransformProcess.Builder(schema).removeColumns("Time").build();
NB_INPUTS-=1;
}
else if(remove == 2){
transformProcess = new TransformProcess.Builder(schema).removeColumns("Elapsed").build();
NB_INPUTS-=1;
}
else{
transformProcess = new TransformProcess.Builder(schema).build();
}
// Load training data
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(1, ",");
trainFeatures.initialize( new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
TransformProcessSequenceRecordReader trainRemovedFeatures = new TransformProcessSequenceRecordReader(trainFeatures, transformProcess);
trainData = new SequenceRecordReaderDataSetIterator(trainRemovedFeatures, trainLabels,
BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Load validation data
SequenceRecordReader validFeatures = new CSVSequenceRecordReader(1, ",");
validFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES - 1));
SequenceRecordReader validLabels = new CSVSequenceRecordReader();
validLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES - 1));
TransformProcessSequenceRecordReader validRemovedFeatures = new TransformProcessSequenceRecordReader(validFeatures, transformProcess);
validData = new SequenceRecordReaderDataSetIterator(validRemovedFeatures, validLabels,
BATCH_SIZE, numLabelClasses, false,SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Load test data
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(1, ",");
testFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
testLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
TransformProcessSequenceRecordReader testRemovedFeatures = new TransformProcessSequenceRecordReader(testFeatures, transformProcess);
testData = new SequenceRecordReaderDataSetIterator(testRemovedFeatures, testLabels,
BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
}
// STEP 1: ETL/vectorization
// STEP 2: Model configuration and initialization
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(RANDOM_SEED)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(LEARNING_RATE)
.weightInit(WeightInit.XAVIER)
.updater(Updater.ADAM)
.graphBuilder()
.addInputs("trainFeatures")
.setOutputs("predictMortality")
.addLayer("L1", new GravesLSTM.Builder()
.nIn(NB_INPUTS)
.nOut(lstmLayerSize)
.activation(Activation.TANH)
.build(),
"trainFeatures")
.addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.XENT)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.nIn(lstmLayerSize).nOut(numLabelClasses).build(),"L1")
.pretrain(false).backprop(true)
.build();
// STEP 3 Performance monitoring
ComputationGraph model = new ComputationGraph(conf);
model.init();
model.setListeners(new ScoreIterationListener(10));
// STEP 4 Model training
for( int i=0; i<NB_EPOCHS; i++ ){
model.fit(trainData); // implicit inner loop over minibatches
// loop over batches in training data to compute training AUC
//TH would like to define AUC, and ROC and maybe point to docs
//TH not exactly the correct place for this, but they will want to discuss what the output is, so step back from the output layer which I assume is softmax or what is needed for classification, what is the output at each neuron, like one layer back? It gets a collection of sequences in, emits a collection of sequences?
//TH Thanks
ROC roc = new ROC(100);
trainData.reset();
while(trainData.hasNext()){
DataSet batch = trainData.next();
INDArray[] output = model.output(batch.getFeatures());
roc.evalTimeSeries(batch.getLabels(), output[0]);
}
log.info("EPOCH " + i + " TRAIN AUC: " + roc.calculateAUC());
roc = new ROC(100);
while (validData.hasNext()) {
DataSet batch = validData.next();
INDArray[] output = model.output(batch.getFeatures());
roc.evalTimeSeries(batch.getLabels(), output[0]);
}
log.info("EPOCH " + i + " VALID AUC: " + roc.calculateAUC());
trainData.reset();
validData.reset();
}
ROC roc = new ROC(100);
while (testData.hasNext()) {
DataSet batch = testData.next();
INDArray[] output = model.output(batch.getFeatures());
roc.evalTimeSeries(batch.getLabels(), output[0]);
}
log.info("***** Test Evaluation *****");
log.info("{}", roc.calculateAUC());
}
}
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.