Refactor the utility class
This commit is contained in:
parent
0527d816cc
commit
c20918329f
|
@ -1,102 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package com.baeldung.logreg;
|
||||
|
||||
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
|
||||
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
|
||||
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
|
||||
import org.apache.http.HttpEntity;
|
||||
import org.apache.http.client.methods.CloseableHttpResponse;
|
||||
import org.apache.http.client.methods.HttpGet;
|
||||
import org.apache.http.impl.client.CloseableHttpClient;
|
||||
import org.apache.http.impl.client.HttpClientBuilder;
|
||||
|
||||
import java.io.*;
|
||||
|
||||
/**
|
||||
* Common data utility functions.
|
||||
*
|
||||
* @author fvaleri
|
||||
*/
|
||||
public class DataUtilities {
|
||||
|
||||
/**
|
||||
* Download a remote file if it doesn't exist.
|
||||
* @param remoteUrl URL of the remote file.
|
||||
* @param localPath Where to download the file.
|
||||
* @return True if and only if the file has been downloaded.
|
||||
* @throws Exception IO error.
|
||||
*/
|
||||
public static boolean downloadFile(String remoteUrl, String localPath) throws IOException {
|
||||
boolean downloaded = false;
|
||||
if (remoteUrl == null || localPath == null)
|
||||
return downloaded;
|
||||
File file = new File(localPath);
|
||||
if (!file.exists()) {
|
||||
file.getParentFile().mkdirs();
|
||||
HttpClientBuilder builder = HttpClientBuilder.create();
|
||||
CloseableHttpClient client = builder.build();
|
||||
try (CloseableHttpResponse response = client.execute(new HttpGet(remoteUrl))) {
|
||||
HttpEntity entity = response.getEntity();
|
||||
if (entity != null) {
|
||||
try (FileOutputStream outstream = new FileOutputStream(file)) {
|
||||
entity.writeTo(outstream);
|
||||
outstream.flush();
|
||||
outstream.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
downloaded = true;
|
||||
}
|
||||
if (!file.exists())
|
||||
throw new IOException("File doesn't exist: " + localPath);
|
||||
return downloaded;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a "tar.gz" file into a local folder.
|
||||
* @param inputPath Input file path.
|
||||
* @param outputPath Output directory path.
|
||||
* @throws IOException IO error.
|
||||
*/
|
||||
public static void extractTarGz(String inputPath, String outputPath) throws IOException {
|
||||
if (inputPath == null || outputPath == null)
|
||||
return;
|
||||
final int bufferSize = 4096;
|
||||
if (!outputPath.endsWith("" + File.separatorChar))
|
||||
outputPath = outputPath + File.separatorChar;
|
||||
try (TarArchiveInputStream tais = new TarArchiveInputStream(
|
||||
new GzipCompressorInputStream(new BufferedInputStream(new FileInputStream(inputPath))))) {
|
||||
TarArchiveEntry entry;
|
||||
while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {
|
||||
if (entry.isDirectory()) {
|
||||
new File(outputPath + entry.getName()).mkdirs();
|
||||
} else {
|
||||
int count;
|
||||
byte data[] = new byte[bufferSize];
|
||||
FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());
|
||||
BufferedOutputStream dest = new BufferedOutputStream(fos, bufferSize);
|
||||
while ((count = tais.read(data, 0, bufferSize)) != -1) {
|
||||
dest.write(data, 0, count);
|
||||
}
|
||||
dest.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -67,18 +67,20 @@ public class MnistClassifier {
|
|||
|
||||
final String path = basePath + "mnist_png" + File.separator;
|
||||
if (!new File(path).exists()) {
|
||||
logger.debug("Downloading data {}", dataUrl);
|
||||
logger.info("Downloading data {}", dataUrl);
|
||||
String localFilePath = basePath + "mnist_png.tar.gz";
|
||||
logger.info("local file: {}", localFilePath);
|
||||
if (DataUtilities.downloadFile(dataUrl, localFilePath)) {
|
||||
DataUtilities.extractTarGz(localFilePath, basePath);
|
||||
File file = new File(localFilePath);
|
||||
if (!file.exists()) {
|
||||
file.getParentFile()
|
||||
.mkdirs();
|
||||
Utils.downloadAndSave(dataUrl, file);
|
||||
Utils.extractTarArchive(file, basePath);
|
||||
}
|
||||
} else {
|
||||
logger.info("local file exists {}", path);
|
||||
|
||||
logger.info("Using the local data from folder {}", path);
|
||||
}
|
||||
|
||||
logger.info("Vectorizing data...");
|
||||
logger.info("Vectorizing the data from folder {}", path);
|
||||
// vectorization of train data
|
||||
File trainData = new File(path + "training");
|
||||
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
|
||||
|
|
|
@ -36,18 +36,22 @@ public class MnistPrediction {
|
|||
}
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
String path = fileChose().toString();
|
||||
if (!modelPath.exists()) {
|
||||
logger.info("The model not found. Have you trained it?");
|
||||
return;
|
||||
}
|
||||
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath);
|
||||
String path = fileChose();
|
||||
File file = new File(path);
|
||||
|
||||
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
|
||||
new ImagePreProcessingScaler(0, 1).transform(image);
|
||||
|
||||
|
||||
// Pass through to neural Net
|
||||
INDArray output = model.output(image);
|
||||
|
||||
logger.info("File: {}", path);
|
||||
logger.info(output.toString());
|
||||
logger.info("Probabilities: {}", output);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
package com.baeldung.logreg;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
||||
import org.apache.commons.compress.archivers.ArchiveEntry;
|
||||
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
|
||||
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
|
||||
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
|
||||
import org.apache.http.HttpEntity;
|
||||
import org.apache.http.client.methods.CloseableHttpResponse;
|
||||
import org.apache.http.client.methods.HttpGet;
|
||||
import org.apache.http.impl.client.CloseableHttpClient;
|
||||
import org.apache.http.impl.client.HttpClientBuilder;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Utility class for digit classifier.
|
||||
*
|
||||
*/
|
||||
public class Utils {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(Utils.class);
|
||||
|
||||
private Utils() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Download the content of the given url and save it into a file.
|
||||
* @param url
|
||||
* @param file
|
||||
*/
|
||||
public static void downloadAndSave(String url, File file) throws IOException {
|
||||
CloseableHttpClient client = HttpClientBuilder.create()
|
||||
.build();
|
||||
logger.info("Connecting to {}", url);
|
||||
try (CloseableHttpResponse response = client.execute(new HttpGet(url))) {
|
||||
HttpEntity entity = response.getEntity();
|
||||
if (entity != null) {
|
||||
logger.info("Downloaded {} bytes", entity.getContentLength());
|
||||
try (FileOutputStream outstream = new FileOutputStream(file)) {
|
||||
logger.info("Saving to the local file");
|
||||
entity.writeTo(outstream);
|
||||
outstream.flush();
|
||||
logger.info("Local file saved");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a "tar.gz" file into a given folder.
|
||||
* @param file
|
||||
* @param folder
|
||||
*/
|
||||
public static void extractTarArchive(File file, String folder) throws IOException {
|
||||
logger.info("Extracting archive {} into folder {}", file.getName(), folder);
|
||||
// @formatter:off
|
||||
try (FileInputStream fis = new FileInputStream(file);
|
||||
BufferedInputStream bis = new BufferedInputStream(fis);
|
||||
GzipCompressorInputStream gzip = new GzipCompressorInputStream(bis);
|
||||
TarArchiveInputStream tar = new TarArchiveInputStream(gzip)) {
|
||||
// @formatter:on
|
||||
TarArchiveEntry entry;
|
||||
while ((entry = (TarArchiveEntry) tar.getNextEntry()) != null) {
|
||||
extractEntry(entry, tar, folder);
|
||||
}
|
||||
}
|
||||
logger.info("Archive extracted");
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract an entry of the input stream into a given folder
|
||||
* @param entry
|
||||
* @param tar
|
||||
* @param folder
|
||||
* @throws IOException
|
||||
*/
|
||||
public static void extractEntry(ArchiveEntry entry, InputStream tar, String folder) throws IOException {
|
||||
final int bufferSize = 4096;
|
||||
final String path = folder + entry.getName();
|
||||
if (entry.isDirectory()) {
|
||||
new File(path).mkdirs();
|
||||
} else {
|
||||
int count;
|
||||
byte[] data = new byte[bufferSize];
|
||||
// @formatter:off
|
||||
try (FileOutputStream os = new FileOutputStream(path);
|
||||
BufferedOutputStream dest = new BufferedOutputStream(os, bufferSize)) {
|
||||
// @formatter:off
|
||||
while ((count = tar.read(data, 0, bufferSize)) != -1) {
|
||||
dest.write(data, 0, count);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue