commit
8cfb3a1db8
|
@ -0,0 +1,15 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
<artifactId>algorithms-miscellaneous-6</artifactId>
|
||||||
|
<version>0.0.1-SNAPSHOT</version>
|
||||||
|
<name>algorithms-miscellaneous-6</name>
|
||||||
|
|
||||||
|
<parent>
|
||||||
|
<groupId>com.baeldung</groupId>
|
||||||
|
<artifactId>parent-modules</artifactId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
|
||||||
|
</project>
|
|
@ -0,0 +1,33 @@
|
||||||
|
package com.baeldung.algorithms.gradientdescent;
|
||||||
|
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
|
public class GradientDescent {
|
||||||
|
|
||||||
|
private final double precision = 0.000001;
|
||||||
|
|
||||||
|
public double findLocalMinimum(Function<Double, Double> f, double initialX) {
|
||||||
|
double stepCoefficient = 0.1;
|
||||||
|
double previousStep = 1.0;
|
||||||
|
double currentX = initialX;
|
||||||
|
double previousX = initialX;
|
||||||
|
double previousY = f.apply(previousX);
|
||||||
|
int iter = 100;
|
||||||
|
|
||||||
|
currentX += stepCoefficient * previousY;
|
||||||
|
|
||||||
|
while (previousStep > precision && iter > 0) {
|
||||||
|
iter--;
|
||||||
|
double currentY = f.apply(currentX);
|
||||||
|
if (currentY > previousY) {
|
||||||
|
stepCoefficient = -stepCoefficient / 2;
|
||||||
|
}
|
||||||
|
previousX = currentX;
|
||||||
|
currentX += stepCoefficient * previousY;
|
||||||
|
previousY = currentY;
|
||||||
|
previousStep = StrictMath.abs(currentX - previousX);
|
||||||
|
}
|
||||||
|
return currentX;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
package com.baeldung.algorithms.gradientdescent;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class GradientDescentUnitTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void givenFunction_whenStartingPointIsOne_thenLocalMinimumIsFound() {
|
||||||
|
Function<Double, Double> df = x ->
|
||||||
|
StrictMath.abs(StrictMath.pow(x, 3)) - (3 * StrictMath.pow(x, 2)) + x;
|
||||||
|
GradientDescent gd = new GradientDescent();
|
||||||
|
double res = gd.findLocalMinimum(df, 1);
|
||||||
|
assertTrue(res > 1.78);
|
||||||
|
assertTrue(res < 1.84);
|
||||||
|
}
|
||||||
|
}
|
2
pom.xml
2
pom.xml
|
@ -342,6 +342,7 @@
|
||||||
<module>algorithms-miscellaneous-3</module>
|
<module>algorithms-miscellaneous-3</module>
|
||||||
<module>algorithms-miscellaneous-4</module>
|
<module>algorithms-miscellaneous-4</module>
|
||||||
<module>algorithms-miscellaneous-5</module>
|
<module>algorithms-miscellaneous-5</module>
|
||||||
|
<module>algorithms-miscellaneous-6</module>
|
||||||
<module>algorithms-searching</module>
|
<module>algorithms-searching</module>
|
||||||
<module>algorithms-sorting</module>
|
<module>algorithms-sorting</module>
|
||||||
<module>algorithms-sorting-2</module>
|
<module>algorithms-sorting-2</module>
|
||||||
|
@ -853,6 +854,7 @@
|
||||||
<module>algorithms-miscellaneous-3</module>
|
<module>algorithms-miscellaneous-3</module>
|
||||||
<module>algorithms-miscellaneous-4</module>
|
<module>algorithms-miscellaneous-4</module>
|
||||||
<module>algorithms-miscellaneous-5</module>
|
<module>algorithms-miscellaneous-5</module>
|
||||||
|
<module>algorithms-miscellaneous-6</module>
|
||||||
<module>algorithms-searching</module>
|
<module>algorithms-searching</module>
|
||||||
<module>algorithms-sorting</module>
|
<module>algorithms-sorting</module>
|
||||||
<module>algorithms-sorting-2</module>
|
<module>algorithms-sorting-2</module>
|
||||||
|
|
Loading…
Reference in New Issue