wip incremental distance update

This commit is contained in:
Julia Ryan 2025-09-12 00:45:48 -07:00
parent 2bf743a272
commit 3bd28451e9
No known key found for this signature in database
3 changed files with 55 additions and 14 deletions

View file

@ -201,12 +201,20 @@ impl Default for GameState {
impl Display for GameState { impl Display for GameState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let dm = GoalDistanceMap::new(&self.walls, self.current_player); let d1 = GoalDistanceMap::new(&self.walls, PlayerIdentifier::P1);
let d2 = GoalDistanceMap::new(&self.walls, PlayerIdentifier::P2);
writeln!( writeln!(
f, f,
"P1: {}, P2: {}\n", "P1: {} walls, {} away from win",
self.p1.walls_left, self.p2.walls_left self.p1.walls_left,
d1.at(self.p1.x(), self.p1.y())
)?;
writeln!(
f,
"P2: {} walls, {} away from win",
self.p2.walls_left,
d2.at(self.p2.x(), self.p2.y())
)?; )?;
writeln!(f, "┏━━┬━━┬━━┬━━┬━━┬━━┬━━┬━━┬━━┓")?; writeln!(f, "┏━━┬━━┬━━┬━━┬━━┬━━┬━━┬━━┬━━┓")?;
for y in 0..9 { for y in 0..9 {
@ -237,11 +245,12 @@ impl Display for GameState {
write!(f, "{wall}")?; write!(f, "{wall}")?;
} }
let player = if self.p1.x() == x && self.p1.y() == y { let player = if self.p1.x() == x && self.p1.y() == y {
"\x1b[1mP1\x1b[0m".to_string() "\x1b[1mP1\x1b[0m".to_owned()
} else if self.p2.x() == x && self.p2.y() == y { } else if self.p2.x() == x && self.p2.y() == y {
"\x1b[1mP2\x1b[0m".to_string() "\x1b[1mP2\x1b[0m".to_owned()
} else { } else {
format!("{:^2}", dm.at(x, y)) // format!("{:^2}", dm.at(x, y))
" ".to_owned()
}; };
write!(f, "{player}")?; write!(f, "{player}")?;
} }

View file

@ -81,6 +81,11 @@ impl Game<NUM_NEXT_STATES> for Quoridor {
} else { } else {
self.state.walls.can_walk_between(x, y, x, y + 1) self.state.walls.can_walk_between(x, y, x, y + 1)
}; };
if self.state.current_player_state().walls_left == 0 {
if res[..128].iter().any(|x| *x) {
panic!("{res:?}")
}
}
// TODO: detect jumps for the 4 directions and also blocked jumps for the 2 sides of the other pawn // TODO: detect jumps for the 4 directions and also blocked jumps for the 2 sides of the other pawn
res res
} }
@ -90,7 +95,7 @@ impl Game<NUM_NEXT_STATES> for Quoridor {
} }
fn play(&mut self, action: usize) { fn play(&mut self, action: usize) {
if action <= 128 { if action < 128 {
self.state.current_player_state_mut().walls_left -= 1; self.state.current_player_state_mut().walls_left -= 1;
} }
let mut set_block = |i: usize, vertical| { let mut set_block = |i: usize, vertical| {
@ -171,23 +176,19 @@ fn main() -> Result<(), MctsError> {
let mut mcts: Mcts<Quoridor, _> = Mcts::<Quoridor, _>::new(); let mut mcts: Mcts<Quoridor, _> = Mcts::<Quoridor, _>::new();
let evaluator = ProgressEvaluator; let evaluator = ProgressEvaluator;
for _ in 0..100 { while g.get_result().is_none() {
// Perform 100 MCTS iterations for _ in 0..5000 {
for _ in 0..10_000 {
mcts.iterate(&evaluator)?; mcts.iterate(&evaluator)?;
} }
// Get the best action based on visit counts
let (score, policy) = mcts.get_result(); let (score, policy) = mcts.get_result();
println!("Best action score: {}, Policy: {:?}", score, policy);
// Play the best action and update the MCTS tree
let best_action_index = policy let best_action_index = policy
.iter() .iter()
.enumerate() .enumerate()
.max_by(|&(_, &a), &(_, &b)| a.partial_cmp(&b).unwrap()) .max_by(|&(_, &a), &(_, &b)| a.partial_cmp(&b).unwrap())
.map(|(index, _)| index) .map(|(index, _)| index)
.unwrap_or(0); // Default to first action if policy is empty .unwrap_or(0);
// let mut x = policy.iter().enumerate().collect::<Vec<_>>(); // let mut x = policy.iter().enumerate().collect::<Vec<_>>();
// x.sort_by(|&(_, &a), &(_, &b)| a.total_cmp(&b)); // x.sort_by(|&(_, &a), &(_, &b)| a.total_cmp(&b));

View file

@ -38,6 +38,37 @@ impl GoalDistanceMap {
Self { distances: res } Self { distances: res }
} }
pub fn add_wall(&mut self, w: &mut WallState, x: u8, y: u8, vertical: bool) {
// TODO: we just realized that actually you have to search from the 4 invalidated
// nodes and then propagate back when you've found a distance that's minus 1 from a
// trusted node
w.place(x, y, vertical);
let invalidate = [(x, y), (x + 1, y), (x, y + 1), (x + 1, y + 1)];
for (x, y) in invalidate {
self.distances[y as usize][x as usize] = u8::MAX;
}
let mut todo = VecDeque::from(invalidate.map(|(x, y)| ((x, y), self.at(x, y))));
while let Some(((x, y), distance)) = todo.pop_front() {
if self.distances[y as usize][x as usize] == distance - 1 {
continue;
}
self.distances[y as usize][x as usize] = distance;
if x > 0 && w.can_walk_between(x, y, x - 1, y) {
todo.push_back(((x - 1, y), distance - 1));
}
if x < 8 && w.can_walk_between(x, y, x + 1, y) {
todo.push_back(((x + 1, y), distance - 1));
}
if y > 0 && w.can_walk_between(x, y, x, y - 1) {
todo.push_back(((x, y - 1), distance - 1));
}
if y < 8 && w.can_walk_between(x, y, x, y + 1) {
todo.push_back(((x, y + 1), distance - 1));
}
}
}
pub fn at(&self, x: u8, y: u8) -> u8 { pub fn at(&self, x: u8, y: u8) -> u8 {
self.distances[y as usize][x as usize] self.distances[y as usize][x as usize]
} }