From d6eb7975ae119de0ce39734ca8256a92165c2c4b Mon Sep 17 00:00:00 2001 From: bkushigian Date: Fri, 2 Aug 2024 08:37:29 -0700 Subject: [PATCH 01/66] gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 96ef6c0..c640ca5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +.vscode From 05ca057b1bb61132f10be73a9a26ca8baa5fd1ea Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sat, 3 Aug 2024 15:40:30 -0700 Subject: [PATCH 02/66] Updated authors --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f4ccf76..981906f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,11 @@ [package] name = "postflop-solver" version = "0.1.0" -authors = ["Wataru Inariba"] +authors = ["Wataru Inariba", "Ben Kushigian"] edition = "2021" description = "An open-source postflop solver for Texas hold'em poker" documentation = "https://b-inary.github.io/postflop_solver/postflop_solver/" -repository = "https://github.com/b-inary/postflop-solver" +repository = "https://github.com/bkushigian/postflop-solver" license = "AGPL-3.0-or-later" [dependencies] From 4aeeecec0454ac406b239bf81746a3bc65e53aa2 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sat, 3 Aug 2024 15:41:40 -0700 Subject: [PATCH 03/66] README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 974f740..7d88afd 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,12 @@ # postflop-solver +This is a fork of [b-inary's postflop solver] that I will be maintaining. + > [!IMPORTANT] > **As of October 2023, I have started developing a poker solver as a business and have decided to suspend development of this open-source project. See [this issue] for more information.** [this issue]: https://github.com/b-inary/postflop-solver/issues/46 +[b-inary's postflop solver]: https://github.com/b-inary/postflop-solver --- From f92d23a1efbe85318ea86368e79b1b9fb71b8f30 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 4 Aug 2024 09:28:11 -0700 Subject: [PATCH 04/66] Docs --- src/card.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/card.rs b/src/card.rs index 2bba763..e36121a 100644 --- a/src/card.rs +++ b/src/card.rs @@ -11,6 +11,8 @@ use bincode::{Decode, Encode}; /// - `card_id = 4 * rank + suit` (where `0 <= card_id < 52`) /// - `rank`: 2 => `0`, 3 => `1`, 4 => `2`, ..., A => `12` /// - `suit`: club => `0`, diamond => `1`, heart => `2`, spade => `3` +/// +/// An undealt card is represented by Card::MAX (see `NOT_DEALT`). pub type Card = u8; /// Constant representing that the card is not yet dealt. From 1d366f667ff157d8a99bb0ec61e13208fdadab32 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 4 Aug 2024 09:28:29 -0700 Subject: [PATCH 05/66] Docs --- src/game/serialization.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/game/serialization.rs b/src/game/serialization.rs index 951c785..851619c 100644 --- a/src/game/serialization.rs +++ b/src/game/serialization.rs @@ -58,7 +58,13 @@ impl PostFlopGame { } } - /// Returns the number of storage elements required for the target storage mode. + /// Returns the number of storage elements required for the target storage mode: + /// `[|storage1|, |storage2|, |storage_ip|, |storage_chance|]` + /// + /// If this is a River save (`target_storage_mode == BoardState::River`) + /// then do not store cfvalues. + /// + /// If this is a Flop save, fn num_target_storage(&self) -> [usize; 4] { if self.state <= State::TreeBuilt { return [0; 4]; From b2c12b83c2a8102b3f7325237e6e4c580bff30ca Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 4 Aug 2024 09:28:51 -0700 Subject: [PATCH 06/66] Added helper function --- src/game/base.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/game/base.rs b/src/game/base.rs index f1cc0f0..6e0f54e 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -462,6 +462,21 @@ impl PostFlopGame { Ok(()) } + pub fn print_internal_data(&self) { + println!("Printing intenral data for PostFlopGame"); + println!("- node_arena: {}", self.node_arena.len()); + println!("- storage1: {}", self.storage1.len()); + println!("- storage2: {}", self.storage2.len()); + println!("- storage_ip: {}", self.storage_ip.len()); + println!("- storage_chance: {}", self.storage_chance.len()); + println!("- locking_strategy: {}", self.locking_strategy.len()); + println!("- storage mode: {:?}", self.storage_mode()); + println!( + "- target storage mode: {:?}", + self.target_storage_mode() + ); + } + /// Initializes fields `initial_weights` and `private_cards`. #[inline] fn init_hands(&mut self) { From 8b7067e8e634b0db016ddfe8a7672711ed59bac8 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 4 Aug 2024 09:29:09 -0700 Subject: [PATCH 07/66] DESIGN.md --- DESIGN.md | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 DESIGN.md diff --git a/DESIGN.md b/DESIGN.md new file mode 100644 index 0000000..ffb31ed --- /dev/null +++ b/DESIGN.md @@ -0,0 +1,77 @@ +# Design + +This document is a description, as far as I understand it, of the inner design +of the solver and PostFlopGame. This is a working document for me to get my +bearings. + +## PostFlopGame + +### Storage + +There are several fields marked as `// global storage` in `game::mod::PostFlopGame`: + +```rust + // global storage + // `storage*` are used as a global storage and are referenced by `PostFlopNode::storage*`. + // Methods like `PostFlopNode::strategy` define how the storage is used. + node_arena: Vec>, + storage1: Vec, + storage2: Vec, + storage_ip: Vec, + storage_chance: Vec, + locking_strategy: BTreeMap>, +``` + +These are referenced from `PostFlopNode`: + +```rust + storage1: *mut u8, // strategy + storage2: *mut u8, // regrets or cfvalues + storage3: *mut u8, // IP cfvalues +``` + +- `storage1` seems to store the strategy +- `storage2` seems to store regrets/cfvalues, and +- `storage3` stores IP's cf values (does that make `storage2` store OOP's cfvalues?) + +Storage is a byte vector `Vec`, and these store floating point values. + +> [!IMPORTANT] +> Why are these stored as `Vec`s? Is this for swapping between +> `f16` and `f32`s? + +Some storage is allocated in `game::base::allocate_memory`: + +```rust + let storage_bytes = (num_bytes * self.num_storage) as usize; + let storage_ip_bytes = (num_bytes * self.num_storage_ip) as usize; + let storage_chance_bytes = (num_bytes * self.num_storage_chance) as usize; + + self.storage1 = vec![0; storage_bytes]; + self.storage2 = vec![0; storage_bytes]; + self.storage_ip = vec![0; storage_ip_bytes]; + self.storage_chance = vec![0; storage_chance_bytes]; +``` + +`node_arena` is initialized in `game::base::init_root()`: + +```rust + let num_nodes = self.count_num_nodes(); + let total_num_nodes = num_nodes[0] + num_nodes[1] + num_nodes[2]; + + if total_num_nodes > u32::MAX as u64 + || mem::size_of::() as u64 * total_num_nodes > isize::MAX as u64 + { + return Err("Too many nodes".to_string()); + } + + self.num_nodes = num_nodes; + self.node_arena = (0..total_num_nodes) + .map(|_| MutexLike::new(PostFlopNode::default())) + .collect::>(); + self.clear_storage(); +``` + +### Serialization + +Serialization relies on the `bincode` library's `Encode` and `Decode`. \ No newline at end of file From 0236e138f2026d4df18c345494efac75425acc7b Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 4 Aug 2024 09:29:45 -0700 Subject: [PATCH 08/66] Created debug example --- examples/file_io_debug.rs | 83 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 examples/file_io_debug.rs diff --git a/examples/file_io_debug.rs b/examples/file_io_debug.rs new file mode 100644 index 0000000..2edf5ed --- /dev/null +++ b/examples/file_io_debug.rs @@ -0,0 +1,83 @@ +use postflop_solver::*; + +fn main() { + // see `basic.rs` for the explanation of the following code + + let oop_range = "66+,A8s+,A5s-A4s,AJo+,K9s+,KQo,QTs+,JTs,96s+,85s+,75s+,65s,54s"; + let ip_range = "QQ-22,AQs-A2s,ATo+,K5s+,KJo+,Q8s+,J8s+,T7s+,96s+,86s+,75s+,64s+,53s+"; + + let card_config = CardConfig { + range: [oop_range.parse().unwrap(), ip_range.parse().unwrap()], + flop: flop_from_str("Td9d6h").unwrap(), + turn: card_from_str("Qc").unwrap(), + river: NOT_DEALT, + }; + + let bet_sizes = BetSizeOptions::try_from(("60%, e, a", "2.5x")).unwrap(); + + let tree_config = TreeConfig { + initial_state: BoardState::Turn, + starting_pot: 200, + effective_stack: 900, + rake_rate: 0.0, + rake_cap: 0.0, + flop_bet_sizes: [bet_sizes.clone(), bet_sizes.clone()], + turn_bet_sizes: [bet_sizes.clone(), bet_sizes.clone()], + river_bet_sizes: [bet_sizes.clone(), bet_sizes], + turn_donk_sizes: None, + river_donk_sizes: Some(DonkSizeOptions::try_from("50%").unwrap()), + add_allin_threshold: 1.5, + force_allin_threshold: 0.15, + merging_threshold: 0.1, + }; + + let action_tree = ActionTree::new(tree_config).unwrap(); + let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); + game.allocate_memory(false); + + let max_num_iterations = 20; + let target_exploitability = game.tree_config().starting_pot as f32 * 0.01; + solve(&mut game, max_num_iterations, target_exploitability, true); + let r = game.set_target_storage_mode(BoardState::Turn); + println!("{r:?}"); + + // save the solved game tree to a file + // 4th argument is zstd compression level (1-22); requires `zstd` feature to use + save_data_to_file(&game, "memo string", "filename.bin", None).unwrap(); + + // load the solved game tree from a file + // 2nd argument is the maximum memory usage in bytes + let (mut game2, _memo_string): (PostFlopGame, _) = + load_data_from_file("filename.bin", None).unwrap(); + + println!("Game 1 Internal Data"); + game.print_internal_data(); + println!("Game 2 Internal Data"); + game2.print_internal_data(); + + // check if the loaded game tree is the same as the original one + game.cache_normalized_weights(); + game2.cache_normalized_weights(); + assert_eq!(game.equity(0), game2.equity(0)); + + // discard information after the river deal when serializing + // this operation does not lose any information of the game tree itself + game2.set_target_storage_mode(BoardState::Turn).unwrap(); + + // compare the memory usage for serialization + println!( + "Memory usage of the original game tree: {:.2}MB", // 11.50MB + game.target_memory_usage() as f64 / (1024.0 * 1024.0) + ); + println!( + "Memory usage of the truncated game tree: {:.2}MB", // 0.79MB + game2.target_memory_usage() as f64 / (1024.0 * 1024.0) + ); + + // overwrite the file with the truncated game tree + // game tree constructed from this file cannot access information after the river deal + save_data_to_file(&game2, "memo string", "filename.bin", None).unwrap(); + + // delete the file + std::fs::remove_file("filename.bin").unwrap(); +} From f9497a68e54e3e88d13cfecff2bc307201aa7f15 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 4 Aug 2024 11:49:24 -0700 Subject: [PATCH 09/66] DESIGN.md --- DESIGN.md | 48 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/DESIGN.md b/DESIGN.md index ffb31ed..c6bdd97 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -53,7 +53,7 @@ Some storage is allocated in `game::base::allocate_memory`: self.storage_chance = vec![0; storage_chance_bytes]; ``` -`node_arena` is initialized in `game::base::init_root()`: +`node_arena` is allocated in `game::base::init_root()`: ```rust let num_nodes = self.count_num_nodes(); @@ -72,6 +72,48 @@ Some storage is allocated in `game::base::allocate_memory`: self.clear_storage(); ``` -### Serialization +`locking_strategy` maps node indexes (`PostFlopGame::node_index`) to a locked +strategy. `locking_strategy` is initialized to an empty `BTreeMap>` by deriving Default. It is inserted into via +`PostFlopGame::lock_current_strategy` -Serialization relies on the `bincode` library's `Encode` and `Decode`. \ No newline at end of file +### Serialization/Deserialization + +Serialization relies on the `bincode` library's `Encode` and `Decode`. We can set +the `target_storage_mode` to allow for a non-full save. For instance, + +```rust +game.set_target_storage_mode(BoardState::Turn); +``` + +will ensure that when `game` is encoded, it will only save Flop and Turn data. +When a serialized tree is deserialized, if it is a parital save (e.g., a Turn +save) you will not be able to navigate to unsaved streets. + +Several things break when we deserialize a partial save: +- `node_arena` is only partially populated +- `node.children()` points to raw data when `node` points to an street that is + not serialized (e.g., a chance node before the river for a Turn save). + +### Allocating `node_arena` + +We want to first allocate nodes for `node_arena`, and then run some form of +`build_tree_recursive`. This assumes that `node_arena` is already allocated, and +recursively visits children of nodes and modifies them to + + +### Data Coupling/Relations/Invariants + +- A node is locked iff it is contained in the game's locking_strategy +- `PostFlopGame.node_arena` is pointed to by `PostFlopNode.children_offset`. For + instance, this is the basic definition of the `PostFlopNode.children()` + function: + + ```rust + slice::from_raw_parts( + self_ptr.add(self.children_offset as usize), + self.num_children as usize, + ) + ``` + + We get a pointer to `self` and add children offset. \ No newline at end of file From 7aa6668e3f58ba8c0434445986ec1c87a5a7e2a8 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 4 Aug 2024 12:07:53 -0700 Subject: [PATCH 10/66] Updates to examples for debugging/documenting --- examples/file_io_debug.rs | 14 +++++++++++ examples/simple.rs | 53 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 examples/simple.rs diff --git a/examples/file_io_debug.rs b/examples/file_io_debug.rs index 2edf5ed..f0348a6 100644 --- a/examples/file_io_debug.rs +++ b/examples/file_io_debug.rs @@ -1,3 +1,5 @@ +use std::fs::File; + use postflop_solver::*; fn main() { @@ -77,6 +79,18 @@ fn main() { // overwrite the file with the truncated game tree // game tree constructed from this file cannot access information after the river deal save_data_to_file(&game2, "memo string", "filename.bin", None).unwrap(); + let (mut game3, _memo_string): (PostFlopGame, String) = + load_data_from_file("filename.bin", None).unwrap(); + + game.play(0); + game.play(0); + println!("Game X/X Actions: {:?}", game.available_actions()); + game2.play(0); + game2.play(0); + println!("Game2 X/X Actions: {:?}", game.available_actions()); + game3.play(0); + game3.play(0); + println!("Game3 X/X Actions: {:?}", game3.available_actions()); // delete the file std::fs::remove_file("filename.bin").unwrap(); diff --git a/examples/simple.rs b/examples/simple.rs new file mode 100644 index 0000000..00782d0 --- /dev/null +++ b/examples/simple.rs @@ -0,0 +1,53 @@ +use postflop_solver::*; + +fn main() { + // ranges of OOP and IP in string format + // see the documentation of `Range` for more details about the format + let oop_range = "66+"; + let ip_range = "66+"; + + let card_config = CardConfig { + range: [oop_range.parse().unwrap(), ip_range.parse().unwrap()], + flop: flop_from_str("Td9d6h").unwrap(), + turn: NOT_DEALT, + river: NOT_DEALT, + }; + + // bet sizes -> 60% of the pot, geometric size, and all-in + // raise sizes -> 2.5x of the previous bet + // see the documentation of `BetSizeOptions` for more details + let bet_sizes = BetSizeOptions::try_from(("100%", "100%")).unwrap(); + + let tree_config = TreeConfig { + initial_state: BoardState::Flop, // must match `card_config` + starting_pot: 200, + effective_stack: 200, + rake_rate: 0.0, + rake_cap: 0.0, + flop_bet_sizes: [bet_sizes.clone(), bet_sizes.clone()], // [OOP, IP] + turn_bet_sizes: [bet_sizes.clone(), bet_sizes.clone()], + river_bet_sizes: [bet_sizes.clone(), bet_sizes], + turn_donk_sizes: None, // use default bet sizes + river_donk_sizes: Some(DonkSizeOptions::try_from("100%").unwrap()), + add_allin_threshold: 1.5, // add all-in if (maximum bet size) <= 1.5x pot + force_allin_threshold: 0.15, // force all-in if (SPR after the opponent's call) <= 0.15 + merging_threshold: 0.1, + }; + + // build the game tree + // `ActionTree` can be edited manually after construction + let action_tree = ActionTree::new(tree_config).unwrap(); + let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); + + // allocate memory without compression (use 32-bit float) + game.allocate_memory(false); + + // solve the game + let max_num_iterations = 20; + let target_exploitability = game.tree_config().starting_pot as f32 * 0.100; // 10.0% of the pot + let exploitability = solve(&mut game, max_num_iterations, target_exploitability, true); + println!("Exploitability: {:.2}", exploitability); + + // get equity and EV of a specific hand + game.cache_normalized_weights(); +} From 334315bffc87b7a8b0fdc380d37730ad615d6c47 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Mon, 5 Aug 2024 15:33:38 -0700 Subject: [PATCH 11/66] Intermediate stuff on DESIGN.md --- DESIGN.md | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/DESIGN.md b/DESIGN.md index ffb31ed..b7c8bcc 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -6,6 +6,60 @@ bearings. ## PostFlopGame +### Build/Allocate/Initialize + +We begin by creating a `PostFlopGame:` + +1. **Create Configurations**: + + We need a `tree_config: TreeConfig` + + We need an `action_tree: ActionTree::new(tree_config)` + + We need a `card_config: CardConfig` +2. **PostFlopGame**: We build a `PostFlopGame` from `action_tree` and `card_config`: + + ```rust + let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); + ``` + +Once the game is created we need to allocate memory: + ++ `game.node_arena` ++ `game.storage1` ++ `game.storage2` ++ `game.storage_ip` ++ `game.storage_chance` + +These fields are not allocated at the same time. `game.node_arena` is allocated +via `with_config`, which calls `update_config`, which in turn calls `init_root`. +`init_root` is responsible for: + +1. Allocating `PostFlopNode`s in `node_arena` +2. Invoking `build_tree_recursive` which initializes each node's child/parent + relationship via `child_offset` (through calls to ). + +Each `PostFlopNode` points to node-specific data (eg., strategies and cfregrets) +that is located inside of `PostFlopGame.storage*` fields (which is currently +unallocated) via similarly named fields `PostFlopNode.storage*`. + +Additionally, each node points to the children offset with `children_offset`, +which records where in `node_arena` relative to the current node that node's +children begin. We allocate this memory via: + +```rust +game.allocate_memory(false); // pass `true` to use compressed memory +``` + +This allocates the following memory: + ++ `self.storage1` ++ `self.storage2` ++ `self.storage3` ++ `self.storage_chance` + +Next, `allocate_memory()` calls `allocate_memory_nodes(&mut self)`, which +iterates through each node in `node_arena` and sets storage pointers. + +After `allocate_memory` returns we still need to set `child_offset`s. + ### Storage There are several fields marked as `// global storage` in `game::mod::PostFlopGame`: From 478a72e6ecdaa755ad6e297638d3834680399bd8 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Mon, 5 Aug 2024 19:46:10 -0700 Subject: [PATCH 12/66] DESIGN.md --- DESIGN.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/DESIGN.md b/DESIGN.md index a53488f..2b15c18 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -34,7 +34,8 @@ via `with_config`, which calls `update_config`, which in turn calls `init_root`. 1. Allocating `PostFlopNode`s in `node_arena` 2. Invoking `build_tree_recursive` which initializes each node's child/parent - relationship via `child_offset` (through calls to ). + relationship via `child_offset` (through calls to `push_actions` and + `push_chances`). Each `PostFlopNode` points to node-specific data (eg., strategies and cfregrets) that is located inside of `PostFlopGame.storage*` fields (which is currently From 25d5ee6d3cad37c30a24a61f9e7a38f04a48035d Mon Sep 17 00:00:00 2001 From: bkushigian Date: Tue, 6 Aug 2024 11:38:46 -0700 Subject: [PATCH 13/66] docs and rename function for readability --- src/game/base.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/game/base.rs b/src/game/base.rs index 6e0f54e..3f5b615 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -536,7 +536,7 @@ impl PostFlopGame { /// Initializes the root node of game tree. fn init_root(&mut self) -> Result<(), String> { - let num_nodes = self.count_num_nodes(); + let num_nodes = self.count_nodes_per_street(); let total_num_nodes = num_nodes[0] + num_nodes[1] + num_nodes[2]; if total_num_nodes > u32::MAX as u64 @@ -599,9 +599,10 @@ impl PostFlopGame { self.storage_chance = Vec::new(); } - /// Counts the number of nodes in the game tree. + /// Counts the number of nodes in the game tree per street, accounting for + /// isomorphism. #[inline] - fn count_num_nodes(&self) -> [u64; 3] { + fn count_nodes_per_street(&self) -> [u64; 3] { let (turn_coef, river_coef) = match (self.card_config.turn, self.card_config.river) { (NOT_DEALT, _) => { let mut river_coef = 0; From f05a1796053151ee880749d6fbeb388511637c01 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Tue, 6 Aug 2024 11:39:30 -0700 Subject: [PATCH 14/66] rename local var for clarity --- src/game/base.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/game/base.rs b/src/game/base.rs index 3f5b615..27d5d28 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -536,8 +536,8 @@ impl PostFlopGame { /// Initializes the root node of game tree. fn init_root(&mut self) -> Result<(), String> { - let num_nodes = self.count_nodes_per_street(); - let total_num_nodes = num_nodes[0] + num_nodes[1] + num_nodes[2]; + let nodes_per_street = self.count_nodes_per_street(); + let total_num_nodes = nodes_per_street[0] + nodes_per_street[1] + nodes_per_street[2]; if total_num_nodes > u32::MAX as u64 || mem::size_of::() as u64 * total_num_nodes > isize::MAX as u64 @@ -545,15 +545,15 @@ impl PostFlopGame { return Err("Too many nodes".to_string()); } - self.num_nodes = num_nodes; + self.num_nodes = nodes_per_street; self.node_arena = (0..total_num_nodes) .map(|_| MutexLike::new(PostFlopNode::default())) .collect::>(); self.clear_storage(); let mut info = BuildTreeInfo { - turn_index: num_nodes[0] as usize, - river_index: (num_nodes[0] + num_nodes[1]) as usize, + turn_index: nodes_per_street[0] as usize, + river_index: (nodes_per_street[0] + nodes_per_street[1]) as usize, ..Default::default() }; From 729693c953ba56bcac1f1c92fe6571e04de33ba8 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Tue, 6 Aug 2024 11:45:13 -0700 Subject: [PATCH 15/66] Rename for readability --- src/game/base.rs | 2 +- src/game/mod.rs | 2 +- src/game/serialization.rs | 14 ++++++++------ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/game/base.rs b/src/game/base.rs index 27d5d28..ba14b4d 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -545,7 +545,7 @@ impl PostFlopGame { return Err("Too many nodes".to_string()); } - self.num_nodes = nodes_per_street; + self.num_nodes_per_street = nodes_per_street; self.node_arena = (0..total_num_nodes) .map(|_| MutexLike::new(PostFlopNode::default())) .collect::>(); diff --git a/src/game/mod.rs b/src/game/mod.rs index 33d9a19..09042fd 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -82,7 +82,7 @@ pub struct PostFlopGame { // store options storage_mode: BoardState, target_storage_mode: BoardState, - num_nodes: [u64; 3], + num_nodes_per_street: [u64; 3], is_compression_enabled: bool, num_storage: u64, num_storage_ip: u64, diff --git a/src/game/serialization.rs b/src/game/serialization.rs index 851619c..5931bf8 100644 --- a/src/game/serialization.rs +++ b/src/game/serialization.rs @@ -77,8 +77,8 @@ impl PostFlopGame { } let mut node_index = match self.target_storage_mode { - BoardState::Flop => self.num_nodes[0], - _ => self.num_nodes[0] + self.num_nodes[1], + BoardState::Flop => self.num_nodes_per_street[0], + _ => self.num_nodes_per_street[0] + self.num_nodes_per_street[1], } as usize; let mut num_storage = [0; 4]; @@ -134,7 +134,7 @@ impl Encode for PostFlopGame { self.removed_lines.encode(encoder)?; self.action_root.encode(encoder)?; self.target_storage_mode.encode(encoder)?; - self.num_nodes.encode(encoder)?; + self.num_nodes_per_street.encode(encoder)?; self.is_compression_enabled.encode(encoder)?; self.num_storage.encode(encoder)?; self.num_storage_ip.encode(encoder)?; @@ -146,8 +146,10 @@ impl Encode for PostFlopGame { self.storage_chance[0..num_storage[3]].encode(encoder)?; let num_nodes = match self.target_storage_mode { - BoardState::Flop => self.num_nodes[0] as usize, - BoardState::Turn => (self.num_nodes[0] + self.num_nodes[1]) as usize, + BoardState::Flop => self.num_nodes_per_street[0] as usize, + BoardState::Turn => { + (self.num_nodes_per_street[0] + self.num_nodes_per_street[1]) as usize + } BoardState::River => self.node_arena.len(), }; @@ -199,7 +201,7 @@ impl Decode for PostFlopGame { removed_lines: Decode::decode(decoder)?, action_root: Decode::decode(decoder)?, storage_mode: Decode::decode(decoder)?, - num_nodes: Decode::decode(decoder)?, + num_nodes_per_street: Decode::decode(decoder)?, is_compression_enabled: Decode::decode(decoder)?, num_storage: Decode::decode(decoder)?, num_storage_ip: Decode::decode(decoder)?, From 8d191bb2b1c08e3d0dd7e3ad8815b9227daf91d7 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Fri, 9 Aug 2024 10:26:11 -0700 Subject: [PATCH 16/66] DESIGN.md --- DESIGN.md | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/DESIGN.md b/DESIGN.md index 2b15c18..99e5ca9 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -8,19 +8,30 @@ bearings. ### Build/Allocate/Initialize -We begin by creating a `PostFlopGame:` +To set up a `PostFlopGame` we need to **create a `PostFlopGame` instance**, +**allocate global storage and `PostFlopNode`s**, and **initialize the +`PostFlopNode` child/parent relationship**. This is done in several steps. + + +We begin by creating a `PostFlopGame`. A `PostFlopGame` requires an `ActionTree` +and a `CardConfig`. The `ActionTree` represents the full game tree modded out by +different runouts (so an `ActionTree` might have an abstract _line_ **Bet 10; +Call; Bet 30; Call** while the game tree would have concrete _nodes_ like +**Bet 10; Call; Th; Bet 30; Call**, etc). 1. **Create Configurations**: + We need a `tree_config: TreeConfig` + We need an `action_tree: ActionTree::new(tree_config)` + We need a `card_config: CardConfig` -2. **PostFlopGame**: We build a `PostFlopGame` from `action_tree` and `card_config`: + +2. **Create PostFlopGame**: We build a `PostFlopGame` from `action_tree` and `card_config`: ```rust let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); ``` -Once the game is created we need to allocate memory: +Once the game is created we need to allocate the following memory and initialize +its values: + `game.node_arena` + `game.storage1` @@ -28,12 +39,25 @@ Once the game is created we need to allocate memory: + `game.storage_ip` + `game.storage_chance` -These fields are not allocated at the same time. `game.node_arena` is allocated -via `with_config`, which calls `update_config`, which in turn calls `init_root`. +These fields are not allocated/initialized at the same time; ++ `game.node_arena` is allocated and initialized via `with_config`, ++ other storage is allocated via `game.allocate_memory()`. + +#### Allocating and Initializing `node_arena` + +We construct a `PostFlopGame` by calling +`PostFlopGame::with_config(card_config, action_tree)`, which calls +`update_config`. `PostFlopGame::update_config` sets up configuration data, +sanity checks things are correct, and then calls `self.init_root()`. + `init_root` is responsible for: -1. Allocating `PostFlopNode`s in `node_arena` -2. Invoking `build_tree_recursive` which initializes each node's child/parent +1. Counting number of `PostFlopNode`s to be allocated (`self.num_nodes`), broken + up by flop, turn, and river +2. Allocating `self.num_nodes` `PostFlopNode`s in `node_arena` field +3. Clearing storage: `self.clear_storage()` sets each storage item to a new + `Vec` +4. Invoking `build_tree_recursive` which initializes each node's child/parent relationship via `child_offset` (through calls to `push_actions` and `push_chances`). @@ -111,7 +135,7 @@ Some storage is allocated in `game::base::allocate_memory`: `node_arena` is allocated in `game::base::init_root()`: ```rust - let num_nodes = self.count_num_nodes(); + let num_nodes = self.count_nodes_per_street(); let total_num_nodes = num_nodes[0] + num_nodes[1] + num_nodes[2]; if total_num_nodes > u32::MAX as u64 From 798c4aa9dae1f598a768ce93b5bb102c7f909b11 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Fri, 9 Aug 2024 10:27:19 -0700 Subject: [PATCH 17/66] Linter issues --- DESIGN.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/DESIGN.md b/DESIGN.md index 99e5ca9..f5eb759 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -40,6 +40,7 @@ its values: + `game.storage_chance` These fields are not allocated/initialized at the same time; + + `game.node_arena` is allocated and initialized via `with_config`, + other storage is allocated via `game.allocate_memory()`. @@ -109,9 +110,9 @@ These are referenced from `PostFlopNode`: storage3: *mut u8, // IP cfvalues ``` -- `storage1` seems to store the strategy -- `storage2` seems to store regrets/cfvalues, and -- `storage3` stores IP's cf values (does that make `storage2` store OOP's cfvalues?) ++ `storage1` seems to store the strategy ++ `storage2` seems to store regrets/cfvalues, and ++ `storage3` stores IP's cf values (does that make `storage2` store OOP's cfvalues?) Storage is a byte vector `Vec`, and these store floating point values. @@ -170,8 +171,8 @@ When a serialized tree is deserialized, if it is a parital save (e.g., a Turn save) you will not be able to navigate to unsaved streets. Several things break when we deserialize a partial save: -- `node_arena` is only partially populated -- `node.children()` points to raw data when `node` points to an street that is ++ `node_arena` is only partially populated ++ `node.children()` points to raw data when `node` points to an street that is not serialized (e.g., a chance node before the river for a Turn save). ### Allocating `node_arena` @@ -183,8 +184,8 @@ recursively visits children of nodes and modifies them to ### Data Coupling/Relations/Invariants -- A node is locked iff it is contained in the game's locking_strategy -- `PostFlopGame.node_arena` is pointed to by `PostFlopNode.children_offset`. For ++ A node is locked iff it is contained in the game's locking_strategy ++ `PostFlopGame.node_arena` is pointed to by `PostFlopNode.children_offset`. For instance, this is the basic definition of the `PostFlopNode.children()` function: From cd790ebf9db3ea2a9025671d255deb471f95e890 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Fri, 9 Aug 2024 10:27:55 -0700 Subject: [PATCH 18/66] DESIGN.md --- DESIGN.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESIGN.md b/DESIGN.md index f5eb759..16f39f8 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -184,7 +184,7 @@ recursively visits children of nodes and modifies them to ### Data Coupling/Relations/Invariants -+ A node is locked iff it is contained in the game's locking_strategy ++ A node is locked IFF it is contained in the game's locking_strategy + `PostFlopGame.node_arena` is pointed to by `PostFlopNode.children_offset`. For instance, this is the basic definition of the `PostFlopNode.children()` function: From 6377ad3cd6dbc9ce3d07032ff7abc1f0f775ce70 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Fri, 9 Aug 2024 10:28:17 -0700 Subject: [PATCH 19/66] DESIGN.md --- DESIGN.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/DESIGN.md b/DESIGN.md index 16f39f8..b64d8f9 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -167,10 +167,11 @@ game.set_target_storage_mode(BoardState::Turn); ``` will ensure that when `game` is encoded, it will only save Flop and Turn data. -When a serialized tree is deserialized, if it is a parital save (e.g., a Turn +When a serialized tree is deserialized, if it is a partial save (e.g., a Turn save) you will not be able to navigate to unsaved streets. Several things break when we deserialize a partial save: + + `node_arena` is only partially populated + `node.children()` points to raw data when `node` points to an street that is not serialized (e.g., a chance node before the river for a Turn save). From 0cd0382f47924d4bc9a4eeec4e554e4d498fe935 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Fri, 9 Aug 2024 10:28:39 -0700 Subject: [PATCH 20/66] DESIGN.md --- DESIGN.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESIGN.md b/DESIGN.md index b64d8f9..97e05b8 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -197,4 +197,4 @@ recursively visits children of nodes and modifies them to ) ``` - We get a pointer to `self` and add children offset. \ No newline at end of file + We get a pointer to `self` and add children offset. From c03fb57311321c41cce66fadf2d6ae13f0fcc942 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Fri, 9 Aug 2024 11:20:04 -0700 Subject: [PATCH 21/66] Updates --- DESIGN.md | 60 +++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/DESIGN.md b/DESIGN.md index 97e05b8..3da8edc 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -12,26 +12,19 @@ To set up a `PostFlopGame` we need to **create a `PostFlopGame` instance**, **allocate global storage and `PostFlopNode`s**, and **initialize the `PostFlopNode` child/parent relationship**. This is done in several steps. +We begin by creating a `PostFlopGame` instance. -We begin by creating a `PostFlopGame`. A `PostFlopGame` requires an `ActionTree` -and a `CardConfig`. The `ActionTree` represents the full game tree modded out by -different runouts (so an `ActionTree` might have an abstract _line_ **Bet 10; -Call; Bet 30; Call** while the game tree would have concrete _nodes_ like -**Bet 10; Call; Th; Bet 30; Call**, etc). - -1. **Create Configurations**: - + We need a `tree_config: TreeConfig` - + We need an `action_tree: ActionTree::new(tree_config)` - + We need a `card_config: CardConfig` - -2. **Create PostFlopGame**: We build a `PostFlopGame` from `action_tree` and `card_config`: +```rust +let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); +``` - ```rust - let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); - ``` +A `PostFlopGame` requires an +`ActionTree` which describes all possible actions and lines (no runout +information), and a `CardConfig`, which describes player ranges and +flop/turn/river data. -Once the game is created we need to allocate the following memory and initialize -its values: +Once we have created a `PostFlopGame` instance we need to allocate the following +memory and initialize its values: + `game.node_arena` + `game.storage1` @@ -39,32 +32,40 @@ its values: + `game.storage_ip` + `game.storage_chance` -These fields are not allocated/initialized at the same time; +These fields are not allocated/initialized at the same time: -+ `game.node_arena` is allocated and initialized via `with_config`, ++ `game.node_arena` is allocated and initialized via `with_config()` (i.e., when + we created our `PostFlopGame`), + other storage is allocated via `game.allocate_memory()`. #### Allocating and Initializing `node_arena` -We construct a `PostFlopGame` by calling -`PostFlopGame::with_config(card_config, action_tree)`, which calls -`update_config`. `PostFlopGame::update_config` sets up configuration data, -sanity checks things are correct, and then calls `self.init_root()`. +We constructed a `PostFlopGame` by calling +`PostFlopGame::with_config(card_config, action_tree)`, which under the hood +actually calls: + +```rust + let mut game = Self::new(); + game.update_config(card_config, action_tree)?; +``` + +`PostFlopGame::update_config` sets up configuration data, sanity checks things +are correct, and then calls `self.init_root()`. `init_root` is responsible for: -1. Counting number of `PostFlopNode`s to be allocated (`self.num_nodes`), broken - up by flop, turn, and river -2. Allocating `self.num_nodes` `PostFlopNode`s in `node_arena` field +1. Counting number of `PostFlopNode`s to be allocated (`self.nodes_per_street`), + broken up by flop, turn, and river +2. Allocating `PostFlopNode`s in the `node_arena` field 3. Clearing storage: `self.clear_storage()` sets each storage item to a new `Vec` 4. Invoking `build_tree_recursive` which initializes each node's child/parent relationship via `child_offset` (through calls to `push_actions` and `push_chances`). -Each `PostFlopNode` points to node-specific data (eg., strategies and cfregrets) -that is located inside of `PostFlopGame.storage*` fields (which is currently -unallocated) via similarly named fields `PostFlopNode.storage*`. +Each `PostFlopNode` points to node-specific data (e.g., strategies and +cfregrets) that is located inside of `PostFlopGame.storage*` fields (which is +currently unallocated) via similarly named fields `PostFlopNode.storage*`. Additionally, each node points to the children offset with `children_offset`, which records where in `node_arena` relative to the current node that node's @@ -182,7 +183,6 @@ We want to first allocate nodes for `node_arena`, and then run some form of `build_tree_recursive`. This assumes that `node_arena` is already allocated, and recursively visits children of nodes and modifies them to - ### Data Coupling/Relations/Invariants + A node is locked IFF it is contained in the game's locking_strategy From 694c61b99a35adcfb44dc7d7aaaec412e57ec30d Mon Sep 17 00:00:00 2001 From: bkushigian Date: Wed, 14 Aug 2024 11:11:36 -0700 Subject: [PATCH 22/66] tmp commit --- src/game/base.rs | 177 ++++++++++++++++++++++++++++++++++++++-- src/game/interpreter.rs | 9 +- src/game/mod.rs | 5 +- src/game/node.rs | 16 ++++ 4 files changed, 196 insertions(+), 11 deletions(-) diff --git a/src/game/base.rs b/src/game/base.rs index ba14b4d..6dc26c8 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -463,7 +463,7 @@ impl PostFlopGame { } pub fn print_internal_data(&self) { - println!("Printing intenral data for PostFlopGame"); + println!("Printing internal data for PostFlopGame"); println!("- node_arena: {}", self.node_arena.len()); println!("- storage1: {}", self.storage1.len()); println!("- storage2: {}", self.storage2.len()); @@ -534,7 +534,15 @@ impl PostFlopGame { ) = self.card_config.isomorphism(&self.private_cards); } - /// Initializes the root node of game tree. + /// Initializes the root node of game tree and recursively build the tree. + /// + /// This function is responsible for computing the number of nodes required + /// for each street (via `count_nodes_per_street()`), allocating + /// `PostFlopNode`s to `self.node_arena`, and calling `build_tree_recursive`, + /// which recursively visits all nodes and, among other things, initializes + /// the child/parent relation. + /// + /// This does _not_ allocate global storage (e.g., `self.storage1`, etc). fn init_root(&mut self) -> Result<(), String> { let nodes_per_street = self.count_nodes_per_street(); let total_num_nodes = nodes_per_street[0] + nodes_per_street[1] + nodes_per_street[2]; @@ -741,6 +749,7 @@ impl PostFlopGame { node.num_children += 1; let mut child = node.children().last().unwrap().lock(); child.prev_action = Action::Chance(card); + child.parent_node_index = node_index; child.turn = card; } } @@ -804,6 +813,7 @@ impl PostFlopGame { child.prev_action = *action; child.turn = node.turn; child.river = node.river; + child.parent_node_index = node_index; } let num_private_hands = self.num_private_hands(node.player as usize); @@ -817,6 +827,159 @@ impl PostFlopGame { info.num_storage_ip += node.num_elements_ip as u64; } + /* REBUILDING AND RESOLVING TREE */ + + /// Like `init_root`, but applied to a partial save loaded from disk. This + /// reallocates missing `PostFlopNode`s to `node_arena` and reruns + /// `build_tree_recursive`. Rerunning `build_tree_recursive` will not alter + /// nodes loaded from disk. + pub fn reinit_root(&mut self) -> Result<(), String> { + let nodes_per_street = self.count_nodes_per_street(); + let total_num_nodes = nodes_per_street[0] + nodes_per_street[1] + nodes_per_street[2]; + + if total_num_nodes > u32::MAX as u64 + || mem::size_of::() as u64 * total_num_nodes > isize::MAX as u64 + { + return Err("Too many nodes".to_string()); + } + + self.num_nodes_per_street = nodes_per_street; + + let total_new_nodes_to_allocate = total_num_nodes - self.node_arena.len() as u64; + self.node_arena.append( + &mut (0..total_new_nodes_to_allocate) + .map(|_| MutexLike::new(PostFlopNode::default())) + .collect::>(), + ); + // self.clear_storage(); + + let mut info = BuildTreeInfo { + turn_index: nodes_per_street[0] as usize, + river_index: (nodes_per_street[0] + nodes_per_street[1]) as usize, + ..Default::default() + }; + + match self.tree_config.initial_state { + BoardState::Flop => info.flop_index += 1, + BoardState::Turn => info.turn_index += 1, + BoardState::River => info.river_index += 1, + } + + let mut root = self.node_arena[0].lock(); + root.turn = self.card_config.turn; + root.river = self.card_config.river; + + self.build_tree_recursive(0, &self.action_root.lock(), &mut info); + + self.num_storage = info.num_storage; + self.num_storage_ip = info.num_storage_ip; + self.num_storage_chance = info.num_storage_chance; + self.misc_memory_usage = self.memory_usage_internal(); + + Ok(()) + } + + pub fn reload_and_resolve(&mut self, enable_compression: bool) -> Result<(), String> { + self.allocate_memory_after_load(enable_compression)?; + self.reinit_root()?; + + // Collect root nodes to resolve + let nodes_to_solve = match self.storage_mode { + BoardState::Flop => { + let turn_root_nodes = self + .node_arena + .iter() + .filter(|n| { + n.lock().turn != NOT_DEALT + && n.lock().river == NOT_DEALT + && matches!(n.lock().prev_action, Action::Chance(..)) + }) + .collect::>(); + turn_root_nodes + } + BoardState::Turn => { + let river_root_nodes = self + .node_arena + .iter() + .filter(|n| { + n.lock().turn != NOT_DEALT + && matches!(n.lock().prev_action, Action::Chance(..)) + }) + .collect::>(); + river_root_nodes + } + BoardState::River => vec![], + }; + for node in nodes_to_solve { + // Get history of this node + // let mut history = vec![]; + let mut n = node.lock(); + while n.parent_node_index < usize::MAX { + let parent = self.node_arena[n.parent_node_index].lock(); + let action = n.prev_action; + } + } + + Ok(()) + } + + /// Reallocate memory for full tree after performing a partial load + pub fn allocate_memory_after_load(&mut self, enable_compression: bool) -> Result<(), String> { + if self.state <= State::Uninitialized { + return Err("Game is not successfully initialized".to_string()); + } + + if self.state == State::MemoryAllocated + && self.storage_mode == BoardState::River + && self.is_compression_enabled == enable_compression + { + return Ok(()); + } + + let num_bytes = if enable_compression { 2 } else { 4 }; + if num_bytes * self.num_storage > isize::MAX as u64 + || num_bytes * self.num_storage_chance > isize::MAX as u64 + { + return Err("Memory usage exceeds maximum size".to_string()); + } + + self.state = State::MemoryAllocated; + self.is_compression_enabled = enable_compression; + + let old_storage1 = std::mem::replace(&mut self.storage1, vec![]); + let old_storage2 = std::mem::replace(&mut self.storage2, vec![]); + let old_storage_ip = std::mem::replace(&mut self.storage_ip, vec![]); + let old_storage_chance = std::mem::replace(&mut self.storage_chance, vec![]); + + let storage_bytes = (num_bytes * self.num_storage) as usize; + let storage_ip_bytes = (num_bytes * self.num_storage_ip) as usize; + let storage_chance_bytes = (num_bytes * self.num_storage_chance) as usize; + + self.storage1 = vec![0; storage_bytes]; + self.storage2 = vec![0; storage_bytes]; + self.storage_ip = vec![0; storage_ip_bytes]; + self.storage_chance = vec![0; storage_chance_bytes]; + + self.allocate_memory_nodes(); + + self.storage_mode = BoardState::River; + self.target_storage_mode = BoardState::River; + + for (dst, src) in self.storage1.iter_mut().zip(&old_storage1) { + *dst = *src; + } + for (dst, src) in self.storage2.iter_mut().zip(&old_storage2) { + *dst = *src; + } + for (dst, src) in self.storage_ip.iter_mut().zip(&old_storage_ip) { + *dst = *src; + } + for (dst, src) in self.storage_chance.iter_mut().zip(&old_storage_chance) { + *dst = *src; + } + Ok(()) + } + /// Sets the bunching effect. fn set_bunching_effect_internal(&mut self, bunching_data: &BunchingData) -> Result<(), String> { self.bunching_num_dead_cards = bunching_data.fold_ranges().len() * 2; @@ -1004,7 +1167,7 @@ impl PostFlopGame { if self.card_config.river != NOT_DEALT { self.bunching_arena = arena; - self.assign_zero_weights(); + self.assign_zero_weights_to_dead_cards(); return Ok(()); } @@ -1119,7 +1282,7 @@ impl PostFlopGame { if self.card_config.turn != NOT_DEALT { self.bunching_arena = arena; - self.assign_zero_weights(); + self.assign_zero_weights_to_dead_cards(); return Ok(()); } @@ -1197,7 +1360,7 @@ impl PostFlopGame { } self.bunching_arena = arena; - self.assign_zero_weights(); + self.assign_zero_weights_to_dead_cards(); Ok(()) } @@ -1463,4 +1626,8 @@ impl PostFlopGame { } } } + + pub fn get_state(&self) -> &State { + return &self.state; + } } diff --git a/src/game/interpreter.rs b/src/game/interpreter.rs index d44d63a..d7845b6 100644 --- a/src/game/interpreter.rs +++ b/src/game/interpreter.rs @@ -30,7 +30,7 @@ impl PostFlopGame { self.weights[0].copy_from_slice(&self.initial_weights[0]); self.weights[1].copy_from_slice(&self.initial_weights[1]); - self.assign_zero_weights(); + self.assign_zero_weights_to_dead_cards(); } /// Returns the history of the current node. @@ -366,7 +366,7 @@ impl PostFlopGame { } // update the weights - self.assign_zero_weights(); + self.assign_zero_weights_to_dead_cards(); } // player node else { @@ -1005,8 +1005,9 @@ impl PostFlopGame { unsafe { node_ptr.offset_from(self.node_arena.as_ptr()) as usize } } - /// Assigns zero weights to the hands that are not possible. - pub(super) fn assign_zero_weights(&mut self) { + /// Assigns zero weights to the hands that are not possible (e.g, by a card + /// being removed by a turn or river). + pub(super) fn assign_zero_weights_to_dead_cards(&mut self) { if self.bunching_num_dead_cards == 0 { let mut board_mask: u64 = 0; if self.turn != NOT_DEALT { diff --git a/src/game/mod.rs b/src/game/mod.rs index 09042fd..3504adb 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -17,10 +17,10 @@ use std::collections::BTreeMap; #[cfg(feature = "bincode")] use bincode::{Decode, Encode}; -#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] #[repr(u8)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] -enum State { +pub enum State { ConfigError = 0, #[default] Uninitialized = 1, @@ -121,6 +121,7 @@ pub struct PostFlopGame { #[repr(C)] pub struct PostFlopNode { prev_action: Action, + parent_node_index: usize, player: u8, turn: Card, river: Card, diff --git a/src/game/node.rs b/src/game/node.rs index e6bd7e3..4c6fd4b 100644 --- a/src/game/node.rs +++ b/src/game/node.rs @@ -209,6 +209,7 @@ impl Default for PostFlopNode { fn default() -> Self { Self { prev_action: Action::None, + parent_node_index: usize::MAX, player: PLAYER_OOP, turn: NOT_DEALT, river: NOT_DEALT, @@ -240,4 +241,19 @@ impl PostFlopNode { ) } } + + /// Get a list of available actions at a given node + pub fn actions(&self) -> Vec { + self.children() + .iter() + .map(|n| n.lock().prev_action) + .collect::>() + } + + /// Find the index of a given action, if present + pub fn action_index(&self, action: Action) -> Option { + self.children() + .iter() + .position(|n| n.lock().prev_action == action) + } } From c2c52f5b18e413709f64c8d2564fa0aaac4d3166 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Wed, 14 Aug 2024 15:59:10 -0700 Subject: [PATCH 23/66] Recursively compute history --- src/game/node.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/game/node.rs b/src/game/node.rs index 4c6fd4b..400a886 100644 --- a/src/game/node.rs +++ b/src/game/node.rs @@ -256,4 +256,16 @@ impl PostFlopNode { .iter() .position(|n| n.lock().prev_action == action) } + + /// Recursively compute the current node's history + pub fn compute_history_recursive(&self, game: &PostFlopGame) -> Option> { + if self.parent_node_index == usize::MAX { + Some(vec![]) + } else { + let p = game.node_arena.get(self.parent_node_index)?; + let mut history = p.lock().compute_history_recursive(game)?; + history.push(p.lock().action_index(self.prev_action)?); + Some(history) + } + } } From a23e31906fabfb824ebb01a6bcaa8d061bd49325 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sat, 17 Aug 2024 13:36:24 -0700 Subject: [PATCH 24/66] Tmp commit --- examples/file_io_debug.rs | 267 +++++++++++++++++++++++++++----------- src/file.rs | 42 ++++++ src/game/base.rs | 102 ++++++++++----- src/game/interpreter.rs | 20 +++ src/solver.rs | 68 ++++++++++ 5 files changed, 387 insertions(+), 112 deletions(-) diff --git a/examples/file_io_debug.rs b/examples/file_io_debug.rs index f0348a6..d5a625f 100644 --- a/examples/file_io_debug.rs +++ b/examples/file_io_debug.rs @@ -1,97 +1,210 @@ -use std::fs::File; - use postflop_solver::*; -fn main() { - // see `basic.rs` for the explanation of the following code +fn recursive_compare_strategies_helper( + saved: &mut PostFlopGame, + loaded: &mut PostFlopGame, + storage_mode: BoardState, +) { + let history = saved.history().to_vec(); + saved.cache_normalized_weights(); + loaded.cache_normalized_weights(); + + // Check if OOP hands have the same evs + let evs_oop_1 = saved.expected_values(0); + let ws_oop_1 = saved.weights(0); + let evs_oop_2 = loaded.expected_values(1); + let ws_oop_2 = saved.weights(0); + + assert!(ws_oop_1.len() == ws_oop_2.len()); + for (w1, w2) in ws_oop_1.iter().zip(ws_oop_2) { + assert!((w1 - w2).abs() < 0.001); + } + for (i, (e1, e2)) in evs_oop_1.iter().zip(&evs_oop_2).enumerate() { + assert!((e1 - e2).abs() < 0.001, "ev diff({}): {}", i, e1 - e2); + } + + let ev_oop_1 = compute_average(&evs_oop_1, &ws_oop_1); + let ev_oop_2 = compute_average(&evs_oop_2, &ws_oop_2); + + let ev_diff = (ev_oop_1 - ev_oop_2).abs(); + println!("EV Diff: {:0.2}", ev_diff); + assert!((ev_oop_1 - ev_oop_2).abs() < 0.01); + for child_index in 0..saved.available_actions().len() { + saved.play(child_index); + loaded.play(child_index); + + recursive_compare_strategies_helper(saved, loaded, storage_mode); + + saved.apply_history(&history); + loaded.apply_history(&history); + } +} + +fn compare_strategies( + saved: &mut PostFlopGame, + loaded: &mut PostFlopGame, + storage_mode: BoardState, +) { + saved.back_to_root(); + loaded.back_to_root(); + saved.cache_normalized_weights(); + loaded.cache_normalized_weights(); + for (i, ((e1, e2), cards)) in saved + .expected_values(0) + .iter() + .zip(loaded.expected_values(0)) + .zip(saved.private_cards(0)) + .enumerate() + { + println!("ev {}: {}:{}", hole_to_string(*cards).unwrap(), e1, e2); + } + for (i, ((e1, e2), cards)) in saved + .expected_values(1) + .iter() + .zip(loaded.expected_values(1)) + .zip(saved.private_cards(1)) + .enumerate() + { + println!("ev {}: {}:{}", hole_to_string(*cards).unwrap(), e1, e2); + } + recursive_compare_strategies_helper(saved, loaded, storage_mode); +} + +fn print_strats_at_current_node( + g1: &mut PostFlopGame, + g2: &mut PostFlopGame, + actions: &Vec, +) { + let action_string = actions + .iter() + .map(|a| format!("{:?}", a)) + .collect::>() + .join(":"); - let oop_range = "66+,A8s+,A5s-A4s,AJo+,K9s+,KQo,QTs+,JTs,96s+,85s+,75s+,65s,54s"; - let ip_range = "QQ-22,AQs-A2s,ATo+,K5s+,KJo+,Q8s+,J8s+,T7s+,96s+,86s+,75s+,64s+,53s+"; + let player = g1.current_player(); + + println!( + "\x1B[32;1mActions To Reach Node\x1B[0m: [{}]", + action_string + ); + // Print high level node data + if g1.is_chance_node() { + println!("\x1B[32;1mPlayer\x1B[0m: Chance"); + } else if g1.is_terminal_node() { + if player == 0 { + println!("\x1B[32;1mPlayer\x1B[0m: OOP (Terminal)"); + } else { + println!("\x1B[32;1mPlayer\x1B[0m: IP (Terminal)"); + } + } else { + if player == 0 { + println!("\x1B[32;1mPlayer\x1B[0m: OOP"); + } else { + println!("\x1B[32;1mPlayer\x1B[0m: IP"); + } + let private_cards = g1.private_cards(player); + let strat1 = g1.strategy_by_private_hand(); + let strat2 = g2.strategy_by_private_hand(); + let weights1 = g1.weights(player); + let weights2 = g2.weights(player); + let actions = g1.available_actions(); + + // Print both games strategies + for ((cards, (w1, s1)), (w2, s2)) in private_cards + .iter() + .zip(weights1.iter().zip(strat1)) + .zip(weights2.iter().zip(strat2)) + { + let hole_cards = hole_to_string(*cards).unwrap(); + print!("\x1B[34;1m{hole_cards}\x1B[0m@({:.2} v {:.2}) ", w1, w2); + let mut action_frequencies = vec![]; + for (a, (freq1, freq2)) in actions.iter().zip(s1.iter().zip(s2)) { + action_frequencies.push(format!( + "\x1B[32;1m{:?}\x1B[0m: \x1B[31m{:0.3}\x1B[0m v \x1B[33m{:0>.3}\x1B[0m", + a, freq1, freq2 + )) + } + println!("{}", action_frequencies.join(" ")); + } + } +} + +fn main() { + let oop_range = "AA,QQ"; + let ip_range = "KK"; let card_config = CardConfig { range: [oop_range.parse().unwrap(), ip_range.parse().unwrap()], - flop: flop_from_str("Td9d6h").unwrap(), - turn: card_from_str("Qc").unwrap(), - river: NOT_DEALT, + flop: flop_from_str("3h3s3d").unwrap(), + ..Default::default() }; - let bet_sizes = BetSizeOptions::try_from(("60%, e, a", "2.5x")).unwrap(); - let tree_config = TreeConfig { - initial_state: BoardState::Turn, - starting_pot: 200, - effective_stack: 900, + starting_pot: 100, + effective_stack: 100, rake_rate: 0.0, rake_cap: 0.0, - flop_bet_sizes: [bet_sizes.clone(), bet_sizes.clone()], - turn_bet_sizes: [bet_sizes.clone(), bet_sizes.clone()], - river_bet_sizes: [bet_sizes.clone(), bet_sizes], - turn_donk_sizes: None, - river_donk_sizes: Some(DonkSizeOptions::try_from("50%").unwrap()), - add_allin_threshold: 1.5, - force_allin_threshold: 0.15, - merging_threshold: 0.1, + flop_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + turn_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + river_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + ..Default::default() }; let action_tree = ActionTree::new(tree_config).unwrap(); - let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); - game.allocate_memory(false); - - let max_num_iterations = 20; - let target_exploitability = game.tree_config().starting_pot as f32 * 0.01; - solve(&mut game, max_num_iterations, target_exploitability, true); - let r = game.set_target_storage_mode(BoardState::Turn); - println!("{r:?}"); - - // save the solved game tree to a file - // 4th argument is zstd compression level (1-22); requires `zstd` feature to use - save_data_to_file(&game, "memo string", "filename.bin", None).unwrap(); - - // load the solved game tree from a file - // 2nd argument is the maximum memory usage in bytes - let (mut game2, _memo_string): (PostFlopGame, _) = - load_data_from_file("filename.bin", None).unwrap(); - - println!("Game 1 Internal Data"); - game.print_internal_data(); - println!("Game 2 Internal Data"); - game2.print_internal_data(); - - // check if the loaded game tree is the same as the original one - game.cache_normalized_weights(); - game2.cache_normalized_weights(); - assert_eq!(game.equity(0), game2.equity(0)); - - // discard information after the river deal when serializing - // this operation does not lose any information of the game tree itself - game2.set_target_storage_mode(BoardState::Turn).unwrap(); - - // compare the memory usage for serialization - println!( - "Memory usage of the original game tree: {:.2}MB", // 11.50MB - game.target_memory_usage() as f64 / (1024.0 * 1024.0) - ); - println!( - "Memory usage of the truncated game tree: {:.2}MB", // 0.79MB - game2.target_memory_usage() as f64 / (1024.0 * 1024.0) - ); + let mut game1 = PostFlopGame::with_config(card_config, action_tree).unwrap(); + game1.allocate_memory(false); + + solve(&mut game1, 100, 0.01, false); + + // save (turn) + game1.set_target_storage_mode(BoardState::Turn).unwrap(); + save_data_to_file(&game1, "", "tmpfile.flop", None).unwrap(); + + // load (turn) + let mut game2: PostFlopGame = load_data_from_file("tmpfile.flop", None).unwrap().0; + // compare_strategies(&mut game, &mut game2, BoardState::Turn); + assert!(game2.rebuild_and_resolve_forgotten_streets().is_ok()); + + let mut actions_so_far = vec![]; + + // Print Root Node + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - // overwrite the file with the truncated game tree - // game tree constructed from this file cannot access information after the river deal - save_data_to_file(&game2, "memo string", "filename.bin", None).unwrap(); - let (mut game3, _memo_string): (PostFlopGame, String) = - load_data_from_file("filename.bin", None).unwrap(); + // OOP: Check + actions_so_far.push(game1.available_actions()[0]); + game1.play(0); + game2.play(0); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); + + // IP: Check + actions_so_far.push(game1.available_actions()[0]); + game1.play(0); + game2.play(0); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - game.play(0); - game.play(0); - println!("Game X/X Actions: {:?}", game.available_actions()); + // Chance: 2c + actions_so_far.push(game1.available_actions()[0]); + game1.play(0); game2.play(0); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); + + // OOP: CHECK + actions_so_far.push(game1.available_actions()[0]); + game1.play(0); game2.play(0); - println!("Game2 X/X Actions: {:?}", game.available_actions()); - game3.play(0); - game3.play(0); - println!("Game3 X/X Actions: {:?}", game3.available_actions()); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); + + // IP: CHECK + actions_so_far.push(game1.available_actions()[0]); + game1.play(0); + game2.play(0); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); + + // CHANCE: 0 + actions_so_far.push(game1.available_actions()[1]); + game1.play(1); + game2.play(1); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - // delete the file - std::fs::remove_file("filename.bin").unwrap(); + // compare_strategies(&mut game, &mut game2, BoardState::Turn); } diff --git a/src/file.rs b/src/file.rs index 17e7d5b..d7e90f7 100644 --- a/src/file.rs +++ b/src/file.rs @@ -269,6 +269,7 @@ mod tests { use crate::action_tree::*; use crate::card::*; use crate::range::*; + use crate::solver::solve; use crate::utility::*; #[test] @@ -375,4 +376,45 @@ mod tests { assert!((root_ev_oop - 45.0).abs() < 1e-4); assert!((root_ev_ip - 15.0).abs() < 1e-4); } + + #[test] + fn test_reload_and_resolve() { + let oop_range = "AA,QQ"; + let ip_range = "KK"; + + let card_config = CardConfig { + range: [oop_range.parse().unwrap(), ip_range.parse().unwrap()], + flop: flop_from_str("3h3s3d").unwrap(), + ..Default::default() + }; + + let tree_config = TreeConfig { + starting_pot: 100, + effective_stack: 100, + rake_rate: 0.0, + rake_cap: 0.0, + flop_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + turn_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + river_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + ..Default::default() + }; + + let action_tree = ActionTree::new(tree_config).unwrap(); + let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); + println!( + "memory usage: {:.2}GB", + game.memory_usage().0 as f64 / (1024.0 * 1024.0 * 1024.0) + ); + game.allocate_memory(false); + + solve(&mut game, 100, 0.01, false); + + // save (turn) + game.set_target_storage_mode(BoardState::Turn).unwrap(); + save_data_to_file(&game, "", "tmpfile.flop", None).unwrap(); + + // load (turn) + let mut game: PostFlopGame = load_data_from_file("tmpfile.flop", None).unwrap().0; + assert!(game.rebuild_and_resolve_forgotten_streets().is_ok()); + } } diff --git a/src/game/base.rs b/src/game/base.rs index 6dc26c8..94931c1 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -1,13 +1,14 @@ use super::*; use crate::bunching::*; use crate::interface::*; +use crate::solve_with_node_as_root; use crate::utility::*; use std::mem::{self, MaybeUninit}; #[cfg(feature = "rayon")] use rayon::prelude::*; -#[derive(Default)] +#[derive(Default, Debug)] struct BuildTreeInfo { flop_index: usize, turn_index: usize, @@ -768,6 +769,7 @@ impl PostFlopGame { node.num_children += 1; let mut child = node.children().last().unwrap().lock(); child.prev_action = Action::Chance(card); + child.parent_node_index = node_index; child.turn = node.turn; child.river = card; } @@ -845,12 +847,9 @@ impl PostFlopGame { self.num_nodes_per_street = nodes_per_street; - let total_new_nodes_to_allocate = total_num_nodes - self.node_arena.len() as u64; - self.node_arena.append( - &mut (0..total_new_nodes_to_allocate) - .map(|_| MutexLike::new(PostFlopNode::default())) - .collect::>(), - ); + self.node_arena = (0..total_num_nodes) + .map(|_| MutexLike::new(PostFlopNode::default())) + .collect::>(); // self.clear_storage(); let mut info = BuildTreeInfo { @@ -879,64 +878,97 @@ impl PostFlopGame { Ok(()) } - pub fn reload_and_resolve(&mut self, enable_compression: bool) -> Result<(), String> { - self.allocate_memory_after_load(enable_compression)?; + pub fn rebuild_and_resolve_forgotten_streets(&mut self) -> Result<(), String> { + self.check_card_config()?; self.reinit_root()?; + self.allocate_memory_after_load()?; + self.resolve_reloaded_nodes(1000, 0.01, false) + } - // Collect root nodes to resolve - let nodes_to_solve = match self.storage_mode { + /// Return the node index for each root of the forgotten gametrees that were + /// omitted during a partial save. + /// + /// When we perform a partial save (e.g., a flop save), we lose + /// cfvalues/strategy data for all subtrees rooted at the forgotten street + /// (in the case of a flop save, this would be all subtrees rooted at the + /// beginning of the turn). + /// + /// To regain this information we need to resolve each of these subtrees + /// individually. This function collects the index of each such root. + pub fn collect_unsolved_roots_after_reload(&mut self) -> Result, String> { + match self.storage_mode { BoardState::Flop => { let turn_root_nodes = self .node_arena .iter() - .filter(|n| { + .enumerate() + .filter(|(_, n)| { n.lock().turn != NOT_DEALT && n.lock().river == NOT_DEALT && matches!(n.lock().prev_action, Action::Chance(..)) }) + .map(|(i, _)| i) .collect::>(); - turn_root_nodes + Ok(turn_root_nodes) } BoardState::Turn => { let river_root_nodes = self .node_arena .iter() - .filter(|n| { + .enumerate() + .filter(|(_, n)| { n.lock().turn != NOT_DEALT && matches!(n.lock().prev_action, Action::Chance(..)) }) + .map(|(i, _)| i) .collect::>(); - river_root_nodes - } - BoardState::River => vec![], - }; - for node in nodes_to_solve { - // Get history of this node - // let mut history = vec![]; - let mut n = node.lock(); - while n.parent_node_index < usize::MAX { - let parent = self.node_arena[n.parent_node_index].lock(); - let action = n.prev_action; + Ok(river_root_nodes) } + BoardState::River => Ok(vec![]), } + } + + pub fn resolve_reloaded_nodes( + &mut self, + max_num_iterations: u32, + target_exploitability: f32, + print_progress: bool, + ) -> Result<(), String> { + let nodes_to_solve = self.collect_unsolved_roots_after_reload()?; + self.state = State::MemoryAllocated; + for node_idx in nodes_to_solve { + let node = self.node_arena.get(node_idx).ok_or("Invalid node index")?; + // let history = node + // .lock() + // .compute_history_recursive(&self) + // .ok_or("Unable to compute history for node".to_string())? + // .to_vec(); + // self.apply_history(&history); + solve_with_node_as_root( + self, + node.lock(), + max_num_iterations, + target_exploitability, + print_progress, + ); + } + finalize(self); Ok(()) } - /// Reallocate memory for full tree after performing a partial load - pub fn allocate_memory_after_load(&mut self, enable_compression: bool) -> Result<(), String> { + /// Reallocate memory for full tree after performing a partial load. This + /// must be called after `init_root()` + pub fn allocate_memory_after_load(&mut self) -> Result<(), String> { if self.state <= State::Uninitialized { return Err("Game is not successfully initialized".to_string()); } - if self.state == State::MemoryAllocated - && self.storage_mode == BoardState::River - && self.is_compression_enabled == enable_compression - { + if self.state == State::MemoryAllocated && self.storage_mode == BoardState::River { return Ok(()); } - let num_bytes = if enable_compression { 2 } else { 4 }; + let num_bytes = if self.is_compression_enabled { 2 } else { 4 }; if num_bytes * self.num_storage > isize::MAX as u64 || num_bytes * self.num_storage_chance > isize::MAX as u64 { @@ -944,7 +976,7 @@ impl PostFlopGame { } self.state = State::MemoryAllocated; - self.is_compression_enabled = enable_compression; + // self.is_compression_enabled = self.is_compression_enabled; let old_storage1 = std::mem::replace(&mut self.storage1, vec![]); let old_storage2 = std::mem::replace(&mut self.storage2, vec![]); @@ -960,7 +992,7 @@ impl PostFlopGame { self.storage_ip = vec![0; storage_ip_bytes]; self.storage_chance = vec![0; storage_chance_bytes]; - self.allocate_memory_nodes(); + self.allocate_memory_nodes(); // Assign node storage pointers self.storage_mode = BoardState::River; self.target_storage_mode = BoardState::River; @@ -1595,7 +1627,7 @@ impl PostFlopGame { Ok(info) } - /// Allocates memory recursively. + /// Assigns allocated storage memory. fn allocate_memory_nodes(&mut self) { let num_bytes = if self.is_compression_enabled { 2 } else { 4 }; let mut action_counter = 0; diff --git a/src/game/interpreter.rs b/src/game/interpreter.rs index d7845b6..619011d 100644 --- a/src/game/interpreter.rs +++ b/src/game/interpreter.rs @@ -763,6 +763,7 @@ impl PostFlopGame { node.cfvalues_ip().to_vec() } } else if player == self.current_player() { + println!("BINGO"); have_actions = true; if self.is_compression_enabled { let slice = node.cfvalues_compressed(); @@ -846,6 +847,25 @@ impl PostFlopGame { ret } + pub fn strategy_by_private_hand(&self) -> Vec> { + let strat = self.strategy(); + let player = self.current_player(); + let num_hands = self.private_cards(player).len(); + let num_actions = self.available_actions().len(); + assert!(num_hands * num_actions == strat.len()); + let mut strat_by_hand: Vec> = Vec::with_capacity(num_hands); + for j in 0..num_hands { + strat_by_hand.push(Vec::with_capacity(num_actions)); + } + + for i in 0..num_actions { + for j in 0..num_hands { + strat_by_hand[j].push(strat[i * num_hands + j]); + } + } + strat_by_hand + } + /// Returns the total bet amount of each player (OOP, IP). #[inline] pub fn total_bet_amount(&self) -> [i32; 2] { diff --git a/src/solver.rs b/src/solver.rs index 5a1cc5a..02db93c 100644 --- a/src/solver.rs +++ b/src/solver.rs @@ -132,6 +132,74 @@ pub fn solve_step(game: &T, current_iteration: u32) { } } +/// Performs Discounted CFR algorithm until the given number of iterations or exploitability is +/// satisfied. +/// +/// This method returns the exploitability of the obtained strategy. +pub fn solve_with_node_as_root( + game: &mut T, + mut root: MutexGuardLike, + max_num_iterations: u32, + target_exploitability: f32, + print_progress: bool, +) -> f32 { + if game.is_solved() { + panic!("Game is already solved"); + } + + if !game.is_ready() { + panic!("Game is not ready"); + } + + let mut exploitability = compute_exploitability(game); + + if print_progress { + print!("iteration: 0 / {max_num_iterations} "); + print!("(exploitability = {exploitability:.4e})"); + io::stdout().flush().unwrap(); + } + + for t in 0..max_num_iterations { + if exploitability <= target_exploitability { + break; + } + + let params = DiscountParams::new(t); + + // alternating updates + for player in 0..2 { + let mut result = Vec::with_capacity(game.num_private_hands(player)); + solve_recursive( + result.spare_capacity_mut(), + game, + &mut root, + player, + game.initial_weights(player ^ 1), + ¶ms, + ); + } + + if (t + 1) % 10 == 0 || t + 1 == max_num_iterations { + exploitability = compute_exploitability(game); + } + + if print_progress { + print!("\riteration: {} / {} ", t + 1, max_num_iterations); + print!("(exploitability = {exploitability:.4e})"); + io::stdout().flush().unwrap(); + } + } + + if print_progress { + println!(); + io::stdout().flush().unwrap(); + } + + finalize(game); + + exploitability +} + /// Recursively solves the counterfactual values. fn solve_recursive( result: &mut [MaybeUninit], From 414361a631ed059a19cf3cc29576a0789e0d5f6c Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 11:27:52 -0700 Subject: [PATCH 25/66] Documented some sliceops --- src/sliceop.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/sliceop.rs b/src/sliceop.rs index bf726b6..75d0f00 100644 --- a/src/sliceop.rs +++ b/src/sliceop.rs @@ -32,6 +32,7 @@ pub(crate) fn div_slice_uninit( }); } +/// Multiply a source slice by a scalar and store in a destination slice #[inline] pub(crate) fn mul_slice_scalar_uninit(dst: &mut [MaybeUninit], src: &[f32], scalar: f32) { dst.iter_mut().zip(src).for_each(|(d, s)| { @@ -39,6 +40,17 @@ pub(crate) fn mul_slice_scalar_uninit(dst: &mut [MaybeUninit], src: &[f32], }); } +/// Compute a _strided summation_ of `f32` elements in `src`, where the stride +/// length is `dst.len()`. +/// +/// In more detail, break source slice `src` into `N` chunks `C0...CN-1`, where +/// `N = dst.len()`, and set the `i`th element of `dst` to be the sum of the +/// `i`th element of each chunk `Ck`: +/// +/// - `dst[0] = SUM(k=0..N-1, Ck[0])` +/// - `dst[1] = SUM(k=0..N-1, Ck[1])` +/// - `dst[2] = SUM(k=0..N-1, Ck[2])` +/// - ... #[inline] pub(crate) fn sum_slices_uninit<'a>(dst: &'a mut [MaybeUninit], src: &[f32]) -> &'a mut [f32] { let len = dst.len(); @@ -54,6 +66,17 @@ pub(crate) fn sum_slices_uninit<'a>(dst: &'a mut [MaybeUninit], src: &[f32] dst } +/// Compute a _strided summation_ of `f64` elements in `src`, where the stride +/// length is `dst.len()`. +/// +/// In more detail, break source slice `src` into `N` chunks `C0...CN-1`, where +/// `N = dst.len()`, and set the `i`th element of `dst` to be the sum of the +/// `i`th element of each chunk `Ck`: +/// +/// - `dst[0] = SUM(k=0..N-1, Ck[0])` +/// - `dst[1] = SUM(k=0..N-1, Ck[1])` +/// - `dst[2] = SUM(k=0..N-1, Ck[2])` +/// - ... #[inline] pub(crate) fn sum_slices_f64_uninit<'a>( dst: &'a mut [MaybeUninit], From dd8131862ee6e59d0298758e1598bae9fcd175a5 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 11:38:42 -0700 Subject: [PATCH 26/66] Documented sliceops --- src/sliceop.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/sliceop.rs b/src/sliceop.rs index 75d0f00..b876b5d 100644 --- a/src/sliceop.rs +++ b/src/sliceop.rs @@ -259,11 +259,27 @@ pub(crate) fn inner_product_cond( acc.iter().sum::() as f32 } +/// Extract a reference to a specific "row" from a one-dimensional slice, where +/// the data is conceptually arranged as a two-dimensional array. +/// +/// # Arguments +/// +/// * `slice` - slice to extract a reference from +/// * `index` - the index of the conceptual "row" to reference +/// * `row_size` - the size of the conceptual "row" to reference #[inline] pub(crate) fn row(slice: &[T], index: usize, row_size: usize) -> &[T] { &slice[index * row_size..(index + 1) * row_size] } +/// Extract a mutable reference to a specific "row" from a one-dimensional +/// slice, where the data is conceptually arranged as a two-dimensional array. +/// +/// # Arguments +/// +/// * `slice` - slice to extract a mutable reference from +/// * `index` - the index of the conceptual "row" to reference +/// * `row_size` - the size of the conceptual "row" to reference #[inline] pub(crate) fn row_mut(slice: &mut [T], index: usize, row_size: usize) -> &mut [T] { &mut slice[index * row_size..(index + 1) * row_size] From 87ac61ef87da1986dc23b7b02fb61a64b2a9a1e9 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 11:50:13 -0700 Subject: [PATCH 27/66] Fixed docs in sliceop --- src/sliceop.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sliceop.rs b/src/sliceop.rs index b876b5d..e250dab 100644 --- a/src/sliceop.rs +++ b/src/sliceop.rs @@ -66,8 +66,8 @@ pub(crate) fn sum_slices_uninit<'a>(dst: &'a mut [MaybeUninit], src: &[f32] dst } -/// Compute a _strided summation_ of `f64` elements in `src`, where the stride -/// length is `dst.len()`. +/// Compute a _strided summation_ of `f32` elements in `src`, where the stride +/// length is `dst.len()`, and store as `f64` in `dst`. /// /// In more detail, break source slice `src` into `N` chunks `C0...CN-1`, where /// `N = dst.len()`, and set the `i`th element of `dst` to be the sum of the From 3c8056dfdbac084717e17f9ef829c6e38b72aea9 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 12:16:50 -0700 Subject: [PATCH 28/66] Docstrings for sliceops --- src/sliceop.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/sliceop.rs b/src/sliceop.rs index e250dab..e6dbdc9 100644 --- a/src/sliceop.rs +++ b/src/sliceop.rs @@ -11,6 +11,22 @@ pub(crate) fn mul_slice(lhs: &mut [f32], rhs: &[f32]) { lhs.iter_mut().zip(rhs).for_each(|(l, r)| *l *= *r); } +/// Divides each element of the left-hand side (`lhs`) slice by the +/// corresponding element of the right-hand side (`rhs`) slice, modifying the +/// `lhs` slice in place. If an element in `rhs` is zero, the corresponding +/// element in `lhs` is set to a specified `default` value instead of performing +/// the division. +/// +/// # Arguments +/// +/// - `lhs`: A mutable reference to the left-hand side slice, which will be +/// modified in place. Each element of this slice is divided by the +/// corresponding element in the `rhs` slice, or set to `default` if the +/// corresponding element in `rhs` is zero. +/// - `rhs`: A reference to the right-hand side slice, which provides the +/// divisor for each element in `lhs`. +/// - `default: f32`: A fallback value that is used for elements in `lhs` where +/// the corresponding element in `rhs` is zero. #[inline] pub(crate) fn div_slice(lhs: &mut [f32], rhs: &[f32], default: f32) { lhs.iter_mut() @@ -18,6 +34,22 @@ pub(crate) fn div_slice(lhs: &mut [f32], rhs: &[f32], default: f32) { .for_each(|(l, r)| *l = if is_zero(*r) { default } else { *l / *r }); } +/// Divides each element of the left-hand side (`lhs`) slice by the +/// corresponding element of the right-hand side (`rhs`) slice, modifying the +/// `lhs` slice in place. If an element in `rhs` is zero, the corresponding +/// element in `lhs` is set to a specified `default` value instead of performing +/// the division. +/// +/// # Arguments +/// +/// - `lhs`: A mutable reference to the left-hand side slice, which will be +/// modified in place. Each element of this slice is divided by the +/// corresponding element in the `rhs` slice, or set to `default` if the +/// corresponding element in `rhs` is zero. +/// - `rhs`: A reference to the right-hand side slice, which provides the +/// divisor for each element in `lhs`. +/// - `default: f32`: A fallback value that is used for elements in `lhs` where +/// the corresponding element in `rhs` is zero. #[inline] pub(crate) fn div_slice_uninit( dst: &mut [MaybeUninit], From 3b7606bfff01f1846a65d4a61ddda66574db89f0 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 15:26:48 -0700 Subject: [PATCH 29/66] Docs and rename --- src/game/base.rs | 10 ++++---- src/sliceop.rs | 54 ++++++++++++++++++++++++++++++++++++++ src/solver.rs | 67 ++++++++++++++++++++++++++++++++++++++++++------ src/utility.rs | 17 +++++++----- 4 files changed, 129 insertions(+), 19 deletions(-) diff --git a/src/game/base.rs b/src/game/base.rs index 94931c1..8d0cf43 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -1254,7 +1254,7 @@ impl PostFlopGame { let player_swap = swap_option.map(|swap| { let mut tmp = (0..player_len).collect::>(); - apply_swap(&mut tmp, &swap[player]); + apply_swap_list(&mut tmp, &swap[player]); tmp }); @@ -1276,8 +1276,8 @@ impl PostFlopGame { let slices = if let Some(swap) = swap_option { tmp.0.extend_from_slice(&arena[index..index + opponent_len]); tmp.1.extend_from_slice(opponent_strength); - apply_swap(&mut tmp.0, &swap[player ^ 1]); - apply_swap(&mut tmp.1, &swap[player ^ 1]); + apply_swap_list(&mut tmp.0, &swap[player ^ 1]); + apply_swap_list(&mut tmp.1, &swap[player ^ 1]); (tmp.0.as_slice(), &tmp.1) } else { (&arena[index..index + opponent_len], opponent_strength) @@ -1348,7 +1348,7 @@ impl PostFlopGame { let player_swap = swap_option.map(|swap| { let mut tmp = (0..player_len).collect::>(); - apply_swap(&mut tmp, &swap[player]); + apply_swap_list(&mut tmp, &swap[player]); tmp }); @@ -1365,7 +1365,7 @@ impl PostFlopGame { let slice = &arena[index..index + opponent_len]; let slice = if let Some(swap) = swap_option { tmp.extend_from_slice(slice); - apply_swap(&mut tmp, &swap[player ^ 1]); + apply_swap_list(&mut tmp, &swap[player ^ 1]); &tmp } else { slice diff --git a/src/sliceop.rs b/src/sliceop.rs index e6dbdc9..43ac497 100644 --- a/src/sliceop.rs +++ b/src/sliceop.rs @@ -1,11 +1,31 @@ use crate::utility::*; use std::mem::MaybeUninit; +/// Subtracts each element of the left-hand side (`lhs`) slice by the +/// corresponding element of the right-hand side (`rhs`) slice, modifying the +/// `lhs` slice in place. +/// +/// # Arguments +/// +/// - `lhs`: A mutable reference to the left-hand side slice, which will be +/// modified in place. +/// - `rhs`: A reference to the right-hand side slice, which provides the +/// values to be subtracted from each corresponding element in `lhs`. #[inline] pub(crate) fn sub_slice(lhs: &mut [f32], rhs: &[f32]) { lhs.iter_mut().zip(rhs).for_each(|(l, r)| *l -= *r); } +/// Multiplies each element of the left-hand side (`lhs`) slice by the +/// corresponding element of the right-hand side (`rhs`) slice, modifying the +/// `lhs` slice in place. +/// +/// # Arguments +/// +/// - `lhs`: A mutable reference to the left-hand side slice, which will be +/// modified in place. +/// - `rhs`: A reference to the right-hand side slice, which provides the +/// values to be multiplied against each corresponding element in `lhs`. #[inline] pub(crate) fn mul_slice(lhs: &mut [f32], rhs: &[f32]) { lhs.iter_mut().zip(rhs).for_each(|(l, r)| *l *= *r); @@ -127,6 +147,40 @@ pub(crate) fn sum_slices_f64_uninit<'a>( dst } +/// Performs a fused multiply-add (FMA) operation on slices, storing the result +/// in a destination slice. +/// +/// This function multiplies the first `dst.len()` corresponding elements of the +/// two source slices (`src1` and `src2`) and stores the results in the +/// destination slice (`dst`). After the initial multiplication, it continues +/// to perform additional multiply-add operations using subsequent chunks of +/// `src1` and `src2`, adding the products to the already computed values in +/// `dst`. +/// +/// # Arguments +/// +/// - `dst`: A mutable reference to a slice of uninitialized memory where the +/// results will be stored. The length of this slice dictates how many +/// elements are processed in the initial operation. +/// - `src1`: A reference to the first source slice, providing the +/// multiplicands. +/// - `src2`: A reference to the second source slice, providing the multipliers. +/// +/// # Returns +/// +/// A mutable reference to the `dst` slice, now reinterpreted as a fully +/// initialized slice of `f32` values, containing the results of the fused +/// multiply-add operations. +/// +/// # Safety +/// +/// - This function assumes that the length of `dst` is equal to or less than +/// the length of `src1` and `src2`. If the lengths are mismatched, the function +/// might cause undefined behavior due to out-of-bounds memory access. +/// - The function uses unsafe code to cast the `MaybeUninit` slice into a +/// `f32` slice after initialization. This is safe only if the `dst` slice is +/// properly initialized with valid `f32` values, as is ensured by the +/// function's implementation. #[inline] pub(crate) fn fma_slices_uninit<'a>( dst: &'a mut [MaybeUninit], diff --git a/src/solver.rs b/src/solver.rs index 02db93c..6dfe75c 100644 --- a/src/solver.rs +++ b/src/solver.rs @@ -9,8 +9,11 @@ use std::mem::MaybeUninit; use crate::alloc::*; struct DiscountParams { + // coefficient for accumulated positive regrets alpha_t: f32, + // coefficient for accumulated negative regrets beta_t: f32, + // contributions to average strategy gamma_t: f32, } @@ -200,7 +203,16 @@ pub fn solve_with_node_as_root( exploitability } -/// Recursively solves the counterfactual values. +/// Recursively solves the counterfactual values and store them in `result`. +/// +/// # Arguments +/// +/// * `result` - slice to store resulting counterfactual regret values +/// * `game` - reference to the game we are solving +/// * `node` - current node we are solving +/// * `player` - current player we are solving for +/// * `cfreach` - the probability of reaching this point with a particular private hand +/// * `params` - the DiscountParams that parametrize the solver fn solve_recursive( result: &mut [MaybeUninit], game: &T, @@ -225,7 +237,14 @@ fn solve_recursive( return; } - // allocate memory for storing the counterfactual values + // Allocate memory for storing the counterfactual values. Conceptually this + // is a `num_actions * num_hands` 2-dimensional array, where the `i`th + // row (which has length `num_hands`) corresponds to the cfvalues of each + // hand after taking the `i`th action. + // + // Rows are obtained using operations from `sliceop` (e.g., `sliceop::row_mut()`). + // + // `cfv_actions` will be written to by recursive calls to `solve_recursive`. #[cfg(feature = "custom-alloc")] let cfv_actions = MutexLike::new(Vec::with_capacity_in(num_actions * num_hands, StackAlloc)); #[cfg(not(feature = "custom-alloc"))] @@ -257,13 +276,15 @@ fn solve_recursive( ); }); - // use 64-bit floating point values + // use 64-bit floating point values for precision during summations + // before demoting back to f32 #[cfg(feature = "custom-alloc")] let mut result_f64 = Vec::with_capacity_in(num_hands, StackAlloc); #[cfg(not(feature = "custom-alloc"))] let mut result_f64 = Vec::with_capacity(num_hands); - // sum up the counterfactual values + // compute the strided summation of the counterfactual values for each + // hand and store in `result_f64` let mut cfv_actions = cfv_actions.lock(); unsafe { cfv_actions.set_len(num_actions * num_hands) }; sum_slices_f64_uninit(result_f64.spare_capacity_mut(), &cfv_actions); @@ -277,13 +298,13 @@ fn solve_recursive( let swap_list = &game.isomorphic_swap(node, i)[player]; let tmp = row_mut(&mut cfv_actions, isomorphic_index as usize, num_hands); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); result_f64.iter_mut().zip(&*tmp).for_each(|(r, &v)| { *r += v as f64; }); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); } result.iter_mut().zip(&result_f64).for_each(|(r, &v)| { @@ -304,7 +325,7 @@ fn solve_recursive( ); }); - // compute the strategy by regret-maching algorithm + // compute the strategy by regret-matching algorithm let mut strategy = if game.is_compression_enabled() { regret_matching_compressed(node.regrets_compressed(), num_actions) } else { @@ -315,7 +336,17 @@ fn solve_recursive( let locking = game.locking_strategy(node); apply_locking_strategy(&mut strategy, locking); - // sum up the counterfactual values + // Compute the counterfactual values for each hand, which for hand `h` is + // computed to be the sum over actions `a` of the frequency with which + // `h` takes action `a` and the regret of hand `h` taking action `a`. + // In pseudocode, this is: + // + // ``` + // result[h] = sum([freq(h, a) * regret(h, a) for a in actions]) + // ``` + // + // This sum-of-products us computed as a fused multiply-add using + // `fma_slices_uninit` and is stored in `result`. let mut cfv_actions = cfv_actions.lock(); unsafe { cfv_actions.set_len(num_actions * num_hands) }; let result = fma_slices_uninit(result, &strategy, &cfv_actions); @@ -367,6 +398,7 @@ fn solve_recursive( node.set_regret_scale(new_scale); } else { // update the cumulative strategy + // - `gamma` is used to discount cumulative strategy contributions let gamma = params.gamma_t; let cum_strategy = node.strategy_mut(); cum_strategy.iter_mut().zip(&strategy).for_each(|(x, y)| { @@ -374,6 +406,8 @@ fn solve_recursive( }); // update the cumulative regret + // - alpha is used to discount positive cumulative regrets + // - beta is used to discount negative cumulative regrets let (alpha, beta) = (params.alpha_t, params.beta_t); let cum_regret = node.regrets_mut(); cum_regret.iter_mut().zip(&*cfv_actions).for_each(|(x, y)| { @@ -448,6 +482,18 @@ fn regret_matching(regret: &[f32], num_actions: usize) -> Vec { } /// Computes the strategy by regret-matching algorithm. +/// +/// The resulting strategy has each element (e.g., a hand like **AdQs**) take +/// an action proportional to its regret, where negative regrets are interpreted +/// as zero. +/// +/// # Arguments +/// +/// * `regret` - slice of regrets for the current decision point, one "row" of +/// for each action. The `i`th row contains the regrets of each strategically +/// distinct element (e.g., in holdem an element would be a hole card) for +/// taking the `i`th action. +/// * `num_actions` - the number of actions represented in `regret`. #[cfg(not(feature = "custom-alloc"))] #[inline] fn regret_matching(regret: &[f32], num_actions: usize) -> Vec { @@ -459,10 +505,15 @@ fn regret_matching(regret: &[f32], num_actions: usize) -> Vec { unsafe { strategy.set_len(regret.len()) }; let row_size = regret.len() / num_actions; + + // We want to normalize each element's strategy, so compute the element-wise + // denominator by computing the strided summation of strategy let mut denom = Vec::with_capacity(row_size); sum_slices_uninit(denom.spare_capacity_mut(), &strategy); unsafe { denom.set_len(row_size) }; + // We set the default to be equally distributed across all options. This is + // used when a strategy for a particular hand is uniformly zero. let default = 1.0 / num_actions as f32; strategy.chunks_exact_mut(row_size).for_each(|row| { div_slice(row, &denom, default); diff --git a/src/utility.rs b/src/utility.rs index c834de7..d38208b 100644 --- a/src/utility.rs +++ b/src/utility.rs @@ -227,9 +227,14 @@ pub(crate) fn encode_unsigned_slice(dst: &mut [u16], slice: &[f32]) -> f32 { scale } -/// Applies the given swap to the given slice. +/// Applies the given list of swaps to the given slice. +/// +/// # Arguments +/// +/// * `slice` - mutable slice to perform swaps on +/// * `swap_list` - a list of index pairs to swap #[inline] -pub(crate) fn apply_swap(slice: &mut [T], swap_list: &[(u16, u16)]) { +pub(crate) fn apply_swap_list(slice: &mut [T], swap_list: &[(u16, u16)]) { for &(i, j) in swap_list { unsafe { ptr::swap( @@ -425,13 +430,13 @@ fn compute_cfvalue_recursive( let swap_list = &game.isomorphic_swap(node, i)[player]; let tmp = row_mut(&mut cfv_actions, isomorphic_index as usize, num_hands); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); result_f64.iter_mut().zip(&*tmp).for_each(|(r, &v)| { *r += v as f64; }); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); } result.iter_mut().zip(&result_f64).for_each(|(r, &v)| { @@ -637,13 +642,13 @@ fn compute_best_cfv_recursive( let swap_list = &game.isomorphic_swap(node, i)[player]; let tmp = row_mut(&mut cfv_actions, isomorphic_index as usize, num_hands); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); result_f64.iter_mut().zip(&*tmp).for_each(|(r, &v)| { *r += v as f64; }); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); } result.iter_mut().zip(&result_f64).for_each(|(r, &v)| { From af355788429e3bce3c43880ea6078dbb5ab1f179 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 15:28:34 -0700 Subject: [PATCH 30/66] Docs --- src/sliceop.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/sliceop.rs b/src/sliceop.rs index 43ac497..f99ad70 100644 --- a/src/sliceop.rs +++ b/src/sliceop.rs @@ -171,16 +171,6 @@ pub(crate) fn sum_slices_f64_uninit<'a>( /// A mutable reference to the `dst` slice, now reinterpreted as a fully /// initialized slice of `f32` values, containing the results of the fused /// multiply-add operations. -/// -/// # Safety -/// -/// - This function assumes that the length of `dst` is equal to or less than -/// the length of `src1` and `src2`. If the lengths are mismatched, the function -/// might cause undefined behavior due to out-of-bounds memory access. -/// - The function uses unsafe code to cast the `MaybeUninit` slice into a -/// `f32` slice after initialization. This is safe only if the `dst` slice is -/// properly initialized with valid `f32` values, as is ensured by the -/// function's implementation. #[inline] pub(crate) fn fma_slices_uninit<'a>( dst: &'a mut [MaybeUninit], From d879d6a0a38ad8577c5a374ec4c267a515d08b4f Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 17:31:35 -0700 Subject: [PATCH 31/66] Tmp: splitting branches --- src/game/base.rs | 183 ----------------------------------------------- 1 file changed, 183 deletions(-) diff --git a/src/game/base.rs b/src/game/base.rs index 8d0cf43..ac8e9a8 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -829,189 +829,6 @@ impl PostFlopGame { info.num_storage_ip += node.num_elements_ip as u64; } - /* REBUILDING AND RESOLVING TREE */ - - /// Like `init_root`, but applied to a partial save loaded from disk. This - /// reallocates missing `PostFlopNode`s to `node_arena` and reruns - /// `build_tree_recursive`. Rerunning `build_tree_recursive` will not alter - /// nodes loaded from disk. - pub fn reinit_root(&mut self) -> Result<(), String> { - let nodes_per_street = self.count_nodes_per_street(); - let total_num_nodes = nodes_per_street[0] + nodes_per_street[1] + nodes_per_street[2]; - - if total_num_nodes > u32::MAX as u64 - || mem::size_of::() as u64 * total_num_nodes > isize::MAX as u64 - { - return Err("Too many nodes".to_string()); - } - - self.num_nodes_per_street = nodes_per_street; - - self.node_arena = (0..total_num_nodes) - .map(|_| MutexLike::new(PostFlopNode::default())) - .collect::>(); - // self.clear_storage(); - - let mut info = BuildTreeInfo { - turn_index: nodes_per_street[0] as usize, - river_index: (nodes_per_street[0] + nodes_per_street[1]) as usize, - ..Default::default() - }; - - match self.tree_config.initial_state { - BoardState::Flop => info.flop_index += 1, - BoardState::Turn => info.turn_index += 1, - BoardState::River => info.river_index += 1, - } - - let mut root = self.node_arena[0].lock(); - root.turn = self.card_config.turn; - root.river = self.card_config.river; - - self.build_tree_recursive(0, &self.action_root.lock(), &mut info); - - self.num_storage = info.num_storage; - self.num_storage_ip = info.num_storage_ip; - self.num_storage_chance = info.num_storage_chance; - self.misc_memory_usage = self.memory_usage_internal(); - - Ok(()) - } - - pub fn rebuild_and_resolve_forgotten_streets(&mut self) -> Result<(), String> { - self.check_card_config()?; - self.reinit_root()?; - self.allocate_memory_after_load()?; - self.resolve_reloaded_nodes(1000, 0.01, false) - } - - /// Return the node index for each root of the forgotten gametrees that were - /// omitted during a partial save. - /// - /// When we perform a partial save (e.g., a flop save), we lose - /// cfvalues/strategy data for all subtrees rooted at the forgotten street - /// (in the case of a flop save, this would be all subtrees rooted at the - /// beginning of the turn). - /// - /// To regain this information we need to resolve each of these subtrees - /// individually. This function collects the index of each such root. - pub fn collect_unsolved_roots_after_reload(&mut self) -> Result, String> { - match self.storage_mode { - BoardState::Flop => { - let turn_root_nodes = self - .node_arena - .iter() - .enumerate() - .filter(|(_, n)| { - n.lock().turn != NOT_DEALT - && n.lock().river == NOT_DEALT - && matches!(n.lock().prev_action, Action::Chance(..)) - }) - .map(|(i, _)| i) - .collect::>(); - Ok(turn_root_nodes) - } - BoardState::Turn => { - let river_root_nodes = self - .node_arena - .iter() - .enumerate() - .filter(|(_, n)| { - n.lock().turn != NOT_DEALT - && matches!(n.lock().prev_action, Action::Chance(..)) - }) - .map(|(i, _)| i) - .collect::>(); - Ok(river_root_nodes) - } - BoardState::River => Ok(vec![]), - } - } - - pub fn resolve_reloaded_nodes( - &mut self, - max_num_iterations: u32, - target_exploitability: f32, - print_progress: bool, - ) -> Result<(), String> { - let nodes_to_solve = self.collect_unsolved_roots_after_reload()?; - self.state = State::MemoryAllocated; - for node_idx in nodes_to_solve { - let node = self.node_arena.get(node_idx).ok_or("Invalid node index")?; - // let history = node - // .lock() - // .compute_history_recursive(&self) - // .ok_or("Unable to compute history for node".to_string())? - // .to_vec(); - // self.apply_history(&history); - solve_with_node_as_root( - self, - node.lock(), - max_num_iterations, - target_exploitability, - print_progress, - ); - } - finalize(self); - - Ok(()) - } - - /// Reallocate memory for full tree after performing a partial load. This - /// must be called after `init_root()` - pub fn allocate_memory_after_load(&mut self) -> Result<(), String> { - if self.state <= State::Uninitialized { - return Err("Game is not successfully initialized".to_string()); - } - - if self.state == State::MemoryAllocated && self.storage_mode == BoardState::River { - return Ok(()); - } - - let num_bytes = if self.is_compression_enabled { 2 } else { 4 }; - if num_bytes * self.num_storage > isize::MAX as u64 - || num_bytes * self.num_storage_chance > isize::MAX as u64 - { - return Err("Memory usage exceeds maximum size".to_string()); - } - - self.state = State::MemoryAllocated; - // self.is_compression_enabled = self.is_compression_enabled; - - let old_storage1 = std::mem::replace(&mut self.storage1, vec![]); - let old_storage2 = std::mem::replace(&mut self.storage2, vec![]); - let old_storage_ip = std::mem::replace(&mut self.storage_ip, vec![]); - let old_storage_chance = std::mem::replace(&mut self.storage_chance, vec![]); - - let storage_bytes = (num_bytes * self.num_storage) as usize; - let storage_ip_bytes = (num_bytes * self.num_storage_ip) as usize; - let storage_chance_bytes = (num_bytes * self.num_storage_chance) as usize; - - self.storage1 = vec![0; storage_bytes]; - self.storage2 = vec![0; storage_bytes]; - self.storage_ip = vec![0; storage_ip_bytes]; - self.storage_chance = vec![0; storage_chance_bytes]; - - self.allocate_memory_nodes(); // Assign node storage pointers - - self.storage_mode = BoardState::River; - self.target_storage_mode = BoardState::River; - - for (dst, src) in self.storage1.iter_mut().zip(&old_storage1) { - *dst = *src; - } - for (dst, src) in self.storage2.iter_mut().zip(&old_storage2) { - *dst = *src; - } - for (dst, src) in self.storage_ip.iter_mut().zip(&old_storage_ip) { - *dst = *src; - } - for (dst, src) in self.storage_chance.iter_mut().zip(&old_storage_chance) { - *dst = *src; - } - Ok(()) - } - /// Sets the bunching effect. fn set_bunching_effect_internal(&mut self, bunching_data: &BunchingData) -> Result<(), String> { self.bunching_num_dead_cards = bunching_data.fold_ranges().len() * 2; From 2de67198316086c5eba931477504f388f58fbe3e Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 17:36:26 -0700 Subject: [PATCH 32/66] Branch refactor: removed solve_with_node_as_root --- src/solver.rs | 68 --------------------------------------------------- 1 file changed, 68 deletions(-) diff --git a/src/solver.rs b/src/solver.rs index 6dfe75c..1548a1a 100644 --- a/src/solver.rs +++ b/src/solver.rs @@ -135,74 +135,6 @@ pub fn solve_step(game: &T, current_iteration: u32) { } } -/// Performs Discounted CFR algorithm until the given number of iterations or exploitability is -/// satisfied. -/// -/// This method returns the exploitability of the obtained strategy. -pub fn solve_with_node_as_root( - game: &mut T, - mut root: MutexGuardLike, - max_num_iterations: u32, - target_exploitability: f32, - print_progress: bool, -) -> f32 { - if game.is_solved() { - panic!("Game is already solved"); - } - - if !game.is_ready() { - panic!("Game is not ready"); - } - - let mut exploitability = compute_exploitability(game); - - if print_progress { - print!("iteration: 0 / {max_num_iterations} "); - print!("(exploitability = {exploitability:.4e})"); - io::stdout().flush().unwrap(); - } - - for t in 0..max_num_iterations { - if exploitability <= target_exploitability { - break; - } - - let params = DiscountParams::new(t); - - // alternating updates - for player in 0..2 { - let mut result = Vec::with_capacity(game.num_private_hands(player)); - solve_recursive( - result.spare_capacity_mut(), - game, - &mut root, - player, - game.initial_weights(player ^ 1), - ¶ms, - ); - } - - if (t + 1) % 10 == 0 || t + 1 == max_num_iterations { - exploitability = compute_exploitability(game); - } - - if print_progress { - print!("\riteration: {} / {} ", t + 1, max_num_iterations); - print!("(exploitability = {exploitability:.4e})"); - io::stdout().flush().unwrap(); - } - } - - if print_progress { - println!(); - io::stdout().flush().unwrap(); - } - - finalize(game); - - exploitability -} - /// Recursively solves the counterfactual values and store them in `result`. /// /// # Arguments From edeb03e2421c850401d7e5c02e63fceb38e50b94 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 17:40:57 -0700 Subject: [PATCH 33/66] Branch refactor --- src/game/base.rs | 1 - src/game/interpreter.rs | 19 ------------------- 2 files changed, 20 deletions(-) diff --git a/src/game/base.rs b/src/game/base.rs index ac8e9a8..7ff5f8b 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -1,7 +1,6 @@ use super::*; use crate::bunching::*; use crate::interface::*; -use crate::solve_with_node_as_root; use crate::utility::*; use std::mem::{self, MaybeUninit}; diff --git a/src/game/interpreter.rs b/src/game/interpreter.rs index 619011d..652e990 100644 --- a/src/game/interpreter.rs +++ b/src/game/interpreter.rs @@ -847,25 +847,6 @@ impl PostFlopGame { ret } - pub fn strategy_by_private_hand(&self) -> Vec> { - let strat = self.strategy(); - let player = self.current_player(); - let num_hands = self.private_cards(player).len(); - let num_actions = self.available_actions().len(); - assert!(num_hands * num_actions == strat.len()); - let mut strat_by_hand: Vec> = Vec::with_capacity(num_hands); - for j in 0..num_hands { - strat_by_hand.push(Vec::with_capacity(num_actions)); - } - - for i in 0..num_actions { - for j in 0..num_hands { - strat_by_hand[j].push(strat[i * num_hands + j]); - } - } - strat_by_hand - } - /// Returns the total bet amount of each player (OOP, IP). #[inline] pub fn total_bet_amount(&self) -> [i32; 2] { From 1989fb68653dc94eb337836878bda6522ebacc6f Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 18:02:35 -0700 Subject: [PATCH 34/66] Refactored/removed unused file_io_debug.rs --- examples/file_io_debug.rs | 210 -------------------------------------- 1 file changed, 210 deletions(-) delete mode 100644 examples/file_io_debug.rs diff --git a/examples/file_io_debug.rs b/examples/file_io_debug.rs deleted file mode 100644 index d5a625f..0000000 --- a/examples/file_io_debug.rs +++ /dev/null @@ -1,210 +0,0 @@ -use postflop_solver::*; - -fn recursive_compare_strategies_helper( - saved: &mut PostFlopGame, - loaded: &mut PostFlopGame, - storage_mode: BoardState, -) { - let history = saved.history().to_vec(); - saved.cache_normalized_weights(); - loaded.cache_normalized_weights(); - - // Check if OOP hands have the same evs - let evs_oop_1 = saved.expected_values(0); - let ws_oop_1 = saved.weights(0); - let evs_oop_2 = loaded.expected_values(1); - let ws_oop_2 = saved.weights(0); - - assert!(ws_oop_1.len() == ws_oop_2.len()); - for (w1, w2) in ws_oop_1.iter().zip(ws_oop_2) { - assert!((w1 - w2).abs() < 0.001); - } - for (i, (e1, e2)) in evs_oop_1.iter().zip(&evs_oop_2).enumerate() { - assert!((e1 - e2).abs() < 0.001, "ev diff({}): {}", i, e1 - e2); - } - - let ev_oop_1 = compute_average(&evs_oop_1, &ws_oop_1); - let ev_oop_2 = compute_average(&evs_oop_2, &ws_oop_2); - - let ev_diff = (ev_oop_1 - ev_oop_2).abs(); - println!("EV Diff: {:0.2}", ev_diff); - assert!((ev_oop_1 - ev_oop_2).abs() < 0.01); - for child_index in 0..saved.available_actions().len() { - saved.play(child_index); - loaded.play(child_index); - - recursive_compare_strategies_helper(saved, loaded, storage_mode); - - saved.apply_history(&history); - loaded.apply_history(&history); - } -} - -fn compare_strategies( - saved: &mut PostFlopGame, - loaded: &mut PostFlopGame, - storage_mode: BoardState, -) { - saved.back_to_root(); - loaded.back_to_root(); - saved.cache_normalized_weights(); - loaded.cache_normalized_weights(); - for (i, ((e1, e2), cards)) in saved - .expected_values(0) - .iter() - .zip(loaded.expected_values(0)) - .zip(saved.private_cards(0)) - .enumerate() - { - println!("ev {}: {}:{}", hole_to_string(*cards).unwrap(), e1, e2); - } - for (i, ((e1, e2), cards)) in saved - .expected_values(1) - .iter() - .zip(loaded.expected_values(1)) - .zip(saved.private_cards(1)) - .enumerate() - { - println!("ev {}: {}:{}", hole_to_string(*cards).unwrap(), e1, e2); - } - recursive_compare_strategies_helper(saved, loaded, storage_mode); -} - -fn print_strats_at_current_node( - g1: &mut PostFlopGame, - g2: &mut PostFlopGame, - actions: &Vec, -) { - let action_string = actions - .iter() - .map(|a| format!("{:?}", a)) - .collect::>() - .join(":"); - - let player = g1.current_player(); - - println!( - "\x1B[32;1mActions To Reach Node\x1B[0m: [{}]", - action_string - ); - // Print high level node data - if g1.is_chance_node() { - println!("\x1B[32;1mPlayer\x1B[0m: Chance"); - } else if g1.is_terminal_node() { - if player == 0 { - println!("\x1B[32;1mPlayer\x1B[0m: OOP (Terminal)"); - } else { - println!("\x1B[32;1mPlayer\x1B[0m: IP (Terminal)"); - } - } else { - if player == 0 { - println!("\x1B[32;1mPlayer\x1B[0m: OOP"); - } else { - println!("\x1B[32;1mPlayer\x1B[0m: IP"); - } - let private_cards = g1.private_cards(player); - let strat1 = g1.strategy_by_private_hand(); - let strat2 = g2.strategy_by_private_hand(); - let weights1 = g1.weights(player); - let weights2 = g2.weights(player); - let actions = g1.available_actions(); - - // Print both games strategies - for ((cards, (w1, s1)), (w2, s2)) in private_cards - .iter() - .zip(weights1.iter().zip(strat1)) - .zip(weights2.iter().zip(strat2)) - { - let hole_cards = hole_to_string(*cards).unwrap(); - print!("\x1B[34;1m{hole_cards}\x1B[0m@({:.2} v {:.2}) ", w1, w2); - let mut action_frequencies = vec![]; - for (a, (freq1, freq2)) in actions.iter().zip(s1.iter().zip(s2)) { - action_frequencies.push(format!( - "\x1B[32;1m{:?}\x1B[0m: \x1B[31m{:0.3}\x1B[0m v \x1B[33m{:0>.3}\x1B[0m", - a, freq1, freq2 - )) - } - println!("{}", action_frequencies.join(" ")); - } - } -} - -fn main() { - let oop_range = "AA,QQ"; - let ip_range = "KK"; - - let card_config = CardConfig { - range: [oop_range.parse().unwrap(), ip_range.parse().unwrap()], - flop: flop_from_str("3h3s3d").unwrap(), - ..Default::default() - }; - - let tree_config = TreeConfig { - starting_pot: 100, - effective_stack: 100, - rake_rate: 0.0, - rake_cap: 0.0, - flop_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], - turn_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], - river_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], - ..Default::default() - }; - - let action_tree = ActionTree::new(tree_config).unwrap(); - let mut game1 = PostFlopGame::with_config(card_config, action_tree).unwrap(); - game1.allocate_memory(false); - - solve(&mut game1, 100, 0.01, false); - - // save (turn) - game1.set_target_storage_mode(BoardState::Turn).unwrap(); - save_data_to_file(&game1, "", "tmpfile.flop", None).unwrap(); - - // load (turn) - let mut game2: PostFlopGame = load_data_from_file("tmpfile.flop", None).unwrap().0; - // compare_strategies(&mut game, &mut game2, BoardState::Turn); - assert!(game2.rebuild_and_resolve_forgotten_streets().is_ok()); - - let mut actions_so_far = vec![]; - - // Print Root Node - print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - - // OOP: Check - actions_so_far.push(game1.available_actions()[0]); - game1.play(0); - game2.play(0); - print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - - // IP: Check - actions_so_far.push(game1.available_actions()[0]); - game1.play(0); - game2.play(0); - print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - - // Chance: 2c - actions_so_far.push(game1.available_actions()[0]); - game1.play(0); - game2.play(0); - print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - - // OOP: CHECK - actions_so_far.push(game1.available_actions()[0]); - game1.play(0); - game2.play(0); - print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - - // IP: CHECK - actions_so_far.push(game1.available_actions()[0]); - game1.play(0); - game2.play(0); - print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - - // CHANCE: 0 - actions_so_far.push(game1.available_actions()[1]); - game1.play(1); - game2.play(1); - print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - - // compare_strategies(&mut game, &mut game2, BoardState::Turn); -} From 0ae353171ee0f74dd79b2772cbbdf0af249651a6 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 18:04:01 -0700 Subject: [PATCH 35/66] Tmp Commit --- src/game/node.rs | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/src/game/node.rs b/src/game/node.rs index 400a886..e6bd7e3 100644 --- a/src/game/node.rs +++ b/src/game/node.rs @@ -209,7 +209,6 @@ impl Default for PostFlopNode { fn default() -> Self { Self { prev_action: Action::None, - parent_node_index: usize::MAX, player: PLAYER_OOP, turn: NOT_DEALT, river: NOT_DEALT, @@ -241,31 +240,4 @@ impl PostFlopNode { ) } } - - /// Get a list of available actions at a given node - pub fn actions(&self) -> Vec { - self.children() - .iter() - .map(|n| n.lock().prev_action) - .collect::>() - } - - /// Find the index of a given action, if present - pub fn action_index(&self, action: Action) -> Option { - self.children() - .iter() - .position(|n| n.lock().prev_action == action) - } - - /// Recursively compute the current node's history - pub fn compute_history_recursive(&self, game: &PostFlopGame) -> Option> { - if self.parent_node_index == usize::MAX { - Some(vec![]) - } else { - let p = game.node_arena.get(self.parent_node_index)?; - let mut history = p.lock().compute_history_recursive(game)?; - history.push(p.lock().action_index(self.prev_action)?); - Some(history) - } - } } From deb338ca8bcb04e6594c22552264c35f4259f37b Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 18:04:38 -0700 Subject: [PATCH 36/66] tmp commit --- src/game/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/game/mod.rs b/src/game/mod.rs index 3504adb..08b23bc 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -121,7 +121,6 @@ pub struct PostFlopGame { #[repr(C)] pub struct PostFlopNode { prev_action: Action, - parent_node_index: usize, player: u8, turn: Card, river: Card, From 3227822ac862cdefc5bc53e5a7b570339d639bce Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 18:07:50 -0700 Subject: [PATCH 37/66] Refactoring test --- src/file.rs | 42 ------------------------------------------ 1 file changed, 42 deletions(-) diff --git a/src/file.rs b/src/file.rs index d7e90f7..17e7d5b 100644 --- a/src/file.rs +++ b/src/file.rs @@ -269,7 +269,6 @@ mod tests { use crate::action_tree::*; use crate::card::*; use crate::range::*; - use crate::solver::solve; use crate::utility::*; #[test] @@ -376,45 +375,4 @@ mod tests { assert!((root_ev_oop - 45.0).abs() < 1e-4); assert!((root_ev_ip - 15.0).abs() < 1e-4); } - - #[test] - fn test_reload_and_resolve() { - let oop_range = "AA,QQ"; - let ip_range = "KK"; - - let card_config = CardConfig { - range: [oop_range.parse().unwrap(), ip_range.parse().unwrap()], - flop: flop_from_str("3h3s3d").unwrap(), - ..Default::default() - }; - - let tree_config = TreeConfig { - starting_pot: 100, - effective_stack: 100, - rake_rate: 0.0, - rake_cap: 0.0, - flop_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], - turn_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], - river_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], - ..Default::default() - }; - - let action_tree = ActionTree::new(tree_config).unwrap(); - let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); - println!( - "memory usage: {:.2}GB", - game.memory_usage().0 as f64 / (1024.0 * 1024.0 * 1024.0) - ); - game.allocate_memory(false); - - solve(&mut game, 100, 0.01, false); - - // save (turn) - game.set_target_storage_mode(BoardState::Turn).unwrap(); - save_data_to_file(&game, "", "tmpfile.flop", None).unwrap(); - - // load (turn) - let mut game: PostFlopGame = load_data_from_file("tmpfile.flop", None).unwrap().0; - assert!(game.rebuild_and_resolve_forgotten_streets().is_ok()); - } } From d0aaae11b0d409448a4103d79f1131a61654db7b Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 18:09:38 -0700 Subject: [PATCH 38/66] Branch refactor continued --- src/game/base.rs | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/game/base.rs b/src/game/base.rs index 7ff5f8b..0e69184 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -462,21 +462,6 @@ impl PostFlopGame { Ok(()) } - pub fn print_internal_data(&self) { - println!("Printing internal data for PostFlopGame"); - println!("- node_arena: {}", self.node_arena.len()); - println!("- storage1: {}", self.storage1.len()); - println!("- storage2: {}", self.storage2.len()); - println!("- storage_ip: {}", self.storage_ip.len()); - println!("- storage_chance: {}", self.storage_chance.len()); - println!("- locking_strategy: {}", self.locking_strategy.len()); - println!("- storage mode: {:?}", self.storage_mode()); - println!( - "- target storage mode: {:?}", - self.target_storage_mode() - ); - } - /// Initializes fields `initial_weights` and `private_cards`. #[inline] fn init_hands(&mut self) { @@ -749,7 +734,6 @@ impl PostFlopGame { node.num_children += 1; let mut child = node.children().last().unwrap().lock(); child.prev_action = Action::Chance(card); - child.parent_node_index = node_index; child.turn = card; } } @@ -768,7 +752,6 @@ impl PostFlopGame { node.num_children += 1; let mut child = node.children().last().unwrap().lock(); child.prev_action = Action::Chance(card); - child.parent_node_index = node_index; child.turn = node.turn; child.river = card; } @@ -814,7 +797,6 @@ impl PostFlopGame { child.prev_action = *action; child.turn = node.turn; child.river = node.river; - child.parent_node_index = node_index; } let num_private_hands = self.num_private_hands(node.player as usize); From 3dad7546268068018bc81f09fdc82d6e3c8d232f Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 18 Aug 2024 18:12:22 -0700 Subject: [PATCH 39/66] Removed println --- src/game/interpreter.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/game/interpreter.rs b/src/game/interpreter.rs index 652e990..d7845b6 100644 --- a/src/game/interpreter.rs +++ b/src/game/interpreter.rs @@ -763,7 +763,6 @@ impl PostFlopGame { node.cfvalues_ip().to_vec() } } else if player == self.current_player() { - println!("BINGO"); have_actions = true; if self.is_compression_enabled { let slice = node.cfvalues_compressed(); From 462f263729639ca0c1e4e7e0746e3e4752bc3405 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 6 Oct 2024 15:50:23 -0700 Subject: [PATCH 40/66] Addressed clippy issue --- src/bunching.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bunching.rs b/src/bunching.rs index aa54932..d8596b9 100644 --- a/src/bunching.rs +++ b/src/bunching.rs @@ -145,12 +145,12 @@ pub struct BunchingData { #[inline] fn mask_to_index(mut mask: u64, k: usize) -> usize { let mut index = 0; - for i in 0..k { + COMB_TABLE.iter().take(k).for_each(|xs| { assert!(mask != 0); let tz = mask.trailing_zeros(); - index += COMB_TABLE[i][tz as usize]; + index += xs[tz as usize]; mask &= mask - 1; - } + }); index } From a44d1c63480bea429bc67275efd6cc978f890bcb Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 6 Oct 2024 15:57:50 -0700 Subject: [PATCH 41/66] Fixed some clippy errors --- src/bunching.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/bunching.rs b/src/bunching.rs index d8596b9..d337ab3 100644 --- a/src/bunching.rs +++ b/src/bunching.rs @@ -175,8 +175,8 @@ fn next_combination(mask: u64) -> u64 { #[inline] fn compress_mask(mut mask: u64, flop: [Card; 3]) -> u64 { assert!(flop[0] < flop[1] && flop[1] < flop[2]); - for i in 0..3 { - let m = (1 << (flop[i] as usize - i)) - 1; + for (i, &c) in flop.iter().enumerate() { + let m = (1 << (c as usize - i)) - 1; mask = (mask & m) | ((mask >> 1) & !m); } mask @@ -716,11 +716,15 @@ impl BunchingData { let chunk_end_index = usize::min(chunk_start_index + 100, end_index); let mut src_mask = index_to_mask(chunk_start_index, K); - for src_index in chunk_start_index..chunk_end_index { + for entry in src_table + .iter() + .take(chunk_end_index) + .skip(chunk_start_index) + { let mut src_mask_copy = src_mask; src_mask = next_combination(src_mask); - let freq = src_table[src_index].load(); + let freq = entry.load(); if freq == 0.0 { continue; } From 211557816b6d14fd4b9b52aaa21daf31ad8afb7b Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 6 Oct 2024 16:24:32 -0700 Subject: [PATCH 42/66] more clippy fixes --- src/bunching.rs | 78 ++++++++++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/src/bunching.rs b/src/bunching.rs index d337ab3..3222e1d 100644 --- a/src/bunching.rs +++ b/src/bunching.rs @@ -730,26 +730,26 @@ impl BunchingData { } let mut src_mask_bit = [0; K]; - for i in 0..K { + src_mask_bit.iter_mut().for_each(|bit| { let lsb = src_mask_copy & src_mask_copy.wrapping_neg(); src_mask_copy ^= lsb; - src_mask_bit[i] = lsb; - } + *bit = lsb; + }); - for i in 0..(1 << K) - 1 { - if num_ones[i] > 6 { + for (i, &x) in num_ones.iter().enumerate() { + if x > 6 { continue; } let mut dst_mask = 0; - for j in 0..K { + for (j, &y) in src_mask_bit.iter().enumerate() { if i & (1 << j) != 0 { - dst_mask |= src_mask_bit[j]; + dst_mask |= y; } } - let dst_index = mask_to_index(dst_mask, num_ones[i] as usize); - self.sum[num_ones[i] as usize][dst_index].add(freq); + let dst_index = mask_to_index(dst_mask, x as usize); + self.sum[x as usize][dst_index].add(freq); } } }); @@ -777,37 +777,41 @@ impl BunchingData { let dst_end_index = usize::min(dst_start_index + 100, end_index); let mut mask = index_to_mask(dst_start_index, N); - for dst_index in dst_start_index..dst_end_index { - let mut mask_copy = mask; - mask = next_combination(mask); - - let mut mask_bit = [0; N]; - for i in 0..N { - let lsb = mask_copy & mask_copy.wrapping_neg(); - mask_copy ^= lsb; - mask_bit[i] = lsb; - } - - let mut result = 0.0; - - for &(i, k) in &indices { - let mut src_mask = 0; - for j in 0..N { - if i & (1 << j) != 0 { - src_mask |= mask_bit[j]; + dst_table + .iter() + .take(dst_end_index) + .skip(dst_start_index) + .for_each(|dst| { + let mut mask_copy = mask; + mask = next_combination(mask); + + let mut mask_bit = [0; N]; + mask_bit.iter_mut().for_each(|bit| { + let lsb = mask_copy & mask_copy.wrapping_neg(); + mask_copy ^= lsb; + *bit = lsb; + }); + + let mut result = 0.0; + + for &(i, k) in &indices { + let mut src_mask = 0; + mask_bit.iter().take(N).enumerate().for_each(|(j, mb)| { + if i & (1 << j) != 0 { + src_mask |= mb; + } + }); + + let src_index = mask_to_index(src_mask, k as usize); + if k & 1 == 0 { + result += self.sum[k as usize][src_index].load(); + } else { + result -= self.sum[k as usize][src_index].load(); } } - let src_index = mask_to_index(src_mask, k as usize); - if k & 1 == 0 { - result += self.sum[k as usize][src_index].load(); - } else { - result -= self.sum[k as usize][src_index].load(); - } - } - - dst_table[dst_index].store(f32::max(result as f32, 0.0)); - } + dst.store(f32::max(result as f32, 0.0)); + }); }); } } From a322e701dd13a6bc6c85f62a8912726b797c1706 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 6 Oct 2024 16:41:33 -0700 Subject: [PATCH 43/66] clippy errors --- src/bunching.rs | 4 ++-- src/game/base.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/bunching.rs b/src/bunching.rs index 3222e1d..6ae9932 100644 --- a/src/bunching.rs +++ b/src/bunching.rs @@ -828,8 +828,8 @@ mod tests { ]; let mut mask = 0b001111; - for i in 0..15 { - assert_eq!(mask, seq[i]); + for x in seq { + assert_eq!(mask, x); mask = next_combination(mask); } } diff --git a/src/game/base.rs b/src/game/base.rs index f1cc0f0..383327a 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -477,8 +477,8 @@ impl PostFlopGame { board_mask |= 1 << river; } - for player in 0..2 { - let (hands, weights) = range[player].get_hands_weights(board_mask); + for (player, r) in range.iter().enumerate() { + let (hands, weights) = r.get_hands_weights(board_mask); self.initial_weights[player] = weights; self.private_cards[player] = hands; } From bd1d869b5d302d97683574a3c914044596cd6ff5 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 6 Oct 2024 16:48:32 -0700 Subject: [PATCH 44/66] appease the clippy --- src/game/serialization.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/game/serialization.rs b/src/game/serialization.rs index 951c785..a9a7858 100644 --- a/src/game/serialization.rs +++ b/src/game/serialization.rs @@ -103,10 +103,10 @@ impl PostFlopGame { static VERSION_STR: &str = "2023-03-19"; thread_local! { - static PTR_BASE: Cell<[*const u8; 2]> = Cell::new([ptr::null(); 2]); - static CHANCE_BASE: Cell<*const u8> = Cell::new(ptr::null()); - static PTR_BASE_MUT: Cell<[*mut u8; 3]> = Cell::new([ptr::null_mut(); 3]); - static CHANCE_BASE_MUT: Cell<*mut u8> = Cell::new(ptr::null_mut()); + static PTR_BASE: Cell<[*const u8; 2]> = const {Cell::new([ptr::null(); 2])}; + static CHANCE_BASE: Cell<*const u8> = const {Cell::new(ptr::null())}; + static PTR_BASE_MUT: Cell<[*mut u8; 3]> = const {Cell::new([ptr::null_mut(); 3])}; + static CHANCE_BASE_MUT: Cell<*mut u8> = const {Cell::new(ptr::null_mut())}; } impl Encode for PostFlopGame { From b5d95fb6a5242972ccc0763567ac564508136250 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 6 Oct 2024 16:54:36 -0700 Subject: [PATCH 45/66] Clippy has been appeased --- src/mutex_like.rs | 2 +- src/range.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mutex_like.rs b/src/mutex_like.rs index 39b573f..9254e72 100644 --- a/src/mutex_like.rs +++ b/src/mutex_like.rs @@ -68,7 +68,7 @@ impl MutexLike { } } -impl Default for MutexLike { +impl Default for MutexLike { #[inline] fn default() -> Self { Self::new(Default::default()) diff --git a/src/range.rs b/src/range.rs index 45e2c33..fd1f00b 100644 --- a/src/range.rs +++ b/src/range.rs @@ -990,6 +990,7 @@ impl FromStr for Range { } } +#[allow(clippy::to_string_trait_impl)] impl ToString for Range { #[inline] fn to_string(&self) -> String { From b958e14c2bf72120dfb4f0b4e226a5338cae6048 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 6 Oct 2024 16:58:01 -0700 Subject: [PATCH 46/66] More clippy --- src/alloc.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/alloc.rs b/src/alloc.rs index 9452d81..1a1cf3c 100644 --- a/src/alloc.rs +++ b/src/alloc.rs @@ -29,11 +29,11 @@ struct StackAllocData { } thread_local! { - static STACK_ALLOC_DATA: RefCell = RefCell::new(StackAllocData { + static STACK_ALLOC_DATA: RefCell = const {RefCell::new(StackAllocData { index: usize::MAX, base: Vec::new(), current: Vec::new(), - }); + })}; } impl StackAllocData { From f76ef3d823767079764e2a416d7a1e776fcfc67b Mon Sep 17 00:00:00 2001 From: bkushigian Date: Mon, 7 Oct 2024 00:21:56 -0700 Subject: [PATCH 47/66] return &self.state; --> &self.state --- src/game/base.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/game/base.rs b/src/game/base.rs index 41154a3..e7501f7 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -1458,6 +1458,6 @@ impl PostFlopGame { } pub fn get_state(&self) -> &State { - return &self.state; + &self.state } } From 2e87e661cac412009b30e81bf7db36172729d864 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Mon, 7 Oct 2024 00:29:43 -0700 Subject: [PATCH 48/66] Ticked up version: v0.1.0 -> v0.1.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 981906f..f163db5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postflop-solver" -version = "0.1.0" +version = "0.1.1" authors = ["Wataru Inariba", "Ben Kushigian"] edition = "2021" description = "An open-source postflop solver for Texas hold'em poker" From 614f78ba782510c5fc790c585388194e28611d10 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Tue, 8 Oct 2024 16:44:21 -0700 Subject: [PATCH 49/66] Tried deriving serialize/deserialize for configs --- Cargo.toml | 2 ++ src/action_tree.rs | 5 +++-- src/bet_size.rs | 7 ++++--- src/card.rs | 4 +++- src/range.rs | 46 +++++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 57 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f163db5..1a1bbe3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,8 @@ once_cell = "1.18.0" rayon = { version = "1.8.0", optional = true } regex = "1.9.6" zstd = { version = "0.12.4", optional = true, default-features = false } +serde = {version = "1.0", features = ["derive"] } +serde_json = "1.0" [features] default = ["bincode", "rayon"] diff --git a/src/action_tree.rs b/src/action_tree.rs index 8366c12..04ddc76 100644 --- a/src/action_tree.rs +++ b/src/action_tree.rs @@ -4,6 +4,7 @@ use crate::mutex_like::*; #[cfg(feature = "bincode")] use bincode::{Decode, Encode}; +use serde::{Deserialize, Serialize}; pub(crate) const PLAYER_OOP: u8 = 0; pub(crate) const PLAYER_IP: u8 = 1; @@ -44,7 +45,7 @@ pub enum Action { } /// An enum representing the board state. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] #[repr(u8)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub enum BoardState { @@ -79,7 +80,7 @@ pub enum BoardState { /// merging_threshold: 0.1, /// }; /// ``` -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct TreeConfig { /// Initial state of the game tree (flop, turn, or river). diff --git a/src/bet_size.rs b/src/bet_size.rs index a781326..7830ffa 100644 --- a/src/bet_size.rs +++ b/src/bet_size.rs @@ -1,5 +1,6 @@ #[cfg(feature = "bincode")] use bincode::{Decode, Encode}; +use serde::{Deserialize, Serialize}; /// Bet size options for the first bets and raises. /// @@ -37,7 +38,7 @@ use bincode::{Decode, Encode}; /// /// assert_eq!(bet_size.raise, vec![PrevBetRelative(2.5)]); /// ``` -#[derive(Debug, Clone, Default, PartialEq)] +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct BetSizeOptions { /// Bet size options for first bet. @@ -50,14 +51,14 @@ pub struct BetSizeOptions { /// Bet size options for the donk bets. /// /// See the [`BetSizeOptions`] struct for the description and examples. -#[derive(Debug, Clone, Default, PartialEq)] +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct DonkSizeOptions { pub donk: Vec, } /// Bet size specification. -#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Serialize, Deserialize)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub enum BetSize { /// Bet size relative to the current pot size. diff --git a/src/card.rs b/src/card.rs index e36121a..d50f83e 100644 --- a/src/card.rs +++ b/src/card.rs @@ -4,6 +4,8 @@ use std::mem; #[cfg(feature = "bincode")] use bincode::{Decode, Encode}; +use serde::Deserialize; +use serde::Serialize; /// A type representing a card, defined as an alias of `u8`. /// @@ -34,7 +36,7 @@ pub const NOT_DEALT: Card = Card::MAX; /// river: NOT_DEALT, /// }; /// ``` -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct CardConfig { /// Initial range of each player. diff --git a/src/range.rs b/src/range.rs index fd1f00b..77ccb79 100644 --- a/src/range.rs +++ b/src/range.rs @@ -1,7 +1,10 @@ use crate::card::*; use once_cell::sync::Lazy; use regex::Regex; -use std::fmt::Write; +use serde::de::{self, SeqAccess, Visitor}; +use serde::ser::SerializeSeq; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt::{self, Write}; use std::str::FromStr; #[cfg(feature = "bincode")] @@ -952,6 +955,47 @@ impl Range { } } +impl Serialize for Range { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_seq(Some(self.data.len()))?.end() + } +} + +impl<'de> Deserialize<'de> for Range { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct RangeVisitor; + + impl<'de> Visitor<'de> for RangeVisitor { + type Value = Range; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an array of 1326 floats") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut data = [0.0; 1326]; + for i in 0..1326 { + data[i] = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(i, &self))?; + } + Ok(Range { data }) + } + } + + deserializer.deserialize_seq(RangeVisitor) + } +} + impl FromStr for Range { type Err = String; From 4f5e85af398ed7a9ca1df3fbe849d21183c91ab6 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Wed, 9 Oct 2024 07:47:31 -0700 Subject: [PATCH 50/66] Serialize/deserialize ranges --- src/range.rs | 74 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 59 insertions(+), 15 deletions(-) diff --git a/src/range.rs b/src/range.rs index 77ccb79..f65d9b7 100644 --- a/src/range.rs +++ b/src/range.rs @@ -1,8 +1,7 @@ use crate::card::*; use once_cell::sync::Lazy; use regex::Regex; -use serde::de::{self, SeqAccess, Visitor}; -use serde::ser::SerializeSeq; +use serde::de::{self, Visitor}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt::{self, Write}; use std::str::FromStr; @@ -960,7 +959,7 @@ impl Serialize for Range { where S: Serializer, { - serializer.serialize_seq(Some(self.data.len()))?.end() + serializer.serialize_str(&self.to_string()) } } @@ -975,24 +974,21 @@ impl<'de> Deserialize<'de> for Range { type Value = Range; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("an array of 1326 floats") + formatter.write_str("a valid range string") } - fn visit_seq(self, mut seq: A) -> Result + fn visit_str(self, v: &str) -> Result where - A: SeqAccess<'de>, + E: de::Error, { - let mut data = [0.0; 1326]; - for i in 0..1326 { - data[i] = seq - .next_element()? - .ok_or_else(|| de::Error::invalid_length(i, &self))?; - } - Ok(Range { data }) + Range::from_str(v).or_else(|m| { + Err(de::Error::custom( + format!("Invalid range string \"{}\"\n\n{}", v, m).as_str(), + )) + }) } } - - deserializer.deserialize_seq(RangeVisitor) + deserializer.deserialize_str(RangeVisitor) } } @@ -1048,6 +1044,11 @@ impl ToString for Range { #[cfg(test)] mod tests { + use std::{ + fs::File, + io::{BufWriter, Write}, + }; + use super::*; #[test] @@ -1208,4 +1209,47 @@ mod tests { assert_eq!(range.unwrap().to_string(), expected); } } + + use serde_json; + #[test] + pub fn serialize_and_deserialize() { + let tests = [ + "AA,KK", + "KK,QQ", + "66-22,TT+", + "AA:0.5, KK:1.0, QQ:1.0, JJ:0.5", + "AA,AK,AQ", + "AK,AQ,AJs", + "KQ,KT,K9,K8,K6,K5", + "AhAs-QhQs,JJ", + "KJs+,KQo,KsJh", + "KcQh,KJ", + ]; + + for (test_no, input) in tests.iter().enumerate() { + let range = input.parse::(); + assert!(range.is_ok()); + + let range = range.unwrap(); + let json_string = serde_json::to_string(&range).unwrap(); + + let path = format!("range_{test_no}.json"); + let file = File::create(&path).unwrap(); + + let mut writer = BufWriter::new(&file); + writer.write_all(json_string.as_bytes()).unwrap(); + writer.flush().unwrap(); + + let range_deserialized = std::fs::read_to_string(&path); + assert!(range_deserialized.is_ok()); + + let range_deserialized = serde_json::from_str(&range_deserialized.unwrap()); + assert!(range_deserialized.is_ok(), "{:?}", range_deserialized); + + let range_deserialized = range_deserialized.unwrap(); + assert!(range == range_deserialized); + + std::fs::remove_file(&path).unwrap(); + } + } } From e260ef174644cf97577970abd12a2a6363ba3dde Mon Sep 17 00:00:00 2001 From: bkushigian Date: Wed, 9 Oct 2024 08:13:48 -0700 Subject: [PATCH 51/66] Serialize/deserialize --- src/action_tree.rs | 88 +++++++++++++++++++++++++++++++++++++++++++++- src/card.rs | 10 ++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/src/action_tree.rs b/src/action_tree.rs index 04ddc76..a772701 100644 --- a/src/action_tree.rs +++ b/src/action_tree.rs @@ -55,6 +55,31 @@ pub enum BoardState { River = 2, } +/// Used for default serde value +fn zero_f64() -> f64 { + 0.0 +} + +/// Used for default serde value +fn zero_point_one_f64() -> f64 { + 0.1 +} + +/// Used for default serde value +fn zero_point_two_f64() -> f64 { + 0.2 +} + +/// Used for default serde value +fn two_point_five_f64() -> f64 { + 2.5 +} + +/// Used for default serde value +fn flop() -> BoardState { + BoardState::Flop +} + /// A struct containing the game tree configuration. /// /// # Examples @@ -80,10 +105,11 @@ pub enum BoardState { /// merging_threshold: 0.1, /// }; /// ``` -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct TreeConfig { /// Initial state of the game tree (flop, turn, or river). + #[serde(default = "flop")] pub initial_state: BoardState, /// Starting pot size. Must be greater than `0`. @@ -93,34 +119,43 @@ pub struct TreeConfig { pub effective_stack: i32, /// Rake rate. Must be between `0.0` and `1.0`, inclusive. + #[serde(default = "zero_f64")] pub rake_rate: f64, /// Rake cap. Must be non-negative. + #[serde(default = "zero_f64")] pub rake_cap: f64, /// Bet size options of each player for the flop. + #[serde(default)] pub flop_bet_sizes: [BetSizeOptions; 2], /// Bet size options of each player for the turn. + #[serde(default)] pub turn_bet_sizes: [BetSizeOptions; 2], /// Bet size options of each player for the river. + #[serde(default)] pub river_bet_sizes: [BetSizeOptions; 2], /// Donk size options for the turn (set `None` to use default sizes). + #[serde(default)] pub turn_donk_sizes: Option, /// Donk size options for the river (set `None` to use default sizes). + #[serde(default)] pub river_donk_sizes: Option, /// Add all-in action if the ratio of maximum bet size to the pot is below or equal to this /// value (set `0.0` to disable). + #[serde(default = "two_point_five_f64")] pub add_allin_threshold: f64, /// Force all-in action if the SPR (stack/pot) after the opponent's call is below or equal to /// this value (set `0.0` to disable). /// /// Personal recommendation: between `0.1` and `0.2` + #[serde(default = "zero_point_two_f64")] pub force_allin_threshold: f64, /// Merge bet actions if there are bet actions with "close" values (set `0.0` to disable). @@ -131,6 +166,7 @@ pub struct TreeConfig { /// Continue this process with the next highest bet size. /// /// Personal recommendation: around `0.1` + #[serde(default = "zero_point_one_f64")] pub merging_threshold: f64, } @@ -1094,3 +1130,53 @@ fn merge_bet_actions(actions: Vec, pot: i32, offset: i32, param: f64) -> ret.reverse(); ret } + +#[cfg(test)] +mod tests { + use std::{ + fs::File, + io::{BufWriter, Write}, + }; + + use super::TreeConfig; + + #[test] + pub fn serialize_deserialize_tree_config() { + let tree_config = TreeConfig::default(); + let config_string = serde_json::to_string(&tree_config).unwrap(); + + let path = format!("tree_config_0.json"); + let file = File::create(&path).unwrap(); + + let mut writer = BufWriter::new(&file); + writer.write_all(config_string.as_bytes()).unwrap(); + writer.flush().unwrap(); + + let tree_config_deserialized = std::fs::read_to_string(&path); + assert!(tree_config_deserialized.is_ok()); + + let tree_config_deserialized = serde_json::from_str(&tree_config_deserialized.unwrap()); + assert!( + tree_config_deserialized.is_ok(), + "{:?}", + tree_config_deserialized + ); + + let tree_config_deserialized: TreeConfig = tree_config_deserialized.unwrap(); + assert!(tree_config == tree_config_deserialized); + + std::fs::remove_file(&path).unwrap(); + } + + #[test] + pub fn deserialize_partial_tree_config() { + let json = "{\"starting_pot\":20,\"effective_stack\":200}"; + let tree_config: Result = serde_json::from_str(json); + assert!( + tree_config.is_ok(), + "Unable to read json string \"{}\":\n {:?}", + json, + tree_config + ) + } +} diff --git a/src/card.rs b/src/card.rs index d50f83e..6b180de 100644 --- a/src/card.rs +++ b/src/card.rs @@ -20,6 +20,11 @@ pub type Card = u8; /// Constant representing that the card is not yet dealt. pub const NOT_DEALT: Card = Card::MAX; +/// for serde default +fn not_dealt() -> Card { + NOT_DEALT +} + /// A struct containing the card configuration. /// /// # Examples @@ -46,9 +51,11 @@ pub struct CardConfig { pub flop: [Card; 3], /// Turn card: must be in range [`0`, `52`) or `NOT_DEALT`. + #[serde(default = "not_dealt")] pub turn: Card, /// River card: must be in range [`0`, `52`) or `NOT_DEALT`. + #[serde(default = "not_dealt")] pub river: Card, } @@ -452,4 +459,7 @@ mod tests { } } } + + #[test] + fn test_serialize_deserialize_card_config() {} } From 5980a218c32ff12f2645ddfdcf48acfe9af2aae0 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Thu, 10 Oct 2024 15:33:05 -0400 Subject: [PATCH 52/66] Clippy: Result::or_else -> Result::map_err --- src/range.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/range.rs b/src/range.rs index f65d9b7..febdb5a 100644 --- a/src/range.rs +++ b/src/range.rs @@ -981,10 +981,8 @@ impl<'de> Deserialize<'de> for Range { where E: de::Error, { - Range::from_str(v).or_else(|m| { - Err(de::Error::custom( - format!("Invalid range string \"{}\"\n\n{}", v, m).as_str(), - )) + Range::from_str(v).map_err(|m| { + de::Error::custom(format!("Invalid range string \"{}\"\n\n{}", v, m).as_str()) }) } } From af09e54002ca123c45ba25b5d01dfb061f2d4047 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Fri, 11 Oct 2024 00:36:13 -0400 Subject: [PATCH 53/66] Config serialization/deserialization for PostFlopGame --- src/game/base.rs | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/game/base.rs b/src/game/base.rs index e7501f7..e657bfb 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -1460,4 +1460,49 @@ impl PostFlopGame { pub fn get_state(&self) -> &State { &self.state } + + /// This is a temporary function that is _not_ guaranteed to be supported in + /// future versions. It returns a JSON object with a game's `TreeConfig` and + /// `CardConfig`. + pub fn configs_as_json(&self) -> Result { + let tree_config = serde_json::to_value(self.tree_config()); + let tree_config = tree_config.map_err(|e| { + format!( + "Couldn't serialize TreeConfig {:?} to JSON:\n{}", + self.tree_config(), + e + ) + })?; + let card_config = serde_json::to_value(self.card_config()); + let card_config = card_config.map_err(|e| { + format!( + "Couldn't serialize CardConfig {:?} to JSON:\n{}", + self.card_config(), + e + ) + })?; + let mut map = serde_json::Map::new(); + map.insert("tree_config".to_string(), tree_config); + map.insert("card_config".to_string(), card_config); + let json_config = serde_json::Value::Object(map); + Ok(json_config) + } + + pub fn game_from_configs_json(configs_json: serde_json::Value) -> Result { + let map = configs_json.as_object().ok_or({ + "Config JSON must be a JSON object with keys \"tree_config\" and \"card_config\"" + })?; + let tree_config = map + .get("tree_config") + .ok_or("Config JSON must contain key \"tree_config\"")?; + let card_config = map + .get("card_config") + .ok_or("Config JSON must contain key \"card_config\"")?; + let tree_config: TreeConfig = serde_json::from_value(tree_config.clone()) + .map_err(|_| "Error deserializing tree_config")?; + let card_config: CardConfig = serde_json::from_value(card_config.clone()) + .map_err(|_| "Error deserializing card_config")?; + let action_tree = ActionTree::new(tree_config)?; + PostFlopGame::with_config(card_config, action_tree) + } } From 168db4c80161682571e3601b9ec34899cd73762d Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 13 Oct 2024 12:21:01 -0400 Subject: [PATCH 54/66] Added utility funcitons --- src/card.rs | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/range.rs | 23 +++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/src/card.rs b/src/card.rs index 6b180de..dc313f5 100644 --- a/src/card.rs +++ b/src/card.rs @@ -255,6 +255,56 @@ impl CardConfig { ret } + /// Return the current card configuration with new board cards. + /// + /// # Examples + /// + /// ``` + /// use postflop_solver::*; + /// + /// let oop_range = "66+,A8s+,A5s-A4s,AJo+,K9s+,KQo,QTs+,JTs,96s+,85s+,75s+,65s,54s"; + /// let ip_range = "QQ-22,AQs-A2s,ATo+,K5s+,KJo+,Q8s+,J8s+,T7s+,96s+,86s+,75s+,64s+,53s+"; + /// let ranges = [oop_range.parse().unwrap(), ip_range.parse().unwrap()]; + /// + /// let card_config = CardConfig { + /// range: ranges, + /// flop: flop_from_str("Td9d6h").unwrap(), + /// turn: card_from_str("Qc").unwrap(), + /// river: NOT_DEALT, + /// }; + /// + /// let cards = cards_from_str("Th9d3c4h").unwrap(); + /// let card_config2 = card_config.with_cards(cards).unwrap(); + /// assert_eq!(card_config2.range, ranges); + /// assert_eq!(card_config2.flop, [34, 29, 4]); + /// assert_eq!(card_config2.turn, 10); + /// assert_eq!(card_config2.river, NOT_DEALT); + /// ``` + pub fn with_cards(&self, cards: Vec) -> Result { + let num_cards = + 3 + ((self.turn != NOT_DEALT) as usize) + ((self.river != NOT_DEALT) as usize); + if cards.len() != num_cards { + Err(format!( + "Current CardConfig has {} cards but supplied cards list {:?} has {} cards", + num_cards, + cards, + cards.len() + )) + } else { + let turn = cards.get(3).unwrap_or_else(|| &NOT_DEALT); + let river = cards.get(4).unwrap_or_else(|| &NOT_DEALT); + let mut flop: [Card; 3] = [cards[0], cards[1], cards[2]]; + flop.sort_by(|a, b| b.partial_cmp(a).unwrap()); + + Ok(Self { + range: self.range.clone(), + flop: flop, + turn: *turn, + river: *river, + }) + } + } + pub(crate) fn isomorphism(&self, private_cards: &[Vec<(Card, Card)>; 2]) -> IsomorphismData { let mut suit_isomorphism = [0; 4]; let mut next_index = 1; diff --git a/src/range.rs b/src/range.rs index febdb5a..eac5953 100644 --- a/src/range.rs +++ b/src/range.rs @@ -304,6 +304,29 @@ pub fn card_from_str(s: &str) -> Result { Ok(result) } +/// Attempts to convert an optionally space-separated string into an unsorted +/// card vec. +/// +/// # Examples +/// ``` +/// use postflop_solver::cards_from_str; +/// +/// assert_eq!(cards_from_str("2c3d4h"), Ok(vec![0, 5, 10])); +/// assert_eq!(cards_from_str("As Ah Ks"), Ok(vec![51, 50, 47])); +/// ``` +pub fn cards_from_str(s: &str) -> Result, String> { + let chars = s.chars(); + let mut result = vec![]; + + let mut chars = chars.peekable(); + while chars.peek().is_some() { + result.push(card_from_chars( + &mut chars.by_ref().skip_while(|c| c.is_whitespace()), + )?); + } + Ok(result) +} + /// Attempts to convert an optionally space-separated string into a sorted flop array. /// /// # Examples From cca82c65279fc31a919dcdcc2c5d7ec5770ab3e5 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 13 Oct 2024 12:21:18 -0400 Subject: [PATCH 55/66] Started working on batch solve --- Cargo.toml | 3 +- examples/batch_solve.rs | 104 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 examples/batch_solve.rs diff --git a/Cargo.toml b/Cargo.toml index 1a1bbe3..a470c69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,12 +9,13 @@ repository = "https://github.com/bkushigian/postflop-solver" license = "AGPL-3.0-or-later" [dependencies] +clap = { version = "4.5", features = ["derive"] } bincode = { version = "2.0.0-rc.3", optional = true } once_cell = "1.18.0" rayon = { version = "1.8.0", optional = true } regex = "1.9.6" zstd = { version = "0.12.4", optional = true, default-features = false } -serde = {version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" [features] diff --git a/examples/batch_solve.rs b/examples/batch_solve.rs new file mode 100644 index 0000000..d5912d7 --- /dev/null +++ b/examples/batch_solve.rs @@ -0,0 +1,104 @@ +use std::path::PathBuf; + +use clap::Parser; +use postflop_solver::{cards_from_str, solve, ActionTree, CardConfig, PostFlopGame, TreeConfig}; + +/// Simple program to greet a person +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + /// Path to configuration file + #[arg(required = true)] + config: String, + + /// Boards to run on + #[arg(short, long)] + boards: Option>, + + /// File with boards to run on + #[arg(short, long)] + boards_file: Option, + + /// Directory to output solves to + #[arg(short, long, default_value = ".")] + dir: String, + + /// Max number of iterations to run + #[arg(short = 'n', long, default_value = "1000")] + max_iterations: u32, + + /// Default exploitability as ratio of pot. Defaults to 0.2 (20% of pot), + /// but for accurate solves we recommend choosing a lower value. + #[arg(short = 'e', long, default_value = "0.2")] + exploitability: f32, +} + +fn main() { + let args = Args::parse(); + + let config = std::fs::read_to_string(args.config).expect("Unable to read in config"); + + let boards = if let Some(boards) = args.boards { + boards + } else { + let boards_files = args + .boards_file + .expect("Must specify boards or boards_file"); + let boards_contents = + std::fs::read_to_string(boards_files).expect("Unable to read boards_file"); + boards_contents + .lines() + .map(|s| s.to_string()) + .collect::>() + }; + let configs_json: serde_json::Value = + serde_json::from_str(&config).expect("Unable to parse config"); + let configs_map = configs_json.as_object().expect("Expected a json object"); + + let card_config = configs_map.get("card_config").unwrap(); + let card_config: CardConfig = serde_json::from_value(card_config.clone()).unwrap(); + + let tree_config = configs_map.get("tree_config").unwrap(); + let tree_config: TreeConfig = serde_json::from_value(tree_config.clone()).unwrap(); + + // Create output directory if needed. Check if ".pfs" files exist, and if so abort + let dir = PathBuf::from(args.dir); + if dir.exists() { + if !dir.is_dir() { + panic!( + "output directory {} exists but is not a directory", + dir.to_str().unwrap() + ); + } + for board in &boards { + // create board file name + let board_file_name = board + .chars() + .filter(|c| !c.is_whitespace()) + .collect::(); + let board_path = dir.join(board_file_name).with_extension("pfs"); + if board_path.exists() { + panic!("board path {} already exists", board_path.to_string_lossy()); + } + } + } else { + std::fs::create_dir_all(&dir).unwrap(); + } + + for board in &boards { + let cards = + cards_from_str(&board).expect(format!("Couldn't parse board {}", board).as_str()); + + let mut game = PostFlopGame::with_config( + card_config.with_cards(cards).unwrap(), + ActionTree::new(tree_config.clone()).unwrap(), + ) + .unwrap(); + + game.allocate_memory(false); + + let max_num_iterations = args.max_iterations; + let target_exploitability = game.tree_config().starting_pot as f32 * args.exploitability; + solve(&mut game, max_num_iterations, target_exploitability, true); + } +} From d262840d71129b7e9cc12b09a28b57bfd7caa9c7 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 13 Oct 2024 14:37:13 -0400 Subject: [PATCH 56/66] Updates to batch_solve example --- examples/batch_solve.rs | 60 +++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/examples/batch_solve.rs b/examples/batch_solve.rs index d5912d7..b45e68a 100644 --- a/examples/batch_solve.rs +++ b/examples/batch_solve.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use clap::Parser; use postflop_solver::{cards_from_str, solve, ActionTree, CardConfig, PostFlopGame, TreeConfig}; @@ -33,7 +33,7 @@ struct Args { exploitability: f32, } -fn main() { +fn main() -> Result<(), String> { let args = Args::parse(); let config = std::fs::read_to_string(args.config).expect("Unable to read in config"); @@ -63,27 +63,8 @@ fn main() { // Create output directory if needed. Check if ".pfs" files exist, and if so abort let dir = PathBuf::from(args.dir); - if dir.exists() { - if !dir.is_dir() { - panic!( - "output directory {} exists but is not a directory", - dir.to_str().unwrap() - ); - } - for board in &boards { - // create board file name - let board_file_name = board - .chars() - .filter(|c| !c.is_whitespace()) - .collect::(); - let board_path = dir.join(board_file_name).with_extension("pfs"); - if board_path.exists() { - panic!("board path {} already exists", board_path.to_string_lossy()); - } - } - } else { - std::fs::create_dir_all(&dir).unwrap(); - } + setup_output_directory(&dir)?; + ensure_no_conflicts_in_output_dir(&dir, &boards)?; for board in &boards { let cards = @@ -101,4 +82,37 @@ fn main() { let target_exploitability = game.tree_config().starting_pot as f32 * args.exploitability; solve(&mut game, max_num_iterations, target_exploitability, true); } + Ok(()) +} + +fn setup_output_directory(dir: &Path) -> Result<(), String> { + if dir.exists() { + if !dir.is_dir() { + panic!( + "output directory {} exists but is not a directory", + dir.to_str().unwrap() + ); + } + Ok(()) + } else { + std::fs::create_dir_all(&dir).map_err(|_| "Couldn't create dir".to_string()) + } +} + +fn ensure_no_conflicts_in_output_dir(dir: &Path, boards: &[String]) -> Result<(), String> { + for board in boards { + // create board file name + let board_file_name = board + .chars() + .filter(|c| !c.is_whitespace()) + .collect::(); + let board_path = dir.join(board_file_name).with_extension("pfs"); + if board_path.exists() { + return Err(format!( + "board path {} already exists", + board_path.to_string_lossy() + )); + } + } + Ok(()) } From e726ad19d393f92f0bd10870df42f2e089454b7a Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 13 Oct 2024 15:14:55 -0400 Subject: [PATCH 57/66] Fixed doc tests --- src/action_tree.rs | 4 +- src/bet_size.rs | 126 +++++++++++++++++++++++++++++++-------------- 2 files changed, 88 insertions(+), 42 deletions(-) diff --git a/src/action_tree.rs b/src/action_tree.rs index a772701..27b520a 100644 --- a/src/action_tree.rs +++ b/src/action_tree.rs @@ -631,7 +631,7 @@ impl ActionTree { actions.push(Action::Check); // bet - for &bet_size in &bet_options[player as usize].bet { + for &bet_size in bet_options[player as usize].bets() { match bet_size { BetSize::PotRelative(ratio) => { let amount = (pot as f64 * ratio).round() as i32; @@ -664,7 +664,7 @@ impl ActionTree { if !info.allin_flag { // raise - for &bet_size in &bet_options[player as usize].raise { + for &bet_size in bet_options[player as usize].raises() { match bet_size { BetSize::PotRelative(ratio) => { let amount = prev_amount + (pot as f64 * ratio).round() as i32; diff --git a/src/bet_size.rs b/src/bet_size.rs index 7830ffa..1c2596e 100644 --- a/src/bet_size.rs +++ b/src/bet_size.rs @@ -27,7 +27,7 @@ use serde::{Deserialize, Serialize}; /// let bet_size = BetSizeOptions::try_from(("50%, 100c, 2e, a", "2.5x")).unwrap(); /// /// assert_eq!( -/// bet_size.bet, +/// bet_size.bets(), /// vec![ /// PotRelative(0.5), /// Additive(100, 0), @@ -36,16 +36,16 @@ use serde::{Deserialize, Serialize}; /// ] /// ); /// -/// assert_eq!(bet_size.raise, vec![PrevBetRelative(2.5)]); +/// assert_eq!(bet_size.raises(), vec![PrevBetRelative(2.5)]); /// ``` #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct BetSizeOptions { /// Bet size options for first bet. - pub bet: Vec, + bet: Vec, /// Bet size options for raise. - pub raise: Vec, + raise: Vec, } /// Bet size options for the donk bets. @@ -81,6 +81,55 @@ pub enum BetSize { AllIn, } +impl BetSizeOptions { + /// Tries to create a `BetSizeOptions` from two `BetSize` vecs. + /// + /// # Errors + /// + /// Returns `Err` when: + /// - `bets` contains a `BetSize::Relative` bet size + /// - `bets` contains an `BetSize::Additive(_, cap)` with non-zero `cap` + pub fn try_from_sizes(bets: Vec, raises: Vec) -> Result { + Ok(BetSizeOptions { + bet: BetSizeOptions::as_valid_bets(bets)?, + raise: raises, + }) + } + + /// Check `bets` for well-formedness (no sizes relative to previous bet and + /// no raise caps) and return it. Return an `Err` if: + /// - `bets` contains a `BetSize::Relative` bet size + /// - `bets` contains an `BetSize::Additive(_, cap)` with non-zero `cap` + pub fn as_valid_bets(bets: Vec) -> Result, String> { + for bs in bets.iter() { + match &bs { + BetSize::PrevBetRelative(_) => { + let err_msg = "bets cannot contain `BetSize::PrevBetRelative".to_string(); + return Err(err_msg); + } + BetSize::Additive(_, cap) => { + if cap != &0 { + let err_msg = + "bets cannot contain additive bet sizes with non-zero raise caps" + .to_string(); + return Err(err_msg); + } + } + _ => (), + } + } + Ok(bets) + } + + pub fn bets(&self) -> &[BetSize] { + &self.bet + } + + pub fn raises(&self) -> &[BetSize] { + &self.raise + } +} + impl TryFrom<(&str, &str)> for BetSizeOptions { type Error = String; @@ -103,17 +152,19 @@ impl TryFrom<(&str, &str)> for BetSizeOptions { let mut raise = Vec::new(); for bet_size in bet_sizes { - bet.push(bet_size_from_str(bet_size, false)?); + bet.push(bet_size_from_str(bet_size)?); } for raise_size in raise_sizes { - raise.push(bet_size_from_str(raise_size, true)?); + raise.push(bet_size_from_str(raise_size)?); } + // Check for ill-formed bet sizes. This includes + // - bet sizes with relative amounts (e.g., "3x") bet.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap()); raise.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap()); - Ok(BetSizeOptions { bet, raise }) + Self::try_from_sizes(bet, raise) } } @@ -133,12 +184,14 @@ impl TryFrom<&str> for DonkSizeOptions { let mut donk = Vec::new(); for donk_size in donk_sizes { - donk.push(bet_size_from_str(donk_size, false)?); + donk.push(bet_size_from_str(donk_size)?); } donk.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap()); - Ok(DonkSizeOptions { donk }) + Ok(DonkSizeOptions { + donk: BetSizeOptions::as_valid_bets(donk)?, + }) } } @@ -150,23 +203,18 @@ fn parse_float(s: &str) -> Option { } } -fn bet_size_from_str(s: &str, is_raise: bool) -> Result { +fn bet_size_from_str(s: &str) -> Result { let s_lower = s.to_lowercase(); let err_msg = format!("Invalid bet size: {s}"); if let Some(prev_bet_rel) = s_lower.strip_suffix('x') { // Previous bet relative - if !is_raise { - let err_msg = format!("Relative size to the previous bet is not allowed: {s}"); + let float = parse_float(prev_bet_rel).ok_or(&err_msg)?; + if float <= 1.0 { + let err_msg = format!("Multiplier must be greater than 1.0: {s}"); Err(err_msg) } else { - let float = parse_float(prev_bet_rel).ok_or(&err_msg)?; - if float <= 1.0 { - let err_msg = format!("Multiplier must be greater than 1.0: {s}"); - Err(err_msg) - } else { - Ok(BetSize::PrevBetRelative(float)) - } + Ok(BetSize::PrevBetRelative(float)) } } else if s_lower.contains('c') { // Additive @@ -185,10 +233,6 @@ fn bet_size_from_str(s: &str, is_raise: bool) -> Result { let cap = if cap_str.is_empty() { 0 } else { - if !is_raise { - let err_msg = format!("Raise cap is not allowed: {s}"); - return Err(err_msg); - } let float_str = cap_str.strip_suffix('r').ok_or(&err_msg)?; let float = parse_float(float_str).ok_or(&err_msg)?; if float.trunc() != float || float == 0.0 { @@ -278,18 +322,18 @@ mod tests { ]; for (s, expected) in tests { - assert_eq!(bet_size_from_str(s, true), Ok(expected)); + assert_eq!(bet_size_from_str(s), Ok(expected)); } - let error_tests = [ - "", "0", "1.23", "%", "+42%", "-30%", "x", "0x", "1x", "c", "12.3c", "10c10", "42cr", - "c3r", "0c0r", "123c101r", "1c2r3", "12c3.4r", "0e", "2.7e", "101e", "3e7", "E%", - "1e2e3", "bet", "1a", "a1", - ]; + // let error_tests = [ + // "", "0", "1.23", "%", "+42%", "-30%", "x", "0x", "1x", "c", "12.3c", "10c10", "42cr", + // "c3r", "0c0r", "123c101r", "1c2r3", "12c3.4r", "0e", "2.7e", "101e", "3e7", "E%", + // "1e2e3", "bet", "1a", "a1", + // ]; - for s in error_tests { - assert!(bet_size_from_str(s, true).is_err()); - } + // for s in error_tests { + // assert!(bet_size_from_str(s, true).is_err()); + // } } #[test] @@ -298,18 +342,20 @@ mod tests { ( "40%, 70%", "", - BetSizeOptions { - bet: vec![PotRelative(0.4), PotRelative(0.7)], - raise: Vec::new(), - }, + BetSizeOptions::try_from_sizes( + vec![PotRelative(0.4), PotRelative(0.7)], + Vec::new(), + ) + .unwrap(), ), ( "50c, e, a,", "25%, 2.5x, e200%", - BetSizeOptions { - bet: vec![Additive(50, 0), Geometric(0, f64::INFINITY), AllIn], - raise: vec![PotRelative(0.25), PrevBetRelative(2.5), Geometric(0, 2.0)], - }, + BetSizeOptions::try_from_sizes( + vec![Additive(50, 0), Geometric(0, f64::INFINITY), AllIn], + vec![PotRelative(0.25), PrevBetRelative(2.5), Geometric(0, 2.0)], + ) + .unwrap(), ), ]; From a2158ff417ab29a773b339dbfcad714e486852c6 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 13 Oct 2024 15:30:32 -0400 Subject: [PATCH 58/66] Removed old comment --- src/bet_size.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/bet_size.rs b/src/bet_size.rs index 1c2596e..d1558f7 100644 --- a/src/bet_size.rs +++ b/src/bet_size.rs @@ -158,8 +158,6 @@ impl TryFrom<(&str, &str)> for BetSizeOptions { for raise_size in raise_sizes { raise.push(bet_size_from_str(raise_size)?); } - // Check for ill-formed bet sizes. This includes - // - bet sizes with relative amounts (e.g., "3x") bet.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap()); raise.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap()); From 04d40abfff8a270b536b40229e99329653d7dfd0 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 13 Oct 2024 15:32:16 -0400 Subject: [PATCH 59/66] clippy --- src/card.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/card.rs b/src/card.rs index dc313f5..0d7bdfa 100644 --- a/src/card.rs +++ b/src/card.rs @@ -291,14 +291,14 @@ impl CardConfig { cards.len() )) } else { - let turn = cards.get(3).unwrap_or_else(|| &NOT_DEALT); - let river = cards.get(4).unwrap_or_else(|| &NOT_DEALT); + let turn = cards.get(3).unwrap_or(&NOT_DEALT); + let river = cards.get(4).unwrap_or(&NOT_DEALT); let mut flop: [Card; 3] = [cards[0], cards[1], cards[2]]; flop.sort_by(|a, b| b.partial_cmp(a).unwrap()); Ok(Self { - range: self.range.clone(), - flop: flop, + range: self.range, + flop, turn: *turn, river: *river, }) From 20d0dc3cbe347bf442a8bc1dc7d1e5d9940f7228 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 13 Oct 2024 15:36:11 -0400 Subject: [PATCH 60/66] Made DonkSizeOptions non-public --- src/action_tree.rs | 2 +- src/bet_size.rs | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/action_tree.rs b/src/action_tree.rs index 27b520a..f99f189 100644 --- a/src/action_tree.rs +++ b/src/action_tree.rs @@ -599,7 +599,7 @@ impl ActionTree { actions.push(Action::Check); // donk bet - for &donk_size in &donk_options.as_ref().unwrap().donk { + for &donk_size in donk_options.as_ref().unwrap().donks() { match donk_size { BetSize::PotRelative(ratio) => { let amount = (pot as f64 * ratio).round() as i32; diff --git a/src/bet_size.rs b/src/bet_size.rs index d1558f7..a5f26c8 100644 --- a/src/bet_size.rs +++ b/src/bet_size.rs @@ -54,7 +54,7 @@ pub struct BetSizeOptions { #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct DonkSizeOptions { - pub donk: Vec, + donk: Vec, } /// Bet size specification. @@ -166,6 +166,12 @@ impl TryFrom<(&str, &str)> for BetSizeOptions { } } +impl DonkSizeOptions { + pub fn donks(&self) -> &[BetSize] { + &self.donk + } +} + impl TryFrom<&str> for DonkSizeOptions { type Error = String; From 78597fef3aa9d18bdc88ddfa3d5482b927604547 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 13 Oct 2024 15:37:13 -0400 Subject: [PATCH 61/66] Pluralized BetSizeOptions and DonkOptions field names --- src/bet_size.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/bet_size.rs b/src/bet_size.rs index a5f26c8..70a04f1 100644 --- a/src/bet_size.rs +++ b/src/bet_size.rs @@ -42,10 +42,10 @@ use serde::{Deserialize, Serialize}; #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct BetSizeOptions { /// Bet size options for first bet. - bet: Vec, + bets: Vec, /// Bet size options for raise. - raise: Vec, + raises: Vec, } /// Bet size options for the donk bets. @@ -91,8 +91,8 @@ impl BetSizeOptions { /// - `bets` contains an `BetSize::Additive(_, cap)` with non-zero `cap` pub fn try_from_sizes(bets: Vec, raises: Vec) -> Result { Ok(BetSizeOptions { - bet: BetSizeOptions::as_valid_bets(bets)?, - raise: raises, + bets: BetSizeOptions::as_valid_bets(bets)?, + raises, }) } @@ -122,11 +122,11 @@ impl BetSizeOptions { } pub fn bets(&self) -> &[BetSize] { - &self.bet + &self.bets } pub fn raises(&self) -> &[BetSize] { - &self.raise + &self.raises } } From 245df956c461cdf9f492c840b0a924635cee26e4 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 13 Oct 2024 15:38:42 -0400 Subject: [PATCH 62/66] Finished renaming fields --- src/bet_size.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/bet_size.rs b/src/bet_size.rs index 70a04f1..7fc0f89 100644 --- a/src/bet_size.rs +++ b/src/bet_size.rs @@ -54,7 +54,7 @@ pub struct BetSizeOptions { #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct DonkSizeOptions { - donk: Vec, + donks: Vec, } /// Bet size specification. @@ -168,7 +168,7 @@ impl TryFrom<(&str, &str)> for BetSizeOptions { impl DonkSizeOptions { pub fn donks(&self) -> &[BetSize] { - &self.donk + &self.donks } } @@ -194,7 +194,7 @@ impl TryFrom<&str> for DonkSizeOptions { donk.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap()); Ok(DonkSizeOptions { - donk: BetSizeOptions::as_valid_bets(donk)?, + donks: BetSizeOptions::as_valid_bets(donk)?, }) } } @@ -380,13 +380,13 @@ mod tests { ( "40%, 70%", DonkSizeOptions { - donk: vec![PotRelative(0.4), PotRelative(0.7)], + donks: vec![PotRelative(0.4), PotRelative(0.7)], }, ), ( "50c, e, a,", DonkSizeOptions { - donk: vec![Additive(50, 0), Geometric(0, f64::INFINITY), AllIn], + donks: vec![Additive(50, 0), Geometric(0, f64::INFINITY), AllIn], }, ), ]; From 40df53bd4febbb2ef198731add15a49c48757512 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sun, 13 Oct 2024 15:41:32 -0400 Subject: [PATCH 63/66] Uncommented err test --- src/bet_size.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/bet_size.rs b/src/bet_size.rs index 7fc0f89..4e95b68 100644 --- a/src/bet_size.rs +++ b/src/bet_size.rs @@ -329,15 +329,15 @@ mod tests { assert_eq!(bet_size_from_str(s), Ok(expected)); } - // let error_tests = [ - // "", "0", "1.23", "%", "+42%", "-30%", "x", "0x", "1x", "c", "12.3c", "10c10", "42cr", - // "c3r", "0c0r", "123c101r", "1c2r3", "12c3.4r", "0e", "2.7e", "101e", "3e7", "E%", - // "1e2e3", "bet", "1a", "a1", - // ]; - - // for s in error_tests { - // assert!(bet_size_from_str(s, true).is_err()); - // } + let error_tests = [ + "", "0", "1.23", "%", "+42%", "-30%", "x", "0x", "1x", "c", "12.3c", "10c10", "42cr", + "c3r", "0c0r", "123c101r", "1c2r3", "12c3.4r", "0e", "2.7e", "101e", "3e7", "E%", + "1e2e3", "bet", "1a", "a1", + ]; + + for s in error_tests { + assert!(bet_size_from_str(s).is_err()); + } } #[test] From 3a80ad619fdca3b3457b46bb7718bad0d239efc1 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Mon, 14 Oct 2024 12:40:05 -0400 Subject: [PATCH 64/66] Got a compiling serilize/deserialize bet size --- src/bet_size.rs | 132 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 90 insertions(+), 42 deletions(-) diff --git a/src/bet_size.rs b/src/bet_size.rs index 4e95b68..f0ba5bf 100644 --- a/src/bet_size.rs +++ b/src/bet_size.rs @@ -1,6 +1,6 @@ #[cfg(feature = "bincode")] use bincode::{Decode, Encode}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; /// Bet size options for the first bets and raises. /// @@ -42,9 +42,13 @@ use serde::{Deserialize, Serialize}; #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct BetSizeOptions { /// Bet size options for first bet. + #[serde(deserialize_with = "deserialize_bet_sizes", default)] + #[serde(serialize_with = "serialize_bet_sizes")] bets: Vec, /// Bet size options for raise. + #[serde(deserialize_with = "deserialize_bet_sizes", default)] + #[serde(serialize_with = "serialize_bet_sizes")] raises: Vec, } @@ -54,12 +58,15 @@ pub struct BetSizeOptions { #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] pub struct DonkSizeOptions { + #[serde(deserialize_with = "deserialize_bet_sizes", default)] + #[serde(serialize_with = "serialize_bet_sizes")] donks: Vec, } /// Bet size specification. #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Serialize, Deserialize)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] +#[serde(try_from = "&str")] pub enum BetSize { /// Bet size relative to the current pot size. PotRelative(f64), @@ -130,6 +137,14 @@ impl BetSizeOptions { } } +impl TryFrom<&str> for BetSize { + type Error = String; + + fn try_from(s: &str) -> Result { + bet_size_from_str(s) + } +} + impl TryFrom<(&str, &str)> for BetSizeOptions { type Error = String; @@ -137,32 +152,7 @@ impl TryFrom<(&str, &str)> for BetSizeOptions { /// /// See the [`BetSizeOptions`] struct for the description and examples. fn try_from((bet_str, raise_str): (&str, &str)) -> Result { - let mut bet_sizes = bet_str.split(',').map(str::trim).collect::>(); - let mut raise_sizes = raise_str.split(',').map(str::trim).collect::>(); - - if bet_sizes.last().unwrap().is_empty() { - bet_sizes.pop(); - } - - if raise_sizes.last().unwrap().is_empty() { - raise_sizes.pop(); - } - - let mut bet = Vec::new(); - let mut raise = Vec::new(); - - for bet_size in bet_sizes { - bet.push(bet_size_from_str(bet_size)?); - } - - for raise_size in raise_sizes { - raise.push(bet_size_from_str(raise_size)?); - } - - bet.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap()); - raise.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap()); - - Self::try_from_sizes(bet, raise) + Self::try_from_sizes(bet_sizes_from_str(bet_str)?, bet_sizes_from_str(raise_str)?) } } @@ -179,23 +169,39 @@ impl TryFrom<&str> for DonkSizeOptions { /// /// See the [`BetSizeOptions`] struct for the description and examples. fn try_from(donk_str: &str) -> Result { - let mut donk_sizes = donk_str.split(',').map(str::trim).collect::>(); - - if donk_sizes.last().unwrap().is_empty() { - donk_sizes.pop(); - } - - let mut donk = Vec::new(); + let donks = bet_sizes_from_str(donk_str)?; + let donks = BetSizeOptions::as_valid_bets(donks)?; + Ok(DonkSizeOptions { donks }) + } +} - for donk_size in donk_sizes { - donk.push(bet_size_from_str(donk_size)?); +impl From for String { + fn from(bet_size: BetSize) -> Self { + match bet_size { + BetSize::PotRelative(x) => format!("{}%", x), + BetSize::PrevBetRelative(x) => format!("{}x", x), + BetSize::Additive(c, r) => { + if r != 0 { + format!("{}c{}r", c, r) + } else { + format!("{}c", c) + } + } + BetSize::Geometric(n, r) => { + if n == 0 { + if r == f64::INFINITY { + "e".to_string() + } else { + format!("e{}", r * 100.0) + } + } else if r == f64::INFINITY { + format!("{}e", n) + } else { + format!("{}e{}", n, r) + } + } + BetSize::AllIn => "a".to_string(), } - - donk.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap()); - - Ok(DonkSizeOptions { - donks: BetSizeOptions::as_valid_bets(donk)?, - }) } } @@ -207,6 +213,31 @@ fn parse_float(s: &str) -> Option { } } +fn bet_sizes_from_str(bets_str: &str) -> Result, String> { + let mut bet_sizes = bets_str.split(',').map(str::trim).collect::>(); + + if bet_sizes.last().unwrap().is_empty() { + bet_sizes.pop(); + } + + let mut bets = Vec::new(); + + for bet_size in bet_sizes { + bets.push(bet_size_from_str(bet_size)?); + } + + bets.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap()); + + Ok(bets) +} + +fn deserialize_bet_sizes<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + Vec::::deserialize(deserializer) +} + fn bet_size_from_str(s: &str) -> Result { let s_lower = s.to_lowercase(); let err_msg = format!("Invalid bet size: {s}"); @@ -299,6 +330,23 @@ fn bet_size_from_str(s: &str) -> Result { } } +pub fn bet_size_to_string(bs: &BetSize) -> String { + String::from(*bs) +} + +pub fn serialize_bet_sizes(bs: &[BetSize], s: S) -> Result +where + S: Serializer, +{ + s.serialize_str( + bs.iter() + .map(|b| String::from(*b)) + .collect::>() + .join(",") + .as_str(), + ) +} + #[cfg(test)] mod tests { use super::BetSize::*; From f1142c845f9e8a09a5eb9e6218d6a8cb1dcda69d Mon Sep 17 00:00:00 2001 From: bkushigian Date: Mon, 14 Oct 2024 21:31:30 -0400 Subject: [PATCH 65/66] Serialization and deserialization works --- src/bet_size.rs | 8 +++++--- src/game/base.rs | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/bet_size.rs b/src/bet_size.rs index f0ba5bf..50d1739 100644 --- a/src/bet_size.rs +++ b/src/bet_size.rs @@ -1,6 +1,6 @@ #[cfg(feature = "bincode")] use bincode::{Decode, Encode}; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; /// Bet size options for the first bets and raises. /// @@ -178,7 +178,7 @@ impl TryFrom<&str> for DonkSizeOptions { impl From for String { fn from(bet_size: BetSize) -> Self { match bet_size { - BetSize::PotRelative(x) => format!("{}%", x), + BetSize::PotRelative(x) => format!("{}%", 100.0 * x), BetSize::PrevBetRelative(x) => format!("{}x", x), BetSize::Additive(c, r) => { if r != 0 { @@ -235,7 +235,9 @@ fn deserialize_bet_sizes<'de, D>(deserializer: D) -> Result, D::Err where D: Deserializer<'de>, { - Vec::::deserialize(deserializer) + let s = String::deserialize(deserializer)?; + let bet_sizes = bet_sizes_from_str(&s); + bet_sizes.map_err(de::Error::custom) } fn bet_size_from_str(s: &str) -> Result { diff --git a/src/game/base.rs b/src/game/base.rs index e657bfb..f5282ec 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -1488,7 +1488,9 @@ impl PostFlopGame { Ok(json_config) } - pub fn game_from_configs_json(configs_json: serde_json::Value) -> Result { + pub fn game_from_configs_json( + configs_json: &serde_json::Value, + ) -> Result { let map = configs_json.as_object().ok_or({ "Config JSON must be a JSON object with keys \"tree_config\" and \"card_config\"" })?; @@ -1499,9 +1501,9 @@ impl PostFlopGame { .get("card_config") .ok_or("Config JSON must contain key \"card_config\"")?; let tree_config: TreeConfig = serde_json::from_value(tree_config.clone()) - .map_err(|_| "Error deserializing tree_config")?; + .map_err(|e| format!("Error deserializing tree_config: {:?}", e))?; let card_config: CardConfig = serde_json::from_value(card_config.clone()) - .map_err(|_| "Error deserializing card_config")?; + .map_err(|e| format!("Error deserializing card_config: {:?}", e))?; let action_tree = ActionTree::new(tree_config)?; PostFlopGame::with_config(card_config, action_tree) } From 003668efd062727f20382d6ba429169871346d4a Mon Sep 17 00:00:00 2001 From: bkushigian Date: Wed, 16 Oct 2024 07:15:40 -0400 Subject: [PATCH 66/66] Clippyw orkaround --- src/range.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/range.rs b/src/range.rs index eac5953..71f0833 100644 --- a/src/range.rs +++ b/src/range.rs @@ -993,6 +993,8 @@ impl<'de> Deserialize<'de> for Range { { struct RangeVisitor; + // A workaround in a clippy bug + #[allow(clippy::needless_lifetimes)] impl<'de> Visitor<'de> for RangeVisitor { type Value = Range;