BAEL-984 Monte Carlo tree search (#2129)
* BAEL-984 Monte Carlo tree search BAEL-984 Implementation for tic tac toe using Monte Carlo tree search * BAEL-984 test cases for MCTS BAEL-984 test cases for Monte Carlo tree search implementation
This commit is contained in:
parent
2cde0e37c2
commit
1b0d5f0b73
|
@ -0,0 +1,109 @@
|
||||||
|
package com.baeldung.algorithms.mcts.montecarlo;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import com.baeldung.algorithms.mcts.tictactoe.Board;
|
||||||
|
import com.baeldung.algorithms.mcts.tree.Node;
|
||||||
|
import com.baeldung.algorithms.mcts.tree.Tree;
|
||||||
|
|
||||||
|
public class MonteCarloTreeSearch {
|
||||||
|
|
||||||
|
static final int WIN_SCORE = 10;
|
||||||
|
int level;
|
||||||
|
int oponent;
|
||||||
|
|
||||||
|
public MonteCarloTreeSearch() {
|
||||||
|
this.level = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getLevel() {
|
||||||
|
return level;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLevel(int level) {
|
||||||
|
this.level = level;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getMillisForCurrentLevel() {
|
||||||
|
return 2 * (this.level - 1) + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Board findNextMove(Board board, int playerNo) {
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
|
long end = start + 60 * getMillisForCurrentLevel();
|
||||||
|
|
||||||
|
oponent = 3 - playerNo;
|
||||||
|
Tree tree = new Tree();
|
||||||
|
Node rootNode = tree.getRoot();
|
||||||
|
rootNode.getState().setBoard(board);
|
||||||
|
rootNode.getState().setPlayerNo(oponent);
|
||||||
|
|
||||||
|
while (System.currentTimeMillis() < end) {
|
||||||
|
// Phase 1 - Selection
|
||||||
|
Node promisingNode = selectPromisingNode(rootNode);
|
||||||
|
// Phase 2 - Expansion
|
||||||
|
if (promisingNode.getState().getBoard().checkStatus() == Board.IN_PROGRESS)
|
||||||
|
expandNode(promisingNode);
|
||||||
|
|
||||||
|
// Phase 3 - Simulation
|
||||||
|
Node nodeToExplore = promisingNode;
|
||||||
|
if (promisingNode.getChildArray().size() > 0) {
|
||||||
|
nodeToExplore = promisingNode.getRandomChildNode();
|
||||||
|
}
|
||||||
|
int playoutResult = simulateRandomPlayout(nodeToExplore);
|
||||||
|
// Phase 4 - Update
|
||||||
|
backPropogation(nodeToExplore, playoutResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
Node winnerNode = rootNode.getChildWithMaxScore();
|
||||||
|
tree.setRoot(winnerNode);
|
||||||
|
return winnerNode.getState().getBoard();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node selectPromisingNode(Node rootNode) {
|
||||||
|
Node node = rootNode;
|
||||||
|
while (node.getChildArray().size() != 0) {
|
||||||
|
node = UCT.findBestNodeWithUCT(node);
|
||||||
|
}
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void expandNode(Node node) {
|
||||||
|
List<State> possibleStates = node.getState().getAllPossibleStates();
|
||||||
|
possibleStates.forEach(state -> {
|
||||||
|
Node newNode = new Node(state);
|
||||||
|
newNode.setParent(node);
|
||||||
|
newNode.getState().setPlayerNo(node.getState().getOpponent());
|
||||||
|
node.getChildArray().add(newNode);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private void backPropogation(Node nodeToExplore, int playerNo) {
|
||||||
|
Node tempNode = nodeToExplore;
|
||||||
|
while (tempNode != null) {
|
||||||
|
tempNode.getState().incrementVisit();
|
||||||
|
if (tempNode.getState().getPlayerNo() == playerNo)
|
||||||
|
tempNode.getState().addScore(WIN_SCORE);
|
||||||
|
tempNode = tempNode.getParent();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private int simulateRandomPlayout(Node node) {
|
||||||
|
Node tempNode = new Node(node);
|
||||||
|
State tempState = tempNode.getState();
|
||||||
|
int boardStatus = tempState.getBoard().checkStatus();
|
||||||
|
|
||||||
|
if (boardStatus == oponent) {
|
||||||
|
tempNode.getParent().getState().setWinScore(Integer.MIN_VALUE);
|
||||||
|
return boardStatus;
|
||||||
|
}
|
||||||
|
while (boardStatus == Board.IN_PROGRESS) {
|
||||||
|
tempState.togglePlayer();
|
||||||
|
tempState.randomPlay();
|
||||||
|
boardStatus = tempState.getBoard().checkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
return boardStatus;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,97 @@
|
||||||
|
package com.baeldung.algorithms.mcts.montecarlo;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import com.baeldung.algorithms.mcts.tictactoe.Board;
|
||||||
|
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||||
|
|
||||||
|
public class State {
|
||||||
|
Board board;
|
||||||
|
int playerNo;
|
||||||
|
int visitCount;
|
||||||
|
double winScore;
|
||||||
|
|
||||||
|
public State() {
|
||||||
|
board = new Board();
|
||||||
|
}
|
||||||
|
|
||||||
|
public State(State state) {
|
||||||
|
this.board = new Board(state.getBoard());
|
||||||
|
this.playerNo = state.getPlayerNo();
|
||||||
|
this.visitCount = state.getVisitCount();
|
||||||
|
this.winScore = state.getWinScore();
|
||||||
|
}
|
||||||
|
|
||||||
|
public State(Board board) {
|
||||||
|
this.board = new Board(board);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Board getBoard() {
|
||||||
|
return board;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBoard(Board board) {
|
||||||
|
this.board = board;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getPlayerNo() {
|
||||||
|
return playerNo;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setPlayerNo(int playerNo) {
|
||||||
|
this.playerNo = playerNo;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getOpponent() {
|
||||||
|
return 3 - playerNo;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getVisitCount() {
|
||||||
|
return visitCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setVisitCount(int visitCount) {
|
||||||
|
this.visitCount = visitCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getWinScore() {
|
||||||
|
return winScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setWinScore(double winScore) {
|
||||||
|
this.winScore = winScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<State> getAllPossibleStates() {
|
||||||
|
List<State> possibleStates = new ArrayList<>();
|
||||||
|
List<Position> availablePositions = this.board.getEmptyPositions();
|
||||||
|
availablePositions.forEach(p -> {
|
||||||
|
State newState = new State(this.board);
|
||||||
|
newState.setPlayerNo(3 - this.playerNo);
|
||||||
|
newState.getBoard().performMove(newState.getPlayerNo(), p);
|
||||||
|
possibleStates.add(newState);
|
||||||
|
});
|
||||||
|
return possibleStates;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void incrementVisit() {
|
||||||
|
this.visitCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addScore(double score) {
|
||||||
|
if (this.winScore != Integer.MIN_VALUE)
|
||||||
|
this.winScore += score;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void randomPlay() {
|
||||||
|
List<Position> availablePositions = this.board.getEmptyPositions();
|
||||||
|
int totalPossibilities = availablePositions.size();
|
||||||
|
int selectRandom = (int) (Math.random() * ((totalPossibilities - 1) + 1));
|
||||||
|
this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void togglePlayer() {
|
||||||
|
this.playerNo = 3 - this.playerNo;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
package com.baeldung.algorithms.mcts.montecarlo;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import com.baeldung.algorithms.mcts.tree.Node;
|
||||||
|
|
||||||
|
public class UCT {
|
||||||
|
final static double C = 1.41;
|
||||||
|
|
||||||
|
public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
|
||||||
|
if (nodeVisit == 0)
|
||||||
|
return Integer.MAX_VALUE;
|
||||||
|
return ((double) nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Node findBestNodeWithUCT(Node node) {
|
||||||
|
int parentVisit = node.getState().getVisitCount();
|
||||||
|
List<Node> childNodes = node.getChildArray();
|
||||||
|
return Collections.max(childNodes, Comparator.comparing(c -> {
|
||||||
|
double score = uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount());
|
||||||
|
return score;
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,155 @@
|
||||||
|
package com.baeldung.algorithms.mcts.tictactoe;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class Board {
|
||||||
|
int[][] boardValues;
|
||||||
|
int totalMoves;
|
||||||
|
|
||||||
|
public static final int DEFAULT_BOARD_SIZE = 3;
|
||||||
|
|
||||||
|
public static final int IN_PROGRESS = -1;
|
||||||
|
public static final int DRAW = 0;
|
||||||
|
public static final int P1 = 1;
|
||||||
|
public static final int P2 = 2;
|
||||||
|
|
||||||
|
public Board() {
|
||||||
|
boardValues = new int[DEFAULT_BOARD_SIZE][DEFAULT_BOARD_SIZE];
|
||||||
|
}
|
||||||
|
|
||||||
|
public Board(int boardSize) {
|
||||||
|
boardValues = new int[boardSize][boardSize];
|
||||||
|
}
|
||||||
|
|
||||||
|
public Board(int[][] boardValues) {
|
||||||
|
this.boardValues = boardValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Board(int[][] boardValues, int totalMoves) {
|
||||||
|
this.boardValues = boardValues;
|
||||||
|
this.totalMoves = totalMoves;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Board(Board board) {
|
||||||
|
int boardLength = board.getBoardValues().length;
|
||||||
|
this.boardValues = new int[boardLength][boardLength];
|
||||||
|
int[][] boardValues = board.getBoardValues();
|
||||||
|
int n = boardValues.length;
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
int m = boardValues[i].length;
|
||||||
|
for (int j = 0; j < m; j++) {
|
||||||
|
this.boardValues[i][j] = boardValues[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void performMove(int player, Position p) {
|
||||||
|
this.totalMoves++;
|
||||||
|
boardValues[p.getX()][p.getY()] = player;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int[][] getBoardValues() {
|
||||||
|
return boardValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBoardValues(int[][] boardValues) {
|
||||||
|
this.boardValues = boardValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int checkStatus() {
|
||||||
|
int boardSize = boardValues.length;
|
||||||
|
int maxIndex = boardSize - 1;
|
||||||
|
int[] diag1 = new int[boardSize];
|
||||||
|
int[] diag2 = new int[boardSize];
|
||||||
|
|
||||||
|
for (int i = 0; i < boardSize; i++) {
|
||||||
|
int[] row = boardValues[i];
|
||||||
|
int[] col = new int[boardSize];
|
||||||
|
for (int j = 0; j < boardSize; j++) {
|
||||||
|
col[j] = boardValues[j][i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int checkRowForWin = checkForWin(row);
|
||||||
|
if(checkRowForWin!=0)
|
||||||
|
return checkRowForWin;
|
||||||
|
|
||||||
|
int checkColForWin = checkForWin(col);
|
||||||
|
if(checkColForWin!=0)
|
||||||
|
return checkColForWin;
|
||||||
|
|
||||||
|
diag1[i] = boardValues[i][i];
|
||||||
|
diag2[i] = boardValues[maxIndex - i][i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int checkDia1gForWin = checkForWin(diag1);
|
||||||
|
if(checkDia1gForWin!=0)
|
||||||
|
return checkDia1gForWin;
|
||||||
|
|
||||||
|
int checkDiag2ForWin = checkForWin(diag2);
|
||||||
|
if(checkDiag2ForWin!=0)
|
||||||
|
return checkDiag2ForWin;
|
||||||
|
|
||||||
|
if (getEmptyPositions().size() > 0)
|
||||||
|
return IN_PROGRESS;
|
||||||
|
else
|
||||||
|
return DRAW;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int checkForWin(int[] row) {
|
||||||
|
boolean isEqual = true;
|
||||||
|
int size = row.length;
|
||||||
|
int previous = row[0];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
if (previous != row[i]) {
|
||||||
|
isEqual = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
previous = row[i];
|
||||||
|
}
|
||||||
|
if(isEqual)
|
||||||
|
return previous;
|
||||||
|
else
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void printBoard() {
|
||||||
|
int size = this.boardValues.length;
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
for (int j = 0; j < size; j++) {
|
||||||
|
System.out.print(boardValues[i][j] + " ");
|
||||||
|
}
|
||||||
|
System.out.println();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Position> getEmptyPositions() {
|
||||||
|
int size = this.boardValues.length;
|
||||||
|
List<Position> emptyPositions = new ArrayList<>();
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
for (int j = 0; j < size; j++) {
|
||||||
|
if (boardValues[i][j] == 0)
|
||||||
|
emptyPositions.add(new Position(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return emptyPositions;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void printStatus() {
|
||||||
|
switch (this.checkStatus()) {
|
||||||
|
case P1:
|
||||||
|
System.out.println("Player 1 wins");
|
||||||
|
break;
|
||||||
|
case P2:
|
||||||
|
System.out.println("Player 2 wins");
|
||||||
|
break;
|
||||||
|
case DRAW:
|
||||||
|
System.out.println("Game Draw");
|
||||||
|
break;
|
||||||
|
case IN_PROGRESS:
|
||||||
|
System.out.println("Game In rogress");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
package com.baeldung.algorithms.mcts.tictactoe;
|
||||||
|
|
||||||
|
public class Position {
|
||||||
|
int x;
|
||||||
|
int y;
|
||||||
|
|
||||||
|
public Position() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public Position(int x, int y) {
|
||||||
|
this.x = x;
|
||||||
|
this.y = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getX() {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setX(int x) {
|
||||||
|
this.x = x;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getY() {
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setY(int y) {
|
||||||
|
this.y = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
package com.baeldung.algorithms.mcts.tree;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import com.baeldung.algorithms.mcts.montecarlo.State;
|
||||||
|
|
||||||
|
public class Node {
|
||||||
|
State state;
|
||||||
|
Node parent;
|
||||||
|
List<Node> childArray;
|
||||||
|
|
||||||
|
public Node() {
|
||||||
|
this.state = new State();
|
||||||
|
childArray = new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node(State state) {
|
||||||
|
this.state = state;
|
||||||
|
childArray = new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node(State state, Node parent, List<Node> childArray) {
|
||||||
|
this.state = state;
|
||||||
|
this.parent = parent;
|
||||||
|
this.childArray = childArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node(Node node) {
|
||||||
|
this.childArray = new ArrayList<>();
|
||||||
|
this.state = new State(node.getState());
|
||||||
|
if (node.getParent() != null)
|
||||||
|
this.parent = node.getParent();
|
||||||
|
List<Node> childArray = node.getChildArray();
|
||||||
|
for (Node child : childArray) {
|
||||||
|
this.childArray.add(new Node(child));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public State getState() {
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setState(State state) {
|
||||||
|
this.state = state;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node getParent() {
|
||||||
|
return parent;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setParent(Node parent) {
|
||||||
|
this.parent = parent;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Node> getChildArray() {
|
||||||
|
return childArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setChildArray(List<Node> childArray) {
|
||||||
|
this.childArray = childArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node getRandomChildNode() {
|
||||||
|
int noOfPossibleMoves = this.childArray.size();
|
||||||
|
int selectRandom = (int) (Math.random() * ((noOfPossibleMoves - 1) + 1));
|
||||||
|
return this.childArray.get(selectRandom);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node getChildWithMaxScore() {
|
||||||
|
return Collections.max(this.childArray, Comparator.comparing(c -> {
|
||||||
|
return c.getState().getVisitCount();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
package com.baeldung.algorithms.mcts.tree;
|
||||||
|
|
||||||
|
public class Tree {
|
||||||
|
Node root;
|
||||||
|
|
||||||
|
public Tree() {
|
||||||
|
root = new Node();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Tree(Node root) {
|
||||||
|
this.root = root;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node getRoot() {
|
||||||
|
return root;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setRoot(Node root) {
|
||||||
|
this.root = root;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addChild(Node parent, Node child) {
|
||||||
|
parent.getChildArray().add(child);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
package algorithms;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import com.baeldung.algorithms.mcts.montecarlo.MonteCarloTreeSearch;
|
||||||
|
import com.baeldung.algorithms.mcts.montecarlo.State;
|
||||||
|
import com.baeldung.algorithms.mcts.montecarlo.UCT;
|
||||||
|
import com.baeldung.algorithms.mcts.tictactoe.Board;
|
||||||
|
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||||
|
import com.baeldung.algorithms.mcts.tree.Tree;
|
||||||
|
|
||||||
|
public class MCTSTest {
|
||||||
|
Tree gameTree;
|
||||||
|
MonteCarloTreeSearch mcts;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void initGameTree() {
|
||||||
|
gameTree = new Tree();
|
||||||
|
mcts = new MonteCarloTreeSearch();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
|
||||||
|
double uctValue = 15.79;
|
||||||
|
assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
|
||||||
|
State initState = gameTree.getRoot().getState();
|
||||||
|
List<State> possibleStates = initState.getAllPossibleStates();
|
||||||
|
assertTrue(possibleStates.size() > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
|
||||||
|
Board board = new Board();
|
||||||
|
int initAvailablePositions = board.getEmptyPositions().size();
|
||||||
|
board.performMove(Board.P1, new Position(1, 1));
|
||||||
|
int availablePositions = board.getEmptyPositions().size();
|
||||||
|
assertTrue(initAvailablePositions > availablePositions);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
|
||||||
|
Board board = new Board();
|
||||||
|
|
||||||
|
int player = Board.P1;
|
||||||
|
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
||||||
|
for (int i = 0; i < totalMoves; i++) {
|
||||||
|
board = mcts.findNextMove(board, player);
|
||||||
|
if (board.checkStatus() != -1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
player = 3 - player;
|
||||||
|
}
|
||||||
|
int winStatus = board.checkStatus();
|
||||||
|
assertEquals(winStatus, Board.DRAW);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
|
||||||
|
Board board = new Board();
|
||||||
|
MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
|
||||||
|
mcts1.setLevel(1);
|
||||||
|
MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
|
||||||
|
mcts3.setLevel(3);
|
||||||
|
|
||||||
|
int player = Board.P1;
|
||||||
|
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
||||||
|
for (int i = 0; i < totalMoves; i++) {
|
||||||
|
if (player == Board.P1)
|
||||||
|
board = mcts3.findNextMove(board, player);
|
||||||
|
else
|
||||||
|
board = mcts1.findNextMove(board, player);
|
||||||
|
|
||||||
|
if (board.checkStatus() != -1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
player = 3 - player;
|
||||||
|
}
|
||||||
|
int winStatus = board.checkStatus();
|
||||||
|
assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue