BAEL-1208: A guide to deeplearning4j (#2717)
This commit is contained in:
parent
491649a866
commit
a48a401dc2
|
@ -0,0 +1,5 @@
|
||||||
|
### Sample deeplearning4j Project
|
||||||
|
This is a sample project for the [deeplearning4j](https://deeplearning4j.org) library.
|
||||||
|
|
||||||
|
### Relevant Articles:
|
||||||
|
- [A Guide to deeplearning4j](http://www.baeldung.com/a-guide-to-deeplearning4j/)
|
|
@ -0,0 +1,33 @@
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
<groupId>com.baeldung.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j</artifactId>
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
<name>deeplearning4j</name>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
<maven.compiler.source>1.8</maven.compiler.source>
|
||||||
|
<maven.compiler.target>1.8</maven.compiler.target>
|
||||||
|
<dl4j.version>0.9.1</dl4j.version>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-native-platform</artifactId>
|
||||||
|
<version>${dl4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-core</artifactId>
|
||||||
|
<version>${dl4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
</project>
|
|
@ -0,0 +1,80 @@
|
||||||
|
package com.baeldung.deeplearning4j;
|
||||||
|
|
||||||
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
|
import org.datavec.api.split.FileSplit;
|
||||||
|
import org.datavec.api.util.ClassPathResource;
|
||||||
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
|
import org.deeplearning4j.eval.Evaluation;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.SplitTestAndTrain;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class IrisClassifier {
|
||||||
|
|
||||||
|
private static final int CLASSES_COUNT = 3;
|
||||||
|
private static final int FEATURES_COUNT = 4;
|
||||||
|
|
||||||
|
public static void main(String[] args) throws IOException, InterruptedException {
|
||||||
|
|
||||||
|
DataSet allData;
|
||||||
|
try (RecordReader recordReader = new CSVRecordReader(0, ',')) {
|
||||||
|
recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
|
||||||
|
|
||||||
|
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT);
|
||||||
|
allData = iterator.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
allData.shuffle(42);
|
||||||
|
|
||||||
|
DataNormalization normalizer = new NormalizerStandardize();
|
||||||
|
normalizer.fit(allData);
|
||||||
|
normalizer.transform(allData);
|
||||||
|
|
||||||
|
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
|
||||||
|
DataSet trainingData = testAndTrain.getTrain();
|
||||||
|
DataSet testData = testAndTrain.getTest();
|
||||||
|
|
||||||
|
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
|
||||||
|
.iterations(1000)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.learningRate(0.1)
|
||||||
|
.regularization(true).l2(0.0001)
|
||||||
|
.list()
|
||||||
|
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3)
|
||||||
|
.build())
|
||||||
|
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
|
||||||
|
.build())
|
||||||
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nIn(3).nOut(CLASSES_COUNT).build())
|
||||||
|
.backprop(true).pretrain(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
|
||||||
|
model.init();
|
||||||
|
model.fit(trainingData);
|
||||||
|
|
||||||
|
INDArray output = model.output(testData.getFeatureMatrix());
|
||||||
|
|
||||||
|
Evaluation eval = new Evaluation(CLASSES_COUNT);
|
||||||
|
eval.eval(testData.getLabels(), output);
|
||||||
|
System.out.println(eval.stats());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,150 @@
|
||||||
|
5.1,3.5,1.4,0.2,0
|
||||||
|
4.9,3.0,1.4,0.2,0
|
||||||
|
4.7,3.2,1.3,0.2,0
|
||||||
|
4.6,3.1,1.5,0.2,0
|
||||||
|
5.0,3.6,1.4,0.2,0
|
||||||
|
5.4,3.9,1.7,0.4,0
|
||||||
|
4.6,3.4,1.4,0.3,0
|
||||||
|
5.0,3.4,1.5,0.2,0
|
||||||
|
4.4,2.9,1.4,0.2,0
|
||||||
|
4.9,3.1,1.5,0.1,0
|
||||||
|
5.4,3.7,1.5,0.2,0
|
||||||
|
4.8,3.4,1.6,0.2,0
|
||||||
|
4.8,3.0,1.4,0.1,0
|
||||||
|
4.3,3.0,1.1,0.1,0
|
||||||
|
5.8,4.0,1.2,0.2,0
|
||||||
|
5.7,4.4,1.5,0.4,0
|
||||||
|
5.4,3.9,1.3,0.4,0
|
||||||
|
5.1,3.5,1.4,0.3,0
|
||||||
|
5.7,3.8,1.7,0.3,0
|
||||||
|
5.1,3.8,1.5,0.3,0
|
||||||
|
5.4,3.4,1.7,0.2,0
|
||||||
|
5.1,3.7,1.5,0.4,0
|
||||||
|
4.6,3.6,1.0,0.2,0
|
||||||
|
5.1,3.3,1.7,0.5,0
|
||||||
|
4.8,3.4,1.9,0.2,0
|
||||||
|
5.0,3.0,1.6,0.2,0
|
||||||
|
5.0,3.4,1.6,0.4,0
|
||||||
|
5.2,3.5,1.5,0.2,0
|
||||||
|
5.2,3.4,1.4,0.2,0
|
||||||
|
4.7,3.2,1.6,0.2,0
|
||||||
|
4.8,3.1,1.6,0.2,0
|
||||||
|
5.4,3.4,1.5,0.4,0
|
||||||
|
5.2,4.1,1.5,0.1,0
|
||||||
|
5.5,4.2,1.4,0.2,0
|
||||||
|
4.9,3.1,1.5,0.1,0
|
||||||
|
5.0,3.2,1.2,0.2,0
|
||||||
|
5.5,3.5,1.3,0.2,0
|
||||||
|
4.9,3.1,1.5,0.1,0
|
||||||
|
4.4,3.0,1.3,0.2,0
|
||||||
|
5.1,3.4,1.5,0.2,0
|
||||||
|
5.0,3.5,1.3,0.3,0
|
||||||
|
4.5,2.3,1.3,0.3,0
|
||||||
|
4.4,3.2,1.3,0.2,0
|
||||||
|
5.0,3.5,1.6,0.6,0
|
||||||
|
5.1,3.8,1.9,0.4,0
|
||||||
|
4.8,3.0,1.4,0.3,0
|
||||||
|
5.1,3.8,1.6,0.2,0
|
||||||
|
4.6,3.2,1.4,0.2,0
|
||||||
|
5.3,3.7,1.5,0.2,0
|
||||||
|
5.0,3.3,1.4,0.2,0
|
||||||
|
7.0,3.2,4.7,1.4,1
|
||||||
|
6.4,3.2,4.5,1.5,1
|
||||||
|
6.9,3.1,4.9,1.5,1
|
||||||
|
5.5,2.3,4.0,1.3,1
|
||||||
|
6.5,2.8,4.6,1.5,1
|
||||||
|
5.7,2.8,4.5,1.3,1
|
||||||
|
6.3,3.3,4.7,1.6,1
|
||||||
|
4.9,2.4,3.3,1.0,1
|
||||||
|
6.6,2.9,4.6,1.3,1
|
||||||
|
5.2,2.7,3.9,1.4,1
|
||||||
|
5.0,2.0,3.5,1.0,1
|
||||||
|
5.9,3.0,4.2,1.5,1
|
||||||
|
6.0,2.2,4.0,1.0,1
|
||||||
|
6.1,2.9,4.7,1.4,1
|
||||||
|
5.6,2.9,3.6,1.3,1
|
||||||
|
6.7,3.1,4.4,1.4,1
|
||||||
|
5.6,3.0,4.5,1.5,1
|
||||||
|
5.8,2.7,4.1,1.0,1
|
||||||
|
6.2,2.2,4.5,1.5,1
|
||||||
|
5.6,2.5,3.9,1.1,1
|
||||||
|
5.9,3.2,4.8,1.8,1
|
||||||
|
6.1,2.8,4.0,1.3,1
|
||||||
|
6.3,2.5,4.9,1.5,1
|
||||||
|
6.1,2.8,4.7,1.2,1
|
||||||
|
6.4,2.9,4.3,1.3,1
|
||||||
|
6.6,3.0,4.4,1.4,1
|
||||||
|
6.8,2.8,4.8,1.4,1
|
||||||
|
6.7,3.0,5.0,1.7,1
|
||||||
|
6.0,2.9,4.5,1.5,1
|
||||||
|
5.7,2.6,3.5,1.0,1
|
||||||
|
5.5,2.4,3.8,1.1,1
|
||||||
|
5.5,2.4,3.7,1.0,1
|
||||||
|
5.8,2.7,3.9,1.2,1
|
||||||
|
6.0,2.7,5.1,1.6,1
|
||||||
|
5.4,3.0,4.5,1.5,1
|
||||||
|
6.0,3.4,4.5,1.6,1
|
||||||
|
6.7,3.1,4.7,1.5,1
|
||||||
|
6.3,2.3,4.4,1.3,1
|
||||||
|
5.6,3.0,4.1,1.3,1
|
||||||
|
5.5,2.5,4.0,1.3,1
|
||||||
|
5.5,2.6,4.4,1.2,1
|
||||||
|
6.1,3.0,4.6,1.4,1
|
||||||
|
5.8,2.6,4.0,1.2,1
|
||||||
|
5.0,2.3,3.3,1.0,1
|
||||||
|
5.6,2.7,4.2,1.3,1
|
||||||
|
5.7,3.0,4.2,1.2,1
|
||||||
|
5.7,2.9,4.2,1.3,1
|
||||||
|
6.2,2.9,4.3,1.3,1
|
||||||
|
5.1,2.5,3.0,1.1,1
|
||||||
|
5.7,2.8,4.1,1.3,1
|
||||||
|
6.3,3.3,6.0,2.5,2
|
||||||
|
5.8,2.7,5.1,1.9,2
|
||||||
|
7.1,3.0,5.9,2.1,2
|
||||||
|
6.3,2.9,5.6,1.8,2
|
||||||
|
6.5,3.0,5.8,2.2,2
|
||||||
|
7.6,3.0,6.6,2.1,2
|
||||||
|
4.9,2.5,4.5,1.7,2
|
||||||
|
7.3,2.9,6.3,1.8,2
|
||||||
|
6.7,2.5,5.8,1.8,2
|
||||||
|
7.2,3.6,6.1,2.5,2
|
||||||
|
6.5,3.2,5.1,2.0,2
|
||||||
|
6.4,2.7,5.3,1.9,2
|
||||||
|
6.8,3.0,5.5,2.1,2
|
||||||
|
5.7,2.5,5.0,2.0,2
|
||||||
|
5.8,2.8,5.1,2.4,2
|
||||||
|
6.4,3.2,5.3,2.3,2
|
||||||
|
6.5,3.0,5.5,1.8,2
|
||||||
|
7.7,3.8,6.7,2.2,2
|
||||||
|
7.7,2.6,6.9,2.3,2
|
||||||
|
6.0,2.2,5.0,1.5,2
|
||||||
|
6.9,3.2,5.7,2.3,2
|
||||||
|
5.6,2.8,4.9,2.0,2
|
||||||
|
7.7,2.8,6.7,2.0,2
|
||||||
|
6.3,2.7,4.9,1.8,2
|
||||||
|
6.7,3.3,5.7,2.1,2
|
||||||
|
7.2,3.2,6.0,1.8,2
|
||||||
|
6.2,2.8,4.8,1.8,2
|
||||||
|
6.1,3.0,4.9,1.8,2
|
||||||
|
6.4,2.8,5.6,2.1,2
|
||||||
|
7.2,3.0,5.8,1.6,2
|
||||||
|
7.4,2.8,6.1,1.9,2
|
||||||
|
7.9,3.8,6.4,2.0,2
|
||||||
|
6.4,2.8,5.6,2.2,2
|
||||||
|
6.3,2.8,5.1,1.5,2
|
||||||
|
6.1,2.6,5.6,1.4,2
|
||||||
|
7.7,3.0,6.1,2.3,2
|
||||||
|
6.3,3.4,5.6,2.4,2
|
||||||
|
6.4,3.1,5.5,1.8,2
|
||||||
|
6.0,3.0,4.8,1.8,2
|
||||||
|
6.9,3.1,5.4,2.1,2
|
||||||
|
6.7,3.1,5.6,2.4,2
|
||||||
|
6.9,3.1,5.1,2.3,2
|
||||||
|
5.8,2.7,5.1,1.9,2
|
||||||
|
6.8,3.2,5.9,2.3,2
|
||||||
|
6.7,3.3,5.7,2.5,2
|
||||||
|
6.7,3.0,5.2,2.3,2
|
||||||
|
6.3,2.5,5.0,1.9,2
|
||||||
|
6.5,3.0,5.2,2.0,2
|
||||||
|
6.2,3.4,5.4,2.3,2
|
||||||
|
5.9,3.0,5.1,1.8,2
|
Loading…
Reference in New Issue