Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Benchmark downloader #449

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# tensorboard runs
runs/

#benchmark folder
download_benchmark/

# videos
#*.mp4
notebooks/videos/*
Expand Down
9 changes: 9 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ Logging Utilities
utils.logging.set_level


Benchmarks
----------------

.. autosummary::
:toctree: generated/
:template: function.rst

benchmarks.benchmark_utils.download_benchmark_from_SB3_zoo

Environment Wrappers
====================

Expand Down
28 changes: 28 additions & 0 deletions docs/basics/userguide/benchmarks.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
(benchmarks)=

# How to use benchmarks

Rlberry has a tool to download some benchmarks if you need to compare your agent with them.

Currently, the available benchmars are :
- Pre-trained Reinforcement Learning agents using the rl-baselines3-zoo and Stable Baselines3 ([here](https://github.com/DLR-RM/rl-trained-agents)).

## Download the benchmark
To download the benchmark it's easy, you just have to call the function matching the expected benchmark.
You need to specify the names of the agent and the environment. And If you want overwrite the previous data on this combination. (you can use the `output_dir` parameter if you want to download the benchmark in a specific folder).
You can find the API about this benchmark [here](rlberry.benchmarks.benchmark_utils.download_benchmark_from_SB3_zoo)

```python
from rlberry.benchmarks.benchmark_utils import download_benchmark_from_SB3_zoo

agent_name = "dqn"
environment_name = "PongNoFrameskip-v4_1"

path_with_downloaded_files = download_benchmark_from_SB3_zoo(
agent_name, environment_name, overwrite=True
)
```

## How to use these benchmarks to compare your agent

in construction ...
7 changes: 6 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ Changelog
Dev version
-----------

* nothing


*PR #449*

* Add module to download benchmarks
- Pre-trained Reinforcement Learning agents using the rl-baselines3-zoo and Stable Baselines3.

Version 0.7.1
-------------
Expand Down
1 change: 1 addition & 0 deletions docs/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ You can find more details about installation [here](installation)!
- Transfer Learning (In construction)
- [Hypothesis testing for comparison of RL agents](comparison_page)
- [Adaptive hypothesis testing for comparison of RL agents with AdaStop](adastop_userguide)
- [Using benchmarks](benchmarks)
151 changes: 151 additions & 0 deletions rlberry/benchmarks/benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import requests
from tempfile import mkdtemp
import os
import shutil


# # TODO : convert external benchmark to DataFrame that match the input of rlberry.manager.comparaison.py -> compare_agents_data()
# # TODO : Download the external benchmark to a specific folder (or new rlberrygithub?), except if they are stable (huggingface/github)

# benchmark_list = {
# "Google Atari bucket": "https://console.cloud.google.com/storage/brow",
# "SB3 zoo": "https://github.com/DLR-RM/rl-baselines3-zoo/tree/master/logs/benchmark",
# "cleanrl": "https://wandb.ai/openrlbenchmark/openrlbenchmark/reportlist",
# }


# def import_from_google_atari_bucket():
# """import benchmark from Google Atari bucket

# Parameters
# -----------
# x_vec : numpy.ndarray
# numpy 1d array to be searched in the bins
# bins : list
# list of numpy 1d array, bins[d] = bins of the d-th dimension


# Returns
# --------
# index (int) corresponding to the position of x in the partition
# defined by the bins.
# """
# print("TODO")


# def import_from_cleanrl():
# print("TODO")


# def import_from_hugingface():
# print("TODO")


def download_benchmark_from_SB3_zoo(
agent_name, environment_name, overwrite, output_dir=None
):
"""
Download folder from pre-trained Reinforcement Learning agents using the rl-baselines3-zoo and Stable Baselines3.
https://github.com/DLR-RM/rl-trained-agents

Parameters
-----------
agent_name : str
agent name for benchmark to download
environment_name : list
environment name for benchmark to download
overwrite : bool
how to manage if the combination agent_name/environment_name exist :
True : delete the previous folder, then download
False : raise an error
output_dir : str
root path where to download files. (default=None : create temp folder)

Returns
--------
Return the path containing the downloaded files (output_dir/agent_name/environment_name)
"""
if not output_dir:
output_dir = mkdtemp()

Check warning on line 69 in rlberry/benchmarks/benchmark_utils.py

View check run for this annotation

Codecov / codecov/patch

rlberry/benchmarks/benchmark_utils.py#L69

Added line #L69 was not covered by tests

GITHUB_URL = "https://raw.githubusercontent.com/DLR-RM/rl-trained-agents/master/"
base_url = GITHUB_URL + agent_name + "/" + environment_name + "/"

output_folder = os.path.join(output_dir, agent_name, environment_name)
environment_base_name = environment_name.split("_")[0]

if os.path.exists(output_folder):
if not overwrite:
raise FileExistsError(
"The 'overwrite' bool is false, and the combination %s / %s already exist"
% (agent_name, environment_name)
)
shutil.rmtree(output_folder)
os.makedirs(output_folder)

# download CSVs
url_content = None
i = 0
while url_content != b"404: Not Found":
file_name_to_download = str(i) + ".monitor.csv"
url_csv_to_download = base_url + file_name_to_download

req = requests.get(url_csv_to_download)
url_content = req.content

if url_content != b"404: Not Found":
csv_file = open(os.path.join(output_folder, file_name_to_download), "wb")
csv_file.write(url_content)
csv_file.close()
else:
break
i = i + 1

# download zip
file_name_to_download = environment_base_name + ".zip"
url_zip_to_download = base_url + file_name_to_download
req = requests.get(url_zip_to_download)
url_content = req.content
csv_file = open(os.path.join(output_folder, file_name_to_download), "wb")
csv_file.write(url_content)
csv_file.close()

# download evaluations.npz
file_name_to_download = "evaluations.npz"
url_zip_to_download = base_url + file_name_to_download
req = requests.get(url_zip_to_download)
url_content = req.content
csv_file = open(os.path.join(output_folder, file_name_to_download), "wb")
csv_file.write(url_content)
csv_file.close()

# hyperparameter and config
config_folder = output_folder + "/" + environment_base_name
base_url_config = base_url + environment_base_name + "/"

os.makedirs(config_folder)
file_name_to_download = "args.yml"
url_zip_to_download = base_url_config + file_name_to_download
req = requests.get(url_zip_to_download)
url_content = req.content
csv_file = open(os.path.join(config_folder, file_name_to_download), "wb")
csv_file.write(url_content)
csv_file.close()

file_name_to_download = "config.yml"
url_zip_to_download = base_url_config + file_name_to_download
req = requests.get(url_zip_to_download)
url_content = req.content
csv_file = open(os.path.join(config_folder, file_name_to_download), "wb")
csv_file.write(url_content)
csv_file.close()

file_name_to_download = "vecnormalize.pkl"
url_zip_to_download = base_url_config + file_name_to_download
req = requests.get(url_zip_to_download)
url_content = req.content
csv_file = open(os.path.join(config_folder, file_name_to_download), "wb")
csv_file.write(url_content)
csv_file.close()

return output_folder
88 changes: 88 additions & 0 deletions rlberry/benchmarks/tests/test_benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import shutil
from rlberry.benchmarks.benchmark_utils import download_benchmark_from_SB3_zoo
import pytest


@pytest.mark.parametrize("agent_class", ["dqn"])
@pytest.mark.parametrize("env", ["PongNoFrameskip-v4_1"])
def test_download_benchmark_from_SB3_zoo_(agent_class, env):
# remove previous test if existing
test_folder_path = "./tests_dl"
if os.path.exists(test_folder_path):
shutil.rmtree(test_folder_path)

Check warning on line 13 in rlberry/benchmarks/tests/test_benchmark_utils.py

View check run for this annotation

Codecov / codecov/patch

rlberry/benchmarks/tests/test_benchmark_utils.py#L13

Added line #L13 was not covered by tests
os.makedirs(test_folder_path)

# download benchmark
ret_value = download_benchmark_from_SB3_zoo(
agent_class, env, overwrite=True, output_dir=test_folder_path
)

# tests expected result
environment_base_name = env.split("_")[0]
assert str(os.path.join(test_folder_path, agent_class, env)) == ret_value
assert os.path.exists(os.path.join(test_folder_path, agent_class, env))
assert os.path.exists(
os.path.join(test_folder_path, agent_class, env, "0.monitor.csv")
)
assert os.path.exists(
os.path.join(test_folder_path, agent_class, env, environment_base_name + ".zip")
)
assert os.path.exists(
os.path.join(test_folder_path, agent_class, env, "evaluations.npz")
)
assert os.path.exists(
os.path.join(
test_folder_path, agent_class, env, environment_base_name, "args.yml"
)
)
assert os.path.exists(
os.path.join(
test_folder_path, agent_class, env, environment_base_name, "config.yml"
)
)
assert os.path.exists(
os.path.join(
test_folder_path,
agent_class,
env,
environment_base_name,
"vecnormalize.pkl",
)
)

if os.path.exists(test_folder_path):
shutil.rmtree(test_folder_path)


@pytest.mark.parametrize("agent_class", ["dqn"])
@pytest.mark.parametrize("env", ["PongNoFrameskip-v4_1"])
@pytest.mark.parametrize("overwrite", [True, False])
def test_download_benchmark_from_SB3_zoo_overwrite_True(agent_class, env, overwrite):
# remove previous test if existing
test_folder_path = "./tests_dl"
if os.path.exists(test_folder_path):
shutil.rmtree(test_folder_path)

Check warning on line 65 in rlberry/benchmarks/tests/test_benchmark_utils.py

View check run for this annotation

Codecov / codecov/patch

rlberry/benchmarks/tests/test_benchmark_utils.py#L65

Added line #L65 was not covered by tests
os.makedirs(test_folder_path)

# first call
ret_value = download_benchmark_from_SB3_zoo(
agent_class, env, overwrite=overwrite, output_dir=test_folder_path
)

#'overwrite' test
error_was_raised = False
try:
ret_value = download_benchmark_from_SB3_zoo(
agent_class, env, overwrite=overwrite, output_dir=test_folder_path
)
except FileExistsError:
error_was_raised = True

if overwrite:
assert not error_was_raised
else:
assert error_was_raised

if os.path.exists(test_folder_path):
shutil.rmtree(test_folder_path)
Loading
Loading