Algorithms refactor (#2136)
This commit is contained in:
parent
1b0d5f0b73
commit
ef7400484c
|
@ -1,4 +1,4 @@
|
||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Finite state machine.
|
* Finite state machine.
|
|
@ -1,4 +1,4 @@
|
||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Default implementation of a finite state machine.
|
* Default implementation of a finite state machine.
|
|
@ -1,4 +1,4 @@
|
||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
|
@ -1,4 +1,4 @@
|
||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
|
@ -1,4 +1,4 @@
|
||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* State. Part of a finite state machine.
|
* State. Part of a finite state machine.
|
|
@ -1,4 +1,4 @@
|
||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Transition in a finite State machine.
|
* Transition in a finite State machine.
|
|
@ -8,9 +8,9 @@ import com.baeldung.algorithms.mcts.tree.Tree;
|
||||||
|
|
||||||
public class MonteCarloTreeSearch {
|
public class MonteCarloTreeSearch {
|
||||||
|
|
||||||
static final int WIN_SCORE = 10;
|
private static final int WIN_SCORE = 10;
|
||||||
int level;
|
private int level;
|
||||||
int oponent;
|
private int oponent;
|
||||||
|
|
||||||
public MonteCarloTreeSearch() {
|
public MonteCarloTreeSearch() {
|
||||||
this.level = 3;
|
this.level = 3;
|
||||||
|
|
|
@ -7,10 +7,10 @@ import com.baeldung.algorithms.mcts.tictactoe.Board;
|
||||||
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||||
|
|
||||||
public class State {
|
public class State {
|
||||||
Board board;
|
private Board board;
|
||||||
int playerNo;
|
private int playerNo;
|
||||||
int visitCount;
|
private int visitCount;
|
||||||
double winScore;
|
private double winScore;
|
||||||
|
|
||||||
public State() {
|
public State() {
|
||||||
board = new Board();
|
board = new Board();
|
||||||
|
@ -27,23 +27,23 @@ public class State {
|
||||||
this.board = new Board(board);
|
this.board = new Board(board);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Board getBoard() {
|
Board getBoard() {
|
||||||
return board;
|
return board;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setBoard(Board board) {
|
void setBoard(Board board) {
|
||||||
this.board = board;
|
this.board = board;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getPlayerNo() {
|
int getPlayerNo() {
|
||||||
return playerNo;
|
return playerNo;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setPlayerNo(int playerNo) {
|
void setPlayerNo(int playerNo) {
|
||||||
this.playerNo = playerNo;
|
this.playerNo = playerNo;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getOpponent() {
|
int getOpponent() {
|
||||||
return 3 - playerNo;
|
return 3 - playerNo;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,11 +55,11 @@ public class State {
|
||||||
this.visitCount = visitCount;
|
this.visitCount = visitCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getWinScore() {
|
double getWinScore() {
|
||||||
return winScore;
|
return winScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setWinScore(double winScore) {
|
void setWinScore(double winScore) {
|
||||||
this.winScore = winScore;
|
this.winScore = winScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,23 +75,23 @@ public class State {
|
||||||
return possibleStates;
|
return possibleStates;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void incrementVisit() {
|
void incrementVisit() {
|
||||||
this.visitCount++;
|
this.visitCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addScore(double score) {
|
void addScore(double score) {
|
||||||
if (this.winScore != Integer.MIN_VALUE)
|
if (this.winScore != Integer.MIN_VALUE)
|
||||||
this.winScore += score;
|
this.winScore += score;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void randomPlay() {
|
void randomPlay() {
|
||||||
List<Position> availablePositions = this.board.getEmptyPositions();
|
List<Position> availablePositions = this.board.getEmptyPositions();
|
||||||
int totalPossibilities = availablePositions.size();
|
int totalPossibilities = availablePositions.size();
|
||||||
int selectRandom = (int) (Math.random() * ((totalPossibilities - 1) + 1));
|
int selectRandom = (int) (Math.random() * ((totalPossibilities - 1) + 1));
|
||||||
this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
|
this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void togglePlayer() {
|
void togglePlayer() {
|
||||||
this.playerNo = 3 - this.playerNo;
|
this.playerNo = 3 - this.playerNo;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,20 +7,18 @@ import java.util.List;
|
||||||
import com.baeldung.algorithms.mcts.tree.Node;
|
import com.baeldung.algorithms.mcts.tree.Node;
|
||||||
|
|
||||||
public class UCT {
|
public class UCT {
|
||||||
final static double C = 1.41;
|
|
||||||
|
|
||||||
public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
|
public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
|
||||||
if (nodeVisit == 0)
|
if (nodeVisit == 0) {
|
||||||
return Integer.MAX_VALUE;
|
return Integer.MAX_VALUE;
|
||||||
return ((double) nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
|
}
|
||||||
|
return (nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Node findBestNodeWithUCT(Node node) {
|
static Node findBestNodeWithUCT(Node node) {
|
||||||
int parentVisit = node.getState().getVisitCount();
|
int parentVisit = node.getState().getVisitCount();
|
||||||
List<Node> childNodes = node.getChildArray();
|
return Collections.max(
|
||||||
return Collections.max(childNodes, Comparator.comparing(c -> {
|
node.getChildArray(),
|
||||||
double score = uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount());
|
Comparator.comparing(c -> uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount())));
|
||||||
return score;
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
package algorithms;
|
package algorithms;
|
||||||
|
|
||||||
import com.baeldung.automata.*;
|
import com.baeldung.algorithms.automata.*;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package algorithms;
|
package algorithms.mcts;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
@ -16,8 +16,8 @@ import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||||
import com.baeldung.algorithms.mcts.tree.Tree;
|
import com.baeldung.algorithms.mcts.tree.Tree;
|
||||||
|
|
||||||
public class MCTSTest {
|
public class MCTSTest {
|
||||||
Tree gameTree;
|
private Tree gameTree;
|
||||||
MonteCarloTreeSearch mcts;
|
private MonteCarloTreeSearch mcts;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void initGameTree() {
|
public void initGameTree() {
|
Loading…
Reference in New Issue