diff --git a/deeplearning4j/src/main/java/com/baeldung/logreg/MnistClassifier.java b/deeplearning4j/src/main/java/com/baeldung/logreg/MnistClassifier.java index c8580b9c27..01f1ef1bdc 100644 --- a/deeplearning4j/src/main/java/com/baeldung/logreg/MnistClassifier.java +++ b/deeplearning4j/src/main/java/com/baeldung/logreg/MnistClassifier.java @@ -11,6 +11,7 @@ import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.recordreader.ImageRecordReader; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.eval.Evaluation; +import org.deeplearning4j.nn.conf.LearningRatePolicy; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -69,7 +70,8 @@ public class MnistClassifier { String localFilePath = basePath + "mnist_png.tar.gz"; File file = new File(localFilePath); if (!file.exists()) { - file.getParentFile().mkdirs(); + file.getParentFile() + .mkdirs(); Utils.downloadAndSave(dataUrl, file); Utils.extractTarArchive(file, basePath); } @@ -132,7 +134,9 @@ public class MnistClassifier { .build(); final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(seed) .l2(0.0005) // ridge regression value - .updater(new Nesterovs()) //TODO new MapSchedule(ScheduleType.ITERATION, learningRateSchedule) + .updater(new Nesterovs()) + .learningRateSchedule(learningRateSchedule) + .learningRateDecayPolicy(LearningRatePolicy.Schedule) .weightInit(WeightInit.XAVIER) .list() .layer(0, layer1)