Algorithms refactor (#2136)

This commit is contained in:
Grzegorz Piwowarek 2017-06-23 10:18:04 +02:00 committed by GitHub
parent 1b0d5f0b73
commit ef7400484c
11 changed files with 124 additions and 126 deletions

View File

@ -1,4 +1,4 @@
package com.baeldung.automata;
package com.baeldung.algorithms.automata;
/**
* Finite state machine.

View File

@ -1,4 +1,4 @@
package com.baeldung.automata;
package com.baeldung.algorithms.automata;
/**
* Default implementation of a finite state machine.

View File

@ -1,4 +1,4 @@
package com.baeldung.automata;
package com.baeldung.algorithms.automata;
import java.util.ArrayList;
import java.util.List;

View File

@ -1,4 +1,4 @@
package com.baeldung.automata;
package com.baeldung.algorithms.automata;
/**

View File

@ -1,4 +1,4 @@
package com.baeldung.automata;
package com.baeldung.algorithms.automata;
/**
* State. Part of a finite state machine.

View File

@ -1,4 +1,4 @@
package com.baeldung.automata;
package com.baeldung.algorithms.automata;
/**
* Transition in a finite State machine.

View File

@ -8,9 +8,9 @@ import com.baeldung.algorithms.mcts.tree.Tree;
public class MonteCarloTreeSearch {
static final int WIN_SCORE = 10;
int level;
int oponent;
private static final int WIN_SCORE = 10;
private int level;
private int oponent;
public MonteCarloTreeSearch() {
this.level = 3;

View File

@ -7,10 +7,10 @@ import com.baeldung.algorithms.mcts.tictactoe.Board;
import com.baeldung.algorithms.mcts.tictactoe.Position;
public class State {
Board board;
int playerNo;
int visitCount;
double winScore;
private Board board;
private int playerNo;
private int visitCount;
private double winScore;
public State() {
board = new Board();
@ -27,23 +27,23 @@ public class State {
this.board = new Board(board);
}
public Board getBoard() {
Board getBoard() {
return board;
}
public void setBoard(Board board) {
void setBoard(Board board) {
this.board = board;
}
public int getPlayerNo() {
int getPlayerNo() {
return playerNo;
}
public void setPlayerNo(int playerNo) {
void setPlayerNo(int playerNo) {
this.playerNo = playerNo;
}
public int getOpponent() {
int getOpponent() {
return 3 - playerNo;
}
@ -55,11 +55,11 @@ public class State {
this.visitCount = visitCount;
}
public double getWinScore() {
double getWinScore() {
return winScore;
}
public void setWinScore(double winScore) {
void setWinScore(double winScore) {
this.winScore = winScore;
}
@ -75,23 +75,23 @@ public class State {
return possibleStates;
}
public void incrementVisit() {
void incrementVisit() {
this.visitCount++;
}
public void addScore(double score) {
void addScore(double score) {
if (this.winScore != Integer.MIN_VALUE)
this.winScore += score;
}
public void randomPlay() {
void randomPlay() {
List<Position> availablePositions = this.board.getEmptyPositions();
int totalPossibilities = availablePositions.size();
int selectRandom = (int) (Math.random() * ((totalPossibilities - 1) + 1));
this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
}
public void togglePlayer() {
void togglePlayer() {
this.playerNo = 3 - this.playerNo;
}
}

View File

@ -7,20 +7,18 @@ import java.util.List;
import com.baeldung.algorithms.mcts.tree.Node;
public class UCT {
final static double C = 1.41;
public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
if (nodeVisit == 0)
if (nodeVisit == 0) {
return Integer.MAX_VALUE;
return ((double) nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
}
return (nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
}
public static Node findBestNodeWithUCT(Node node) {
static Node findBestNodeWithUCT(Node node) {
int parentVisit = node.getState().getVisitCount();
List<Node> childNodes = node.getChildArray();
return Collections.max(childNodes, Comparator.comparing(c -> {
double score = uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount());
return score;
}));
return Collections.max(
node.getChildArray(),
Comparator.comparing(c -> uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount())));
}
}

View File

@ -1,6 +1,6 @@
package algorithms;
import com.baeldung.automata.*;
import com.baeldung.algorithms.automata.*;
import org.junit.Test;
import static org.junit.Assert.assertTrue;

View File

@ -1,92 +1,92 @@
package algorithms;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import com.baeldung.algorithms.mcts.montecarlo.MonteCarloTreeSearch;
import com.baeldung.algorithms.mcts.montecarlo.State;
import com.baeldung.algorithms.mcts.montecarlo.UCT;
import com.baeldung.algorithms.mcts.tictactoe.Board;
import com.baeldung.algorithms.mcts.tictactoe.Position;
import com.baeldung.algorithms.mcts.tree.Tree;
public class MCTSTest {
Tree gameTree;
MonteCarloTreeSearch mcts;
@Before
public void initGameTree() {
gameTree = new Tree();
mcts = new MonteCarloTreeSearch();
}
@Test
public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
double uctValue = 15.79;
assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
}
@Test
public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
State initState = gameTree.getRoot().getState();
List<State> possibleStates = initState.getAllPossibleStates();
assertTrue(possibleStates.size() > 0);
}
@Test
public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
Board board = new Board();
int initAvailablePositions = board.getEmptyPositions().size();
board.performMove(Board.P1, new Position(1, 1));
int availablePositions = board.getEmptyPositions().size();
assertTrue(initAvailablePositions > availablePositions);
}
@Test
public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
Board board = new Board();
int player = Board.P1;
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
for (int i = 0; i < totalMoves; i++) {
board = mcts.findNextMove(board, player);
if (board.checkStatus() != -1) {
break;
}
player = 3 - player;
}
int winStatus = board.checkStatus();
assertEquals(winStatus, Board.DRAW);
}
@Test
public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
Board board = new Board();
MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
mcts1.setLevel(1);
MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
mcts3.setLevel(3);
int player = Board.P1;
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
for (int i = 0; i < totalMoves; i++) {
if (player == Board.P1)
board = mcts3.findNextMove(board, player);
else
board = mcts1.findNextMove(board, player);
if (board.checkStatus() != -1) {
break;
}
player = 3 - player;
}
int winStatus = board.checkStatus();
assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
}
}
package algorithms.mcts;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import com.baeldung.algorithms.mcts.montecarlo.MonteCarloTreeSearch;
import com.baeldung.algorithms.mcts.montecarlo.State;
import com.baeldung.algorithms.mcts.montecarlo.UCT;
import com.baeldung.algorithms.mcts.tictactoe.Board;
import com.baeldung.algorithms.mcts.tictactoe.Position;
import com.baeldung.algorithms.mcts.tree.Tree;
public class MCTSTest {
private Tree gameTree;
private MonteCarloTreeSearch mcts;
@Before
public void initGameTree() {
gameTree = new Tree();
mcts = new MonteCarloTreeSearch();
}
@Test
public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
double uctValue = 15.79;
assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
}
@Test
public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
State initState = gameTree.getRoot().getState();
List<State> possibleStates = initState.getAllPossibleStates();
assertTrue(possibleStates.size() > 0);
}
@Test
public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
Board board = new Board();
int initAvailablePositions = board.getEmptyPositions().size();
board.performMove(Board.P1, new Position(1, 1));
int availablePositions = board.getEmptyPositions().size();
assertTrue(initAvailablePositions > availablePositions);
}
@Test
public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
Board board = new Board();
int player = Board.P1;
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
for (int i = 0; i < totalMoves; i++) {
board = mcts.findNextMove(board, player);
if (board.checkStatus() != -1) {
break;
}
player = 3 - player;
}
int winStatus = board.checkStatus();
assertEquals(winStatus, Board.DRAW);
}
@Test
public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
Board board = new Board();
MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
mcts1.setLevel(1);
MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
mcts3.setLevel(3);
int player = Board.P1;
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
for (int i = 0; i < totalMoves; i++) {
if (player == Board.P1)
board = mcts3.findNextMove(board, player);
else
board = mcts1.findNextMove(board, player);
if (board.checkStatus() != -1) {
break;
}
player = 3 - player;
}
int winStatus = board.checkStatus();
assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
}
}