From b01bd36786e5f78c7ade055b556a84731a19505e Mon Sep 17 00:00:00 2001 From: Andrew Shcherbakov Date: Sun, 4 Aug 2019 16:03:12 +0200 Subject: [PATCH] Add a code on logistic regression article (BAEL-3081) --- ml/README.md | 5 + ml/pom.xml | 52 ++++++ .../com/baeldung/logreg/DataUtilities.java | 102 +++++++++++ .../com/baeldung/logreg/MnistClassifier.java | 166 ++++++++++++++++++ .../com/baeldung/logreg/MnistPrediction.java | 53 ++++++ ml/src/main/resources/logback.xml | 13 ++ pom.xml | 2 + 7 files changed, 393 insertions(+) create mode 100644 ml/README.md create mode 100644 ml/pom.xml create mode 100644 ml/src/main/java/com/baeldung/logreg/DataUtilities.java create mode 100644 ml/src/main/java/com/baeldung/logreg/MnistClassifier.java create mode 100644 ml/src/main/java/com/baeldung/logreg/MnistPrediction.java create mode 100644 ml/src/main/resources/logback.xml diff --git a/ml/README.md b/ml/README.md new file mode 100644 index 0000000000..14e585cd97 --- /dev/null +++ b/ml/README.md @@ -0,0 +1,5 @@ +### Sample deeplearning4j Project +This is a sample project for the [deeplearning4j](https://deeplearning4j.org) library. + +### Relevant Articles: +- [A Guide to Deeplearning4j](http://www.baeldung.com/deeplearning4j) diff --git a/ml/pom.xml b/ml/pom.xml new file mode 100644 index 0000000000..80afcc24f4 --- /dev/null +++ b/ml/pom.xml @@ -0,0 +1,52 @@ + + 4.0.0 + com.baeldung.deeplearning4j + ml + 1.0-SNAPSHOT + Machine Learning + jar + + + com.baeldung + parent-modules + 1.0.0-SNAPSHOT + + + + + org.nd4j + nd4j-native-platform + ${dl4j.version} + + + org.deeplearning4j + deeplearning4j-core + ${dl4j.version} + + + org.deeplearning4j + deeplearning4j-nn + ${dl4j.version} + + + + org.datavec + datavec-api + ${dl4j.version} + + + org.apache.httpcomponents + httpclient + 4.3.5 + + + + + + + 1.0.0-beta4 + + + \ No newline at end of file diff --git a/ml/src/main/java/com/baeldung/logreg/DataUtilities.java b/ml/src/main/java/com/baeldung/logreg/DataUtilities.java new file mode 100644 index 0000000000..2f18d30219 --- /dev/null +++ b/ml/src/main/java/com/baeldung/logreg/DataUtilities.java @@ -0,0 +1,102 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package com.baeldung.logreg; + +import org.apache.commons.compress.archivers.tar.TarArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; +import org.apache.http.HttpEntity; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; + +import java.io.*; + +/** + * Common data utility functions. + * + * @author fvaleri + */ +public class DataUtilities { + + /** + * Download a remote file if it doesn't exist. + * @param remoteUrl URL of the remote file. + * @param localPath Where to download the file. + * @return True if and only if the file has been downloaded. + * @throws Exception IO error. + */ + public static boolean downloadFile(String remoteUrl, String localPath) throws IOException { + boolean downloaded = false; + if (remoteUrl == null || localPath == null) + return downloaded; + File file = new File(localPath); + if (!file.exists()) { + file.getParentFile().mkdirs(); + HttpClientBuilder builder = HttpClientBuilder.create(); + CloseableHttpClient client = builder.build(); + try (CloseableHttpResponse response = client.execute(new HttpGet(remoteUrl))) { + HttpEntity entity = response.getEntity(); + if (entity != null) { + try (FileOutputStream outstream = new FileOutputStream(file)) { + entity.writeTo(outstream); + outstream.flush(); + outstream.close(); + } + } + } + downloaded = true; + } + if (!file.exists()) + throw new IOException("File doesn't exist: " + localPath); + return downloaded; + } + + /** + * Extract a "tar.gz" file into a local folder. + * @param inputPath Input file path. + * @param outputPath Output directory path. + * @throws IOException IO error. + */ + public static void extractTarGz(String inputPath, String outputPath) throws IOException { + if (inputPath == null || outputPath == null) + return; + final int bufferSize = 4096; + if (!outputPath.endsWith("" + File.separatorChar)) + outputPath = outputPath + File.separatorChar; + try (TarArchiveInputStream tais = new TarArchiveInputStream( + new GzipCompressorInputStream(new BufferedInputStream(new FileInputStream(inputPath))))) { + TarArchiveEntry entry; + while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) { + if (entry.isDirectory()) { + new File(outputPath + entry.getName()).mkdirs(); + } else { + int count; + byte data[] = new byte[bufferSize]; + FileOutputStream fos = new FileOutputStream(outputPath + entry.getName()); + BufferedOutputStream dest = new BufferedOutputStream(fos, bufferSize); + while ((count = tais.read(data, 0, bufferSize)) != -1) { + dest.write(data, 0, count); + } + dest.close(); + } + } + } + } + +} diff --git a/ml/src/main/java/com/baeldung/logreg/MnistClassifier.java b/ml/src/main/java/com/baeldung/logreg/MnistClassifier.java new file mode 100644 index 0000000000..395307712d --- /dev/null +++ b/ml/src/main/java/com/baeldung/logreg/MnistClassifier.java @@ -0,0 +1,166 @@ +package com.baeldung.logreg; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +import org.datavec.api.io.labels.ParentPathLabelGenerator; +import org.datavec.api.split.FileSplit; +import org.datavec.image.loader.NativeImageLoader; +import org.datavec.image.recordreader.ImageRecordReader; +import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; +import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; +import org.nd4j.linalg.learning.config.Nesterovs; +import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.linalg.schedule.MapSchedule; +import org.nd4j.linalg.schedule.ScheduleType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Handwritten digit image classification based on LeNet-5 architecture by Yann LeCun. + * + * This code accompanies the article "Logistic regression in Java" and is heavily based on + * MnistClassifier. + * Some minor changes have been made in order to make article's flow smoother. + * + */ + +public class MnistClassifier { + private static final Logger logger = LoggerFactory.getLogger(MnistClassifier.class); + private static final String basePath = System.getProperty("java.io.tmpdir") + "mnist" + File.separator; + private static final File modelPath = new File(basePath + "mnist-model.zip"); + private static final String dataUrl = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz"; + + public static void main(String[] args) throws Exception { + // input image sizes in pixels + int height = 28; + int width = 28; + // input image colour depth (1 for gray scale images) + int channels = 1; + // the number of output classes + int outputClasses = 10; + // number of samples that will be propagated through the network in each iteration + int batchSize = 54; + // total number of training epochs + int epochs = 1; + + // initialize a pseudorandom number generator + int seed = 1234; + Random randNumGen = new Random(seed); + + final String path = basePath + "mnist_png" + File.separator; + if (!new File(path).exists()) { + logger.debug("Downloading data {}", dataUrl); + String localFilePath = basePath + "mnist_png.tar.gz"; + logger.info("local file: {}", localFilePath); + if (DataUtilities.downloadFile(dataUrl, localFilePath)) { + DataUtilities.extractTarGz(localFilePath, basePath); + } + } else { + logger.info("local file exists {}", path); + + } + + logger.info("Vectorizing data..."); + // vectorization of train data + File trainData = new File(path + "training"); + FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); + // use parent directory name as the image label + ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); + ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker); + trainRR.initialize(trainSplit); + DataSetIterator train = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputClasses); + + // pixel values from 0-255 to 0-1 (min-max scaling) + DataNormalization imageScaler = new ImagePreProcessingScaler(); + imageScaler.fit(train); + train.setPreProcessor(imageScaler); + + // vectorization of test data + File testData = new File(path + "testing"); + FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); + ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker); + testRR.initialize(testSplit); + DataSetIterator test = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputClasses); + // same normalization for better results + test.setPreProcessor(imageScaler); + + logger.info("Network configuration and training..."); + // reduce the learning rate as the number of training epochs increases + // iteration #, learning rate + Map learningRateSchedule = new HashMap<>(); + learningRateSchedule.put(0, 0.06); + learningRateSchedule.put(200, 0.05); + learningRateSchedule.put(600, 0.028); + learningRateSchedule.put(800, 0.0060); + learningRateSchedule.put(1000, 0.001); + + final ConvolutionLayer layer1 = new ConvolutionLayer.Builder(5, 5).nIn(channels) + .stride(1, 1) + .nOut(20) + .activation(Activation.IDENTITY) + .build(); + final SubsamplingLayer layer2 = new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) + .stride(2, 2) + .build(); + // nIn need not specified in later layers + final ConvolutionLayer layer3 = new ConvolutionLayer.Builder(5, 5).stride(1, 1) + .nOut(50) + .activation(Activation.IDENTITY) + .build(); + final DenseLayer layer4 = new DenseLayer.Builder().activation(Activation.RELU) + .nOut(500) + .build(); + final OutputLayer layer5 = new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputClasses) + .activation(Activation.SOFTMAX) + .build(); + final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(seed) + .l2(0.0005) // ridge regression value + .updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule))) + .weightInit(WeightInit.XAVIER) + .list() + .layer(layer1) + .layer(layer2) + .layer(layer3) + .layer(layer2) + .layer(layer4) + .layer(layer5) + .setInputType(InputType.convolutionalFlat(height, width, channels)) + .build(); + + final MultiLayerNetwork model = new MultiLayerNetwork(config); + model.init(); + model.setListeners(new ScoreIterationListener(100)); + logger.info("Total num of params: {}", model.numParams()); + + // evaluation while training (the score should go down) + for (int i = 0; i < epochs; i++) { + model.fit(train); + logger.info("Completed epoch {}", i); + train.reset(); + test.reset(); + } + Evaluation eval = model.evaluate(test); + logger.info(eval.stats()); + + ModelSerializer.writeModel(model, modelPath, true); + logger.info("The MINIST model has been saved in {}", modelPath.getPath()); + } +} \ No newline at end of file diff --git a/ml/src/main/java/com/baeldung/logreg/MnistPrediction.java b/ml/src/main/java/com/baeldung/logreg/MnistPrediction.java new file mode 100644 index 0000000000..5ec1348e07 --- /dev/null +++ b/ml/src/main/java/com/baeldung/logreg/MnistPrediction.java @@ -0,0 +1,53 @@ +package com.baeldung.logreg; + +import java.io.File; +import java.io.IOException; + +import javax.swing.JFileChooser; + +import org.datavec.image.loader.NativeImageLoader; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class MnistPrediction { + private static final Logger logger = LoggerFactory.getLogger(MnistPrediction.class); + private static final File modelPath = new File(System.getProperty("java.io.tmpdir") + "mnist" + File.separator + "mnist-model.zip"); + private static final int height = 28; + private static final int width = 28; + private static final int channels = 1; + + /** + * Opens a popup that allows to select a file from the filesystem. + * @return + */ + public static String fileChose() { + JFileChooser fc = new JFileChooser(); + int ret = fc.showOpenDialog(null); + if (ret == JFileChooser.APPROVE_OPTION) { + File file = fc.getSelectedFile(); + return file.getAbsolutePath(); + } else { + return null; + } + } + + public static void main(String[] args) throws IOException { + String path = fileChose().toString(); + MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath); + File file = new File(path); + + INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file); + new ImagePreProcessingScaler(0, 1).transform(image); + + // Pass through to neural Net + INDArray output = model.output(image); + + logger.info("File: {}", path); + logger.info(output.toString()); + } + +} diff --git a/ml/src/main/resources/logback.xml b/ml/src/main/resources/logback.xml new file mode 100644 index 0000000000..7d900d8ea8 --- /dev/null +++ b/ml/src/main/resources/logback.xml @@ -0,0 +1,13 @@ + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + \ No newline at end of file diff --git a/pom.xml b/pom.xml index 352da33fee..7e9ce85b50 100644 --- a/pom.xml +++ b/pom.xml @@ -524,6 +524,7 @@ metrics microprofile + ml msf4j mustache @@ -1216,6 +1217,7 @@ metrics microprofile + ml msf4j mustache