Skip to content

Commit

Permalink
some fixes (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Mar 25, 2024
1 parent 0914e82 commit 477a23a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 17 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ DIRS = data examples jat scripts tests
# Check that source code meets quality standards
quality:
black --check $(DIRS) setup.py
ruff $(DIRS) setup.py
ruff check $(DIRS) setup.py

# Format source code automatically
style:
black $(DIRS) setup.py
ruff $(DIRS) setup.py --fix
ruff check $(DIRS) setup.py --fix

# Run tests for the library
test:
Expand Down
23 changes: 13 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
To get started with JAT, follow these steps:

1. Clone this repository onto your local machine.

```shell
git clone https://github.com/huggingface/jat.git
cd jat
```
git clone https://github.com/huggingface/jat.git
cd jat
```

2. Create a new virtual environment and activate it, and install required dependencies via pip.
```

```shell
python3 -m venv env
source env/bin/activate
pip install .
```

## Demonstration of the trained agent

The trained JAT agent is available [here](https://huggingface.co/jat-project/jat). The following script gives an example of the use of this agent on the Pong environment
Expand Down Expand Up @@ -65,27 +67,28 @@ env.close()
% GIF of trained agent here

## Usage Examples
Here are some examples of how you might use JAT in both evaluation and fine-tuning modes. More detailed information about each example is provided within the corresponding script files.
* **Evaluation Mode**: Evaluate pretrained JAT models on specific downstream tasks
```
```shell
python scripts/eval_jat.py --model_name_or_path jat-project/jat --tasks atari-pong --trust_remote_code
```
* **Training Mode**: Train your own JAT model from scratch
```
```shell
python scripts/train_jat.py %TODO
```
For further details regarding usage, consult the documentation included with individual script files.
## Dataset
% TODO
## Citation
Please ensure proper citations when incorporating this work into your projects.
Expand Down
4 changes: 3 additions & 1 deletion jat/modeling_jat.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,9 @@ def to_list(x):

# Context window
if context_window is not None:
self._last_key_values = tuple(tuple(pkv[:, :, -context_window:] for pkv in pkvs) for pkvs in self._last_key_values)
self._last_key_values = tuple(
tuple(pkv[:, :, -context_window:] for pkv in pkvs) for pkvs in self._last_key_values
)

# Return the predicted action
if continuous_actions is not None:
Expand Down
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ line-length = 119
target-version = ['py38']

[tool.ruff]
line-length = 119

[tool.ruff.lint]
ignore = ["C901"]
select = ["C", "E", "F", "I", "W"]
line-length = 119

[tool.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["jat"]
known-first-party = ["jat"]
4 changes: 3 additions & 1 deletion scripts/eval_jat.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def eval_rl(model, processor, task, eval_args):
done = False
model.reset_rl() # remove KV Cache
while not done:
action = model.get_next_action(processor, **observation, reward=reward, action_space=env.action_space, context_window=context_window)
action = model.get_next_action(
processor, **observation, reward=reward, action_space=env.action_space, context_window=context_window
)
observation, reward, termined, truncated, info = env.step(action)
done = termined or truncated

Expand Down

0 comments on commit 477a23a

Please sign in to comment.