-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot_policy.py
69 lines (61 loc) · 2.92 KB
/
plot_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import matplotlib.pyplot as plt
import cmocean
import yaml
from os import getcwd
path = getcwd()
with open(path + '/params.yaml', 'r') as F:
params = yaml.safe_load(F)
seed = params['seed']
n_t = params['n_t']
cur_scale = params['cur_scale']
data = np.load(path + '/data/charts/charts_' + str(seed) + '.npz', allow_pickle=True)
chart = data['chart']
dim = chart.shape[0]
water_p = np.meshgrid(np.arange(0, dim, 3), np.arange(0, dim, 3))
current = data['water_c']
water_c = cur_scale*current[0::3, 0::3]
data1 = np.load(path + '/data/policy/policy_gps_' + str(seed) + '.npz')
pos = data1['pos']
pos_est = data1['pos_est']
no_gps = data1['no_gps']
data2 = np.load(path + '/data/value/value_' + str(seed) + '.npz', allow_pickle=True)
value = data2['value']
if len(no_gps) > 0:
no_gps_line = [no_gps[0]]
for i in range(1, len(no_gps)):
if no_gps[i-1] not in no_gps_line:
no_gps_line.append(no_gps[i-1])
no_gps_line.append(no_gps[i])
_plot_fig, _plot_axs = plt.subplots(1, 1, figsize=(8, 8))
_plot_lines = []
mappable = _plot_axs.pcolormesh(chart.transpose(), vmin=-1.3, vmax=1.3, cmap=cmocean.cm.delta)
quiver = _plot_axs.quiver(water_p[0]+0.5, water_p[1]+0.5, water_c[:, :, 0].transpose(), water_c[:, :, 1].transpose(), scale=20, alpha=0.5, color='w')
_plot_axs.plot(pos[:, 0]*dim, pos[:, 1]*dim, c='k', zorder=1)
_plot_axs.scatter(pos[:, 0][0::n_t]*dim, pos[:, 1][0::n_t]*dim, c='k', zorder=1)
_plot_axs.scatter(pos[-1, 0]*dim, pos[-1, 1]*dim, c='k', zorder=1)
if len(no_gps) > 0:
_plot_axs.plot(pos_est[no_gps_line[1:], 0]*dim, pos_est[no_gps_line[1:], 1]*dim, c='g', zorder=10)
_plot_axs.scatter(pos_est[no_gps[1:], 0]*dim, pos_est[no_gps[1:], 1]*dim, c='g', s=15, zorder=10)
_plot_axs.scatter([pos[0, 0]*dim], [pos[0, 1]*dim], c='w', zorder=1)
_plot_axs.set_xticks([])
_plot_axs.set_yticks([])
_plot_axs.set_xlim([0, dim])
_plot_axs.set_ylim([0, dim])
plt.savefig(path + '/PDFs/policy/policy_gps_' + str(seed) + '.pdf')
_plot_fig, _plot_axs = plt.subplots(1, 1, figsize=(8, 8))
_plot_lines = []
mappable = _plot_axs.pcolormesh(value.transpose(), vmin=-1.0, vmax=1.0, cmap='RdYlGn')
quiver = _plot_axs.quiver(water_p[0]+0.5, water_p[1]+0.5, water_c[:, :, 0].transpose(), water_c[:, :, 1].transpose(), scale=20, alpha=0.5, color='w')
_plot_axs.plot(pos[:, 0]*dim, pos[:, 1]*dim, c='k', zorder=1)
_plot_axs.scatter(pos[:, 0][0::n_t]*dim, pos[:, 1][0::n_t]*dim, c='k', zorder=1)
_plot_axs.scatter(pos[-1, 0]*dim, pos[-1, 1]*dim, c='k', zorder=1)
if len(no_gps) > 0:
_plot_axs.plot(pos_est[no_gps_line[1:], 0]*dim, pos_est[no_gps_line[1:], 1]*dim, c='g', zorder=10)
_plot_axs.scatter(pos_est[no_gps[1:], 0]*dim, pos_est[no_gps[1:], 1]*dim, c='g', zorder=10)
_plot_axs.scatter([pos[0, 0]*dim], [pos[0, 1]*dim], c='w', zorder=1)
_plot_axs.set_xticks([])
_plot_axs.set_yticks([])
_plot_axs.set_xlim([0, dim])
_plot_axs.set_ylim([0, dim])
plt.savefig(path + '/PDFs/policy/policy_value_gps_' + str(seed) + '.pdf')