diff --git a/Pyrado/pyrado/domain_randomization/domain_parameter.py b/Pyrado/pyrado/domain_randomization/domain_parameter.py index 331ed895b5..02eeb06750 100644 --- a/Pyrado/pyrado/domain_randomization/domain_parameter.py +++ b/Pyrado/pyrado/domain_randomization/domain_parameter.py @@ -403,19 +403,11 @@ def from_domain_randomizer(domain_randomizer, *, target_cov_factor=1.0, init_cov :param init_cov_factor: scaling of the randomizer's variance to get the init variance; defaults to `1/100` :return: the self-paced domain parameter """ - ( - name, - target_mean, - target_cov_flat, - init_mean, - init_cov_flat, - ) = ( - [], - [], - [], - [], - [], - ) + name = [] + target_mean = [] + target_cov_flat = [] + init_mean = [] + init_cov_flat = [] for domain_param in domain_randomizer.domain_params: if not isinstance(domain_param, NormalDomainParam): raise pyrado.TypeErr( diff --git a/Pyrado/pyrado/plotting/heatmap.py b/Pyrado/pyrado/plotting/heatmap.py index de2e3bd970..fbf1049b8a 100644 --- a/Pyrado/pyrado/plotting/heatmap.py +++ b/Pyrado/pyrado/plotting/heatmap.py @@ -190,6 +190,10 @@ def draw_heatmap( :return: handles to the heat map and the color bar figures (`None` if not existent) """ if isinstance(data, pd.DataFrame): + if not data.index.is_numeric(): + raise pyrado.TypeErr(given=data.index, msg="expected numeric index") + if not data.index.is_numeric(): + raise pyrado.TypeErr(given=data.columns, msg="expected numeric index") # Extract the data x = data.columns y = data.index diff --git a/Pyrado/tests/test_domain_randomization.py b/Pyrado/tests/test_domain_randomization.py index 8809f6e4ac..1b794090a0 100644 --- a/Pyrado/tests/test_domain_randomization.py +++ b/Pyrado/tests/test_domain_randomization.py @@ -33,16 +33,31 @@ import torch as to from tests.conftest import m_needs_bullet, m_needs_mujoco +import pyrado from pyrado.domain_randomization.domain_parameter import ( BernoulliDomainParam, MultivariateNormalDomainParam, NormalDomainParam, + SelfPacedDomainParam, UniformDomainParam, ) +from pyrado.domain_randomization.domain_randomizer import DomainRandomizer from pyrado.domain_randomization.utils import param_grid from pyrado.environments.sim_base import SimEnv +def assert_is_close(param, info_expected): + info_actual = param.info() + assert list(sorted(info_actual.keys())) == list(sorted(info_expected.keys())) + for key, expected in info_expected.items(): + if key in ["name", "clip_lo", "clip_up"]: + assert info_actual[key] == expected + else: + assert to.allclose( + info_actual[key], expected + ), f"key: {key}, actual: {info_actual[key]}, expected: {expected}" + + @pytest.mark.parametrize( "dp", [ @@ -66,6 +81,42 @@ def test_domain_param(dp, num_samples): assert len(s) == num_samples +def test_self_paced_domain_param_make_broadening(): + param = SelfPacedDomainParam.make_broadening(["a"], [1.0], 0.0004, 0.04) + assert_is_close( + param, + { + "name": ["a"], + "target_mean": to.tensor([1.0]).double(), + "target_cov_chol": to.tensor([0.2]).double(), + "init_mean": to.tensor([1.0]).double(), + "init_cov_chol": to.tensor([0.02]).double(), + "clip_lo": -pyrado.inf, + "clip_up": pyrado.inf, + }, + ) + + +def test_self_paced_domain_param_from_domain_randomizer(): + param = SelfPacedDomainParam.from_domain_randomizer( + DomainRandomizer(NormalDomainParam(name="a", mean=1.0, std=1.0)), + init_cov_factor=0.0004, + target_cov_factor=0.04, + ) + assert_is_close( + param, + { + "name": ["a"], + "target_mean": to.tensor([1.0]).double(), + "target_cov_chol": to.tensor([0.2]).double(), + "init_mean": to.tensor([1.0]).double(), + "init_cov_chol": to.tensor([0.02]).double(), + "clip_lo": -pyrado.inf, + "clip_up": pyrado.inf, + }, + ) + + def test_randomizer(default_randomizer): print(default_randomizer) # Generate 7 samples