From 20a6d2faeb80a24c8478720b7a69baa59ab77cb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jana=20D=C3=B6nszelmann?= Date: Thu, 11 Sep 2025 11:46:42 -0700 Subject: [PATCH] ai --- Cargo.lock | 159 +++++++++++++++++++++++++++++++++++++ Cargo.toml | 1 + src/gamestate.rs | 51 ++++++++++-- src/main.rs | 198 ++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 401 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f9e4843..8ed135a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,165 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "cfg-if" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" + +[[package]] +name = "getrandom" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi", +] + +[[package]] +name = "libc" +version = "0.2.175" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +dependencies = [ + "unicode-ident", +] + [[package]] name = "quoridor" version = "0.1.0" +dependencies = [ + "simple-mcts", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom", +] + +[[package]] +name = "simple-mcts" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12206ee44e7faaf94c033a47d2168d59f4957e8f50d71b5fdf7e16cfd419644e" +dependencies = [ + "rand", +] + +[[package]] +name = "syn" +version = "2.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" + +[[package]] +name = "wasi" +version = "0.14.5+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4494f6290a82f5fe584817a676a34b9d6763e8d9d18204009fb31dceca98fd4" +dependencies = [ + "wasip2", +] + +[[package]] +name = "wasip2" +version = "1.0.0+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03fa2761397e5bd52002cd7e73110c71af2109aca4e521a9f40473fe685b0a24" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wit-bindgen" +version = "0.45.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36" + +[[package]] +name = "zerocopy" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 843d6de..c7da30e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,4 @@ version = "0.1.0" edition = "2024" [dependencies] +simple-mcts = "0.1" diff --git a/src/gamestate.rs b/src/gamestate.rs index c5d2cfa..d3290cb 100644 --- a/src/gamestate.rs +++ b/src/gamestate.rs @@ -1,7 +1,9 @@ use std::fmt::Display; +use simple_mcts::Game; + #[derive(Clone, Copy, PartialEq)] -struct PlayerState { +pub struct PlayerState { xy: u8, walls_left: u8, } @@ -45,7 +47,7 @@ impl PlayerState { } #[derive(Copy, Clone)] -struct WallState { +pub struct WallState { verticals: [u8; 9], horizontals: [u8; 9], } @@ -101,10 +103,46 @@ impl Default for WallState { } } -struct GameState { - p1: PlayerState, - p2: PlayerState, - walls: WallState, +#[derive(Clone, Copy)] +pub struct GameState { + pub p1: PlayerState, + pub p2: PlayerState, + pub walls: WallState, + pub whose_turn: bool, +} + +impl GameState { + pub fn current_player(&self) -> &PlayerState { + if self.whose_turn { &self.p1 } else { &self.p2 } + } + + pub fn current_player_mut(&mut self) -> &mut PlayerState { + if self.whose_turn { + &mut self.p1 + } else { + &mut self.p2 + } + } + + pub fn mcts_result(&self) -> Option { + let p1_won = self.p1.y() == 8; + let p2_won = self.p2.y() == 0; + + let outcome_for_p1 = match (p1_won, p2_won) { + (false, false) => return None, + (true, false) => 1.0, + (false, true) => -1.0, + (true, true) => 0.0, + }; + + Some(if self.whose_turn { + //p1 wants to win + outcome_for_p1 + } else { + //p2 wants to win + -1.0 * outcome_for_p1 + }) + } } impl Default for GameState { @@ -113,6 +151,7 @@ impl Default for GameState { p1: PlayerState::P1_START, p2: PlayerState::P2_START, walls: Default::default(), + whose_turn: true, } } } diff --git a/src/main.rs b/src/main.rs index a71187c..8347bca 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,199 @@ +use std::error::Error; +use std::thread; +use std::time::Duration; + +use simple_mcts::Game; +use simple_mcts::GameEvaluator; +use simple_mcts::Mcts; +use simple_mcts::MctsError; + +use crate::gamestate::GameState; + mod gamestate; -fn main() { - println!("Hello, world!"); +struct Quoridor { + state: GameState, +} + +impl Default for Quoridor { + fn default() -> Self { + Self { + state: Default::default(), + } + } +} + +const NUM_NEXT_STATES: usize = 64 /* horizontal */ + 64 /* vertical */ + + 4 /* move directions */ + 2 /* blocked jumps */; + +impl Game for Quoridor { + type State = GameState; + + fn new() -> Self { + Self { + state: Default::default(), + } + } + + fn get_actions(&self) -> [bool; NUM_NEXT_STATES] { + let mut res = [false; NUM_NEXT_STATES]; + + for x in 0..8 { + for y in 0..8 { + if self.state.walls.can_walk_between(x, y, x + 1, y) + && self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1) + { + res[x as usize * 8 + y as usize] = true; + } + + if self.state.walls.can_walk_between(x, y, x, y + 1) + && self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1) + { + res[x as usize * 8 + y as usize + 64] = true; + } + } + } + + let p = self.state.current_player(); + let x = p.x(); + let y = p.y(); + res[128] = if x == 0 { + false + } else { + self.state.walls.can_walk_between(x, y, x - 1, y) + }; + res[129] = if y == 0 { + false + } else { + self.state.walls.can_walk_between(x, y, x, y - 1) + }; + res[130] = if x == 8 { + false + } else { + self.state.walls.can_walk_between(x, y, x + 1, y) + }; + res[131] = if y == 8 { + false + } else { + self.state.walls.can_walk_between(x, y, x, y + 1) + }; + // TODO: detect jumps for the 4 directions and also blocked jumps for the 2 sides of the other pawn + + // dbg!(&res[128..132]); + res + } + + fn is_finish(&self) -> bool { + self.state.p1.y() == 8 || self.state.p2.y() == 0 + } + + fn play(&mut self, action: usize) { + let mut set_block = |i: usize, vertical| { + let x = (i / 8) as u8; + let y = (i % 8) as u8; + + if vertical { + debug_assert!(self.state.walls.can_walk_between(x, y, x + 1, y)); + debug_assert!(self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1)); + + self.state.walls.block(x, y, x + 1, y); + self.state.walls.block(x, y + 1, x + 1, y + 1); + + debug_assert!(!self.state.walls.can_walk_between(x, y, x + 1, y)); + debug_assert!(!self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1)); + } else { + debug_assert!(self.state.walls.can_walk_between(x, y, x, y + 1)); + debug_assert!(self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1)); + + self.state.walls.block(x, y, x, y + 1); + self.state.walls.block(x + 1, y, x + 1, y + 1); + + debug_assert!(!self.state.walls.can_walk_between(x, y, x, y + 1)); + debug_assert!(!self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1)); + } + }; + + match action { + i @ 0..64 => set_block(i, true), + i @ 64..128 => set_block(i - 64, false), + 128 => { + let x = self.state.current_player().x(); + self.state.current_player_mut().set_x(x - 1); + } + 129 => { + let y = self.state.current_player().y(); + self.state.current_player_mut().set_y(y - 1); + } + 130 => { + let x = self.state.current_player().x(); + self.state.current_player_mut().set_x(x + 1); + } + 131 => { + let y = self.state.current_player().y(); + self.state.current_player_mut().set_x(y + 1); + } + 132 => todo!(), + 133 => todo!(), + _ => unreachable!(), + } + + self.state.whose_turn = !self.state.whose_turn; + } + + fn get_state(&self) -> Self::State { + self.state + } + + fn get_result(&self) -> Option { + self.get_state().mcts_result() + } + + fn clone(&self) -> Self { + Self { state: self.state } + } +} + +struct QuoridorEvaluator; + +impl GameEvaluator for QuoridorEvaluator { + fn evaluate(&self, state: GameState) -> (f64, [f64; NUM_NEXT_STATES]) { + ( + state.mcts_result().unwrap_or(0.0), + [const { 1.0 / NUM_NEXT_STATES as f64 }; NUM_NEXT_STATES], + ) + } +} + +fn main() -> Result<(), MctsError> { + let g = Quoridor::default(); + + let mut mcts: Mcts = Mcts::::new(); + let evaluator = QuoridorEvaluator; + + for _ in 0..100 { + // Perform 100 MCTS iterations + for _ in 0..1000 { + mcts.iterate(&evaluator)?; + } + + // Get the best action based on visit counts + let (score, policy) = mcts.get_result(); + println!("Best action score: {}, Policy: {:?}", score, policy); + + // Play the best action and update the MCTS tree + let best_action_index = policy + .iter() + .enumerate() + .max_by(|&(_, &a), &(_, &b)| a.partial_cmp(&b).unwrap()) + .map(|(index, _)| index) + .unwrap_or(0); // Default to first action if policy is empty + println!("best action: {best_action_index}"); + mcts.play(best_action_index)?; + + println!("{}", mcts.get_game().state); + thread::sleep(Duration::from_millis(50)); + } + + // Continue with the next game state + Ok(()) }