CNN example with Deeplearning4j in Java: refactor

This commit is contained in:
helga_sh 2020-07-23 16:17:04 +03:00
parent adc586c566
commit 51f1fc9b1e
6 changed files with 26 additions and 28 deletions

View File

@ -40,12 +40,12 @@
<dependency> <dependency>
<groupId>org.slf4j</groupId> <groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId> <artifactId>slf4j-api</artifactId>
<version>1.7.5</version> <version>${sl4j.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.slf4j</groupId> <groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId> <artifactId>slf4j-log4j12</artifactId>
<version>1.7.5</version> <version>${sl4j.version}</version>
</dependency> </dependency>
<!-- https://mvnrepository.com/artifact/org.datavec/datavec-api --> <!-- https://mvnrepository.com/artifact/org.datavec/datavec-api -->
<dependency> <dependency>
@ -63,6 +63,7 @@
<properties> <properties>
<dl4j.version>0.9.1</dl4j.version> <!-- Latest non beta version --> <dl4j.version>0.9.1</dl4j.version> <!-- Latest non beta version -->
<httpclient.version>4.3.5</httpclient.version> <httpclient.version>4.3.5</httpclient.version>
<sl4j.version>1.7.5</sl4j.version>
</properties> </properties>
</project> </project>

View File

@ -1,4 +1,4 @@
package com.baeldung.deeplearning4j.cnn.service.dataset; package com.baeldung.deeplearning4j.cnn;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator;
@ -8,18 +8,19 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.util.List; import java.util.List;
@Getter @Getter
public class CifarDataSetService implements IDataSetService { class CifarDataSetService implements IDataSetService {
private CifarDataSetIterator trainIterator; private final InputType inputType = InputType.convolutional(32, 32, 3);
private CifarDataSetIterator testIterator;
private final InputType inputType = InputType.convolutional(32,32,3);
private final int trainImagesNum = 512; private final int trainImagesNum = 512;
private final int testImagesNum = 128; private final int testImagesNum = 128;
private final int trainBatch = 16; private final int trainBatch = 16;
private final int testBatch = 8; private final int testBatch = 8;
public CifarDataSetService() { private final CifarDataSetIterator trainIterator;
private final CifarDataSetIterator testIterator;
CifarDataSetService() {
trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true); trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true);
testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false); testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false);
} }

View File

@ -1,14 +1,11 @@
package com.baeldung.deeplearning4j.cnn; package com.baeldung.deeplearning4j.cnn;
import com.baeldung.deeplearning4j.cnn.domain.network.CnnModel;
import com.baeldung.deeplearning4j.cnn.domain.network.CnnModelProperties;
import com.baeldung.deeplearning4j.cnn.service.dataset.CifarDataSetService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.eval.Evaluation;
@Slf4j @Slf4j
public class CnnExample { class CnnExample {
public static void main(String... args) { public static void main(String... args) {
CnnModel network = new CnnModel(new CifarDataSetService(), new CnnModelProperties()); CnnModel network = new CnnModel(new CifarDataSetService(), new CnnModelProperties());

View File

@ -1,6 +1,5 @@
package com.baeldung.deeplearning4j.cnn.domain.network; package com.baeldung.deeplearning4j.cnn;
import com.baeldung.deeplearning4j.cnn.service.dataset.IDataSetService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -17,15 +16,15 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.stream.IntStream; import java.util.stream.IntStream;
@Slf4j @Slf4j
public class CnnModel { class CnnModel {
private final IDataSetService dataSetService; private final IDataSetService dataSetService;
private MultiLayerNetwork network; private final MultiLayerNetwork network;
private final CnnModelProperties properties; private final CnnModelProperties properties;
public CnnModel(IDataSetService dataSetService, CnnModelProperties properties) { CnnModel(IDataSetService dataSetService, CnnModelProperties properties) {
this.dataSetService = dataSetService; this.dataSetService = dataSetService;
this.properties = properties; this.properties = properties;
@ -52,16 +51,16 @@ public class CnnModel {
network = new MultiLayerNetwork(configuration); network = new MultiLayerNetwork(configuration);
} }
public void train() { void train() {
network.init(); network.init();
int epochsNum = properties.getEpochsNum(); int epochsNum = properties.getEpochsNum();
IntStream.range(1, epochsNum + 1).forEach(epoch -> { IntStream.range(1, epochsNum + 1).forEach(epoch -> {
log.info(String.format("Epoch %d?%d", epoch, epochsNum)); log.info("Epoch {} / {}", epoch, epochsNum);
network.fit(dataSetService.trainIterator()); network.fit(dataSetService.trainIterator());
}); });
} }
public Evaluation evaluate() { Evaluation evaluate() {
return network.evaluate(dataSetService.testIterator()); return network.evaluate(dataSetService.testIterator());
} }
@ -95,7 +94,7 @@ public class CnnModel {
private SubsamplingLayer pooling2x2Stride1() { private SubsamplingLayer pooling2x2Stride1() {
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2) .kernelSize(2, 2)
.stride(1, 1) .stride(1, 1)
.build(); .build();
} }

View File

@ -1,10 +1,10 @@
package com.baeldung.deeplearning4j.cnn.domain.network; package com.baeldung.deeplearning4j.cnn;
import lombok.Value; import lombok.Value;
import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.Updater;
@Value @Value
public class CnnModelProperties { class CnnModelProperties {
private final int epochsNum = 512; private final int epochsNum = 512;
private final double learningRate = 0.001; private final double learningRate = 0.001;

View File

@ -1,11 +1,11 @@
package com.baeldung.deeplearning4j.cnn.service.dataset; package com.baeldung.deeplearning4j.cnn;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.util.List; import java.util.List;
public interface IDataSetService { interface IDataSetService {
DataSetIterator trainIterator(); DataSetIterator trainIterator();
DataSetIterator testIterator(); DataSetIterator testIterator();