Set the learning rates (#13052)
This commit is contained in:
parent
fd1bf7a029
commit
c5944c1d3d
|
@ -11,6 +11,7 @@ import org.datavec.image.loader.NativeImageLoader;
|
|||
import org.datavec.image.recordreader.ImageRecordReader;
|
||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.nn.conf.LearningRatePolicy;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
|
@ -69,7 +70,8 @@ public class MnistClassifier {
|
|||
String localFilePath = basePath + "mnist_png.tar.gz";
|
||||
File file = new File(localFilePath);
|
||||
if (!file.exists()) {
|
||||
file.getParentFile().mkdirs();
|
||||
file.getParentFile()
|
||||
.mkdirs();
|
||||
Utils.downloadAndSave(dataUrl, file);
|
||||
Utils.extractTarArchive(file, basePath);
|
||||
}
|
||||
|
@ -132,7 +134,9 @@ public class MnistClassifier {
|
|||
.build();
|
||||
final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(seed)
|
||||
.l2(0.0005) // ridge regression value
|
||||
.updater(new Nesterovs()) //TODO new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)
|
||||
.updater(new Nesterovs())
|
||||
.learningRateSchedule(learningRateSchedule)
|
||||
.learningRateDecayPolicy(LearningRatePolicy.Schedule)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(0, layer1)
|
||||
|
|
Loading…
Reference in New Issue