CNN example with Deeplearning4j in Java

This commit is contained in:
helga_sh 2020-07-21 16:24:31 +03:00
parent bf5c396967
commit adc586c566
6 changed files with 226 additions and 0 deletions

View File

@ -37,6 +37,16 @@
<artifactId>deeplearning4j-nn</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.5</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>1.7.5</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.datavec/datavec-api -->
<dependency>
<groupId>org.datavec</groupId>

View File

@ -0,0 +1,21 @@
package com.baeldung.deeplearning4j.cnn;
import com.baeldung.deeplearning4j.cnn.domain.network.CnnModel;
import com.baeldung.deeplearning4j.cnn.domain.network.CnnModelProperties;
import com.baeldung.deeplearning4j.cnn.service.dataset.CifarDataSetService;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.eval.Evaluation;
@Slf4j
public class CnnExample {
public static void main(String... args) {
CnnModel network = new CnnModel(new CifarDataSetService(), new CnnModelProperties());
network.train();
Evaluation evaluation = network.evaluate();
log.info(evaluation.stats());
}
}

View File

@ -0,0 +1,120 @@
package com.baeldung.deeplearning4j.cnn.domain.network;
import com.baeldung.deeplearning4j.cnn.service.dataset.IDataSetService;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.stream.IntStream;
@Slf4j
public class CnnModel {
private final IDataSetService dataSetService;
private MultiLayerNetwork network;
private final CnnModelProperties properties;
public CnnModel(IDataSetService dataSetService, CnnModelProperties properties) {
this.dataSetService = dataSetService;
this.properties = properties;
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
.seed(1611)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(properties.getLearningRate())
.regularization(true)
.updater(properties.getOptimizer())
.list()
.layer(0, conv5x5())
.layer(1, pooling2x2Stride2())
.layer(2, conv3x3Stride1Padding2())
.layer(3, pooling2x2Stride1())
.layer(4, conv3x3Stride1Padding1())
.layer(5, pooling2x2Stride1())
.layer(6, dense())
.pretrain(false)
.backprop(true)
.setInputType(dataSetService.inputType())
.build();
network = new MultiLayerNetwork(configuration);
}
public void train() {
network.init();
int epochsNum = properties.getEpochsNum();
IntStream.range(1, epochsNum + 1).forEach(epoch -> {
log.info(String.format("Epoch %d?%d", epoch, epochsNum));
network.fit(dataSetService.trainIterator());
});
}
public Evaluation evaluate() {
return network.evaluate(dataSetService.testIterator());
}
private ConvolutionLayer conv5x5() {
return new ConvolutionLayer.Builder(5, 5)
.nIn(3)
.nOut(16)
.stride(1, 1)
.padding(1, 1)
.weightInit(WeightInit.XAVIER_UNIFORM)
.activation(Activation.RELU)
.build();
}
private SubsamplingLayer pooling2x2Stride2() {
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build();
}
private ConvolutionLayer conv3x3Stride1Padding2() {
return new ConvolutionLayer.Builder(3, 3)
.nOut(32)
.stride(1, 1)
.padding(2, 2)
.weightInit(WeightInit.XAVIER_UNIFORM)
.activation(Activation.RELU)
.build();
}
private SubsamplingLayer pooling2x2Stride1() {
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(1, 1)
.build();
}
private ConvolutionLayer conv3x3Stride1Padding1() {
return new ConvolutionLayer.Builder(3, 3)
.nOut(64)
.stride(1, 1)
.padding(1, 1)
.weightInit(WeightInit.XAVIER_UNIFORM)
.activation(Activation.RELU)
.build();
}
private OutputLayer dense() {
return new OutputLayer.Builder(LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER_UNIFORM)
.nOut(dataSetService.labels().size() - 1)
.build();
}
}

View File

@ -0,0 +1,13 @@
package com.baeldung.deeplearning4j.cnn.domain.network;
import lombok.Value;
import org.deeplearning4j.nn.conf.Updater;
@Value
public class CnnModelProperties {
private final int epochsNum = 512;
private final double learningRate = 0.001;
private final Updater optimizer = Updater.ADAM;
}

View File

@ -0,0 +1,46 @@
package com.baeldung.deeplearning4j.cnn.service.dataset;
import lombok.Getter;
import org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.util.List;
@Getter
public class CifarDataSetService implements IDataSetService {
private CifarDataSetIterator trainIterator;
private CifarDataSetIterator testIterator;
private final InputType inputType = InputType.convolutional(32,32,3);
private final int trainImagesNum = 512;
private final int testImagesNum = 128;
private final int trainBatch = 16;
private final int testBatch = 8;
public CifarDataSetService() {
trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true);
testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false);
}
@Override
public DataSetIterator trainIterator() {
return trainIterator;
}
@Override
public DataSetIterator testIterator() {
return testIterator;
}
@Override
public InputType inputType() {
return inputType;
}
@Override
public List<String> labels() {
return trainIterator.getLabels();
}
}

View File

@ -0,0 +1,16 @@
package com.baeldung.deeplearning4j.cnn.service.dataset;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.util.List;
public interface IDataSetService {
DataSetIterator trainIterator();
DataSetIterator testIterator();
InputType inputType();
List<String> labels();
}