diff --git a/baselines/baselines_utils.py b/baselines/baselines_utils.py new file mode 100644 index 00000000..553107b2 --- /dev/null +++ b/baselines/baselines_utils.py @@ -0,0 +1,50 @@ +from os.path import exists +from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3 import PPO +from stable_baselines3.common.utils import set_random_seed + +from stream_agent_wrapper import StreamWrapper +from red_gym_env import RedGymEnv + + +def load_or_create_model(model_to_load_path, env_config, total_timesteps, num_cpu): + + env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) + if exists(model_to_load_path + '.zip'): + print('\nloading checkpoint') + model = PPO.load(model_to_load_path, env=env) + model.n_steps = total_timesteps + model.n_envs = num_cpu + model.rollout_buffer.buffer_size = total_timesteps + model.rollout_buffer.n_envs = num_cpu + model.rollout_buffer.reset() + else: + model = PPO('CnnPolicy', env, verbose=1, n_steps=total_timesteps, batch_size=512, n_epochs=1, gamma=0.999, tensorboard_log=env_config['session_path']) + + return model + + +def make_env(rank, env_config, seed=0): + """ + Utility function for multiprocessed env. + :param env_id: (str) the environment ID + :param num_env: (int) the number of environments you wish to have in subprocesses + :param seed: (int) the initial seed for RNG + :param rank: (int) index of the subprocess + """ + def _init(): + env = RedGymEnv(env_config) + env.reset(seed=(seed + rank)) + if env_config['stream'] is True: + env = StreamWrapper( + env, + stream_metadata = { # All of this is part is optional + "user": "MATHIEU", # choose your own username + "env_id": env_config['instance_id'], # environment identifier + "color": "#d900ff", # choose your color :) + "extra": "", # any extra text you put here will be displayed + } + ) + return env + set_random_seed(seed) + return _init \ No newline at end of file diff --git a/baselines/memory_addresses.py b/baselines/memory_addresses.py index be989ee5..b554c0a8 100644 --- a/baselines/memory_addresses.py +++ b/baselines/memory_addresses.py @@ -19,4 +19,6 @@ MONEY_ADDRESS_1 = 0xD347 MONEY_ADDRESS_2 = 0xD348 -MONEY_ADDRESS_3 = 0xD349 \ No newline at end of file +MONEY_ADDRESS_3 = 0xD349 + +SEEN_POKEMONS_ADDRESSES = [0xD30A, 0xD30B, 0xD30C, 0xD30D, 0xD30E, 0xD30F, 0xD310, 0xD311, 0xD312, 0xD313, 0xD314, 0xD315, 0xD316, 0xD317, 0xD318, 0xD319, 0xD31A, 0xD31B, 0xD31C] diff --git a/baselines/reader_pyboy.py b/baselines/reader_pyboy.py new file mode 100644 index 00000000..b3318083 --- /dev/null +++ b/baselines/reader_pyboy.py @@ -0,0 +1,120 @@ +from memory_addresses import * + + +class ReaderPyBoy: + + def __init__(self, pyboy): + self.pyboy = pyboy + + def read_m(self, addr): + return self.pyboy.get_memory_value(addr) + + def read_money(self): + return (100 * 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_1)) + + 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_2)) + + self.read_bcd(self.read_m(MONEY_ADDRESS_3))) + + def read_bcd(self, num): + return 10 * ((num >> 4) & 0x0f) + (num & 0x0f) + + def read_bit(self, addr, bit: int) -> bool: + # add padding so zero will read '0b100000000' instead of '0b0' + return bin(256 + self.read_m(addr))[-bit-1] == '1' + + def read_hp_fraction(self): + hp_sum = sum([self.read_hp(add) for add in HP_ADDRESSES]) + max_hp_sum = sum([self.read_hp(add) for add in MAX_HP_ADDRESSES]) + max_hp_sum = max(max_hp_sum, 1) + return hp_sum / max_hp_sum + + def read_hp(self, start): + return 256 * self.read_m(start) + self.read_m(start+1) + + # built-in since python 3.10 + def bit_count(self, bits): + return bin(bits).count('1') + + def read_triple(self, start_add): + return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2) + + def get_badges(self): + return self.bit_count(self.read_m(BADGE_COUNT_ADDRESS)) + + def get_opponent_level(self): + return max([self.read_m(a) for a in OPPONENT_LEVELS_ADDRESSES]) - 5 + + def read_party(self): + return [self.read_m(addr) for addr in PARTY_ADDRESSES] + + def get_levels_sum(self): + poke_levels = [max(self.read_m(a) - 2, 0) for a in LEVELS_ADDRESSES] + return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level + + def read_party_size_address(self): + return self.read_m(PARTY_SIZE_ADDRESS) + + def read_x_pos(self): + return self.read_m(X_POS_ADDRESS) + + def read_y_pos(self): + return self.read_m(Y_POS_ADDRESS) + + def read_map_n(self): + return self.read_m(MAP_N_ADDRESS) + + def read_events(self): + return [ + self.bit_count(self.read_m(i)) + for i in range(EVENT_FLAGS_START_ADDRESS, EVENT_FLAGS_END_ADDRESS) + ] + + def read_museum_tickets(self): + museum_ticket = (MUSEUM_TICKET_ADDRESS, 0) + return self.read_bit(museum_ticket[0], museum_ticket[1]) + + def read_levels(self): + return [self.read_m(a) for a in LEVELS_ADDRESSES] + + def read_seen_pokemons(self): + return [self.bit_count(self.read_m(a)) for a in SEEN_POKEMONS_ADDRESSES] + + def get_map_location(self): + map_locations = { + 0: "Pallet Town", + 1: "Viridian City", + 2: "Pewter City", + 3: "Cerulean City", + 12: "Route 1", + 13: "Route 2", + 14: "Route 3", + 15: "Route 4", + 33: "Route 22", + 37: "Red house first", + 38: "Red house second", + 39: "Blues house", + 40: "oaks lab", + 41: "Pokémon Center (Viridian City)", + 42: "Poké Mart (Viridian City)", + 43: "School (Viridian City)", + 44: "House 1 (Viridian City)", + 47: "Gate (Viridian City/Pewter City) (Route 2)", + 49: "Gate (Route 2)", + 50: "Gate (Route 2/Viridian Forest) (Route 2)", + 51: "viridian forest", + 52: "Pewter Museum (floor 1)", + 53: "Pewter Museum (floor 2)", + 54: "Pokémon Gym (Pewter City)", + 55: "House with disobedient Nidoran♂ (Pewter City)", + 56: "Poké Mart (Pewter City)", + 57: "House with two Trainers (Pewter City)", + 58: "Pokémon Center (Pewter City)", + 59: "Mt. Moon (Route 3 entrance)", + 60: "Mt. Moon", + 61: "Mt. Moon", + 68: "Pokémon Center (Route 4)", + 193: "Badges check gate (Route 22)" + } + if self.read_map_n() in map_locations.keys(): + return map_locations[self.read_map_n()] + else: + return "Unknown Location" diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index e1133f73..32f4d6dd 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -1,60 +1,43 @@ - import sys -import uuid -import os -from math import floor, sqrt -import json -from pathlib import Path import numpy as np -from einops import rearrange -import matplotlib.pyplot as plt -from skimage.transform import resize from pyboy import PyBoy -#from pyboy.logger import log_level -import hnswlib -import mediapy as media +import uuid +import json import pandas as pd +from pathlib import Path +from renderer import Renderer +from rewards import Reward +from reader_pyboy import ReaderPyBoy from gymnasium import Env, spaces from pyboy.utils import WindowEvent -from memory_addresses import * -class RedGymEnv(Env): +class RedGymEnv(Env): - def __init__( - self, config=None): + def __init__(self, config=None): self.debug = config['debug'] + self.instance_id = str(uuid.uuid4())[:8] if 'instance_id' not in config else config['instance_id'] self.s_path = config['session_path'] self.save_final_state = config['save_final_state'] - self.print_rewards = config['print_rewards'] - self.vec_dim = 4320 #1000 + self.save_video = config['save_video'] self.headless = config['headless'] - self.num_elements = 20000 # max self.init_state = config['init_state'] self.act_freq = config['action_freq'] self.max_steps = config['max_steps'] self.early_stopping = config['early_stop'] - self.save_video = config['save_video'] self.fast_video = config['fast_video'] self.video_interval = 256 * self.act_freq - self.downsample_factor = 2 - self.frame_stacks = 3 - self.explore_weight = 1 if 'explore_weight' not in config else config['explore_weight'] + self.print_rewards = config['print_rewards'] self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore'] - self.similar_frame_dist = config['sim_frame_dist'] - self.reward_scale = 1 if 'reward_scale' not in config else config['reward_scale'] + self.extra_buttons = False if 'extra_buttons' not in config else config['extra_buttons'] - self.instance_id = str(uuid.uuid4())[:8] if 'instance_id' not in config else config['instance_id'] - self.s_path.mkdir(exist_ok=True) self.reset_count = 0 self.all_runs = [] - # Set this in SOME subclasses self.metadata = {"render.modes": []} - self.reward_range = (0, 15000) self.valid_actions = [ WindowEvent.PRESS_ARROW_DOWN, @@ -64,7 +47,7 @@ def __init__( WindowEvent.PRESS_BUTTON_A, WindowEvent.PRESS_BUTTON_B, ] - + if self.extra_buttons: self.valid_actions.extend([ WindowEvent.PRESS_BUTTON_START, @@ -83,158 +66,108 @@ def __init__( WindowEvent.RELEASE_BUTTON_B ] - self.output_shape = (36, 40, 3) - self.mem_padding = 2 - self.memory_height = 8 - self.col_steps = 16 - self.output_full = ( - self.output_shape[0] * self.frame_stacks + 2 * (self.mem_padding + self.memory_height), - self.output_shape[1], - self.output_shape[2] - ) - - # Set these in ALL subclasses - self.action_space = spaces.Discrete(len(self.valid_actions)) - self.observation_space = spaces.Box(low=0, high=255, shape=self.output_full, dtype=np.uint8) - head = 'headless' if config['headless'] else 'SDL2' - #log_level("ERROR") self.pyboy = PyBoy( - config['gb_path'], - debugging=False, - disable_input=False, - window_type=head, - hide_window='--quiet' in sys.argv, - ) - - self.screen = self.pyboy.botsupport_manager().screen() + config['gb_path'], + debugging=False, + disable_input=False, + window_type=head, + hide_window='--quiet' in sys.argv, + ) if not config['headless']: self.pyboy.set_emulation_speed(6) - + + self.reader = ReaderPyBoy(self.pyboy) + + # Rewards + self.reward_service = Reward(config, self.reader) + self.renderer = Renderer(self.s_path, self.pyboy, self.reward_service, self.instance_id) + + # Set these in ALL subclasses + self.action_space = spaces.Discrete(len(self.valid_actions)) + self.observation_space = spaces.Box(low=0, high=255, shape=self.renderer.output_full, dtype=np.uint8) + self.reset() + def render(self): + return self.renderer.render() + def reset(self, seed=None, options=None): self.seed = seed # restart game, skipping credits with open(self.init_state, "rb") as f: self.pyboy.load_state(f) - - if self.use_screen_explore: - self.init_knn() - else: - self.init_map_mem() - self.recent_memory = np.zeros((self.output_shape[1]*self.memory_height, 3), dtype=np.uint8) - - self.recent_frames = np.zeros( - (self.frame_stacks, self.output_shape[0], - self.output_shape[1], self.output_shape[2]), - dtype=np.uint8) + self.reward_service.reset() - self.agent_stats = [] - + self.renderer.reset() if self.save_video: - base_dir = self.s_path / Path('rollouts') - base_dir.mkdir(exist_ok=True) - full_name = Path(f'full_reset_{self.reset_count}_id{self.instance_id}').with_suffix('.mp4') - model_name = Path(f'model_reset_{self.reset_count}_id{self.instance_id}').with_suffix('.mp4') - self.full_frame_writer = media.VideoWriter(base_dir / full_name, (144, 160), fps=60) - self.full_frame_writer.__enter__() - self.model_frame_writer = media.VideoWriter(base_dir / model_name, self.output_full[:2], fps=60) - self.model_frame_writer.__enter__() - - self.levels_satisfied = False - self.base_explore = 0 - self.max_opponent_level = 0 - self.max_event_rew = 0 - self.max_level_rew = 0 - self.last_health = 1 - self.total_healing_rew = 0 - self.died_count = 0 - self.party_size = 0 + self.renderer.save_video(self.reset_count) + + self.agent_stats = [] + self.step_count = 0 - self.progress_reward = self.get_game_state_reward() - self.total_reward = sum([val for _, val in self.progress_reward.items()]) + self.reset_count += 1 return self.render(), {} - - def init_knn(self): - # Declaring index - self.knn_index = hnswlib.Index(space='l2', dim=self.vec_dim) # possible options are l2, cosine or ip - # Initing index - the maximum number of elements should be known beforehand - self.knn_index.init_index( - max_elements=self.num_elements, ef_construction=100, M=16) - - def init_map_mem(self): - self.seen_coords = {} - - def render(self, reduce_res=True, add_memory=True, update_mem=True): - game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3) - if reduce_res: - game_pixels_render = (255*resize(game_pixels_render, self.output_shape)).astype(np.uint8) - if update_mem: - self.recent_frames[0] = game_pixels_render - if add_memory: - pad = np.zeros( - shape=(self.mem_padding, self.output_shape[1], 3), - dtype=np.uint8) - game_pixels_render = np.concatenate( - ( - self.create_exploration_memory(), - pad, - self.create_recent_memory(), - pad, - rearrange(self.recent_frames, 'f h w c -> (f h) w c') - ), - axis=0) - return game_pixels_render - + def step(self, action): self.run_action_on_emulator(action) self.append_agent_stats(action) + self.renderer.recent_frames = np.roll(self.renderer.recent_frames, 1, axis=0) - self.recent_frames = np.roll(self.recent_frames, 1, axis=0) - obs_memory = self.render() + # OBSERVATION - # trim off memory from frame for knn index - frame_start = 2 * (self.memory_height + self.mem_padding) - obs_flat = obs_memory[ - frame_start:frame_start+self.output_shape[0], ...].flatten().astype(np.float32) + obs_memory = self.render() + obs_flat = self.renderer.get_obs_flat(obs_memory) - if self.use_screen_explore: - self.update_frame_knn_index(obs_flat) - else: - self.update_seen_coords() - - self.update_heal_reward() - self.party_size = self.read_m(PARTY_SIZE_ADDRESS) + # REWARD - new_reward, new_prog = self.update_reward() - - self.last_health = self.read_hp_fraction() + reward_delta, new_prog = self.reward_service.update_rewards(obs_flat, self.step_count) + if self.print_rewards and self.step_count % 100 == 0: + self.reward_service.print_rewards(self.step_count) + if reward_delta < 0 and self.reader.read_hp_fraction() > 0: + self.renderer.save_screenshot('neg_reward', self.reward_service.total_reward, self.reset_count) # shift over short term reward memory - self.recent_memory = np.roll(self.recent_memory, 3) - self.recent_memory[0, 0] = min(new_prog[0] * 64, 255) - self.recent_memory[0, 1] = min(new_prog[1] * 64, 255) - self.recent_memory[0, 2] = min(new_prog[2] * 128, 255) + self.renderer.recent_memory = np.roll(self.renderer.recent_memory, 3) + self.renderer.recent_memory[0, 0] = min(new_prog[0] * 64, 255) + self.renderer.recent_memory[0, 1] = min(new_prog[1] * 64, 255) + self.renderer.recent_memory[0, 2] = min(new_prog[2] * 128, 255) - step_limit_reached = self.check_if_done() + # DONE - self.save_and_print_info(step_limit_reached, obs_memory) + done = self.check_if_done() + if self.step_count % 50 == 0: + self.renderer.save_and_print_info() - self.step_count += 1 + if done: + self.all_runs.append(self.reward_service.get_game_state_rewards()) + with open(self.s_path / Path(f'all_runs_{self.instance_id}.json'), 'w') as f: + json.dump(self.all_runs, f) + pd.DataFrame(self.agent_stats).to_csv( + self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') - return obs_memory, new_reward*0.1, False, step_limit_reached, {} + if self.print_rewards: + print('', flush=True) + if self.save_final_state: + self.renderer.save_final_state(obs_memory, self.reset_count, self.reward_service.total_reward) + + if self.save_video and done: + self.renderer.full_frame_writer.close() + self.renderer.model_frame_writer.close() + + self.step_count += 1 + return obs_memory, reward_delta * 0.1, False, done, {} def run_action_on_emulator(self, action): # press button then release after some steps self.pyboy.send_input(self.valid_actions[action]) # disable rendering when we don't need it - if not self.save_video and self.headless: + if not self.renderer.save_video and self.headless: self.pyboy._rendering(False) for i in range(self.act_freq): # release action, so they are stateless @@ -248,378 +181,45 @@ def run_action_on_emulator(self, action): if self.valid_actions[action] == WindowEvent.PRESS_BUTTON_START: self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START) if self.save_video and not self.fast_video: - self.add_video_frame() - if i == self.act_freq-1: + self.renderer.add_video_frame() + if i == self.act_freq - 1: self.pyboy._rendering(True) self.pyboy.tick() if self.save_video and self.fast_video: - self.add_video_frame() - - def add_video_frame(self): - self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False)) - self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) - + self.renderer.add_video_frame() + def append_agent_stats(self, action): - x_pos = self.read_m(X_POS_ADDRESS) - y_pos = self.read_m(Y_POS_ADDRESS) - map_n = self.read_m(MAP_N_ADDRESS) - levels = [self.read_m(a) for a in LEVELS_ADDRESSES] + x_pos = self.reader.read_x_pos() + y_pos = self.reader.read_y_pos() + map_n = self.reader.read_map_n() + levels = self.reader.read_levels() if self.use_screen_explore: - expl = ('frames', self.knn_index.get_current_count()) + expl = ('frames', self.reward_service.knn_index.get_current_count()) else: - expl = ('coord_count', len(self.seen_coords)) + expl = ('coord_count', len(self.reward_service.seen_coords)) self.agent_stats.append({ 'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, - 'map_location': self.get_map_location(map_n), + 'map_location': self.reader.get_map_location(), 'last_action': action, - 'pcount': self.read_m(PARTY_SIZE_ADDRESS), - 'levels': levels, + 'final_total_reward': self.reward_service.total_reward, + 'party_size': self.reader.read_party_size_address(), + 'levels': levels, 'levels_sum': sum(levels), - 'ptypes': self.read_party(), - 'hp': self.read_hp_fraction(), + 'seen_pokemons': self.reward_service.seen_pokemons, + 'ptypes': self.reader.read_party(), + 'hp': self.reader.read_hp_fraction(), expl[0]: expl[1], - 'deaths': self.died_count, 'badge': self.get_badges(), - 'event': self.progress_reward['event'], 'healr': self.total_healing_rew + 'deaths': self.reward_service.died_count, + 'badge': self.reader.get_badges(), + 'event': self.reward_service.max_event, + 'healr': self.reward_service.total_healing }) - def update_frame_knn_index(self, frame_vec): - - if self.get_levels_sum() >= 22 and not self.levels_satisfied: - self.levels_satisfied = True - self.base_explore = self.knn_index.get_current_count() - self.init_knn() - - if self.knn_index.get_current_count() == 0: - # if index is empty add current frame - self.knn_index.add_items( - frame_vec, np.array([self.knn_index.get_current_count()]) - ) - else: - # check for nearest frame and add if current - labels, distances = self.knn_index.knn_query(frame_vec, k = 1) - if distances[0][0] > self.similar_frame_dist: - # print(f"distances[0][0] : {distances[0][0]} similar_frame_dist : {self.similar_frame_dist}") - self.knn_index.add_items( - frame_vec, np.array([self.knn_index.get_current_count()]) - ) - - def update_seen_coords(self): - x_pos = self.read_m(X_POS_ADDRESS) - y_pos = self.read_m(Y_POS_ADDRESS) - map_n = self.read_m(MAP_N_ADDRESS) - coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}" - if self.get_levels_sum() >= 22 and not self.levels_satisfied: - self.levels_satisfied = True - self.base_explore = len(self.seen_coords) - self.seen_coords = {} - - self.seen_coords[coord_string] = self.step_count - - def update_reward(self): - # compute reward - old_prog = self.group_rewards() - self.progress_reward = self.get_game_state_reward() - new_prog = self.group_rewards() - new_total = sum([val for _, val in self.progress_reward.items()]) #sqrt(self.explore_reward * self.progress_reward) - new_step = new_total - self.total_reward - if new_step < 0 and self.read_hp_fraction() > 0: - #print(f'\n\nreward went down! {self.progress_reward}\n\n') - self.save_screenshot('neg_reward') - - self.total_reward = new_total - return (new_step, - (new_prog[0]-old_prog[0], - new_prog[1]-old_prog[1], - new_prog[2]-old_prog[2]) - ) - - def group_rewards(self): - prog = self.progress_reward - # these values are only used by memory - return (prog['level'] * 100 / self.reward_scale, - self.read_hp_fraction()*2000, - prog['explore'] * 150 / (self.explore_weight * self.reward_scale)) - #(prog['events'], - # prog['levels'] + prog['party_xp'], - # prog['explore']) - - def create_exploration_memory(self): - w = self.output_shape[1] - h = self.memory_height - - def make_reward_channel(r_val): - col_steps = self.col_steps - max_r_val = (w-1) * h * col_steps - # truncate progress bar. if hitting this - # you should scale down the reward in group_rewards! - r_val = min(r_val, max_r_val) - row = floor(r_val / (h * col_steps)) - memory = np.zeros(shape=(h, w), dtype=np.uint8) - memory[:, :row] = 255 - row_covered = row * h * col_steps - col = floor((r_val - row_covered) / col_steps) - memory[:col, row] = 255 - col_covered = col * col_steps - last_pixel = floor(r_val - row_covered - col_covered) - memory[col, row] = last_pixel * (255 // col_steps) - return memory - - level, hp, explore = self.group_rewards() - full_memory = np.stack(( - make_reward_channel(level), - make_reward_channel(hp), - make_reward_channel(explore) - ), axis=-1) - - if self.get_badges() > 0: - full_memory[:, -1, :] = 255 - - return full_memory - - def create_recent_memory(self): - return rearrange( - self.recent_memory, - '(w h) c -> h w c', - h=self.memory_height) - def check_if_done(self): if self.early_stopping: done = False - if self.step_count > 128 and self.recent_memory.sum() < (255 * 1): + if self.step_count > 128 and self.renderer.recent_memory.sum() < (255 * 1): done = True else: done = self.step_count >= self.max_steps - #done = self.read_hp_fraction() == 0 return done - - def save_and_print_info(self, done, obs_memory): - if self.print_rewards: - prog_string = f'step: {self.step_count:6d}' - for key, val in self.progress_reward.items(): - prog_string += f' {key}: {val:5.2f}' - prog_string += f' sum: {self.total_reward:5.2f}' - print(f'\r{prog_string}', end='', flush=True) - - if self.step_count % 50 == 0: - plt.imsave( - self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), - self.render(reduce_res=False)) - - if self.print_rewards and done: - print('', flush=True) - if self.save_final_state: - fs_path = self.s_path / Path('final_states') - fs_path.mkdir(exist_ok=True) - plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'), - obs_memory) - plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'), - self.render(reduce_res=False)) - - if self.save_video and done: - self.full_frame_writer.close() - self.model_frame_writer.close() - - if done: - self.all_runs.append(self.progress_reward) - with open(self.s_path / Path(f'all_runs_{self.instance_id}.json'), 'w') as f: - json.dump(self.all_runs, f) - pd.DataFrame(self.agent_stats).to_csv( - self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') - - def read_m(self, addr): - return self.pyboy.get_memory_value(addr) - - def read_bit(self, addr, bit: int) -> bool: - # add padding so zero will read '0b100000000' instead of '0b0' - return bin(256 + self.read_m(addr))[-bit-1] == '1' - - def get_levels_sum(self): - poke_levels = [max(self.read_m(a) - 2, 0) for a in LEVELS_ADDRESSES] - return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level - - def get_levels_reward(self): - explore_thresh = 22 - scale_factor = 4 - level_sum = self.get_levels_sum() - if level_sum < explore_thresh: - scaled = level_sum - else: - scaled = (level_sum-explore_thresh) / scale_factor + explore_thresh - self.max_level_rew = max(self.max_level_rew, scaled) - return self.max_level_rew - - def get_knn_reward(self): - - pre_rew = self.explore_weight * 0.005 - post_rew = self.explore_weight * 0.01 - cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords) - base = (self.base_explore if self.levels_satisfied else cur_size) * pre_rew - post = (cur_size if self.levels_satisfied else 0) * post_rew - return base + post - - def get_badges(self): - return self.bit_count(self.read_m(BADGE_COUNT_ADDRESS)) - - def read_party(self): - return [self.read_m(addr) for addr in PARTY_ADDRESSES] - - def update_heal_reward(self): - cur_health = self.read_hp_fraction() - # if health increased and party size did not change - if (cur_health > self.last_health and - self.read_m(PARTY_SIZE_ADDRESS) == self.party_size): - if self.last_health > 0: - heal_amount = cur_health - self.last_health - if heal_amount > 0.5: - print(f'healed: {heal_amount}') - self.save_screenshot('healing') - self.total_healing_rew += heal_amount * 4 - else: - self.died_count += 1 - - def get_all_events_reward(self): - # adds up all event flags, exclude museum ticket - event_flags_start = EVENT_FLAGS_START_ADDRESS - event_flags_end = EVENT_FLAGS_END_ADDRESS - museum_ticket = (MUSEUM_TICKET_ADDRESS, 0) - base_event_flags = 13 - return max( - sum( - [ - self.bit_count(self.read_m(i)) - for i in range(event_flags_start, event_flags_end) - ] - ) - - base_event_flags - - int(self.read_bit(museum_ticket[0], museum_ticket[1])), - 0, - ) - - def get_game_state_reward(self, print_stats=False): - # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map - # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm - ''' - num_poke = self.read_m(0xD163) - poke_xps = [self.read_triple(a) for a in [0xD179, 0xD1A5, 0xD1D1, 0xD1FD, 0xD229, 0xD255]] - #money = self.read_money() - 975 # subtract starting money - seen_poke_count = sum([self.bit_count(self.read_m(i)) for i in range(0xD30A, 0xD31D)]) - all_events_score = sum([self.bit_count(self.read_m(i)) for i in range(0xD747, 0xD886)]) - oak_parcel = self.read_bit(0xD74E, 1) - oak_pokedex = self.read_bit(0xD74B, 5) - opponent_level = self.read_m(0xCFF3) - self.max_opponent_level = max(self.max_opponent_level, opponent_level) - enemy_poke_count = self.read_m(0xD89C) - self.max_opponent_poke = max(self.max_opponent_poke, enemy_poke_count) - - if print_stats: - print(f'num_poke : {num_poke}') - print(f'poke_levels : {poke_levels}') - print(f'poke_xps : {poke_xps}') - #print(f'money: {money}') - print(f'seen_poke_count : {seen_poke_count}') - print(f'oak_parcel: {oak_parcel} oak_pokedex: {oak_pokedex} all_events_score: {all_events_score}') - ''' - - state_scores = { - 'event': self.reward_scale*self.update_max_event_rew(), - #'party_xp': self.reward_scale*0.1*sum(poke_xps), - 'level': self.reward_scale*self.get_levels_reward(), - 'heal': self.reward_scale*self.total_healing_rew, - 'op_lvl': self.reward_scale*self.update_max_op_level(), - 'dead': self.reward_scale*-0.1*self.died_count, - 'badge': self.reward_scale*self.get_badges() * 5, - #'op_poke': self.reward_scale*self.max_opponent_poke * 800, - #'money': self.reward_scale* money * 3, - #'seen_poke': self.reward_scale * seen_poke_count * 400, - 'explore': self.reward_scale * self.get_knn_reward() - } - - return state_scores - - def save_screenshot(self, name): - ss_dir = self.s_path / Path('screenshots') - ss_dir.mkdir(exist_ok=True) - plt.imsave( - ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), - self.render(reduce_res=False)) - - def update_max_op_level(self): - #opponent_level = self.read_m(0xCFE8) - 5 # base level - opponent_level = max([self.read_m(a) for a in OPPONENT_LEVELS_ADDRESSES]) - 5 - #if opponent_level >= 7: - # self.save_screenshot('highlevelop') - self.max_opponent_level = max(self.max_opponent_level, opponent_level) - return self.max_opponent_level * 0.2 - - def update_max_event_rew(self): - cur_rew = self.get_all_events_reward() - self.max_event_rew = max(cur_rew, self.max_event_rew) - return self.max_event_rew - - def read_hp_fraction(self): - hp_sum = sum([self.read_hp(add) for add in HP_ADDRESSES]) - max_hp_sum = sum([self.read_hp(add) for add in MAX_HP_ADDRESSES]) - max_hp_sum = max(max_hp_sum, 1) - return hp_sum / max_hp_sum - - def read_hp(self, start): - return 256 * self.read_m(start) + self.read_m(start+1) - - # built-in since python 3.10 - def bit_count(self, bits): - return bin(bits).count('1') - - def read_triple(self, start_add): - return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2) - - def read_bcd(self, num): - return 10 * ((num >> 4) & 0x0f) + (num & 0x0f) - - def read_money(self): - return (100 * 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_1)) + - 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_2)) + - self.read_bcd(self.read_m(MONEY_ADDRESS_3))) - - def get_map_location(self, map_idx): - map_locations = { - 0: "Pallet Town", - 1: "Viridian City", - 2: "Pewter City", - 3: "Cerulean City", - 12: "Route 1", - 13: "Route 2", - 14: "Route 3", - 15: "Route 4", - 33: "Route 22", - 37: "Red house first", - 38: "Red house second", - 39: "Blues house", - 40: "oaks lab", - 41: "Pokémon Center (Viridian City)", - 42: "Poké Mart (Viridian City)", - 43: "School (Viridian City)", - 44: "House 1 (Viridian City)", - 47: "Gate (Viridian City/Pewter City) (Route 2)", - 49: "Gate (Route 2)", - 50: "Gate (Route 2/Viridian Forest) (Route 2)", - 51: "viridian forest", - 52: "Pewter Museum (floor 1)", - 53: "Pewter Museum (floor 2)", - 54: "Pokémon Gym (Pewter City)", - 55: "House with disobedient Nidoran♂ (Pewter City)", - 56: "Poké Mart (Pewter City)", - 57: "House with two Trainers (Pewter City)", - 58: "Pokémon Center (Pewter City)", - 59: "Mt. Moon (Route 3 entrance)", - 60: "Mt. Moon", - 61: "Mt. Moon", - 68: "Pokémon Center (Route 4)", - 193: "Badges check gate (Route 22)" - } - if map_idx in map_locations.keys(): - return map_locations[map_idx] - else: - return "Unknown Location" - diff --git a/baselines/renderer.py b/baselines/renderer.py new file mode 100644 index 00000000..e12eb937 --- /dev/null +++ b/baselines/renderer.py @@ -0,0 +1,130 @@ +import numpy as np +import matplotlib.pyplot as plt +from math import floor +import mediapy as media +from pathlib import Path +from einops import rearrange +from skimage.transform import resize +from reader_pyboy import ReaderPyBoy + + +class Renderer: + + def __init__(self, s_path, pyboy, reward_service, instance_id): + self.reward_service = reward_service + self.instance_id = instance_id + self.s_path = s_path + self.s_path.mkdir(exist_ok=True) + self.output_shape = (36, 40, 3) + self.frame_stacks = 3 + self.mem_padding = 2 + self.memory_height = 8 + self.output_full = ( + self.output_shape[0] * self.frame_stacks + 2 * (self.mem_padding + self.memory_height), + self.output_shape[1], + self.output_shape[2] + ) + self.col_steps = 16 + self.screen = pyboy.botsupport_manager().screen() + self.reader = ReaderPyBoy(pyboy) + + def render(self, reduce_res=True, add_memory=True, update_mem=True): + game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3) + if reduce_res: + game_pixels_render = (255 * resize(game_pixels_render, self.output_shape)).astype(np.uint8) + if update_mem: + self.recent_frames[0] = game_pixels_render + if add_memory: + pad = np.zeros( + shape=(self.mem_padding, self.output_shape[1], 3), + dtype=np.uint8) + level, hp, explore = self.reward_service.group_rewards_lvl_hp_explore(self.reward_service.get_game_state_rewards()) + game_pixels_render = np.concatenate( + ( + self.create_exploration_memory(level, hp, explore), + pad, + self.create_recent_memory(), + pad, + rearrange(self.recent_frames, 'f h w c -> (f h) w c') + ), + axis=0) + return game_pixels_render + + def save_and_print_info(self): + plt.imsave( + self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), + self.render(reduce_res=False)) + + + def save_screenshot(self, name, total_reward, reset_count): + ss_dir = self.s_path / Path('screenshots') + ss_dir.mkdir(exist_ok=True) + plt.imsave( + ss_dir / Path( + f'frame{self.instance_id}_r{total_reward:.4f}_{reset_count}_{name}.jpeg'), + self.render(reduce_res=False) + ) + + def save_video(self, reset_count): + base_dir = self.s_path / Path('rollouts') + base_dir.mkdir(exist_ok=True) + full_name = Path(f'full_reset_{reset_count}_id{self.instance_id}').with_suffix('.mp4') + model_name = Path(f'model_reset_{reset_count}_id{self.instance_id}').with_suffix('.mp4') + self.full_frame_writer = media.VideoWriter(base_dir / full_name, (144, 160), fps=60) + self.full_frame_writer.__enter__() + self.model_frame_writer = media.VideoWriter(base_dir / model_name, self.output_full[:2], fps=60) + self.model_frame_writer.__enter__() + + def add_video_frame(self): + self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False)) + self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) + + def get_obs_flat(self, obs_memory): + # trim off memory from frame for knn index + frame_start = 2 * (self.memory_height + self.mem_padding) + return obs_memory[frame_start:frame_start + self.output_shape[0], ...].flatten().astype(np.float32) + + def create_recent_memory(self): + return rearrange(self.recent_memory,'(w h) c -> h w c', h=self.memory_height) + + def create_exploration_memory(self, level, hp, explore): + w = self.output_shape[1] + h = self.memory_height + + def make_reward_channel(r_val): + col_steps = self.col_steps + max_r_val = (w - 1) * h * col_steps + # truncate progress bar. if hitting this + # you should scale down the reward in group_rewards! + r_val = min(r_val, max_r_val) + row = floor(r_val / (h * col_steps)) + memory = np.zeros(shape=(h, w), dtype=np.uint8) + memory[:, :row] = 255 + row_covered = row * h * col_steps + col = floor((r_val - row_covered) / col_steps) + memory[:col, row] = 255 + col_covered = col * col_steps + last_pixel = floor(r_val - row_covered - col_covered) + memory[col, row] = last_pixel * (255 // col_steps) + return memory + + full_memory = np.stack( + (make_reward_channel(level), make_reward_channel(hp), make_reward_channel(explore)), + axis=-1) + + if self.reader.get_badges() > 0: + full_memory[:, -1, :] = 255 + + return full_memory + + def save_final_state(self, obs_memory, reset_count, total_reward): + fs_path = self.s_path / Path('final_states') + fs_path.mkdir(exist_ok=True) + plt.imsave(fs_path / Path(f'frame_r{total_reward:.4f}_{reset_count}_small.jpeg'), obs_memory) + plt.imsave(fs_path / Path(f'frame_r{total_reward:.4f}_{reset_count}_full.jpeg'), self.render(reduce_res=False)) + + def reset(self): + self.recent_memory = np.zeros((self.output_shape[1] * self.memory_height, 3), dtype=np.uint8) + self.recent_frames = np.zeros( + (self.frame_stacks, self.output_shape[0],self.output_shape[1], self.output_shape[2]), + dtype=np.uint8) diff --git a/baselines/rewards.py b/baselines/rewards.py new file mode 100644 index 00000000..3c46bbf8 --- /dev/null +++ b/baselines/rewards.py @@ -0,0 +1,211 @@ +import hnswlib +import numpy as np + + +class Reward: + + def __init__(self, config, reader): + self.reward_scale = 1 if 'reward_scale' not in config else config['reward_scale'] + self.reward_range = (0, 15000) + self.reader = reader + + # Pokedex + self.seen_pokemons = 0 + + # Level + self.max_level = 0 + self.levels_satisfied = False + self.max_opponent_level = 0 + + # Event + self.max_event = 0 + + # Health + self.total_healing = 0 + self.last_party_size = 0 + self.last_health = 1 + self.died_count = 0 + + # Explore + self.base_explore = 0 + self.explore_weight = 1 if 'explore_weight' not in config else config['explore_weight'] + self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore'] + self.similar_frame_dist = config['sim_frame_dist'] + self.vec_dim = 4320 # 1000 + self.num_elements = 20000 # max + self.knn_index = None + self.init_knn() + self.seen_coords = {} + self.init_map_mem() + self.cur_size = 0 + + self.last_game_state_rewards = self.get_game_state_rewards() + self.total_reward = 0 + + def reset(self): + self.max_event = 0 + self.max_level = 0 + self.levels_satisfied = False + self.total_healing = 0 + self.last_party_size = 0 + self.last_health = 1 + self.max_opponent_level = 0 + self.died_count = 0 + self.seen_pokemons = 0 + self.base_explore = 0 + self.seen_coords = {} + if self.use_screen_explore: + self.init_knn() + else: + self.init_map_mem() + self.total_reward = 0 + self.last_game_state_rewards = self.get_game_state_rewards() + + def get_game_state_rewards(self): + # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map + # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm + return { + 'event': self.reward_scale * self.max_event * 1, + 'level': self.reward_scale * self.compute_level_reward(), + 'heal': self.reward_scale * self.total_healing * 2, + 'op_lvl': self.reward_scale * self.max_opponent_level * 1, + 'dead': self.reward_scale * self.died_count * -0.1, + 'badge': self.reward_scale * self.reader.get_badges() * 5, + 'seen_poke': self.reward_scale * self.seen_pokemons * 1, + 'explore': self.reward_scale * self.compute_explore_reward() + } + + def init_knn(self): + # Declaring index + self.knn_index = hnswlib.Index(space='l2', dim=self.vec_dim) # possible options are l2, cosine or ip + # Initing index - the maximum number of elements should be known beforehand + self.knn_index.init_index(max_elements=self.num_elements, ef_construction=100, M=16) + + def init_map_mem(self): + self.seen_coords = {} + + def update_exploration_reward(self): + self.cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords) + + def get_all_events_flags(self): + # adds up all event flags, exclude museum ticket + base_event_flags = 13 + return max(0, sum(self.reader.read_events()) - base_event_flags - int(self.reader.read_museum_tickets())) + + def compute_level_reward(self): + # Levels count only quarter after 22 threshold + return int(min(22, self.max_level) + (max(0, (self.max_level - 22)) / 4)) * 1 + + def compute_explore_reward(self): + pre_rew = 0.005 + post_rew = 0.01 + if not self.levels_satisfied: + return (self.cur_size * 0.005) * self.explore_weight + else: + return ((self.base_explore * pre_rew) + (self.cur_size * post_rew)) * self.explore_weight + + def group_rewards_lvl_hp_explore(self, rewards): + return (rewards['level'] * 100 / self.reward_scale, + self.reader.read_hp_fraction() * 2000, + rewards['explore'] * 150 / (self.explore_weight * self.reward_scale)) + + def update_rewards(self, obs_flat, step_count): + if self.use_screen_explore: + self.update_frame_knn_index(obs_flat) + else: + self.update_seen_coords(step_count) + self.update_exploration_reward() + self.update_max_event() + self.update_total_heal_and_death() + self.update_max_op_level() + self.update_seen_pokemons() + self.update_max_level() + + return self.update_state_reward() + + def update_state_reward(self): + # compute reward + last_total = sum([val for _, val in self.last_game_state_rewards.items()]) + new_total = sum([val for _, val in self.get_game_state_rewards().items()]) + self.total_reward = new_total + reward_delta = new_total - last_total + + self.last_game_state_rewards = self.get_game_state_rewards() + + # used by memory + old_prog = self.group_rewards_lvl_hp_explore(self.last_game_state_rewards) + new_prog = self.group_rewards_lvl_hp_explore(self.get_game_state_rewards()) + return reward_delta, (new_prog[0] - old_prog[0], new_prog[1] - old_prog[1], new_prog[2] - old_prog[2]) + + def update_max_op_level(self): + opponent_level = self.reader.get_opponent_level() + self.max_opponent_level = max(self.max_opponent_level, opponent_level) + + def update_seen_pokemons(self): + initial_seen_pokemon = 3 + self.seen_pokemons = sum(self.reader.read_seen_pokemons()) - initial_seen_pokemon + + def update_max_level(self): + # lvl can't decrease + self.max_level = max(self.max_level, self.reader.get_levels_sum()) + + def update_frame_knn_index(self, frame_vec): + + if self.reader.get_levels_sum() >= 22 and not self.levels_satisfied: + self.levels_satisfied = True + self.base_explore = self.knn_index.get_current_count() + self.init_knn() + + if self.knn_index.get_current_count() == 0: + # if index is empty add current frame + self.knn_index.add_items( + frame_vec, np.array([self.knn_index.get_current_count()]) + ) + else: + # check for nearest frame and add if current + _, distances = self.knn_index.knn_query(frame_vec, k=1) + if distances[0][0] > self.similar_frame_dist: + # print(f"distances[0][0] : {distances[0][0]} similar_frame_dist : {self.similar_frame_dist}") + self.knn_index.add_items( + frame_vec, np.array([self.knn_index.get_current_count()]) + ) + + def update_seen_coords(self, step_count): + x_pos = self.reader.read_x_pos() + y_pos = self.reader.read_y_pos() + map_n = self.reader.read_map_n() + coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}" + if self.reader.get_levels_sum() >= 22 and not self.levels_satisfied: + self.levels_satisfied = True + self.base_explore = len(self.seen_coords) + self.seen_coords = {} + + self.seen_coords[coord_string] = step_count + + def update_total_heal_and_death(self): + cur_health = self.reader.read_hp_fraction() + # if health increased and party size did not change + if (cur_health > self.last_health and + self.reader.read_party_size_address() == self.last_party_size): + if self.last_health > 0: + heal_amount = cur_health - self.last_health + if heal_amount > 0.5: + print(f'healed: {heal_amount}') + self.total_healing += heal_amount + else: + self.died_count += 1 + self.last_party_size = self.reader.read_party_size_address() + self.last_health = self.reader.read_hp_fraction() + + def update_max_event(self): + cur_rew = self.get_all_events_flags() + self.max_event = max(cur_rew, self.max_event) + + def print_rewards(self, step_count): + prog_string = f'step: {step_count:6d}' + rewards_state = self.get_game_state_rewards() + for key, val in rewards_state.items(): + prog_string += f' {key}: {val:5.2f}' + prog_string += f' sum: {self.total_reward:5.2f}' + prog_string += f' {self.reader.get_map_location()}' + print(f'\r{prog_string}', end='', flush=True) diff --git a/baselines/run_baseline_parallel.py b/baselines/run_baseline_parallel.py index f4423a3a..b698289c 100644 --- a/baselines/run_baseline_parallel.py +++ b/baselines/run_baseline_parallel.py @@ -1,63 +1,31 @@ -from os.path import exists from pathlib import Path +from datetime import datetime import uuid -from red_gym_env import RedGymEnv -from stable_baselines3 import A2C, PPO -from stable_baselines3.common import env_checker -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv -from stable_baselines3.common.utils import set_random_seed -from stable_baselines3.common.callbacks import CheckpointCallback +from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList -def make_env(rank, env_conf, seed=0): - """ - Utility function for multiprocessed env. - :param env_id: (str) the environment ID - :param num_env: (int) the number of environments you wish to have in subprocesses - :param seed: (int) the initial seed for RNG - :param rank: (int) index of the subprocess - """ - def _init(): - env = RedGymEnv(env_conf) - env.reset(seed=(seed + rank)) - return env - set_random_seed(seed) - return _init +from tensorboard_callback import TensorboardCallback +from baselines_utils import load_or_create_model if __name__ == '__main__': - ep_length = 2048 * 8 - sess_path = Path(f'session_{str(uuid.uuid4())[:8]}') - + sess_path = Path(f'sessions/session_{datetime.now().strftime("%Y%m%d_%H%M")}') + pretrained_model = 'session_4da05e87_main_good/poke_439746560_steps' + model_i_like = 'session_20240227_1952/poke_720896_steps' + model_to_load_path = '' # 'sessions/session_20240302_1929/poke_1040384_steps' env_config = { 'headless': True, 'save_final_state': True, 'early_stop': False, - 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, + 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, 'print_rewards': True, 'save_video': False, 'fast_video': True, 'session_path': sess_path, - 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0, - 'use_screen_explore': True, 'extra_buttons': False + 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0, + 'use_screen_explore': True, 'extra_buttons': False, 'stream': False, 'instance_id': str(uuid.uuid4())[:8] } - - - num_cpu = 44 #64 #46 # Also sets the number of episodes per training iteration - env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) - - checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, - name_prefix='poke') - #env_checker.check_env(env) - learn_steps = 40 - file_name = 'session_e41c9eff/poke_38207488_steps' #'session_e41c9eff/poke_250871808_steps' - - #'session_bfdca25a/poke_42532864_steps' #'session_d3033abb/poke_47579136_steps' #'session_a17cc1f5/poke_33546240_steps' #'session_e4bdca71/poke_8945664_steps' #'session_eb21989e/poke_40255488_steps' #'session_80f70ab4/poke_58982400_steps' - if exists(file_name + '.zip'): - print('\nloading checkpoint') - model = PPO.load(file_name, env=env) - model.n_steps = ep_length - model.n_envs = num_cpu - model.rollout_buffer.buffer_size = ep_length - model.rollout_buffer.n_envs = num_cpu - model.rollout_buffer.reset() - else: - model = PPO('CnnPolicy', env, verbose=1, n_steps=ep_length, batch_size=512, n_epochs=1, gamma=0.999) - + num_cpu = 4 #64 #46 # Also sets the number of episodes per training iteration + model = load_or_create_model(model_to_load_path, env_config, ep_length, num_cpu) + + checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, name_prefix='poke') + callbacks = [checkpoint_callback, TensorboardCallback()] + learn_steps = 10 + for i in range(learn_steps): - model.learn(total_timesteps=(ep_length)*num_cpu*1000, callback=checkpoint_callback) + model.learn(total_timesteps=ep_length * num_cpu * 40, callback=CallbackList(callbacks))