BAEL-984 Monte Carlo tree search (#2129)

* BAEL-984 Monte Carlo tree search

BAEL-984 Implementation for tic tac toe using Monte Carlo tree search

* BAEL-984	 test cases for MCTS

BAEL-984 test cases for Monte Carlo tree search implementation
This commit is contained in:
Parth Karia 2017-06-23 02:21:43 +05:30 committed by maibin
parent 2cde0e37c2
commit 1b0d5f0b73
8 changed files with 614 additions and 0 deletions

View File

@ -0,0 +1,109 @@
package com.baeldung.algorithms.mcts.montecarlo;
import java.util.List;
import com.baeldung.algorithms.mcts.tictactoe.Board;
import com.baeldung.algorithms.mcts.tree.Node;
import com.baeldung.algorithms.mcts.tree.Tree;
public class MonteCarloTreeSearch {
static final int WIN_SCORE = 10;
int level;
int oponent;
public MonteCarloTreeSearch() {
this.level = 3;
}
public int getLevel() {
return level;
}
public void setLevel(int level) {
this.level = level;
}
private int getMillisForCurrentLevel() {
return 2 * (this.level - 1) + 1;
}
public Board findNextMove(Board board, int playerNo) {
long start = System.currentTimeMillis();
long end = start + 60 * getMillisForCurrentLevel();
oponent = 3 - playerNo;
Tree tree = new Tree();
Node rootNode = tree.getRoot();
rootNode.getState().setBoard(board);
rootNode.getState().setPlayerNo(oponent);
while (System.currentTimeMillis() < end) {
// Phase 1 - Selection
Node promisingNode = selectPromisingNode(rootNode);
// Phase 2 - Expansion
if (promisingNode.getState().getBoard().checkStatus() == Board.IN_PROGRESS)
expandNode(promisingNode);
// Phase 3 - Simulation
Node nodeToExplore = promisingNode;
if (promisingNode.getChildArray().size() > 0) {
nodeToExplore = promisingNode.getRandomChildNode();
}
int playoutResult = simulateRandomPlayout(nodeToExplore);
// Phase 4 - Update
backPropogation(nodeToExplore, playoutResult);
}
Node winnerNode = rootNode.getChildWithMaxScore();
tree.setRoot(winnerNode);
return winnerNode.getState().getBoard();
}
private Node selectPromisingNode(Node rootNode) {
Node node = rootNode;
while (node.getChildArray().size() != 0) {
node = UCT.findBestNodeWithUCT(node);
}
return node;
}
private void expandNode(Node node) {
List<State> possibleStates = node.getState().getAllPossibleStates();
possibleStates.forEach(state -> {
Node newNode = new Node(state);
newNode.setParent(node);
newNode.getState().setPlayerNo(node.getState().getOpponent());
node.getChildArray().add(newNode);
});
}
private void backPropogation(Node nodeToExplore, int playerNo) {
Node tempNode = nodeToExplore;
while (tempNode != null) {
tempNode.getState().incrementVisit();
if (tempNode.getState().getPlayerNo() == playerNo)
tempNode.getState().addScore(WIN_SCORE);
tempNode = tempNode.getParent();
}
}
private int simulateRandomPlayout(Node node) {
Node tempNode = new Node(node);
State tempState = tempNode.getState();
int boardStatus = tempState.getBoard().checkStatus();
if (boardStatus == oponent) {
tempNode.getParent().getState().setWinScore(Integer.MIN_VALUE);
return boardStatus;
}
while (boardStatus == Board.IN_PROGRESS) {
tempState.togglePlayer();
tempState.randomPlay();
boardStatus = tempState.getBoard().checkStatus();
}
return boardStatus;
}
}

View File

@ -0,0 +1,97 @@
package com.baeldung.algorithms.mcts.montecarlo;
import java.util.ArrayList;
import java.util.List;
import com.baeldung.algorithms.mcts.tictactoe.Board;
import com.baeldung.algorithms.mcts.tictactoe.Position;
public class State {
Board board;
int playerNo;
int visitCount;
double winScore;
public State() {
board = new Board();
}
public State(State state) {
this.board = new Board(state.getBoard());
this.playerNo = state.getPlayerNo();
this.visitCount = state.getVisitCount();
this.winScore = state.getWinScore();
}
public State(Board board) {
this.board = new Board(board);
}
public Board getBoard() {
return board;
}
public void setBoard(Board board) {
this.board = board;
}
public int getPlayerNo() {
return playerNo;
}
public void setPlayerNo(int playerNo) {
this.playerNo = playerNo;
}
public int getOpponent() {
return 3 - playerNo;
}
public int getVisitCount() {
return visitCount;
}
public void setVisitCount(int visitCount) {
this.visitCount = visitCount;
}
public double getWinScore() {
return winScore;
}
public void setWinScore(double winScore) {
this.winScore = winScore;
}
public List<State> getAllPossibleStates() {
List<State> possibleStates = new ArrayList<>();
List<Position> availablePositions = this.board.getEmptyPositions();
availablePositions.forEach(p -> {
State newState = new State(this.board);
newState.setPlayerNo(3 - this.playerNo);
newState.getBoard().performMove(newState.getPlayerNo(), p);
possibleStates.add(newState);
});
return possibleStates;
}
public void incrementVisit() {
this.visitCount++;
}
public void addScore(double score) {
if (this.winScore != Integer.MIN_VALUE)
this.winScore += score;
}
public void randomPlay() {
List<Position> availablePositions = this.board.getEmptyPositions();
int totalPossibilities = availablePositions.size();
int selectRandom = (int) (Math.random() * ((totalPossibilities - 1) + 1));
this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
}
public void togglePlayer() {
this.playerNo = 3 - this.playerNo;
}
}

View File

@ -0,0 +1,26 @@
package com.baeldung.algorithms.mcts.montecarlo;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import com.baeldung.algorithms.mcts.tree.Node;
public class UCT {
final static double C = 1.41;
public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
if (nodeVisit == 0)
return Integer.MAX_VALUE;
return ((double) nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
}
public static Node findBestNodeWithUCT(Node node) {
int parentVisit = node.getState().getVisitCount();
List<Node> childNodes = node.getChildArray();
return Collections.max(childNodes, Comparator.comparing(c -> {
double score = uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount());
return score;
}));
}
}

View File

@ -0,0 +1,155 @@
package com.baeldung.algorithms.mcts.tictactoe;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class Board {
int[][] boardValues;
int totalMoves;
public static final int DEFAULT_BOARD_SIZE = 3;
public static final int IN_PROGRESS = -1;
public static final int DRAW = 0;
public static final int P1 = 1;
public static final int P2 = 2;
public Board() {
boardValues = new int[DEFAULT_BOARD_SIZE][DEFAULT_BOARD_SIZE];
}
public Board(int boardSize) {
boardValues = new int[boardSize][boardSize];
}
public Board(int[][] boardValues) {
this.boardValues = boardValues;
}
public Board(int[][] boardValues, int totalMoves) {
this.boardValues = boardValues;
this.totalMoves = totalMoves;
}
public Board(Board board) {
int boardLength = board.getBoardValues().length;
this.boardValues = new int[boardLength][boardLength];
int[][] boardValues = board.getBoardValues();
int n = boardValues.length;
for (int i = 0; i < n; i++) {
int m = boardValues[i].length;
for (int j = 0; j < m; j++) {
this.boardValues[i][j] = boardValues[i][j];
}
}
}
public void performMove(int player, Position p) {
this.totalMoves++;
boardValues[p.getX()][p.getY()] = player;
}
public int[][] getBoardValues() {
return boardValues;
}
public void setBoardValues(int[][] boardValues) {
this.boardValues = boardValues;
}
public int checkStatus() {
int boardSize = boardValues.length;
int maxIndex = boardSize - 1;
int[] diag1 = new int[boardSize];
int[] diag2 = new int[boardSize];
for (int i = 0; i < boardSize; i++) {
int[] row = boardValues[i];
int[] col = new int[boardSize];
for (int j = 0; j < boardSize; j++) {
col[j] = boardValues[j][i];
}
int checkRowForWin = checkForWin(row);
if(checkRowForWin!=0)
return checkRowForWin;
int checkColForWin = checkForWin(col);
if(checkColForWin!=0)
return checkColForWin;
diag1[i] = boardValues[i][i];
diag2[i] = boardValues[maxIndex - i][i];
}
int checkDia1gForWin = checkForWin(diag1);
if(checkDia1gForWin!=0)
return checkDia1gForWin;
int checkDiag2ForWin = checkForWin(diag2);
if(checkDiag2ForWin!=0)
return checkDiag2ForWin;
if (getEmptyPositions().size() > 0)
return IN_PROGRESS;
else
return DRAW;
}
private int checkForWin(int[] row) {
boolean isEqual = true;
int size = row.length;
int previous = row[0];
for (int i = 0; i < size; i++) {
if (previous != row[i]) {
isEqual = false;
break;
}
previous = row[i];
}
if(isEqual)
return previous;
else
return 0;
}
public void printBoard() {
int size = this.boardValues.length;
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
System.out.print(boardValues[i][j] + " ");
}
System.out.println();
}
}
public List<Position> getEmptyPositions() {
int size = this.boardValues.length;
List<Position> emptyPositions = new ArrayList<>();
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
if (boardValues[i][j] == 0)
emptyPositions.add(new Position(i, j));
}
}
return emptyPositions;
}
public void printStatus() {
switch (this.checkStatus()) {
case P1:
System.out.println("Player 1 wins");
break;
case P2:
System.out.println("Player 2 wins");
break;
case DRAW:
System.out.println("Game Draw");
break;
case IN_PROGRESS:
System.out.println("Game In rogress");
break;
}
}
}

View File

@ -0,0 +1,31 @@
package com.baeldung.algorithms.mcts.tictactoe;
public class Position {
int x;
int y;
public Position() {
}
public Position(int x, int y) {
this.x = x;
this.y = y;
}
public int getX() {
return x;
}
public void setX(int x) {
this.x = x;
}
public int getY() {
return y;
}
public void setY(int y) {
this.y = y;
}
}

View File

@ -0,0 +1,78 @@
package com.baeldung.algorithms.mcts.tree;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import com.baeldung.algorithms.mcts.montecarlo.State;
public class Node {
State state;
Node parent;
List<Node> childArray;
public Node() {
this.state = new State();
childArray = new ArrayList<>();
}
public Node(State state) {
this.state = state;
childArray = new ArrayList<>();
}
public Node(State state, Node parent, List<Node> childArray) {
this.state = state;
this.parent = parent;
this.childArray = childArray;
}
public Node(Node node) {
this.childArray = new ArrayList<>();
this.state = new State(node.getState());
if (node.getParent() != null)
this.parent = node.getParent();
List<Node> childArray = node.getChildArray();
for (Node child : childArray) {
this.childArray.add(new Node(child));
}
}
public State getState() {
return state;
}
public void setState(State state) {
this.state = state;
}
public Node getParent() {
return parent;
}
public void setParent(Node parent) {
this.parent = parent;
}
public List<Node> getChildArray() {
return childArray;
}
public void setChildArray(List<Node> childArray) {
this.childArray = childArray;
}
public Node getRandomChildNode() {
int noOfPossibleMoves = this.childArray.size();
int selectRandom = (int) (Math.random() * ((noOfPossibleMoves - 1) + 1));
return this.childArray.get(selectRandom);
}
public Node getChildWithMaxScore() {
return Collections.max(this.childArray, Comparator.comparing(c -> {
return c.getState().getVisitCount();
}));
}
}

View File

@ -0,0 +1,26 @@
package com.baeldung.algorithms.mcts.tree;
public class Tree {
Node root;
public Tree() {
root = new Node();
}
public Tree(Node root) {
this.root = root;
}
public Node getRoot() {
return root;
}
public void setRoot(Node root) {
this.root = root;
}
public void addChild(Node parent, Node child) {
parent.getChildArray().add(child);
}
}

View File

@ -0,0 +1,92 @@
package algorithms;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import com.baeldung.algorithms.mcts.montecarlo.MonteCarloTreeSearch;
import com.baeldung.algorithms.mcts.montecarlo.State;
import com.baeldung.algorithms.mcts.montecarlo.UCT;
import com.baeldung.algorithms.mcts.tictactoe.Board;
import com.baeldung.algorithms.mcts.tictactoe.Position;
import com.baeldung.algorithms.mcts.tree.Tree;
public class MCTSTest {
Tree gameTree;
MonteCarloTreeSearch mcts;
@Before
public void initGameTree() {
gameTree = new Tree();
mcts = new MonteCarloTreeSearch();
}
@Test
public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
double uctValue = 15.79;
assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
}
@Test
public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
State initState = gameTree.getRoot().getState();
List<State> possibleStates = initState.getAllPossibleStates();
assertTrue(possibleStates.size() > 0);
}
@Test
public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
Board board = new Board();
int initAvailablePositions = board.getEmptyPositions().size();
board.performMove(Board.P1, new Position(1, 1));
int availablePositions = board.getEmptyPositions().size();
assertTrue(initAvailablePositions > availablePositions);
}
@Test
public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
Board board = new Board();
int player = Board.P1;
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
for (int i = 0; i < totalMoves; i++) {
board = mcts.findNextMove(board, player);
if (board.checkStatus() != -1) {
break;
}
player = 3 - player;
}
int winStatus = board.checkStatus();
assertEquals(winStatus, Board.DRAW);
}
@Test
public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
Board board = new Board();
MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
mcts1.setLevel(1);
MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
mcts3.setLevel(3);
int player = Board.P1;
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
for (int i = 0; i < totalMoves; i++) {
if (player == Board.P1)
board = mcts3.findNextMove(board, player);
else
board = mcts1.findNextMove(board, player);
if (board.checkStatus() != -1) {
break;
}
player = 3 - player;
}
int winStatus = board.checkStatus();
assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
}
}