54 lines
1.8 KiB
Java
54 lines
1.8 KiB
Java
package com.baeldung.logreg;
|
|
|
|
import java.io.File;
|
|
import java.io.IOException;
|
|
|
|
import javax.swing.JFileChooser;
|
|
|
|
import org.datavec.image.loader.NativeImageLoader;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.deeplearning4j.util.ModelSerializer;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
public class MnistPrediction {
|
|
private static final Logger logger = LoggerFactory.getLogger(MnistPrediction.class);
|
|
private static final File modelPath = new File(System.getProperty("java.io.tmpdir") + "mnist" + File.separator + "mnist-model.zip");
|
|
private static final int height = 28;
|
|
private static final int width = 28;
|
|
private static final int channels = 1;
|
|
|
|
/**
|
|
* Opens a popup that allows to select a file from the filesystem.
|
|
* @return
|
|
*/
|
|
public static String fileChose() {
|
|
JFileChooser fc = new JFileChooser();
|
|
int ret = fc.showOpenDialog(null);
|
|
if (ret == JFileChooser.APPROVE_OPTION) {
|
|
File file = fc.getSelectedFile();
|
|
return file.getAbsolutePath();
|
|
} else {
|
|
return null;
|
|
}
|
|
}
|
|
|
|
public static void main(String[] args) throws IOException {
|
|
String path = fileChose().toString();
|
|
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath);
|
|
File file = new File(path);
|
|
|
|
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
|
|
new ImagePreProcessingScaler(0, 1).transform(image);
|
|
|
|
// Pass through to neural Net
|
|
INDArray output = model.output(image);
|
|
|
|
logger.info("File: {}", path);
|
|
logger.info(output.toString());
|
|
}
|
|
|
|
}
|