From ba840d9b39096e71643d09f30f72bcec127b3f2d Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Tue, 27 Feb 2024 19:43:20 +0100 Subject: [PATCH 01/10] introduces new classes for maintability Signed-off-by: Mathieu D --- baselines/baselines_utils.py | 39 +++ baselines/memory_addresses.py | 4 +- baselines/reader_pyboy.py | 110 +++++++ baselines/red_gym_env.py | 453 ++++++----------------------- baselines/rewards.py | 219 ++++++++++++++ baselines/run_baseline_parallel.py | 62 +--- 6 files changed, 471 insertions(+), 416 deletions(-) create mode 100644 baselines/baselines_utils.py create mode 100644 baselines/reader_pyboy.py create mode 100644 baselines/rewards.py diff --git a/baselines/baselines_utils.py b/baselines/baselines_utils.py new file mode 100644 index 000000000..a31ae9603 --- /dev/null +++ b/baselines/baselines_utils.py @@ -0,0 +1,39 @@ +from os.path import exists +from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3 import PPO +from stable_baselines3.common.utils import set_random_seed +from red_gym_env import RedGymEnv + + +def load_or_create_model(model_to_load_path, env_config, total_timesteps, num_cpu): + + env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) + + if exists(model_to_load_path + '.zip'): + print('\nloading checkpoint') + model = PPO.load(model_to_load_path, env=env) + model.n_steps = total_timesteps + model.n_envs = num_cpu + model.rollout_buffer.buffer_size = total_timesteps + model.rollout_buffer.n_envs = num_cpu + model.rollout_buffer.reset() + else: + model = PPO('CnnPolicy', env, verbose=1, n_steps=total_timesteps, batch_size=512, n_epochs=1, gamma=0.999) + + return model + + +def make_env(rank, env_conf, seed=0): + """ + Utility function for multiprocessed env. + :param env_id: (str) the environment ID + :param num_env: (int) the number of environments you wish to have in subprocesses + :param seed: (int) the initial seed for RNG + :param rank: (int) index of the subprocess + """ + def _init(): + env = RedGymEnv(env_conf) + env.reset(seed=(seed + rank)) + return env + set_random_seed(seed) + return _init diff --git a/baselines/memory_addresses.py b/baselines/memory_addresses.py index be989ee58..b554c0a88 100644 --- a/baselines/memory_addresses.py +++ b/baselines/memory_addresses.py @@ -19,4 +19,6 @@ MONEY_ADDRESS_1 = 0xD347 MONEY_ADDRESS_2 = 0xD348 -MONEY_ADDRESS_3 = 0xD349 \ No newline at end of file +MONEY_ADDRESS_3 = 0xD349 + +SEEN_POKEMONS_ADDRESSES = [0xD30A, 0xD30B, 0xD30C, 0xD30D, 0xD30E, 0xD30F, 0xD310, 0xD311, 0xD312, 0xD313, 0xD314, 0xD315, 0xD316, 0xD317, 0xD318, 0xD319, 0xD31A, 0xD31B, 0xD31C] diff --git a/baselines/reader_pyboy.py b/baselines/reader_pyboy.py new file mode 100644 index 000000000..f7ebbe64b --- /dev/null +++ b/baselines/reader_pyboy.py @@ -0,0 +1,110 @@ +from memory_addresses import * + + +class ReaderPyBoy: + + def __init__(self, pyboy): + self.pyboy = pyboy + + def read_m(self, addr): + return self.pyboy.get_memory_value(addr) + + def read_money(self): + return (100 * 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_1)) + + 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_2)) + + self.read_bcd(self.read_m(MONEY_ADDRESS_3))) + + def read_bcd(self, num): + return 10 * ((num >> 4) & 0x0f) + (num & 0x0f) + + def read_bit(self, addr, bit: int) -> bool: + # add padding so zero will read '0b100000000' instead of '0b0' + return bin(256 + self.read_m(addr))[-bit-1] == '1' + + def read_hp_fraction(self): + hp_sum = sum([self.read_hp(add) for add in HP_ADDRESSES]) + max_hp_sum = sum([self.read_hp(add) for add in MAX_HP_ADDRESSES]) + max_hp_sum = max(max_hp_sum, 1) + return hp_sum / max_hp_sum + + def read_hp(self, start): + return 256 * self.read_m(start) + self.read_m(start+1) + + # built-in since python 3.10 + def bit_count(self, bits): + return bin(bits).count('1') + + def read_triple(self, start_add): + return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2) + + def get_badges(self): + return self.bit_count(self.read_m(BADGE_COUNT_ADDRESS)) + + def get_opponent_level(self): + return max([self.read_m(a) for a in OPPONENT_LEVELS_ADDRESSES]) - 5 + + def read_party(self): + return [self.read_m(addr) for addr in PARTY_ADDRESSES] + + def get_levels_sum(self): + poke_levels = [max(self.read_m(a) - 2, 0) for a in LEVELS_ADDRESSES] + return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level + + def read_party_size_address(self): + return self.read_m(PARTY_SIZE_ADDRESS) + + def read_x_pos(self): + return self.read_m(X_POS_ADDRESS) + + def read_y_pos(self): + return self.read_m(Y_POS_ADDRESS) + + def read_map_n(self): + return self.read_m(MAP_N_ADDRESS) + + def read_levels(self): + return [self.read_m(a) for a in LEVELS_ADDRESSES] + + def read_seen_pokemons(self): + return [self.bit_count(self.read_m(a)) for a in SEEN_POKEMONS_ADDRESSES] + + def get_map_location(self): + map_locations = { + 0: "Pallet Town", + 1: "Viridian City", + 2: "Pewter City", + 3: "Cerulean City", + 12: "Route 1", + 13: "Route 2", + 14: "Route 3", + 15: "Route 4", + 33: "Route 22", + 37: "Red house first", + 38: "Red house second", + 39: "Blues house", + 40: "oaks lab", + 41: "Pokémon Center (Viridian City)", + 42: "Poké Mart (Viridian City)", + 43: "School (Viridian City)", + 44: "House 1 (Viridian City)", + 47: "Gate (Viridian City/Pewter City) (Route 2)", + 49: "Gate (Route 2)", + 50: "Gate (Route 2/Viridian Forest) (Route 2)", + 51: "viridian forest", + 52: "Pewter Museum (floor 1)", + 53: "Pewter Museum (floor 2)", + 54: "Pokémon Gym (Pewter City)", + 55: "House with disobedient Nidoran♂ (Pewter City)", + 56: "Poké Mart (Pewter City)", + 57: "House with two Trainers (Pewter City)", + 58: "Pokémon Center (Pewter City)", + 59: "Mt. Moon (Route 3 entrance)", + 60: "Mt. Moon", + 61: "Mt. Moon", + 68: "Pokémon Center (Route 4)", + 193: "Badges check gate (Route 22)" + } + if self.read_map_n() in map_locations.keys(): + return map_locations[self.read_map_n()] + else: + return "Unknown Location" diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index e1133f731..7f649b0fd 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -1,8 +1,6 @@ - import sys -import uuid -import os -from math import floor, sqrt +import uuid +from math import floor import json from pathlib import Path @@ -11,28 +9,24 @@ import matplotlib.pyplot as plt from skimage.transform import resize from pyboy import PyBoy -#from pyboy.logger import log_level -import hnswlib import mediapy as media import pandas as pd +from rewards import Reward +from reader_pyboy import ReaderPyBoy from gymnasium import Env, spaces from pyboy.utils import WindowEvent -from memory_addresses import * -class RedGymEnv(Env): +class RedGymEnv(Env): def __init__( - self, config=None): + self, config=None): self.debug = config['debug'] self.s_path = config['session_path'] self.save_final_state = config['save_final_state'] - self.print_rewards = config['print_rewards'] - self.vec_dim = 4320 #1000 self.headless = config['headless'] - self.num_elements = 20000 # max self.init_state = config['init_state'] self.act_freq = config['action_freq'] self.max_steps = config['max_steps'] @@ -42,10 +36,8 @@ def __init__( self.video_interval = 256 * self.act_freq self.downsample_factor = 2 self.frame_stacks = 3 - self.explore_weight = 1 if 'explore_weight' not in config else config['explore_weight'] self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore'] - self.similar_frame_dist = config['sim_frame_dist'] - self.reward_scale = 1 if 'reward_scale' not in config else config['reward_scale'] + self.extra_buttons = False if 'extra_buttons' not in config else config['extra_buttons'] self.instance_id = str(uuid.uuid4())[:8] if 'instance_id' not in config else config['instance_id'] self.s_path.mkdir(exist_ok=True) @@ -54,7 +46,6 @@ def __init__( # Set this in SOME subclasses self.metadata = {"render.modes": []} - self.reward_range = (0, 15000) self.valid_actions = [ WindowEvent.PRESS_ARROW_DOWN, @@ -64,7 +55,7 @@ def __init__( WindowEvent.PRESS_BUTTON_A, WindowEvent.PRESS_BUTTON_B, ] - + if self.extra_buttons: self.valid_actions.extend([ WindowEvent.PRESS_BUTTON_START, @@ -89,8 +80,8 @@ def __init__( self.col_steps = 16 self.output_full = ( self.output_shape[0] * self.frame_stacks + 2 * (self.mem_padding + self.memory_height), - self.output_shape[1], - self.output_shape[2] + self.output_shape[1], + self.output_shape[2] ) # Set these in ALL subclasses @@ -99,20 +90,25 @@ def __init__( head = 'headless' if config['headless'] else 'SDL2' - #log_level("ERROR") + # log_level("ERROR") self.pyboy = PyBoy( - config['gb_path'], - debugging=False, - disable_input=False, - window_type=head, - hide_window='--quiet' in sys.argv, - ) - + config['gb_path'], + debugging=False, + disable_input=False, + window_type=head, + hide_window='--quiet' in sys.argv, + ) self.screen = self.pyboy.botsupport_manager().screen() if not config['headless']: self.pyboy.set_emulation_speed(6) - + + self.reader = ReaderPyBoy(self.pyboy) + + # Rewards + self.print_rewards = config['print_rewards'] + self.reward_service = Reward(config, self.reader, self.save_screenshot) + self.reset() def reset(self, seed=None, options=None): @@ -120,21 +116,18 @@ def reset(self, seed=None, options=None): # restart game, skipping credits with open(self.init_state, "rb") as f: self.pyboy.load_state(f) - - if self.use_screen_explore: - self.init_knn() - else: - self.init_map_mem() - self.recent_memory = np.zeros((self.output_shape[1]*self.memory_height, 3), dtype=np.uint8) - + self.reward_service.reset() + + self.recent_memory = np.zeros((self.output_shape[1] * self.memory_height, 3), dtype=np.uint8) + self.recent_frames = np.zeros( - (self.frame_stacks, self.output_shape[0], + (self.frame_stacks, self.output_shape[0], self.output_shape[1], self.output_shape[2]), dtype=np.uint8) self.agent_stats = [] - + if self.save_video: base_dir = self.s_path / Path('rollouts') base_dir.mkdir(exist_ok=True) @@ -144,45 +137,25 @@ def reset(self, seed=None, options=None): self.full_frame_writer.__enter__() self.model_frame_writer = media.VideoWriter(base_dir / model_name, self.output_full[:2], fps=60) self.model_frame_writer.__enter__() - - self.levels_satisfied = False - self.base_explore = 0 - self.max_opponent_level = 0 - self.max_event_rew = 0 - self.max_level_rew = 0 - self.last_health = 1 - self.total_healing_rew = 0 - self.died_count = 0 - self.party_size = 0 + self.step_count = 0 - self.progress_reward = self.get_game_state_reward() - self.total_reward = sum([val for _, val in self.progress_reward.items()]) + self.reset_count += 1 return self.render(), {} - - def init_knn(self): - # Declaring index - self.knn_index = hnswlib.Index(space='l2', dim=self.vec_dim) # possible options are l2, cosine or ip - # Initing index - the maximum number of elements should be known beforehand - self.knn_index.init_index( - max_elements=self.num_elements, ef_construction=100, M=16) - - def init_map_mem(self): - self.seen_coords = {} def render(self, reduce_res=True, add_memory=True, update_mem=True): - game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3) + game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3) if reduce_res: - game_pixels_render = (255*resize(game_pixels_render, self.output_shape)).astype(np.uint8) + game_pixels_render = (255 * resize(game_pixels_render, self.output_shape)).astype(np.uint8) if update_mem: self.recent_frames[0] = game_pixels_render if add_memory: pad = np.zeros( - shape=(self.mem_padding, self.output_shape[1], 3), + shape=(self.mem_padding, self.output_shape[1], 3), dtype=np.uint8) game_pixels_render = np.concatenate( ( - self.create_exploration_memory(), + self.create_exploration_memory(), pad, self.create_recent_memory(), pad, @@ -190,31 +163,24 @@ def render(self, reduce_res=True, add_memory=True, update_mem=True): ), axis=0) return game_pixels_render - + def step(self, action): self.run_action_on_emulator(action) self.append_agent_stats(action) self.recent_frames = np.roll(self.recent_frames, 1, axis=0) - obs_memory = self.render() + # OBSERVATION + + obs_memory = self.render() # trim off memory from frame for knn index frame_start = 2 * (self.memory_height + self.mem_padding) - obs_flat = obs_memory[ - frame_start:frame_start+self.output_shape[0], ...].flatten().astype(np.float32) + obs_flat = obs_memory[frame_start:frame_start + self.output_shape[0], ...].flatten().astype(np.float32) - if self.use_screen_explore: - self.update_frame_knn_index(obs_flat) - else: - self.update_seen_coords() - - self.update_heal_reward() - self.party_size = self.read_m(PARTY_SIZE_ADDRESS) + # REWARD - new_reward, new_prog = self.update_reward() - - self.last_health = self.read_hp_fraction() + reward_delta, new_prog = self.reward_service.update_rewards(obs_flat, self.step_count) # shift over short term reward memory self.recent_memory = np.roll(self.recent_memory, 3) @@ -222,13 +188,13 @@ def step(self, action): self.recent_memory[0, 1] = min(new_prog[1] * 64, 255) self.recent_memory[0, 2] = min(new_prog[2] * 128, 255) - step_limit_reached = self.check_if_done() - - self.save_and_print_info(step_limit_reached, obs_memory) + # DONE + step_limit_reached = self.check_if_done() + self.save_and_print_info(step_limit_reached, obs_memory, reward_delta) self.step_count += 1 - return obs_memory, new_reward*0.1, False, step_limit_reached, {} + return obs_memory, reward_delta * 0.1, False, step_limit_reached, {} def run_action_on_emulator(self, action): # press button then release after some steps @@ -249,107 +215,48 @@ def run_action_on_emulator(self, action): self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START) if self.save_video and not self.fast_video: self.add_video_frame() - if i == self.act_freq-1: + if i == self.act_freq - 1: self.pyboy._rendering(True) self.pyboy.tick() if self.save_video and self.fast_video: self.add_video_frame() - + def add_video_frame(self): self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False)) self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) - + def append_agent_stats(self, action): - x_pos = self.read_m(X_POS_ADDRESS) - y_pos = self.read_m(Y_POS_ADDRESS) - map_n = self.read_m(MAP_N_ADDRESS) - levels = [self.read_m(a) for a in LEVELS_ADDRESSES] + x_pos = self.reader.read_x_pos() + y_pos = self.reader.read_y_pos() + map_n = self.reader.read_map_n() + levels = self.reader.read_levels() if self.use_screen_explore: - expl = ('frames', self.knn_index.get_current_count()) + expl = ('frames', self.reward_service.knn_index.get_current_count()) else: - expl = ('coord_count', len(self.seen_coords)) + expl = ('coord_count', len(self.reward_service.seen_coords)) self.agent_stats.append({ 'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, - 'map_location': self.get_map_location(map_n), + 'map_location': self.reader.get_map_location(), 'last_action': action, - 'pcount': self.read_m(PARTY_SIZE_ADDRESS), - 'levels': levels, + 'pcount': self.reader.read_party_size_address(), + 'levels': levels, 'levels_sum': sum(levels), - 'ptypes': self.read_party(), - 'hp': self.read_hp_fraction(), + 'ptypes': self.reader.read_party(), + 'hp': self.reader.read_hp_fraction(), expl[0]: expl[1], - 'deaths': self.died_count, 'badge': self.get_badges(), - 'event': self.progress_reward['event'], 'healr': self.total_healing_rew + 'deaths': self.reward_service.died_count, + 'badge': self.reader.get_badges(), + 'event': self.reward_service.max_event_rew, + 'healr': self.reward_service.total_healing_rew }) - def update_frame_knn_index(self, frame_vec): - - if self.get_levels_sum() >= 22 and not self.levels_satisfied: - self.levels_satisfied = True - self.base_explore = self.knn_index.get_current_count() - self.init_knn() - - if self.knn_index.get_current_count() == 0: - # if index is empty add current frame - self.knn_index.add_items( - frame_vec, np.array([self.knn_index.get_current_count()]) - ) - else: - # check for nearest frame and add if current - labels, distances = self.knn_index.knn_query(frame_vec, k = 1) - if distances[0][0] > self.similar_frame_dist: - # print(f"distances[0][0] : {distances[0][0]} similar_frame_dist : {self.similar_frame_dist}") - self.knn_index.add_items( - frame_vec, np.array([self.knn_index.get_current_count()]) - ) - - def update_seen_coords(self): - x_pos = self.read_m(X_POS_ADDRESS) - y_pos = self.read_m(Y_POS_ADDRESS) - map_n = self.read_m(MAP_N_ADDRESS) - coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}" - if self.get_levels_sum() >= 22 and not self.levels_satisfied: - self.levels_satisfied = True - self.base_explore = len(self.seen_coords) - self.seen_coords = {} - - self.seen_coords[coord_string] = self.step_count - - def update_reward(self): - # compute reward - old_prog = self.group_rewards() - self.progress_reward = self.get_game_state_reward() - new_prog = self.group_rewards() - new_total = sum([val for _, val in self.progress_reward.items()]) #sqrt(self.explore_reward * self.progress_reward) - new_step = new_total - self.total_reward - if new_step < 0 and self.read_hp_fraction() > 0: - #print(f'\n\nreward went down! {self.progress_reward}\n\n') - self.save_screenshot('neg_reward') - - self.total_reward = new_total - return (new_step, - (new_prog[0]-old_prog[0], - new_prog[1]-old_prog[1], - new_prog[2]-old_prog[2]) - ) - - def group_rewards(self): - prog = self.progress_reward - # these values are only used by memory - return (prog['level'] * 100 / self.reward_scale, - self.read_hp_fraction()*2000, - prog['explore'] * 150 / (self.explore_weight * self.reward_scale)) - #(prog['events'], - # prog['levels'] + prog['party_xp'], - # prog['explore']) - def create_exploration_memory(self): w = self.output_shape[1] h = self.memory_height - + def make_reward_channel(r_val): col_steps = self.col_steps - max_r_val = (w-1) * h * col_steps + max_r_val = (w - 1) * h * col_steps # truncate progress bar. if hitting this # you should scale down the reward in group_rewards! r_val = min(r_val, max_r_val) @@ -360,26 +267,26 @@ def make_reward_channel(r_val): col = floor((r_val - row_covered) / col_steps) memory[:col, row] = 255 col_covered = col * col_steps - last_pixel = floor(r_val - row_covered - col_covered) + last_pixel = floor(r_val - row_covered - col_covered) memory[col, row] = last_pixel * (255 // col_steps) return memory - - level, hp, explore = self.group_rewards() + + level, hp, explore = self.reward_service.group_rewards_lvl_hp_explore(self.reward_service.get_game_state_rewards()) full_memory = np.stack(( make_reward_channel(level), make_reward_channel(hp), make_reward_channel(explore) ), axis=-1) - - if self.get_badges() > 0: + + if self.reader.get_badges() > 0: full_memory[:, -1, :] = 255 return full_memory def create_recent_memory(self): return rearrange( - self.recent_memory, - '(w h) c -> h w c', + self.recent_memory, + '(w h) c -> h w c', h=self.memory_height) def check_if_done(self): @@ -395,14 +302,15 @@ def check_if_done(self): def save_and_print_info(self, done, obs_memory): if self.print_rewards: prog_string = f'step: {self.step_count:6d}' - for key, val in self.progress_reward.items(): + rewards_state = self.reward_service.get_game_state_rewards() + for key, val in rewards_state.items(): prog_string += f' {key}: {val:5.2f}' - prog_string += f' sum: {self.total_reward:5.2f}' + prog_string += f' sum: {self.reward_service.total_reward:5.2f}' print(f'\r{prog_string}', end='', flush=True) - + if self.step_count % 50 == 0: plt.imsave( - self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), + self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), self.render(reduce_res=False)) if self.print_rewards and done: @@ -411,10 +319,10 @@ def save_and_print_info(self, done, obs_memory): fs_path = self.s_path / Path('final_states') fs_path.mkdir(exist_ok=True) plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'), + fs_path / Path(f'frame_r{self.reward_service.total_reward:.4f}_{self.reset_count}_small.jpeg'), obs_memory) plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'), + fs_path / Path(f'frame_r{self.reward_service.total_reward:.4f}_{self.reset_count}_full.jpeg'), self.render(reduce_res=False)) if self.save_video and done: @@ -422,204 +330,17 @@ def save_and_print_info(self, done, obs_memory): self.model_frame_writer.close() if done: - self.all_runs.append(self.progress_reward) + self.all_runs.append(self.reward_service.get_game_state_rewards()) with open(self.s_path / Path(f'all_runs_{self.instance_id}.json'), 'w') as f: json.dump(self.all_runs, f) pd.DataFrame(self.agent_stats).to_csv( self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') - - def read_m(self, addr): - return self.pyboy.get_memory_value(addr) - - def read_bit(self, addr, bit: int) -> bool: - # add padding so zero will read '0b100000000' instead of '0b0' - return bin(256 + self.read_m(addr))[-bit-1] == '1' - - def get_levels_sum(self): - poke_levels = [max(self.read_m(a) - 2, 0) for a in LEVELS_ADDRESSES] - return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level - - def get_levels_reward(self): - explore_thresh = 22 - scale_factor = 4 - level_sum = self.get_levels_sum() - if level_sum < explore_thresh: - scaled = level_sum - else: - scaled = (level_sum-explore_thresh) / scale_factor + explore_thresh - self.max_level_rew = max(self.max_level_rew, scaled) - return self.max_level_rew - - def get_knn_reward(self): - - pre_rew = self.explore_weight * 0.005 - post_rew = self.explore_weight * 0.01 - cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords) - base = (self.base_explore if self.levels_satisfied else cur_size) * pre_rew - post = (cur_size if self.levels_satisfied else 0) * post_rew - return base + post - - def get_badges(self): - return self.bit_count(self.read_m(BADGE_COUNT_ADDRESS)) - - def read_party(self): - return [self.read_m(addr) for addr in PARTY_ADDRESSES] - - def update_heal_reward(self): - cur_health = self.read_hp_fraction() - # if health increased and party size did not change - if (cur_health > self.last_health and - self.read_m(PARTY_SIZE_ADDRESS) == self.party_size): - if self.last_health > 0: - heal_amount = cur_health - self.last_health - if heal_amount > 0.5: - print(f'healed: {heal_amount}') - self.save_screenshot('healing') - self.total_healing_rew += heal_amount * 4 - else: - self.died_count += 1 - - def get_all_events_reward(self): - # adds up all event flags, exclude museum ticket - event_flags_start = EVENT_FLAGS_START_ADDRESS - event_flags_end = EVENT_FLAGS_END_ADDRESS - museum_ticket = (MUSEUM_TICKET_ADDRESS, 0) - base_event_flags = 13 - return max( - sum( - [ - self.bit_count(self.read_m(i)) - for i in range(event_flags_start, event_flags_end) - ] - ) - - base_event_flags - - int(self.read_bit(museum_ticket[0], museum_ticket[1])), - 0, - ) - - def get_game_state_reward(self, print_stats=False): - # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map - # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm - ''' - num_poke = self.read_m(0xD163) - poke_xps = [self.read_triple(a) for a in [0xD179, 0xD1A5, 0xD1D1, 0xD1FD, 0xD229, 0xD255]] - #money = self.read_money() - 975 # subtract starting money - seen_poke_count = sum([self.bit_count(self.read_m(i)) for i in range(0xD30A, 0xD31D)]) - all_events_score = sum([self.bit_count(self.read_m(i)) for i in range(0xD747, 0xD886)]) - oak_parcel = self.read_bit(0xD74E, 1) - oak_pokedex = self.read_bit(0xD74B, 5) - opponent_level = self.read_m(0xCFF3) - self.max_opponent_level = max(self.max_opponent_level, opponent_level) - enemy_poke_count = self.read_m(0xD89C) - self.max_opponent_poke = max(self.max_opponent_poke, enemy_poke_count) - - if print_stats: - print(f'num_poke : {num_poke}') - print(f'poke_levels : {poke_levels}') - print(f'poke_xps : {poke_xps}') - #print(f'money: {money}') - print(f'seen_poke_count : {seen_poke_count}') - print(f'oak_parcel: {oak_parcel} oak_pokedex: {oak_pokedex} all_events_score: {all_events_score}') - ''' - - state_scores = { - 'event': self.reward_scale*self.update_max_event_rew(), - #'party_xp': self.reward_scale*0.1*sum(poke_xps), - 'level': self.reward_scale*self.get_levels_reward(), - 'heal': self.reward_scale*self.total_healing_rew, - 'op_lvl': self.reward_scale*self.update_max_op_level(), - 'dead': self.reward_scale*-0.1*self.died_count, - 'badge': self.reward_scale*self.get_badges() * 5, - #'op_poke': self.reward_scale*self.max_opponent_poke * 800, - #'money': self.reward_scale* money * 3, - #'seen_poke': self.reward_scale * seen_poke_count * 400, - 'explore': self.reward_scale * self.get_knn_reward() - } - - return state_scores - + def save_screenshot(self, name): ss_dir = self.s_path / Path('screenshots') ss_dir.mkdir(exist_ok=True) plt.imsave( - ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), - self.render(reduce_res=False)) - - def update_max_op_level(self): - #opponent_level = self.read_m(0xCFE8) - 5 # base level - opponent_level = max([self.read_m(a) for a in OPPONENT_LEVELS_ADDRESSES]) - 5 - #if opponent_level >= 7: - # self.save_screenshot('highlevelop') - self.max_opponent_level = max(self.max_opponent_level, opponent_level) - return self.max_opponent_level * 0.2 - - def update_max_event_rew(self): - cur_rew = self.get_all_events_reward() - self.max_event_rew = max(cur_rew, self.max_event_rew) - return self.max_event_rew - - def read_hp_fraction(self): - hp_sum = sum([self.read_hp(add) for add in HP_ADDRESSES]) - max_hp_sum = sum([self.read_hp(add) for add in MAX_HP_ADDRESSES]) - max_hp_sum = max(max_hp_sum, 1) - return hp_sum / max_hp_sum - - def read_hp(self, start): - return 256 * self.read_m(start) + self.read_m(start+1) - - # built-in since python 3.10 - def bit_count(self, bits): - return bin(bits).count('1') - - def read_triple(self, start_add): - return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2) - - def read_bcd(self, num): - return 10 * ((num >> 4) & 0x0f) + (num & 0x0f) - - def read_money(self): - return (100 * 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_1)) + - 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_2)) + - self.read_bcd(self.read_m(MONEY_ADDRESS_3))) - - def get_map_location(self, map_idx): - map_locations = { - 0: "Pallet Town", - 1: "Viridian City", - 2: "Pewter City", - 3: "Cerulean City", - 12: "Route 1", - 13: "Route 2", - 14: "Route 3", - 15: "Route 4", - 33: "Route 22", - 37: "Red house first", - 38: "Red house second", - 39: "Blues house", - 40: "oaks lab", - 41: "Pokémon Center (Viridian City)", - 42: "Poké Mart (Viridian City)", - 43: "School (Viridian City)", - 44: "House 1 (Viridian City)", - 47: "Gate (Viridian City/Pewter City) (Route 2)", - 49: "Gate (Route 2)", - 50: "Gate (Route 2/Viridian Forest) (Route 2)", - 51: "viridian forest", - 52: "Pewter Museum (floor 1)", - 53: "Pewter Museum (floor 2)", - 54: "Pokémon Gym (Pewter City)", - 55: "House with disobedient Nidoran♂ (Pewter City)", - 56: "Poké Mart (Pewter City)", - 57: "House with two Trainers (Pewter City)", - 58: "Pokémon Center (Pewter City)", - 59: "Mt. Moon (Route 3 entrance)", - 60: "Mt. Moon", - 61: "Mt. Moon", - 68: "Pokémon Center (Route 4)", - 193: "Badges check gate (Route 22)" - } - if map_idx in map_locations.keys(): - return map_locations[map_idx] - else: - return "Unknown Location" - + ss_dir / Path( + f'frame{self.instance_id}_r{self.reward_service.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), + self.render(reduce_res=False) + ) diff --git a/baselines/rewards.py b/baselines/rewards.py new file mode 100644 index 000000000..9d50a0604 --- /dev/null +++ b/baselines/rewards.py @@ -0,0 +1,219 @@ +from memory_addresses import EVENT_FLAGS_START_ADDRESS, EVENT_FLAGS_END_ADDRESS, MUSEUM_TICKET_ADDRESS +import hnswlib +import numpy as np + + +class Reward: + + def __init__(self, config, reader, save_screenshot): + self.save_screenshot = save_screenshot + self.reward_scale = 1 if 'reward_scale' not in config else config['reward_scale'] + self.reward_range = (0, 15000) + self.reader = reader + + # Pokedex + self.seen_pokemons_rew = 0 + + # Level + self.max_level_rew = 0 + self.levels_satisfied = False + self.max_opponent_level = 0 + + # Event + self.max_event_rew = 0 + + # Health + self.total_healing_rew = 0 + self.last_party_size = 0 + self.last_health = 1 + self.died_count = 0 + + # Explore + self.base_explore = 0 + self.explore_weight = 1 if 'explore_weight' not in config else config['explore_weight'] + self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore'] + self.similar_frame_dist = config['sim_frame_dist'] + self.vec_dim = 4320 # 1000 + self.num_elements = 20000 # max + self.knn_index = None + self.init_knn() + self.seen_coords = {} + self.init_map_mem() + + self.last_game_state_rewards = self.get_game_state_rewards() + self.total_reward = 0 + + def init_knn(self): + # Declaring index + self.knn_index = hnswlib.Index(space='l2', dim=self.vec_dim) # possible options are l2, cosine or ip + # Initing index - the maximum number of elements should be known beforehand + self.knn_index.init_index(max_elements=self.num_elements, ef_construction=100, M=16) + + def init_map_mem(self): + self.seen_coords = {} + + def get_exploration_reward(self): + pre_rew = self.explore_weight * 0.005 + post_rew = self.explore_weight * 0.01 + cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords) + base = (self.base_explore if self.levels_satisfied else cur_size) * pre_rew + post = (cur_size if self.levels_satisfied else 0) * post_rew + return base + post + + def get_all_events_reward(self): + # adds up all event flags, exclude museum ticket + event_flags_start = EVENT_FLAGS_START_ADDRESS + event_flags_end = EVENT_FLAGS_END_ADDRESS + museum_ticket = (MUSEUM_TICKET_ADDRESS, 0) + base_event_flags = 13 + return max( + sum( + [ + self.reader.bit_count(self.reader.read_m(i)) + for i in range(event_flags_start, event_flags_end) + ] + ) + - base_event_flags + - int(self.reader.read_bit(museum_ticket[0], museum_ticket[1])), + 0, + ) + + def get_game_state_rewards(self): + # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map + # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm + return { + 'event': self.reward_scale * self.max_event_rew, + 'level': self.reward_scale * self.max_level_rew, + 'heal': self.reward_scale * self.total_healing_rew, + 'op_lvl': self.reward_scale * self.max_opponent_level, + 'dead': self.reward_scale * -0.1 * self.died_count, + 'badge': self.reward_scale * self.reader.get_badges() * 5, + 'explore': self.reward_scale * self.get_exploration_reward(), + # 'party_xp': self.reward_scale*0.1*sum(poke_xps), + # 'op_poke': self.reward_scale*self.max_opponent_poke * 800, + # 'money': self.reward_scale* money * 3, + 'seen_poke': self.reward_scale * self.seen_pokemons_rew + } + + def group_rewards_lvl_hp_explore(self, rewards): + return (rewards['level'] * 100 / self.reward_scale, + self.reader.read_hp_fraction() * 2000, + rewards['explore'] * 150 / (self.explore_weight * self.reward_scale)) + + def update_rewards(self, obs_flat, step_count): + if self.use_screen_explore: + self.update_frame_knn_index(obs_flat) + else: + self.update_seen_coords(step_count) + + self.update_max_event_rew() + self.update_heal_reward() + self.update_max_op_level() + self.update_seen_pokemons() + self.update_max_level_reward() + + return self.update_state_reward() + + def update_state_reward(self): + # compute reward + last_total = sum([val for _, val in self.last_game_state_rewards.items()]) + new_total = sum([val for _, val in self.get_game_state_rewards().items()]) + self.total_reward = new_total + reward_delta = new_total - last_total + if reward_delta < 0 and self.reader.read_hp_fraction() > 0: + self.save_screenshot('neg_reward') + + self.last_game_state_rewards = self.get_game_state_rewards() + + # used by memory + old_prog = self.group_rewards_lvl_hp_explore(self.last_game_state_rewards) + new_prog = self.group_rewards_lvl_hp_explore(self.get_game_state_rewards()) + return reward_delta, (new_prog[0] - old_prog[0],new_prog[1] - old_prog[1], new_prog[2] - old_prog[2]) + + def update_max_op_level(self): + opponent_level = self.reader.get_opponent_level() + self.max_opponent_level = max(self.max_opponent_level, opponent_level) + return self.max_opponent_level * 0.2 + + def update_seen_pokemons(self): + initial_seen_pokemon = 3 + self.seen_pokemons_rew = sum(self.reader.read_seen_pokemons()) - initial_seen_pokemon + + def update_max_level_reward(self): + explore_thresh = 22 + scale_factor = 4 + level_sum = self.reader.get_levels_sum() + if level_sum < explore_thresh: + scaled = level_sum + else: + scaled = (level_sum-explore_thresh) / scale_factor + explore_thresh + # always keeping the max, lvl can't decrease + self.max_level_rew = max(self.max_level_rew, scaled) + return self.max_level_rew + + def update_frame_knn_index(self, frame_vec): + + if self.reader.get_levels_sum() >= 22 and not self.levels_satisfied: + self.levels_satisfied = True + self.base_explore = self.knn_index.get_current_count() + self.init_knn() + + if self.knn_index.get_current_count() == 0: + # if index is empty add current frame + self.knn_index.add_items( + frame_vec, np.array([self.knn_index.get_current_count()]) + ) + else: + # check for nearest frame and add if current + _, distances = self.knn_index.knn_query(frame_vec, k=1) + if distances[0][0] > self.similar_frame_dist: + # print(f"distances[0][0] : {distances[0][0]} similar_frame_dist : {self.similar_frame_dist}") + self.knn_index.add_items( + frame_vec, np.array([self.knn_index.get_current_count()]) + ) + + def update_seen_coords(self, step_count): + x_pos = self.reader.read_x_pos() + y_pos = self.reader.read_y_pos() + map_n = self.reader.read_map_n() + coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}" + if self.reader.get_levels_sum() >= 22 and not self.levels_satisfied: + self.levels_satisfied = True + self.base_explore = len(self.seen_coords) + self.seen_coords = {} + + self.seen_coords[coord_string] = step_count + + def update_heal_reward(self): + cur_health = self.reader.read_hp_fraction() + # if health increased and party size did not change + if (cur_health > self.last_health and + self.reader.read_party_size_address() == self.last_party_size): + if self.last_health > 0: + heal_amount = cur_health - self.last_health + if heal_amount > 0.5: + print(f'healed: {heal_amount}') + self.save_screenshot('healing') + self.total_healing_rew += heal_amount * 4 + else: + self.died_count += 1 + self.last_party_size = self.reader.read_party_size_address() + self.last_health = self.reader.read_hp_fraction() + + def update_max_event_rew(self): + cur_rew = self.get_all_events_reward() + self.max_event_rew = max(cur_rew, self.max_event_rew) + + def reset(self): + self.max_event_rew = 0 + self.max_level_rew = 0 + self.total_healing_rew = 0 + self.max_opponent_level = 0 + self.died_count = 0 + self.seen_pokemons_rew = 0 + if self.use_screen_explore: + self.init_knn() + else: + self.init_map_mem() + self.last_game_state_rewards = self.get_game_state_rewards() + self.total_reward = 0 diff --git a/baselines/run_baseline_parallel.py b/baselines/run_baseline_parallel.py index f4423a3a5..1deed1328 100644 --- a/baselines/run_baseline_parallel.py +++ b/baselines/run_baseline_parallel.py @@ -1,63 +1,27 @@ -from os.path import exists from pathlib import Path -import uuid -from red_gym_env import RedGymEnv -from stable_baselines3 import A2C, PPO -from stable_baselines3.common import env_checker -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv -from stable_baselines3.common.utils import set_random_seed +from datetime import datetime from stable_baselines3.common.callbacks import CheckpointCallback - -def make_env(rank, env_conf, seed=0): - """ - Utility function for multiprocessed env. - :param env_id: (str) the environment ID - :param num_env: (int) the number of environments you wish to have in subprocesses - :param seed: (int) the initial seed for RNG - :param rank: (int) index of the subprocess - """ - def _init(): - env = RedGymEnv(env_conf) - env.reset(seed=(seed + rank)) - return env - set_random_seed(seed) - return _init +from baselines_utils import load_or_create_model, get_formatted_timestamp if __name__ == '__main__': - ep_length = 2048 * 8 - sess_path = Path(f'session_{str(uuid.uuid4())[:8]}') - + sess_path = Path(f'session_{datetime.now().strftime("%Y%m%d_%H%M")}') + #pretrained_session = "session_4da05e87_main_good/poke_439746560_steps" + model_to_load_path = 'session_e41c9eff/poke_38207488_steps' #'session_e41c9eff/poke_250871808_steps' env_config = { 'headless': True, 'save_final_state': True, 'early_stop': False, - 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, + 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, 'print_rewards': True, 'save_video': False, 'fast_video': True, 'session_path': sess_path, - 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0, + 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0, 'use_screen_explore': True, 'extra_buttons': False } - - + num_cpu = 44 #64 #46 # Also sets the number of episodes per training iteration - env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) - - checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, - name_prefix='poke') - #env_checker.check_env(env) + model = load_or_create_model(model_to_load_path, env_config, ep_length, num_cpu) + + checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, name_prefix='poke') learn_steps = 40 - file_name = 'session_e41c9eff/poke_38207488_steps' #'session_e41c9eff/poke_250871808_steps' - - #'session_bfdca25a/poke_42532864_steps' #'session_d3033abb/poke_47579136_steps' #'session_a17cc1f5/poke_33546240_steps' #'session_e4bdca71/poke_8945664_steps' #'session_eb21989e/poke_40255488_steps' #'session_80f70ab4/poke_58982400_steps' - if exists(file_name + '.zip'): - print('\nloading checkpoint') - model = PPO.load(file_name, env=env) - model.n_steps = ep_length - model.n_envs = num_cpu - model.rollout_buffer.buffer_size = ep_length - model.rollout_buffer.n_envs = num_cpu - model.rollout_buffer.reset() - else: - model = PPO('CnnPolicy', env, verbose=1, n_steps=ep_length, batch_size=512, n_epochs=1, gamma=0.999) - + for i in range(learn_steps): - model.learn(total_timesteps=(ep_length)*num_cpu*1000, callback=checkpoint_callback) + model.learn(total_timesteps=ep_length * num_cpu * 1000, callback=checkpoint_callback) From fd91ab5059290d78d079391e85f3ca88a8cf26f5 Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Tue, 27 Feb 2024 19:52:23 +0100 Subject: [PATCH 02/10] fix oopsies Signed-off-by: Mathieu D --- baselines/red_gym_env.py | 2 +- baselines/run_baseline_parallel.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index 7f649b0fd..9efb592e0 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -191,7 +191,7 @@ def step(self, action): # DONE step_limit_reached = self.check_if_done() - self.save_and_print_info(step_limit_reached, obs_memory, reward_delta) + self.save_and_print_info(step_limit_reached, obs_memory) self.step_count += 1 return obs_memory, reward_delta * 0.1, False, step_limit_reached, {} diff --git a/baselines/run_baseline_parallel.py b/baselines/run_baseline_parallel.py index 1deed1328..94ab8f0d6 100644 --- a/baselines/run_baseline_parallel.py +++ b/baselines/run_baseline_parallel.py @@ -1,7 +1,7 @@ from pathlib import Path from datetime import datetime from stable_baselines3.common.callbacks import CheckpointCallback -from baselines_utils import load_or_create_model, get_formatted_timestamp +from baselines_utils import load_or_create_model if __name__ == '__main__': From 150e4542c5f6b62e7e9433a13a2d1c3f8e524cb8 Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Wed, 28 Feb 2024 10:59:21 +0100 Subject: [PATCH 03/10] weight rewards at the same place Signed-off-by: Mathieu D --- baselines/red_gym_env.py | 10 +++--- baselines/rewards.py | 70 +++++++++++++++++++++------------------- 2 files changed, 41 insertions(+), 39 deletions(-) diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index 9efb592e0..de5d86902 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -190,11 +190,11 @@ def step(self, action): # DONE - step_limit_reached = self.check_if_done() - self.save_and_print_info(step_limit_reached, obs_memory) + done = self.check_if_done() + self.save_and_print_info(done, obs_memory) self.step_count += 1 - return obs_memory, reward_delta * 0.1, False, step_limit_reached, {} + return obs_memory, reward_delta * 0.1, False, done, {} def run_action_on_emulator(self, action): # press button then release after some steps @@ -246,8 +246,8 @@ def append_agent_stats(self, action): expl[0]: expl[1], 'deaths': self.reward_service.died_count, 'badge': self.reader.get_badges(), - 'event': self.reward_service.max_event_rew, - 'healr': self.reward_service.total_healing_rew + 'event': self.reward_service.max_event, + 'healr': self.reward_service.total_healing }) def create_exploration_memory(self): diff --git a/baselines/rewards.py b/baselines/rewards.py index 9d50a0604..4237197a4 100644 --- a/baselines/rewards.py +++ b/baselines/rewards.py @@ -12,18 +12,18 @@ def __init__(self, config, reader, save_screenshot): self.reader = reader # Pokedex - self.seen_pokemons_rew = 0 + self.seen_pokemons = 0 # Level - self.max_level_rew = 0 + self.max_level = 0 self.levels_satisfied = False self.max_opponent_level = 0 # Event - self.max_event_rew = 0 + self.max_event = 0 # Health - self.total_healing_rew = 0 + self.total_healing = 0 self.last_party_size = 0 self.last_health = 1 self.died_count = 0 @@ -39,6 +39,7 @@ def __init__(self, config, reader, save_screenshot): self.init_knn() self.seen_coords = {} self.init_map_mem() + self.explore_reward = 0 self.last_game_state_rewards = self.get_game_state_rewards() self.total_reward = 0 @@ -52,15 +53,16 @@ def init_knn(self): def init_map_mem(self): self.seen_coords = {} - def get_exploration_reward(self): + def update_exploration_reward(self): pre_rew = self.explore_weight * 0.005 post_rew = self.explore_weight * 0.01 cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords) - base = (self.base_explore if self.levels_satisfied else cur_size) * pre_rew - post = (cur_size if self.levels_satisfied else 0) * post_rew - return base + post + if not self.levels_satisfied: + self.explore_reward = cur_size * pre_rew + else: + self.explore_reward = (self.base_explore * pre_rew) + (cur_size * post_rew) - def get_all_events_reward(self): + def get_all_events_flags(self): # adds up all event flags, exclude museum ticket event_flags_start = EVENT_FLAGS_START_ADDRESS event_flags_end = EVENT_FLAGS_END_ADDRESS @@ -82,17 +84,17 @@ def get_game_state_rewards(self): # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm return { - 'event': self.reward_scale * self.max_event_rew, - 'level': self.reward_scale * self.max_level_rew, - 'heal': self.reward_scale * self.total_healing_rew, - 'op_lvl': self.reward_scale * self.max_opponent_level, - 'dead': self.reward_scale * -0.1 * self.died_count, + 'event': self.reward_scale * self.max_event * 1, + 'level': self.reward_scale * self.max_level * 1, + 'heal': self.reward_scale * self.total_healing * 4, + 'op_lvl': self.reward_scale * self.max_opponent_level * 1, + 'dead': self.reward_scale * self.died_count * -0.1, 'badge': self.reward_scale * self.reader.get_badges() * 5, - 'explore': self.reward_scale * self.get_exploration_reward(), + 'explore': self.reward_scale * self.explore_reward, # 'party_xp': self.reward_scale*0.1*sum(poke_xps), # 'op_poke': self.reward_scale*self.max_opponent_poke * 800, # 'money': self.reward_scale* money * 3, - 'seen_poke': self.reward_scale * self.seen_pokemons_rew + 'seen_poke': self.reward_scale * self.seen_pokemons } def group_rewards_lvl_hp_explore(self, rewards): @@ -105,12 +107,12 @@ def update_rewards(self, obs_flat, step_count): self.update_frame_knn_index(obs_flat) else: self.update_seen_coords(step_count) - - self.update_max_event_rew() - self.update_heal_reward() + self.update_exploration_reward() + self.update_max_event() + self.update_total_heal_and_death() self.update_max_op_level() self.update_seen_pokemons() - self.update_max_level_reward() + self.update_max_level() return self.update_state_reward() @@ -128,7 +130,7 @@ def update_state_reward(self): # used by memory old_prog = self.group_rewards_lvl_hp_explore(self.last_game_state_rewards) new_prog = self.group_rewards_lvl_hp_explore(self.get_game_state_rewards()) - return reward_delta, (new_prog[0] - old_prog[0],new_prog[1] - old_prog[1], new_prog[2] - old_prog[2]) + return reward_delta, (new_prog[0] - old_prog[0], new_prog[1] - old_prog[1], new_prog[2] - old_prog[2]) def update_max_op_level(self): opponent_level = self.reader.get_opponent_level() @@ -137,9 +139,9 @@ def update_max_op_level(self): def update_seen_pokemons(self): initial_seen_pokemon = 3 - self.seen_pokemons_rew = sum(self.reader.read_seen_pokemons()) - initial_seen_pokemon + self.seen_pokemons = sum(self.reader.read_seen_pokemons()) - initial_seen_pokemon - def update_max_level_reward(self): + def update_max_level(self): explore_thresh = 22 scale_factor = 4 level_sum = self.reader.get_levels_sum() @@ -148,8 +150,8 @@ def update_max_level_reward(self): else: scaled = (level_sum-explore_thresh) / scale_factor + explore_thresh # always keeping the max, lvl can't decrease - self.max_level_rew = max(self.max_level_rew, scaled) - return self.max_level_rew + self.max_level = max(self.max_level, scaled) + return self.max_level def update_frame_knn_index(self, frame_vec): @@ -184,7 +186,7 @@ def update_seen_coords(self, step_count): self.seen_coords[coord_string] = step_count - def update_heal_reward(self): + def update_total_heal_and_death(self): cur_health = self.reader.read_hp_fraction() # if health increased and party size did not change if (cur_health > self.last_health and @@ -194,23 +196,23 @@ def update_heal_reward(self): if heal_amount > 0.5: print(f'healed: {heal_amount}') self.save_screenshot('healing') - self.total_healing_rew += heal_amount * 4 + self.total_healing += heal_amount else: self.died_count += 1 self.last_party_size = self.reader.read_party_size_address() self.last_health = self.reader.read_hp_fraction() - def update_max_event_rew(self): - cur_rew = self.get_all_events_reward() - self.max_event_rew = max(cur_rew, self.max_event_rew) + def update_max_event(self): + cur_rew = self.get_all_events_flags() + self.max_event = max(cur_rew, self.max_event) def reset(self): - self.max_event_rew = 0 - self.max_level_rew = 0 - self.total_healing_rew = 0 + self.max_event = 0 + self.max_level = 0 + self.total_healing = 0 self.max_opponent_level = 0 self.died_count = 0 - self.seen_pokemons_rew = 0 + self.seen_pokemons = 0 if self.use_screen_explore: self.init_knn() else: From e3546f31e747b5da01410da76f8617512e84f468 Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Thu, 29 Feb 2024 11:19:32 +0100 Subject: [PATCH 04/10] events Signed-off-by: Mathieu D --- baselines/reader_pyboy.py | 10 ++++++++ baselines/rewards.py | 38 ++++++++---------------------- baselines/run_baseline_parallel.py | 11 +++++---- 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/baselines/reader_pyboy.py b/baselines/reader_pyboy.py index f7ebbe64b..b3318083e 100644 --- a/baselines/reader_pyboy.py +++ b/baselines/reader_pyboy.py @@ -62,6 +62,16 @@ def read_y_pos(self): def read_map_n(self): return self.read_m(MAP_N_ADDRESS) + def read_events(self): + return [ + self.bit_count(self.read_m(i)) + for i in range(EVENT_FLAGS_START_ADDRESS, EVENT_FLAGS_END_ADDRESS) + ] + + def read_museum_tickets(self): + museum_ticket = (MUSEUM_TICKET_ADDRESS, 0) + return self.read_bit(museum_ticket[0], museum_ticket[1]) + def read_levels(self): return [self.read_m(a) for a in LEVELS_ADDRESSES] diff --git a/baselines/rewards.py b/baselines/rewards.py index 4237197a4..118d61d38 100644 --- a/baselines/rewards.py +++ b/baselines/rewards.py @@ -64,39 +64,30 @@ def update_exploration_reward(self): def get_all_events_flags(self): # adds up all event flags, exclude museum ticket - event_flags_start = EVENT_FLAGS_START_ADDRESS - event_flags_end = EVENT_FLAGS_END_ADDRESS - museum_ticket = (MUSEUM_TICKET_ADDRESS, 0) base_event_flags = 13 - return max( - sum( - [ - self.reader.bit_count(self.reader.read_m(i)) - for i in range(event_flags_start, event_flags_end) - ] - ) - - base_event_flags - - int(self.reader.read_bit(museum_ticket[0], museum_ticket[1])), - 0, - ) + return max(0, sum(self.reader.read_events()) - base_event_flags - int(self.reader.read_museum_tickets())) def get_game_state_rewards(self): # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm return { 'event': self.reward_scale * self.max_event * 1, - 'level': self.reward_scale * self.max_level * 1, + 'level': self.reward_scale * self.compute_level_reward() * 1, 'heal': self.reward_scale * self.total_healing * 4, 'op_lvl': self.reward_scale * self.max_opponent_level * 1, 'dead': self.reward_scale * self.died_count * -0.1, 'badge': self.reward_scale * self.reader.get_badges() * 5, - 'explore': self.reward_scale * self.explore_reward, + 'explore': self.reward_scale * self.explore_reward,git # 'party_xp': self.reward_scale*0.1*sum(poke_xps), # 'op_poke': self.reward_scale*self.max_opponent_poke * 800, # 'money': self.reward_scale* money * 3, - 'seen_poke': self.reward_scale * self.seen_pokemons + # 'seen_poke': self.reward_scale * self.seen_pokemons } + # Levels count only quarter after 22 threshold + def compute_level_reward(self): + return int(min(22, self.max_level) + (max(0, (self.max_level - 22)) / 4)) + def group_rewards_lvl_hp_explore(self, rewards): return (rewards['level'] * 100 / self.reward_scale, self.reader.read_hp_fraction() * 2000, @@ -135,23 +126,14 @@ def update_state_reward(self): def update_max_op_level(self): opponent_level = self.reader.get_opponent_level() self.max_opponent_level = max(self.max_opponent_level, opponent_level) - return self.max_opponent_level * 0.2 def update_seen_pokemons(self): initial_seen_pokemon = 3 self.seen_pokemons = sum(self.reader.read_seen_pokemons()) - initial_seen_pokemon def update_max_level(self): - explore_thresh = 22 - scale_factor = 4 - level_sum = self.reader.get_levels_sum() - if level_sum < explore_thresh: - scaled = level_sum - else: - scaled = (level_sum-explore_thresh) / scale_factor + explore_thresh - # always keeping the max, lvl can't decrease - self.max_level = max(self.max_level, scaled) - return self.max_level + # lvl can't decrease + self.max_level = max(self.max_level, self.reader.get_levels_sum()) def update_frame_knn_index(self, frame_vec): diff --git a/baselines/run_baseline_parallel.py b/baselines/run_baseline_parallel.py index 94ab8f0d6..927ac1c7d 100644 --- a/baselines/run_baseline_parallel.py +++ b/baselines/run_baseline_parallel.py @@ -6,9 +6,10 @@ if __name__ == '__main__': ep_length = 2048 * 8 - sess_path = Path(f'session_{datetime.now().strftime("%Y%m%d_%H%M")}') - #pretrained_session = "session_4da05e87_main_good/poke_439746560_steps" - model_to_load_path = 'session_e41c9eff/poke_38207488_steps' #'session_e41c9eff/poke_250871808_steps' + sess_path = Path(f'sessions/session_{datetime.now().strftime("%Y%m%d_%H%M")}') + pretrained_model = 'session_4da05e87_main_good/poke_439746560_steps' + model_i_like = 'session_20240227_1952/poke_720896_steps' + model_to_load_path = 'sessions/' + pretrained_model env_config = { 'headless': True, 'save_final_state': True, 'early_stop': False, 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, @@ -17,11 +18,11 @@ 'use_screen_explore': True, 'extra_buttons': False } - num_cpu = 44 #64 #46 # Also sets the number of episodes per training iteration + num_cpu = 1 #64 #46 # Also sets the number of episodes per training iteration model = load_or_create_model(model_to_load_path, env_config, ep_length, num_cpu) checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, name_prefix='poke') - learn_steps = 40 + learn_steps = 10 for i in range(learn_steps): model.learn(total_timesteps=ep_length * num_cpu * 1000, callback=checkpoint_callback) From b88728e5ec2f6a32e9ca76b21cae063c65564b2d Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Thu, 29 Feb 2024 14:57:22 +0100 Subject: [PATCH 05/10] extract renderer Signed-off-by: Mathieu D --- baselines/red_gym_env.py | 210 ++++++--------------------------------- baselines/renderer.py | 164 ++++++++++++++++++++++++++++++ baselines/rewards.py | 2 +- 3 files changed, 195 insertions(+), 181 deletions(-) create mode 100644 baselines/renderer.py diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index de5d86902..09a81e5c8 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -1,16 +1,9 @@ import sys -import uuid -from math import floor -import json -from pathlib import Path import numpy as np -from einops import rearrange -import matplotlib.pyplot as plt -from skimage.transform import resize from pyboy import PyBoy -import mediapy as media -import pandas as pd + +from renderer import Renderer from rewards import Reward from reader_pyboy import ReaderPyBoy @@ -20,29 +13,23 @@ class RedGymEnv(Env): - def __init__( - self, config=None): + def __init__(self, config=None): self.debug = config['debug'] - self.s_path = config['session_path'] - self.save_final_state = config['save_final_state'] + self.headless = config['headless'] self.init_state = config['init_state'] self.act_freq = config['action_freq'] self.max_steps = config['max_steps'] self.early_stopping = config['early_stop'] - self.save_video = config['save_video'] self.fast_video = config['fast_video'] self.video_interval = 256 * self.act_freq self.downsample_factor = 2 - self.frame_stacks = 3 + self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore'] self.extra_buttons = False if 'extra_buttons' not in config else config['extra_buttons'] - self.instance_id = str(uuid.uuid4())[:8] if 'instance_id' not in config else config['instance_id'] - self.s_path.mkdir(exist_ok=True) self.reset_count = 0 - self.all_runs = [] # Set this in SOME subclasses self.metadata = {"render.modes": []} @@ -74,20 +61,6 @@ def __init__( WindowEvent.RELEASE_BUTTON_B ] - self.output_shape = (36, 40, 3) - self.mem_padding = 2 - self.memory_height = 8 - self.col_steps = 16 - self.output_full = ( - self.output_shape[0] * self.frame_stacks + 2 * (self.mem_padding + self.memory_height), - self.output_shape[1], - self.output_shape[2] - ) - - # Set these in ALL subclasses - self.action_space = spaces.Discrete(len(self.valid_actions)) - self.observation_space = spaces.Box(low=0, high=255, shape=self.output_full, dtype=np.uint8) - head = 'headless' if config['headless'] else 'SDL2' # log_level("ERROR") @@ -98,19 +71,25 @@ def __init__( window_type=head, hide_window='--quiet' in sys.argv, ) - self.screen = self.pyboy.botsupport_manager().screen() if not config['headless']: self.pyboy.set_emulation_speed(6) self.reader = ReaderPyBoy(self.pyboy) + self.renderer = Renderer(config, self.pyboy) + + # Set these in ALL subclasses + self.action_space = spaces.Discrete(len(self.valid_actions)) + self.observation_space = spaces.Box(low=0, high=255, shape=self.renderer.output_full, dtype=np.uint8) # Rewards - self.print_rewards = config['print_rewards'] - self.reward_service = Reward(config, self.reader, self.save_screenshot) + self.reward_service = Reward(config, self.reader, self.renderer.save_screenshot) self.reset() + def render(self): + self.renderer.render(self.reward_service) + def reset(self, seed=None, options=None): self.seed = seed # restart game, skipping credits @@ -119,79 +98,40 @@ def reset(self, seed=None, options=None): self.reward_service.reset() - self.recent_memory = np.zeros((self.output_shape[1] * self.memory_height, 3), dtype=np.uint8) - - self.recent_frames = np.zeros( - (self.frame_stacks, self.output_shape[0], - self.output_shape[1], self.output_shape[2]), - dtype=np.uint8) + self.renderer.reset() self.agent_stats = [] - if self.save_video: - base_dir = self.s_path / Path('rollouts') - base_dir.mkdir(exist_ok=True) - full_name = Path(f'full_reset_{self.reset_count}_id{self.instance_id}').with_suffix('.mp4') - model_name = Path(f'model_reset_{self.reset_count}_id{self.instance_id}').with_suffix('.mp4') - self.full_frame_writer = media.VideoWriter(base_dir / full_name, (144, 160), fps=60) - self.full_frame_writer.__enter__() - self.model_frame_writer = media.VideoWriter(base_dir / model_name, self.output_full[:2], fps=60) - self.model_frame_writer.__enter__() - self.step_count = 0 self.reset_count += 1 - return self.render(), {} - - def render(self, reduce_res=True, add_memory=True, update_mem=True): - game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3) - if reduce_res: - game_pixels_render = (255 * resize(game_pixels_render, self.output_shape)).astype(np.uint8) - if update_mem: - self.recent_frames[0] = game_pixels_render - if add_memory: - pad = np.zeros( - shape=(self.mem_padding, self.output_shape[1], 3), - dtype=np.uint8) - game_pixels_render = np.concatenate( - ( - self.create_exploration_memory(), - pad, - self.create_recent_memory(), - pad, - rearrange(self.recent_frames, 'f h w c -> (f h) w c') - ), - axis=0) - return game_pixels_render + return self.renderer.render(self.reward_service), {} def step(self, action): self.run_action_on_emulator(action) self.append_agent_stats(action) - - self.recent_frames = np.roll(self.recent_frames, 1, axis=0) + self.renderer.recent_frames = np.roll(self.renderer.recent_frames, 1, axis=0) # OBSERVATION - obs_memory = self.render() - # trim off memory from frame for knn index - frame_start = 2 * (self.memory_height + self.mem_padding) - obs_flat = obs_memory[frame_start:frame_start + self.output_shape[0], ...].flatten().astype(np.float32) + obs_memory = self.renderer.render(self.reward_service) + obs_flat = self.renderer.get_obs_flat(obs_memory) # REWARD reward_delta, new_prog = self.reward_service.update_rewards(obs_flat, self.step_count) # shift over short term reward memory - self.recent_memory = np.roll(self.recent_memory, 3) - self.recent_memory[0, 0] = min(new_prog[0] * 64, 255) - self.recent_memory[0, 1] = min(new_prog[1] * 64, 255) - self.recent_memory[0, 2] = min(new_prog[2] * 128, 255) + self.renderer.recent_memory = np.roll(self.renderer.recent_memory, 3) + self.renderer.recent_memory[0, 0] = min(new_prog[0] * 64, 255) + self.renderer.recent_memory[0, 1] = min(new_prog[1] * 64, 255) + self.renderer.recent_memory[0, 2] = min(new_prog[2] * 128, 255) # DONE done = self.check_if_done() - self.save_and_print_info(done, obs_memory) + self.renderer.save_and_print_info(done, obs_memory, self.step_count, self.reward_service, self.reset_count, self.agent_stats) self.step_count += 1 return obs_memory, reward_delta * 0.1, False, done, {} @@ -200,7 +140,7 @@ def run_action_on_emulator(self, action): # press button then release after some steps self.pyboy.send_input(self.valid_actions[action]) # disable rendering when we don't need it - if not self.save_video and self.headless: + if not self.renderer.save_video and self.headless: self.pyboy._rendering(False) for i in range(self.act_freq): # release action, so they are stateless @@ -213,17 +153,13 @@ def run_action_on_emulator(self, action): self.pyboy.send_input(self.release_button[action - 4]) if self.valid_actions[action] == WindowEvent.PRESS_BUTTON_START: self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START) - if self.save_video and not self.fast_video: - self.add_video_frame() + if self.renderer.save_video and not self.fast_video: + self.renderer.add_video_frame(self.reward_service) if i == self.act_freq - 1: self.pyboy._rendering(True) self.pyboy.tick() - if self.save_video and self.fast_video: - self.add_video_frame() - - def add_video_frame(self): - self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False)) - self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) + if self.renderer.save_video and self.fast_video: + self.renderer.add_video_frame(self.reward_service) def append_agent_stats(self, action): x_pos = self.reader.read_x_pos() @@ -250,97 +186,11 @@ def append_agent_stats(self, action): 'healr': self.reward_service.total_healing }) - def create_exploration_memory(self): - w = self.output_shape[1] - h = self.memory_height - - def make_reward_channel(r_val): - col_steps = self.col_steps - max_r_val = (w - 1) * h * col_steps - # truncate progress bar. if hitting this - # you should scale down the reward in group_rewards! - r_val = min(r_val, max_r_val) - row = floor(r_val / (h * col_steps)) - memory = np.zeros(shape=(h, w), dtype=np.uint8) - memory[:, :row] = 255 - row_covered = row * h * col_steps - col = floor((r_val - row_covered) / col_steps) - memory[:col, row] = 255 - col_covered = col * col_steps - last_pixel = floor(r_val - row_covered - col_covered) - memory[col, row] = last_pixel * (255 // col_steps) - return memory - - level, hp, explore = self.reward_service.group_rewards_lvl_hp_explore(self.reward_service.get_game_state_rewards()) - full_memory = np.stack(( - make_reward_channel(level), - make_reward_channel(hp), - make_reward_channel(explore) - ), axis=-1) - - if self.reader.get_badges() > 0: - full_memory[:, -1, :] = 255 - - return full_memory - - def create_recent_memory(self): - return rearrange( - self.recent_memory, - '(w h) c -> h w c', - h=self.memory_height) - def check_if_done(self): if self.early_stopping: done = False - if self.step_count > 128 and self.recent_memory.sum() < (255 * 1): + if self.step_count > 128 and self.renderer.recent_memory.sum() < (255 * 1): done = True else: done = self.step_count >= self.max_steps - #done = self.read_hp_fraction() == 0 return done - - def save_and_print_info(self, done, obs_memory): - if self.print_rewards: - prog_string = f'step: {self.step_count:6d}' - rewards_state = self.reward_service.get_game_state_rewards() - for key, val in rewards_state.items(): - prog_string += f' {key}: {val:5.2f}' - prog_string += f' sum: {self.reward_service.total_reward:5.2f}' - print(f'\r{prog_string}', end='', flush=True) - - if self.step_count % 50 == 0: - plt.imsave( - self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), - self.render(reduce_res=False)) - - if self.print_rewards and done: - print('', flush=True) - if self.save_final_state: - fs_path = self.s_path / Path('final_states') - fs_path.mkdir(exist_ok=True) - plt.imsave( - fs_path / Path(f'frame_r{self.reward_service.total_reward:.4f}_{self.reset_count}_small.jpeg'), - obs_memory) - plt.imsave( - fs_path / Path(f'frame_r{self.reward_service.total_reward:.4f}_{self.reset_count}_full.jpeg'), - self.render(reduce_res=False)) - - if self.save_video and done: - self.full_frame_writer.close() - self.model_frame_writer.close() - - if done: - self.all_runs.append(self.reward_service.get_game_state_rewards()) - with open(self.s_path / Path(f'all_runs_{self.instance_id}.json'), 'w') as f: - json.dump(self.all_runs, f) - pd.DataFrame(self.agent_stats).to_csv( - self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') - - def save_screenshot(self, name): - ss_dir = self.s_path / Path('screenshots') - ss_dir.mkdir(exist_ok=True) - plt.imsave( - ss_dir / Path( - f'frame{self.instance_id}_r{self.reward_service.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), - self.render(reduce_res=False) - ) diff --git a/baselines/renderer.py b/baselines/renderer.py new file mode 100644 index 000000000..9b1373414 --- /dev/null +++ b/baselines/renderer.py @@ -0,0 +1,164 @@ +from pathlib import Path +import uuid +import numpy as np +import matplotlib.pyplot as plt +from math import floor +import json +import pandas as pd +import mediapy as media +from einops import rearrange +from skimage.transform import resize +from reader_pyboy import ReaderPyBoy + + +class Renderer: + + def __init__(self, config, pyboy): + self.print_rewards = config['print_rewards'] + self.save_video = config['save_video'] + self.save_final_state = config['save_final_state'] + self.instance_id = str(uuid.uuid4())[:8] if 'instance_id' not in config else config['instance_id'] + self.s_path = config['session_path'] + self.s_path.mkdir(exist_ok=True) + self.output_shape = (36, 40, 3) + self.frame_stacks = 3 + self.mem_padding = 2 + self.memory_height = 8 + self.output_full = ( + self.output_shape[0] * self.frame_stacks + 2 * (self.mem_padding + self.memory_height), + self.output_shape[1], + self.output_shape[2] + ) + self.col_steps = 16 + self.screen = pyboy.botsupport_manager().screen() + self.all_runs = [] + self.reader = ReaderPyBoy(pyboy) + + def render(self, reward_service, reduce_res=True, add_memory=True, update_mem=True): + game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3) + if reduce_res: + game_pixels_render = (255 * resize(game_pixels_render, self.output_shape)).astype(np.uint8) + if update_mem: + self.recent_frames[0] = game_pixels_render + if add_memory: + pad = np.zeros( + shape=(self.mem_padding, self.output_shape[1], 3), + dtype=np.uint8) + game_pixels_render = np.concatenate( + ( + self.create_exploration_memory(reward_service), + pad, + self.create_recent_memory(), + pad, + rearrange(self.recent_frames, 'f h w c -> (f h) w c') + ), + axis=0) + return game_pixels_render + + def save_and_print_info(self, done, obs_memory, step_count, reward_service, reset_count, agent_stats): + if self.print_rewards: + prog_string = f'step: {step_count:6d}' + rewards_state = reward_service.get_game_state_rewards() + for key, val in rewards_state.items(): + prog_string += f' {key}: {val:5.2f}' + prog_string += f' sum: {reward_service.total_reward:5.2f}' + print(f'\r{prog_string}', end='', flush=True) + + if step_count % 50 == 0: + plt.imsave( + self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), + self.render(reward_service, reduce_res=False)) + + if self.print_rewards and done: + print('', flush=True) + if self.save_final_state: + fs_path = self.s_path / Path('final_states') + fs_path.mkdir(exist_ok=True) + plt.imsave( + fs_path / Path(f'frame_r{reward_service.total_reward:.4f}_{reset_count}_small.jpeg'), + obs_memory) + plt.imsave( + fs_path / Path(f'frame_r{reward_service.total_reward:.4f}_{reset_count}_full.jpeg'), + self.render(reward_service, reduce_res=False)) + + if self.save_video and done: + self.full_frame_writer.close() + self.model_frame_writer.close() + + if done: + self.all_runs.append(reward_service.get_game_state_rewards()) + with open(self.s_path / Path(f'all_runs_{self.instance_id}.json'), 'w') as f: + json.dump(self.all_runs, f) + pd.DataFrame(agent_stats).to_csv( + self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') + + def save_screenshot(self, name, reward_service, reset_count): + ss_dir = self.s_path / Path('screenshots') + ss_dir.mkdir(exist_ok=True) + plt.imsave( + ss_dir / Path( + f'frame{self.instance_id}_r{reward_service.total_reward:.4f}_{reset_count}_{name}.jpeg'), + self.render(reward_service, reduce_res=False) + ) + + def save_video(self, reset_count): + base_dir = self.s_path / Path('rollouts') + base_dir.mkdir(exist_ok=True) + full_name = Path(f'full_reset_{reset_count}_id{self.instance_id}').with_suffix('.mp4') + model_name = Path(f'model_reset_{reset_count}_id{self.instance_id}').with_suffix('.mp4') + self.full_frame_writer = media.VideoWriter(base_dir / full_name, (144, 160), fps=60) + self.full_frame_writer.__enter__() + self.model_frame_writer = media.VideoWriter(base_dir / model_name, self.output_full[:2], fps=60) + self.model_frame_writer.__enter__() + + def add_video_frame(self, reward_service): + self.full_frame_writer.add_image(self.render(reward_service, reduce_res=False, update_mem=False)) + self.model_frame_writer.add_image(self.render(reward_service, reduce_res=True, update_mem=False)) + + def get_obs_flat(self, obs_memory): + # trim off memory from frame for knn index + frame_start = 2 * (self.memory_height + self.mem_padding) + return obs_memory[frame_start:frame_start + self.output_shape[0], ...].flatten().astype(np.float32) + + def create_recent_memory(self): + return rearrange(self.recent_memory,'(w h) c -> h w c', h=self.memory_height) + + def create_exploration_memory(self, reward_service): + w = self.output_shape[1] + h = self.memory_height + + def make_reward_channel(r_val): + col_steps = self.col_steps + max_r_val = (w - 1) * h * col_steps + # truncate progress bar. if hitting this + # you should scale down the reward in group_rewards! + r_val = min(r_val, max_r_val) + row = floor(r_val / (h * col_steps)) + memory = np.zeros(shape=(h, w), dtype=np.uint8) + memory[:, :row] = 255 + row_covered = row * h * col_steps + col = floor((r_val - row_covered) / col_steps) + memory[:col, row] = 255 + col_covered = col * col_steps + last_pixel = floor(r_val - row_covered - col_covered) + memory[col, row] = last_pixel * (255 // col_steps) + return memory + + level, hp, explore = reward_service.group_rewards_lvl_hp_explore(reward_service.get_game_state_rewards()) + full_memory = np.stack(( + make_reward_channel(level), + make_reward_channel(hp), + make_reward_channel(explore) + ), axis=-1) + + if self.reader.get_badges() > 0: + full_memory[:, -1, :] = 255 + + return full_memory + def reset(self): + self.recent_memory = np.zeros((self.output_shape[1] * self.memory_height, 3), dtype=np.uint8) + self.recent_frames = np.zeros( + (self.frame_stacks, self.output_shape[0],self.output_shape[1], self.output_shape[2]), + dtype=np.uint8) + if self.save_video: + self.save_video() diff --git a/baselines/rewards.py b/baselines/rewards.py index 118d61d38..9634e21d5 100644 --- a/baselines/rewards.py +++ b/baselines/rewards.py @@ -77,7 +77,7 @@ def get_game_state_rewards(self): 'op_lvl': self.reward_scale * self.max_opponent_level * 1, 'dead': self.reward_scale * self.died_count * -0.1, 'badge': self.reward_scale * self.reader.get_badges() * 5, - 'explore': self.reward_scale * self.explore_reward,git + 'explore': self.reward_scale * self.explore_reward # 'party_xp': self.reward_scale*0.1*sum(poke_xps), # 'op_poke': self.reward_scale*self.max_opponent_poke * 800, # 'money': self.reward_scale* money * 3, From fc2df2dc78c9893adc96804ab7096fc87f25ef73 Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Thu, 29 Feb 2024 18:11:08 +0100 Subject: [PATCH 06/10] renderer Signed-off-by: Mathieu D --- baselines/red_gym_env.py | 57 ++++++++++++++++++++------- baselines/renderer.py | 84 ++++++++++++++-------------------------- baselines/rewards.py | 15 ++++--- 3 files changed, 80 insertions(+), 76 deletions(-) diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index 09a81e5c8..0111b6ebf 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -2,7 +2,10 @@ import numpy as np from pyboy import PyBoy - +import uuid +import json +import pandas as pd +from pathlib import Path from renderer import Renderer from rewards import Reward from reader_pyboy import ReaderPyBoy @@ -16,7 +19,10 @@ class RedGymEnv(Env): def __init__(self, config=None): self.debug = config['debug'] - + self.instance_id = str(uuid.uuid4())[:8] if 'instance_id' not in config else config['instance_id'] + self.s_path = config['session_path'] + self.save_final_state = config['save_final_state'] + self.save_video = config['save_video'] self.headless = config['headless'] self.init_state = config['init_state'] self.act_freq = config['action_freq'] @@ -25,12 +31,13 @@ def __init__(self, config=None): self.fast_video = config['fast_video'] self.video_interval = 256 * self.act_freq self.downsample_factor = 2 + self.print_rewards = config['print_rewards'] self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore'] self.extra_buttons = False if 'extra_buttons' not in config else config['extra_buttons'] self.reset_count = 0 - + self.all_runs = [] # Set this in SOME subclasses self.metadata = {"render.modes": []} @@ -63,7 +70,6 @@ def __init__(self, config=None): head = 'headless' if config['headless'] else 'SDL2' - # log_level("ERROR") self.pyboy = PyBoy( config['gb_path'], debugging=False, @@ -76,15 +82,15 @@ def __init__(self, config=None): self.pyboy.set_emulation_speed(6) self.reader = ReaderPyBoy(self.pyboy) - self.renderer = Renderer(config, self.pyboy) + + # Rewards + self.reward_service = Reward(config, self.reader) + self.renderer = Renderer(config, self.pyboy, self.reward_service, self.instance_id) # Set these in ALL subclasses self.action_space = spaces.Discrete(len(self.valid_actions)) self.observation_space = spaces.Box(low=0, high=255, shape=self.renderer.output_full, dtype=np.uint8) - # Rewards - self.reward_service = Reward(config, self.reader, self.renderer.save_screenshot) - self.reset() def render(self): @@ -99,6 +105,8 @@ def reset(self, seed=None, options=None): self.reward_service.reset() self.renderer.reset() + if self.save_video: + self.renderer.save_video(self.reset_count) self.agent_stats = [] @@ -121,6 +129,10 @@ def step(self, action): # REWARD reward_delta, new_prog = self.reward_service.update_rewards(obs_flat, self.step_count) + if self.print_rewards: + self.reward_service.print_rewards(self.step_count) + if reward_delta < 0 and self.reader.read_hp_fraction() > 0: + self.renderer.save_screenshot('neg_reward') # shift over short term reward memory self.renderer.recent_memory = np.roll(self.renderer.recent_memory, 3) @@ -131,9 +143,26 @@ def step(self, action): # DONE done = self.check_if_done() - self.renderer.save_and_print_info(done, obs_memory, self.step_count, self.reward_service, self.reset_count, self.agent_stats) - self.step_count += 1 + if self.step_count % 50 == 0: + self.renderer.save_and_print_info() + + if done: + self.all_runs.append(self.reward_service.get_game_state_rewards()) + with open(self.s_path / Path(f'all_runs_{self.instance_id}.json'), 'w') as f: + json.dump(self.all_runs, f) + pd.DataFrame(self.agent_stats).to_csv( + self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') + if self.print_rewards: + print('', flush=True) + if self.save_final_state: + self.renderer.save_final_state(obs_memory, self.reset_count) + + if self.save_video and done: + self.renderer.full_frame_writer.close() + self.renderer.model_frame_writer.close() + + self.step_count += 1 return obs_memory, reward_delta * 0.1, False, done, {} def run_action_on_emulator(self, action): @@ -153,13 +182,13 @@ def run_action_on_emulator(self, action): self.pyboy.send_input(self.release_button[action - 4]) if self.valid_actions[action] == WindowEvent.PRESS_BUTTON_START: self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START) - if self.renderer.save_video and not self.fast_video: - self.renderer.add_video_frame(self.reward_service) + if self.save_video and not self.fast_video: + self.renderer.add_video_frame() if i == self.act_freq - 1: self.pyboy._rendering(True) self.pyboy.tick() - if self.renderer.save_video and self.fast_video: - self.renderer.add_video_frame(self.reward_service) + if self.save_video and self.fast_video: + self.renderer.add_video_frame() def append_agent_stats(self, action): x_pos = self.reader.read_x_pos() diff --git a/baselines/renderer.py b/baselines/renderer.py index 9b1373414..61856374d 100644 --- a/baselines/renderer.py +++ b/baselines/renderer.py @@ -1,11 +1,8 @@ -from pathlib import Path -import uuid import numpy as np import matplotlib.pyplot as plt from math import floor -import json -import pandas as pd import mediapy as media +from pathlib import Path from einops import rearrange from skimage.transform import resize from reader_pyboy import ReaderPyBoy @@ -13,11 +10,9 @@ class Renderer: - def __init__(self, config, pyboy): - self.print_rewards = config['print_rewards'] - self.save_video = config['save_video'] - self.save_final_state = config['save_final_state'] - self.instance_id = str(uuid.uuid4())[:8] if 'instance_id' not in config else config['instance_id'] + def __init__(self, config, pyboy, reward_service, instance_id): + self.reward_service = reward_service + self.instance_id = instance_id self.s_path = config['session_path'] self.s_path.mkdir(exist_ok=True) self.output_shape = (36, 40, 3) @@ -31,10 +26,9 @@ def __init__(self, config, pyboy): ) self.col_steps = 16 self.screen = pyboy.botsupport_manager().screen() - self.all_runs = [] self.reader = ReaderPyBoy(pyboy) - def render(self, reward_service, reduce_res=True, add_memory=True, update_mem=True): + def render(self, reduce_res=True, add_memory=True, update_mem=True): game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3) if reduce_res: game_pixels_render = (255 * resize(game_pixels_render, self.output_shape)).astype(np.uint8) @@ -44,9 +38,10 @@ def render(self, reward_service, reduce_res=True, add_memory=True, update_mem=Tr pad = np.zeros( shape=(self.mem_padding, self.output_shape[1], 3), dtype=np.uint8) + level, hp, explore = self.reward_service.group_rewards_lvl_hp_explore(self.reward_service.get_game_state_rewards()) game_pixels_render = np.concatenate( ( - self.create_exploration_memory(reward_service), + self.create_exploration_memory(level, hp, explore), pad, self.create_recent_memory(), pad, @@ -55,42 +50,11 @@ def render(self, reward_service, reduce_res=True, add_memory=True, update_mem=Tr axis=0) return game_pixels_render - def save_and_print_info(self, done, obs_memory, step_count, reward_service, reset_count, agent_stats): - if self.print_rewards: - prog_string = f'step: {step_count:6d}' - rewards_state = reward_service.get_game_state_rewards() - for key, val in rewards_state.items(): - prog_string += f' {key}: {val:5.2f}' - prog_string += f' sum: {reward_service.total_reward:5.2f}' - print(f'\r{prog_string}', end='', flush=True) - - if step_count % 50 == 0: - plt.imsave( - self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), - self.render(reward_service, reduce_res=False)) - - if self.print_rewards and done: - print('', flush=True) - if self.save_final_state: - fs_path = self.s_path / Path('final_states') - fs_path.mkdir(exist_ok=True) - plt.imsave( - fs_path / Path(f'frame_r{reward_service.total_reward:.4f}_{reset_count}_small.jpeg'), - obs_memory) - plt.imsave( - fs_path / Path(f'frame_r{reward_service.total_reward:.4f}_{reset_count}_full.jpeg'), - self.render(reward_service, reduce_res=False)) - - if self.save_video and done: - self.full_frame_writer.close() - self.model_frame_writer.close() - - if done: - self.all_runs.append(reward_service.get_game_state_rewards()) - with open(self.s_path / Path(f'all_runs_{self.instance_id}.json'), 'w') as f: - json.dump(self.all_runs, f) - pd.DataFrame(agent_stats).to_csv( - self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') + def save_and_print_info(self): + plt.imsave( + self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), + self.render(reduce_res=False)) + def save_screenshot(self, name, reward_service, reset_count): ss_dir = self.s_path / Path('screenshots') @@ -98,7 +62,7 @@ def save_screenshot(self, name, reward_service, reset_count): plt.imsave( ss_dir / Path( f'frame{self.instance_id}_r{reward_service.total_reward:.4f}_{reset_count}_{name}.jpeg'), - self.render(reward_service, reduce_res=False) + self.render(reduce_res=False) ) def save_video(self, reset_count): @@ -111,9 +75,9 @@ def save_video(self, reset_count): self.model_frame_writer = media.VideoWriter(base_dir / model_name, self.output_full[:2], fps=60) self.model_frame_writer.__enter__() - def add_video_frame(self, reward_service): - self.full_frame_writer.add_image(self.render(reward_service, reduce_res=False, update_mem=False)) - self.model_frame_writer.add_image(self.render(reward_service, reduce_res=True, update_mem=False)) + def add_video_frame(self): + self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False)) + self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) def get_obs_flat(self, obs_memory): # trim off memory from frame for knn index @@ -123,7 +87,7 @@ def get_obs_flat(self, obs_memory): def create_recent_memory(self): return rearrange(self.recent_memory,'(w h) c -> h w c', h=self.memory_height) - def create_exploration_memory(self, reward_service): + def create_exploration_memory(self, level, hp, explore): w = self.output_shape[1] h = self.memory_height @@ -144,7 +108,6 @@ def make_reward_channel(r_val): memory[col, row] = last_pixel * (255 // col_steps) return memory - level, hp, explore = reward_service.group_rewards_lvl_hp_explore(reward_service.get_game_state_rewards()) full_memory = np.stack(( make_reward_channel(level), make_reward_channel(hp), @@ -155,10 +118,19 @@ def make_reward_channel(r_val): full_memory[:, -1, :] = 255 return full_memory + + def save_final_state(self, obs_memory, reset_count): + fs_path = self.s_path / Path('final_states') + fs_path.mkdir(exist_ok=True) + plt.imsave( + fs_path / Path(f'frame_r{self.reward_service.total_reward:.4f}_{reset_count}_small.jpeg'), + obs_memory) + plt.imsave( + fs_path / Path(f'frame_r{self.reward_service.total_reward:.4f}_{reset_count}_full.jpeg'), + self.render(reduce_res=False)) + def reset(self): self.recent_memory = np.zeros((self.output_shape[1] * self.memory_height, 3), dtype=np.uint8) self.recent_frames = np.zeros( (self.frame_stacks, self.output_shape[0],self.output_shape[1], self.output_shape[2]), dtype=np.uint8) - if self.save_video: - self.save_video() diff --git a/baselines/rewards.py b/baselines/rewards.py index 9634e21d5..ae2f9fcbc 100644 --- a/baselines/rewards.py +++ b/baselines/rewards.py @@ -1,12 +1,10 @@ -from memory_addresses import EVENT_FLAGS_START_ADDRESS, EVENT_FLAGS_END_ADDRESS, MUSEUM_TICKET_ADDRESS import hnswlib import numpy as np class Reward: - def __init__(self, config, reader, save_screenshot): - self.save_screenshot = save_screenshot + def __init__(self, config, reader): self.reward_scale = 1 if 'reward_scale' not in config else config['reward_scale'] self.reward_range = (0, 15000) self.reader = reader @@ -113,8 +111,6 @@ def update_state_reward(self): new_total = sum([val for _, val in self.get_game_state_rewards().items()]) self.total_reward = new_total reward_delta = new_total - last_total - if reward_delta < 0 and self.reader.read_hp_fraction() > 0: - self.save_screenshot('neg_reward') self.last_game_state_rewards = self.get_game_state_rewards() @@ -177,7 +173,6 @@ def update_total_heal_and_death(self): heal_amount = cur_health - self.last_health if heal_amount > 0.5: print(f'healed: {heal_amount}') - self.save_screenshot('healing') self.total_healing += heal_amount else: self.died_count += 1 @@ -201,3 +196,11 @@ def reset(self): self.init_map_mem() self.last_game_state_rewards = self.get_game_state_rewards() self.total_reward = 0 + + def print_rewards(self, step_count): + prog_string = f'step: {step_count:6d}' + rewards_state = self.get_game_state_rewards() + for key, val in rewards_state.items(): + prog_string += f' {key}: {val:5.2f}' + prog_string += f' sum: {self.total_reward:5.2f}' + print(f'\r{prog_string}', end='', flush=True) From 3acad176af1ab8149928ac4f29340e50cac56585 Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Fri, 1 Mar 2024 09:07:50 +0100 Subject: [PATCH 07/10] renderer Signed-off-by: Mathieu D --- baselines/red_gym_env.py | 12 ++++++------ baselines/renderer.py | 26 ++++++++++---------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index 0111b6ebf..b13be18d1 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -85,7 +85,7 @@ def __init__(self, config=None): # Rewards self.reward_service = Reward(config, self.reader) - self.renderer = Renderer(config, self.pyboy, self.reward_service, self.instance_id) + self.renderer = Renderer(self.s_path, self.pyboy, self.reward_service, self.instance_id) # Set these in ALL subclasses self.action_space = spaces.Discrete(len(self.valid_actions)) @@ -94,7 +94,7 @@ def __init__(self, config=None): self.reset() def render(self): - self.renderer.render(self.reward_service) + return self.renderer.render() def reset(self, seed=None, options=None): self.seed = seed @@ -113,7 +113,7 @@ def reset(self, seed=None, options=None): self.step_count = 0 self.reset_count += 1 - return self.renderer.render(self.reward_service), {} + return self.render(), {} def step(self, action): @@ -123,7 +123,7 @@ def step(self, action): # OBSERVATION - obs_memory = self.renderer.render(self.reward_service) + obs_memory = self.render() obs_flat = self.renderer.get_obs_flat(obs_memory) # REWARD @@ -132,7 +132,7 @@ def step(self, action): if self.print_rewards: self.reward_service.print_rewards(self.step_count) if reward_delta < 0 and self.reader.read_hp_fraction() > 0: - self.renderer.save_screenshot('neg_reward') + self.renderer.save_screenshot('neg_reward', self.reward_service.total_reward, self.reset_count) # shift over short term reward memory self.renderer.recent_memory = np.roll(self.renderer.recent_memory, 3) @@ -156,7 +156,7 @@ def step(self, action): if self.print_rewards: print('', flush=True) if self.save_final_state: - self.renderer.save_final_state(obs_memory, self.reset_count) + self.renderer.save_final_state(obs_memory, self.reset_count, self.reward_service.total_reward) if self.save_video and done: self.renderer.full_frame_writer.close() diff --git a/baselines/renderer.py b/baselines/renderer.py index 61856374d..e12eb937e 100644 --- a/baselines/renderer.py +++ b/baselines/renderer.py @@ -10,10 +10,10 @@ class Renderer: - def __init__(self, config, pyboy, reward_service, instance_id): + def __init__(self, s_path, pyboy, reward_service, instance_id): self.reward_service = reward_service self.instance_id = instance_id - self.s_path = config['session_path'] + self.s_path = s_path self.s_path.mkdir(exist_ok=True) self.output_shape = (36, 40, 3) self.frame_stacks = 3 @@ -56,12 +56,12 @@ def save_and_print_info(self): self.render(reduce_res=False)) - def save_screenshot(self, name, reward_service, reset_count): + def save_screenshot(self, name, total_reward, reset_count): ss_dir = self.s_path / Path('screenshots') ss_dir.mkdir(exist_ok=True) plt.imsave( ss_dir / Path( - f'frame{self.instance_id}_r{reward_service.total_reward:.4f}_{reset_count}_{name}.jpeg'), + f'frame{self.instance_id}_r{total_reward:.4f}_{reset_count}_{name}.jpeg'), self.render(reduce_res=False) ) @@ -108,26 +108,20 @@ def make_reward_channel(r_val): memory[col, row] = last_pixel * (255 // col_steps) return memory - full_memory = np.stack(( - make_reward_channel(level), - make_reward_channel(hp), - make_reward_channel(explore) - ), axis=-1) + full_memory = np.stack( + (make_reward_channel(level), make_reward_channel(hp), make_reward_channel(explore)), + axis=-1) if self.reader.get_badges() > 0: full_memory[:, -1, :] = 255 return full_memory - def save_final_state(self, obs_memory, reset_count): + def save_final_state(self, obs_memory, reset_count, total_reward): fs_path = self.s_path / Path('final_states') fs_path.mkdir(exist_ok=True) - plt.imsave( - fs_path / Path(f'frame_r{self.reward_service.total_reward:.4f}_{reset_count}_small.jpeg'), - obs_memory) - plt.imsave( - fs_path / Path(f'frame_r{self.reward_service.total_reward:.4f}_{reset_count}_full.jpeg'), - self.render(reduce_res=False)) + plt.imsave(fs_path / Path(f'frame_r{total_reward:.4f}_{reset_count}_small.jpeg'), obs_memory) + plt.imsave(fs_path / Path(f'frame_r{total_reward:.4f}_{reset_count}_full.jpeg'), self.render(reduce_res=False)) def reset(self): self.recent_memory = np.zeros((self.output_shape[1] * self.memory_height, 3), dtype=np.uint8) From 164b0c9c44532d322463109a2b83d92ace9583d1 Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Sat, 2 Mar 2024 22:22:11 +0100 Subject: [PATCH 08/10] fix rewards reset, add more info for agent_stats Signed-off-by: Mathieu D --- baselines/baselines_utils.py | 26 +++++---- baselines/red_gym_env.py | 8 +-- baselines/rewards.py | 87 ++++++++++++++++-------------- baselines/run_baseline_parallel.py | 15 +++--- 4 files changed, 76 insertions(+), 60 deletions(-) diff --git a/baselines/baselines_utils.py b/baselines/baselines_utils.py index a31ae9603..78d395ad7 100644 --- a/baselines/baselines_utils.py +++ b/baselines/baselines_utils.py @@ -2,13 +2,24 @@ from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3 import PPO from stable_baselines3.common.utils import set_random_seed + +from stream_agent_wrapper import StreamWrapper from red_gym_env import RedGymEnv def load_or_create_model(model_to_load_path, env_config, total_timesteps, num_cpu): - env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) - + env = make_env(0, env_config) + if env_config['stream'] is True: + env = StreamWrapper( + env, + stream_metadata = { # All of this is part is optional + "user": "MATHIEU", # choose your own username + "env_id": env_config['instance_id'], # environment identifier + "color": "#d900ff", # choose your color :) + "extra": "", # any extra text you put here will be displayed + } + ) if exists(model_to_load_path + '.zip'): print('\nloading checkpoint') model = PPO.load(model_to_load_path, env=env) @@ -18,7 +29,7 @@ def load_or_create_model(model_to_load_path, env_config, total_timesteps, num_cp model.rollout_buffer.n_envs = num_cpu model.rollout_buffer.reset() else: - model = PPO('CnnPolicy', env, verbose=1, n_steps=total_timesteps, batch_size=512, n_epochs=1, gamma=0.999) + model = PPO('CnnPolicy', env, verbose=1, n_steps=total_timesteps, batch_size=512, n_epochs=1, gamma=0.999, tensorboard_log=model_to_load_path) return model @@ -31,9 +42,6 @@ def make_env(rank, env_conf, seed=0): :param seed: (int) the initial seed for RNG :param rank: (int) index of the subprocess """ - def _init(): - env = RedGymEnv(env_conf) - env.reset(seed=(seed + rank)) - return env - set_random_seed(seed) - return _init + env = RedGymEnv(env_conf) + env.reset(seed=(seed + rank)) + return env \ No newline at end of file diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index b13be18d1..32f4d6dd3 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -30,9 +30,7 @@ def __init__(self, config=None): self.early_stopping = config['early_stop'] self.fast_video = config['fast_video'] self.video_interval = 256 * self.act_freq - self.downsample_factor = 2 self.print_rewards = config['print_rewards'] - self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore'] self.extra_buttons = False if 'extra_buttons' not in config else config['extra_buttons'] @@ -129,7 +127,7 @@ def step(self, action): # REWARD reward_delta, new_prog = self.reward_service.update_rewards(obs_flat, self.step_count) - if self.print_rewards: + if self.print_rewards and self.step_count % 100 == 0: self.reward_service.print_rewards(self.step_count) if reward_delta < 0 and self.reader.read_hp_fraction() > 0: self.renderer.save_screenshot('neg_reward', self.reward_service.total_reward, self.reset_count) @@ -203,9 +201,11 @@ def append_agent_stats(self, action): 'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, 'map_location': self.reader.get_map_location(), 'last_action': action, - 'pcount': self.reader.read_party_size_address(), + 'final_total_reward': self.reward_service.total_reward, + 'party_size': self.reader.read_party_size_address(), 'levels': levels, 'levels_sum': sum(levels), + 'seen_pokemons': self.reward_service.seen_pokemons, 'ptypes': self.reader.read_party(), 'hp': self.reader.read_hp_fraction(), expl[0]: expl[1], diff --git a/baselines/rewards.py b/baselines/rewards.py index ae2f9fcbc..3c46bbf83 100644 --- a/baselines/rewards.py +++ b/baselines/rewards.py @@ -37,11 +37,44 @@ def __init__(self, config, reader): self.init_knn() self.seen_coords = {} self.init_map_mem() - self.explore_reward = 0 + self.cur_size = 0 self.last_game_state_rewards = self.get_game_state_rewards() self.total_reward = 0 + def reset(self): + self.max_event = 0 + self.max_level = 0 + self.levels_satisfied = False + self.total_healing = 0 + self.last_party_size = 0 + self.last_health = 1 + self.max_opponent_level = 0 + self.died_count = 0 + self.seen_pokemons = 0 + self.base_explore = 0 + self.seen_coords = {} + if self.use_screen_explore: + self.init_knn() + else: + self.init_map_mem() + self.total_reward = 0 + self.last_game_state_rewards = self.get_game_state_rewards() + + def get_game_state_rewards(self): + # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map + # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm + return { + 'event': self.reward_scale * self.max_event * 1, + 'level': self.reward_scale * self.compute_level_reward(), + 'heal': self.reward_scale * self.total_healing * 2, + 'op_lvl': self.reward_scale * self.max_opponent_level * 1, + 'dead': self.reward_scale * self.died_count * -0.1, + 'badge': self.reward_scale * self.reader.get_badges() * 5, + 'seen_poke': self.reward_scale * self.seen_pokemons * 1, + 'explore': self.reward_scale * self.compute_explore_reward() + } + def init_knn(self): # Declaring index self.knn_index = hnswlib.Index(space='l2', dim=self.vec_dim) # possible options are l2, cosine or ip @@ -52,39 +85,24 @@ def init_map_mem(self): self.seen_coords = {} def update_exploration_reward(self): - pre_rew = self.explore_weight * 0.005 - post_rew = self.explore_weight * 0.01 - cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords) - if not self.levels_satisfied: - self.explore_reward = cur_size * pre_rew - else: - self.explore_reward = (self.base_explore * pre_rew) + (cur_size * post_rew) + self.cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords) def get_all_events_flags(self): # adds up all event flags, exclude museum ticket base_event_flags = 13 return max(0, sum(self.reader.read_events()) - base_event_flags - int(self.reader.read_museum_tickets())) - def get_game_state_rewards(self): - # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map - # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm - return { - 'event': self.reward_scale * self.max_event * 1, - 'level': self.reward_scale * self.compute_level_reward() * 1, - 'heal': self.reward_scale * self.total_healing * 4, - 'op_lvl': self.reward_scale * self.max_opponent_level * 1, - 'dead': self.reward_scale * self.died_count * -0.1, - 'badge': self.reward_scale * self.reader.get_badges() * 5, - 'explore': self.reward_scale * self.explore_reward - # 'party_xp': self.reward_scale*0.1*sum(poke_xps), - # 'op_poke': self.reward_scale*self.max_opponent_poke * 800, - # 'money': self.reward_scale* money * 3, - # 'seen_poke': self.reward_scale * self.seen_pokemons - } - - # Levels count only quarter after 22 threshold def compute_level_reward(self): - return int(min(22, self.max_level) + (max(0, (self.max_level - 22)) / 4)) + # Levels count only quarter after 22 threshold + return int(min(22, self.max_level) + (max(0, (self.max_level - 22)) / 4)) * 1 + + def compute_explore_reward(self): + pre_rew = 0.005 + post_rew = 0.01 + if not self.levels_satisfied: + return (self.cur_size * 0.005) * self.explore_weight + else: + return ((self.base_explore * pre_rew) + (self.cur_size * post_rew)) * self.explore_weight def group_rewards_lvl_hp_explore(self, rewards): return (rewards['level'] * 100 / self.reward_scale, @@ -183,24 +201,11 @@ def update_max_event(self): cur_rew = self.get_all_events_flags() self.max_event = max(cur_rew, self.max_event) - def reset(self): - self.max_event = 0 - self.max_level = 0 - self.total_healing = 0 - self.max_opponent_level = 0 - self.died_count = 0 - self.seen_pokemons = 0 - if self.use_screen_explore: - self.init_knn() - else: - self.init_map_mem() - self.last_game_state_rewards = self.get_game_state_rewards() - self.total_reward = 0 - def print_rewards(self, step_count): prog_string = f'step: {step_count:6d}' rewards_state = self.get_game_state_rewards() for key, val in rewards_state.items(): prog_string += f' {key}: {val:5.2f}' prog_string += f' sum: {self.total_reward:5.2f}' + prog_string += f' {self.reader.get_map_location()}' print(f'\r{prog_string}', end='', flush=True) diff --git a/baselines/run_baseline_parallel.py b/baselines/run_baseline_parallel.py index 927ac1c7d..b698289c9 100644 --- a/baselines/run_baseline_parallel.py +++ b/baselines/run_baseline_parallel.py @@ -1,6 +1,9 @@ from pathlib import Path from datetime import datetime -from stable_baselines3.common.callbacks import CheckpointCallback +import uuid +from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList + +from tensorboard_callback import TensorboardCallback from baselines_utils import load_or_create_model if __name__ == '__main__': @@ -9,20 +12,20 @@ sess_path = Path(f'sessions/session_{datetime.now().strftime("%Y%m%d_%H%M")}') pretrained_model = 'session_4da05e87_main_good/poke_439746560_steps' model_i_like = 'session_20240227_1952/poke_720896_steps' - model_to_load_path = 'sessions/' + pretrained_model + model_to_load_path = '' # 'sessions/session_20240302_1929/poke_1040384_steps' env_config = { 'headless': True, 'save_final_state': True, 'early_stop': False, 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, 'print_rewards': True, 'save_video': False, 'fast_video': True, 'session_path': sess_path, 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0, - 'use_screen_explore': True, 'extra_buttons': False + 'use_screen_explore': True, 'extra_buttons': False, 'stream': False, 'instance_id': str(uuid.uuid4())[:8] } - - num_cpu = 1 #64 #46 # Also sets the number of episodes per training iteration + num_cpu = 4 #64 #46 # Also sets the number of episodes per training iteration model = load_or_create_model(model_to_load_path, env_config, ep_length, num_cpu) checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, name_prefix='poke') + callbacks = [checkpoint_callback, TensorboardCallback()] learn_steps = 10 for i in range(learn_steps): - model.learn(total_timesteps=ep_length * num_cpu * 1000, callback=checkpoint_callback) + model.learn(total_timesteps=ep_length * num_cpu * 40, callback=CallbackList(callbacks)) From 4b4fe066828c0ba1f3990c13293ed578e5de7627 Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Sun, 3 Mar 2024 15:58:10 +0100 Subject: [PATCH 09/10] SubprocVecEnv Signed-off-by: Mathieu D --- baselines/baselines_utils.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/baselines/baselines_utils.py b/baselines/baselines_utils.py index 78d395ad7..31950f392 100644 --- a/baselines/baselines_utils.py +++ b/baselines/baselines_utils.py @@ -9,17 +9,18 @@ def load_or_create_model(model_to_load_path, env_config, total_timesteps, num_cpu): - env = make_env(0, env_config) - if env_config['stream'] is True: - env = StreamWrapper( - env, - stream_metadata = { # All of this is part is optional - "user": "MATHIEU", # choose your own username - "env_id": env_config['instance_id'], # environment identifier - "color": "#d900ff", # choose your color :) - "extra": "", # any extra text you put here will be displayed - } - ) + env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) + #env = make_env(0, env_config) + #if env_config['stream'] is True: + # env = StreamWrapper( + # env, + # stream_metadata = { # All of this is part is optional + # "user": "MATHIEU", # choose your own username + # "env_id": env_config['instance_id'], # environment identifier + # "color": "#d900ff", # choose your color :) + # "extra": "", # any extra text you put here will be displayed + # } + # ) if exists(model_to_load_path + '.zip'): print('\nloading checkpoint') model = PPO.load(model_to_load_path, env=env) @@ -29,7 +30,7 @@ def load_or_create_model(model_to_load_path, env_config, total_timesteps, num_cp model.rollout_buffer.n_envs = num_cpu model.rollout_buffer.reset() else: - model = PPO('CnnPolicy', env, verbose=1, n_steps=total_timesteps, batch_size=512, n_epochs=1, gamma=0.999, tensorboard_log=model_to_load_path) + model = PPO('CnnPolicy', env, verbose=1, n_steps=total_timesteps, batch_size=512, n_epochs=1, gamma=0.999, tensorboard_log=env_config['session_path']) return model @@ -42,6 +43,9 @@ def make_env(rank, env_conf, seed=0): :param seed: (int) the initial seed for RNG :param rank: (int) index of the subprocess """ - env = RedGymEnv(env_conf) - env.reset(seed=(seed + rank)) - return env \ No newline at end of file + def _init(): + env = RedGymEnv(env_conf) + env.reset(seed=(seed + rank)) + return env + set_random_seed(seed) + return _init \ No newline at end of file From 5686eb452b76930c3d92495ea1dcaaabde4820a5 Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Sun, 3 Mar 2024 16:01:14 +0100 Subject: [PATCH 10/10] conditional StreamWrapper Signed-off-by: Mathieu D --- baselines/baselines_utils.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/baselines/baselines_utils.py b/baselines/baselines_utils.py index 31950f392..553107b2b 100644 --- a/baselines/baselines_utils.py +++ b/baselines/baselines_utils.py @@ -10,17 +10,6 @@ def load_or_create_model(model_to_load_path, env_config, total_timesteps, num_cpu): env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) - #env = make_env(0, env_config) - #if env_config['stream'] is True: - # env = StreamWrapper( - # env, - # stream_metadata = { # All of this is part is optional - # "user": "MATHIEU", # choose your own username - # "env_id": env_config['instance_id'], # environment identifier - # "color": "#d900ff", # choose your color :) - # "extra": "", # any extra text you put here will be displayed - # } - # ) if exists(model_to_load_path + '.zip'): print('\nloading checkpoint') model = PPO.load(model_to_load_path, env=env) @@ -35,7 +24,7 @@ def load_or_create_model(model_to_load_path, env_config, total_timesteps, num_cp return model -def make_env(rank, env_conf, seed=0): +def make_env(rank, env_config, seed=0): """ Utility function for multiprocessed env. :param env_id: (str) the environment ID @@ -44,8 +33,18 @@ def make_env(rank, env_conf, seed=0): :param rank: (int) index of the subprocess """ def _init(): - env = RedGymEnv(env_conf) + env = RedGymEnv(env_config) env.reset(seed=(seed + rank)) + if env_config['stream'] is True: + env = StreamWrapper( + env, + stream_metadata = { # All of this is part is optional + "user": "MATHIEU", # choose your own username + "env_id": env_config['instance_id'], # environment identifier + "color": "#d900ff", # choose your color :) + "extra": "", # any extra text you put here will be displayed + } + ) return env set_random_seed(seed) return _init \ No newline at end of file