Refactor the utility class

This commit is contained in:
Andrew Shcherbakov 2019-09-04 22:25:22 +02:00
parent 0527d816cc
commit c20918329f
4 changed files with 119 additions and 112 deletions

View File

@ -1,102 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package com.baeldung.logreg;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import java.io.*;
/**
* Common data utility functions.
*
* @author fvaleri
*/
public class DataUtilities {
/**
* Download a remote file if it doesn't exist.
* @param remoteUrl URL of the remote file.
* @param localPath Where to download the file.
* @return True if and only if the file has been downloaded.
* @throws Exception IO error.
*/
public static boolean downloadFile(String remoteUrl, String localPath) throws IOException {
boolean downloaded = false;
if (remoteUrl == null || localPath == null)
return downloaded;
File file = new File(localPath);
if (!file.exists()) {
file.getParentFile().mkdirs();
HttpClientBuilder builder = HttpClientBuilder.create();
CloseableHttpClient client = builder.build();
try (CloseableHttpResponse response = client.execute(new HttpGet(remoteUrl))) {
HttpEntity entity = response.getEntity();
if (entity != null) {
try (FileOutputStream outstream = new FileOutputStream(file)) {
entity.writeTo(outstream);
outstream.flush();
outstream.close();
}
}
}
downloaded = true;
}
if (!file.exists())
throw new IOException("File doesn't exist: " + localPath);
return downloaded;
}
/**
* Extract a "tar.gz" file into a local folder.
* @param inputPath Input file path.
* @param outputPath Output directory path.
* @throws IOException IO error.
*/
public static void extractTarGz(String inputPath, String outputPath) throws IOException {
if (inputPath == null || outputPath == null)
return;
final int bufferSize = 4096;
if (!outputPath.endsWith("" + File.separatorChar))
outputPath = outputPath + File.separatorChar;
try (TarArchiveInputStream tais = new TarArchiveInputStream(
new GzipCompressorInputStream(new BufferedInputStream(new FileInputStream(inputPath))))) {
TarArchiveEntry entry;
while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {
if (entry.isDirectory()) {
new File(outputPath + entry.getName()).mkdirs();
} else {
int count;
byte data[] = new byte[bufferSize];
FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());
BufferedOutputStream dest = new BufferedOutputStream(fos, bufferSize);
while ((count = tais.read(data, 0, bufferSize)) != -1) {
dest.write(data, 0, count);
}
dest.close();
}
}
}
}
}

View File

@ -67,18 +67,20 @@ public class MnistClassifier {
final String path = basePath + "mnist_png" + File.separator; final String path = basePath + "mnist_png" + File.separator;
if (!new File(path).exists()) { if (!new File(path).exists()) {
logger.debug("Downloading data {}", dataUrl); logger.info("Downloading data {}", dataUrl);
String localFilePath = basePath + "mnist_png.tar.gz"; String localFilePath = basePath + "mnist_png.tar.gz";
logger.info("local file: {}", localFilePath); File file = new File(localFilePath);
if (DataUtilities.downloadFile(dataUrl, localFilePath)) { if (!file.exists()) {
DataUtilities.extractTarGz(localFilePath, basePath); file.getParentFile()
.mkdirs();
Utils.downloadAndSave(dataUrl, file);
Utils.extractTarArchive(file, basePath);
} }
} else { } else {
logger.info("local file exists {}", path); logger.info("Using the local data from folder {}", path);
} }
logger.info("Vectorizing data..."); logger.info("Vectorizing the data from folder {}", path);
// vectorization of train data // vectorization of train data
File trainData = new File(path + "training"); File trainData = new File(path + "training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);

View File

@ -36,8 +36,12 @@ public class MnistPrediction {
} }
public static void main(String[] args) throws IOException { public static void main(String[] args) throws IOException {
String path = fileChose().toString(); if (!modelPath.exists()) {
logger.info("The model not found. Have you trained it?");
return;
}
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath); MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath);
String path = fileChose();
File file = new File(path); File file = new File(path);
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file); INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
@ -47,7 +51,7 @@ public class MnistPrediction {
INDArray output = model.output(image); INDArray output = model.output(image);
logger.info("File: {}", path); logger.info("File: {}", path);
logger.info(output.toString()); logger.info("Probabilities: {}", output);
} }
} }

View File

@ -0,0 +1,103 @@
package com.baeldung.logreg;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import org.apache.commons.compress.archivers.ArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Utility class for digit classifier.
*
*/
public class Utils {
private static final Logger logger = LoggerFactory.getLogger(Utils.class);
private Utils() {
}
/**
* Download the content of the given url and save it into a file.
* @param url
* @param file
*/
public static void downloadAndSave(String url, File file) throws IOException {
CloseableHttpClient client = HttpClientBuilder.create()
.build();
logger.info("Connecting to {}", url);
try (CloseableHttpResponse response = client.execute(new HttpGet(url))) {
HttpEntity entity = response.getEntity();
if (entity != null) {
logger.info("Downloaded {} bytes", entity.getContentLength());
try (FileOutputStream outstream = new FileOutputStream(file)) {
logger.info("Saving to the local file");
entity.writeTo(outstream);
outstream.flush();
logger.info("Local file saved");
}
}
}
}
/**
* Extract a "tar.gz" file into a given folder.
* @param file
* @param folder
*/
public static void extractTarArchive(File file, String folder) throws IOException {
logger.info("Extracting archive {} into folder {}", file.getName(), folder);
// @formatter:off
try (FileInputStream fis = new FileInputStream(file);
BufferedInputStream bis = new BufferedInputStream(fis);
GzipCompressorInputStream gzip = new GzipCompressorInputStream(bis);
TarArchiveInputStream tar = new TarArchiveInputStream(gzip)) {
// @formatter:on
TarArchiveEntry entry;
while ((entry = (TarArchiveEntry) tar.getNextEntry()) != null) {
extractEntry(entry, tar, folder);
}
}
logger.info("Archive extracted");
}
/**
* Extract an entry of the input stream into a given folder
* @param entry
* @param tar
* @param folder
* @throws IOException
*/
public static void extractEntry(ArchiveEntry entry, InputStream tar, String folder) throws IOException {
final int bufferSize = 4096;
final String path = folder + entry.getName();
if (entry.isDirectory()) {
new File(path).mkdirs();
} else {
int count;
byte[] data = new byte[bufferSize];
// @formatter:off
try (FileOutputStream os = new FileOutputStream(path);
BufferedOutputStream dest = new BufferedOutputStream(os, bufferSize)) {
// @formatter:off
while ((count = tar.read(data, 0, bufferSize)) != -1) {
dest.write(data, 0, count);
}
}
}
}
}