五子棋的蒙特卡洛树搜索(MCTS)实现 - Java版
下面是一个完整的Java实现,使用蒙特卡洛树搜索算法来实现五子棋AI。这个实现包含了MCTS的核心组件:选择(Selection)、扩展(Expansion)、模拟(Simulation)和反向传播(Backpropagation)。
import java.util.*;
public class GomokuMCTS {
// 棋盘大小
private static final int BOARD_SIZE = 15;
// 玩家标记
private static final int EMPTY = 0;
private static final int HUMAN = 1;
private static final int AI = 2;
// MCTS参数
private static final int SIMULATION_TIMES = 10000;
private static final double EXPLORATION_FACTOR = Math.sqrt(2);
// 节点类表示MCTS树中的节点
private static class Node {
int[][] board;
int lastMovePlayer; // 最后下棋的玩家
int lastMoveX, lastMoveY; // 最后下的位置
Node parent;
List<Node> children;
int visitCount;
double winScore;
public Node(int[][] board, int lastMovePlayer, int lastMoveX, int lastMoveY, Node parent) {
this.board = copyBoard(board);
this.lastMovePlayer = lastMovePlayer;
this.lastMoveX = lastMoveX;
this.lastMoveY = lastMoveY;
this.parent = parent;
this.children = new ArrayList<>();
this.visitCount = 0;
this.winScore = 0;
}
// 复制棋盘
private int[][] copyBoard(int[][] original) {
int[][] copy = new int[BOARD_SIZE][BOARD_SIZE];
for (int i = 0; i < BOARD_SIZE; i++) {
System.arraycopy(original[i], 0, copy[i], 0, BOARD_SIZE);
}
return copy;
}
// 获取所有可能的子节点(所有合法移动)
public List<Node> getAllPossibleNodes(int player) {
List<Node> possibleNodes = new ArrayList<>();
// 简单策略:只考虑已有棋子周围的空位
Set<String> considered = new HashSet<>();
for (int i = 0; i < BOARD_SIZE; i++) {
for (int j = 0; j < BOARD_SIZE; j++) {
if (board[i][j] != EMPTY) {
// 检查周围3x3区域
for (int dx = -2; dx <= 2; dx++) {
for (int dy = -2; dy <= 2; dy++) {
int x = i + dx;
int y = j + dy;
if (x >= 0 && x < BOARD_SIZE && y >= 0 && y < BOARD_SIZE
&& board[x][y] == EMPTY && !considered.contains(x + "," + y)) {
considered.add(x + "," + y);
int[][] newBoard = copyBoard(board);
newBoard[x][y] = player;
possibleNodes.add(new Node(newBoard, player, x, y, this));
}
}
}
}
}
}
// 如果棋盘为空,选择中心点
if (considered.isEmpty()) {
int center = BOARD_SIZE / 2;
int[][] newBoard = copyBoard(board);
newBoard[center][center] = player;
possibleNodes.add(new Node(newBoard, player, center, center, this));
}
return possibleNodes;
}
// 随机选择子节点
public Node getRandomChildNode() {
int randomIndex = (int) (Math.random() * children.size());
return children.get(randomIndex);
}
// 计算UCT值(Upper Confidence Bound for Trees)
public double getUCTValue() {
if (visitCount == 0) {
return Double.MAX_VALUE;
}
return (winScore / visitCount) + EXPLORATION_FACTOR * Math.sqrt(Math.log(parent.visitCount) / visitCount);
}
// 获取最佳子节点(基于UCT值或胜率)
public Node getBestChild() {
return Collections.max(children, Comparator.comparing(c -> c.visitCount));
}
}
// 查找最佳移动
public static int[] findBestMove(int[][] board, int player) {
Node rootNode = new Node(board, player == HUMAN ? AI : HUMAN, -1, -1, null);
// MCTS循环
for (int i = 0; i < SIMULATION_TIMES; i++) {
// 1. 选择阶段
Node promisingNode = selectPromisingNode(rootNode);
// 2. 扩展阶段
if (!isTerminal(promisingNode.board)) {
expandNode(promisingNode, player);
}
// 3. 模拟阶段
Node nodeToExplore = promisingNode;
if (promisingNode.children.size() > 0) {
nodeToExplore = promisingNode.getRandomChildNode();
}
int playoutResult = simulateRandomPlayout(nodeToExplore, player);
// 4. 反向传播阶段
backPropagation(nodeToExplore, playoutResult);
}
// 选择访问次数最多的子节点
Node bestNode = rootNode.getBestChild();
return new int[]{bestNode.lastMoveX, bestNode.lastMoveY};
}
// 选择阶段 - 从根节点开始选择最有潜力的节点
private static Node selectPromisingNode(Node rootNode) {
Node node = rootNode;
while (node.children.size() != 0) {
node = findBestNodeWithUCT(node);
}
return node;
}
// 使用UCT算法选择最佳子节点
private static Node findBestNodeWithUCT(Node node) {
int parentVisit = node.visitCount;
return Collections.max(node.children,
Comparator.comparing(child -> (child.winScore / child.visitCount) +
EXPLORATION_FACTOR * Math.sqrt(Math.log(parentVisit) / child.visitCount)));
}
// 扩展阶段 - 为节点添加所有可能的子节点
private static void expandNode(Node node, int player) {
List<Node> possibleNodes = node.getAllPossibleNodes(player);
for (Node possibleNode : possibleNodes) {
node.children.add(possibleNode);
}
}
// 模拟阶段 - 随机模拟游戏直到结束
private static int simulateRandomPlayout(Node node, int player) {
int[][] tempBoard = node.copyBoard(node.board);
int boardStatus = isTerminal(tempBoard);
int currentPlayer = player;
// 如果当前节点已经是终局
if (boardStatus != Integer.MIN_VALUE) {
return boardStatus;
}
// 随机下棋直到游戏结束
while (true) {
boardStatus = isTerminal(tempBoard);
if (boardStatus != Integer.MIN_VALUE) {
return boardStatus;
}
// 获取所有合法移动
List<int[]> availableMoves = getAvailableMoves(tempBoard);
if (availableMoves.isEmpty()) {
return 0; // 平局
}
// 随机选择一个移动
int[] randomMove = availableMoves.get((int) (Math.random() * availableMoves.size()));
tempBoard[randomMove[0]][randomMove[1]] = currentPlayer;
currentPlayer = currentPlayer == HUMAN ? AI : HUMAN;
}
}
// 获取所有合法移动
private static List<int[]> getAvailableMoves(int[][] board) {
List<int[]> moves = new ArrayList<>();
// 简单策略:只考虑已有棋子周围的空位
Set<String> considered = new HashSet<>();
for (int i = 0; i < BOARD_SIZE; i++) {
for (int j = 0; j < BOARD_SIZE; j++) {
if (board[i][j] != EMPTY) {
// 检查周围3x3区域
for (int dx = -2; dx <= 2; dx++) {
for (int dy = -2; dy <= 2; dy++) {
int x = i + dx;
int y = j + dy;
if (x >= 0 && x < BOARD_SIZE && y >= 0 && y < BOARD_SIZE
&& board[x][y] == EMPTY && !considered.contains(x + "," + y)) {
considered.add(x + "," + y);
moves.add(new int[]{x, y});
}
}
}
}
}
}
// 如果棋盘为空,选择中心点
if (moves.isEmpty()) {
int center = BOARD_SIZE / 2;
moves.add(new int[]{center, center});
}
return moves;
}
// 反向传播阶段 - 更新从叶子节点到根节点的所有节点统计信息
private static void backPropagation(Node nodeToExplore, int playoutResult) {
Node tempNode = nodeToExplore;
while (tempNode != null) {
tempNode.visitCount++;
if ((tempNode.lastMovePlayer == AI && playoutResult == AI) ||
(tempNode.lastMovePlayer == HUMAN && playoutResult == HUMAN)) {
tempNode.winScore += 1;
} else if (playoutResult == 0) {
tempNode.winScore += 0.5;
}
tempNode = tempNode.parent;
}
}
// 检查游戏是否结束(返回获胜玩家或MIN_VALUE表示未结束)
private static int isTerminal(int[][] board) {
// 检查所有行
for (int i = 0; i < BOARD_SIZE; i++) {
for (int j = 0; j < BOARD_SIZE - 4; j++) {
if (board[i][j] != EMPTY &&
board[i][j] == board[i][j+1] &&
board[i][j] == board[i][j+2] &&
board[i][j] == board[i][j+3] &&
board[i][j] == board[i][j+4]) {
return board[i][j];
}
}
}
// 检查所有列
for (int j = 0; j < BOARD_SIZE; j++) {
for (int i = 0; i < BOARD_SIZE - 4; i++) {
if (board[i][j] != EMPTY &&
board[i][j] == board[i+1][j] &&
board[i][j] == board[i+2][j] &&
board[i][j] == board[i+3][j] &&
board[i][j] == board[i+4][j]) {
return board[i][j];
}
}
}
// 检查对角线(左上到右下)
for (int i = 0; i < BOARD_SIZE - 4; i++) {
for (int j = 0; j < BOARD_SIZE - 4; j++) {
if (board[i][j] != EMPTY &&
board[i][j] == board[i+1][j+1] &&
board[i][j] == board[i+2][j+2] &&
board[i][j] == board[i+3][j+3] &&
board[i][j] == board[i+4][j+4]) {
return board[i][j];
}
}
}
// 检查对角线(右上到左下)
for (int i = 0; i < BOARD_SIZE - 4; i++) {
for (int j = 4; j < BOARD_SIZE; j++) {
if (board[i][j] != EMPTY &&
board[i][j] == board[i+1][j-1] &&
board[i][j] == board[i+2][j-2] &&
board[i][j] == board[i+3][j-3] &&
board[i][j] == board[i+4][j-4]) {
return board[i][j];
}
}
}
// 检查是否还有空位
for (int i = 0; i < BOARD_SIZE; i++) {
for (int j = 0; j < BOARD_SIZE; j++) {
if (board[i][j] == EMPTY) {
return Integer.MIN_VALUE; // 游戏未结束
}
}
}
return 0; // 平局
}
// 打印棋盘
public static void printBoard(int[][] board) {
System.out.print(" ");
for (int j = 0; j < BOARD_SIZE; j++) {
System.out.printf("%2d ", j);
}
System.out.println();
for (int i = 0; i < BOARD_SIZE; i++) {
System.out.printf("%2d ", i);
for (int j = 0; j < BOARD_SIZE; j++) {
char c = '.';
if (board[i][j] == HUMAN) c = 'X';
else if (board[i][j] == AI) c = 'O';
System.out.printf("%2c ", c);
}
System.out.println();
}
}
// 主函数 - 简单的控制台游戏
public static void main(String[] args) {
int[][] board = new int[BOARD_SIZE][BOARD_SIZE];
Scanner scanner = new Scanner(System.in);
int currentPlayer = HUMAN;
while (true) {
printBoard(board);
if (currentPlayer == HUMAN) {
System.out.println("你的回合(X), 输入行和列(例如: 7 7):");
int x = scanner.nextInt();
int y = scanner.nextInt();
if (x < 0 || x >= BOARD_SIZE || y < 0 || y >= BOARD_SIZE || board[x][y] != EMPTY) {
System.out.println("无效的移动,请重试!");
continue;
}
board[x][y] = HUMAN;
} else {
System.out.println("AI思考中...");
int[] bestMove = findBestMove(board, AI);
board[bestMove[0]][bestMove[1]] = AI;
System.out.println("AI下在: " + bestMove[0] + " " + bestMove[1]);
}
// 检查游戏是否结束
int result = isTerminal(board);
if (result == HUMAN) {
printBoard(board);
System.out.println("你赢了!");
break;
} else if (result == AI) {
printBoard(board);
System.out.println("AI赢了!");
break;
} else if (result == 0) {
printBoard(board);
System.out.println("平局!");
break;
}
currentPlayer = currentPlayer == HUMAN ? AI : HUMAN;
}
scanner.close();
}
}
代码说明
-
MCTS核心组件:
Node
类表示搜索树中的节点,包含棋盘状态、统计信息等findBestMove
是主入口,执行MCTS算法- 四个阶段:选择、扩展、模拟、反向传播
-
游戏逻辑:
- 15x15的五子棋棋盘
- 检查五子连珠的胜利条件
- 简单的移动生成策略(只考虑已有棋子周围的空位)
-
优化点:
- 使用UCT算法平衡探索和利用
- 在扩展阶段限制考虑的移动数量(只考虑已有棋子周围的空位)
- 随机模拟使用快速评估
-
使用方法:
- 运行main方法开始人机对战
- 人类玩家输入坐标(如"7 7")下棋
- AI会自动计算最佳移动并下棋
这个实现可以根据需要进行优化,例如添加更智能的模拟策略、并行化MCTS搜索、或者使用神经网络来指导搜索等。