Algorithms refactor (#2136)
This commit is contained in:
parent
1b0d5f0b73
commit
ef7400484c
|
@ -1,4 +1,4 @@
|
|||
package com.baeldung.automata;
|
||||
package com.baeldung.algorithms.automata;
|
||||
|
||||
/**
|
||||
* Finite state machine.
|
|
@ -1,4 +1,4 @@
|
|||
package com.baeldung.automata;
|
||||
package com.baeldung.algorithms.automata;
|
||||
|
||||
/**
|
||||
* Default implementation of a finite state machine.
|
|
@ -1,4 +1,4 @@
|
|||
package com.baeldung.automata;
|
||||
package com.baeldung.algorithms.automata;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
|
@ -1,4 +1,4 @@
|
|||
package com.baeldung.automata;
|
||||
package com.baeldung.algorithms.automata;
|
||||
|
||||
|
||||
/**
|
|
@ -1,4 +1,4 @@
|
|||
package com.baeldung.automata;
|
||||
package com.baeldung.algorithms.automata;
|
||||
|
||||
/**
|
||||
* State. Part of a finite state machine.
|
|
@ -1,4 +1,4 @@
|
|||
package com.baeldung.automata;
|
||||
package com.baeldung.algorithms.automata;
|
||||
|
||||
/**
|
||||
* Transition in a finite State machine.
|
|
@ -8,9 +8,9 @@ import com.baeldung.algorithms.mcts.tree.Tree;
|
|||
|
||||
public class MonteCarloTreeSearch {
|
||||
|
||||
static final int WIN_SCORE = 10;
|
||||
int level;
|
||||
int oponent;
|
||||
private static final int WIN_SCORE = 10;
|
||||
private int level;
|
||||
private int oponent;
|
||||
|
||||
public MonteCarloTreeSearch() {
|
||||
this.level = 3;
|
||||
|
|
|
@ -7,10 +7,10 @@ import com.baeldung.algorithms.mcts.tictactoe.Board;
|
|||
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||
|
||||
public class State {
|
||||
Board board;
|
||||
int playerNo;
|
||||
int visitCount;
|
||||
double winScore;
|
||||
private Board board;
|
||||
private int playerNo;
|
||||
private int visitCount;
|
||||
private double winScore;
|
||||
|
||||
public State() {
|
||||
board = new Board();
|
||||
|
@ -27,23 +27,23 @@ public class State {
|
|||
this.board = new Board(board);
|
||||
}
|
||||
|
||||
public Board getBoard() {
|
||||
Board getBoard() {
|
||||
return board;
|
||||
}
|
||||
|
||||
public void setBoard(Board board) {
|
||||
void setBoard(Board board) {
|
||||
this.board = board;
|
||||
}
|
||||
|
||||
public int getPlayerNo() {
|
||||
int getPlayerNo() {
|
||||
return playerNo;
|
||||
}
|
||||
|
||||
public void setPlayerNo(int playerNo) {
|
||||
void setPlayerNo(int playerNo) {
|
||||
this.playerNo = playerNo;
|
||||
}
|
||||
|
||||
public int getOpponent() {
|
||||
int getOpponent() {
|
||||
return 3 - playerNo;
|
||||
}
|
||||
|
||||
|
@ -55,11 +55,11 @@ public class State {
|
|||
this.visitCount = visitCount;
|
||||
}
|
||||
|
||||
public double getWinScore() {
|
||||
double getWinScore() {
|
||||
return winScore;
|
||||
}
|
||||
|
||||
public void setWinScore(double winScore) {
|
||||
void setWinScore(double winScore) {
|
||||
this.winScore = winScore;
|
||||
}
|
||||
|
||||
|
@ -75,23 +75,23 @@ public class State {
|
|||
return possibleStates;
|
||||
}
|
||||
|
||||
public void incrementVisit() {
|
||||
void incrementVisit() {
|
||||
this.visitCount++;
|
||||
}
|
||||
|
||||
public void addScore(double score) {
|
||||
void addScore(double score) {
|
||||
if (this.winScore != Integer.MIN_VALUE)
|
||||
this.winScore += score;
|
||||
}
|
||||
|
||||
public void randomPlay() {
|
||||
void randomPlay() {
|
||||
List<Position> availablePositions = this.board.getEmptyPositions();
|
||||
int totalPossibilities = availablePositions.size();
|
||||
int selectRandom = (int) (Math.random() * ((totalPossibilities - 1) + 1));
|
||||
this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
|
||||
}
|
||||
|
||||
public void togglePlayer() {
|
||||
void togglePlayer() {
|
||||
this.playerNo = 3 - this.playerNo;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,20 +7,18 @@ import java.util.List;
|
|||
import com.baeldung.algorithms.mcts.tree.Node;
|
||||
|
||||
public class UCT {
|
||||
final static double C = 1.41;
|
||||
|
||||
public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
|
||||
if (nodeVisit == 0)
|
||||
if (nodeVisit == 0) {
|
||||
return Integer.MAX_VALUE;
|
||||
return ((double) nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
|
||||
}
|
||||
return (nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
|
||||
}
|
||||
|
||||
public static Node findBestNodeWithUCT(Node node) {
|
||||
static Node findBestNodeWithUCT(Node node) {
|
||||
int parentVisit = node.getState().getVisitCount();
|
||||
List<Node> childNodes = node.getChildArray();
|
||||
return Collections.max(childNodes, Comparator.comparing(c -> {
|
||||
double score = uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount());
|
||||
return score;
|
||||
}));
|
||||
return Collections.max(
|
||||
node.getChildArray(),
|
||||
Comparator.comparing(c -> uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount())));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package algorithms;
|
||||
|
||||
import com.baeldung.automata.*;
|
||||
import com.baeldung.algorithms.automata.*;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
|
|
@ -1,92 +1,92 @@
|
|||
package algorithms;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import com.baeldung.algorithms.mcts.montecarlo.MonteCarloTreeSearch;
|
||||
import com.baeldung.algorithms.mcts.montecarlo.State;
|
||||
import com.baeldung.algorithms.mcts.montecarlo.UCT;
|
||||
import com.baeldung.algorithms.mcts.tictactoe.Board;
|
||||
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||
import com.baeldung.algorithms.mcts.tree.Tree;
|
||||
|
||||
public class MCTSTest {
|
||||
Tree gameTree;
|
||||
MonteCarloTreeSearch mcts;
|
||||
|
||||
@Before
|
||||
public void initGameTree() {
|
||||
gameTree = new Tree();
|
||||
mcts = new MonteCarloTreeSearch();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
|
||||
double uctValue = 15.79;
|
||||
assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
|
||||
State initState = gameTree.getRoot().getState();
|
||||
List<State> possibleStates = initState.getAllPossibleStates();
|
||||
assertTrue(possibleStates.size() > 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
|
||||
Board board = new Board();
|
||||
int initAvailablePositions = board.getEmptyPositions().size();
|
||||
board.performMove(Board.P1, new Position(1, 1));
|
||||
int availablePositions = board.getEmptyPositions().size();
|
||||
assertTrue(initAvailablePositions > availablePositions);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
|
||||
Board board = new Board();
|
||||
|
||||
int player = Board.P1;
|
||||
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
||||
for (int i = 0; i < totalMoves; i++) {
|
||||
board = mcts.findNextMove(board, player);
|
||||
if (board.checkStatus() != -1) {
|
||||
break;
|
||||
}
|
||||
player = 3 - player;
|
||||
}
|
||||
int winStatus = board.checkStatus();
|
||||
assertEquals(winStatus, Board.DRAW);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
|
||||
Board board = new Board();
|
||||
MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
|
||||
mcts1.setLevel(1);
|
||||
MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
|
||||
mcts3.setLevel(3);
|
||||
|
||||
int player = Board.P1;
|
||||
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
||||
for (int i = 0; i < totalMoves; i++) {
|
||||
if (player == Board.P1)
|
||||
board = mcts3.findNextMove(board, player);
|
||||
else
|
||||
board = mcts1.findNextMove(board, player);
|
||||
|
||||
if (board.checkStatus() != -1) {
|
||||
break;
|
||||
}
|
||||
player = 3 - player;
|
||||
}
|
||||
int winStatus = board.checkStatus();
|
||||
assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
|
||||
}
|
||||
|
||||
}
|
||||
package algorithms.mcts;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import com.baeldung.algorithms.mcts.montecarlo.MonteCarloTreeSearch;
|
||||
import com.baeldung.algorithms.mcts.montecarlo.State;
|
||||
import com.baeldung.algorithms.mcts.montecarlo.UCT;
|
||||
import com.baeldung.algorithms.mcts.tictactoe.Board;
|
||||
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||
import com.baeldung.algorithms.mcts.tree.Tree;
|
||||
|
||||
public class MCTSTest {
|
||||
private Tree gameTree;
|
||||
private MonteCarloTreeSearch mcts;
|
||||
|
||||
@Before
|
||||
public void initGameTree() {
|
||||
gameTree = new Tree();
|
||||
mcts = new MonteCarloTreeSearch();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
|
||||
double uctValue = 15.79;
|
||||
assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
|
||||
State initState = gameTree.getRoot().getState();
|
||||
List<State> possibleStates = initState.getAllPossibleStates();
|
||||
assertTrue(possibleStates.size() > 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
|
||||
Board board = new Board();
|
||||
int initAvailablePositions = board.getEmptyPositions().size();
|
||||
board.performMove(Board.P1, new Position(1, 1));
|
||||
int availablePositions = board.getEmptyPositions().size();
|
||||
assertTrue(initAvailablePositions > availablePositions);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
|
||||
Board board = new Board();
|
||||
|
||||
int player = Board.P1;
|
||||
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
||||
for (int i = 0; i < totalMoves; i++) {
|
||||
board = mcts.findNextMove(board, player);
|
||||
if (board.checkStatus() != -1) {
|
||||
break;
|
||||
}
|
||||
player = 3 - player;
|
||||
}
|
||||
int winStatus = board.checkStatus();
|
||||
assertEquals(winStatus, Board.DRAW);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
|
||||
Board board = new Board();
|
||||
MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
|
||||
mcts1.setLevel(1);
|
||||
MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
|
||||
mcts3.setLevel(3);
|
||||
|
||||
int player = Board.P1;
|
||||
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
||||
for (int i = 0; i < totalMoves; i++) {
|
||||
if (player == Board.P1)
|
||||
board = mcts3.findNextMove(board, player);
|
||||
else
|
||||
board = mcts1.findNextMove(board, player);
|
||||
|
||||
if (board.checkStatus() != -1) {
|
||||
break;
|
||||
}
|
||||
player = 3 - player;
|
||||
}
|
||||
int winStatus = board.checkStatus();
|
||||
assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue