-
Notifications
You must be signed in to change notification settings - Fork 0
/
vis_utils.py
39 lines (31 loc) · 1005 Bytes
/
vis_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
Utility functions for visualizing reward functions
"""
import numpy as np
import matplotlib.pyplot as plt
def heatmap_2d(reward_matrix, title='', display_text = True, blocking = True, figure_num = 1):
"""
Draws and then displays a 2d heatmap using matplotlib
Args:
reward_matrix (MxN array): Array mapping from position to reward value
Returns:
None
"""
if blocking:
plt.figure(figure_num)
plt.clf()
plt.imshow(reward_matrix, interpolation = "nearest")
plt.title(title)
plt.colorbar()
plt.gca().invert_yaxis()
if display_text:
for y in range(reward_matrix.shape[0]):
for x in range(reward_matrix.shape[1]):
plt.text(x, y, "%.1f" % reward_matrix[y, x],
horizontalalignment = "center",
verticalalignment = "center"
)
if blocking:
plt.ion()
plt.show()
input("Press Enter to continue...")