ai
This commit is contained in:
parent
d4e626d13c
commit
20a6d2faeb
4 changed files with 401 additions and 8 deletions
159
Cargo.lock
generated
159
Cargo.lock
generated
|
|
@ -2,6 +2,165 @@
|
||||||
# It is not intended for manual editing.
|
# It is not intended for manual editing.
|
||||||
version = 4
|
version = 4
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfg-if"
|
||||||
|
version = "1.0.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "getrandom"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"r-efi",
|
||||||
|
"wasi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libc"
|
||||||
|
version = "0.2.175"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ppv-lite86"
|
||||||
|
version = "0.2.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
|
||||||
|
dependencies = [
|
||||||
|
"zerocopy",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro2"
|
||||||
|
version = "1.0.101"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quoridor"
|
name = "quoridor"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"simple-mcts",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quote"
|
||||||
|
version = "1.0.40"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "r-efi"
|
||||||
|
version = "5.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand"
|
||||||
|
version = "0.9.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
|
||||||
|
dependencies = [
|
||||||
|
"rand_chacha",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_chacha"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
|
||||||
|
dependencies = [
|
||||||
|
"ppv-lite86",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_core"
|
||||||
|
version = "0.9.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "simple-mcts"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "12206ee44e7faaf94c033a47d2168d59f4957e8f50d71b5fdf7e16cfd419644e"
|
||||||
|
dependencies = [
|
||||||
|
"rand",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "2.0.106"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-ident"
|
||||||
|
version = "1.0.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasi"
|
||||||
|
version = "0.14.5+wasi-0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a4494f6290a82f5fe584817a676a34b9d6763e8d9d18204009fb31dceca98fd4"
|
||||||
|
dependencies = [
|
||||||
|
"wasip2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasip2"
|
||||||
|
version = "1.0.0+wasi-0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "03fa2761397e5bd52002cd7e73110c71af2109aca4e521a9f40473fe685b0a24"
|
||||||
|
dependencies = [
|
||||||
|
"wit-bindgen",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wit-bindgen"
|
||||||
|
version = "0.45.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zerocopy"
|
||||||
|
version = "0.8.27"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c"
|
||||||
|
dependencies = [
|
||||||
|
"zerocopy-derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zerocopy-derive"
|
||||||
|
version = "0.8.27"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -4,3 +4,4 @@ version = "0.1.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
simple-mcts = "0.1"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
use simple_mcts::Game;
|
||||||
|
|
||||||
#[derive(Clone, Copy, PartialEq)]
|
#[derive(Clone, Copy, PartialEq)]
|
||||||
struct PlayerState {
|
pub struct PlayerState {
|
||||||
xy: u8,
|
xy: u8,
|
||||||
walls_left: u8,
|
walls_left: u8,
|
||||||
}
|
}
|
||||||
|
|
@ -45,7 +47,7 @@ impl PlayerState {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
struct WallState {
|
pub struct WallState {
|
||||||
verticals: [u8; 9],
|
verticals: [u8; 9],
|
||||||
horizontals: [u8; 9],
|
horizontals: [u8; 9],
|
||||||
}
|
}
|
||||||
|
|
@ -101,10 +103,46 @@ impl Default for WallState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GameState {
|
#[derive(Clone, Copy)]
|
||||||
p1: PlayerState,
|
pub struct GameState {
|
||||||
p2: PlayerState,
|
pub p1: PlayerState,
|
||||||
walls: WallState,
|
pub p2: PlayerState,
|
||||||
|
pub walls: WallState,
|
||||||
|
pub whose_turn: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GameState {
|
||||||
|
pub fn current_player(&self) -> &PlayerState {
|
||||||
|
if self.whose_turn { &self.p1 } else { &self.p2 }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn current_player_mut(&mut self) -> &mut PlayerState {
|
||||||
|
if self.whose_turn {
|
||||||
|
&mut self.p1
|
||||||
|
} else {
|
||||||
|
&mut self.p2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mcts_result(&self) -> Option<f64> {
|
||||||
|
let p1_won = self.p1.y() == 8;
|
||||||
|
let p2_won = self.p2.y() == 0;
|
||||||
|
|
||||||
|
let outcome_for_p1 = match (p1_won, p2_won) {
|
||||||
|
(false, false) => return None,
|
||||||
|
(true, false) => 1.0,
|
||||||
|
(false, true) => -1.0,
|
||||||
|
(true, true) => 0.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(if self.whose_turn {
|
||||||
|
//p1 wants to win
|
||||||
|
outcome_for_p1
|
||||||
|
} else {
|
||||||
|
//p2 wants to win
|
||||||
|
-1.0 * outcome_for_p1
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for GameState {
|
impl Default for GameState {
|
||||||
|
|
@ -113,6 +151,7 @@ impl Default for GameState {
|
||||||
p1: PlayerState::P1_START,
|
p1: PlayerState::P1_START,
|
||||||
p2: PlayerState::P2_START,
|
p2: PlayerState::P2_START,
|
||||||
walls: Default::default(),
|
walls: Default::default(),
|
||||||
|
whose_turn: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
198
src/main.rs
198
src/main.rs
|
|
@ -1,5 +1,199 @@
|
||||||
|
use std::error::Error;
|
||||||
|
use std::thread;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use simple_mcts::Game;
|
||||||
|
use simple_mcts::GameEvaluator;
|
||||||
|
use simple_mcts::Mcts;
|
||||||
|
use simple_mcts::MctsError;
|
||||||
|
|
||||||
|
use crate::gamestate::GameState;
|
||||||
|
|
||||||
mod gamestate;
|
mod gamestate;
|
||||||
|
|
||||||
fn main() {
|
struct Quoridor {
|
||||||
println!("Hello, world!");
|
state: GameState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Quoridor {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
state: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const NUM_NEXT_STATES: usize = 64 /* horizontal */ + 64 /* vertical */ +
|
||||||
|
4 /* move directions */ + 2 /* blocked jumps */;
|
||||||
|
|
||||||
|
impl Game<NUM_NEXT_STATES> for Quoridor {
|
||||||
|
type State = GameState;
|
||||||
|
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
state: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_actions(&self) -> [bool; NUM_NEXT_STATES] {
|
||||||
|
let mut res = [false; NUM_NEXT_STATES];
|
||||||
|
|
||||||
|
for x in 0..8 {
|
||||||
|
for y in 0..8 {
|
||||||
|
if self.state.walls.can_walk_between(x, y, x + 1, y)
|
||||||
|
&& self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1)
|
||||||
|
{
|
||||||
|
res[x as usize * 8 + y as usize] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.state.walls.can_walk_between(x, y, x, y + 1)
|
||||||
|
&& self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1)
|
||||||
|
{
|
||||||
|
res[x as usize * 8 + y as usize + 64] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let p = self.state.current_player();
|
||||||
|
let x = p.x();
|
||||||
|
let y = p.y();
|
||||||
|
res[128] = if x == 0 {
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
self.state.walls.can_walk_between(x, y, x - 1, y)
|
||||||
|
};
|
||||||
|
res[129] = if y == 0 {
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
self.state.walls.can_walk_between(x, y, x, y - 1)
|
||||||
|
};
|
||||||
|
res[130] = if x == 8 {
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
self.state.walls.can_walk_between(x, y, x + 1, y)
|
||||||
|
};
|
||||||
|
res[131] = if y == 8 {
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
self.state.walls.can_walk_between(x, y, x, y + 1)
|
||||||
|
};
|
||||||
|
// TODO: detect jumps for the 4 directions and also blocked jumps for the 2 sides of the other pawn
|
||||||
|
|
||||||
|
// dbg!(&res[128..132]);
|
||||||
|
res
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_finish(&self) -> bool {
|
||||||
|
self.state.p1.y() == 8 || self.state.p2.y() == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn play(&mut self, action: usize) {
|
||||||
|
let mut set_block = |i: usize, vertical| {
|
||||||
|
let x = (i / 8) as u8;
|
||||||
|
let y = (i % 8) as u8;
|
||||||
|
|
||||||
|
if vertical {
|
||||||
|
debug_assert!(self.state.walls.can_walk_between(x, y, x + 1, y));
|
||||||
|
debug_assert!(self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1));
|
||||||
|
|
||||||
|
self.state.walls.block(x, y, x + 1, y);
|
||||||
|
self.state.walls.block(x, y + 1, x + 1, y + 1);
|
||||||
|
|
||||||
|
debug_assert!(!self.state.walls.can_walk_between(x, y, x + 1, y));
|
||||||
|
debug_assert!(!self.state.walls.can_walk_between(x, y + 1, x + 1, y + 1));
|
||||||
|
} else {
|
||||||
|
debug_assert!(self.state.walls.can_walk_between(x, y, x, y + 1));
|
||||||
|
debug_assert!(self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1));
|
||||||
|
|
||||||
|
self.state.walls.block(x, y, x, y + 1);
|
||||||
|
self.state.walls.block(x + 1, y, x + 1, y + 1);
|
||||||
|
|
||||||
|
debug_assert!(!self.state.walls.can_walk_between(x, y, x, y + 1));
|
||||||
|
debug_assert!(!self.state.walls.can_walk_between(x + 1, y, x + 1, y + 1));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match action {
|
||||||
|
i @ 0..64 => set_block(i, true),
|
||||||
|
i @ 64..128 => set_block(i - 64, false),
|
||||||
|
128 => {
|
||||||
|
let x = self.state.current_player().x();
|
||||||
|
self.state.current_player_mut().set_x(x - 1);
|
||||||
|
}
|
||||||
|
129 => {
|
||||||
|
let y = self.state.current_player().y();
|
||||||
|
self.state.current_player_mut().set_y(y - 1);
|
||||||
|
}
|
||||||
|
130 => {
|
||||||
|
let x = self.state.current_player().x();
|
||||||
|
self.state.current_player_mut().set_x(x + 1);
|
||||||
|
}
|
||||||
|
131 => {
|
||||||
|
let y = self.state.current_player().y();
|
||||||
|
self.state.current_player_mut().set_x(y + 1);
|
||||||
|
}
|
||||||
|
132 => todo!(),
|
||||||
|
133 => todo!(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
|
||||||
|
self.state.whose_turn = !self.state.whose_turn;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_state(&self) -> Self::State {
|
||||||
|
self.state
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_result(&self) -> Option<f64> {
|
||||||
|
self.get_state().mcts_result()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self { state: self.state }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct QuoridorEvaluator;
|
||||||
|
|
||||||
|
impl GameEvaluator<Quoridor, NUM_NEXT_STATES> for QuoridorEvaluator {
|
||||||
|
fn evaluate(&self, state: GameState) -> (f64, [f64; NUM_NEXT_STATES]) {
|
||||||
|
(
|
||||||
|
state.mcts_result().unwrap_or(0.0),
|
||||||
|
[const { 1.0 / NUM_NEXT_STATES as f64 }; NUM_NEXT_STATES],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<(), MctsError> {
|
||||||
|
let g = Quoridor::default();
|
||||||
|
|
||||||
|
let mut mcts: Mcts<Quoridor, _> = Mcts::<Quoridor, _>::new();
|
||||||
|
let evaluator = QuoridorEvaluator;
|
||||||
|
|
||||||
|
for _ in 0..100 {
|
||||||
|
// Perform 100 MCTS iterations
|
||||||
|
for _ in 0..1000 {
|
||||||
|
mcts.iterate(&evaluator)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the best action based on visit counts
|
||||||
|
let (score, policy) = mcts.get_result();
|
||||||
|
println!("Best action score: {}, Policy: {:?}", score, policy);
|
||||||
|
|
||||||
|
// Play the best action and update the MCTS tree
|
||||||
|
let best_action_index = policy
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|&(_, &a), &(_, &b)| a.partial_cmp(&b).unwrap())
|
||||||
|
.map(|(index, _)| index)
|
||||||
|
.unwrap_or(0); // Default to first action if policy is empty
|
||||||
|
println!("best action: {best_action_index}");
|
||||||
|
mcts.play(best_action_index)?;
|
||||||
|
|
||||||
|
println!("{}", mcts.get_game().state);
|
||||||
|
thread::sleep(Duration::from_millis(50));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue with the next game state
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue