diff --git a/src/schnapsen/bots/__init__.py b/src/schnapsen/bots/__init__.py index dc1cbd0..f15e9dc 100644 --- a/src/schnapsen/bots/__init__.py +++ b/src/schnapsen/bots/__init__.py @@ -7,5 +7,6 @@ from .ml_bot import MLDataBot, MLPlayingBot, train_ML_model from .gui.guibot import SchnapsenServer from .minimax import MiniMaxBot +from .two_stage_bot import TwoStageBot -__all__ = ["RandBot", "AlphaBetaBot", "RdeepBot", "MLDataBot", "MLPlayingBot", "train_ML_model", "SchnapsenServer", "MiniMaxBot"] +__all__ = ["RandBot", "AlphaBetaBot", "RdeepBot", "MLDataBot", "MLPlayingBot", "train_ML_model", "SchnapsenServer", "MiniMaxBot", "TwoStageBot"] diff --git a/src/schnapsen/bots/two_stage_bot.py b/src/schnapsen/bots/two_stage_bot.py new file mode 100644 index 0000000..ab37585 --- /dev/null +++ b/src/schnapsen/bots/two_stage_bot.py @@ -0,0 +1,27 @@ +"""Two stage bot""" +from typing import Optional +from schnapsen.game import ( + Bot, + Move, + PlayerPerspective, + GamePhase, +) + + +class TwoStageBot(Bot): + """Bot which plays first the one, than the other startegy""" + + def __init__(self, bot1: Bot, bot2: Bot, name: Optional[str] = None) -> None: + super().__init__(name) + self.bot_phase1: Bot = bot1 + self.bot_phase2: Bot = bot2 + + def get_move( + self, perspective: PlayerPerspective, leader_move: Optional[Move] + ) -> Move: + if perspective.get_phase() == GamePhase.ONE: + return self.bot_phase1.get_move(perspective, leader_move) + elif perspective.get_phase() == GamePhase.TWO: + return self.bot_phase2.get_move(perspective, leader_move) + else: + raise AssertionError("Phase ain't right.")