Add a code on logistic regression article (BAEL-3081)

This commit is contained in:
Andrew Shcherbakov 2019-08-04 16:03:12 +02:00
parent c06471c727
commit b01bd36786
7 changed files with 393 additions and 0 deletions

5
ml/README.md Normal file
View File

@ -0,0 +1,5 @@
### Sample deeplearning4j Project
This is a sample project for the [deeplearning4j](https://deeplearning4j.org) library.
### Relevant Articles:
- [A Guide to Deeplearning4j](http://www.baeldung.com/deeplearning4j)

52
ml/pom.xml Normal file
View File

@ -0,0 +1,52 @@
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.baeldung.deeplearning4j</groupId>
<artifactId>ml</artifactId>
<version>1.0-SNAPSHOT</version>
<name>Machine Learning</name>
<packaging>jar</packaging>
<parent>
<groupId>com.baeldung</groupId>
<artifactId>parent-modules</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nn</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.datavec/datavec-api -->
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>4.3.5</version>
</dependency>
</dependencies>
<properties>
<dl4j.version>1.0.0-beta4</dl4j.version>
</properties>
</project>

View File

@ -0,0 +1,102 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package com.baeldung.logreg;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import java.io.*;
/**
* Common data utility functions.
*
* @author fvaleri
*/
public class DataUtilities {
/**
* Download a remote file if it doesn't exist.
* @param remoteUrl URL of the remote file.
* @param localPath Where to download the file.
* @return True if and only if the file has been downloaded.
* @throws Exception IO error.
*/
public static boolean downloadFile(String remoteUrl, String localPath) throws IOException {
boolean downloaded = false;
if (remoteUrl == null || localPath == null)
return downloaded;
File file = new File(localPath);
if (!file.exists()) {
file.getParentFile().mkdirs();
HttpClientBuilder builder = HttpClientBuilder.create();
CloseableHttpClient client = builder.build();
try (CloseableHttpResponse response = client.execute(new HttpGet(remoteUrl))) {
HttpEntity entity = response.getEntity();
if (entity != null) {
try (FileOutputStream outstream = new FileOutputStream(file)) {
entity.writeTo(outstream);
outstream.flush();
outstream.close();
}
}
}
downloaded = true;
}
if (!file.exists())
throw new IOException("File doesn't exist: " + localPath);
return downloaded;
}
/**
* Extract a "tar.gz" file into a local folder.
* @param inputPath Input file path.
* @param outputPath Output directory path.
* @throws IOException IO error.
*/
public static void extractTarGz(String inputPath, String outputPath) throws IOException {
if (inputPath == null || outputPath == null)
return;
final int bufferSize = 4096;
if (!outputPath.endsWith("" + File.separatorChar))
outputPath = outputPath + File.separatorChar;
try (TarArchiveInputStream tais = new TarArchiveInputStream(
new GzipCompressorInputStream(new BufferedInputStream(new FileInputStream(inputPath))))) {
TarArchiveEntry entry;
while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {
if (entry.isDirectory()) {
new File(outputPath + entry.getName()).mkdirs();
} else {
int count;
byte data[] = new byte[bufferSize];
FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());
BufferedOutputStream dest = new BufferedOutputStream(fos, bufferSize);
while ((count = tais.read(data, 0, bufferSize)) != -1) {
dest.write(data, 0, count);
}
dest.close();
}
}
}
}
}

View File

@ -0,0 +1,166 @@
package com.baeldung.logreg;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Handwritten digit image classification based on LeNet-5 architecture by Yann LeCun.
*
* This code accompanies the article "Logistic regression in Java" and is heavily based on
* <a href="https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/mnist/MnistClassifier.java">MnistClassifier</a>.
* Some minor changes have been made in order to make article's flow smoother.
*
*/
public class MnistClassifier {
private static final Logger logger = LoggerFactory.getLogger(MnistClassifier.class);
private static final String basePath = System.getProperty("java.io.tmpdir") + "mnist" + File.separator;
private static final File modelPath = new File(basePath + "mnist-model.zip");
private static final String dataUrl = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
public static void main(String[] args) throws Exception {
// input image sizes in pixels
int height = 28;
int width = 28;
// input image colour depth (1 for gray scale images)
int channels = 1;
// the number of output classes
int outputClasses = 10;
// number of samples that will be propagated through the network in each iteration
int batchSize = 54;
// total number of training epochs
int epochs = 1;
// initialize a pseudorandom number generator
int seed = 1234;
Random randNumGen = new Random(seed);
final String path = basePath + "mnist_png" + File.separator;
if (!new File(path).exists()) {
logger.debug("Downloading data {}", dataUrl);
String localFilePath = basePath + "mnist_png.tar.gz";
logger.info("local file: {}", localFilePath);
if (DataUtilities.downloadFile(dataUrl, localFilePath)) {
DataUtilities.extractTarGz(localFilePath, basePath);
}
} else {
logger.info("local file exists {}", path);
}
logger.info("Vectorizing data...");
// vectorization of train data
File trainData = new File(path + "training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
// use parent directory name as the image label
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
trainRR.initialize(trainSplit);
DataSetIterator train = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputClasses);
// pixel values from 0-255 to 0-1 (min-max scaling)
DataNormalization imageScaler = new ImagePreProcessingScaler();
imageScaler.fit(train);
train.setPreProcessor(imageScaler);
// vectorization of test data
File testData = new File(path + "testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
testRR.initialize(testSplit);
DataSetIterator test = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputClasses);
// same normalization for better results
test.setPreProcessor(imageScaler);
logger.info("Network configuration and training...");
// reduce the learning rate as the number of training epochs increases
// iteration #, learning rate
Map<Integer, Double> learningRateSchedule = new HashMap<>();
learningRateSchedule.put(0, 0.06);
learningRateSchedule.put(200, 0.05);
learningRateSchedule.put(600, 0.028);
learningRateSchedule.put(800, 0.0060);
learningRateSchedule.put(1000, 0.001);
final ConvolutionLayer layer1 = new ConvolutionLayer.Builder(5, 5).nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build();
final SubsamplingLayer layer2 = new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2)
.build();
// nIn need not specified in later layers
final ConvolutionLayer layer3 = new ConvolutionLayer.Builder(5, 5).stride(1, 1)
.nOut(50)
.activation(Activation.IDENTITY)
.build();
final DenseLayer layer4 = new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500)
.build();
final OutputLayer layer5 = new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputClasses)
.activation(Activation.SOFTMAX)
.build();
final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(seed)
.l2(0.0005) // ridge regression value
.updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
.weightInit(WeightInit.XAVIER)
.list()
.layer(layer1)
.layer(layer2)
.layer(layer3)
.layer(layer2)
.layer(layer4)
.layer(layer5)
.setInputType(InputType.convolutionalFlat(height, width, channels))
.build();
final MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
model.setListeners(new ScoreIterationListener(100));
logger.info("Total num of params: {}", model.numParams());
// evaluation while training (the score should go down)
for (int i = 0; i < epochs; i++) {
model.fit(train);
logger.info("Completed epoch {}", i);
train.reset();
test.reset();
}
Evaluation eval = model.evaluate(test);
logger.info(eval.stats());
ModelSerializer.writeModel(model, modelPath, true);
logger.info("The MINIST model has been saved in {}", modelPath.getPath());
}
}

View File

@ -0,0 +1,53 @@
package com.baeldung.logreg;
import java.io.File;
import java.io.IOException;
import javax.swing.JFileChooser;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class MnistPrediction {
private static final Logger logger = LoggerFactory.getLogger(MnistPrediction.class);
private static final File modelPath = new File(System.getProperty("java.io.tmpdir") + "mnist" + File.separator + "mnist-model.zip");
private static final int height = 28;
private static final int width = 28;
private static final int channels = 1;
/**
* Opens a popup that allows to select a file from the filesystem.
* @return
*/
public static String fileChose() {
JFileChooser fc = new JFileChooser();
int ret = fc.showOpenDialog(null);
if (ret == JFileChooser.APPROVE_OPTION) {
File file = fc.getSelectedFile();
return file.getAbsolutePath();
} else {
return null;
}
}
public static void main(String[] args) throws IOException {
String path = fileChose().toString();
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath);
File file = new File(path);
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
new ImagePreProcessingScaler(0, 1).transform(image);
// Pass through to neural Net
INDArray output = model.output(image);
logger.info("File: {}", path);
logger.info(output.toString());
}
}

View File

@ -0,0 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
</pattern>
</encoder>
</appender>
<root level="INFO">
<appender-ref ref="STDOUT" />
</root>
</configuration>

View File

@ -524,6 +524,7 @@
<module>metrics</module>
<!-- <module>micronaut</module> --> <!-- Fixing in BAEL-10877 -->
<module>microprofile</module>
<module>ml</module>
<module>msf4j</module>
<!-- <module>muleesb</module> --> <!-- Fixing in BAEL-10878 -->
<module>mustache</module>
@ -1216,6 +1217,7 @@
<module>metrics</module>
<!-- <module>micronaut</module> --> <!-- Fixing in BAEL-10877 -->
<module>microprofile</module>
<module>ml</module>
<module>msf4j</module>
<!-- <module>muleesb</module> --> <!-- Fixing in BAEL-10878 -->
<module>mustache</module>