java-tutorials/ml/src/main/java/com/baeldung/logreg/MnistPrediction.java
2019-08-04 16:03:12 +02:00

54 lines
1.8 KiB
Java

package com.baeldung.logreg;
import java.io.File;
import java.io.IOException;
import javax.swing.JFileChooser;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class MnistPrediction {
private static final Logger logger = LoggerFactory.getLogger(MnistPrediction.class);
private static final File modelPath = new File(System.getProperty("java.io.tmpdir") + "mnist" + File.separator + "mnist-model.zip");
private static final int height = 28;
private static final int width = 28;
private static final int channels = 1;
/**
* Opens a popup that allows to select a file from the filesystem.
* @return
*/
public static String fileChose() {
JFileChooser fc = new JFileChooser();
int ret = fc.showOpenDialog(null);
if (ret == JFileChooser.APPROVE_OPTION) {
File file = fc.getSelectedFile();
return file.getAbsolutePath();
} else {
return null;
}
}
public static void main(String[] args) throws IOException {
String path = fileChose().toString();
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath);
File file = new File(path);
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
new ImagePreProcessingScaler(0, 1).transform(image);
// Pass through to neural Net
INDArray output = model.output(image);
logger.info("File: {}", path);
logger.info(output.toString());
}
}