Adding source code for the tutorial tracked under BAEL-2759 (#6533)
This commit is contained in:
parent
b3fc27088b
commit
dc72b8b397
5
pom.xml
5
pom.xml
|
@ -526,6 +526,9 @@
|
||||||
<module>rxjava</module>
|
<module>rxjava</module>
|
||||||
<module>rxjava-2</module>
|
<module>rxjava-2</module>
|
||||||
<module>software-security/sql-injection-samples</module>
|
<module>software-security/sql-injection-samples</module>
|
||||||
|
|
||||||
|
<module>tensorflow-java</module>
|
||||||
|
|
||||||
</modules>
|
</modules>
|
||||||
|
|
||||||
</profile>
|
</profile>
|
||||||
|
@ -742,6 +745,8 @@
|
||||||
<module>xml</module>
|
<module>xml</module>
|
||||||
<module>xmlunit-2</module>
|
<module>xmlunit-2</module>
|
||||||
<module>xstream</module>
|
<module>xstream</module>
|
||||||
|
|
||||||
|
<module>tensorflow-java</module>
|
||||||
|
|
||||||
</modules>
|
</modules>
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
/.settings
|
||||||
|
/model
|
||||||
|
/target
|
||||||
|
.classpath
|
||||||
|
.project
|
||||||
|
.springBeans
|
|
@ -0,0 +1,3 @@
|
||||||
|
## Relevant articles:
|
||||||
|
|
||||||
|
- [TensorFlow for Java](https://www.baeldung.com/xxxx)
|
|
@ -0,0 +1,52 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
<groupId>com.baeldung</groupId>
|
||||||
|
<artifactId>tensorflow-java</artifactId>
|
||||||
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
<url>http://maven.apache.org</url>
|
||||||
|
|
||||||
|
<parent>
|
||||||
|
<groupId>com.baeldung</groupId>
|
||||||
|
<artifactId>parent-modules</artifactId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<java.version>1.8</java.version>
|
||||||
|
<tensorflow.version>1.12.0</tensorflow.version>
|
||||||
|
<junit.jupiter.version>5.4.0</junit.jupiter.version>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.tensorflow</groupId>
|
||||||
|
<artifactId>tensorflow</artifactId>
|
||||||
|
<version>${tensorflow.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
|
<version>${junit.jupiter.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-engine</artifactId>
|
||||||
|
<version>${junit.jupiter.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-maven-plugin</artifactId>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
</project>
|
|
@ -0,0 +1,41 @@
|
||||||
|
package org.baeldung.tensorflow;
|
||||||
|
|
||||||
|
import org.tensorflow.DataType;
|
||||||
|
import org.tensorflow.Graph;
|
||||||
|
import org.tensorflow.Operation;
|
||||||
|
import org.tensorflow.Session;
|
||||||
|
import org.tensorflow.Tensor;
|
||||||
|
|
||||||
|
public class TensorflowGraph {
|
||||||
|
|
||||||
|
public static Graph createGraph() {
|
||||||
|
Graph graph = new Graph();
|
||||||
|
Operation a = graph.opBuilder("Const", "a").setAttr("dtype", DataType.fromClass(Double.class))
|
||||||
|
.setAttr("value", Tensor.<Double>create(3.0, Double.class)).build();
|
||||||
|
Operation b = graph.opBuilder("Const", "b").setAttr("dtype", DataType.fromClass(Double.class))
|
||||||
|
.setAttr("value", Tensor.<Double>create(2.0, Double.class)).build();
|
||||||
|
Operation x = graph.opBuilder("Placeholder", "x").setAttr("dtype", DataType.fromClass(Double.class)).build();
|
||||||
|
Operation y = graph.opBuilder("Placeholder", "y").setAttr("dtype", DataType.fromClass(Double.class)).build();
|
||||||
|
Operation ax = graph.opBuilder("Mul", "ax").addInput(a.output(0)).addInput(x.output(0)).build();
|
||||||
|
Operation by = graph.opBuilder("Mul", "by").addInput(b.output(0)).addInput(y.output(0)).build();
|
||||||
|
graph.opBuilder("Add", "z").addInput(ax.output(0)).addInput(by.output(0)).build();
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Object runGraph(Graph graph, Double x, Double y) {
|
||||||
|
Object result;
|
||||||
|
try (Session sess = new Session(graph)) {
|
||||||
|
result = sess.runner().fetch("z").feed("x", Tensor.<Double>create(x, Double.class))
|
||||||
|
.feed("y", Tensor.<Double>create(y, Double.class)).run().get(0).expect(Double.class)
|
||||||
|
.doubleValue();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
Graph graph = TensorflowGraph.createGraph();
|
||||||
|
Object result = TensorflowGraph.runGraph(graph, 3.0, 6.0);
|
||||||
|
System.out.println(result);
|
||||||
|
graph.close();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
package org.baeldung.tensorflow;
|
||||||
|
|
||||||
|
import org.tensorflow.SavedModelBundle;
|
||||||
|
import org.tensorflow.Tensor;
|
||||||
|
|
||||||
|
public class TensorflowSavedModel {
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
SavedModelBundle model = SavedModelBundle.load("./model", "serve");
|
||||||
|
Tensor<Integer> tensor = model.session().runner().fetch("z").feed("x", Tensor.<Integer>create(3, Integer.class))
|
||||||
|
.feed("y", Tensor.<Integer>create(3, Integer.class)).run().get(0).expect(Integer.class);
|
||||||
|
System.out.println(tensor.intValue());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
import tensorflow as tf
|
||||||
|
graph = tf.Graph()
|
||||||
|
builder = tf.saved_model.builder.SavedModelBuilder('./model')
|
||||||
|
writer = tf.summary.FileWriter('.')
|
||||||
|
with graph.as_default():
|
||||||
|
a = tf.constant(2, name='a')
|
||||||
|
b = tf.constant(3, name='b')
|
||||||
|
x = tf.placeholder(tf.int32, name='x')
|
||||||
|
y = tf.placeholder(tf.int32, name='y')
|
||||||
|
z = tf.math.add(a*x, b*y, name='z')
|
||||||
|
writer.add_graph(tf.get_default_graph())
|
||||||
|
writer.flush()
|
||||||
|
sess = tf.Session()
|
||||||
|
sess.run(z, feed_dict = {x: 2, y: 3})
|
||||||
|
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
|
||||||
|
builder.save()
|
|
@ -0,0 +1,21 @@
|
||||||
|
package org.baeldung.tensorflow;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.tensorflow.Graph;
|
||||||
|
|
||||||
|
public class TensorflowGraphUnitTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void givenTensorflowGraphWhenRunInSessionReturnsExpectedResult() {
|
||||||
|
|
||||||
|
Graph graph = TensorflowGraph.createGraph();
|
||||||
|
Object result = TensorflowGraph.runGraph(graph, 3.0, 6.0);
|
||||||
|
assertEquals(21.0, result);
|
||||||
|
System.out.println(result);
|
||||||
|
graph.close();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue