Algorithms refactor (#2136)
This commit is contained in:
parent
1b0d5f0b73
commit
ef7400484c
@ -1,4 +1,4 @@
|
|||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Finite state machine.
|
* Finite state machine.
|
@ -1,4 +1,4 @@
|
|||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Default implementation of a finite state machine.
|
* Default implementation of a finite state machine.
|
@ -1,4 +1,4 @@
|
|||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
@ -1,4 +1,4 @@
|
|||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
@ -1,4 +1,4 @@
|
|||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* State. Part of a finite state machine.
|
* State. Part of a finite state machine.
|
@ -1,4 +1,4 @@
|
|||||||
package com.baeldung.automata;
|
package com.baeldung.algorithms.automata;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Transition in a finite State machine.
|
* Transition in a finite State machine.
|
@ -8,9 +8,9 @@ import com.baeldung.algorithms.mcts.tree.Tree;
|
|||||||
|
|
||||||
public class MonteCarloTreeSearch {
|
public class MonteCarloTreeSearch {
|
||||||
|
|
||||||
static final int WIN_SCORE = 10;
|
private static final int WIN_SCORE = 10;
|
||||||
int level;
|
private int level;
|
||||||
int oponent;
|
private int oponent;
|
||||||
|
|
||||||
public MonteCarloTreeSearch() {
|
public MonteCarloTreeSearch() {
|
||||||
this.level = 3;
|
this.level = 3;
|
||||||
|
@ -7,10 +7,10 @@ import com.baeldung.algorithms.mcts.tictactoe.Board;
|
|||||||
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||||
|
|
||||||
public class State {
|
public class State {
|
||||||
Board board;
|
private Board board;
|
||||||
int playerNo;
|
private int playerNo;
|
||||||
int visitCount;
|
private int visitCount;
|
||||||
double winScore;
|
private double winScore;
|
||||||
|
|
||||||
public State() {
|
public State() {
|
||||||
board = new Board();
|
board = new Board();
|
||||||
@ -27,23 +27,23 @@ public class State {
|
|||||||
this.board = new Board(board);
|
this.board = new Board(board);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Board getBoard() {
|
Board getBoard() {
|
||||||
return board;
|
return board;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setBoard(Board board) {
|
void setBoard(Board board) {
|
||||||
this.board = board;
|
this.board = board;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getPlayerNo() {
|
int getPlayerNo() {
|
||||||
return playerNo;
|
return playerNo;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setPlayerNo(int playerNo) {
|
void setPlayerNo(int playerNo) {
|
||||||
this.playerNo = playerNo;
|
this.playerNo = playerNo;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getOpponent() {
|
int getOpponent() {
|
||||||
return 3 - playerNo;
|
return 3 - playerNo;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,11 +55,11 @@ public class State {
|
|||||||
this.visitCount = visitCount;
|
this.visitCount = visitCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getWinScore() {
|
double getWinScore() {
|
||||||
return winScore;
|
return winScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setWinScore(double winScore) {
|
void setWinScore(double winScore) {
|
||||||
this.winScore = winScore;
|
this.winScore = winScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,23 +75,23 @@ public class State {
|
|||||||
return possibleStates;
|
return possibleStates;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void incrementVisit() {
|
void incrementVisit() {
|
||||||
this.visitCount++;
|
this.visitCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addScore(double score) {
|
void addScore(double score) {
|
||||||
if (this.winScore != Integer.MIN_VALUE)
|
if (this.winScore != Integer.MIN_VALUE)
|
||||||
this.winScore += score;
|
this.winScore += score;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void randomPlay() {
|
void randomPlay() {
|
||||||
List<Position> availablePositions = this.board.getEmptyPositions();
|
List<Position> availablePositions = this.board.getEmptyPositions();
|
||||||
int totalPossibilities = availablePositions.size();
|
int totalPossibilities = availablePositions.size();
|
||||||
int selectRandom = (int) (Math.random() * ((totalPossibilities - 1) + 1));
|
int selectRandom = (int) (Math.random() * ((totalPossibilities - 1) + 1));
|
||||||
this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
|
this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void togglePlayer() {
|
void togglePlayer() {
|
||||||
this.playerNo = 3 - this.playerNo;
|
this.playerNo = 3 - this.playerNo;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,20 +7,18 @@ import java.util.List;
|
|||||||
import com.baeldung.algorithms.mcts.tree.Node;
|
import com.baeldung.algorithms.mcts.tree.Node;
|
||||||
|
|
||||||
public class UCT {
|
public class UCT {
|
||||||
final static double C = 1.41;
|
|
||||||
|
|
||||||
public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
|
public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
|
||||||
if (nodeVisit == 0)
|
if (nodeVisit == 0) {
|
||||||
return Integer.MAX_VALUE;
|
return Integer.MAX_VALUE;
|
||||||
return ((double) nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
|
}
|
||||||
|
return (nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Node findBestNodeWithUCT(Node node) {
|
static Node findBestNodeWithUCT(Node node) {
|
||||||
int parentVisit = node.getState().getVisitCount();
|
int parentVisit = node.getState().getVisitCount();
|
||||||
List<Node> childNodes = node.getChildArray();
|
return Collections.max(
|
||||||
return Collections.max(childNodes, Comparator.comparing(c -> {
|
node.getChildArray(),
|
||||||
double score = uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount());
|
Comparator.comparing(c -> uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount())));
|
||||||
return score;
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package algorithms;
|
package algorithms;
|
||||||
|
|
||||||
import com.baeldung.automata.*;
|
import com.baeldung.algorithms.automata.*;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
@ -1,92 +1,92 @@
|
|||||||
package algorithms;
|
package algorithms.mcts;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import com.baeldung.algorithms.mcts.montecarlo.MonteCarloTreeSearch;
|
import com.baeldung.algorithms.mcts.montecarlo.MonteCarloTreeSearch;
|
||||||
import com.baeldung.algorithms.mcts.montecarlo.State;
|
import com.baeldung.algorithms.mcts.montecarlo.State;
|
||||||
import com.baeldung.algorithms.mcts.montecarlo.UCT;
|
import com.baeldung.algorithms.mcts.montecarlo.UCT;
|
||||||
import com.baeldung.algorithms.mcts.tictactoe.Board;
|
import com.baeldung.algorithms.mcts.tictactoe.Board;
|
||||||
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
import com.baeldung.algorithms.mcts.tictactoe.Position;
|
||||||
import com.baeldung.algorithms.mcts.tree.Tree;
|
import com.baeldung.algorithms.mcts.tree.Tree;
|
||||||
|
|
||||||
public class MCTSTest {
|
public class MCTSTest {
|
||||||
Tree gameTree;
|
private Tree gameTree;
|
||||||
MonteCarloTreeSearch mcts;
|
private MonteCarloTreeSearch mcts;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void initGameTree() {
|
public void initGameTree() {
|
||||||
gameTree = new Tree();
|
gameTree = new Tree();
|
||||||
mcts = new MonteCarloTreeSearch();
|
mcts = new MonteCarloTreeSearch();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
|
public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
|
||||||
double uctValue = 15.79;
|
double uctValue = 15.79;
|
||||||
assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
|
assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
|
public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
|
||||||
State initState = gameTree.getRoot().getState();
|
State initState = gameTree.getRoot().getState();
|
||||||
List<State> possibleStates = initState.getAllPossibleStates();
|
List<State> possibleStates = initState.getAllPossibleStates();
|
||||||
assertTrue(possibleStates.size() > 0);
|
assertTrue(possibleStates.size() > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
|
public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
|
||||||
Board board = new Board();
|
Board board = new Board();
|
||||||
int initAvailablePositions = board.getEmptyPositions().size();
|
int initAvailablePositions = board.getEmptyPositions().size();
|
||||||
board.performMove(Board.P1, new Position(1, 1));
|
board.performMove(Board.P1, new Position(1, 1));
|
||||||
int availablePositions = board.getEmptyPositions().size();
|
int availablePositions = board.getEmptyPositions().size();
|
||||||
assertTrue(initAvailablePositions > availablePositions);
|
assertTrue(initAvailablePositions > availablePositions);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
|
public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
|
||||||
Board board = new Board();
|
Board board = new Board();
|
||||||
|
|
||||||
int player = Board.P1;
|
int player = Board.P1;
|
||||||
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
||||||
for (int i = 0; i < totalMoves; i++) {
|
for (int i = 0; i < totalMoves; i++) {
|
||||||
board = mcts.findNextMove(board, player);
|
board = mcts.findNextMove(board, player);
|
||||||
if (board.checkStatus() != -1) {
|
if (board.checkStatus() != -1) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
player = 3 - player;
|
player = 3 - player;
|
||||||
}
|
}
|
||||||
int winStatus = board.checkStatus();
|
int winStatus = board.checkStatus();
|
||||||
assertEquals(winStatus, Board.DRAW);
|
assertEquals(winStatus, Board.DRAW);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
|
public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
|
||||||
Board board = new Board();
|
Board board = new Board();
|
||||||
MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
|
MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
|
||||||
mcts1.setLevel(1);
|
mcts1.setLevel(1);
|
||||||
MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
|
MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
|
||||||
mcts3.setLevel(3);
|
mcts3.setLevel(3);
|
||||||
|
|
||||||
int player = Board.P1;
|
int player = Board.P1;
|
||||||
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
|
||||||
for (int i = 0; i < totalMoves; i++) {
|
for (int i = 0; i < totalMoves; i++) {
|
||||||
if (player == Board.P1)
|
if (player == Board.P1)
|
||||||
board = mcts3.findNextMove(board, player);
|
board = mcts3.findNextMove(board, player);
|
||||||
else
|
else
|
||||||
board = mcts1.findNextMove(board, player);
|
board = mcts1.findNextMove(board, player);
|
||||||
|
|
||||||
if (board.checkStatus() != -1) {
|
if (board.checkStatus() != -1) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
player = 3 - player;
|
player = 3 - player;
|
||||||
}
|
}
|
||||||
int winStatus = board.checkStatus();
|
int winStatus = board.checkStatus();
|
||||||
assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
|
assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
x
Reference in New Issue
Block a user