-
Notifications
You must be signed in to change notification settings - Fork 649
/
stream_agent_wrapper.py
71 lines (61 loc) · 2.28 KB
/
stream_agent_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import asyncio
import websockets
import json
import gymnasium as gym
X_POS_ADDRESS, Y_POS_ADDRESS = 0xD362, 0xD361
MAP_N_ADDRESS = 0xD35E
class StreamWrapper(gym.Wrapper):
def __init__(self, env, stream_metadata={}):
super().__init__(env)
self.ws_address = "wss://transdimensional.xyz/broadcast"
self.stream_metadata = stream_metadata
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.websocket = None
self.loop.run_until_complete(
self.establish_wc_connection()
)
self.upload_interval = 300
self.steam_step_counter = 0
self.env = env
self.coord_list = []
if hasattr(env, "pyboy"):
self.emulator = env.pyboy
elif hasattr(env, "game"):
self.emulator = env.game
else:
raise Exception("Could not find emulator!")
def step(self, action):
x_pos = self.emulator.get_memory_value(X_POS_ADDRESS)
y_pos = self.emulator.get_memory_value(Y_POS_ADDRESS)
map_n = self.emulator.get_memory_value(MAP_N_ADDRESS)
self.coord_list.append([x_pos, y_pos, map_n])
if self.steam_step_counter >= self.upload_interval:
self.stream_metadata["extra"] = f"coords: {len(self.env.seen_coords)}"
self.loop.run_until_complete(
self.broadcast_ws_message(
json.dumps(
{
"metadata": self.stream_metadata,
"coords": self.coord_list
}
)
)
)
self.steam_step_counter = 0
self.coord_list = []
self.steam_step_counter += 1
return self.env.step(action)
async def broadcast_ws_message(self, message):
if self.websocket is None:
await self.establish_wc_connection()
if self.websocket is not None:
try:
await self.websocket.send(message)
except websockets.exceptions.WebSocketException as e:
self.websocket = None
async def establish_wc_connection(self):
try:
self.websocket = await websockets.connect(self.ws_address)
except:
self.websocket = None