Refactor the utility class
This commit is contained in:
parent
0527d816cc
commit
c20918329f
|
@ -1,102 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package com.baeldung.logreg;
|
|
||||||
|
|
||||||
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
|
|
||||||
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
|
|
||||||
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
|
|
||||||
import org.apache.http.HttpEntity;
|
|
||||||
import org.apache.http.client.methods.CloseableHttpResponse;
|
|
||||||
import org.apache.http.client.methods.HttpGet;
|
|
||||||
import org.apache.http.impl.client.CloseableHttpClient;
|
|
||||||
import org.apache.http.impl.client.HttpClientBuilder;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Common data utility functions.
|
|
||||||
*
|
|
||||||
* @author fvaleri
|
|
||||||
*/
|
|
||||||
public class DataUtilities {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Download a remote file if it doesn't exist.
|
|
||||||
* @param remoteUrl URL of the remote file.
|
|
||||||
* @param localPath Where to download the file.
|
|
||||||
* @return True if and only if the file has been downloaded.
|
|
||||||
* @throws Exception IO error.
|
|
||||||
*/
|
|
||||||
public static boolean downloadFile(String remoteUrl, String localPath) throws IOException {
|
|
||||||
boolean downloaded = false;
|
|
||||||
if (remoteUrl == null || localPath == null)
|
|
||||||
return downloaded;
|
|
||||||
File file = new File(localPath);
|
|
||||||
if (!file.exists()) {
|
|
||||||
file.getParentFile().mkdirs();
|
|
||||||
HttpClientBuilder builder = HttpClientBuilder.create();
|
|
||||||
CloseableHttpClient client = builder.build();
|
|
||||||
try (CloseableHttpResponse response = client.execute(new HttpGet(remoteUrl))) {
|
|
||||||
HttpEntity entity = response.getEntity();
|
|
||||||
if (entity != null) {
|
|
||||||
try (FileOutputStream outstream = new FileOutputStream(file)) {
|
|
||||||
entity.writeTo(outstream);
|
|
||||||
outstream.flush();
|
|
||||||
outstream.close();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
downloaded = true;
|
|
||||||
}
|
|
||||||
if (!file.exists())
|
|
||||||
throw new IOException("File doesn't exist: " + localPath);
|
|
||||||
return downloaded;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract a "tar.gz" file into a local folder.
|
|
||||||
* @param inputPath Input file path.
|
|
||||||
* @param outputPath Output directory path.
|
|
||||||
* @throws IOException IO error.
|
|
||||||
*/
|
|
||||||
public static void extractTarGz(String inputPath, String outputPath) throws IOException {
|
|
||||||
if (inputPath == null || outputPath == null)
|
|
||||||
return;
|
|
||||||
final int bufferSize = 4096;
|
|
||||||
if (!outputPath.endsWith("" + File.separatorChar))
|
|
||||||
outputPath = outputPath + File.separatorChar;
|
|
||||||
try (TarArchiveInputStream tais = new TarArchiveInputStream(
|
|
||||||
new GzipCompressorInputStream(new BufferedInputStream(new FileInputStream(inputPath))))) {
|
|
||||||
TarArchiveEntry entry;
|
|
||||||
while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {
|
|
||||||
if (entry.isDirectory()) {
|
|
||||||
new File(outputPath + entry.getName()).mkdirs();
|
|
||||||
} else {
|
|
||||||
int count;
|
|
||||||
byte data[] = new byte[bufferSize];
|
|
||||||
FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());
|
|
||||||
BufferedOutputStream dest = new BufferedOutputStream(fos, bufferSize);
|
|
||||||
while ((count = tais.read(data, 0, bufferSize)) != -1) {
|
|
||||||
dest.write(data, 0, count);
|
|
||||||
}
|
|
||||||
dest.close();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -67,18 +67,20 @@ public class MnistClassifier {
|
||||||
|
|
||||||
final String path = basePath + "mnist_png" + File.separator;
|
final String path = basePath + "mnist_png" + File.separator;
|
||||||
if (!new File(path).exists()) {
|
if (!new File(path).exists()) {
|
||||||
logger.debug("Downloading data {}", dataUrl);
|
logger.info("Downloading data {}", dataUrl);
|
||||||
String localFilePath = basePath + "mnist_png.tar.gz";
|
String localFilePath = basePath + "mnist_png.tar.gz";
|
||||||
logger.info("local file: {}", localFilePath);
|
File file = new File(localFilePath);
|
||||||
if (DataUtilities.downloadFile(dataUrl, localFilePath)) {
|
if (!file.exists()) {
|
||||||
DataUtilities.extractTarGz(localFilePath, basePath);
|
file.getParentFile()
|
||||||
|
.mkdirs();
|
||||||
|
Utils.downloadAndSave(dataUrl, file);
|
||||||
|
Utils.extractTarArchive(file, basePath);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logger.info("local file exists {}", path);
|
logger.info("Using the local data from folder {}", path);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("Vectorizing data...");
|
logger.info("Vectorizing the data from folder {}", path);
|
||||||
// vectorization of train data
|
// vectorization of train data
|
||||||
File trainData = new File(path + "training");
|
File trainData = new File(path + "training");
|
||||||
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
|
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
|
||||||
|
|
|
@ -36,8 +36,12 @@ public class MnistPrediction {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException {
|
public static void main(String[] args) throws IOException {
|
||||||
String path = fileChose().toString();
|
if (!modelPath.exists()) {
|
||||||
|
logger.info("The model not found. Have you trained it?");
|
||||||
|
return;
|
||||||
|
}
|
||||||
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath);
|
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath);
|
||||||
|
String path = fileChose();
|
||||||
File file = new File(path);
|
File file = new File(path);
|
||||||
|
|
||||||
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
|
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
|
||||||
|
@ -47,7 +51,7 @@ public class MnistPrediction {
|
||||||
INDArray output = model.output(image);
|
INDArray output = model.output(image);
|
||||||
|
|
||||||
logger.info("File: {}", path);
|
logger.info("File: {}", path);
|
||||||
logger.info(output.toString());
|
logger.info("Probabilities: {}", output);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
package com.baeldung.logreg;
|
||||||
|
|
||||||
|
import java.io.BufferedInputStream;
|
||||||
|
import java.io.BufferedOutputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
import org.apache.commons.compress.archivers.ArchiveEntry;
|
||||||
|
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
|
||||||
|
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
|
||||||
|
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
|
||||||
|
import org.apache.http.HttpEntity;
|
||||||
|
import org.apache.http.client.methods.CloseableHttpResponse;
|
||||||
|
import org.apache.http.client.methods.HttpGet;
|
||||||
|
import org.apache.http.impl.client.CloseableHttpClient;
|
||||||
|
import org.apache.http.impl.client.HttpClientBuilder;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility class for digit classifier.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class Utils {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(Utils.class);
|
||||||
|
|
||||||
|
private Utils() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Download the content of the given url and save it into a file.
|
||||||
|
* @param url
|
||||||
|
* @param file
|
||||||
|
*/
|
||||||
|
public static void downloadAndSave(String url, File file) throws IOException {
|
||||||
|
CloseableHttpClient client = HttpClientBuilder.create()
|
||||||
|
.build();
|
||||||
|
logger.info("Connecting to {}", url);
|
||||||
|
try (CloseableHttpResponse response = client.execute(new HttpGet(url))) {
|
||||||
|
HttpEntity entity = response.getEntity();
|
||||||
|
if (entity != null) {
|
||||||
|
logger.info("Downloaded {} bytes", entity.getContentLength());
|
||||||
|
try (FileOutputStream outstream = new FileOutputStream(file)) {
|
||||||
|
logger.info("Saving to the local file");
|
||||||
|
entity.writeTo(outstream);
|
||||||
|
outstream.flush();
|
||||||
|
logger.info("Local file saved");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract a "tar.gz" file into a given folder.
|
||||||
|
* @param file
|
||||||
|
* @param folder
|
||||||
|
*/
|
||||||
|
public static void extractTarArchive(File file, String folder) throws IOException {
|
||||||
|
logger.info("Extracting archive {} into folder {}", file.getName(), folder);
|
||||||
|
// @formatter:off
|
||||||
|
try (FileInputStream fis = new FileInputStream(file);
|
||||||
|
BufferedInputStream bis = new BufferedInputStream(fis);
|
||||||
|
GzipCompressorInputStream gzip = new GzipCompressorInputStream(bis);
|
||||||
|
TarArchiveInputStream tar = new TarArchiveInputStream(gzip)) {
|
||||||
|
// @formatter:on
|
||||||
|
TarArchiveEntry entry;
|
||||||
|
while ((entry = (TarArchiveEntry) tar.getNextEntry()) != null) {
|
||||||
|
extractEntry(entry, tar, folder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.info("Archive extracted");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract an entry of the input stream into a given folder
|
||||||
|
* @param entry
|
||||||
|
* @param tar
|
||||||
|
* @param folder
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
|
public static void extractEntry(ArchiveEntry entry, InputStream tar, String folder) throws IOException {
|
||||||
|
final int bufferSize = 4096;
|
||||||
|
final String path = folder + entry.getName();
|
||||||
|
if (entry.isDirectory()) {
|
||||||
|
new File(path).mkdirs();
|
||||||
|
} else {
|
||||||
|
int count;
|
||||||
|
byte[] data = new byte[bufferSize];
|
||||||
|
// @formatter:off
|
||||||
|
try (FileOutputStream os = new FileOutputStream(path);
|
||||||
|
BufferedOutputStream dest = new BufferedOutputStream(os, bufferSize)) {
|
||||||
|
// @formatter:off
|
||||||
|
while ((count = tar.read(data, 0, bufferSize)) != -1) {
|
||||||
|
dest.write(data, 0, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue