diff --git a/.circleci/config.yml b/.circleci/config.yml
deleted file mode 100644
index 783d935b3..000000000
--- a/.circleci/config.yml
+++ /dev/null
@@ -1,753 +0,0 @@
-version: 2.1
-
-executors:
- windows-cpu:
- machine:
- resource_class: windows.xlarge
- image: windows-server-2019-vs2019:stable
- shell: bash.exe
-
- windows-gpu:
- machine:
- resource_class: windows.gpu.nvidia.medium
- image: windows-server-2019-nvidia:stable
- shell: bash.exe
-
-
-commands:
- checkout_merge:
- description: "checkout merge branch"
- steps:
- - checkout
- designate_upload_channel:
- description: "inserts the correct upload channel into ${BASH_ENV}"
- steps:
- - run:
- name: adding UPLOAD_CHANNEL to BASH_ENV
- command: |
- our_upload_channel=nightly
- # On tags upload to test instead
- if [[ -n "${CIRCLE_TAG}" ]] || [[ ${CIRCLE_BRANCH} =~ release/* ]]; then
- our_upload_channel=test
- fi
- echo "export UPLOAD_CHANNEL=${our_upload_channel}" >> ${BASH_ENV}
- apt_install:
- parameters:
- args:
- type: string
- descr:
- type: string
- default: ""
- update:
- type: boolean
- default: true
- steps:
- - run:
- name: >
- <<^ parameters.descr >> apt install << parameters.args >> < parameters.descr >>
- <<# parameters.descr >> << parameters.descr >> < parameters.descr >>
- command: |
- <<# parameters.update >> sudo apt update -qy < parameters.update >>
- sudo apt install << parameters.args >>
- pip_install:
- parameters:
- args:
- type: string
- descr:
- type: string
- default: ""
- user:
- type: boolean
- default: true
- steps:
- - run:
- name: >
- <<^ parameters.descr >> pip install << parameters.args >> < parameters.descr >>
- <<# parameters.descr >> << parameters.descr >> < parameters.descr >>
- command: >
- pip install
- <<# parameters.user >> --user < parameters.user >>
- --progress-bar=off
- << parameters.args >>
-
- install_tensordict:
- parameters:
- editable:
- type: boolean
- default: true
- steps:
- - pip_install:
- args: --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- descr: Install PyTorch from nightly releases
- - pip_install:
- args: --no-build-isolation <<# parameters.editable >> --editable < parameters.editable >> .
- descr: Install tensordict <<# parameters.editable >> in editable mode < parameters.editable >>
-
-
-binary_common: &binary_common
- parameters:
- # Edit these defaults to do a release
- build_version:
- description: "version number of release binary; by default, build a nightly"
- type: string
- default: ""
- pytorch_version:
- description: "PyTorch version to build against; by default, use a nightly"
- type: string
- default: ""
- # Don't edit these
- python_version:
- description: "Python version to build against (e.g., 3.7)"
- type: string
- cu_version:
- description: "CUDA version to build against, in CU format (e.g., cpu or cu100)"
- type: string
- default: "cpu"
- unicode_abi:
- description: "Python 2.7 wheel only: whether or not we are cp27mu (default: no)"
- type: string
- default: ""
- wheel_docker_image:
- description: "Wheel only: what docker image to use"
- type: string
- default: "pytorch/manylinux-cuda113"
- conda_docker_image:
- description: "Conda only: what docker image to use"
- type: string
- default: "pytorch/conda-builder:cpu"
- environment:
- PYTHON_VERSION: << parameters.python_version >>
- PYTORCH_VERSION: << parameters.pytorch_version >>
- UNICODE_ABI: << parameters.unicode_abi >>
- CU_VERSION: << parameters.cu_version >>
-
-smoke_test_common: &smoke_test_common
- <<: *binary_common
- docker:
- - image: tensordict/smoke_test:latest
-
-jobs:
-# circleci_consistency:
-# docker:
-# - image: circleci/python:3.7
-# steps:
-# - checkout
-# - pip_install:
-# args: jinja2 pyyaml
-# - run:
-# name: Check CircleCI config consistency
-# command: |
-# python .circleci/regenerate.py
-# git diff --exit-code || (echo ".circleci/config.yml not in sync with config.yml.in! Run .circleci/regenerate.py to update config"; exit 1)
-
- lint_python_and_config:
- docker:
- - image: circleci/python:3.8
- steps:
- - checkout
- - pip_install:
- args: pre-commit
- descr: Install lint utilities
- - run:
- name: Install pre-commit hooks
- command: pre-commit install-hooks
- - run:
- name: Lint Python code and config files
- command: pre-commit run --all-files
- - run:
- name: Required lint modifications
- when: on_fail
- command: git --no-pager diff
-
- # lint_c:
- # docker:
- # - image: circleci/python:3.7
- # steps:
- # - apt_install:
- # args: libtinfo5
- # descr: Install additional system libraries
- # - checkout
- # - run:
- # name: Install lint utilities
- # command: |
- # curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o clang-format
- # chmod +x clang-format
- # sudo mv clang-format /opt/clang-format
- # - run:
- # name: Lint C code
- # command: ./.circleci/unittest/linux/scripts/run-clang-format.py -r tensordict/csrc --clang-format-executable /opt/clang-format
- # - run:
- # name: Required lint modifications
- # when: on_fail
- # command: git --no-pager diff
-
- type_check_python:
- docker:
- - image: circleci/python:3.8
- steps:
- - checkout
- - pip_install:
- args: cmake ninja
- descr: Install CMake and Ninja
- - install_tensordict:
- editable: true
- - pip_install:
- args: mypy
- descr: Install Python type check utilities
- - run:
- name: Check Python types statically
- command: mypy --install-types --non-interactive --config-file mypy.ini
-
- binary_linux_wheel:
- <<: *binary_common
- docker:
- - image: << parameters.wheel_docker_image >>
- resource_class: 2xlarge+
- steps:
- - checkout_merge
- - designate_upload_channel
- - run: packaging/build_wheels.sh
- - store_artifacts:
- path: dist
- - persist_to_workspace:
- root: dist
- paths:
- - "*"
-
- binary_macos_wheel:
- <<: *binary_common
- macos:
- xcode: "14.0"
- steps:
- - checkout_merge
- - designate_upload_channel
- - run:
- # Cannot easily deduplicate this as source'ing activate
- # will set environment variables which we need to propagate
- # to build_wheel.sh
- command: |
- curl -o conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
- sh conda.sh -b
- source $HOME/miniconda3/bin/activate
- packaging/build_wheels.sh
- - store_artifacts:
- path: dist
- - persist_to_workspace:
- root: dist
- paths:
- - "*"
-
- unittest_linux_cpu:
- <<: *binary_common
-
- docker:
- - image: "pytorch/manylinux-cuda113"
- resource_class: 2xlarge+
-
- environment:
- TAR_OPTIONS: --no-same-owner
- PYTHON_VERSION: << parameters.python_version >>
- CU_VERSION: << parameters.cu_version >>
-
- steps:
- - checkout
- - designate_upload_channel
- - run:
- name: Generate cache key
- # This will refresh cache on Sundays, nightly build should generate new cache.
- command: echo "$(date +"%Y-%U")" > .circleci-weekly
- - restore_cache:
- keys:
- - env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
- - run:
- name: Setup
- command: .circleci/unittest/linux/scripts/setup_env.sh
-
- - save_cache:
-
- key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- paths:
- - conda
- - env
- - run:
- name: Install tensordict
- command: .circleci/unittest/linux/scripts/install.sh
- - run:
- name: Run tests
- command: .circleci/unittest/linux/scripts/run_test.sh
- - run:
- name: Codecov upload
- command: |
- curl -Os https://uploader.codecov.io/latest/linux/codecov
- chmod +x codecov
- ./codecov -t ${CODECOV_TOKEN} -s ./ -Z -F linux-cpu
- - run:
- name: Post process
- command: .circleci/unittest/linux/scripts/post_process.sh
- - store_test_results:
- path: test-results
-
- unittest_linux_gpu:
- <<: *binary_common
- machine:
- image: ubuntu-2004-cuda-11.4:202110-01
- resource_class: gpu.nvidia.medium
- environment:
- image_name: "pytorch/manylinux-cuda113"
- TAR_OPTIONS: --no-same-owner
- PYTHON_VERSION: << parameters.python_version >>
- CU_VERSION: << parameters.cu_version >>
-
- steps:
- - checkout
- - designate_upload_channel
- - run:
- name: Generate cache key
- # This will refresh cache on Sundays, nightly build should generate new cache.
- command: echo "$(date +"%Y-%U")" > .circleci-weekly
- - restore_cache:
-
- keys:
- - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- - run:
- name: Setup
- command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
- - save_cache:
-
- key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- paths:
- - conda
- - env
-# - run:
-# # Here we create an envlist file that contains some env variables that we want the docker container to be aware of.
-# # Normally, the CIRCLECI variable is set and available on all CI workflows: https://circleci.com/docs/2.0/env-vars/#built-in-environment-variables.
-# # They're available in all the other workflows (OSX and Windows).
-# # But here, we're running the unittest_linux_gpu workflows in a docker container, where those variables aren't accessible.
-# # So instead we dump the variables we need in env.list and we pass that file when invoking "docker run".
-# name: export CIRCLECI env var
-# command: echo "CIRCLECI=true" >> ./env.list
- - run:
- name: Install tensordict
-# command: bash .circleci/unittest/linux/scripts/install.sh
- command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux/scripts/install.sh
- - run:
- name: Run tests
- command: bash .circleci/unittest/linux/scripts/run_test.sh
-# command: docker run --env-file ./env.list -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/run_test.sh
- - run:
- name: Codecov upload
- command: |
- curl -Os https://uploader.codecov.io/latest/linux/codecov
- chmod +x codecov
- ./codecov -t ${CODECOV_TOKEN} -s ./ -Z -F linux-gpu
- - run:
- name: Post Process
- command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/post_process.sh
- - store_test_results:
- path: test-results
-
- unittest_linux_torchrec_gpu:
- <<: *binary_common
- machine:
- image: ubuntu-2004-cuda-11.4:202110-01
- resource_class: gpu.nvidia.medium
- environment:
- image_name: "pytorch/manylinux-cuda113"
- TAR_OPTIONS: --no-same-owner
- PYTHON_VERSION: << parameters.python_version >>
- CU_VERSION: << parameters.cu_version >>
-
- steps:
- - checkout
- - designate_upload_channel
- - run:
- name: Generate cache key
- # This will refresh cache on Sundays, nightly build should generate new cache.
- command: echo "$(date +"%Y-%U")" > .circleci-weekly
- - restore_cache:
-
- keys:
- - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_torchrec/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- - run:
- name: Setup
- command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_torchrec/scripts/setup_env.sh
- - save_cache:
-
- key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_torchrec/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- paths:
- - conda
- - env
-# - run:
-# # Here we create an envlist file that contains some env variables that we want the docker container to be aware of.
-# # Normally, the CIRCLECI variable is set and available on all CI workflows: https://circleci.com/docs/2.0/env-vars/#built-in-environment-variables.
-# # They're available in all the other workflows (OSX and Windows).
-# # But here, we're running the unittest_linux_gpu workflows in a docker container, where those variables aren't accessible.
-# # So instead we dump the variables we need in env.list and we pass that file when invoking "docker run".
-# name: export CIRCLECI env var
-# command: echo "CIRCLECI=true" >> ./env.list
- - run:
- name: Install tensordict
-# command: bash .circleci/unittest/linux_torchrec/scripts/install.sh
- command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux_torchrec/scripts/install.sh
- - run:
- name: Run tests
- command: bash .circleci/unittest/linux_torchrec/scripts/run_test.sh
-# command: docker run --env-file ./env.list -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_torchrec/scripts/run_test.sh
- - run:
- name: Codecov upload
- command: |
- curl -Os https://uploader.codecov.io/latest/linux/codecov
- chmod +x codecov
- ./codecov -t ${CODECOV_TOKEN} -s ./ -Z -F linux-torchrec-gpu
- - run:
- name: Post Process
- command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_torchrec/scripts/post_process.sh
- - store_test_results:
- path: test-results
-
- unittest_linux_stable_cpu:
- <<: *binary_common
-
- docker:
- - image: "pytorch/manylinux-cuda113"
- resource_class: 2xlarge+
-
- environment:
- TAR_OPTIONS: --no-same-owner
- PYTHON_VERSION: << parameters.python_version >>
- CU_VERSION: << parameters.cu_version >>
-
- steps:
- - checkout
- - designate_upload_channel
- - run:
- name: Generate cache key
- # This will refresh cache on Sundays, nightly build should generate new cache.
- command: echo "$(date +"%Y-%U")" > .circleci-weekly
- - restore_cache:
-
- keys:
- - env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- - run:
- name: Setup
- command: .circleci/unittest/linux_stable/scripts/setup_env.sh
-
- - save_cache:
-
- key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- paths:
- - conda
- - env
- - run:
- name: Install tensordict
- command: .circleci/unittest/linux_stable/scripts/install.sh
- - run:
- name: Run tests
- command: .circleci/unittest/linux_stable/scripts/run_test.sh
- - run:
- name: Codecov upload
- command: |
- curl -Os https://uploader.codecov.io/latest/linux/codecov
- chmod +x codecov
- ./codecov -t ${CODECOV_TOKEN} -s ./ -Z -F linux-stable-cpu
- - run:
- name: Post process
- command: .circleci/unittest/linux_stable/scripts/post_process.sh
- - store_test_results:
- path: test-results
-
- unittest_linux_stable_gpu:
- <<: *binary_common
- machine:
- image: ubuntu-2004-cuda-11.4:202110-01
- resource_class: gpu.nvidia.medium
- environment:
- image_name: "pytorch/manylinux-cuda113"
- TAR_OPTIONS: --no-same-owner
- PYTHON_VERSION: << parameters.python_version >>
- CU_VERSION: << parameters.cu_version >>
-
- steps:
- - checkout
- - designate_upload_channel
- - run:
- name: Generate cache key
- # This will refresh cache on Sundays, nightly build should generate new cache.
- command: echo "$(date +"%Y-%U")" > .circleci-weekly
- - restore_cache:
-
- keys:
- - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- - run:
- name: Setup
- command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_stable/scripts/setup_env.sh
- - save_cache:
-
- key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- paths:
- - conda
- - env
- - run:
- name: Install tensordict
-# command: bash .circleci/unittest/linux_stable/scripts/install.sh
- command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux/scripts/install.sh
- - run:
- name: Run tests
- command: bash .circleci/unittest/linux_stable/scripts/run_test.sh
-# command: docker run --env-file ./env.list -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/run_test.sh
- - run:
- name: Codecov upload
- command: |
- curl -Os https://uploader.codecov.io/latest/linux/codecov
- chmod +x codecov
- ./codecov -t ${CODECOV_TOKEN} -s ./ -Z -F linux-stable-gpu
- - run:
- name: Post Process
- command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_stable/scripts/post_process.sh
- - store_test_results:
- path: test-results
-
- unittest_macos_cpu:
- <<: *binary_common
- macos:
- xcode: "13.0"
-
- resource_class: large
- steps:
- - checkout
- - designate_upload_channel
- - run:
- name: Install wget
- command: HOMEBREW_NO_AUTO_UPDATE=1 brew install wget
- # Disable brew auto update which is very slow
- - run:
- name: Generate cache key
- # This will refresh cache on Sundays, nightly build should generate new cache.
- command: echo "$(date +"%Y-%U")" > .circleci-weekly
- - restore_cache:
-
- keys:
- - env-v3-macos-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- - run:
- name: Setup
- command: .circleci/unittest/linux/scripts/setup_env.sh
- - save_cache:
-
- key: env-v3-macos-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- paths:
- - conda
- - env
- - run:
- name: Install tensordict
- command: .circleci/unittest/linux/scripts/install.sh
- - run:
- name: Run tests
- command: .circleci/unittest/linux/scripts/run_test.sh
- - run:
- name: Codecov upload
- command: |
- curl -Os https://uploader.codecov.io/latest/macos/codecov
- chmod +x codecov
- ./codecov -t ${CODECOV_TOKEN} -s ./ -Z -F linux-macos-cpu
- - run:
- name: Post process
- command: .circleci/unittest/linux/scripts/post_process.sh
- - store_test_results:
- path: test-results
-
- unittest_rl_linux_optdeps_gpu:
- # A hopefully temporary pipeline
- <<: *binary_common
- machine:
- image: ubuntu-2004-cuda-11.4:202110-01
- resource_class: gpu.nvidia.medium
- environment:
- image_name: "nvidia/cudagl:11.4.0-base"
- TAR_OPTIONS: --no-same-owner
- PYTHON_VERSION: << parameters.python_version >>
- CU_VERSION: << parameters.cu_version >>
-
- steps:
- - checkout
- - designate_upload_channel
- - run:
- name: Generate cache key
- # This will refresh cache on Sundays, nightly build should generate new cache.
- command: echo "$(date +"%Y-%U")" > .circleci-weekly
- - restore_cache:
-
- keys:
- - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/rl_linux_optdeps/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- - run:
- name: Setup
- command: .circleci/unittest/rl_linux_optdeps/scripts/setup_env.sh
- - save_cache:
-
- key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/rl_linux_optdeps/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
-
- paths:
- - conda
- - env
-# - run:
-# # Here we create an envlist file that contains some env variables that we want the docker container to be aware of.
-# # Normally, the CIRCLECI variable is set and available on all CI workflows: https://circleci.com/docs/2.0/env-vars/#built-in-environment-variables.
-# # They're available in all the other workflows (OSX and Windows).
-# # But here, we're running the unittest_linux_gpu workflows in a docker container, where those variables aren't accessible.
-# # So instead we dump the variables we need in env.list and we pass that file when invoking "docker run".
-# name: export CIRCLECI env var
-# command: echo "CIRCLECI=true" >> ./env.list
- - run:
- name: Install torchrl
-# command: bash .circleci/unittest/rl_linux_optdeps/scripts/install.sh
- command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/rl_linux_optdeps/scripts/install.sh
- - run:
- name: Run tests
- command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/rl_linux_optdeps/scripts/run_test.sh
-# command: docker run --env-file ./env.list -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/run_test.sh
- - run:
- name: Codecov upload
- command: |
- curl -Os https://uploader.codecov.io/latest/linux/codecov
- chmod +x codecov
- ./codecov -t ${CODECOV_TOKEN} -s ./ -Z -F linux-outdeps-gpu
- - run:
- name: Post Process
- command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/rl_linux_optdeps/scripts/post_process.sh
- - store_test_results:
- path: test-results
-
-workflows:
- lint:
- jobs:
-# - circleci_consistency
- - lint_python_and_config
-# - lint_c
-# - type_check_python
-
- build:
- jobs:
- - binary_linux_wheel:
- conda_docker_image: pytorch/conda-builder:cpu
- cu_version: cpu
- name: binary_linux_wheel_py3.8_cpu
- python_version: '3.8'
- wheel_docker_image: pytorch/manylinux-cuda102
-
- - binary_linux_wheel:
- conda_docker_image: pytorch/conda-builder:cpu
- cu_version: cpu
- name: binary_linux_wheel_py3.9_cpu
- python_version: '3.9'
- wheel_docker_image: pytorch/manylinux-cuda102
-
- - binary_linux_wheel:
- conda_docker_image: pytorch/conda-builder:cpu
- cu_version: cpu
- name: binary_linux_wheel_py3.10_cpu
- python_version: '3.10'
- wheel_docker_image: pytorch/manylinux-cuda102
-
- - binary_macos_wheel:
- conda_docker_image: pytorch/conda-builder:cpu
- cu_version: cpu
- name: binary_macos_wheel_py3.8_cpu
- python_version: '3.8'
- wheel_docker_image: pytorch/manylinux-cuda102
-
- - binary_macos_wheel:
- conda_docker_image: pytorch/conda-builder:cpu
- cu_version: cpu
- name: binary_macos_wheel_py3.9_cpu
- python_version: '3.9'
- wheel_docker_image: pytorch/manylinux-cuda102
-
- - binary_macos_wheel:
- conda_docker_image: pytorch/conda-builder:cpu
- cu_version: cpu
- name: binary_macos_wheel_py3.10_cpu
- python_version: '3.10'
- wheel_docker_image: pytorch/manylinux-cuda102
-
- unittest:
- jobs:
- - unittest_macos_cpu:
- cu_version: cpu
- name: unittest_macos_cpu_py3.8
- python_version: '3.8'
- - unittest_linux_cpu:
- cu_version: cpu
- name: unittest_linux_cpu_py3.8
- python_version: '3.8'
- - unittest_linux_gpu:
- cu_version: cu113
- name: unittest_linux_gpu_py3.8
- python_version: '3.8'
- - unittest_linux_stable_cpu:
- cu_version: cpu
- name: unittest_linux_stable_cpu_py3.8
- python_version: '3.8'
- - unittest_linux_stable_gpu:
- cu_version: cu113
- name: unittest_linux_stable_gpu_py3.8
- python_version: '3.8'
-
- - unittest_macos_cpu:
- cu_version: cpu
- name: unittest_macos_cpu_py3.9
- python_version: '3.9'
- - unittest_linux_cpu:
- cu_version: cpu
- name: unittest_linux_cpu_py3.9
- python_version: '3.9'
- - unittest_linux_gpu:
- cu_version: cu113
- name: unittest_linux_gpu_py3.9
- python_version: '3.9'
- - unittest_linux_stable_cpu:
- cu_version: cpu
- name: unittest_linux_stable_cpu_py3.9
- python_version: '3.9'
- - unittest_linux_stable_gpu:
- cu_version: cu113
- name: unittest_linux_stable_gpu_py3.9
- python_version: '3.9'
- - unittest_macos_cpu:
- cu_version: cpu
- name: unittest_macos_cpu_py3.10
- python_version: '3.10'
- - unittest_linux_cpu:
- cu_version: cpu
- name: unittest_linux_cpu_py3.10
- python_version: '3.10'
- - unittest_linux_gpu:
- cu_version: cu113
- name: unittest_linux_gpu_py3.10
- python_version: '3.10'
- - unittest_linux_stable_cpu:
- cu_version: cpu
- name: unittest_linux_stable_cpu_py3.10
- python_version: '3.10'
- - unittest_linux_stable_gpu:
- cu_version: cu113
- name: unittest_linux_stable_gpu_py3.10
- python_version: '3.10'
-
- - unittest_rl_linux_optdeps_gpu:
- cu_version: cu117
- name: unittest_rl_linux_optdeps_gpu_py3.9
- python_version: '3.9'
- filters:
- branches:
- ignore: main
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index fa228beb1..39447b08f 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -59,5 +59,5 @@ If you know or suspect the reason for this bug, paste the code lines and suggest
## Checklist
- [ ] I have checked that there is no similar issue in the repo (**required**)
-- [ ] I have read the [documentation](https://github.com/pytorch/rl/tree/main/docs/) (**required**)
+- [ ] I have read the [documentation](https://github.com/pytorch/tensordict/tree/main/docs/) (**required**)
- [ ] I have provided a minimal working example to reproduce the bug (**required**)
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 4be152bfb..6ca7443b2 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -8,7 +8,7 @@ Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax `close #15213` if this solves the issue #15213
-- [ ] I have raised an issue to propose this change ([required](https://github.com/pytorch/rl/issues) for new features and bug fixes)
+- [ ] I have raised an issue to propose this change ([required](https://github.com/pytorch/tensordict/issues) for new features and bug fixes)
## Types of changes
@@ -25,7 +25,7 @@ What types of changes does your code introduce? Remove all that do not apply:
Go over all the following points, and put an `x` in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!
-- [ ] I have read the [CONTRIBUTION](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md) guide (**required**)
+- [ ] I have read the [CONTRIBUTION](https://github.com/pytorch/tensordict/blob/main/CONTRIBUTING.md) guide (**required**)
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
- [ ] I have updated the documentation accordingly.
diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml
similarity index 100%
rename from .circleci/unittest/linux/scripts/environment.yml
rename to .github/unittest/linux/scripts/environment.yml
diff --git a/.circleci/unittest/linux/scripts/install.sh b/.github/unittest/linux/scripts/install.sh
similarity index 52%
rename from .circleci/unittest/linux/scripts/install.sh
rename to .github/unittest/linux/scripts/install.sh
index 5c42aee90..65904ffe1 100755
--- a/.circleci/unittest/linux/scripts/install.sh
+++ b/.github/unittest/linux/scripts/install.sh
@@ -6,6 +6,7 @@ unset PYTORCH_VERSION
# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config.
set -e
+set -v
eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env
@@ -25,17 +26,28 @@ fi
git submodule sync && git submodule update --init --recursive
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
-if [ "${CU_VERSION:-}" == cpu ] ; then
- pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
+if [[ "$TORCH_VERSION" == "nightly" ]]; then
+ if [ "${CU_VERSION:-}" == cpu ] ; then
+ python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
+ else
+ python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
+ fi
+elif [[ "$TORCH_VERSION" == "stable" ]]; then
+ if [ "${CU_VERSION:-}" == cpu ] ; then
+ python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
+ else
+ python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$CU_VERSION
+ fi
else
- pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu113
+ printf "Failed to install pytorch"
+ exit 1
fi
printf "* Installing tensordict\n"
-pip3 install -e .
+python setup.py develop
# install torchsnapshot nightly
-pip3 install git+https://github.com/pytorch/torchsnapshot
+python -m pip install git+https://github.com/pytorch/torchsnapshot --no-build-isolation
# smoke test
python -c "import functorch;import torchsnapshot"
diff --git a/.circleci/unittest/linux/scripts/post_process.sh b/.github/unittest/linux/scripts/post_process.sh
similarity index 100%
rename from .circleci/unittest/linux/scripts/post_process.sh
rename to .github/unittest/linux/scripts/post_process.sh
diff --git a/.circleci/unittest/linux/scripts/run-clang-format.py b/.github/unittest/linux/scripts/run-clang-format.py
similarity index 100%
rename from .circleci/unittest/linux/scripts/run-clang-format.py
rename to .github/unittest/linux/scripts/run-clang-format.py
diff --git a/.circleci/unittest/linux/scripts/run_test.sh b/.github/unittest/linux/scripts/run_test.sh
similarity index 100%
rename from .circleci/unittest/linux/scripts/run_test.sh
rename to .github/unittest/linux/scripts/run_test.sh
diff --git a/.circleci/unittest/linux_torchrec/scripts/setup_env.sh b/.github/unittest/linux/scripts/setup_env.sh
similarity index 89%
rename from .circleci/unittest/linux_torchrec/scripts/setup_env.sh
rename to .github/unittest/linux/scripts/setup_env.sh
index 609604386..b3a61419b 100755
--- a/.circleci/unittest/linux_torchrec/scripts/setup_env.sh
+++ b/.github/unittest/linux/scripts/setup_env.sh
@@ -6,6 +6,8 @@
# Do not install PyTorch and torchvision here, otherwise they also get cached.
set -e
+set -v
+
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Avoid error: "fatal: unsafe repository"
@@ -24,7 +26,7 @@ esac
# 1. Install conda at ./conda
if [ ! -d "${conda_dir}" ]; then
printf "* Installing conda\n"
- wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh"
+ wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-${ARCH}.sh"
bash ./miniconda.sh -b -f -p "${conda_dir}"
fi
eval "$(${conda_dir}/bin/conda shell.bash hook)"
@@ -45,3 +47,8 @@ cat "${this_dir}/environment.yml"
pip install pip --upgrade
conda env update --file "${this_dir}/environment.yml" --prune
+
+#if [[ $OSTYPE == 'darwin'* ]]; then
+# printf "* Installing C++ for OSX\n"
+# conda install -c conda-forge cxx-compiler -y
+#fi
diff --git a/.circleci/unittest/linux_stable/scripts/environment.yml b/.github/unittest/linux_stable/scripts/environment.yml
similarity index 100%
rename from .circleci/unittest/linux_stable/scripts/environment.yml
rename to .github/unittest/linux_stable/scripts/environment.yml
diff --git a/.circleci/unittest/linux_stable/scripts/install.sh b/.github/unittest/linux_stable/scripts/install.sh
similarity index 100%
rename from .circleci/unittest/linux_stable/scripts/install.sh
rename to .github/unittest/linux_stable/scripts/install.sh
diff --git a/.circleci/unittest/linux_stable/scripts/post_process.sh b/.github/unittest/linux_stable/scripts/post_process.sh
similarity index 100%
rename from .circleci/unittest/linux_stable/scripts/post_process.sh
rename to .github/unittest/linux_stable/scripts/post_process.sh
diff --git a/.circleci/unittest/linux_stable/scripts/run-clang-format.py b/.github/unittest/linux_stable/scripts/run-clang-format.py
similarity index 100%
rename from .circleci/unittest/linux_stable/scripts/run-clang-format.py
rename to .github/unittest/linux_stable/scripts/run-clang-format.py
diff --git a/.circleci/unittest/linux_stable/scripts/run_test.sh b/.github/unittest/linux_stable/scripts/run_test.sh
similarity index 100%
rename from .circleci/unittest/linux_stable/scripts/run_test.sh
rename to .github/unittest/linux_stable/scripts/run_test.sh
diff --git a/.circleci/unittest/linux_stable/scripts/setup_env.sh b/.github/unittest/linux_stable/scripts/setup_env.sh
similarity index 100%
rename from .circleci/unittest/linux_stable/scripts/setup_env.sh
rename to .github/unittest/linux_stable/scripts/setup_env.sh
diff --git a/.circleci/unittest/linux_torchrec/scripts/environment.yml b/.github/unittest/linux_torchrec/scripts/environment.yml
similarity index 100%
rename from .circleci/unittest/linux_torchrec/scripts/environment.yml
rename to .github/unittest/linux_torchrec/scripts/environment.yml
diff --git a/.circleci/unittest/linux_torchrec/scripts/install.sh b/.github/unittest/linux_torchrec/scripts/install.sh
similarity index 100%
rename from .circleci/unittest/linux_torchrec/scripts/install.sh
rename to .github/unittest/linux_torchrec/scripts/install.sh
diff --git a/.circleci/unittest/linux_torchrec/scripts/post_process.sh b/.github/unittest/linux_torchrec/scripts/post_process.sh
similarity index 100%
rename from .circleci/unittest/linux_torchrec/scripts/post_process.sh
rename to .github/unittest/linux_torchrec/scripts/post_process.sh
diff --git a/.circleci/unittest/linux_torchrec/scripts/run-clang-format.py b/.github/unittest/linux_torchrec/scripts/run-clang-format.py
similarity index 100%
rename from .circleci/unittest/linux_torchrec/scripts/run-clang-format.py
rename to .github/unittest/linux_torchrec/scripts/run-clang-format.py
diff --git a/.circleci/unittest/linux_torchrec/scripts/run_test.sh b/.github/unittest/linux_torchrec/scripts/run_test.sh
similarity index 100%
rename from .circleci/unittest/linux_torchrec/scripts/run_test.sh
rename to .github/unittest/linux_torchrec/scripts/run_test.sh
diff --git a/.circleci/unittest/linux/scripts/setup_env.sh b/.github/unittest/linux_torchrec/scripts/setup_env.sh
similarity index 100%
rename from .circleci/unittest/linux/scripts/setup_env.sh
rename to .github/unittest/linux_torchrec/scripts/setup_env.sh
diff --git a/.circleci/unittest/rl_linux_optdeps/scripts/environment.yml b/.github/unittest/rl_linux_optdeps/scripts/environment.yml
similarity index 100%
rename from .circleci/unittest/rl_linux_optdeps/scripts/environment.yml
rename to .github/unittest/rl_linux_optdeps/scripts/environment.yml
diff --git a/.circleci/unittest/rl_linux_optdeps/scripts/install.sh b/.github/unittest/rl_linux_optdeps/scripts/install.sh
similarity index 100%
rename from .circleci/unittest/rl_linux_optdeps/scripts/install.sh
rename to .github/unittest/rl_linux_optdeps/scripts/install.sh
diff --git a/.circleci/unittest/rl_linux_optdeps/scripts/post_process.sh b/.github/unittest/rl_linux_optdeps/scripts/post_process.sh
similarity index 100%
rename from .circleci/unittest/rl_linux_optdeps/scripts/post_process.sh
rename to .github/unittest/rl_linux_optdeps/scripts/post_process.sh
diff --git a/.circleci/unittest/rl_linux_optdeps/scripts/run-clang-format.py b/.github/unittest/rl_linux_optdeps/scripts/run-clang-format.py
similarity index 100%
rename from .circleci/unittest/rl_linux_optdeps/scripts/run-clang-format.py
rename to .github/unittest/rl_linux_optdeps/scripts/run-clang-format.py
diff --git a/.circleci/unittest/rl_linux_optdeps/scripts/run_test.sh b/.github/unittest/rl_linux_optdeps/scripts/run_test.sh
similarity index 100%
rename from .circleci/unittest/rl_linux_optdeps/scripts/run_test.sh
rename to .github/unittest/rl_linux_optdeps/scripts/run_test.sh
diff --git a/.circleci/unittest/rl_linux_optdeps/scripts/setup_env.sh b/.github/unittest/rl_linux_optdeps/scripts/setup_env.sh
similarity index 100%
rename from .circleci/unittest/rl_linux_optdeps/scripts/setup_env.sh
rename to .github/unittest/rl_linux_optdeps/scripts/setup_env.sh
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
new file mode 100644
index 000000000..0818bd228
--- /dev/null
+++ b/.github/workflows/lint.yml
@@ -0,0 +1,75 @@
+name: Lint
+
+on:
+ pull_request:
+ push:
+ branches:
+ - nightly
+ - main
+ - release/*
+ workflow_dispatch:
+
+concurrency:
+ # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}.
+ # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke.
+ group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
+ cancel-in-progress: true
+
+jobs:
+ python-source-and-configs:
+ uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ with:
+ repository: pytorch/tensordict
+ script: |
+ set -euo pipefail
+
+ echo '::group::Setup environment'
+ CONDA_PATH=$(which conda)
+ eval "$(${CONDA_PATH} shell.bash hook)"
+ conda create --name ci --quiet --yes python=3.8 pip
+ conda activate ci
+ echo '::endgroup::'
+
+ echo '::group::Install lint tools'
+ pip install --progress-bar=off pre-commit
+ echo '::endgroup::'
+
+ echo '::group::Lint Python source and configs'
+ set +e
+ pre-commit run --all-files
+
+ if [ $? -ne 0 ]; then
+ git --no-pager diff
+ exit 1
+ fi
+ echo '::endgroup::'
+
+ c-source:
+ uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ with:
+ repository: pytorch/tensordict
+ script: |
+ set -euo pipefail
+
+ echo '::group::Setup environment'
+ CONDA_PATH=$(which conda)
+ eval "$(${CONDA_PATH} shell.bash hook)"
+ conda create --name ci --quiet --yes -c conda-forge python=3.8 ncurses=5 libgcc
+ conda activate ci
+ export LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}"
+ echo '::endgroup::'
+
+ echo '::group::Install lint tools'
+ curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o ./clang-format
+ chmod +x ./clang-format
+ echo '::endgroup::'
+
+ echo '::group::Lint C source'
+ set +e
+ ./.github/unittest/linux/scripts/run-clang-format.py -r torchrl/csrc --clang-format-executable ./clang-format
+
+ if [ $? -ne 0 ]; then
+ git --no-pager diff
+ exit 1
+ fi
+ echo '::endgroup::'
diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml
index c45f15d42..31eaf72df 100644
--- a/.github/workflows/nightly_build.yml
+++ b/.github/workflows/nightly_build.yml
@@ -30,7 +30,7 @@ concurrency:
jobs:
build-wheel-linux:
# Don't run on forked repos.
- if: github.repository_owner == 'pytorch-labs'
+ if: github.repository_owner == 'pytorch'
runs-on: ubuntu-20.04
strategy:
matrix:
@@ -69,7 +69,7 @@ jobs:
build-wheel-mac:
# Don't run on forked repos.
- if: github.repository_owner == 'pytorch-labs'
+ if: github.repository_owner == 'pytorch'
runs-on: macos-latest
strategy:
matrix:
@@ -101,7 +101,7 @@ jobs:
test-wheel-mac:
# Don't run on forked repos.
- if: github.repository_owner == 'pytorch-labs'
+ if: github.repository_owner == 'pytorch'
needs: build-wheel-mac
runs-on: macos-latest
strategy:
@@ -150,7 +150,7 @@ jobs:
upload-wheel-linux:
# Don't run on forked repos.
- if: github.repository_owner == 'pytorch-labs'
+ if: github.repository_owner == 'pytorch'
needs: test-wheel-linux
runs-on: ubuntu-20.04
strategy:
@@ -181,7 +181,7 @@ jobs:
upload-wheel-mac:
# Don't run on forked repos.
- if: github.repository_owner == 'pytorch-labs'
+ if: github.repository_owner == 'pytorch'
needs: test-wheel-mac
runs-on: macos-latest
strategy:
@@ -209,7 +209,7 @@ jobs:
test-wheel-linux:
# Don't run on forked repos.
- if: github.repository_owner == 'pytorch-labs'
+ if: github.repository_owner == 'pytorch'
needs: build-wheel-linux
runs-on: ubuntu-20.04
strategy:
@@ -269,7 +269,7 @@ jobs:
build-wheel-windows:
# Don't run on forked repos.
- if: github.repository_owner == 'pytorch-labs'
+ if: github.repository_owner == 'pytorch'
runs-on: windows-latest
strategy:
matrix:
@@ -301,7 +301,7 @@ jobs:
test-wheel-windows:
# Don't run on forked repos.
- if: github.repository_owner == 'pytorch-labs'
+ if: github.repository_owner == 'pytorch'
needs: build-wheel-windows
runs-on: windows-latest
strategy:
@@ -355,7 +355,7 @@ jobs:
upload-wheel-windows:
# Don't run on forked repos.
- if: github.repository_owner == 'pytorch-labs'
+ if: github.repository_owner == 'pytorch'
needs: test-wheel-windows
runs-on: windows-latest
strategy:
diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml
new file mode 100644
index 000000000..c15842f9e
--- /dev/null
+++ b/.github/workflows/test-linux.yml
@@ -0,0 +1,137 @@
+name: Unit-tests on Linux
+
+on:
+ pull_request:
+ push:
+ branches:
+ - nightly
+ - main
+ - release/*
+ workflow_dispatch:
+
+env:
+ CHANNEL: "nightly"
+
+concurrency:
+ # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}.
+ # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke.
+ group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
+ cancel-in-progress: true
+
+jobs:
+ test-gpu:
+ strategy:
+ matrix:
+ python_version: ["3.8"]
+ cuda_arch_version: ["12.1"]
+ fail-fast: false
+ uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ with:
+ runner: linux.g5.4xlarge.nvidia.gpu
+ repository: pytorch/tensordict
+ gpu-arch-type: cuda
+ gpu-arch-version: ${{ matrix.cuda_arch_version }}
+ script: |
+ # Set env vars from matrix
+ export PYTHON_VERSION=${{ matrix.python_version }}
+ # Commenting these out for now because the GPU test are not working inside docker
+ export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }}
+ export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}"
+ export TORCH_VERSION=nightly
+ # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines
+ #export CU_VERSION="cpu"
+ export ARCH=x86_64
+
+ echo "PYTHON_VERSION: $PYTHON_VERSION"
+ echo "CU_VERSION: $CU_VERSION"
+
+ ## setup_env.sh
+ bash .github/unittest/linux/scripts/setup_env.sh
+ bash .github/unittest/linux/scripts/install.sh
+ bash .github/unittest/linux/scripts/run_test.sh
+ bash .github/unittest/linux/scripts/post_process.sh
+
+ test-cpu:
+ strategy:
+ matrix:
+ python_version: ["3.8", "3.9", "3.10", "3.11"]
+ fail-fast: false
+ uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ with:
+ runner: linux.12xlarge
+ repository: pytorch/tensordict
+ timeout: 90
+ script: |
+ # Set env vars from matrix
+ export PYTHON_VERSION=${{ matrix.python_version }}
+ export CU_VERSION="cpu"
+ export TORCH_VERSION=nightly
+ export ARCH=x86_64
+
+ echo "PYTHON_VERSION: $PYTHON_VERSION"
+ echo "CU_VERSION: $CU_VERSION"
+
+ ## setup_env.sh
+ bash .github/unittest/linux/scripts/setup_env.sh
+ bash .github/unittest/linux/scripts/install.sh
+ bash .github/unittest/linux/scripts/run_test.sh
+ bash .github/unittest/linux/scripts/post_process.sh
+
+ test-stable-gpu:
+ strategy:
+ matrix:
+ python_version: ["3.8"]
+ cuda_arch_version: ["12.1"]
+ fail-fast: false
+ uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ with:
+ runner: linux.g5.4xlarge.nvidia.gpu
+ repository: pytorch/tensordict
+ gpu-arch-type: cuda
+ gpu-arch-version: ${{ matrix.cuda_arch_version }}
+ timeout: 90
+ script: |
+ # Set env vars from matrix
+ export PYTHON_VERSION=${{ matrix.python_version }}
+ # Commenting these out for now because the GPU test are not working inside docker
+ export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }}
+ export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}"
+ export TORCH_VERSION=stable
+ # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines
+ #export CU_VERSION="cpu"
+ export ARCH=x86_64
+
+ echo "PYTHON_VERSION: $PYTHON_VERSION"
+ echo "CU_VERSION: $CU_VERSION"
+
+ ## setup_env.sh
+ bash .github/unittest/linux/scripts/setup_env.sh
+ bash .github/unittest/linux/scripts/install.sh
+ bash .github/unittest/linux/scripts/run_test.sh
+ bash .github/unittest/linux/scripts/post_process.sh
+
+ test-stable-cpu:
+ strategy:
+ matrix:
+ python_version: ["3.8", "3.11"]
+ fail-fast: false
+ uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ with:
+ runner: linux.12xlarge
+ repository: pytorch/tensordict
+ timeout: 90
+ script: |
+ # Set env vars from matrix
+ export PYTHON_VERSION=${{ matrix.python_version }}
+ export CU_VERSION="cpu"
+ export TORCH_VERSION=stable
+ export ARCH=x86_64
+
+ echo "PYTHON_VERSION: $PYTHON_VERSION"
+ echo "CU_VERSION: $CU_VERSION"
+
+ ## setup_env.sh
+ bash .github/unittest/linux/scripts/setup_env.sh
+ bash .github/unittest/linux/scripts/install.sh
+ bash .github/unittest/linux/scripts/run_test.sh
+ bash .github/unittest/linux/scripts/post_process.sh
diff --git a/.github/workflows/test-macos.yml b/.github/workflows/test-macos.yml
new file mode 100644
index 000000000..fc8218c9d
--- /dev/null
+++ b/.github/workflows/test-macos.yml
@@ -0,0 +1,77 @@
+name: Unit-tests on MacOS
+
+on:
+ pull_request:
+ push:
+ branches:
+ - nightly
+ - main
+ - release/*
+ workflow_dispatch:
+
+env:
+ CHANNEL: "nightly"
+
+concurrency:
+ # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}.
+ # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke.
+ group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
+ cancel-in-progress: true
+
+jobs:
+ tests-intel:
+ strategy:
+ matrix:
+ python_version: ["3.8", "3.11"]
+ fail-fast: false
+ uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
+ with:
+ repository: pytorch/tensordict
+ timeout: 120
+ script: |
+ # Set env vars from matrix
+ set -e
+ set -v
+ export PYTHON_VERSION=${{ matrix.python_version }}
+ export CU_VERSION="cpu"
+ export SYSTEM_VERSION_COMPAT=0
+ export TORCH_VERSION=nightly
+ export ARCH=x86_64
+
+ echo "PYTHON_VERSION: $PYTHON_VERSION"
+ echo "CU_VERSION: $CU_VERSION"
+
+ ## setup_env.sh
+ bash .github/unittest/linux/scripts/setup_env.sh
+ bash .github/unittest/linux/scripts/install.sh
+ bash .github/unittest/linux/scripts/run_test.sh
+ bash .github/unittest/linux/scripts/post_process.sh
+
+ tests-silicon:
+ strategy:
+ matrix:
+ python_version: ["3.9"]
+ fail-fast: false
+ uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
+ with:
+ runner: macos-m1-12
+ repository: pytorch/tensordict
+ timeout: 120
+ script: |
+ # Set env vars from matrix
+ set -e
+ set -v
+ export PYTHON_VERSION=${{ matrix.python_version }}
+ export CU_VERSION="cpu"
+ export SYSTEM_VERSION_COMPAT=0
+ export TORCH_VERSION=nightly
+ export ARCH=arm64
+
+ echo "PYTHON_VERSION: $PYTHON_VERSION"
+ echo "CU_VERSION: $CU_VERSION"
+
+ ## setup_env.sh
+ bash .github/unittest/linux/scripts/setup_env.sh
+ bash .github/unittest/linux/scripts/install.sh
+ bash .github/unittest/linux/scripts/run_test.sh
+ bash .github/unittest/linux/scripts/post_process.sh
diff --git a/.github/workflows/test-rl-gpu.yml b/.github/workflows/test-rl-gpu.yml
new file mode 100644
index 000000000..ab284d754
--- /dev/null
+++ b/.github/workflows/test-rl-gpu.yml
@@ -0,0 +1,51 @@
+name: Unit-tests (RL) on Linux
+
+on:
+ pull_request:
+ push:
+ branches:
+ - nightly
+ - main
+ - release/*
+ workflow_dispatch:
+
+env:
+ CHANNEL: "nightly"
+
+concurrency:
+ # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}.
+ # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke.
+ group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
+ cancel-in-progress: true
+
+jobs:
+ test-gpu:
+ strategy:
+ matrix:
+ python_version: ["3.8"]
+ cuda_arch_version: ["12.1"]
+ fail-fast: false
+ uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ with:
+ runner: linux.g5.4xlarge.nvidia.gpu
+ repository: pytorch/tensordict
+ gpu-arch-type: cuda
+ gpu-arch-version: ${{ matrix.cuda_arch_version }}
+ script: |
+ # Set env vars from matrix
+ export PYTHON_VERSION=${{ matrix.python_version }}
+ # Commenting these out for now because the GPU test are not working inside docker
+ export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }}
+ export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}"
+ export TORCH_VERSION=nightly
+ # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines
+ #export CU_VERSION="cpu"
+
+ echo "PYTHON_VERSION: $PYTHON_VERSION"
+ echo "CU_VERSION: $CU_VERSION"
+
+ ## setup_env.sh
+ bash .github/unittest/rl_linux_optdeps/scripts/setup_env.sh
+ bash .github/unittest/rl_linux_optdeps/scripts/install.sh
+ bash .github/unittest/rl_linux_optdeps/scripts/run_test.sh
+ bash .github/unittest/rl_linux_optdeps/scripts/post_process.sh
diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml
index fef883d26..860ac2541 100644
--- a/.github/workflows/wheels.yml
+++ b/.github/workflows/wheels.yml
@@ -4,7 +4,7 @@ on:
types: [opened, synchronize, reopened]
push:
branches:
- - release/0.1.3
+ - release/0.2.0
concurrency:
# Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}.
@@ -32,7 +32,7 @@ jobs:
run: |
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
python3 -mpip install wheel
- BUILD_VERSION=0.1.3 python3 setup.py bdist_wheel
+ BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel
# NB: wheels have the linux_x86_64 tag so we rename to manylinux1
# find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \;
# pytorch/pytorch binaries are also manylinux_2_17 compliant but they
@@ -72,7 +72,7 @@ jobs:
run: |
export CC=clang CXX=clang++
python3 -mpip install wheel
- BUILD_VERSION=0.1.3 python3 setup.py bdist_wheel
+ BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v2
with:
@@ -104,7 +104,7 @@ jobs:
shell: bash
run: |
python3 -mpip install wheel
- BUILD_VERSION=0.1.3 python3 setup.py bdist_wheel
+ BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v2
with:
diff --git a/README.md b/README.md
index a7f5e8cb2..a6fddbcb5 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
[![Docs - GitHub.io](https://img.shields.io/static/v1?logo=github&style=flat&color=pink&label=docs&message=tensordict)][#docs-package]
[![Benchmarks](https://img.shields.io/badge/Benchmarks-blue.svg)][#docs-package-benchmark]
@@ -10,18 +10,18 @@
[![Downloads](https://static.pepy.tech/personalized-badge/tensordict?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads)][#pepy-package]
[![Downloads](https://static.pepy.tech/personalized-badge/tensordict-nightly?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads%20(nightly))][#pepy-package-nightly]
-[![codecov](https://codecov.io/gh/pytorch-labs/tensordict/branch/main/graph/badge.svg?token=9QTUG6NAGQ)][#codecov-package]
-[![circleci](https://circleci.com/gh/pytorch-labs/tensordict.svg?style=shield)][#circleci-package]
+[![codecov](https://codecov.io/gh/pytorch/tensordict/branch/main/graph/badge.svg?token=9QTUG6NAGQ)][#codecov-package]
+[![circleci](https://circleci.com/gh/pytorch/tensordict.svg?style=shield)][#circleci-package]
[![Conda - Platform](https://img.shields.io/conda/pn/conda-forge/tensordict?logo=anaconda&style=flat)][#conda-forge-package]
[![Conda (channel only)](https://img.shields.io/conda/vn/conda-forge/tensordict?logo=anaconda&style=flat&color=orange)][#conda-forge-package]
-[#docs-package]: https://pytorch-labs.github.io/tensordict/
-[#docs-package-benchmark]: https://pytorch-labs.github.io/tensordict/dev/bench/
-[#github-license]: https://github.com/pytorch-labs/tensordict/blob/main/LICENSE
+[#docs-package]: https://pytorch.github.io/tensordict/
+[#docs-package-benchmark]: https://pytorch.github.io/tensordict/dev/bench/
+[#github-license]: https://github.com/pytorch/tensordict/blob/main/LICENSE
[#pepy-package]: https://pepy.tech/project/tensordict
[#pepy-package-nightly]: https://pepy.tech/project/tensordict-nightly
-[#codecov-package]: https://codecov.io/gh/pytorch-labs/tensordict
-[#circleci-package]: https://circleci.com/gh/pytorch-labs/tensordict
+[#codecov-package]: https://codecov.io/gh/pytorch/tensordict
+[#circleci-package]: https://circleci.com/gh/pytorch/tensordict
[#conda-forge-package]: https://anaconda.org/conda-forge/tensordict
@@ -142,7 +142,7 @@ counterparts:
```
When nodes share a common scratch space, the
-[`MemmapTensor` backend](https://pytorch-labs.github.io/tensordict/tutorials/tensordict_memory.html)
+[`MemmapTensor` backend](https://pytorch.github.io/tensordict/tutorials/tensordict_memory.html)
can be used
to seamlessly send, receive and read a huge amount of data.
diff --git a/docs/source/_static/js/tensordict_theme.js b/docs/source/_static/js/tensordict_theme.js
index 9b29c6c57..3f7566317 100644
--- a/docs/source/_static/js/tensordict_theme.js
+++ b/docs/source/_static/js/tensordict_theme.js
@@ -944,10 +944,10 @@ var downloadNote = $(".sphx-glr-download-link-note.admonition.note");
if (downloadNote.length >= 1) {
var tutorialUrlArray = $("#tutorial-type").text().split('/');
- var githubLink = "https://github.com/pytorch-labs/tensordict/tree/main/sphinx-tutorials/" + tutorialUrlArray[tutorialUrlArray.length - 1] + ".py",
+ var githubLink = "https://github.com/pytorch/tensordict/tree/main/sphinx-tutorials/" + tutorialUrlArray[tutorialUrlArray.length - 1] + ".py",
notebookLink = $(".reference.download")[1].href,
notebookDownloadPath = notebookLink.split('_downloads')[1],
- colabLink = "https://colab.research.google.com/github/pytorch-labs/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
+ colabLink = "https://colab.research.google.com/github/pytorch/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
$("#google-colab-link").wrap("");
$("#download-notebook-link").wrap("");
@@ -1373,16 +1373,16 @@ $("table").removeAttr("border");
// with the 3 download buttons at the top of the page
var downloadNote = $(".sphx-glr-download-link-note.admonition.note");
-var githubLink = "https://github.com/pytorch-labs/tensordict/tree/main/tutorials/sphinx-tutorials/" + tutorialUrlArray.join("/") + ".py";
+var githubLink = "https://github.com/pytorch/tensordict/tree/main/tutorials/sphinx-tutorials/" + tutorialUrlArray.join("/") + ".py";
$("#github-view-link").wrap("");
// if (downloadNote.length >= 1) {
// var tutorialUrlArray = $("#tutorial-type").text().split('/');
-// var githubLink = "https://github.com/pytorch-labs/tensordict/tree/main/tutorials/" + tutorialUrlArray.join("/") + ".py",
+// var githubLink = "https://github.com/pytorch/tensordict/tree/main/tutorials/" + tutorialUrlArray.join("/") + ".py",
// notebookLink = $(".reference.download")[1].href,
// notebookDownloadPath = notebookLink.split('_downloads')[1],
-// colabLink = "https://colab.research.google.com/github/pytorch-labs/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
+// colabLink = "https://colab.research.google.com/github/pytorch/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
// $("#google-colab-link").wrap("");
// $("#download-notebook-link").wrap("");
diff --git a/docs/source/_static/js/theme.js b/docs/source/_static/js/theme.js
index fa20bdbc7..82664c699 100644
--- a/docs/source/_static/js/theme.js
+++ b/docs/source/_static/js/theme.js
@@ -945,10 +945,10 @@ if (downloadNote.length >= 1) {
var tutorialUrlArray = $("#tutorial-type").text().split('/');
tutorialUrlArray[0] = tutorialUrlArray[0] + "/sphinx_tuto"
- var githubLink = "https://github.com/pytorch-labs/tensordict/blob/main/" + tutorialUrlArray.join("/") + ".py",
+ var githubLink = "https://github.com/pytorch/tensordict/blob/main/" + tutorialUrlArray.join("/") + ".py",
notebookLink = $(".reference.download")[1].href,
notebookDownloadPath = notebookLink.split('_downloads')[1],
- colabLink = "https://colab.research.google.com/github/pytorch-labs/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
+ colabLink = "https://colab.research.google.com/github/pytorch/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
$("#google-colab-link").wrap("");
$("#download-notebook-link").wrap("");
@@ -2073,10 +2073,10 @@ if (downloadNote.length >= 1) {
var tutorialUrlArray = $("#tutorial-type").text().split('/');
tutorialUrlArray[0] = tutorialUrlArray[0] + "/sphinx_tuto"
- var githubLink = "https://github.com/pytorch-labs/tensordict/blob/main/" + tutorialUrlArray.join("/") + ".py",
+ var githubLink = "https://github.com/pytorch/tensordict/blob/main/" + tutorialUrlArray.join("/") + ".py",
notebookLink = $(".reference.download")[1].href,
notebookDownloadPath = notebookLink.split('_downloads')[1],
- colabLink = "https://colab.research.google.com/github/pytorch-labs/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
+ colabLink = "https://colab.research.google.com/github/pytorch/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
$("#google-colab-link").wrap("");
$("#download-notebook-link").wrap("");
@@ -3201,10 +3201,10 @@ if (downloadNote.length >= 1) {
var tutorialUrlArray = $("#tutorial-type").text().split('/');
tutorialUrlArray[0] = tutorialUrlArray[0] + "/sphinx_tuto"
- var githubLink = "https://github.com/pytorch-labs/tensordict/blob/main/" + tutorialUrlArray.join("/") + ".py",
+ var githubLink = "https://github.com/pytorch/tensordict/blob/main/" + tutorialUrlArray.join("/") + ".py",
notebookLink = $(".reference.download")[1].href,
notebookDownloadPath = notebookLink.split('_downloads')[1],
- colabLink = "https://colab.research.google.com/github/pytorch-labs/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
+ colabLink = "https://colab.research.google.com/github/pytorch/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
$("#google-colab-link").wrap("");
$("#download-notebook-link").wrap("");
@@ -4328,10 +4328,10 @@ require=(function(){function r(e,n,t){function o(i,f){if(!n[i]){if(!e[i]){var c=
if (downloadNote.length >= 1) {
var tutorialUrlArray = $("#tutorial-type").text().split('/');
- var githubLink = "https://github.com/pytorch-labs/tensordict/tree/main/tutorials/sphinx_tuto/" + tutorialUrlArray[tutorialUrlArray.length - 1] + ".py",
+ var githubLink = "https://github.com/pytorch/tensordict/tree/main/tutorials/sphinx_tuto/" + tutorialUrlArray[tutorialUrlArray.length - 1] + ".py",
notebookLink = $(".reference.download")[1].href,
notebookDownloadPath = notebookLink.split('_downloads')[1],
- colabLink = "https://colab.research.google.com/github/pytorch-labs/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
+ colabLink = "https://colab.research.google.com/github/pytorch/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
$("#google-colab-link").wrap("");
$("#download-notebook-link").wrap("");
@@ -4760,10 +4760,10 @@ require=(function(){function r(e,n,t){function o(i,f){if(!n[i]){if(!e[i]){var c=
if (downloadNote.length >= 1) {
var tutorialUrlArray = $("#tutorial-type").text().split('/');
- var githubLink = "https://github.com/pytorch-labs/tensordict/tree/main/tutorials/sphinx_tuto/" + tutorialUrlArray.join("/") + ".py",
+ var githubLink = "https://github.com/pytorch/tensordict/tree/main/tutorials/sphinx_tuto/" + tutorialUrlArray.join("/") + ".py",
notebookLink = $(".reference.download")[1].href,
notebookDownloadPath = notebookLink.split('_downloads')[1],
- colabLink = "https://colab.research.google.com/github/pytorch-labs/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
+ colabLink = "https://colab.research.google.com/github/pytorch/tensordict/blob/gh-pages/_downloads" + notebookDownloadPath;
$("#google-colab-link").wrap("");
$("#download-notebook-link").wrap("");
diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst
index fa55636d5..16c0b2358 100644
--- a/docs/source/distributed.rst
+++ b/docs/source/distributed.rst
@@ -81,7 +81,7 @@ Operating on Memory-mapped tensors across nodes
We provide a simple example of a distributed script where one process creates a
memory-mapped tensor, and sends its reference to another worker that is responsible of
updating it. You will find this example in the
-`benchmark directory `_.
+`benchmark directory `_.
In short, our goal is to show how to handle read and write operations on big
tensors when nodes have access to a shared physical storage. The steps involve:
diff --git a/docs/source/saving.rst b/docs/source/saving.rst
index 06a26abc6..f03d7b62b 100644
--- a/docs/source/saving.rst
+++ b/docs/source/saving.rst
@@ -79,17 +79,17 @@ that we will re-populate with the saved data.
Again, two lines of code are sufficient to save the data:
>>> app_state = {
- ... "state": torchsnapshot.StateDict(tensordict=tensordict_source.state_dict())
+ ... "state": torchsnapshot.StateDict(tensordict=tensordict_source.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path="/path/to/my/snapshot")
We have been using :obj:`torchsnapshot.StateDict` and we explicitly called
-:obj:`my_tensordict_source.state_dict()`, unlike the previous example.
+:obj:`my_tensordict_source.state_dict(keep_vars=True)`, unlike the previous example.
Now, to load this onto a destination tensordict:
>>> snapshot = Snapshot(path="/path/to/my/snapshot")
>>> app_state = {
- ... "state": torchsnapshot.StateDict(tensordict=tensordict_target.state_dict())
+ ... "state": torchsnapshot.StateDict(tensordict=tensordict_target.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
@@ -117,7 +117,7 @@ Here is a full example:
>>> assert isinstance(td["b", "c"], MemmapTensor)
>>>
>>> app_state = {
- ... "state": torchsnapshot.StateDict(tensordict=td.state_dict())
+ ... "state": torchsnapshot.StateDict(tensordict=td.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=f"/tmp/{uuid.uuid4()}")
>>>
@@ -126,7 +126,7 @@ Here is a full example:
>>> td_dest.memmap_()
>>> assert isinstance(td_dest["b", "c"], MemmapTensor)
>>> app_state = {
- ... "state": torchsnapshot.StateDict(tensordict=td_dest.state_dict())
+ ... "state": torchsnapshot.StateDict(tensordict=td_dest.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
>>> # sanity check
@@ -157,7 +157,7 @@ Finally, tensorclass also supports this feature. The code is fairly similar to t
>>> assert isinstance(tc.y.x, MemmapTensor)
>>>
>>> app_state = {
- ... "state": torchsnapshot.StateDict(tensordict=tc.state_dict())
+ ... "state": torchsnapshot.StateDict(tensordict=tc.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=f"/tmp/{uuid.uuid4()}")
>>>
@@ -165,7 +165,7 @@ Finally, tensorclass also supports this feature. The code is fairly similar to t
>>> tc_dest.memmap_()
>>> assert isinstance(tc_dest.y.x, MemmapTensor)
>>> app_state = {
- ... "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict())
+ ... "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
>>>
diff --git a/setup.cfg b/setup.cfg
index ef057bb52..6e2c14d79 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -27,4 +27,4 @@ add-ignore = D100, D104, D105, D107, D102
ignore-decorators =
test_*
; test/*.py
-; .circleci/*
+; .github/*
diff --git a/setup.py b/setup.py
index 4b9b5f508..0fcae2fa3 100644
--- a/setup.py
+++ b/setup.py
@@ -164,7 +164,7 @@ def _main(argv):
version=version,
author="tensordict contributors",
author_email="vmoens@fb.com",
- url="https://github.com/pytorch-labs/tensordict",
+ url="https://github.com/pytorch/tensordict",
long_description=long_description,
long_description_content_type="text/markdown",
license="BSD",
diff --git a/tensordict/nn/distributions/continuous.py b/tensordict/nn/distributions/continuous.py
index 105656306..75ee36271 100644
--- a/tensordict/nn/distributions/continuous.py
+++ b/tensordict/nn/distributions/continuous.py
@@ -247,7 +247,7 @@ def sample(
) -> torch.Tensor:
if sample_shape is None:
sample_shape = torch.Size([])
- return self.param.expand(*sample_shape, *self.param.shape)
+ return self.param.expand((*sample_shape, *self.param.shape))
def rsample(
self,
@@ -255,7 +255,7 @@ def rsample(
) -> torch.Tensor:
if sample_shape is None:
sample_shape = torch.Size([])
- return self.param.expand(*sample_shape, *self.param.shape)
+ return self.param.expand((*sample_shape, *self.param.shape))
@property
def mode(self) -> torch.Tensor:
diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py
index a6a602bb0..3685661c7 100644
--- a/tensordict/nn/params.py
+++ b/tensordict/nn/params.py
@@ -10,9 +10,10 @@
import re
from copy import copy
from functools import wraps
-from typing import Any, Callable, Iterator, Sequence
+from typing import Any, Callable, Iterator, OrderedDict, Sequence
import torch
+from functorch import dim as ftdim
from tensordict import TensorDictBase
from tensordict.nn.utils import Buffer
@@ -72,7 +73,7 @@ def _get_args_dict(func, args, kwargs):
def _maybe_make_param(tensor):
if (
- isinstance(tensor, Tensor)
+ isinstance(tensor, (Tensor, ftdim.Tensor))
and not isinstance(tensor, nn.Parameter)
and tensor.dtype in (torch.float, torch.double, torch.half)
):
@@ -82,7 +83,7 @@ def _maybe_make_param(tensor):
def _maybe_make_param_or_buffer(tensor):
if (
- isinstance(tensor, Tensor)
+ isinstance(tensor, (Tensor, ftdim.Tensor))
and not isinstance(tensor, nn.Parameter)
and tensor.dtype in (torch.float, torch.double, torch.half)
):
@@ -319,7 +320,7 @@ def __torch_function__(
if kwargs is None:
kwargs = {}
if func not in TDPARAM_HANDLED_FUNCTIONS or not all(
- issubclass(t, (Tensor, TensorDictBase)) for t in types
+ issubclass(t, (Tensor, ftdim.Tensor, TensorDictBase)) for t in types
):
return NotImplemented
return TDPARAM_HANDLED_FUNCTIONS[func](*args, **kwargs)
@@ -744,6 +745,53 @@ def values(
continue
yield self._apply_get_post_hook(v)
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ sd = self._param_td.flatten_keys(".").state_dict(
+ destination=destination, prefix=prefix, keep_vars=keep_vars
+ )
+ return sd
+
+ def load_state_dict(
+ self, state_dict: OrderedDict[str, Any], strict=True, assign=False
+ ):
+ state_dict_tensors = {}
+ state_dict = dict(state_dict)
+ for k, v in list(state_dict.items()):
+ if isinstance(v, torch.Tensor):
+ del state_dict[k]
+ state_dict_tensors[k] = v
+ state_dict_tensors = dict(
+ TensorDict(state_dict_tensors, []).unflatten_keys(".")
+ )
+ self.data.load_state_dict(
+ {**state_dict_tensors, **state_dict}, strict=True, assign=False
+ )
+ return self
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ data = (
+ TensorDict(
+ {
+ key: val
+ for key, val in state_dict.items()
+ if key.startswith(prefix) and val is not None
+ },
+ [],
+ )
+ .unflatten_keys(".")
+ .get(prefix[:-1])
+ )
+ self.data.load_state_dict(data)
+
def items(
self, include_nested: bool = False, leaves_only: bool = False
) -> Iterator[CompatibleType]:
diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py
index 2113a81e2..657cb814e 100644
--- a/tensordict/tensorclass.py
+++ b/tensordict/tensorclass.py
@@ -740,14 +740,22 @@ def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417
self._tensordict._batch_size_setter(new_size)
-def _state_dict(self) -> dict[str, Any]:
+def _state_dict(
+ self, destination=None, prefix="", keep_vars=False, flatten=False
+) -> dict[str, Any]:
"""Returns a state_dict dictionary that can be used to save and load data from a tensorclass."""
- state_dict = {"_tensordict": self._tensordict.state_dict()}
+ state_dict = {
+ "_tensordict": self._tensordict.state_dict(
+ destination=destination, prefix=prefix, keep_vars=keep_vars, flatten=flatten
+ )
+ }
state_dict["_non_tensordict"] = copy(self._non_tensordict)
return state_dict
-def _load_state_dict(self, state_dict: dict[str, Any]):
+def _load_state_dict(
+ self, state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False
+):
"""Loads a state_dict attemptedly in-place on the destination tensorclass."""
for key, item in state_dict.items():
# keys will never be nested which facilitates everything, but let's
@@ -778,7 +786,9 @@ def _load_state_dict(self, state_dict: dict[str, Any]):
f"Key '{sub_key}' wasn't expected in the state-dict."
)
- self._tensordict.load_state_dict(item)
+ self._tensordict.load_state_dict(
+ item, strict=strict, assign=assign, from_flatten=from_flatten
+ )
else:
raise KeyError(f"Key '{key}' wasn't expected in the state-dict.")
diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py
index 4ed333d7b..4a83e24e1 100644
--- a/tensordict/tensordict.py
+++ b/tensordict/tensordict.py
@@ -32,11 +32,13 @@
TypeVar,
Union,
)
+
from warnings import warn
import numpy as np
import torch
+from functorch import dim as ftdim
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict.memmap import memmap_tensor_as_tensor, MemmapTensor
from tensordict.utils import (
@@ -69,7 +71,7 @@
NestedKey,
prod,
)
-from torch import distributed as dist, multiprocessing as mp, Tensor
+from torch import distributed as dist, multiprocessing as mp, nn, Tensor
from torch.utils._pytree import tree_map
try:
@@ -376,7 +378,7 @@ def from_module(module, as_module: bool = False):
"""Copies the params and buffers of a module in a tensordict.
Args:
- as_module (bool, optional): if ``True``, a :class:`tensordict.nn.TensorDictParams`
+ as_module (bool, optional): if ``True``, a :class:`~tensordict.nn.TensorDictParams`
instance will be returned which can be used to store parameters
within a :class:`torch.nn.Module`. Defaults to ``False``.
@@ -408,6 +410,24 @@ def from_module(module, as_module: bool = False):
return TensorDictParams(td, no_convert=True)
return td
+ def to_module(self, module):
+ from tensordict.nn.functional_modules import set_tensor_dict
+
+ __base__setattr__ = nn.Module.__setattr__
+ # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
+ __dict__ = module.__dict__
+
+ for key, value in self.items():
+ cls = value.__class__
+ if _is_tensor_collection(cls) or issubclass(cls, dict):
+ value.to_module(__dict__["_modules"][key])
+ else:
+ if module.__class__.__setattr__ is __base__setattr__:
+ set_tensor_dict(__dict__, module, key, value)
+ else:
+ # use specialized __setattr__ if needed
+ setattr(module, key, value)
+
@property
def shape(self) -> torch.Size:
"""See :obj:`TensorDictBase.batch_size`."""
@@ -691,12 +711,54 @@ def is_shared(self) -> bool:
return self.device.type == "cuda" or self._is_shared
return self._is_shared
- def state_dict(self) -> OrderedDict[str, Any]:
+ def state_dict(
+ self,
+ destination=None,
+ prefix="",
+ keep_vars=False,
+ flatten=False,
+ ) -> OrderedDict[str, Any]:
+ """Produces a state_dict from the tensordict. The structure of the state-dict will still be nested, unless ``flatten`` is set to ``True``.
+
+ A tensordict state-dict contains all the tensors and meta-data needed
+ to rebuild the tensordict (names are currently not supported).
+
+ Args:
+ destination (dict, optional): If provided, the state of tensordict will
+ be updated into the dict and the same object is returned.
+ Otherwise, an ``OrderedDict`` will be created and returned.
+ Default: ``None``.
+ prefix (str, optional): a prefix added to tensor
+ names to compose the keys in state_dict. Default: ``''``.
+ keep_vars (bool, optional): by default the :class:`torch.Tensor` s
+ returned in the state dict are detached from autograd. If it's
+ set to ``True``, detaching will not be performed.
+ Default: ``False``.
+ flatten (bool, optional): whether the structure should be flattened
+ with the ``"."`` character or not.
+ Defaults to ``False``.
+
+ Examples:
+ >>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, [])
+ >>> sd = data.state_dict()
+ >>> print(sd)
+ OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3', OrderedDict([('3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])), ('__batch_size', torch.Size([])), ('__device', None)])
+ >>> sd = data.state_dict(flatten=True)
+ OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3.3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])
+
+ """
out = collections.OrderedDict()
- for key, item in self.apply(memmap_tensor_as_tensor).items():
- out[key] = (
- item if not _is_tensor_collection(item.__class__) else item.state_dict()
- )
+ source = self.apply(memmap_tensor_as_tensor)
+ if flatten:
+ source = source.flatten_keys(".")
+ for key, item in source.items():
+ if not _is_tensor_collection(item.__class__):
+ if not keep_vars:
+ out[prefix + key] = item.detach().clone()
+ else:
+ out[prefix + key] = item
+ else:
+ out[prefix + key] = item.state_dict(keep_vars=keep_vars)
if "__batch_size" in out:
raise KeyError(
"Cannot retrieve the state_dict of a TensorDict with `'__batch_size'` key"
@@ -705,26 +767,95 @@ def state_dict(self) -> OrderedDict[str, Any]:
raise KeyError(
"Cannot retrieve the state_dict of a TensorDict with `'__batch_size'` key"
)
- out["__batch_size"] = self.batch_size
- out["__device"] = self.device
+ out[prefix + "__batch_size"] = source.batch_size
+ out[prefix + "__device"] = source.device
+ if destination is not None:
+ destination.update(out)
+ return destination
return out
- def load_state_dict(self, state_dict: OrderedDict[str, Any]) -> T:
+ def load_state_dict(
+ self,
+ state_dict: OrderedDict[str, Any],
+ strict=True,
+ assign=False,
+ from_flatten=False,
+ ) -> T:
+ """Loads a state-dict, formatted as in :meth:`~.state_dict`, into the tensordict.
+
+ Args:
+ state_dict (OrderedDict): the state_dict of to be copied.
+ strict (bool, optional): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this tensordict's
+ :meth:`torch.nn.Module.state_dict` function. Default: ``True``
+ assign (bool, optional): whether to assign items in the state
+ dictionary to their corresponding keys in the tensordict instead
+ of copying them inplace into the tensordict's current tensors.
+ When ``False``, the properties of the tensors in the current
+ module are preserved while when ``True``, the properties of the
+ Tensors in the state dict are preserved.
+ Default: ``False``
+ from_flatten (bool, optional): if ``True``, the input state_dict is
+ assumed to be flattened.
+ Defaults to ``False``.
+
+ Examples:
+ >>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, [])
+ >>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, [])
+ >>> sd = data.state_dict()
+ >>> data_zeroed.load_state_dict(sd)
+ >>> print(data_zeroed["3", "3"])
+ tensor(3)
+ >>> # with flattening
+ >>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, [])
+ >>> data_zeroed.load_state_dict(data.state_dict(flatten=True), from_flatten=True)
+ >>> print(data_zeroed["3", "3"])
+ tensor(3)
+
+
+ """
+ if from_flatten:
+ self_flatten = self.flatten_keys(".")
+ self_flatten.load_state_dict(state_dict, strict=strict, assign=assign)
+ if not assign:
+ # modifications are done in-place so we should be fine returning self
+ return self
+ else:
+ # run a check over keys, if we any key with a '.' in name we're doomed
+ DOT_ERROR = "Cannot use load_state_dict(..., from_flatten=True, assign=True) when some keys contain a dot character."
+ for key in self.keys(True, True):
+ if isinstance(key, tuple):
+ for subkey in key:
+ if "." in subkey:
+ raise RuntimeError(DOT_ERROR)
+ elif "." in key:
+ raise RuntimeError(DOT_ERROR)
+ return self.update(self_flatten.unflatten_keys("."))
# copy since we'll be using pop
state_dict = copy(state_dict)
self.batch_size = state_dict.pop("__batch_size")
- device = state_dict.pop("__device")
- if device is not None:
- self.to(device)
+ device = state_dict.pop("__device", None)
+ if device is not None and self.device is not None and device != self.device:
+ raise RuntimeError("Loading data from another device is not yet supported.")
+
for key, item in state_dict.items():
if isinstance(item, dict):
self.set(
key,
- self.get(key, default=TensorDict({}, [])).load_state_dict(item),
- inplace=True,
+ self.get(key, default=TensorDict({}, [])).load_state_dict(
+ item, assign=assign, strict=strict
+ ),
+ inplace=not assign,
)
else:
- self.set(key, item, inplace=True)
+ self.set(key, item, inplace=not assign)
+ if strict and set(state_dict.keys()) != set(self.keys()):
+ set_sd = set(state_dict.keys())
+ set_td = set(self.keys())
+ raise RuntimeError(
+ "Cannot load state-dict because the key sets don't match: got "
+ f"state_dict extra keys \n{set_sd-set_td}\n and tensordict extra keys\n{set_td-set_sd}\n"
+ )
return self
def is_memmap(self) -> bool:
@@ -1940,7 +2071,15 @@ def sorted_keys(self) -> list[NestedKey]:
"""
return sorted(self.keys())
+ @overload
def expand(self, *shape: int) -> T:
+ ...
+
+ @overload
+ def expand(self, shape: torch.Size) -> T:
+ ...
+
+ def expand(self, *args: int | torch.Size) -> T:
"""Expands each tensors of the tensordict according to the torch.expand function.
In practice, this amends to: :obj:`tensor.expand(*shape, *tensor.shape)`.
@@ -1958,8 +2097,11 @@ def expand(self, *shape: int) -> T:
d = {}
tensordict_dims = self.batch_dims
- if len(shape) == 1 and isinstance(shape[0], Sequence):
- shape = tuple(shape[0])
+ if len(args) == 1 and isinstance(args[0], Sequence):
+ shape = tuple(args[0])
+ else:
+ # we don't check that all elements are int to reduce overhead
+ shape = args
# new shape dim check
if len(shape) < len(self.shape):
@@ -2412,7 +2554,7 @@ def to_h5(
**kwargs: kwargs to be passed to :meth:`h5py.File.create_dataset`.
Returns:
- A :class:`PersitentTensorDict` instance linked to the newly created file.
+ A :class:`~.tensordict.PersitentTensorDict` instance linked to the newly created file.
Examples:
>>> import tempfile
@@ -2550,7 +2692,7 @@ def clone(self, recurse: bool = True) -> T:
TensorDict will be copied too. Default is `True`.
.. note::
- For some TensorDictBase subtypes, such as :class:`SubTensorDict`, cloning
+ For some TensorDictBase subtypes, such as :class:`~.tensordict.SubTensorDict`, cloning
recursively makes little sense (in this specific case it would involve
copying the parent tensordict too). In those cases, :meth:`~.clone` will
fall back onto :meth:`~.to_tensordict`.
@@ -2622,7 +2764,7 @@ def to(self, *args, **kwargs) -> T:
other (TensorDictBase, optional): TensorDict instance whose dtype
and device are the desired dtype and device for all tensors
in this TensorDict.
- .. note:: Since :class:`TensorDictBase` instances do not have
+ .. note:: Since :class:`~tensordict.TensorDictBase` instances do not have
a dtype, the dtype is gathered from the example leaves.
If there are more than one dtype, then no dtype
casting is undertook.
@@ -3345,7 +3487,6 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T:
if key in keys and (
not is_tensor_collection(out.get(key)) or not out.get(key).is_empty()
):
- print(out.get(key))
raise KeyError(
"Unflattening key(s) in tensordict will override existing unflattened key"
)
@@ -3394,6 +3535,8 @@ def _get_names_idx(self, idx):
else:
def is_boolean(idx):
+ if isinstance(idx, ftdim.Dim):
+ return None
if isinstance(idx, tuple) and len(idx) == 1:
return is_boolean(idx[0])
if hasattr(idx, "dtype") and idx.dtype is torch.bool:
@@ -3765,6 +3908,7 @@ def type(self, dst_type):
Tensor,
MemmapTensor,
TensorDictBase,
+ ftdim.Tensor,
]
if _has_torchrec:
_ACCEPTED_CLASSES += [KeyedJaggedTensor]
@@ -3930,7 +4074,7 @@ def __init__(
@classmethod
def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None):
- """Returns a TensorDict created from a dictionary or another :class:`TensorDict`.
+ """Returns a TensorDict created from a dictionary or another :class:`~.tensordict.TensorDict`.
If ``batch_size`` is not specified, returns the maximum batch size possible.
@@ -4174,7 +4318,15 @@ def pin_mem(tensor):
return self.apply(pin_mem)
+ @overload
def expand(self, *shape: int) -> T:
+ ...
+
+ @overload
+ def expand(self, shape: torch.Size) -> T:
+ ...
+
+ def expand(self, *args: int | torch.Size) -> T:
"""Expands every tensor with `(*shape, *tensor.shape)` and returns the same tensordict with new tensors with expanded shapes.
Supports iterables to specify the shape.
@@ -4183,8 +4335,11 @@ def expand(self, *shape: int) -> T:
d = {}
tensordict_dims = self.batch_dims
- if len(shape) == 1 and isinstance(shape[0], Sequence):
- shape = tuple(shape[0])
+ if len(args) == 1 and isinstance(args[0], Sequence):
+ shape = tuple(args[0])
+ else:
+ # we don't check that all elements are int to reduce overhead
+ shape = args
# new shape dim check
if len(shape) < len(self.shape):
@@ -4209,9 +4364,9 @@ def expand(self, *shape: int) -> T:
tensor_dims = len(value.shape)
last_n_dims = tensor_dims - tensordict_dims
if last_n_dims > 0:
- d[key] = value.expand(*shape, *value.shape[-last_n_dims:])
+ d[key] = value.expand((*shape, *value.shape[-last_n_dims:]))
else:
- d[key] = value.expand(*shape)
+ d[key] = value.expand(shape)
out = TensorDict(
source=d,
batch_size=torch.Size(shape),
@@ -4452,11 +4607,6 @@ def memmap_(
raise RuntimeError(
"memmap and shared memory are mutually exclusive features."
)
- # if not self._tensordict.keys():
- # raise Exception(
- # "memmap_() must be called when the TensorDict is (partially) "
- # "populated. Set a tensor first."
- # )
for key, value in self.items():
if value.requires_grad:
raise Exception(
@@ -5867,8 +6017,8 @@ def clone(self, recurse: bool = True) -> SubTensorDict:
Args:
recurse (bool, optional): if ``True`` (default), a regular
- :class:`TensorDict` instance will be created from the :class:`SubTensorDict`.
- Otherwise, another :class:`SubTensorDict` with identical content
+ :class:`~.tensordict.TensorDict` instance will be created from the :class:`~.tensordict.SubTensorDict`.
+ Otherwise, another :class:`~.tensordict.SubTensorDict` with identical content
will be returned.
Examples:
@@ -5929,9 +6079,11 @@ def select(self, *keys: str, inplace: bool = False, strict: bool = True) -> T:
return self
return self._source.select(*keys, strict=strict)[self.idx]
- def expand(self, *shape: int, inplace: bool = False) -> T:
- if len(shape) == 1 and isinstance(shape[0], Sequence):
- shape = tuple(shape[0])
+ def expand(self, *args: int, inplace: bool = False) -> T:
+ if len(args) == 1 and isinstance(args[0], Sequence):
+ shape = tuple(args[0])
+ else:
+ shape = args
return self.apply(
lambda x: x.expand((*shape, *x.shape[self.ndim :])), batch_size=shape
)
@@ -6396,7 +6548,13 @@ def _split_index(self, index):
continue
if cursor == self.stack_dim:
# we need to check which tds need to be indexed
- if isinstance(idx, slice) or _is_number(idx):
+ if isinstance(idx, ftdim.Dim):
+ raise ValueError(
+ "Cannot index a lazy stacked tensordict along the stack dimension with "
+ "a first-class dimension index. Consider consolidating the tensordict first "
+ "using `tensordict.contiguous()`."
+ )
+ elif isinstance(idx, slice) or _is_number(idx):
selected_td_idx = range(len(self.tensordicts))[idx]
if not isinstance(selected_td_idx, range):
isinteger = True
@@ -6428,6 +6586,7 @@ def _split_index(self, index):
idx,
(
int,
+ ftdim.Dim,
slice,
list,
range,
@@ -7241,54 +7400,6 @@ def __getitem__(self, index: IndexType) -> T:
out._td_dim_name = self._td_dim_name
return out
- # index_dict = _convert_index_lazystack(index, self.stack_dim, self.batch_size)
- # if index_dict is None:
- # # then we use a sub-tensordict
- # return self.get_sub_tensordict(index)
- # td_index = index_dict["remaining_index"]
- # stack_index = index_dict["stack_index"]
- # new_stack_dim = index_dict["new_stack_dim"]
- # if new_stack_dim is not None:
- # if isinstance(stack_index, slice):
- # # we can't iterate but we can index the list directly
- # out = LazyStackedTensorDict(
- # *[td[td_index] for td in self.tensordicts[stack_index]],
- # stack_dim=new_stack_dim,
- # )
- # elif isinstance(stack_index, (list, range)):
- # # then we can iterate
- # out = LazyStackedTensorDict(
- # *[self.tensordicts[idx][td_index] for idx in stack_index],
- # stack_dim=new_stack_dim,
- # )
- # elif isinstance(stack_index, Tensor):
- # # td_index is a nested tuple that mimics the shape of stack_index
- # def _nested_stack(t: list, stack_idx: Tensor, td_index):
- # if stack_idx.ndim:
- # out = LazyStackedTensorDict(
- # *[
- # _nested_stack(t, _idx, td_index[i])
- # for i, _idx in enumerate(stack_idx.unbind(0))
- # ],
- # stack_dim=new_stack_dim,
- # )
- # return out
- # return t[stack_idx][td_index]
- #
- # # print(index, td_index, stack_index)
- # out = _nested_stack(self.tensordicts, stack_index, td_index)
- # else:
- # raise TypeError("Invalid index used for stack dimension.")
- # out._td_dim_name = self._td_dim_name
- # return out
- # out = self.tensordicts[stack_index]
- # if td_index:
- # return out[td_index]
- # return out
-
- # def __hash__(self):
- # return hash(self.tensordicts)
-
def __eq__(self, other):
if is_tensorclass(other):
return other == self
@@ -7531,12 +7642,14 @@ def load_memmap(cls, prefix: str) -> LazyStackedTensorDict:
metadata = torch.load(prefix / "meta.pt")
return cls(*tensordicts, stack_dim=metadata["stack_dim"])
- def expand(self, *shape: int, inplace: bool = False) -> T:
- if len(shape) == 1 and isinstance(shape[0], Sequence):
- shape = tuple(shape[0])
+ def expand(self, *args: int, inplace: bool = False) -> T:
+ if len(args) == 1 and isinstance(args[0], Sequence):
+ shape = tuple(args[0])
+ else:
+ shape = args
stack_dim = len(shape) + self.stack_dim - self.ndimension()
new_shape_tensordicts = [v for i, v in enumerate(shape) if i != stack_dim]
- tensordicts = [td.expand(*new_shape_tensordicts) for td in self.tensordicts]
+ tensordicts = [td.expand(new_shape_tensordicts) for td in self.tensordicts]
if inplace:
self.tensordicts = tensordicts
self.stack_dim = stack_dim
@@ -8126,8 +8239,8 @@ def clone(self, recurse: bool = True) -> T:
Args:
recurse (bool, optional): if ``True`` (default), a regular
- :class:`TensorDict` instance will be returned.
- Otherwise, another :class:`SubTensorDict` with identical content
+ :class:`~.tensordict.TensorDict` instance will be returned.
+ Otherwise, another :class:`~.tensordict.SubTensorDict` with identical content
will be returned.
"""
if not recurse:
@@ -8810,17 +8923,17 @@ def dense_stack_tds(
td_list: Sequence[TensorDictBase] | LazyStackedTensorDict,
dim: int = None,
) -> T:
- """Densely stack a list of :class:`tensordict.TensorDictBase` objects (or a :class:`tensordict.LazyStackedTensorDict`) given that they have the same structure.
+ """Densely stack a list of :class:`~tensordict.TensorDictBase` objects (or a :class:`~tensordict.LazyStackedTensorDict`) given that they have the same structure.
- This function is called with a list of :class:`tensordict.TensorDictBase` (either passed directly or obtrained from
- a :class:`tensordict.LazyStackedTensorDict`).
- Instead of calling ``torch.stack(td_list)``, which would return a :class:`tensordict.LazyStackedTensorDict`,
+ This function is called with a list of :class:`~tensordict.TensorDictBase` (either passed directly or obtrained from
+ a :class:`~tensordict.LazyStackedTensorDict`).
+ Instead of calling ``torch.stack(td_list)``, which would return a :class:`~tensordict.LazyStackedTensorDict`,
this function expands the first element of the input list and stacks the input list onto that element.
This works only when all the elements of the input list have the same structure.
- The :class:`tensordict.TensorDictBase` returned will have the same type of the elements of the input list.
+ The :class:`~tensordict.TensorDictBase` returned will have the same type of the elements of the input list.
- This function is useful when some of the :class:`tensordict.TensorDictBase` objects that need to be stacked
- are :class:`tensordict.LazyStackedTensorDict` or have :class:`tensordict.LazyStackedTensorDict`
+ This function is useful when some of the :class:`~tensordict.TensorDictBase` objects that need to be stacked
+ are :class:`~tensordict.LazyStackedTensorDict` or have :class:`~tensordict.LazyStackedTensorDict`
among entries (or nested entries).
In those cases, calling ``torch.stack(td_list).to_tensordict()`` is infeasible.
Thus, this function provides an alternative for densely stacking the list provided.
@@ -8951,7 +9064,7 @@ def _clone_value(value: CompatibleType, recurse: bool) -> CompatibleType:
def _is_number(item):
- if isinstance(item, Number):
+ if isinstance(item, (Number, ftdim.Dim)):
return True
if isinstance(item, Tensor) and item.ndim == 0:
return True
diff --git a/tensordict/utils.py b/tensordict/utils.py
index d5f190ce7..78da3281f 100644
--- a/tensordict/utils.py
+++ b/tensordict/utils.py
@@ -20,6 +20,7 @@
import numpy as np
import torch
+from functorch import dim as ftdim
from packaging.version import parse
from tensordict._tensordict import ( # noqa: F401
@@ -150,7 +151,7 @@ def _getitem_batch_size(batch_size, index):
out.extend(bs_shape)
bs_shape = None
continue
- elif isinstance(idx, int):
+ elif isinstance(idx, (int, ftdim.Dim)):
# could be spared for efficiency
continue
elif isinstance(idx, slice):
@@ -506,7 +507,7 @@ def expand_right(
tensor_expand = tensor
while tensor_expand.ndimension() < len(shape):
tensor_expand = tensor_expand.unsqueeze(-1)
- tensor_expand = tensor_expand.expand(*shape)
+ tensor_expand = tensor_expand.expand(shape)
return tensor_expand
@@ -761,9 +762,12 @@ def _is_shared(tensor: torch.Tensor) -> bool:
if torch._C._functorch.is_batchedtensor(tensor):
return None
return tensor.is_shared()
+ if isinstance(tensor, ftdim.Tensor):
+ return None
elif isinstance(tensor, KeyedJaggedTensor):
return False
else:
+ print(type(tensor))
return tensor.is_shared()
diff --git a/test/_utils_internal.py b/test/_utils_internal.py
index 3990af874..06ece52d8 100644
--- a/test/_utils_internal.py
+++ b/test/_utils_internal.py
@@ -236,7 +236,7 @@ def td_params(self, device):
def expand_list(list_of_tensors, *dims):
n = len(list_of_tensors)
td = TensorDict({str(i): tensor for i, tensor in enumerate(list_of_tensors)}, [])
- td = td.expand(*dims).contiguous()
+ td = td.expand(dims).contiguous()
return [td[str(i)] for i in range(n)]
diff --git a/test/test_nn.py b/test/test_nn.py
index e9f76114e..26395fef7 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -3159,6 +3159,46 @@ def test_add_scale_sequence(self, num_outputs=4):
assert (scale > 0).all()
+class TestStateDict:
+ @pytest.mark.parametrize("detach", [True, False])
+ def test_sd_params(self, detach):
+ td = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, [])
+ td = TensorDictParams(td)
+ if detach:
+ sd = td.detach().clone().zero_().state_dict()
+ else:
+ sd = td.state_dict()
+ sd = {
+ k: v if not isinstance(v, torch.Tensor) else v * 0
+ for k, v in sd.items()
+ }
+ # do some op to create a graph
+ td.apply(lambda x: x + 1)
+ # load the data
+ td.load_state_dict(sd)
+ # check that data has been loaded
+ assert (td == 0).all()
+
+ def test_sd_module(self):
+ td = TensorDict({"1": 1.0, "2": 2.0, "3": {"3": 3.0}}, [])
+ td = TensorDictParams(td)
+ module = nn.Linear(3, 4)
+ module.td = td
+
+ sd = module.state_dict()
+ assert "td.1" in sd
+ assert "td.3.3" in sd
+ sd = {k: v * 0 if isinstance(v, torch.Tensor) else v for k, v in sd.items()}
+
+ # load the data
+ module.load_state_dict(sd)
+
+ # check that data has been loaded
+ assert (module.td == 0).all()
+ for val in td.values(True, True):
+ assert isinstance(val, nn.Parameter)
+
+
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py
index e43686fd1..fb44fa949 100644
--- a/test/test_tensorclass.py
+++ b/test/test_tensorclass.py
@@ -1352,7 +1352,9 @@ class MyClass:
assert isinstance(tc.y.x, MemmapTensor)
assert tc.z == z
- app_state = {"state": torchsnapshot.StateDict(tensordict=tc.state_dict())}
+ app_state = {
+ "state": torchsnapshot.StateDict(tensordict=tc.state_dict(keep_vars=True))
+ }
snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=str(tmp_path))
tc_dest = MyClass(
@@ -1363,7 +1365,9 @@ class MyClass:
)
tc_dest.memmap_()
assert isinstance(tc_dest.y.x, MemmapTensor)
- app_state = {"state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict())}
+ app_state = {
+ "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict(keep_vars=True))
+ }
snapshot.restore(app_state=app_state)
assert (tc_dest == tc).all()
diff --git a/test/test_tensordict.py b/test/test_tensordict.py
index fa7303738..8ec9edd9f 100644
--- a/test/test_tensordict.py
+++ b/test/test_tensordict.py
@@ -12,6 +12,7 @@
import pytest
import torch
+from tensordict.nn import TensorDictParams
try:
import torchsnapshot
@@ -30,6 +31,7 @@
_has_h5py = False
from _utils_internal import decompose, get_available_devices, prod, TestTensorDictsBase
+from functorch import dim as ftdim
from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict
from tensordict.tensordict import (
@@ -1380,6 +1382,34 @@ def test_cpu_cuda(self, td_name, device):
assert td_device.device == torch.device("cuda")
assert td_back.device == torch.device("cpu")
+ def test_state_dict(self, td_name, device):
+ torch.manual_seed(1)
+ td = getattr(self, td_name)(device)
+ sd = td.state_dict()
+ td_zero = td.clone().detach().zero_()
+ td_zero.load_state_dict(sd)
+ assert_allclose_td(td, td_zero)
+
+ def test_state_dict_strict(self, td_name, device):
+ torch.manual_seed(1)
+ td = getattr(self, td_name)(device)
+ sd = td.state_dict()
+ td_zero = td.clone().detach().zero_()
+ del sd["a"]
+ td_zero.load_state_dict(sd, strict=False)
+ with pytest.raises(RuntimeError):
+ td_zero.load_state_dict(sd, strict=True)
+
+ def test_state_dict_assign(self, td_name, device):
+ torch.manual_seed(1)
+ td = getattr(self, td_name)(device)
+ sd = td.state_dict()
+ td_zero = td.clone().detach().zero_()
+ shallow_copy = td_zero.clone(False)
+ td_zero.load_state_dict(sd, assign=True)
+ assert (shallow_copy == 0).all()
+ assert_allclose_td(td, td_zero)
+
@pytest.mark.parametrize("dim", range(4))
def test_unbind(self, td_name, device, dim):
if td_name not in ["sub_td", "idx_td", "td_reset_bs"]:
@@ -5115,7 +5145,11 @@ def test_inplace(self, save_name):
td.memmap_()
assert isinstance(td["b", "c"], MemmapTensor)
- app_state = {"state": torchsnapshot.StateDict(**{save_name: td.state_dict()})}
+ app_state = {
+ "state": torchsnapshot.StateDict(
+ **{save_name: td.state_dict(keep_vars=True)}
+ )
+ }
path = f"/tmp/{uuid.uuid4()}"
snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=path)
@@ -5132,7 +5166,9 @@ def test_inplace(self, save_name):
td_dest.memmap_()
assert isinstance(td_dest["b", "c"], MemmapTensor)
app_state = {
- "state": torchsnapshot.StateDict(**{save_name: td_dest.state_dict()})
+ "state": torchsnapshot.StateDict(
+ **{save_name: td_dest.state_dict(keep_vars=True)}
+ )
}
snapshot.restore(app_state=app_state)
@@ -6205,6 +6241,140 @@ def _pool_fixt():
yield pool
+class TestFCD(TestTensorDictsBase):
+ """Test stack for first-class dimension."""
+
+ @pytest.mark.parametrize(
+ "td_name",
+ [
+ "td",
+ "stacked_td",
+ "sub_td",
+ "sub_td2",
+ "idx_td",
+ "memmap_td",
+ "unsqueezed_td",
+ "squeezed_td",
+ "td_reset_bs",
+ "nested_td",
+ "nested_tensorclass",
+ "permute_td",
+ "nested_stacked_td",
+ "td_params",
+ pytest.param(
+ "td_h5",
+ marks=pytest.mark.skipif(not _has_h5py, reason="h5py not found."),
+ ),
+ ],
+ )
+ @pytest.mark.parametrize("device", get_available_devices())
+ def test_fcd(self, td_name, device):
+ td = getattr(self, td_name)(device)
+ d0 = ftdim.dims(1)
+ if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 0:
+ with pytest.raises(ValueError, match="Cannot index"):
+ td[d0]
+ else:
+ assert td[d0].shape == td.shape[1:]
+ d0, d1 = ftdim.dims(2)
+ if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1):
+ with pytest.raises(ValueError, match="Cannot index"):
+ td[d0, d1]
+ else:
+ assert td[d0, d1].shape == td.shape[2:]
+ d0, d1, d2 = ftdim.dims(3)
+ if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1, 2):
+ with pytest.raises(ValueError, match="Cannot index"):
+ td[d0, d1, d2]
+ else:
+ assert td[d0, d1, d2].shape == td.shape[3:]
+ d0 = ftdim.dims(1)
+ if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 1:
+ with pytest.raises(ValueError, match="Cannot index"):
+ td[:, d0]
+ else:
+ assert td[:, d0].shape == torch.Size((td.shape[0], *td.shape[2:]))
+
+ @pytest.mark.parametrize(
+ "td_name",
+ [
+ "td",
+ "stacked_td",
+ "idx_td",
+ "memmap_td",
+ "td_reset_bs",
+ "nested_td",
+ "nested_tensorclass",
+ "nested_stacked_td",
+ "td_params",
+ pytest.param(
+ "td_h5",
+ marks=pytest.mark.skipif(not _has_h5py, reason="h5py not found."),
+ ),
+ # these tds cannot see their dim names edited:
+ # "sub_td",
+ # "sub_td2",
+ # "unsqueezed_td",
+ # "squeezed_td",
+ # "permute_td",
+ ],
+ )
+ @pytest.mark.parametrize("device", get_available_devices())
+ def test_fcd_names(self, td_name, device):
+ td = getattr(self, td_name)(device)
+ td.names = ["a", "b", "c", "d"]
+ d0 = ftdim.dims(1)
+ if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 0:
+ with pytest.raises(ValueError, match="Cannot index"):
+ td[d0]
+ else:
+ assert td[d0].names == ["b", "c", "d"]
+ d0, d1 = ftdim.dims(2)
+ if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1):
+ with pytest.raises(ValueError, match="Cannot index"):
+ td[d0, d1]
+ else:
+ assert td[d0, d1].names == ["c", "d"]
+ d0, d1, d2 = ftdim.dims(3)
+ if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1, 2):
+ with pytest.raises(ValueError, match="Cannot index"):
+ td[d0, d1, d2]
+ else:
+ assert td[d0, d1, d2].names == ["d"]
+ d0 = ftdim.dims(1)
+ if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 1:
+ with pytest.raises(ValueError, match="Cannot index"):
+ td[:, d0]
+ else:
+ assert td[:, d0].names == ["a", "c", "d"]
+
+ @pytest.mark.parametrize("as_module", [False, True])
+ def test_modules(self, as_module):
+ modules = [
+ lambda: nn.Linear(3, 4),
+ lambda: nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4)),
+ lambda: nn.Transformer(16, 4, 2, 2, 8),
+ lambda: nn.Sequential(nn.Conv2d(3, 4, 3), nn.Conv2d(4, 4, 3)),
+ ]
+ inputs = [
+ lambda: (torch.randn(2, 3),),
+ lambda: (torch.randn(2, 3),),
+ lambda: (torch.randn(2, 3, 16), torch.randn(2, 3, 16)),
+ lambda: (torch.randn(2, 3, 16, 16),),
+ ]
+ param_batch = 5
+ for make_module, make_input in zip(modules, inputs):
+ module = make_module()
+ td = TensorDict.from_module(module, as_module=as_module)
+ td = td.expand(param_batch).clone()
+ d0 = ftdim.dims(1)
+ td = TensorDictParams(td)[d0]
+ td.to_module(module)
+ y = module(*make_input())
+ assert y.dims == (d0,)
+ assert y._tensor.shape[0] == param_batch
+
+
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/tutorials/sphinx_tuto/tensorclass_imagenet.py b/tutorials/sphinx_tuto/tensorclass_imagenet.py
index 7bbf468c9..65ce02517 100644
--- a/tutorials/sphinx_tuto/tensorclass_imagenet.py
+++ b/tutorials/sphinx_tuto/tensorclass_imagenet.py
@@ -166,11 +166,11 @@ def forward(self, x):
index0 = torch.randint(x.shape[-2] - self.h, (*batch, 1), device=x.device)
index0 = index0 + torch.arange(self.h, device=x.device)
index0 = (
- index0.unsqueeze(1).unsqueeze(-1).expand(*batch, 3, self.h, x.shape[-1])
+ index0.unsqueeze(1).unsqueeze(-1).expand((*batch, 3), self.h, x.shape[-1])
)
index1 = torch.randint(x.shape[-1] - self.w, (*batch, 1), device=x.device)
index1 = index1 + torch.arange(self.w, device=x.device)
- index1 = index1.unsqueeze(1).unsqueeze(-2).expand(*batch, 3, self.h, self.w)
+ index1 = index1.unsqueeze(1).unsqueeze(-2).expand((*batch, 3), self.h, self.w)
return x.gather(-2, index0).gather(-1, index1)
@@ -402,11 +402,11 @@ def __call__(self, x: ImageNetData):
##############################################################################
# This shows that much of the overhead is coming from i/o operations rather than the
# transforms, and hence explains how the memory-mapped array helps us load data more
-# efficiently. Check out the `distributed example `__
+# efficiently. Check out the `distributed example `__
# for more context about the other results from these charts.
#
# We can get even better performance with the TensorClass approach by using multiple
# workers to load batches from the memory-mapped array, though this comes with some
# added complexity. See `this example in our benchmarks
-# `__
+# `__
# for an example of how this could work.
diff --git a/version.txt b/version.txt
index d917d3e26..0ea3a944b 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.1.2
+0.2.0