Merge pull request #9754 from HelgaShiryaeva/cnn-dl4j-java-example
CNN example with Deeplearning4j in Java
This commit is contained in:
commit
3acad0b7cf
|
@ -37,6 +37,16 @@
|
|||
<artifactId>deeplearning4j-nn</artifactId>
|
||||
<version>${dl4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>${slf4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-log4j12</artifactId>
|
||||
<version>${slf4j.version}</version>
|
||||
</dependency>
|
||||
<!-- https://mvnrepository.com/artifact/org.datavec/datavec-api -->
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
|
@ -53,6 +63,7 @@
|
|||
<properties>
|
||||
<dl4j.version>0.9.1</dl4j.version> <!-- Latest non beta version -->
|
||||
<httpclient.version>4.3.5</httpclient.version>
|
||||
<slf4j.version>1.7.5</slf4j.version>
|
||||
</properties>
|
||||
|
||||
</project>
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
package com.baeldung.deeplearning4j.cnn;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
class CifarDataSetService implements IDataSetService {
|
||||
|
||||
private final InputType inputType = InputType.convolutional(32, 32, 3);
|
||||
private final int trainImagesNum = 512;
|
||||
private final int testImagesNum = 128;
|
||||
private final int trainBatch = 16;
|
||||
private final int testBatch = 8;
|
||||
|
||||
private final CifarDataSetIterator trainIterator;
|
||||
|
||||
private final CifarDataSetIterator testIterator;
|
||||
|
||||
CifarDataSetService() {
|
||||
trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true);
|
||||
testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSetIterator trainIterator() {
|
||||
return trainIterator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSetIterator testIterator() {
|
||||
return testIterator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType inputType() {
|
||||
return inputType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> labels() {
|
||||
return trainIterator.getLabels();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package com.baeldung.deeplearning4j.cnn;
|
||||
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
|
||||
@Slf4j
|
||||
class CnnExample {
|
||||
|
||||
public static void main(String... args) {
|
||||
CnnModel network = new CnnModel(new CifarDataSetService(), new CnnModelProperties());
|
||||
|
||||
network.train();
|
||||
Evaluation evaluation = network.evaluate();
|
||||
|
||||
log.info(evaluation.stats());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
package com.baeldung.deeplearning4j.cnn;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
@Slf4j
|
||||
class CnnModel {
|
||||
|
||||
private final IDataSetService dataSetService;
|
||||
|
||||
private final MultiLayerNetwork network;
|
||||
|
||||
private final CnnModelProperties properties;
|
||||
|
||||
CnnModel(IDataSetService dataSetService, CnnModelProperties properties) {
|
||||
|
||||
this.dataSetService = dataSetService;
|
||||
this.properties = properties;
|
||||
|
||||
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
|
||||
.seed(1611)
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.learningRate(properties.getLearningRate())
|
||||
.regularization(true)
|
||||
.updater(properties.getOptimizer())
|
||||
.list()
|
||||
.layer(0, conv5x5())
|
||||
.layer(1, pooling2x2Stride2())
|
||||
.layer(2, conv3x3Stride1Padding2())
|
||||
.layer(3, pooling2x2Stride1())
|
||||
.layer(4, conv3x3Stride1Padding1())
|
||||
.layer(5, pooling2x2Stride1())
|
||||
.layer(6, dense())
|
||||
.pretrain(false)
|
||||
.backprop(true)
|
||||
.setInputType(dataSetService.inputType())
|
||||
.build();
|
||||
|
||||
network = new MultiLayerNetwork(configuration);
|
||||
}
|
||||
|
||||
void train() {
|
||||
network.init();
|
||||
int epochsNum = properties.getEpochsNum();
|
||||
IntStream.range(1, epochsNum + 1).forEach(epoch -> {
|
||||
log.info("Epoch {} / {}", epoch, epochsNum);
|
||||
network.fit(dataSetService.trainIterator());
|
||||
});
|
||||
}
|
||||
|
||||
Evaluation evaluate() {
|
||||
return network.evaluate(dataSetService.testIterator());
|
||||
}
|
||||
|
||||
private ConvolutionLayer conv5x5() {
|
||||
return new ConvolutionLayer.Builder(5, 5)
|
||||
.nIn(3)
|
||||
.nOut(16)
|
||||
.stride(1, 1)
|
||||
.padding(1, 1)
|
||||
.weightInit(WeightInit.XAVIER_UNIFORM)
|
||||
.activation(Activation.RELU)
|
||||
.build();
|
||||
}
|
||||
|
||||
private SubsamplingLayer pooling2x2Stride2() {
|
||||
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
|
||||
.kernelSize(2, 2)
|
||||
.stride(2, 2)
|
||||
.build();
|
||||
}
|
||||
|
||||
private ConvolutionLayer conv3x3Stride1Padding2() {
|
||||
return new ConvolutionLayer.Builder(3, 3)
|
||||
.nOut(32)
|
||||
.stride(1, 1)
|
||||
.padding(2, 2)
|
||||
.weightInit(WeightInit.XAVIER_UNIFORM)
|
||||
.activation(Activation.RELU)
|
||||
.build();
|
||||
}
|
||||
|
||||
private SubsamplingLayer pooling2x2Stride1() {
|
||||
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
|
||||
.kernelSize(2, 2)
|
||||
.stride(1, 1)
|
||||
.build();
|
||||
}
|
||||
|
||||
private ConvolutionLayer conv3x3Stride1Padding1() {
|
||||
return new ConvolutionLayer.Builder(3, 3)
|
||||
.nOut(64)
|
||||
.stride(1, 1)
|
||||
.padding(1, 1)
|
||||
.weightInit(WeightInit.XAVIER_UNIFORM)
|
||||
.activation(Activation.RELU)
|
||||
.build();
|
||||
}
|
||||
|
||||
private OutputLayer dense() {
|
||||
return new OutputLayer.Builder(LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.weightInit(WeightInit.XAVIER_UNIFORM)
|
||||
.nOut(dataSetService.labels().size() - 1)
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package com.baeldung.deeplearning4j.cnn;
|
||||
|
||||
import lombok.Value;
|
||||
import org.deeplearning4j.nn.conf.Updater;
|
||||
|
||||
@Value
|
||||
class CnnModelProperties {
|
||||
private final int epochsNum = 512;
|
||||
|
||||
private final double learningRate = 0.001;
|
||||
|
||||
private final Updater optimizer = Updater.ADAM;
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
package com.baeldung.deeplearning4j.cnn;
|
||||
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
interface IDataSetService {
|
||||
DataSetIterator trainIterator();
|
||||
|
||||
DataSetIterator testIterator();
|
||||
|
||||
InputType inputType();
|
||||
|
||||
List<String> labels();
|
||||
}
|
Loading…
Reference in New Issue