Set the learning rates (#13052)

This commit is contained in:
Andrey Shcherbakov 2022-11-24 20:29:13 +01:00 committed by GitHub
parent fd1bf7a029
commit c5944c1d3d
1 changed files with 6 additions and 2 deletions

View File

@ -11,6 +11,7 @@ import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
@ -69,7 +70,8 @@ public class MnistClassifier {
String localFilePath = basePath + "mnist_png.tar.gz";
File file = new File(localFilePath);
if (!file.exists()) {
file.getParentFile().mkdirs();
file.getParentFile()
.mkdirs();
Utils.downloadAndSave(dataUrl, file);
Utils.extractTarArchive(file, basePath);
}
@ -132,7 +134,9 @@ public class MnistClassifier {
.build();
final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(seed)
.l2(0.0005) // ridge regression value
.updater(new Nesterovs()) //TODO new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)
.updater(new Nesterovs())
.learningRateSchedule(learningRateSchedule)
.learningRateDecayPolicy(LearningRatePolicy.Schedule)
.weightInit(WeightInit.XAVIER)
.list()
.layer(0, layer1)