-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconnectn-player.py
39 lines (30 loc) · 1.34 KB
/
connectn-player.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
from models.dqn import DQNAgent
from games.connectn_driver import Driver
from games.connectn import Game, RED, YELLOW
from models.player import Player
from utils.helpers import args
EPISODES = 20000
MAX_MOVES = 20 # maximum number of moves in a game
# game parameters
COLS=5
ROWS=6
# create the game
baseGame = Game(cols=COLS, rows=ROWS)
game = Driver(base_game=baseGame)
game.test()
# setup the agent # REMEMBER set correct state size, state has to be flat (1,)
state_size = game.state().shape[1]
action_size = COLS # [left, right, straight]
agentRed = DQNAgent(state_size, action_size, epsilon=args.start_epsilon,
epsilon_decay=0.99, epsilon_min=0.1, batch_size=32)
redPlayer = Player(game=game, max_moves=MAX_MOVES,
name='connectn-red-qlearner', agent=agentRed,
role=RED, log='wins')
agentYellow = DQNAgent(state_size, action_size, epsilon=args.start_epsilon,
epsilon_decay=0.99, epsilon_min=0.01, batch_size=32)
yellowPlayer = Player(game=game, max_moves=MAX_MOVES,
name='connectn-yellow-qlearner', agent=agentYellow,
role=YELLOW, log='wins')
# TODO factor out role from player
redPlayer.train(episodes=EPISODES, resume=args.save_resume, save_freq=100, plot_freq=args.plot_freq, show=args.show, opponent=yellowPlayer)