Skip to content

Commit

Permalink
regenerated benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Mar 4, 2024
1 parent 2b0f87b commit 92a54ff
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 20 deletions.
21 changes: 10 additions & 11 deletions scripts/generate_benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,16 @@ python scripts/ruleset_generator.py \
--total_rulesets=1_000_000 \
--save_path="high_1m"


# medium + distractors
python scripts/ruleset_generator.py \
--prune_chain \
--prune_prob=0.8 \
--chain_depth=2 \
--sample_distractor_rules \
--num_distractor_rules=4 \
--num_distractor_objects=2 \
--total_rulesets=1_000_000 \
--save_path="medium_dist_1m"
## medium + distractors
#python scripts/ruleset_generator.py \
# --prune_chain \
# --prune_prob=0.8 \
# --chain_depth=2 \
# --sample_distractor_rules \
# --num_distractor_rules=4 \
# --num_distractor_objects=2 \
# --total_rulesets=1_000_000 \
# --save_path="medium_dist_1m"

# medium 3M
python scripts/ruleset_generator.py \
Expand Down
13 changes: 6 additions & 7 deletions src/xminigrid/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
DATA_PATH = os.environ.get("XLAND_MINIGRID_DATA", os.path.expanduser("~/.xland_minigrid"))

NAME2HFFILENAME = {
"trivial-1m": "trivial_1m",
"small-1m": "small_1m",
"small-dist-1m": "small_dist_1m",
"medium-1m": "medium_1m_v1",
"medium-3m": "medium_3m_v1",
"high-1m": "high_1m",
"high-3m": "high_3m",
"trivial-1m": "trivial_1m_v2",
"small-1m": "small_1m_v2",
"medium-1m": "medium_1m_v1_v2",
"medium-3m": "medium_3m_v1_v2",
"high-1m": "high_1m_v2",
"high-3m": "high_3m_v2",
}


Expand Down
3 changes: 2 additions & 1 deletion src/xminigrid/experimental/img_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np.
cache_path = os.path.join(CACHE_PATH, "render_cache")
if not os.path.exists(cache_path):
os.makedirs(CACHE_PATH, exist_ok=True)

print("Building rendering cache, may take a while...")
TILE_CACHE, TILE_W_AGENT_CACHE = build_cache(np.asarray(TILES_REGISTRY), tile_size=TILE_SIZE)
TILE_CACHE = jnp.asarray(TILE_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3)
TILE_W_AGENT_CACHE = jnp.asarray(TILE_W_AGENT_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3)

print("Done. Cache will be reused on consequent runs.")
save_bz2_pickle({"tile_cache": TILE_CACHE, "tile_agent_cache": TILE_W_AGENT_CACHE}, cache_path)

TILE_CACHE = load_bz2_pickle(cache_path)["tile_cache"]
Expand Down
5 changes: 4 additions & 1 deletion src/xminigrid/manual_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from pygame.event import Event

import xminigrid
from xminigrid.wrappers import GymAutoResetWrapper

from .environment import Environment, EnvParamsT
from .rendering.text_render import print_ruleset
from .types import EnvCarryT
from .wrappers import GymAutoResetWrapper


class ManualControl:
Expand Down Expand Up @@ -162,6 +163,8 @@ def close(self) -> None:
if "XLand" in args.env_id:
bench = xminigrid.load_benchmark(args.benchmark_id)
env_params = env_params.replace(ruleset=bench.get_ruleset(args.ruleset_id))
print_ruleset(env_params.ruleset)
print()

control = ManualControl(env=env, env_params=env_params, agent_view=args.agent_view)
control.start()

0 comments on commit 92a54ff

Please sign in to comment.