Skip to content

Commit

Permalink
[CI] Fix windows CI (#1746)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 20, 2023
1 parent 4d3a0c6 commit f0b4814
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 112 deletions.
10 changes: 7 additions & 3 deletions .github/unittest/windows_optdepts/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ conda activate ./env
if [ "${CU_VERSION:-}" == cpu ] ; then
cudatoolkit="cpuonly"
version="cpu"
torch_cuda="False"
else
if [[ ${#CU_VERSION} -eq 4 ]]; then
CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}"
Expand All @@ -39,9 +40,9 @@ git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with %s\n" "${cudatoolkit}"
if $torch_cuda ; then
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
else
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
fi

torch_cuda=$(python -c "import torch; print(torch.cuda.is_available())")
Expand All @@ -57,7 +58,10 @@ fi
#python -m pip install pip --upgrade

# install tensordict
pip3 install git+https://github.com/pytorch/tensordict
git clone https://github.com/pytorch/tensordict
cd tensordict
python setup.py develop
cd ..

# smoke test
python -c """
Expand Down
3 changes: 0 additions & 3 deletions .github/unittest/windows_optdepts/scripts/post_process.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#!/usr/bin/env bash

set -e

eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')"
conda activate ./env
4 changes: 4 additions & 0 deletions .github/unittest/windows_optdepts/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,9 @@ conda activate ./env
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
source "$this_dir/set_cuda_envs.sh"

# we don't use torchsnapshot
export CKPT_BACKEND=torch
export MAX_IDLE_COUNT=60

python -m torch.utils.collect_env
pytest --junitxml=test-results/junit.xml -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py
21 changes: 12 additions & 9 deletions .github/unittest/windows_optdepts/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@ fi
eval "$(${conda_dir}/Scripts/conda.exe 'shell.bash' 'hook')"

# 2. Create test environment at ./env
if [ ! -d "${env_dir}" ]; then
printf "* Creating a test environment\n"
conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
fi
printf "* Creating a test environment\n"
conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"

printf "* Activating the environment"
conda deactivate
conda activate "${env_dir}"

# 3. Install Conda dependencies
printf "* Installing dependencies (except PyTorch)\n"
conda env update --file "${this_dir}/environment.yml" --prune
printf "Python version"
echo $(which python)
echo $(python --version)
echo $(conda info -e)

#conda env update --file "${this_dir}/environment.yml" --prune

# we don't use torchsnapshot
conda env config vars set CKPT_BACKEND=torch
python -m pip install hypothesis future cloudpickle pytest pytest-cov pytest-mock pytest-instafail pytest-rerunfailures expecttest pyyaml scipy coverage
51 changes: 0 additions & 51 deletions .github/workflows/test-windows-optdepts-gpu.yml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Unit-tests on Windows CPU
name: Unit-tests on Windows

on:
pull_request:
Expand All @@ -16,17 +16,18 @@ concurrency:
cancel-in-progress: true

jobs:
unittests:
unittests-cpu:
uses: pytorch/test-infra/.github/workflows/windows_job.yml@main
with:
runner: "windows.4xlarge"
repository: pytorch/rl
timeout: 120
timeout: 40
script: |
set -euxo pipefail
export PYTHON_VERSION="3.9"
export CU_VERSION="cpu"
export torch_cuda="False"
# TODO: Port this to pytorch/test-infra/.github/workflows/windows_job.yml
export PATH="/c/Jenkins/Miniconda3/Scripts:${PATH}"
Expand All @@ -45,3 +46,40 @@ jobs:
## post_process.sh
./.github/unittest/windows_optdepts/scripts/post_process.sh
unittests-gpu:
uses: pytorch/test-infra/.github/workflows/windows_job.yml@main
with:
runner: "windows.g5.4xlarge.nvidia.gpu"
repository: pytorch/rl
timeout: 40
script: |
set -euxo pipefail
export PYTHON_VERSION="3.9"
export CUDA_VERSION="11.6"
export CU_VERSION="cu116"
export torch_cuda="True"
# TODO: Port this to pytorch/test-infra/.github/workflows/windows_job.yml
export PATH="/c/Jenkins/Miniconda3/Scripts:${PATH}"
echo "PYTHON_VERSION: $PYTHON_VERSION"
## setup_env.sh
./.github/unittest/windows_optdepts/scripts/setup_env.sh
## Install CUDA
packaging/windows/internal/cuda_install.bat
## Update CUDA Driver
packaging/windows/internal/driver_update.bat
## install.sh
./.github/unittest/windows_optdepts/scripts/install.sh
## run_test.sh
./.github/unittest/windows_optdepts/scripts/run_test.sh
## post_process.sh
./.github/unittest/windows_optdepts/scripts/post_process.sh
1 change: 0 additions & 1 deletion examples/rlhf/train_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def main(cfg):
always_save_checkpoint = train_cfg.always_save_checkpoint

device = cfg.sys.device
dtype = cfg.sys.dtype
compile_ = cfg.sys.compile

ctx = setup(cfg.sys)
Expand Down
1 change: 1 addition & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
VERBOSE = strtobool(os.environ.get("VERBOSE", "0"))
_os_is_windows = sys.platform == "win32"
RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1"))
BATCHED_PIPE_TIMEOUT = float(os.environ.get("RL_WARNINGS", "60.0"))


class timeit:
Expand Down
52 changes: 28 additions & 24 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

_TIMEOUT = 1.0
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", 1000))

DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM
Expand Down Expand Up @@ -1371,32 +1372,35 @@ def shutdown(self) -> None:
self._shutdown_main()

def _shutdown_main(self) -> None:
if self.closed:
return
_check_for_faulty_process(self.procs)
self.closed = True
for idx in range(self.num_workers):
if not self.procs[idx].is_alive():
continue
try:
self.pipes[idx].send((None, "close"))
try:
if self.closed:
return
_check_for_faulty_process(self.procs)
self.closed = True
for idx in range(self.num_workers):
if not self.procs[idx].is_alive():
continue
try:
self.pipes[idx].send((None, "close"))

if self.pipes[idx].poll(10.0):
msg = self.pipes[idx].recv()
if msg != "closed":
raise RuntimeError(f"got {msg} but expected 'close'")
else:
if self.pipes[idx].poll(10.0):
msg = self.pipes[idx].recv()
if msg != "closed":
raise RuntimeError(f"got {msg} but expected 'close'")
else:
continue
except BrokenPipeError:
continue
except BrokenPipeError:
continue

for proc in self.procs:
exitcode = proc.join(1.0)
if exitcode is None:
proc.terminate()
self.queue_out.close()
for pipe in self.pipes:
pipe.close()
self.queue_out.close()
for pipe in self.pipes:
pipe.close()
for proc in self.procs:
proc.join(1.0)
finally:
for proc in self.procs:
if proc.is_alive():
proc.terminate()

def set_seed(self, seed: int, static_seed: bool = False) -> int:
"""Sets the seeds of the environments stored in the DataCollector.
Expand Down Expand Up @@ -2165,7 +2169,7 @@ def _main_async_collector(
f"if this is expected via the environment variable MAX_IDLE_COUNT "
f"(current value is {_MAX_IDLE_COUNT})."
f"\nIf this occurs at the end of a function or program, it means that your collector has not been "
f"collected, consider calling `collector.shutdown()` or `del collector` before ending the program."
f"collected, consider calling `collector.shutdown()` before ending the program."
)
continue
if msg in ("continue", "continue_random"):
Expand Down
51 changes: 33 additions & 18 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,23 +1026,28 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:

@_check_start
def _shutdown_workers(self) -> None:
if self.is_closed:
raise RuntimeError(
"calling {self.__class__.__name__}._shutdown_workers only allowed when env.is_closed = False"
)
for i, channel in enumerate(self.parent_channels):
if self._verbose:
print(f"closing {i}")
channel.send(("close", None))
self._events[i].wait()
self._events[i].clear()

del self.shared_tensordicts, self.shared_tensordict_parent

for channel in self.parent_channels:
channel.close()
for proc in self._workers:
proc.join()
try:
if self.is_closed:
raise RuntimeError(
"calling {self.__class__.__name__}._shutdown_workers only allowed when env.is_closed = False"
)
for i, channel in enumerate(self.parent_channels):
if self._verbose:
print(f"closing {i}")
channel.send(("close", None))
self._events[i].wait()
self._events[i].clear()

del self.shared_tensordicts, self.shared_tensordict_parent

for channel in self.parent_channels:
channel.close()
for proc in self._workers:
proc.join(timeout=1.0)
finally:
for proc in self._workers:
if proc.is_alive():
proc.terminate()
del self._workers
del self.parent_channels
self._cuda_events = None
Expand Down Expand Up @@ -1156,13 +1161,23 @@ def _run_worker_pipe_shared_mem(
del env_fun

i = -1
import torchrl

_timeout = torchrl._utils.BATCHED_PIPE_TIMEOUT

initialized = False

child_pipe.send("started")

while True:
try:
cmd, data = child_pipe.recv()
if child_pipe.poll(_timeout):
cmd, data = child_pipe.recv()
else:
raise TimeoutError(
f"Worker timed out after {_timeout}s, "
f"increase timeout if needed throught the BATCHED_PIPE_TIMEOUT environment variable."
)
except EOFError as err:
raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err
if cmd == "seed":
Expand Down

0 comments on commit f0b4814

Please sign in to comment.