BAEL-984 Monte Carlo tree search (#2129)
* BAEL-984 Monte Carlo tree search BAEL-984 Implementation for tic tac toe using Monte Carlo tree search * BAEL-984 test cases for MCTS BAEL-984 test cases for Monte Carlo tree search implementation
This commit is contained in:
parent
2cde0e37c2
commit
1b0d5f0b73
|
@ -0,0 +1,109 @@
|
|||
package com.baeldung.algorithms.mcts.montecarlo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import com.baeldung.algorithms.mcts.tictactoe.Board;
|
||||
import com.baeldung.algorithms.mcts.tree.Node;
|
||||
import com.baeldung.algorithms.mcts.tree.Tree;
|
||||
|
||||
public class MonteCarloTreeSearch {
|
||||
|
||||
static final int WIN_SCORE = 10;
|
||||
int level;
|
||||
int oponent;
|
||||
|
||||
public MonteCarloTreeSearch() {
|
||||
this.level = 3;
|
||||
}
|
||||
|
||||
public int getLevel() {
|
||||
return level;
|
||||
}
|
||||
|
||||
public void setLevel(int level) {
|
||||
this.level = level;
|
||||
}
|
||||
|
||||
private int getMillisForCurrentLevel() {
|
||||
return 2 * (this.level - 1) + 1;
|
||||
}
|
||||
|
||||
public Board findNextMove(Board board, int playerNo) {
|
||||
long start = System.currentTimeMillis();
|
||||
long end = start + 60 * getMillisForCurrentLevel();
|
||||
|
||||
oponent = 3 - playerNo;
|
||||
Tree tree = new Tree();
|
||||
Node rootNode = tree.getRoot();
|
||||
rootNode.getState().setBoard(board);
|
||||
rootNode.getState().setPlayerNo(oponent);
|
||||
|
||||
while (System.currentTimeMillis() < end) {
|
||||
// Phase 1 - Selection
|
||||
Node promisingNode = selectPromisingNode(rootNode);
|
||||
// Phase 2 - Expansion
|
||||
if (promisingNode.getState().getBoard().checkStatus() == Board.IN_PROGRESS)
|
||||
expandNode(promisingNode);
|
||||
|
||||
// Phase 3 - Simulation
|
||||
Node nodeToExplore = promisingNode;
|
||||
if (promisingNode.getChildArray().size() > 0) {
|
||||
nodeToExplore = promisingNode.getRandomChildNode();
|
||||
}
|
||||
int playoutResult = simulateRandomPlayout(nodeToExplore);
|
||||
// Phase 4 - Update
|
||||
backPropogation(nodeToExplore, playoutResult);
|
||||
}
|
||||
|
||||
Node winnerNode = rootNode.getChildWithMaxScore();
|
||||
tree.setRoot(winnerNode);
|
||||
return winnerNode.getState().getBoard();
|
||||
}
|
||||
|
||||
private Node selectPromisingNode(Node rootNode) {
|
||||
Node node = rootNode;
|
||||
while (node.getChildArray().size() != 0) {
|
||||
node = UCT.findBestNodeWithUCT(node);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
private void expandNode(Node node) {
|
||||
List<State> possibleStates = node.getState().getAllPossibleStates();
|
||||
possibleStates.forEach(state -> {
|
||||
Node newNode = new Node(state);
|
||||
newNode.setParent(node);
|
||||
newNode.getState().setPlayerNo(node.getState().getOpponent());
|
||||
node.getChildArray().add(newNode);
|
||||
});
|
||||
}
|
||||
|
||||
private void backPropogation(Node nodeToExplore, int playerNo) {
|
||||
Node tempNode = nodeToExplore;
|
||||
while (tempNode != null) {
|
||||
tempNode.getState().incrementVisit();
|
||||
if (tempNode.getState().getPlayerNo() == playerNo)
|
||||
tempNode.getState().addScore(WIN_SCORE);
|
||||
tempNode = tempNode.getParent();
|
||||
}
|
||||
}
|
||||
|
||||
private int simulateRandomPlayout(Node node) {
|
||||
Node tempNode = new Node(node);
|
||||
State tempState = tempNode.getState();
|
||||
int boardStatus = tempState.getBoard().checkStatus();
|
||||
|
||||
if (boardStatus == oponent) {
|
||||
tempNode.getParent().getState().setWinScore(Integer.MIN_VALUE);
|
||||
return boardStatus;
|
||||
}
|
||||
while (boardStatus == Board.IN_PROGRESS) {
|
||||
tempState.togglePlayer();
|
||||
tempState.randomPlay();
|
||||
boardStatus = tempState.getBoard().checkStatus();
|
||||
}
|
||||
|
||||
return boardStatus;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package com.baeldung.algorithms.mcts.montecarlo;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import com.baeldung.algorithms.mcts.tictactoe.Board;
|
||||
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||
|
||||
public class State {
|
||||
Board board;
|
||||
int playerNo;
|
||||
int visitCount;
|
||||
double winScore;
|
||||
|
||||
public State() {
|
||||
board = new Board();
|
||||
}
|
||||
|
||||
public State(State state) {
|
||||
this.board = new Board(state.getBoard());
|
||||
this.playerNo = state.getPlayerNo();
|
||||
this.visitCount = state.getVisitCount();
|
||||
this.winScore = state.getWinScore();
|
||||
}
|
||||
|
||||
public State(Board board) {
|
||||
this.board = new Board(board);
|
||||
}
|
||||
|
||||
public Board getBoard() {
|
||||
return board;
|
||||
}
|
||||
|
||||
public void setBoard(Board board) {
|
||||
this.board = board;
|
||||
}
|
||||
|
||||
public int getPlayerNo() {
|
||||
return playerNo;
|
||||
}
|
||||
|
||||
public void setPlayerNo(int playerNo) {
|
||||
this.playerNo = playerNo;
|
||||
}
|
||||
|
||||
public int getOpponent() {
|
||||
return 3 - playerNo;
|
||||
}
|
||||
|
||||
public int getVisitCount() {
|
||||
return visitCount;
|
||||
}
|
||||
|
||||
public void setVisitCount(int visitCount) {
|
||||
this.visitCount = visitCount;
|
||||
}
|
||||
|
||||
public double getWinScore() {
|
||||
return winScore;
|
||||
}
|
||||
|
||||
public void setWinScore(double winScore) {
|
||||
this.winScore = winScore;
|
||||
}
|
||||
|
||||
public List<State> getAllPossibleStates() {
|
||||
List<State> possibleStates = new ArrayList<>();
|
||||
List<Position> availablePositions = this.board.getEmptyPositions();
|
||||
availablePositions.forEach(p -> {
|
||||
State newState = new State(this.board);
|
||||
newState.setPlayerNo(3 - this.playerNo);
|
||||
newState.getBoard().performMove(newState.getPlayerNo(), p);
|
||||
possibleStates.add(newState);
|
||||
});
|
||||
return possibleStates;
|
||||
}
|
||||
|
||||
public void incrementVisit() {
|
||||
this.visitCount++;
|
||||
}
|
||||
|
||||
public void addScore(double score) {
|
||||
if (this.winScore != Integer.MIN_VALUE)
|
||||
this.winScore += score;
|
||||
}
|
||||
|
||||
public void randomPlay() {
|
||||
List<Position> availablePositions = this.board.getEmptyPositions();
|
||||
int totalPossibilities = availablePositions.size();
|
||||
int selectRandom = (int) (Math.random() * ((totalPossibilities - 1) + 1));
|
||||
this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
|
||||
}
|
||||
|
||||
public void togglePlayer() {
|
||||
this.playerNo = 3 - this.playerNo;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package com.baeldung.algorithms.mcts.montecarlo;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
import com.baeldung.algorithms.mcts.tree.Node;
|
||||
|
||||
public class UCT {
|
||||
final static double C = 1.41;
|
||||
|
||||
public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
|
||||
if (nodeVisit == 0)
|
||||
return Integer.MAX_VALUE;
|
||||
return ((double) nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
|
||||
}
|
||||
|
||||
public static Node findBestNodeWithUCT(Node node) {
|
||||
int parentVisit = node.getState().getVisitCount();
|
||||
List<Node> childNodes = node.getChildArray();
|
||||
return Collections.max(childNodes, Comparator.comparing(c -> {
|
||||
double score = uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount());
|
||||
return score;
|
||||
}));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,155 @@
|
|||
package com.baeldung.algorithms.mcts.tictactoe;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class Board {
|
||||
int[][] boardValues;
|
||||
int totalMoves;
|
||||
|
||||
public static final int DEFAULT_BOARD_SIZE = 3;
|
||||
|
||||
public static final int IN_PROGRESS = -1;
|
||||
public static final int DRAW = 0;
|
||||
public static final int P1 = 1;
|
||||
public static final int P2 = 2;
|
||||
|
||||
public Board() {
|
||||
boardValues = new int[DEFAULT_BOARD_SIZE][DEFAULT_BOARD_SIZE];
|
||||
}
|
||||
|
||||
public Board(int boardSize) {
|
||||
boardValues = new int[boardSize][boardSize];
|
||||
}
|
||||
|
||||
public Board(int[][] boardValues) {
|
||||
this.boardValues = boardValues;
|
||||
}
|
||||
|
||||
public Board(int[][] boardValues, int totalMoves) {
|
||||
this.boardValues = boardValues;
|
||||
this.totalMoves = totalMoves;
|
||||
}
|
||||
|
||||
public Board(Board board) {
|
||||
int boardLength = board.getBoardValues().length;
|
||||
this.boardValues = new int[boardLength][boardLength];
|
||||
int[][] boardValues = board.getBoardValues();
|
||||
int n = boardValues.length;
|
||||
for (int i = 0; i < n; i++) {
|
||||
int m = boardValues[i].length;
|
||||
for (int j = 0; j < m; j++) {
|
||||
this.boardValues[i][j] = boardValues[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void performMove(int player, Position p) {
|
||||
this.totalMoves++;
|
||||
boardValues[p.getX()][p.getY()] = player;
|
||||
}
|
||||
|
||||
public int[][] getBoardValues() {
|
||||
return boardValues;
|
||||
}
|
||||
|
||||
public void setBoardValues(int[][] boardValues) {
|
||||
this.boardValues = boardValues;
|
||||
}
|
||||
|
||||
public int checkStatus() {
|
||||
int boardSize = boardValues.length;
|
||||
int maxIndex = boardSize - 1;
|
||||
int[] diag1 = new int[boardSize];
|
||||
int[] diag2 = new int[boardSize];
|
||||
|
||||
for (int i = 0; i < boardSize; i++) {
|
||||
int[] row = boardValues[i];
|
||||
int[] col = new int[boardSize];
|
||||
for (int j = 0; j < boardSize; j++) {
|
||||
col[j] = boardValues[j][i];
|
||||
}
|
||||
|
||||
int checkRowForWin = checkForWin(row);
|
||||
if(checkRowForWin!=0)
|
||||
return checkRowForWin;
|
||||
|
||||
int checkColForWin = checkForWin(col);
|
||||
if(checkColForWin!=0)
|
||||
return checkColForWin;
|
||||
|
||||
diag1[i] = boardValues[i][i];
|
||||
diag2[i] = boardValues[maxIndex - i][i];
|
||||
}
|
||||
|
||||
int checkDia1gForWin = checkForWin(diag1);
|
||||
if(checkDia1gForWin!=0)
|
||||
return checkDia1gForWin;
|
||||
|
||||
int checkDiag2ForWin = checkForWin(diag2);
|
||||
if(checkDiag2ForWin!=0)
|
||||
return checkDiag2ForWin;
|
||||
|
||||
if (getEmptyPositions().size() > 0)
|
||||
return IN_PROGRESS;
|
||||
else
|
||||
return DRAW;
|
||||
}
|
||||
|
||||
private int checkForWin(int[] row) {
|
||||
boolean isEqual = true;
|
||||
int size = row.length;
|
||||
int previous = row[0];
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (previous != row[i]) {
|
||||
isEqual = false;
|
||||
break;
|
||||
}
|
||||
previous = row[i];
|
||||
}
|
||||
if(isEqual)
|
||||
return previous;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
public void printBoard() {
|
||||
int size = this.boardValues.length;
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int j = 0; j < size; j++) {
|
||||
System.out.print(boardValues[i][j] + " ");
|
||||
}
|
||||
System.out.println();
|
||||
}
|
||||
}
|
||||
|
||||
public List<Position> getEmptyPositions() {
|
||||
int size = this.boardValues.length;
|
||||
List<Position> emptyPositions = new ArrayList<>();
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int j = 0; j < size; j++) {
|
||||
if (boardValues[i][j] == 0)
|
||||
emptyPositions.add(new Position(i, j));
|
||||
}
|
||||
}
|
||||
return emptyPositions;
|
||||
}
|
||||
|
||||
public void printStatus() {
|
||||
switch (this.checkStatus()) {
|
||||
case P1:
|
||||
System.out.println("Player 1 wins");
|
||||
break;
|
||||
case P2:
|
||||
System.out.println("Player 2 wins");
|
||||
break;
|
||||
case DRAW:
|
||||
System.out.println("Game Draw");
|
||||
break;
|
||||
case IN_PROGRESS:
|
||||
System.out.println("Game In rogress");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package com.baeldung.algorithms.mcts.tictactoe;
|
||||
|
||||
public class Position {
|
||||
int x;
|
||||
int y;
|
||||
|
||||
public Position() {
|
||||
}
|
||||
|
||||
public Position(int x, int y) {
|
||||
this.x = x;
|
||||
this.y = y;
|
||||
}
|
||||
|
||||
public int getX() {
|
||||
return x;
|
||||
}
|
||||
|
||||
public void setX(int x) {
|
||||
this.x = x;
|
||||
}
|
||||
|
||||
public int getY() {
|
||||
return y;
|
||||
}
|
||||
|
||||
public void setY(int y) {
|
||||
this.y = y;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package com.baeldung.algorithms.mcts.tree;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
import com.baeldung.algorithms.mcts.montecarlo.State;
|
||||
|
||||
public class Node {
|
||||
State state;
|
||||
Node parent;
|
||||
List<Node> childArray;
|
||||
|
||||
public Node() {
|
||||
this.state = new State();
|
||||
childArray = new ArrayList<>();
|
||||
}
|
||||
|
||||
public Node(State state) {
|
||||
this.state = state;
|
||||
childArray = new ArrayList<>();
|
||||
}
|
||||
|
||||
public Node(State state, Node parent, List<Node> childArray) {
|
||||
this.state = state;
|
||||
this.parent = parent;
|
||||
this.childArray = childArray;
|
||||
}
|
||||
|
||||
public Node(Node node) {
|
||||
this.childArray = new ArrayList<>();
|
||||
this.state = new State(node.getState());
|
||||
if (node.getParent() != null)
|
||||
this.parent = node.getParent();
|
||||
List<Node> childArray = node.getChildArray();
|
||||
for (Node child : childArray) {
|
||||
this.childArray.add(new Node(child));
|
||||
}
|
||||
}
|
||||
|
||||
public State getState() {
|
||||
return state;
|
||||
}
|
||||
|
||||
public void setState(State state) {
|
||||
this.state = state;
|
||||
}
|
||||
|
||||
public Node getParent() {
|
||||
return parent;
|
||||
}
|
||||
|
||||
public void setParent(Node parent) {
|
||||
this.parent = parent;
|
||||
}
|
||||
|
||||
public List<Node> getChildArray() {
|
||||
return childArray;
|
||||
}
|
||||
|
||||
public void setChildArray(List<Node> childArray) {
|
||||
this.childArray = childArray;
|
||||
}
|
||||
|
||||
public Node getRandomChildNode() {
|
||||
int noOfPossibleMoves = this.childArray.size();
|
||||
int selectRandom = (int) (Math.random() * ((noOfPossibleMoves - 1) + 1));
|
||||
return this.childArray.get(selectRandom);
|
||||
}
|
||||
|
||||
public Node getChildWithMaxScore() {
|
||||
return Collections.max(this.childArray, Comparator.comparing(c -> {
|
||||
return c.getState().getVisitCount();
|
||||
}));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package com.baeldung.algorithms.mcts.tree;
|
||||
|
||||
public class Tree {
|
||||
Node root;
|
||||
|
||||
public Tree() {
|
||||
root = new Node();
|
||||
}
|
||||
|
||||
public Tree(Node root) {
|
||||
this.root = root;
|
||||
}
|
||||
|
||||
public Node getRoot() {
|
||||
return root;
|
||||
}
|
||||
|
||||
public void setRoot(Node root) {
|
||||
this.root = root;
|
||||
}
|
||||
|
||||
public void addChild(Node parent, Node child) {
|
||||
parent.getChildArray().add(child);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
package algorithms;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import com.baeldung.algorithms.mcts.montecarlo.MonteCarloTreeSearch;
|
||||
import com.baeldung.algorithms.mcts.montecarlo.State;
|
||||
import com.baeldung.algorithms.mcts.montecarlo.UCT;
|
||||
import com.baeldung.algorithms.mcts.tictactoe.Board;
|
||||
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||
import com.baeldung.algorithms.mcts.tree.Tree;
|
||||
|
||||
public class MCTSTest {
|
||||
Tree gameTree;
|
||||
MonteCarloTreeSearch mcts;
|
||||
|
||||
@Before
|
||||
public void initGameTree() {
|
||||
gameTree = new Tree();
|
||||
mcts = new MonteCarloTreeSearch();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
|
||||
double uctValue = 15.79;
|
||||
assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
|
||||
State initState = gameTree.getRoot().getState();
|
||||
List<State> possibleStates = initState.getAllPossibleStates();
|
||||
assertTrue(possibleStates.size() > 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
|
||||
Board board = new Board();
|
||||
int initAvailablePositions = board.getEmptyPositions().size();
|
||||
board.performMove(Board.P1, new Position(1, 1));
|
||||
int availablePositions = board.getEmptyPositions().size();
|
||||
assertTrue(initAvailablePositions > availablePositions);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
|
||||
Board board = new Board();
|
||||
|
||||
int player = Board.P1;
|
||||
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
||||
for (int i = 0; i < totalMoves; i++) {
|
||||
board = mcts.findNextMove(board, player);
|
||||
if (board.checkStatus() != -1) {
|
||||
break;
|
||||
}
|
||||
player = 3 - player;
|
||||
}
|
||||
int winStatus = board.checkStatus();
|
||||
assertEquals(winStatus, Board.DRAW);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
|
||||
Board board = new Board();
|
||||
MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
|
||||
mcts1.setLevel(1);
|
||||
MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
|
||||
mcts3.setLevel(3);
|
||||
|
||||
int player = Board.P1;
|
||||
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
||||
for (int i = 0; i < totalMoves; i++) {
|
||||
if (player == Board.P1)
|
||||
board = mcts3.findNextMove(board, player);
|
||||
else
|
||||
board = mcts1.findNextMove(board, player);
|
||||
|
||||
if (board.checkStatus() != -1) {
|
||||
break;
|
||||
}
|
||||
player = 3 - player;
|
||||
}
|
||||
int winStatus = board.checkStatus();
|
||||
assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue