This commit is contained in:
Jana Dönszelmann 2025-09-11 11:46:42 -07:00
parent d4e626d13c
commit 20a6d2faeb
No known key found for this signature in database
4 changed files with 401 additions and 8 deletions

159
Cargo.lock generated
View file

@ -2,6 +2,165 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 4 version = 4
[[package]]
name = "cfg-if"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi",
]
[[package]]
name = "libc"
version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]]
name = "ppv-lite86"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
dependencies = [
"zerocopy",
]
[[package]]
name = "proc-macro2"
version = "1.0.101"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
dependencies = [
"unicode-ident",
]
[[package]] [[package]]
name = "quoridor" name = "quoridor"
version = "0.1.0" version = "0.1.0"
dependencies = [
"simple-mcts",
]
[[package]]
name = "quote"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
dependencies = [
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom",
]
[[package]]
name = "simple-mcts"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12206ee44e7faaf94c033a47d2168d59f4957e8f50d71b5fdf7e16cfd419644e"
dependencies = [
"rand",
]
[[package]]
name = "syn"
version = "2.0.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "unicode-ident"
version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d"
[[package]]
name = "wasi"
version = "0.14.5+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4494f6290a82f5fe584817a676a34b9d6763e8d9d18204009fb31dceca98fd4"
dependencies = [
"wasip2",
]
[[package]]
name = "wasip2"
version = "1.0.0+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03fa2761397e5bd52002cd7e73110c71af2109aca4e521a9f40473fe685b0a24"
dependencies = [
"wit-bindgen",
]
[[package]]
name = "wit-bindgen"
version = "0.45.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36"
[[package]]
name = "zerocopy"
version = "0.8.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

View file

@ -4,3 +4,4 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
simple-mcts = "0.1"

View file

@ -1,7 +1,9 @@
use std::fmt::Display; use std::fmt::Display;
use simple_mcts::Game;
#[derive(Clone, Copy, PartialEq)] #[derive(Clone, Copy, PartialEq)]
struct PlayerState { pub struct PlayerState {
xy: u8, xy: u8,
walls_left: u8, walls_left: u8,
} }
@ -45,7 +47,7 @@ impl PlayerState {
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
struct WallState { pub struct WallState {
verticals: [u8; 9], verticals: [u8; 9],
horizontals: [u8; 9], horizontals: [u8; 9],
} }
@ -101,10 +103,46 @@ impl Default for WallState {
} }
} }
struct GameState { #[derive(Clone, Copy)]
p1: PlayerState, pub struct GameState {
p2: PlayerState, pub p1: PlayerState,
walls: WallState, pub p2: PlayerState,
pub walls: WallState,
pub whose_turn: bool,
}
impl GameState {
pub fn current_player(&self) -> &PlayerState {
if self.whose_turn { &self.p1 } else { &self.p2 }
}
pub fn current_player_mut(&mut self) -> &mut PlayerState {
if self.whose_turn {
&mut self.p1
} else {
&mut self.p2
}
}
pub fn mcts_result(&self) -> Option<f64> {
let p1_won = self.p1.y() == 8;
let p2_won = self.p2.y() == 0;
let outcome_for_p1 = match (p1_won, p2_won) {
(false, false) => return None,
(true, false) => 1.0,
(false, true) => -1.0,
(true, true) => 0.0,
};
Some(if self.whose_turn {
//p1 wants to win
outcome_for_p1
} else {
//p2 wants to win
-1.0 * outcome_for_p1
})
}
} }
impl Default for GameState { impl Default for GameState {
@ -113,6 +151,7 @@ impl Default for GameState {
p1: PlayerState::P1_START, p1: PlayerState::P1_START,
p2: PlayerState::P2_START, p2: PlayerState::P2_START,
walls: Default::default(), walls: Default::default(),
whose_turn: true,
} }
} }
} }

View file

@ -1,5 +1,199 @@
use std::error::Error;
use std::thread;
use std::time::Duration;
use simple_mcts::Game;
use simple_mcts::GameEvaluator;
use simple_mcts::Mcts;
use simple_mcts::MctsError;
use crate::gamestate::GameState;
mod gamestate; mod gamestate;
fn main() { struct Quoridor {
println!("Hello, world!"); state: GameState,
}
impl Default for Quoridor {
fn default() -> Self {
Self {
state: Default::default(),
}
}
}
const NUM_NEXT_STATES: usize = 64 /* horizontal */ + 64 /* vertical */ +
4 /* move directions */ + 2 /* blocked jumps */;
impl Game<NUM_NEXT_STATES> for Quoridor {
type State = GameState;
fn new() -> Self {
Self {
state: Default::default(),
}
}
fn get_actions(&self) -> [bool; NUM_NEXT_STATES] {
let mut res = [false; NUM_NEXT_STATES];
for x in 0..8 {
for y in 0..8 {
if self.state.walls.can_walk_between(x, y, x + 1, y)
&& self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1)
{
res[x as usize * 8 + y as usize] = true;
}
if self.state.walls.can_walk_between(x, y, x, y + 1)
&& self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1)
{
res[x as usize * 8 + y as usize + 64] = true;
}
}
}
let p = self.state.current_player();
let x = p.x();
let y = p.y();
res[128] = if x == 0 {
false
} else {
self.state.walls.can_walk_between(x, y, x - 1, y)
};
res[129] = if y == 0 {
false
} else {
self.state.walls.can_walk_between(x, y, x, y - 1)
};
res[130] = if x == 8 {
false
} else {
self.state.walls.can_walk_between(x, y, x + 1, y)
};
res[131] = if y == 8 {
false
} else {
self.state.walls.can_walk_between(x, y, x, y + 1)
};
// TODO: detect jumps for the 4 directions and also blocked jumps for the 2 sides of the other pawn
// dbg!(&res[128..132]);
res
}
fn is_finish(&self) -> bool {
self.state.p1.y() == 8 || self.state.p2.y() == 0
}
fn play(&mut self, action: usize) {
let mut set_block = |i: usize, vertical| {
let x = (i / 8) as u8;
let y = (i % 8) as u8;
if vertical {
debug_assert!(self.state.walls.can_walk_between(x, y, x + 1, y));
debug_assert!(self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1));
self.state.walls.block(x, y, x + 1, y);
self.state.walls.block(x, y + 1, x + 1, y + 1);
debug_assert!(!self.state.walls.can_walk_between(x, y, x + 1, y));
debug_assert!(!self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1));
} else {
debug_assert!(self.state.walls.can_walk_between(x, y, x, y + 1));
debug_assert!(self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1));
self.state.walls.block(x, y, x, y + 1);
self.state.walls.block(x + 1, y, x + 1, y + 1);
debug_assert!(!self.state.walls.can_walk_between(x, y, x, y + 1));
debug_assert!(!self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1));
}
};
match action {
i @ 0..64 => set_block(i, true),
i @ 64..128 => set_block(i - 64, false),
128 => {
let x = self.state.current_player().x();
self.state.current_player_mut().set_x(x - 1);
}
129 => {
let y = self.state.current_player().y();
self.state.current_player_mut().set_y(y - 1);
}
130 => {
let x = self.state.current_player().x();
self.state.current_player_mut().set_x(x + 1);
}
131 => {
let y = self.state.current_player().y();
self.state.current_player_mut().set_x(y + 1);
}
132 => todo!(),
133 => todo!(),
_ => unreachable!(),
}
self.state.whose_turn = !self.state.whose_turn;
}
fn get_state(&self) -> Self::State {
self.state
}
fn get_result(&self) -> Option<f64> {
self.get_state().mcts_result()
}
fn clone(&self) -> Self {
Self { state: self.state }
}
}
struct QuoridorEvaluator;
impl GameEvaluator<Quoridor, NUM_NEXT_STATES> for QuoridorEvaluator {
fn evaluate(&self, state: GameState) -> (f64, [f64; NUM_NEXT_STATES]) {
(
state.mcts_result().unwrap_or(0.0),
[const { 1.0 / NUM_NEXT_STATES as f64 }; NUM_NEXT_STATES],
)
}
}
fn main() -> Result<(), MctsError> {
let g = Quoridor::default();
let mut mcts: Mcts<Quoridor, _> = Mcts::<Quoridor, _>::new();
let evaluator = QuoridorEvaluator;
for _ in 0..100 {
// Perform 100 MCTS iterations
for _ in 0..1000 {
mcts.iterate(&evaluator)?;
}
// Get the best action based on visit counts
let (score, policy) = mcts.get_result();
println!("Best action score: {}, Policy: {:?}", score, policy);
// Play the best action and update the MCTS tree
let best_action_index = policy
.iter()
.enumerate()
.max_by(|&(_, &a), &(_, &b)| a.partial_cmp(&b).unwrap())
.map(|(index, _)| index)
.unwrap_or(0); // Default to first action if policy is empty
println!("best action: {best_action_index}");
mcts.play(best_action_index)?;
println!("{}", mcts.get_game().state);
thread::sleep(Duration::from_millis(50));
}
// Continue with the next game state
Ok(())
} }