From a48a401dc2188eba9164153bd9b207b8f3b9758e Mon Sep 17 00:00:00 2001 From: Sergey Petunin Date: Mon, 9 Oct 2017 18:13:46 +0200 Subject: [PATCH] BAEL-1208: A guide to deeplearning4j (#2717) --- deeplearning4j/README.md | 5 + deeplearning4j/pom.xml | 33 ++++ .../deeplearning4j/IrisClassifier.java | 80 ++++++++++ deeplearning4j/src/main/resources/iris.txt | 150 ++++++++++++++++++ pom.xml | 1 + 5 files changed, 269 insertions(+) create mode 100644 deeplearning4j/README.md create mode 100644 deeplearning4j/pom.xml create mode 100644 deeplearning4j/src/main/java/com/baeldung/deeplearning4j/IrisClassifier.java create mode 100644 deeplearning4j/src/main/resources/iris.txt diff --git a/deeplearning4j/README.md b/deeplearning4j/README.md new file mode 100644 index 0000000000..729ab101fd --- /dev/null +++ b/deeplearning4j/README.md @@ -0,0 +1,5 @@ +### Sample deeplearning4j Project +This is a sample project for the [deeplearning4j](https://deeplearning4j.org) library. + +### Relevant Articles: +- [A Guide to deeplearning4j](http://www.baeldung.com/a-guide-to-deeplearning4j/) diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml new file mode 100644 index 0000000000..a39fabc3d6 --- /dev/null +++ b/deeplearning4j/pom.xml @@ -0,0 +1,33 @@ + + 4.0.0 + com.baeldung.deeplearning4j + deeplearning4j + jar + 1.0-SNAPSHOT + deeplearning4j + + + UTF-8 + 1.8 + 1.8 + 0.9.1 + + + + + + org.nd4j + nd4j-native-platform + ${dl4j.version} + + + + org.deeplearning4j + deeplearning4j-core + ${dl4j.version} + + + + + \ No newline at end of file diff --git a/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/IrisClassifier.java b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/IrisClassifier.java new file mode 100644 index 0000000000..bf341209e1 --- /dev/null +++ b/deeplearning4j/src/main/java/com/baeldung/deeplearning4j/IrisClassifier.java @@ -0,0 +1,80 @@ +package com.baeldung.deeplearning4j; + +import org.datavec.api.records.reader.RecordReader; +import org.datavec.api.records.reader.impl.csv.CSVRecordReader; +import org.datavec.api.split.FileSplit; +import org.datavec.api.util.ClassPathResource; +import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; +import org.deeplearning4j.eval.Evaluation; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.SplitTestAndTrain; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; +import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import java.io.IOException; + +public class IrisClassifier { + + private static final int CLASSES_COUNT = 3; + private static final int FEATURES_COUNT = 4; + + public static void main(String[] args) throws IOException, InterruptedException { + + DataSet allData; + try (RecordReader recordReader = new CSVRecordReader(0, ',')) { + recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); + + DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT); + allData = iterator.next(); + } + + allData.shuffle(42); + + DataNormalization normalizer = new NormalizerStandardize(); + normalizer.fit(allData); + normalizer.transform(allData); + + SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); + DataSet trainingData = testAndTrain.getTrain(); + DataSet testData = testAndTrain.getTest(); + + MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() + .iterations(1000) + .activation(Activation.TANH) + .weightInit(WeightInit.XAVIER) + .learningRate(0.1) + .regularization(true).l2(0.0001) + .list() + .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3) + .build()) + .layer(1, new DenseLayer.Builder().nIn(3).nOut(3) + .build()) + .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) + .activation(Activation.SOFTMAX) + .nIn(3).nOut(CLASSES_COUNT).build()) + .backprop(true).pretrain(false) + .build(); + + MultiLayerNetwork model = new MultiLayerNetwork(configuration); + model.init(); + model.fit(trainingData); + + INDArray output = model.output(testData.getFeatureMatrix()); + + Evaluation eval = new Evaluation(CLASSES_COUNT); + eval.eval(testData.getLabels(), output); + System.out.println(eval.stats()); + + } + +} diff --git a/deeplearning4j/src/main/resources/iris.txt b/deeplearning4j/src/main/resources/iris.txt new file mode 100644 index 0000000000..8b4511f8be --- /dev/null +++ b/deeplearning4j/src/main/resources/iris.txt @@ -0,0 +1,150 @@ +5.1,3.5,1.4,0.2,0 +4.9,3.0,1.4,0.2,0 +4.7,3.2,1.3,0.2,0 +4.6,3.1,1.5,0.2,0 +5.0,3.6,1.4,0.2,0 +5.4,3.9,1.7,0.4,0 +4.6,3.4,1.4,0.3,0 +5.0,3.4,1.5,0.2,0 +4.4,2.9,1.4,0.2,0 +4.9,3.1,1.5,0.1,0 +5.4,3.7,1.5,0.2,0 +4.8,3.4,1.6,0.2,0 +4.8,3.0,1.4,0.1,0 +4.3,3.0,1.1,0.1,0 +5.8,4.0,1.2,0.2,0 +5.7,4.4,1.5,0.4,0 +5.4,3.9,1.3,0.4,0 +5.1,3.5,1.4,0.3,0 +5.7,3.8,1.7,0.3,0 +5.1,3.8,1.5,0.3,0 +5.4,3.4,1.7,0.2,0 +5.1,3.7,1.5,0.4,0 +4.6,3.6,1.0,0.2,0 +5.1,3.3,1.7,0.5,0 +4.8,3.4,1.9,0.2,0 +5.0,3.0,1.6,0.2,0 +5.0,3.4,1.6,0.4,0 +5.2,3.5,1.5,0.2,0 +5.2,3.4,1.4,0.2,0 +4.7,3.2,1.6,0.2,0 +4.8,3.1,1.6,0.2,0 +5.4,3.4,1.5,0.4,0 +5.2,4.1,1.5,0.1,0 +5.5,4.2,1.4,0.2,0 +4.9,3.1,1.5,0.1,0 +5.0,3.2,1.2,0.2,0 +5.5,3.5,1.3,0.2,0 +4.9,3.1,1.5,0.1,0 +4.4,3.0,1.3,0.2,0 +5.1,3.4,1.5,0.2,0 +5.0,3.5,1.3,0.3,0 +4.5,2.3,1.3,0.3,0 +4.4,3.2,1.3,0.2,0 +5.0,3.5,1.6,0.6,0 +5.1,3.8,1.9,0.4,0 +4.8,3.0,1.4,0.3,0 +5.1,3.8,1.6,0.2,0 +4.6,3.2,1.4,0.2,0 +5.3,3.7,1.5,0.2,0 +5.0,3.3,1.4,0.2,0 +7.0,3.2,4.7,1.4,1 +6.4,3.2,4.5,1.5,1 +6.9,3.1,4.9,1.5,1 +5.5,2.3,4.0,1.3,1 +6.5,2.8,4.6,1.5,1 +5.7,2.8,4.5,1.3,1 +6.3,3.3,4.7,1.6,1 +4.9,2.4,3.3,1.0,1 +6.6,2.9,4.6,1.3,1 +5.2,2.7,3.9,1.4,1 +5.0,2.0,3.5,1.0,1 +5.9,3.0,4.2,1.5,1 +6.0,2.2,4.0,1.0,1 +6.1,2.9,4.7,1.4,1 +5.6,2.9,3.6,1.3,1 +6.7,3.1,4.4,1.4,1 +5.6,3.0,4.5,1.5,1 +5.8,2.7,4.1,1.0,1 +6.2,2.2,4.5,1.5,1 +5.6,2.5,3.9,1.1,1 +5.9,3.2,4.8,1.8,1 +6.1,2.8,4.0,1.3,1 +6.3,2.5,4.9,1.5,1 +6.1,2.8,4.7,1.2,1 +6.4,2.9,4.3,1.3,1 +6.6,3.0,4.4,1.4,1 +6.8,2.8,4.8,1.4,1 +6.7,3.0,5.0,1.7,1 +6.0,2.9,4.5,1.5,1 +5.7,2.6,3.5,1.0,1 +5.5,2.4,3.8,1.1,1 +5.5,2.4,3.7,1.0,1 +5.8,2.7,3.9,1.2,1 +6.0,2.7,5.1,1.6,1 +5.4,3.0,4.5,1.5,1 +6.0,3.4,4.5,1.6,1 +6.7,3.1,4.7,1.5,1 +6.3,2.3,4.4,1.3,1 +5.6,3.0,4.1,1.3,1 +5.5,2.5,4.0,1.3,1 +5.5,2.6,4.4,1.2,1 +6.1,3.0,4.6,1.4,1 +5.8,2.6,4.0,1.2,1 +5.0,2.3,3.3,1.0,1 +5.6,2.7,4.2,1.3,1 +5.7,3.0,4.2,1.2,1 +5.7,2.9,4.2,1.3,1 +6.2,2.9,4.3,1.3,1 +5.1,2.5,3.0,1.1,1 +5.7,2.8,4.1,1.3,1 +6.3,3.3,6.0,2.5,2 +5.8,2.7,5.1,1.9,2 +7.1,3.0,5.9,2.1,2 +6.3,2.9,5.6,1.8,2 +6.5,3.0,5.8,2.2,2 +7.6,3.0,6.6,2.1,2 +4.9,2.5,4.5,1.7,2 +7.3,2.9,6.3,1.8,2 +6.7,2.5,5.8,1.8,2 +7.2,3.6,6.1,2.5,2 +6.5,3.2,5.1,2.0,2 +6.4,2.7,5.3,1.9,2 +6.8,3.0,5.5,2.1,2 +5.7,2.5,5.0,2.0,2 +5.8,2.8,5.1,2.4,2 +6.4,3.2,5.3,2.3,2 +6.5,3.0,5.5,1.8,2 +7.7,3.8,6.7,2.2,2 +7.7,2.6,6.9,2.3,2 +6.0,2.2,5.0,1.5,2 +6.9,3.2,5.7,2.3,2 +5.6,2.8,4.9,2.0,2 +7.7,2.8,6.7,2.0,2 +6.3,2.7,4.9,1.8,2 +6.7,3.3,5.7,2.1,2 +7.2,3.2,6.0,1.8,2 +6.2,2.8,4.8,1.8,2 +6.1,3.0,4.9,1.8,2 +6.4,2.8,5.6,2.1,2 +7.2,3.0,5.8,1.6,2 +7.4,2.8,6.1,1.9,2 +7.9,3.8,6.4,2.0,2 +6.4,2.8,5.6,2.2,2 +6.3,2.8,5.1,1.5,2 +6.1,2.6,5.6,1.4,2 +7.7,3.0,6.1,2.3,2 +6.3,3.4,5.6,2.4,2 +6.4,3.1,5.5,1.8,2 +6.0,3.0,4.8,1.8,2 +6.9,3.1,5.4,2.1,2 +6.7,3.1,5.6,2.4,2 +6.9,3.1,5.1,2.3,2 +5.8,2.7,5.1,1.9,2 +6.8,3.2,5.9,2.3,2 +6.7,3.3,5.7,2.5,2 +6.7,3.0,5.2,2.3,2 +6.3,2.5,5.0,1.9,2 +6.5,3.0,5.2,2.0,2 +6.2,3.4,5.4,2.3,2 +5.9,3.0,5.1,1.8,2 diff --git a/pom.xml b/pom.xml index ff4b490c6b..c3915e4fce 100644 --- a/pom.xml +++ b/pom.xml @@ -248,6 +248,7 @@ mockserver undertow vertx-and-rxjava + deeplearning4j