Set the learning rates (#13052)
This commit is contained in:
parent
fd1bf7a029
commit
c5944c1d3d
@ -11,6 +11,7 @@ import org.datavec.image.loader.NativeImageLoader;
|
|||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
import org.deeplearning4j.eval.Evaluation;
|
||||||
|
import org.deeplearning4j.nn.conf.LearningRatePolicy;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
@ -69,7 +70,8 @@ public class MnistClassifier {
|
|||||||
String localFilePath = basePath + "mnist_png.tar.gz";
|
String localFilePath = basePath + "mnist_png.tar.gz";
|
||||||
File file = new File(localFilePath);
|
File file = new File(localFilePath);
|
||||||
if (!file.exists()) {
|
if (!file.exists()) {
|
||||||
file.getParentFile().mkdirs();
|
file.getParentFile()
|
||||||
|
.mkdirs();
|
||||||
Utils.downloadAndSave(dataUrl, file);
|
Utils.downloadAndSave(dataUrl, file);
|
||||||
Utils.extractTarArchive(file, basePath);
|
Utils.extractTarArchive(file, basePath);
|
||||||
}
|
}
|
||||||
@ -132,7 +134,9 @@ public class MnistClassifier {
|
|||||||
.build();
|
.build();
|
||||||
final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(seed)
|
final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(seed)
|
||||||
.l2(0.0005) // ridge regression value
|
.l2(0.0005) // ridge regression value
|
||||||
.updater(new Nesterovs()) //TODO new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)
|
.updater(new Nesterovs())
|
||||||
|
.learningRateSchedule(learningRateSchedule)
|
||||||
|
.learningRateDecayPolicy(LearningRatePolicy.Schedule)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.list()
|
.list()
|
||||||
.layer(0, layer1)
|
.layer(0, layer1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user