-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
112 lines (96 loc) · 3.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import importlib
from popgym.wrappers import (
Antialias,
PreviousAction,
Flatten,
DiscreteAction,
)
import gymnasium as gym
import jax
import jax.numpy as jnp
import optax
import numpy as np
def scale_by_norm(scale: float=1.0, eps: float=1e-6):
def init_fn(params):
del params
return optax._src.base.EmptyState()
def update_fn(updates, state, params=None):
del params
g_norm = jnp.maximum(optax.global_norm(updates) + eps, 1 / scale)
#g_norm = (optax.global_norm(updates) / scale + eps)
def scale_fn(t):
return t / g_norm
updates = jax.tree_util.tree_map(scale_fn, updates)
return updates, state
return optax.GradientTransformation(init_fn, update_fn)
def load_popgym_env(config, eval=False):
if config["collect"]["popgym_env"]:
module, cls = config["collect"]["env"].rsplit(".", 1)
mod = importlib.import_module(module)
instance = getattr(mod, cls)(**config["collect"].get("env_kwargs", {}))
if config["collect"]["env_prev_action"]:
instance = PreviousAction(instance)
instance = Flatten(Antialias(instance))
if isinstance(instance.action_space, gym.spaces.MultiDiscrete):
instance = DiscreteAction(instance)
elif config["collect"].get("memorygym_env"):
module, cls = config["collect"]["env"].rsplit(".", 1)
mod = importlib.import_module(module)
instance = getattr(mod, cls)()
instance.default_reset_parameters.update(config["collect"].get("env_kwargs", {}))
else:
instance = gym.make(config["collect"]["env"], **config["collect"].get("env_kwargs", {}))
if config["collect"].get("atari_env") and config["model"].get("atari_cnn"):
instance = gym.wrappers.AtariPreprocessing(instance, frame_skip=1, scale_obs=True)
instance.action_space.seed(config["seed"] + eval * 1000)
return instance
instance.action_space.seed(config["seed"] + eval * 1000)
return instance
def filter_inf(log_dict):
d = {}
for k, v in log_dict.items():
if k != float("-inf"):
d[k] = v
return d
@jax.jit
def expand_right(src, shape):
a_dims = len(src.shape)
b_dims = len(shape)
right = [1] * (a_dims - b_dims)
return src.reshape(*src.shape, *right)
def get_summary_info(model):
"""An alternative repr useful for initial debugging"""
import pandas as pd
def get_info(v):
info = dict()
info['type'] = type(v).__name__
info['dtype'] = v.dtype.name if hasattr(v, 'dtype') else None
info['shape'] = jnp.shape(v)
info['size'] = jnp.size(v)
#info['nancount'] = np.isnan(v).sum()
#info['zerocount'] = np.size(v) - np.count_nonzero(v)
info['min'] = jnp.min(v).item()
info['max'] = jnp.max(v).item()
info['mean'] = jnp.mean(v).item()
info['std'] = jnp.std(v).item()
info['norm'] = jnp.linalg.norm(v).item()
return info
d_ = {jax.tree_util.keystr(k): get_info(v) for k, v in jax.tree_util.tree_leaves_with_path(model) if isinstance(v, (jax.Array, float))}
return pd.DataFrame(d_).T
def get_wandb_model_info(model):
"""An alternative repr useful for initial debugging"""
info = {}
for k, v in jax.tree_util.tree_leaves_with_path(model):
if isinstance(v, (jax.Array)):
prefix = "params/model"
k = jax.tree_util.keystr(k)
info[prefix + k + '.mean'] = jnp.mean(v)
info[prefix + k + '.std'] = jnp.std(v)
info[prefix + k + '.norm'] = jnp.linalg.norm(v)
return info
def elementwise_grad(g):
def wrapped(x, *rest):
y, g_vjp = jax.vjp(lambda x: g(x, *rest), x)
x_bar, = g_vjp(jnp.ones_like(y))
return x_bar
return wrapped