Skip to content

Commit

Permalink
address review issues
Browse files Browse the repository at this point in the history
  • Loading branch information
fdamken committed Oct 22, 2023
1 parent 3836256 commit 7af231c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 13 deletions.
18 changes: 5 additions & 13 deletions Pyrado/pyrado/domain_randomization/domain_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,19 +403,11 @@ def from_domain_randomizer(domain_randomizer, *, target_cov_factor=1.0, init_cov
:param init_cov_factor: scaling of the randomizer's variance to get the init variance; defaults to `1/100`
:return: the self-paced domain parameter
"""
(
name,
target_mean,
target_cov_flat,
init_mean,
init_cov_flat,
) = (
[],
[],
[],
[],
[],
)
name = []
target_mean = []
target_cov_flat = []
init_mean = []
init_cov_flat = []
for domain_param in domain_randomizer.domain_params:
if not isinstance(domain_param, NormalDomainParam):
raise pyrado.TypeErr(
Expand Down
4 changes: 4 additions & 0 deletions Pyrado/pyrado/plotting/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ def draw_heatmap(
:return: handles to the heat map and the color bar figures (`None` if not existent)
"""
if isinstance(data, pd.DataFrame):
if not data.index.is_numeric():
raise pyrado.TypeErr(given=data.index, msg="expected numeric index")
if not data.index.is_numeric():
raise pyrado.TypeErr(given=data.columns, msg="expected numeric index")
# Extract the data
x = data.columns
y = data.index
Expand Down
51 changes: 51 additions & 0 deletions Pyrado/tests/test_domain_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,31 @@
import torch as to
from tests.conftest import m_needs_bullet, m_needs_mujoco

import pyrado
from pyrado.domain_randomization.domain_parameter import (
BernoulliDomainParam,
MultivariateNormalDomainParam,
NormalDomainParam,
SelfPacedDomainParam,
UniformDomainParam,
)
from pyrado.domain_randomization.domain_randomizer import DomainRandomizer
from pyrado.domain_randomization.utils import param_grid
from pyrado.environments.sim_base import SimEnv


def assert_is_close(param, info_expected):
info_actual = param.info()
assert list(sorted(info_actual.keys())) == list(sorted(info_expected.keys()))
for key, expected in info_expected.items():
if key in ["name", "clip_lo", "clip_up"]:
assert info_actual[key] == expected
else:
assert to.allclose(
info_actual[key], expected
), f"key: {key}, actual: {info_actual[key]}, expected: {expected}"


@pytest.mark.parametrize(
"dp",
[
Expand All @@ -66,6 +81,42 @@ def test_domain_param(dp, num_samples):
assert len(s) == num_samples


def test_self_paced_domain_param_make_broadening():
param = SelfPacedDomainParam.make_broadening(["a"], [1.0], 0.0004, 0.04)
assert_is_close(
param,
{
"name": ["a"],
"target_mean": to.tensor([1.0]).double(),
"target_cov_chol": to.tensor([0.2]).double(),
"init_mean": to.tensor([1.0]).double(),
"init_cov_chol": to.tensor([0.02]).double(),
"clip_lo": -pyrado.inf,
"clip_up": pyrado.inf,
},
)


def test_self_paced_domain_param_from_domain_randomizer():
param = SelfPacedDomainParam.from_domain_randomizer(
DomainRandomizer(NormalDomainParam(name="a", mean=1.0, std=1.0)),
init_cov_factor=0.0004,
target_cov_factor=0.04,
)
assert_is_close(
param,
{
"name": ["a"],
"target_mean": to.tensor([1.0]).double(),
"target_cov_chol": to.tensor([0.2]).double(),
"init_mean": to.tensor([1.0]).double(),
"init_cov_chol": to.tensor([0.02]).double(),
"clip_lo": -pyrado.inf,
"clip_up": pyrado.inf,
},
)


def test_randomizer(default_randomizer):
print(default_randomizer)
# Generate 7 samples
Expand Down

0 comments on commit 7af231c

Please sign in to comment.