Merge pull request #7502 from veontomo/BAEL-3081
Code for article "Logistic Regression in Java" (BAEL-3081)
This commit is contained in:
commit
12d8c8dd30
|
@ -0,0 +1,5 @@
|
||||||
|
### Logistic Regression in Java
|
||||||
|
This is a soft introduction to ML using [deeplearning4j](https://deeplearning4j.org) library
|
||||||
|
|
||||||
|
### Relevant Articles:
|
||||||
|
- [Logistic Regression in Java](http://www.baeldung.com/)
|
|
@ -0,0 +1,52 @@
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
<groupId>com.baeldung.deeplearning4j</groupId>
|
||||||
|
<artifactId>ml</artifactId>
|
||||||
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
<name>Machine Learning</name>
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
|
<parent>
|
||||||
|
<groupId>com.baeldung</groupId>
|
||||||
|
<artifactId>parent-modules</artifactId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-native-platform</artifactId>
|
||||||
|
<version>${dl4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-core</artifactId>
|
||||||
|
<version>${dl4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-nn</artifactId>
|
||||||
|
<version>${dl4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<!-- https://mvnrepository.com/artifact/org.datavec/datavec-api -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.datavec</groupId>
|
||||||
|
<artifactId>datavec-api</artifactId>
|
||||||
|
<version>${dl4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.httpcomponents</groupId>
|
||||||
|
<artifactId>httpclient</artifactId>
|
||||||
|
<version>4.3.5</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<dl4j.version>1.0.0-beta4</dl4j.version>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
</project>
|
|
@ -0,0 +1,168 @@
|
||||||
|
package com.baeldung.logreg;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
||||||
|
import org.datavec.api.split.FileSplit;
|
||||||
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
import org.nd4j.linalg.schedule.MapSchedule;
|
||||||
|
import org.nd4j.linalg.schedule.ScheduleType;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handwritten digit image classification based on LeNet-5 architecture by Yann LeCun.
|
||||||
|
*
|
||||||
|
* This code accompanies the article "Logistic regression in Java" and is heavily based on
|
||||||
|
* <a href="https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/mnist/MnistClassifier.java">MnistClassifier</a>.
|
||||||
|
* Some minor changes have been made in order to make article's flow smoother.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
public class MnistClassifier {
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(MnistClassifier.class);
|
||||||
|
private static final String basePath = System.getProperty("java.io.tmpdir") + "mnist" + File.separator;
|
||||||
|
private static final File modelPath = new File(basePath + "mnist-model.zip");
|
||||||
|
private static final String dataUrl = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
|
||||||
|
|
||||||
|
public static void main(String[] args) throws Exception {
|
||||||
|
// input image sizes in pixels
|
||||||
|
int height = 28;
|
||||||
|
int width = 28;
|
||||||
|
// input image colour depth (1 for gray scale images)
|
||||||
|
int channels = 1;
|
||||||
|
// the number of output classes
|
||||||
|
int outputClasses = 10;
|
||||||
|
// number of samples that will be propagated through the network in each iteration
|
||||||
|
int batchSize = 54;
|
||||||
|
// total number of training epochs
|
||||||
|
int epochs = 1;
|
||||||
|
|
||||||
|
// initialize a pseudorandom number generator
|
||||||
|
int seed = 1234;
|
||||||
|
Random randNumGen = new Random(seed);
|
||||||
|
|
||||||
|
final String path = basePath + "mnist_png" + File.separator;
|
||||||
|
if (!new File(path).exists()) {
|
||||||
|
logger.info("Downloading data {}", dataUrl);
|
||||||
|
String localFilePath = basePath + "mnist_png.tar.gz";
|
||||||
|
File file = new File(localFilePath);
|
||||||
|
if (!file.exists()) {
|
||||||
|
file.getParentFile()
|
||||||
|
.mkdirs();
|
||||||
|
Utils.downloadAndSave(dataUrl, file);
|
||||||
|
Utils.extractTarArchive(file, basePath);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.info("Using the local data from folder {}", path);
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Vectorizing the data from folder {}", path);
|
||||||
|
// vectorization of train data
|
||||||
|
File trainData = new File(path + "training");
|
||||||
|
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
|
||||||
|
// use parent directory name as the image label
|
||||||
|
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
||||||
|
ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
|
||||||
|
trainRR.initialize(trainSplit);
|
||||||
|
DataSetIterator train = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputClasses);
|
||||||
|
|
||||||
|
// pixel values from 0-255 to 0-1 (min-max scaling)
|
||||||
|
DataNormalization imageScaler = new ImagePreProcessingScaler();
|
||||||
|
imageScaler.fit(train);
|
||||||
|
train.setPreProcessor(imageScaler);
|
||||||
|
|
||||||
|
// vectorization of test data
|
||||||
|
File testData = new File(path + "testing");
|
||||||
|
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
|
||||||
|
ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
|
||||||
|
testRR.initialize(testSplit);
|
||||||
|
DataSetIterator test = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputClasses);
|
||||||
|
// same normalization for better results
|
||||||
|
test.setPreProcessor(imageScaler);
|
||||||
|
|
||||||
|
logger.info("Network configuration and training...");
|
||||||
|
// reduce the learning rate as the number of training epochs increases
|
||||||
|
// iteration #, learning rate
|
||||||
|
Map<Integer, Double> learningRateSchedule = new HashMap<>();
|
||||||
|
learningRateSchedule.put(0, 0.06);
|
||||||
|
learningRateSchedule.put(200, 0.05);
|
||||||
|
learningRateSchedule.put(600, 0.028);
|
||||||
|
learningRateSchedule.put(800, 0.0060);
|
||||||
|
learningRateSchedule.put(1000, 0.001);
|
||||||
|
|
||||||
|
final ConvolutionLayer layer1 = new ConvolutionLayer.Builder(5, 5).nIn(channels)
|
||||||
|
.stride(1, 1)
|
||||||
|
.nOut(20)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.build();
|
||||||
|
final SubsamplingLayer layer2 = new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
|
||||||
|
.stride(2, 2)
|
||||||
|
.build();
|
||||||
|
// nIn need not specified in later layers
|
||||||
|
final ConvolutionLayer layer3 = new ConvolutionLayer.Builder(5, 5).stride(1, 1)
|
||||||
|
.nOut(50)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.build();
|
||||||
|
final DenseLayer layer4 = new DenseLayer.Builder().activation(Activation.RELU)
|
||||||
|
.nOut(500)
|
||||||
|
.build();
|
||||||
|
final OutputLayer layer5 = new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputClasses)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.build();
|
||||||
|
final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(seed)
|
||||||
|
.l2(0.0005) // ridge regression value
|
||||||
|
.updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.list()
|
||||||
|
.layer(layer1)
|
||||||
|
.layer(layer2)
|
||||||
|
.layer(layer3)
|
||||||
|
.layer(layer2)
|
||||||
|
.layer(layer4)
|
||||||
|
.layer(layer5)
|
||||||
|
.setInputType(InputType.convolutionalFlat(height, width, channels))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final MultiLayerNetwork model = new MultiLayerNetwork(config);
|
||||||
|
model.init();
|
||||||
|
model.setListeners(new ScoreIterationListener(100));
|
||||||
|
logger.info("Total num of params: {}", model.numParams());
|
||||||
|
|
||||||
|
// evaluation while training (the score should go down)
|
||||||
|
for (int i = 0; i < epochs; i++) {
|
||||||
|
model.fit(train);
|
||||||
|
logger.info("Completed epoch {}", i);
|
||||||
|
train.reset();
|
||||||
|
test.reset();
|
||||||
|
}
|
||||||
|
Evaluation eval = model.evaluate(test);
|
||||||
|
logger.info(eval.stats());
|
||||||
|
|
||||||
|
ModelSerializer.writeModel(model, modelPath, true);
|
||||||
|
logger.info("The MINIST model has been saved in {}", modelPath.getPath());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
package com.baeldung.logreg;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import javax.swing.JFileChooser;
|
||||||
|
|
||||||
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
public class MnistPrediction {
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(MnistPrediction.class);
|
||||||
|
private static final File modelPath = new File(System.getProperty("java.io.tmpdir") + "mnist" + File.separator + "mnist-model.zip");
|
||||||
|
private static final int height = 28;
|
||||||
|
private static final int width = 28;
|
||||||
|
private static final int channels = 1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Opens a popup that allows to select a file from the filesystem.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static String fileChose() {
|
||||||
|
JFileChooser fc = new JFileChooser();
|
||||||
|
int ret = fc.showOpenDialog(null);
|
||||||
|
if (ret == JFileChooser.APPROVE_OPTION) {
|
||||||
|
File file = fc.getSelectedFile();
|
||||||
|
return file.getAbsolutePath();
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) throws IOException {
|
||||||
|
if (!modelPath.exists()) {
|
||||||
|
logger.info("The model not found. Have you trained it?");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath);
|
||||||
|
String path = fileChose();
|
||||||
|
File file = new File(path);
|
||||||
|
|
||||||
|
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
|
||||||
|
new ImagePreProcessingScaler(0, 1).transform(image);
|
||||||
|
|
||||||
|
// Pass through to neural Net
|
||||||
|
INDArray output = model.output(image);
|
||||||
|
|
||||||
|
logger.info("File: {}", path);
|
||||||
|
logger.info("Probabilities: {}", output);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,103 @@
|
||||||
|
package com.baeldung.logreg;
|
||||||
|
|
||||||
|
import java.io.BufferedInputStream;
|
||||||
|
import java.io.BufferedOutputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
import org.apache.commons.compress.archivers.ArchiveEntry;
|
||||||
|
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
|
||||||
|
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
|
||||||
|
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
|
||||||
|
import org.apache.http.HttpEntity;
|
||||||
|
import org.apache.http.client.methods.CloseableHttpResponse;
|
||||||
|
import org.apache.http.client.methods.HttpGet;
|
||||||
|
import org.apache.http.impl.client.CloseableHttpClient;
|
||||||
|
import org.apache.http.impl.client.HttpClientBuilder;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility class for digit classifier.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class Utils {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(Utils.class);
|
||||||
|
|
||||||
|
private Utils() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Download the content of the given url and save it into a file.
|
||||||
|
* @param url
|
||||||
|
* @param file
|
||||||
|
*/
|
||||||
|
public static void downloadAndSave(String url, File file) throws IOException {
|
||||||
|
CloseableHttpClient client = HttpClientBuilder.create()
|
||||||
|
.build();
|
||||||
|
logger.info("Connecting to {}", url);
|
||||||
|
try (CloseableHttpResponse response = client.execute(new HttpGet(url))) {
|
||||||
|
HttpEntity entity = response.getEntity();
|
||||||
|
if (entity != null) {
|
||||||
|
logger.info("Downloaded {} bytes", entity.getContentLength());
|
||||||
|
try (FileOutputStream outstream = new FileOutputStream(file)) {
|
||||||
|
logger.info("Saving to the local file");
|
||||||
|
entity.writeTo(outstream);
|
||||||
|
outstream.flush();
|
||||||
|
logger.info("Local file saved");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract a "tar.gz" file into a given folder.
|
||||||
|
* @param file
|
||||||
|
* @param folder
|
||||||
|
*/
|
||||||
|
public static void extractTarArchive(File file, String folder) throws IOException {
|
||||||
|
logger.info("Extracting archive {} into folder {}", file.getName(), folder);
|
||||||
|
// @formatter:off
|
||||||
|
try (FileInputStream fis = new FileInputStream(file);
|
||||||
|
BufferedInputStream bis = new BufferedInputStream(fis);
|
||||||
|
GzipCompressorInputStream gzip = new GzipCompressorInputStream(bis);
|
||||||
|
TarArchiveInputStream tar = new TarArchiveInputStream(gzip)) {
|
||||||
|
// @formatter:on
|
||||||
|
TarArchiveEntry entry;
|
||||||
|
while ((entry = (TarArchiveEntry) tar.getNextEntry()) != null) {
|
||||||
|
extractEntry(entry, tar, folder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.info("Archive extracted");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract an entry of the input stream into a given folder
|
||||||
|
* @param entry
|
||||||
|
* @param tar
|
||||||
|
* @param folder
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
|
public static void extractEntry(ArchiveEntry entry, InputStream tar, String folder) throws IOException {
|
||||||
|
final int bufferSize = 4096;
|
||||||
|
final String path = folder + entry.getName();
|
||||||
|
if (entry.isDirectory()) {
|
||||||
|
new File(path).mkdirs();
|
||||||
|
} else {
|
||||||
|
int count;
|
||||||
|
byte[] data = new byte[bufferSize];
|
||||||
|
// @formatter:off
|
||||||
|
try (FileOutputStream os = new FileOutputStream(path);
|
||||||
|
BufferedOutputStream dest = new BufferedOutputStream(os, bufferSize)) {
|
||||||
|
// @formatter:off
|
||||||
|
while ((count = tar.read(data, 0, bufferSize)) != -1) {
|
||||||
|
dest.write(data, 0, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<configuration>
|
||||||
|
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||||
|
<encoder>
|
||||||
|
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
|
||||||
|
</pattern>
|
||||||
|
</encoder>
|
||||||
|
</appender>
|
||||||
|
|
||||||
|
<root level="INFO">
|
||||||
|
<appender-ref ref="STDOUT" />
|
||||||
|
</root>
|
||||||
|
</configuration>
|
2
pom.xml
2
pom.xml
|
@ -535,6 +535,7 @@
|
||||||
<module>metrics</module>
|
<module>metrics</module>
|
||||||
<!-- <module>micronaut</module> --> <!-- Fixing in BAEL-10877 -->
|
<!-- <module>micronaut</module> --> <!-- Fixing in BAEL-10877 -->
|
||||||
<module>microprofile</module>
|
<module>microprofile</module>
|
||||||
|
<module>ml</module>
|
||||||
<module>msf4j</module>
|
<module>msf4j</module>
|
||||||
<!-- <module>muleesb</module> --> <!-- Fixing in BAEL-10878 -->
|
<!-- <module>muleesb</module> --> <!-- Fixing in BAEL-10878 -->
|
||||||
<module>mustache</module>
|
<module>mustache</module>
|
||||||
|
@ -1253,6 +1254,7 @@
|
||||||
<module>metrics</module>
|
<module>metrics</module>
|
||||||
<!-- <module>micronaut</module> --> <!-- Fixing in BAEL-10877 -->
|
<!-- <module>micronaut</module> --> <!-- Fixing in BAEL-10877 -->
|
||||||
<module>microprofile</module>
|
<module>microprofile</module>
|
||||||
|
<module>ml</module>
|
||||||
<module>msf4j</module>
|
<module>msf4j</module>
|
||||||
<!-- <module>muleesb</module> --> <!-- Fixing in BAEL-10878 -->
|
<!-- <module>muleesb</module> --> <!-- Fixing in BAEL-10878 -->
|
||||||
<module>mustache</module>
|
<module>mustache</module>
|
||||||
|
|
Loading…
Reference in New Issue