diff --git a/src/gamestate.rs b/src/gamestate.rs index d3290cb..62ff548 100644 --- a/src/gamestate.rs +++ b/src/gamestate.rs @@ -2,15 +2,18 @@ use std::fmt::Display; use simple_mcts::Game; +use crate::pathfind::GoalDistanceMap; + #[derive(Clone, Copy, PartialEq)] pub struct PlayerState { xy: u8, - walls_left: u8, + pub walls_left: u8, } impl PlayerState { - pub const P1_START: Self = Self::new(4, 0, 9); - pub const P2_START: Self = Self::new(4, 8, 9); + pub const INITIAL_WALLS: u8 = 10; + pub const P1_START: Self = Self::new(4, 0, Self::INITIAL_WALLS); + pub const P2_START: Self = Self::new(4, 8, Self::INITIAL_WALLS); pub const fn new(x: u8, y: u8, walls_left: u8) -> Self { let mut res = Self { xy: 0, walls_left }; @@ -54,24 +57,24 @@ pub struct WallState { impl WallState { #[inline] - pub fn block_cleaned_hori(&mut self, byte_idx: u8, bit: u8) { + fn block_cleaned_hori(&mut self, byte_idx: u8, bit: u8) { self.horizontals[byte_idx as usize] |= 1 << bit; } #[inline] - pub fn block_cleaned_verti(&mut self, byte_idx: u8, bit: u8) { + fn block_cleaned_verti(&mut self, byte_idx: u8, bit: u8) { self.verticals[byte_idx as usize] |= 1 << bit; } #[inline] - pub fn can_walk_between_cleaned_hori(&self, byte_idx: u8, bit: u8) -> bool { + fn can_walk_between_cleaned_hori(&self, byte_idx: u8, bit: u8) -> bool { (self.horizontals[byte_idx as usize] >> bit) & 1 != 0 } #[inline] - pub fn can_walk_between_cleaned_verti(&self, byte_idx: u8, bit: u8) -> bool { + fn can_walk_between_cleaned_verti(&self, byte_idx: u8, bit: u8) -> bool { (self.verticals[byte_idx as usize] >> bit) & 1 != 0 } - pub fn block(&mut self, from_x: u8, from_y: u8, to_x: u8, to_y: u8) { + fn block(&mut self, from_x: u8, from_y: u8, to_x: u8, to_y: u8) { match (from_x.wrapping_sub(to_x), from_y.wrapping_sub(to_y)) { (1, 0) => self.block_cleaned_verti(to_y, to_x), (0xff, 0) => self.block_cleaned_verti(from_y, from_x), @@ -92,6 +95,24 @@ impl WallState { _ => unreachable!(), } } + + pub fn can_place(&self, x: u8, y: u8, vertical: bool) -> bool { + if vertical { + self.can_walk_between(x, y, x + 1, y) && self.can_walk_between(x, y + 1, x + 1, y + 1) + } else { + self.can_walk_between(x, y, x, y + 1) && self.can_walk_between(x + 1, y, x + 1, y + 1) + } + } + + pub fn place(&mut self, x: u8, y: u8, vertical: bool) { + if vertical { + self.block(x, y, x + 1, y); + self.block(x, y + 1, x + 1, y + 1); + } else { + self.block(x, y, x, y + 1); + self.block(x + 1, y, x + 1, y + 1); + } + } } impl Default for WallState { @@ -103,30 +124,55 @@ impl Default for WallState { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PlayerIdentifier { + P1, + P2, +} + +impl PlayerIdentifier { + pub fn swap(&mut self) { + use PlayerIdentifier::*; + *self = match self { + P1 => P2, + P2 => P1, + }; + } + + pub const fn y_goal(&self) -> u8 { + match self { + PlayerIdentifier::P1 => 8, + PlayerIdentifier::P2 => 0, + } + } +} + #[derive(Clone, Copy)] pub struct GameState { pub p1: PlayerState, pub p2: PlayerState, pub walls: WallState, - pub whose_turn: bool, + pub current_player: PlayerIdentifier, } impl GameState { - pub fn current_player(&self) -> &PlayerState { - if self.whose_turn { &self.p1 } else { &self.p2 } + pub fn current_player_state(&self) -> &PlayerState { + match self.current_player { + PlayerIdentifier::P1 => &self.p1, + PlayerIdentifier::P2 => &self.p2, + } } - pub fn current_player_mut(&mut self) -> &mut PlayerState { - if self.whose_turn { - &mut self.p1 - } else { - &mut self.p2 + pub fn current_player_state_mut(&mut self) -> &mut PlayerState { + match self.current_player { + PlayerIdentifier::P1 => &mut self.p1, + PlayerIdentifier::P2 => &mut self.p2, } } pub fn mcts_result(&self) -> Option { - let p1_won = self.p1.y() == 8; - let p2_won = self.p2.y() == 0; + let p1_won = self.p1.y() == PlayerIdentifier::P1.y_goal(); + let p2_won = self.p2.y() == PlayerIdentifier::P2.y_goal(); let outcome_for_p1 = match (p1_won, p2_won) { (false, false) => return None, @@ -135,12 +181,9 @@ impl GameState { (true, true) => 0.0, }; - Some(if self.whose_turn { - //p1 wants to win - outcome_for_p1 - } else { - //p2 wants to win - -1.0 * outcome_for_p1 + Some(match self.current_player { + PlayerIdentifier::P1 => outcome_for_p1, + PlayerIdentifier::P2 => -1.0 * outcome_for_p1, }) } } @@ -151,13 +194,15 @@ impl Default for GameState { p1: PlayerState::P1_START, p2: PlayerState::P2_START, walls: Default::default(), - whose_turn: true, + current_player: PlayerIdentifier::P1, } } } impl Display for GameState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let dm = GoalDistanceMap::new(&self.walls, self.current_player); + writeln!( f, "P1: {}, P2: {}\n", @@ -191,12 +236,12 @@ impl Display for GameState { }; write!(f, "{wall}")?; } - let player = if self.p1.x() == x && 8 - self.p1.y() == y { - "P1" - } else if self.p2.x() == x && 8 - self.p2.y() == y { - "P2" + let player = if self.p1.x() == x && self.p1.y() == y { + "\x1b[1mP1\x1b[0m".to_string() + } else if self.p2.x() == x && self.p2.y() == y { + "\x1b[1mP2\x1b[0m".to_string() } else { - " " + format!("{:^2}", dm.at(x, y)) }; write!(f, "{player}")?; } @@ -209,7 +254,7 @@ impl Display for GameState { #[cfg(test)] mod tests { - use crate::gamestate::{GameState, PlayerState, WallState}; + use crate::gamestate::WallState; #[test] fn test_blocking() { @@ -233,14 +278,5 @@ mod tests { assert!(w.can_walk_between(7, 0, 8, 0)); w.block(7, 0, 8, 0); assert!(!w.can_walk_between(7, 0, 8, 0)); - - println!( - "{}", - GameState { - p1: PlayerState::P1_START, - p2: PlayerState::P2_START, - walls: w, - } - ); } } diff --git a/src/main.rs b/src/main.rs index 8347bca..d41fa46 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,3 @@ -use std::error::Error; use std::thread; use std::time::Duration; @@ -8,8 +7,10 @@ use simple_mcts::Mcts; use simple_mcts::MctsError; use crate::gamestate::GameState; +use crate::gamestate::PlayerIdentifier; mod gamestate; +mod pathfind; struct Quoridor { state: GameState, @@ -38,23 +39,26 @@ impl Game for Quoridor { fn get_actions(&self) -> [bool; NUM_NEXT_STATES] { let mut res = [false; NUM_NEXT_STATES]; - for x in 0..8 { - for y in 0..8 { - if self.state.walls.can_walk_between(x, y, x + 1, y) - && self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1) - { - res[x as usize * 8 + y as usize] = true; - } + if self.state.current_player_state().walls_left > 0 { + for x in 0..8 { + for y in 0..8 { + if self.state.walls.can_walk_between(x, y, x + 1, y) + && self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1) + { + res[x as usize * 8 + y as usize] = self.state.walls.can_place(x, y, false); + } - if self.state.walls.can_walk_between(x, y, x, y + 1) - && self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1) - { - res[x as usize * 8 + y as usize + 64] = true; + if self.state.walls.can_walk_between(x, y, x, y + 1) + && self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1) + { + res[x as usize * 8 + y as usize + 64] = + self.state.walls.can_place(x, y, true); + } } } } - let p = self.state.current_player(); + let p = self.state.current_player_state(); let x = p.x(); let y = p.y(); res[128] = if x == 0 { @@ -78,8 +82,6 @@ impl Game for Quoridor { self.state.walls.can_walk_between(x, y, x, y + 1) }; // TODO: detect jumps for the 4 directions and also blocked jumps for the 2 sides of the other pawn - - // dbg!(&res[128..132]); res } @@ -88,56 +90,40 @@ impl Game for Quoridor { } fn play(&mut self, action: usize) { + if action <= 128 { + self.state.current_player_state_mut().walls_left -= 1; + } let mut set_block = |i: usize, vertical| { let x = (i / 8) as u8; let y = (i % 8) as u8; - - if vertical { - debug_assert!(self.state.walls.can_walk_between(x, y, x + 1, y)); - debug_assert!(self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1)); - - self.state.walls.block(x, y, x + 1, y); - self.state.walls.block(x, y + 1, x + 1, y + 1); - - debug_assert!(!self.state.walls.can_walk_between(x, y, x + 1, y)); - debug_assert!(!self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1)); - } else { - debug_assert!(self.state.walls.can_walk_between(x, y, x, y + 1)); - debug_assert!(self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1)); - - self.state.walls.block(x, y, x, y + 1); - self.state.walls.block(x + 1, y, x + 1, y + 1); - - debug_assert!(!self.state.walls.can_walk_between(x, y, x, y + 1)); - debug_assert!(!self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1)); - } + self.state.walls.place(x, y, vertical); }; match action { i @ 0..64 => set_block(i, true), i @ 64..128 => set_block(i - 64, false), 128 => { - let x = self.state.current_player().x(); - self.state.current_player_mut().set_x(x - 1); + let x = self.state.current_player_state().x(); + self.state.current_player_state_mut().set_x(x - 1); } 129 => { - let y = self.state.current_player().y(); - self.state.current_player_mut().set_y(y - 1); + let y = self.state.current_player_state().y(); + self.state.current_player_state_mut().set_y(y - 1); } 130 => { - let x = self.state.current_player().x(); - self.state.current_player_mut().set_x(x + 1); + let x = self.state.current_player_state().x(); + self.state.current_player_state_mut().set_x(x + 1); } 131 => { - let y = self.state.current_player().y(); - self.state.current_player_mut().set_x(y + 1); + let y = self.state.current_player_state().y(); + self.state.current_player_state_mut().set_y(y + 1); } 132 => todo!(), 133 => todo!(), _ => unreachable!(), } - self.state.whose_turn = !self.state.whose_turn; + self.state.current_player.swap(); } fn get_state(&self) -> Self::State { @@ -153,9 +139,8 @@ impl Game for Quoridor { } } -struct QuoridorEvaluator; - -impl GameEvaluator for QuoridorEvaluator { +struct ResultEvaluator; +impl GameEvaluator for ResultEvaluator { fn evaluate(&self, state: GameState) -> (f64, [f64; NUM_NEXT_STATES]) { ( state.mcts_result().unwrap_or(0.0), @@ -164,15 +149,31 @@ impl GameEvaluator for QuoridorEvaluator { } } +struct ProgressEvaluator; +impl GameEvaluator for ProgressEvaluator { + fn evaluate(&self, state: GameState) -> (f64, [f64; NUM_NEXT_STATES]) { + let progress_across = 8 - state + .current_player + .y_goal() + .abs_diff(state.current_player_state().y()); + + ( + progress_across as f64 / 8.0, + [const { 1.0 / NUM_NEXT_STATES as f64 }; NUM_NEXT_STATES], + ) + } +} + fn main() -> Result<(), MctsError> { - let g = Quoridor::default(); + let mut g = Quoridor::default(); + g.state.walls.place(4, 4, false); let mut mcts: Mcts = Mcts::::new(); - let evaluator = QuoridorEvaluator; + let evaluator = ProgressEvaluator; for _ in 0..100 { // Perform 100 MCTS iterations - for _ in 0..1000 { + for _ in 0..10_000 { mcts.iterate(&evaluator)?; } @@ -187,11 +188,18 @@ fn main() -> Result<(), MctsError> { .max_by(|&(_, &a), &(_, &b)| a.partial_cmp(&b).unwrap()) .map(|(index, _)| index) .unwrap_or(0); // Default to first action if policy is empty + + // let mut x = policy.iter().enumerate().collect::>(); + // x.sort_by(|&(_, &a), &(_, &b)| a.total_cmp(&b)); + // let top_5 = &x[x.len() - 5..]; + // println!("{top_5:?}"); + // let best_action_index = x.last().map(|(index, _)| *index).unwrap_or(0); + println!("best action: {best_action_index}"); mcts.play(best_action_index)?; println!("{}", mcts.get_game().state); - thread::sleep(Duration::from_millis(50)); + // thread::sleep(Duration::from_millis(500)); } // Continue with the next game state diff --git a/src/pathfind.rs b/src/pathfind.rs new file mode 100644 index 0000000..dea7707 --- /dev/null +++ b/src/pathfind.rs @@ -0,0 +1,44 @@ +use std::collections::VecDeque; + +use crate::gamestate::{PlayerIdentifier, WallState}; + +pub struct GoalDistanceMap { + distances: [[u8; 9]; 9], +} + +impl GoalDistanceMap { + pub fn new(w: &WallState, for_player: PlayerIdentifier) -> Self { + let mut todo = VecDeque::with_capacity(9 * 9); + let mut res = [[u8::MAX; 9]; 9]; + + for i in 0u8..9 { + todo.push_back(((i, for_player.y_goal()), 0)); + } + + while let Some(((x, y), distance)) = todo.pop_front() { + if res[y as usize][x as usize] != u8::MAX { + continue; + } + res[y as usize][x as usize] = distance; + + if x > 0 && w.can_walk_between(x, y, x - 1, y) { + todo.push_back(((x - 1, y), distance + 1)); + } + if x < 8 && w.can_walk_between(x, y, x + 1, y) { + todo.push_back(((x + 1, y), distance + 1)); + } + if y > 0 && w.can_walk_between(x, y, x, y - 1) { + todo.push_back(((x, y - 1), distance + 1)); + } + if y < 8 && w.can_walk_between(x, y, x, y + 1) { + todo.push_back(((x, y + 1), distance + 1)); + } + } + + Self { distances: res } + } + + pub fn at(&self, x: u8, y: u8) -> u8 { + self.distances[y as usize][x as usize] + } +}