diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100755 index 0000000..3291ea0 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,42 @@ +{ + "name": "ChaiLab", + "build": { + "context": "..", + "dockerfile": "../Dockerfile.chailab", + "target": "chailab-baseimage" + }, + "runArgs": [ + // by default use all GPUs, can be overriden by envvar + "--gpus=${localEnv:DEVBOX_GPU_SPEC:all}", + "--ipc=host", + "-v=/data/instance:/data/instance", + // default container name is chai-lab-container + "--name=chai-lab-${localEnv:DEVBOX_USER:container}", + // set restrictions on CPU and RAM memory usage + "--cpus=60.0", + "--memory=1000g" + ], + "shutdownAction": "none", + "postCreateCommand": "uv pip install -e . && pre-commit install -f", + "customizations": { + "vscode": { + "settings": { + "python.defaultInterpreterPath": "/opt/venv/bin/python" + }, + "extensions": [ + "ms-azuretools.vscode-docker", + "ms-python.python", + "ms-python.vscode-pylance", + "ms-python.mypy-type-checker", + "charliermarsh.ruff", + "ms-toolsai.jupyter", + "arianjamasb.protein-viewer", + "redhat.vscode-yaml", + // very optional git-specific stuff + "arturock.gitstash", + "mhutchie.git-graph", + "GitHub.vscode-pull-request-github" + ] + } + } +} \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..9d8f2d0 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..59a823b --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,9 @@ +## Description + + +## Motivation + + + +## Test plan + diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 0000000..1c0af86 --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,27 @@ +name: Mypy +on: + # Triggered whenever a commit is added to the main branch + push: + branches: + - main + # Triggered whenever a PR is opened or updated + pull_request: +jobs: + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: "pip" + - name: Install dependencies + run: | + # install uv and cpu-only torch + pip install --no-deps uv -r <( cat requirements.in | grep torch) --extra-index-url https://download.pytorch.org/whl/cpu + # install requirements, except torch and potentially nvidia-related stuff + uv pip install --system -r <( cat requirements.in | grep -v nvidia | grep -v torch ) + uv pip install --system --no-deps -e . + - name: Run mypy + run: mypy . diff --git a/.github/workflows/prettier_yaml.yml b/.github/workflows/prettier_yaml.yml new file mode 100644 index 0000000..82dba85 --- /dev/null +++ b/.github/workflows/prettier_yaml.yml @@ -0,0 +1,27 @@ +name: Prettier + +on: + # Triggered whenever a commit is added to the main branch + push: + branches: + - main + # Triggered whenever a PR is opened or updated + pull_request: +jobs: + yaml: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + + - name: Install dependencies + run: npm install --save-dev --save-exact prettier + + - name: Run Prettier to check YAML format + run: npx prettier --check "**/*.yml" "**/*.yaml" diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml new file mode 100644 index 0000000..431edf2 --- /dev/null +++ b/.github/workflows/publish-to-pypi.yml @@ -0,0 +1,27 @@ +name: Deploy to pypi + +on: + release: + types: [created] + +jobs: + publish: + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + pip install . && pip install hatch + - name: Publish to PyPi + env: + HATCH_INDEX_USER: "__token__" + HATCH_INDEX_AUTH: ${{ secrets.PYPI_TOKEN }} + run: | + hatch build --clean && hatch publish diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..dde0f6c --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,27 @@ +name: Ruff +on: + # Triggered whenever a commit is added to the main branch + push: + branches: + - main + # Triggered whenever a PR is opened or updated + pull_request: +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + - name: Install + run: pip install pre-commit + - name: Run pre-commit checks on all files + # run specific ruff pre-commit hooks on all files + run: > + pip install pre-commit + && pre-commit install -f + && pre-commit run ruff --all-files + && pre-commit run ruff-format --all-files diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..72c088d --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm/JetBrains +.idea/ + +# outputs from the model +outputs/ + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a07ede4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +--- +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.6.3 + hooks: + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..08cde26 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,15 @@ +// We use Dev Containers to run the code in a container. As a result, +// VS Code runs extensions in one of two places: locally on the UI / +// client side, or in the container. When adding extensions, take care +// to ensure you add them to the right list: container extensions should +// be added to the list in ../.devcontainer/devcontainer.json, and UI +// extensions should be added to the list below. See +// https://code.visualstudio.com/docs/devcontainers/containers#_managing-extensions +{ + "recommendations": [ + "ms-vscode-remote.remote-containers", + "ms-vscode-remote.remote-ssh", + "ms-vscode-remote.remote-ssh-edit", + "ms-vscode.remote-explorer" + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..672221d --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,16 @@ +{ + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "python.analysis.autoImportCompletions": true, + "terminal.integrated.scrollback": 30000, + "mypy-type-checker.cwd": "${workspaceFolder}", + // pytest setup + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, +} \ No newline at end of file diff --git a/Dockerfile.chailab b/Dockerfile.chailab new file mode 100755 index 0000000..0f616dd --- /dev/null +++ b/Dockerfile.chailab @@ -0,0 +1,79 @@ +FROM ubuntu:22.04 AS chailab-baseimage + +ENV \ + LANG=C.UTF-8 \ + LC_ALL=C.UTF-8 \ + # config for apt + DEBIAN_FRONTEND=noninteractive \ + # default editor for git cli + EDITOR=vim \ + # keep (large) mypy cache outside of working tree + MYPY_CACHE_DIR='/tmp/.chai_lab_mypy_cache' \ + # always flush output from python + PYTHONUNBUFFERED=TRUE \ + # enable fault handler (print tracebacks even after segfault or NCCL errors). + PYTHONFAULTHANDLER=1 \ + # keep __pycache__ out of working tree + PYTHONPYCACHEPREFIX='/tmp/.chai_lab_pycache' + + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get -qq update \ + && apt-get -qq install -y \ + # common things + gnupg ca-certificates wget git curl aria2 lsb-release tzdata \ + rsync sudo tree htop tmux unzip \ + clang \ + # for direct ssh into container + openssh-server socat \ + # provides `fuser` command + psmisc \ + # RDMA/InfiniBand + libibverbs1 librdmacm1 \ + # text editors, needed by git cli + nano vim \ + build-essential libstdc++6 \ + # (run continues) + # stop git from complaining about dubious ownership. + && git config --global --add safe.directory "*" \ + # + # cuda softlinking is needed in podman, but not docker + && ln -s /lib/x86_64-linux-gnu/libcuda.so.1 /lib/x86_64-linux-gnu/libcuda.so \ + && ldconfig /lib/x86_64-linux-gnu/ \ + # setup timezone, to $TZ, ubuntu-specific + # && ln -fs /usr/share/zoneinfo/$TZ /etc/localtime \ + && dpkg-reconfigure --frontend noninteractive tzdata \ + # change default shell to bash (has no effect during building) + && chsh -s /bin/bash + + +ENV \ + # expose CUDA libraries. Now that we don't build anything this is likely redundant + LD_LIBRARY_PATH="/usr/local/cuda/lib64/stubs/:$LD_LIBRARY_PATH" \ + # Set uv timeout to larger value to account for slow download time of nvidia-cudnn-cu12 + UV_HTTP_TIMEOUT=1000 \ + # where virtual env will be installed + VIRTUAL_ENV=/opt/venv + +# Install dependencies in virtualenv +COPY ./requirements.in /tmp/requirements.in +# from https://pythonspeed.com/articles/activate-virtualenv-dockerfile/ +# a trick to have virtualenv "always activated" +ENV PATH="$VIRTUAL_ENV/bin:$PATH" +RUN --mount=type=cache,target=/root/.cache/uv \ + # Install uv + curl -LsSf https://astral.sh/uv/install.sh | sh \ + && $HOME/.cargo/bin/uv venv --python 3.11 $VIRTUAL_ENV \ + # this is sh, not bash, so . not source + && . $VIRTUAL_ENV/bin/activate \ + && $HOME/.cargo/bin/uv pip install uv pip -r /tmp/requirements.in + + +# making sure envvars are set in all shells +RUN echo "PATH=\"$PATH\"" >> /etc/environment \ + && echo "LANG=\"$LANG\"" >> /etc/environment \ + && echo "LC_ALL=\"$LC_ALL\"" >> /etc/environment \ + && echo "LD_LIBRARY_PATH=\"$LD_LIBRARY_PATH\"" >> /etc/environment \ + && echo "EDITOR=\"$EDITOR\"" >> /etc/environment + +# no startup command. \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..7059b35 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,387 @@ +Please read this Chai Discovery Community License Agreement (the +“**License**”) carefully before using the Chai-1 Model software code and +model weights (the “**AI Model**”) and any “**Outputs**” (as defined +below) which is offered by Chai Discovery, Inc. (“**Chai**”) and made +available at the following link +, as they may be updated and +amended from time to time. + +THIS LICENSE GRANTS RIGHTS ONLY TO USE THE ai MODEL, outputs, and any +derivative works (AS DEFINED BELOW) SOLELY FOR NON-COMMERCIAL PURPOSES +(AS DEFINED BELOW). YOU MAY NOT USE THE AI MODEL, OUTPUT, OR ANY +DERIVATIVE WORKS UNDER THE TERMS OF THIS LICENSE FOR ANY COMMERCIAL +PURPOSES OR AS PART OF A SERVICE OFFERING. PLEASE REVIEW SECTION “use +restrictions and aup” below CAREFULLY BEFORE USING THE ai model or any +output. + +By downloading the AI Model, or otherwise using the AI Model or +exercising any of the rights granted hereunder in any manner, You agree +that You have read and agree to be bound by the terms of this License +and that You will use the AI Model only for Non-Commercial Purposes (as +defined below). If You are accessing the AI Model on behalf of an +organization or entity, You represent and warrant that You are +authorized to enter into this License on that organization’s or entity’s +behalf and bind them to the terms of this License (in which case, the +references to “You” and “Your” in this License, except for in this +sentence, refer to that organization or entity) and that such entity is +not a Commercial Entity (as defined below). No rights are granted under +this License to a Commercial Entity. Use of the AI Model is expressly +conditioned upon Your assent to all terms of this License. + + +## **1. Definitions.** + +In addition to other terms defined elsewhere in this License, the terms +below have the following meanings. + +1. “**Commercial Entity**” means any entity engaged, in whole or in + part, in any activity intended for or directed toward commercial + advantage or monetary compensation, including but not limited to the + development of any product or service intended to be sold or made + available for a fee or other economic consideration. For the purpose + of this License, references to a Commercial Entity expressly exclude + any universities, non-profit organizations, non-profit research + institutes, and non-profit educational and government bodies. + +2. “**Contribution**” means any work of authorship, including the + original version of the AI Model and any modifications or additions + to that AI Model or Derivative Works thereof, that is intentionally + submitted to Chai for inclusion in the AI Model by the copyright + owner or by an individual or legal entity authorized to submit on + behalf of the copyright owner. For the purposes of this definition, + "submitted" means any form of electronic, verbal, or written + communication sent to Chai or its representatives, including but not + limited to communication on electronic mailing lists, source code + control systems, and issue tracking systems that are managed by, or + on behalf of, Chai for the purpose of discussing and improving the + AI Model, but excluding Outputs and all communications that are + conspicuously marked or otherwise designated in writing by the + copyright owner as "Not a Contribution." + +3. “**Contributor**” means Chai and any individual or legal entity on + behalf of whom a Contribution has been received by Chai and + subsequently incorporated within the AI Model. + +4. “**Derivative Work**” means any work, whether in Source or Object + form, that is based on (or derived from) the AI Model and for which + the revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the + purposes of this License, Derivative Works shall not include works + that remain separable from, or merely link (or bind by name) to the + interfaces of, the AI Model and Derivative Works thereof. + +5. “**Non-Commercial Purposes**” means uses not intended for or + directed toward commercial advantage or monetary compensation, or + the facilitation of development of any product or service to be sold + or made available for a fee or other economic consideration. For the + avoidance of doubt, the provision of Outputs or Output Derivatives + as a service, or the provision of any other service that utilizes + the AI Model, Derivative Works thereof, Outputs or Derivative + Outputs (even if the service does not provide Outputs or Output + Derivatives), is not a Non-Commercial Purpose, whether or not for a + fee or other economic consideration. + +6. “**Object**” means any form resulting from mechanical transformation + or translation of a Source form, including but not limited to + compiled object code, generated documentation, and conversions to + other media types. + +7. “**Output**” means any output that is made available to You by the + functionality of the AI Model, including but not limited to any + protein sequence, structure prediction, functional annotation, + molecule, descriptions of a molecule, structure predictions, + confidence rankings, intermediate model states, model, model + embeddings, sequence, text, and/or images. + +8. “**Output Derivatives**” means any enhancements, modifications and + derivative works of Outputs (including, but not limited to, any + derivative sequences, structures, or molecules). + +9. “**Source**” means the preferred form for making modifications, + including but not limited to AI Model source code, documentation + source, and configuration files. + +10. “**You**” or “**Your**” means the individual entering into this + License or the organization or entity on whose behalf such + individual is entering into this License. + + +## **2. Grant of License.** + +1. **Copyright License**. Subject to the terms and conditions of this + License, each Contributor hereby grants to You a limited, + non-exclusive, worldwide, royalty-free, non-transferable, + non-sublicensable copyright license to reproduce, prepare Derivative + Works of, publicly display, publicly perform, and distribute the AI + Model and such Derivative Works in Source or Object form solely for + Your Non-Commercial Purposes and subject to the restrictions set + forth in Sections 3 (“Use Restrictions and AUP”) and 4 (“Sharing and + Distribution”). + +2. **Patent License**. Subject to the terms and conditions of this + License, each Contributor hereby grants to You a limited, + non-exclusive, worldwide, royalty-free, non-transferable, + non-sublicensable patent license to make, have made, use, import, + and otherwise transfer the AI Model solely for Your Non-Commercial + Purposes and subject to the restrictions set forth in Sections 3 + (“Use Restrictions and AUP”) and 4 (“Sharing and Distribution”), + where such license applies only to those patent claims licensable by + such Contributor that are necessarily infringed by its + Contribution(s) alone or by combination of its Contribution(s) with + the AI Model to which such Contribution(s) was submitted. + + 1. If You institute patent litigation against any entity (including + a cross-claim or counterclaim in a lawsuit) alleging that the AI + Model or a Contribution incorporated within the AI Model + constitutes direct or contributory patent infringement, then any + patent licenses granted to You under this License for that AI + Model shall terminate as of the date such litigation is filed + and may be reinstituted only by a separate grant of a patent + license in writing from the Contributor. + + +## **3. Use Restrictions and AUP.** + +1. **No Commercial Use**. You may use the AI Model, Contributions, + Derivative Works, Outputs and Output Derivatives only for + Non-Commercial Purposes. Any commercial use of any of the foregoing, + including but not limited to any use by, on behalf of or for any + Commercial Entity or to facilitate the development of any product or + service to be sold or made available for a fee or other economic + consideration, is strictly prohibited under this License. + +2. **Drug Discovery.** You may use the AI Model, Contributions, + Derivative Works, Outputs and Output Derivatives in connection with + drug development or discovery, including but not limited to: (i) the + development (at any stage) or discovery of any drug, medication, + therapeutic, or pharmaceutical of any kind; (ii) any molecular or + biological target, hit or lead identification; (iii) drug candidate + selection; or (iv) lead optimization. + +3. **No Service Offerings**. You may not use the AI Model or any + Contributions, Derivative Works, Outputs or Output Derivatives in or + in connection with the provision of any service offering to third + parties (such as in connection with a hosted service offering that + provides Outputs or Output Derivatives to third parties), regardless + of whether or not such service requires monetary compensation or + other consideration. + +4. **Acceptable Use Policy**. Your use of the AI Model, Derivative + Works, Outputs and Output Derivatives is further subject to the Chai + Discovery Acceptable Use Policy available at + and any additional + use restrictions that may be communicated to You through the AI Model, + as may be updated and amended from time to time (the “**AUP**”), the + terms of which are incorporated herein by reference. In the event of + any conflict between the terms of this License and the terms of the + AUP, the terms that are more restrictive of Your use of the AI + Model, Derivative Works, Outputs and Output Derivatives, as + applicable, shall govern and control. For the purpose of clarity, + the AUP includes, among other things, restrictions that the AI + Model, Derivative Works, Outputs and Output Derivatives may not be + used to train, optimize, improve or otherwise influence the + functionality or performance of any: (i) neural network, tool, + platform and/or artificial intelligence or machine learning models + with more than 10,000 trainable parameters; or (ii) technology for + protein structure prediction or protein, drug, or enzyme design. + + +## **4. Sharing and Distribution.** + +Subject to Section “Use Restrictions and AUP”, You may reproduce and +distribute copies of the AI Model or Derivative Works thereof, with or +without modifications, and in Source or Object form solely for Your +Non-Commercial Purposes, provided that You meet the following +conditions: + +1. You must not distribute copies of the AI Model, Contributions, + Derivative Works, Output, and Output Derivatives, or allow the use + of any reproductions or copies thereof by, on behalf of or for, any + Commercial Entity; + +2. You must restrict the usage of any copies of the AI Model, + Contributions, Derivative Works, Output, and Output Derivatives to + usage for Non-Commercial Purposes; + +3. You must give any other recipients of the AI Model, Contributions, + Derivative Works, Output, and Output Derivatives a copy of this + License; + +4. You must cause any modified files of the AI Model, Contributions, + Derivative Works, Output, and Output Derivatives to carry prominent + notices stating that You changed the files; + +5. You must retain, in the AI Model, Contributions, Derivative Works, + Output, and Output Derivatives that You distribute, all copyright, + patent, trademark, and attribution notices which are included in the + version of the AI Model, Contributions, Derivative Works, Output, + and Output Derivatives provided to You (collectively, “**Attribution + Notices**”), excluding those portions of the Attribution Notices + that do not pertain to any part of the Derivative Works or Output + Derivatives that you distribute, You must include the pertinent + portions of the Attribution Notices in at least one of the following + places: within a NOTICE text file distributed as part of the + Derivative Works or Output Derivatives; within the Source form or + documentation, if provided along with the Derivative Works or Output + Derivatives; or, within a display generated by the Derivative Works, + if and wherever such third-party notices normally appear. The + contents of such Attribution Notices are for informational purposes + only and do not modify this License. You may add Your own + attribution notices within Derivative Works or Output Derivatives + that You distribute, alongside or as an addendum to the pertinent + Attribution Notices, provided that such additional attribution + notices cannot be construed as modifying this License. + +You may add Your own copyright statement to Your modifications and may +provide additional or different license terms and conditions for use, +reproduction, or distribution of Your modifications, or for any such +Derivative Works as a whole, or for Your Services, provided Your use, +reproduction, and distribution of the AI Model, Derivative Works, and +Your Services otherwise complies with the conditions stated in this +License. + + +## **5. Submission of Contributions.** + +Unless You explicitly state otherwise, any Contribution intentionally +submitted for inclusion in the AI Model by You to Chai shall be under +the terms and conditions of this License, without any additional terms +or conditions. Notwithstanding the above, nothing herein shall supersede +or modify the terms of any separate license agreement you may have +executed with Chai regarding such Contributions. + + +## **6. Trademarks.** + +This License does not grant permission to use the trade names, +trademarks, service marks, or product names of Chai, except for +reasonable and customary use in describing the origin of the AI Model +and reproducing the content of the NOTICE file. + + +## **7. Term and Termination.** + +This License applies for so long as the rights licensed in Section 2 +hereunder remain protected by copyright and/or patent law, as +applicable. However, if You fail to comply with this License, then Your +rights under this License terminate automatically. + +1. For the avoidance of doubt, this Section (“Term and Termination”) + does not affect any right that Chai may have to seek remedies for + Your violations of this License. + +2. For the avoidance of doubt, Chai may also offer the AI Model under + separate terms or conditions or stop distributing the AI Model at + any time; however, doing so will not terminate this License. + +3. This sentence of Section “Term and Termination” and Sections + “Submission of Contributions,” “Trademarks,” “Disclaimer of + Warranty,” Limitation of Liability,” and “General” survive + termination of this License. + + +## **8. Disclaimer of Warranty.** + +CHAI PROVIDES THE AI MODEL AND ITS OUTPUTS (AND EACH CONTRIBUTOR +PROVIDES ITS CONTRIBUTIONS) ON AN "AS IS" BASIS, WITHOUT WARRANTY OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY +IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, +QUIET ENJOYMENT AND NON-INFRINGEMENT, AND ANY WARRANTIES ARISING OUT OF +COURSE OF DEALING OR USAGE OF TRADE, ALL OF WHICH ARE HEREBY DISCLAIMED. +CHAI AND ITS CONTRIBUTORS MAKE NO WARRANTY (1) THAT THE AI MODEL, +DERIVATIVE WORKS, OUTPUTS, AND/OR OUTPUT DERIVATIVES WILL MEET YOUR +REQUIREMENTS OR BE AVAILABLE ON AN UNINTERRUPTED, SECURE, OR ERROR-FREE +BASIS, OR (2) REGARDING THE QUALITY, ACCURACY, TIMELINESS, TRUTHFULNESS, +COMPLETENESS OR RELIABILITY OF ANY OUTPUTS. YOU ARE SOLELY RESPONSIBLE +FOR DETERMINING THE APPROPRIATENESS OF USING THE AI MODEL, DERIVATIVE +WORKS, OUTPUTS, AND/OR OUTPUT DERIVATIVES AND ASSUME ANY RISKS +ASSOCIATED WITH YOUR EXERCISE OF PERMISSIONS UNDER THIS AGREEMENT. + + +## **9. Limitation of Liability.** + +TO THE MAXIMUM EXTENT PERMITTED BY LAW, NEITHER CHAI NOR ANY +CONTRIBUTORS WILL BE LIABLE FOR ANY DIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY OR CONSEQUENTIAL DAMAGES, OR DAMAGES FOR LOST PROFITS, LOST +REVENUES, LOST SAVINGS, LOST BUSINESS OPPORTUNITY, LOSS OF DATA OR +GOODWILL, SERVICE INTERRUPTION, COMPUTER DAMAGE OR SYSTEM FAILURE OR THE +COST OF SUBSTITUTE SERVICES OF ANY KIND ARISING OUT OF OR IN CONNECTION +WITH THESE TERMS OR FROM THE USE OF OR INABILITY TO USE THE SERVICES OR +OUTPUT, WHETHER BASED ON WARRANTY, CONTRACT, TORT (INCLUDING +NEGLIGENCE), PRODUCT LIABILITY OR ANY OTHER LEGAL THEORY, AND WHETHER OR +NOT CHAI OR the CONTRIBUTORS HAVE BEEN INFORMED OF THE POSSIBILITY OF +SUCH DAMAGE, EVEN IF A LIMITED REMEDY SET FORTH HEREIN IS FOUND TO HAVE +FAILED OF ITS ESSENTIAL PURPOSE. THE EXCLUSIONS AND LIMITATIONS OF +DAMAGES SET FORTH ABOVE ARE FUNDAMENTAL ELEMENTS OF THE BASIS OF THE +BARGAIN BETWEEN Chai, THE CONTRIBUTORS, AND YOU. + + +## **10. General.** + +1. Entire Agreement. This License constitutes the entire + agreement between You and Chai relating to the subject matter hereof + and supersedes all proposals, understandings, or discussions, + whether written or oral, relating to the subject matter of this + License and all past dealing or industry custom. The failure of + either party to enforce its rights under this License at any time + for any period shall not be construed as a waiver of such rights. + Chai may amend or modify this License from time to time and will use + reasonable efforts to provide You with notice of any material + changes that may negatively impact Your use of the AI Model via the + github page for the AI Model at + , or through another + means made available to You. No other changes, modifications or + waivers to this License will be effective unless in writing and + signed by both parties. + +2. Relationship of Parties. Nothing in this License will be + construed to create a partnership, joint venture or agency + relationship between the parties. Neither party will have the power + to bind the other or to incur obligations on the other’s behalf + without such other party’s prior written consent. Unless otherwise + expressly provided, no provisions of this License are intended or + will be construed to confer upon or give to any person or entity, + other than the parties, any rights, remedies or other benefits under + or by reason of this License.. + +3. Export Control. You will comply fully with all applicable + export laws and regulations of the United States + (“**Export Laws**”) to ensure that neither the AI Model, + Contributions, Derivative Works, Outputs, or Output Derivatives, nor + any technical data related thereto is: (i) exported or re-exported + directly or indirectly in violation of Export Laws; or (ii) used for + any purposes prohibited by the Export Laws, including, but not + limited to, nuclear, chemical, or biological weapons proliferation. + +4. Assignment. This License and the rights and obligations + herein may not be assigned or transferred, in whole or in part, by + You without the prior written consent of Chai. Any assignment in + violation of this provision is void. Chai may freely assign or + transfer this License, in whole or in part. This License shall be + binding upon, and inure to the benefit of, the successors and + permitted assigns of the parties. + +5. Governing Law. This License shall be governed by and + construed under the laws of the State of California and the United + States without regard to conflicts of laws provisions thereof, and + without regard to the Uniform Computer Information Transactions Act. + Any legal action or proceeding arising under this License will be + brought exclusively in the federal or state courts located in the + Northern District of California and the parties irrevocably consent + to the personal jurisdiction and venue therein. + +6. Severability. If any provision of this License is held to be + invalid, illegal or unenforceable in any respect, that provision + shall be limited or eliminated to the minimum extent necessary so + that this License otherwise remains in full force and effect and + enforceable. + + +## **11. Additional License Rights.** + +If You are interested in using the AI Model or Outputs for purposes +beyond the rights granted under this License (for example, if you would +like to use the AI Model or Outputs for commercial purposes), you may +contact Chai at . Any such use in excess +of the rights granted herein to You must be subject to a written +agreement between Chai and You. diff --git a/README.md b/README.md new file mode 100644 index 0000000..83f50e0 --- /dev/null +++ b/README.md @@ -0,0 +1,64 @@ +# Chai-1 + +Chai-1 is a multi-modal foundation model for molecular structure prediction that performs at the state-of-the-art across a variety of benchmarks. Chai-1 enables unified prediction of proteins, small molecules, DNA, RNA, glycosylations, and more. + +

+ +

+ +For more information on the model's performance and capabilities, see our [technical report](https://chaiassets.com/chai-1/paper/technical_report.pdf). + +## Installation + +```shell +pip install chai_lab +``` + +This Python package requires Linux, and a GPU with CUDA. + + +## Running the model + +The model accepts inputs in the FASTA file format, and allows you to specify the number of trunk recycles and diffusion timesteps via the `chai_lab.chai1.run_inference` function. By default, the model generates five sample predictions, and uses embeddings without MSAs or templates. + +The following script demonstrates how to provide inputs to the model, and obtain a list of PDB files for downstream analysis: + +```shell +python examples/predict_structure.py +``` + +For more advanced use cases, we also expose the `chai_lab.chai1.run_folding_on_context`, which allows users to construct an `AllAtomFeatureContext` manually. This allows users to specify their own templates, MSAs, embeddings, and constraints. We currently provide an example of how to construct an embeddings context, and will be releasing helper methods to build MSA and templates contexts soon. + +## ⚡ Try it online + +We provide a [web server](https://lab.chaidiscovery.com) so you can test the Chai-1 model right from your browser, without any setup. + +

+ +

+ +## 💬 Feedback + +Found a 🐞? Please report it in GitHub [issues](https://github.com/chaidiscovery/chai-lab/issues). + +We welcome community testing and feedback. To share observations about the model's performance, please reach via [GitHub discussions](https://github.com/chaidiscovery/chai-lab/discussions), or [via email](mailto:feedback@chaidiscovery.com). + +## 🛠️ Development + +We use [devcontainers](https://code.visualstudio.com/docs/devcontainers/containers) in development, which helps us ensure we work in identical environments. We recommend working inside a devcontainer if you want to make a contribution to this repository. + +Devcontainers work on local Linux setup, and on remote machines with over an SSH connection. + +## Status + +Since this is an initial release, we expect to make some breaking changes to the API and are not guaranteeing backwards compatibility. We recommend pinning the current version in your requirements, i.e.: + +``` +chai_lab==0.0.1 +``` + +## Licence + +See [LICENSE.md](LICENSE.md). + +To discuss commercial use of our models, reach us [via email](mailto:partnerships@chaidiscovery.com). diff --git a/assets/chailab_online_screenshot.png b/assets/chailab_online_screenshot.png new file mode 100644 index 0000000..8ad446f Binary files /dev/null and b/assets/chailab_online_screenshot.png differ diff --git a/assets/performance_barplot.png b/assets/performance_barplot.png new file mode 100644 index 0000000..b996807 Binary files /dev/null and b/assets/performance_barplot.png differ diff --git a/chai_lab/__init__.py b/chai_lab/__init__.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/chai_lab/__init__.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py new file mode 100644 index 0000000..e28e293 --- /dev/null +++ b/chai_lab/chai1.py @@ -0,0 +1,682 @@ +# %% +import math +from dataclasses import dataclass +from pathlib import Path + +import torch +import torch.export +from einops import einsum, rearrange, repeat +from torch import Tensor +from tqdm import tqdm + +from chai_lab.data.collate.collate import Collate +from chai_lab.data.collate.utils import AVAILABLE_MODEL_SIZES +from chai_lab.data.dataset.all_atom_feature_context import ( + MAX_MSA_DEPTH, + MAX_NUM_TEMPLATES, + AllAtomFeatureContext, +) +from chai_lab.data.dataset.constraints.constraint_context import ConstraintContext +from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext +from chai_lab.data.dataset.embeddings.esm import get_esm_embedding_context +from chai_lab.data.dataset.inference_dataset import load_chains_from_raw, read_inputs +from chai_lab.data.dataset.msas.msa_context import MSAContext +from chai_lab.data.dataset.structure.all_atom_structure_context import ( + AllAtomStructureContext, +) +from chai_lab.data.dataset.templates.context import TemplateContext +from chai_lab.data.features.feature_factory import FeatureFactory +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.atom_element import AtomElementOneHot +from chai_lab.data.features.generators.atom_name import AtomNameOneHot +from chai_lab.data.features.generators.base import EncodingType +from chai_lab.data.features.generators.blocked_atom_pair_distances import ( + BlockedAtomPairDistances, + BlockedAtomPairDistogram, +) +from chai_lab.data.features.generators.docking import DockingConstraintGenerator +from chai_lab.data.features.generators.esm_generator import ESMEmbeddings +from chai_lab.data.features.generators.identity import Identity +from chai_lab.data.features.generators.is_cropped_chain import ChainIsCropped +from chai_lab.data.features.generators.missing_chain_contact import MissingChainContact +from chai_lab.data.features.generators.msa import ( + IsPairedMSAGenerator, + MSADataSourceGenerator, + MSADeletionMeanGenerator, + MSADeletionValueGenerator, + MSAFeatureGenerator, + MSAHasDeletionGenerator, + MSAProfileGenerator, +) +from chai_lab.data.features.generators.ref_pos import RefPos +from chai_lab.data.features.generators.relative_chain import RelativeChain +from chai_lab.data.features.generators.relative_entity import RelativeEntity +from chai_lab.data.features.generators.relative_sep import RelativeSequenceSeparation +from chai_lab.data.features.generators.relative_token import RelativeTokenSeparation +from chai_lab.data.features.generators.residue_type import ResidueType +from chai_lab.data.features.generators.structure_metadata import ( + IsDistillation, + TokenBFactor, + TokenPLDDT, +) +from chai_lab.data.features.generators.templates import ( + TemplateDistogramGenerator, + TemplateMaskGenerator, + TemplateResTypeGenerator, + TemplateUnitVectorGenerator, +) +from chai_lab.data.features.generators.token_dist_restraint import ( + TokenDistanceRestraint, +) +from chai_lab.data.features.generators.token_pair_pocket_restraint import ( + TokenPairPocketRestraint, +) +from chai_lab.data.io.pdb_utils import write_pdbs_from_outputs +from chai_lab.model.diffusion_schedules import InferenceNoiseSchedule +from chai_lab.model.utils import center_random_augmentation +from chai_lab.ranking.frames import get_frames_and_mask +from chai_lab.ranking.rank import SampleRanking, rank +from chai_lab.utils.paths import chai1_component +from chai_lab.utils.tensor_utils import move_data_to_device, set_seed, und_self +from chai_lab.utils.typing import Float, typecheck + + +class UnsupportedInputError(RuntimeError): + pass + + +def load_exported(comp_key: str, device: torch.device) -> torch.nn.Module: + local_path = chai1_component(comp_key) + exported_program = torch.export.load(local_path) + return exported_program.module().to(device) + + +# %% +# Create feature factory + +feature_generators = dict( + RelativeSequenceSeparation=RelativeSequenceSeparation(sep_bins=None), + RelativeTokenSeparation=RelativeTokenSeparation(r_max=32), + RelativeEntity=RelativeEntity(), + RelativeChain=RelativeChain(), + ResidueType=ResidueType( + min_corrupt_prob=0.0, + max_corrupt_prob=0.0, + num_res_ty=32, + key="token_residue_type", + ), + ESMEmbeddings=ESMEmbeddings(), # TODO: this can probably be the identity + BlockedAtomPairDistogram=BlockedAtomPairDistogram(), + InverseSquaredBlockedAtomPairDistances=BlockedAtomPairDistances( + transform="inverse_squared", + encoding_ty=EncodingType.IDENTITY, + ), + AtomRefPos=RefPos(), + AtomRefCharge=Identity( + key="inputs/atom_ref_charge", + ty=FeatureType.ATOM, + dim=1, + can_mask=False, + ), + AtomRefMask=Identity( + key="inputs/atom_ref_mask", + ty=FeatureType.ATOM, + dim=1, + can_mask=False, + ), + AtomRefElement=AtomElementOneHot(max_atomic_num=128), + AtomNameOneHot=AtomNameOneHot(), + TemplateMask=TemplateMaskGenerator(), + TemplateUnitVector=TemplateUnitVectorGenerator(), + TemplateResType=TemplateResTypeGenerator(), + TemplateDistogram=TemplateDistogramGenerator(), + TokenDistanceRestraint=TokenDistanceRestraint( + include_probability=0.0, + size=0.33, + min_dist=6.0, + max_dist=30.0, + num_rbf_radii=6, + ), + DockingConstraintGenerator=DockingConstraintGenerator( + include_probability=0.0, + structure_dropout_prob=0.75, + chain_dropout_prob=0.75, + ), + TokenPairPocketRestraint=TokenPairPocketRestraint( + size=0.33, + include_probability=0.0, + min_dist=6.0, + max_dist=20.0, + coord_noise=0.0, + num_rbf_radii=6, + ), + MSAProfile=MSAProfileGenerator(), + MSADeletionMean=MSADeletionMeanGenerator(), + IsDistillation=IsDistillation(), + TokenBFactor=TokenBFactor(include_prob=0.0), + TokenPLDDT=TokenPLDDT(include_prob=0.0), + ChainIsCropped=ChainIsCropped(), + MissingChainContact=MissingChainContact(contact_threshold=6.0), + MSAOneHot=MSAFeatureGenerator(), + MSAHasDeletion=MSAHasDeletionGenerator(), + MSADeletionValue=MSADeletionValueGenerator(), + IsPairedMSA=IsPairedMSAGenerator(), + MSADataSource=MSADataSourceGenerator(), +) +feature_factory = FeatureFactory(feature_generators) + +# %% +# Config + + +# Load input fasta +example_fasta1 = """>protein|7WJ3_A +AGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASPRGEPRAPWVEQEGPEYWDRETQKYKRQAQTDRVSLRNLRGYYNQSEAGSHTLQWMFGCDLGPDGRLLRGYDQSAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAAREAEQRRAYLEGTCVEWLRRYLENGKETLQRAEHPKTHVTHHPVSDHEATLRCWALGFYPAEITLTWQWDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPEPLTLRWEP +""" + +example_fasta2 = """>protein|protein-one +[ACE]GE[AIB][AIB][AIB][AIB]KE[AIB][AIB][AIB][AIB]KE[AIB][AIB][AIB][AIB]KE[AIB][AIB][AIB][AIB]K[AIB][AIB][AIB]WKG[NH2] +>protein|protein-two +[KCJ][SEP][PPN][B3S][BAL][PPN]K[NH2] +>ligand|8HA0_J||prot-7VVJ_F|prot-8HAO_H +CC(C)CCCC(C)C1CCC2C1(CCC3C2CC=C4C3(CCC(C4)O)C)C +>protein|7WJ3_A||prot-7SR4_A|prot-short-GAAL +AGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASPRGEPRAPWVEQEGP +>ligand|some-smiles +CC +>rna|some-rna +AUGCGAUACGUA +>dna|some-dna +ATGCGTACGTAC +""" + + +class DiffusionConfig: + S_churn: float = 80 + S_tmin: float = 4e-4 + S_tmax: float = 80.0 + S_noise: float = 1.003 + sigma_data: float = 16.0 + second_order: bool = True + + +# %% +# Input validation + + +def raise_if_too_many_tokens(n_actual_tokens: int): + if n_actual_tokens > max(AVAILABLE_MODEL_SIZES): + raise UnsupportedInputError( + f"Too many tokens in input: {n_actual_tokens} > {max(AVAILABLE_MODEL_SIZES)}. " + "Please limit the length of the input sequence." + ) + + +def raise_if_too_many_templates(n_actual_templates: int): + if n_actual_templates > MAX_NUM_TEMPLATES: + raise UnsupportedInputError( + f"Too many templates in input: {n_actual_templates} > {MAX_NUM_TEMPLATES}. " + "Please limit the number of templates." + ) + + +def raise_if_msa_too_deep(msa_depth: int): + if msa_depth > MAX_MSA_DEPTH: + raise UnsupportedInputError( + f"MSA to deep: {msa_depth} > {MAX_MSA_DEPTH}. " + "Please limit the MSA depth." + ) + + +# %% +# Inference logic + + +@torch.no_grad() +def run_inference( + fasta_file: Path, + output_dir: Path, + use_esm_embeddings: bool = True, + # expose some params for easy tweaking + num_trunk_recycles: int = 3, + num_diffn_timesteps: int = 2, + seed: int | None = None, + device: torch.device | None = None, +) -> list[Path]: + # Prepare inputs + assert fasta_file.exists(), fasta_file + fasta_inputs = read_inputs(fasta_file, length_limit=None) + assert len(fasta_inputs) > 0, "No inputs found in fasta file" + + # Load structure context + chains = load_chains_from_raw(fasta_inputs) + contexts = [c.structure_context for c in chains] + merged_context = AllAtomStructureContext.merge(contexts) + n_actual_tokens = merged_context.num_tokens + raise_if_too_many_tokens(n_actual_tokens) + + # Load MSAs + msa_context = MSAContext.create_empty( + n_tokens=n_actual_tokens, + depth=MAX_MSA_DEPTH, + ) + main_msa_context = MSAContext.create_empty( + n_tokens=n_actual_tokens, + depth=MAX_MSA_DEPTH, + ) + + # Load templates + template_context = TemplateContext.empty( + n_tokens=n_actual_tokens, + n_templates=MAX_NUM_TEMPLATES, + ) + + # Load ESM embeddings + if use_esm_embeddings: + embedding_context = get_esm_embedding_context(chains, device=device) + else: + embedding_context = EmbeddingContext.empty(n_tokens=n_actual_tokens) + + # Constraints + constraint_context = ConstraintContext.empty() + + # Build final feature context + feature_context = AllAtomFeatureContext( + chains=chains, + structure_context=merged_context, + msa_context=msa_context, + main_msa_context=main_msa_context, + template_context=template_context, + embedding_context=embedding_context, + constraint_context=constraint_context, + ) + + output_paths, scores, ranking_data = run_folding_on_context( + feature_context, + output_dir=output_dir, + num_trunk_recycles=num_trunk_recycles, + num_diffn_timesteps=num_diffn_timesteps, + seed=seed, + device=device, + ) + return output_paths + + +def _bin_centers(min_bin: float, max_bin: float, no_bins: int) -> Tensor: + return torch.linspace(min_bin, max_bin, 2 * no_bins + 1)[1::2] + + +@typecheck +@dataclass(frozen=True) +class ConfidenceScores: + # Predicted aligned error(PAE) + pae: Float[Tensor, "bs num_tokens num_tokens"] + + # Predicted distance error (PDE) + pde: Float[Tensor, "bs num_tokens num_tokens"] + + # Predicted local distance difference test (pLDDT) + plddt: Float[Tensor, "bs num_tokens"] + + +@torch.no_grad() +def run_folding_on_context( + feature_context: AllAtomFeatureContext, + output_dir: Path, + # expose some params for easy tweaking + num_trunk_recycles: int = 3, + num_diffn_timesteps: int = 200, + seed: int | None = None, + device: torch.device | None = None, +) -> tuple[list[Path], ConfidenceScores, list[SampleRanking]]: + """ + Function for in-depth explorations. + User completely controls folding inputs. + """ + # Set seed + if seed is not None: + set_seed([seed]) + + if device is None: + device = torch.device("cuda:0") + + ## + ## Validate inputs + ## + + n_actual_tokens = feature_context.structure_context.num_tokens + raise_if_too_many_tokens(n_actual_tokens) + raise_if_too_many_templates(feature_context.template_context.num_templates) + raise_if_msa_too_deep(feature_context.msa_context.depth) + raise_if_msa_too_deep(feature_context.main_msa_context.depth) + + ## + ## Prepare batch + ## + + # Collate inputs into batch + collator = Collate( + feature_factory=feature_factory, + num_key_atoms=128, + num_query_atoms=32, + ) + + feature_contexts = [feature_context] + batch_size = len(feature_contexts) + batch = collator(feature_contexts) + batch = move_data_to_device(batch, device=device) + + # Get features and inputs from batch + features = {name: feature for name, feature in batch["features"].items()} + inputs = batch["inputs"] + block_indices_h = inputs["block_atom_pair_q_idces"] + block_indices_w = inputs["block_atom_pair_kv_idces"] + atom_single_mask = inputs["atom_exists_mask"] + atom_token_indices = inputs["atom_token_index"].long() + token_single_mask = inputs["token_exists_mask"] + token_pair_mask = und_self(token_single_mask, "b i, b j -> b i j") + token_reference_atom_index = inputs["token_ref_atom_index"] + atom_within_token_index = inputs["atom_within_token_index"] + msa_mask = inputs["msa_mask"] + template_input_masks = und_self( + inputs["template_mask"], "b t n1, b t n2 -> b t n1 n2" + ) + block_atom_pair_mask = inputs["block_atom_pair_mask"] + + ## + ## Load exported models + ## + + # Model is size-specific + model_size = min(x for x in AVAILABLE_MODEL_SIZES if n_actual_tokens <= x) + + feature_embedding = load_exported(f"{model_size}/feature_embedding.pt2", device) + token_input_embedder = load_exported( + f"{model_size}/token_input_embedder.pt2", device + ) + trunk = load_exported(f"{model_size}/trunk.pt2", device) + diffusion_module = load_exported(f"{model_size}/diffusion_module.pt2", device) + confidence_head = load_exported(f"{model_size}/confidence_head.pt2", device) + + ## + ## Run the features through the feature embedder + ## + + embedded_features = feature_embedding.forward(**features) + token_single_input_feats = embedded_features["TOKEN"] + token_pair_input_feats, token_pair_structure_input_features = embedded_features[ + "TOKEN_PAIR" + ].chunk(2, dim=-1) + atom_single_input_feats, atom_single_structure_input_features = embedded_features[ + "ATOM" + ].chunk(2, dim=-1) + block_atom_pair_input_feats, block_atom_pair_structure_input_feats = ( + embedded_features["ATOM_PAIR"].chunk(2, dim=-1) + ) + template_input_feats = embedded_features["TEMPLATES"] + msa_input_feats = embedded_features["MSA"] + + ## + ## Run the inputs through the token input embedder + ## + + token_input_embedder_outputs: tuple[Tensor, ...] = token_input_embedder.forward( + token_single_input_feats=token_single_input_feats, + token_pair_input_feats=token_pair_input_feats, + atom_single_input_feats=atom_single_input_feats, + block_atom_pair_feat=block_atom_pair_input_feats, + block_atom_pair_mask=block_atom_pair_mask, + block_indices_h=block_indices_h, + block_indices_w=block_indices_w, + atom_single_mask=atom_single_mask, + atom_token_indices=atom_token_indices, + ) + token_single_initial_repr, token_single_structure_input, token_pair_initial_repr = ( + token_input_embedder_outputs + ) + + ## + ## Run the input representations through the trunk + ## + + # Recycle the representations by feeding the output back into the trunk as input for + # the subsequent recycle + token_single_trunk_repr = token_single_initial_repr + token_pair_trunk_repr = token_pair_initial_repr + for _ in tqdm(range(num_trunk_recycles), desc="Trunk recycles"): + (token_single_trunk_repr, token_pair_trunk_repr) = trunk.forward( + token_single_trunk_initial_repr=token_single_initial_repr, + token_pair_trunk_initial_repr=token_pair_initial_repr, + token_single_trunk_repr=token_single_trunk_repr, # recycled + token_pair_trunk_repr=token_pair_trunk_repr, # recycled + msa_input_feats=msa_input_feats, + msa_mask=msa_mask, + template_input_feats=template_input_feats, + template_input_masks=template_input_masks, + token_single_mask=token_single_mask, + token_pair_mask=token_pair_mask, + ) + + ## + ## Denoise the trunk representation by passing it through the diffusion module + ## + + def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor: + atom_noised_coords = rearrange( + atom_pos, "(b s) ... -> b s ...", s=s + ).contiguous() + noise_sigma = repeat(sigma, " -> b s", b=batch_size, s=s) + return diffusion_module.forward( + token_single_initial_repr=token_single_structure_input.float(), + token_pair_initial_repr=token_pair_structure_input_features.float(), + token_single_trunk_repr=token_single_trunk_repr.float(), + token_pair_trunk_repr=token_pair_trunk_repr.float(), + atom_single_input_feats=atom_single_structure_input_features.float(), + atom_block_pair_input_feats=block_atom_pair_structure_input_feats.float(), + atom_single_mask=atom_single_mask, + atom_block_pair_mask=block_atom_pair_mask, + token_single_mask=token_single_mask, + block_indices_h=block_indices_h, + block_indices_w=block_indices_w, + atom_noised_coords=atom_noised_coords.float(), + noise_sigma=noise_sigma.float(), + atom_token_indices=atom_token_indices, + ) + + num_diffn_samples = 5 # Fixed at export time + inference_noise_schedule = InferenceNoiseSchedule( + s_max=DiffusionConfig.S_tmax, + s_min=4e-4, + p=7.0, + sigma_data=DiffusionConfig.sigma_data, + ) + sigmas = inference_noise_schedule.get_schedule( + device=device, num_timesteps=num_diffn_timesteps + ) + gammas = torch.where( + (sigmas >= DiffusionConfig.S_tmin) & (sigmas <= DiffusionConfig.S_tmax), + min(DiffusionConfig.S_churn / num_diffn_timesteps, math.sqrt(2) - 1), + 0.0, + ) + + sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) + + # Initial atom positions + _, num_atoms = atom_single_mask.shape + atom_pos = sigmas[0] * torch.randn( + batch_size * num_diffn_samples, num_atoms, 3, device=device + ) + + for sigma_curr, sigma_next, gamma_curr in tqdm( + sigmas_and_gammas, desc="Diffusion steps" + ): + # Center coords + atom_pos = center_random_augmentation( + atom_pos, + atom_single_mask=repeat( + atom_single_mask, + "b a -> (b s) a", + s=num_diffn_samples, + ), + ) + + # Alg 2. lines 4-6 + noise = DiffusionConfig.S_noise * torch.randn( + atom_pos.shape, device=atom_pos.device + ) + sigma_hat = sigma_curr + gamma_curr * sigma_curr + atom_pos_noise = (sigma_hat**2 - sigma_curr**2).clamp_min(1e-6).sqrt() + atom_pos_hat = atom_pos + noise * atom_pos_noise + + # Lines 7-8 + denoised_pos = _denoise( + atom_pos=atom_pos_hat, + sigma=sigma_hat, + s=num_diffn_samples, + ) + d_i = (atom_pos_hat - denoised_pos) / sigma_hat + atom_pos = atom_pos_hat + (sigma_next - sigma_hat) * d_i + + # Lines 9-11 + if sigma_next != 0 and DiffusionConfig.second_order: # second order update + denoised_pos = _denoise( + atom_pos, + sigma=sigma_next, + s=num_diffn_samples, + ) + d_i_prime = (atom_pos - denoised_pos) / sigma_next + atom_pos = atom_pos + (sigma_next - sigma_hat) * ((d_i_prime + d_i) / 2) + + ## + ## Run the confidence model + ## + + confidence_outputs: list[tuple[Tensor, ...]] = [ + confidence_head.forward( + token_single_input_repr=token_single_initial_repr, + token_single_trunk_repr=token_single_trunk_repr, + token_pair_trunk_repr=token_pair_trunk_repr, + token_single_mask=token_single_mask, + atom_single_mask=atom_single_mask, + atom_coords=atom_pos[s : s + 1], + token_reference_atom_index=token_reference_atom_index, + atom_token_index=atom_token_indices, + atom_within_token_index=atom_within_token_index, + ) + for s in range(num_diffn_samples) + ] + + pae_logits = torch.cat( + [x[0] for x in confidence_outputs], + ) + pde_logits = torch.cat( + [x[1] for x in confidence_outputs], + ) + plddt_logits = torch.cat( + [x[2] for x in confidence_outputs], + ) + + assert atom_pos.shape[0] == num_diffn_samples + assert pae_logits.shape[0] == num_diffn_samples + + ## + ## Write the outputs + ## + + output_paths: list[Path] = [] + for idx in range(num_diffn_samples): + sample_atom_pos = atom_pos[idx : idx + 1] + # sample_confidence = confidences["plddt_values"][idx : idx + 1] + # trunk_sample_idx = preds["trunk_sample_index"][idx].item() + trunk_sample_idx = 0 + out_basename = f"pred.model_trunk_{trunk_sample_idx}_idx_{idx}.pdb" + pdb_out_path = output_dir / out_basename + print(f"Writing output to {pdb_out_path}") + write_pdbs_from_outputs( + coords=sample_atom_pos, + # bfactors=sample_confidence, + output_batch=move_data_to_device(inputs, torch.device("cpu")), + write_path=pdb_out_path, + ) + output_paths.append(pdb_out_path) + + def softmax_einsum_and_cpu( + logits: Tensor, bin_mean: Tensor, pattern: str + ) -> Tensor: + # utility to compute score from bin logits + res = einsum( + logits.float().softmax(dim=-1), bin_mean.to(logits.device), pattern + ) + return res.to(device="cpu") + + token_mask_1d = rearrange(token_single_mask, "1 b -> b") + + pae_scores = softmax_einsum_and_cpu( + pae_logits[:, token_mask_1d, :, :][:, :, token_mask_1d, :], + _bin_centers(0.0, 32.0, 64), + "b n1 n2 d, d -> b n1 n2", + ) + + pde_scores = softmax_einsum_and_cpu( + pde_logits[:, token_mask_1d, :, :][:, :, token_mask_1d, :], + _bin_centers(0.0, 32.0, 64), + "b n1 n2 d, d -> b n1 n2", + ) + + plddt_scores_atom = softmax_einsum_and_cpu( + plddt_logits, + _bin_centers(0, 1, plddt_logits.shape[-1]), + "b a d, d -> b a", + ) + + # converting to per-token + [mask] = atom_single_mask.cpu() + [indices] = atom_token_indices.cpu() + + def avg_1d(x): + n = torch.bincount(indices[mask], weights=x[mask]) + d = torch.bincount(indices[mask]).clamp(min=1) + return n / d + + plddt_scores = torch.stack([avg_1d(x) for x in plddt_scores_atom]) + + confidence_scores = ConfidenceScores( + pae=pae_scores, + pde=pde_scores, + plddt=plddt_scores, + ) + + ranking_data: list[SampleRanking] = [] + + for s in range(atom_pos.shape[0]): + _, valid_frames_mask = get_frames_and_mask( + atom_pos[s : s + 1], + inputs["token_asym_id"], + inputs["token_residue_index"], + inputs["token_backbone_frame_mask"], + inputs["token_centre_atom_index"], + inputs["token_exists_mask"], + inputs["atom_exists_mask"], + inputs["token_backbone_frame_index"], + inputs["atom_token_index"], + ) + + ranking_data.append( + rank( + atom_pos[s : s + 1], + atom_mask=inputs["atom_exists_mask"], + atom_token_index=inputs["atom_token_index"], + token_exists_mask=inputs["token_exists_mask"], + token_asym_id=inputs["token_asym_id"], + token_entity_type=inputs["token_entity_type"], + token_valid_frames_mask=valid_frames_mask, + lddt_logits=plddt_logits[s : s + 1], + lddt_bin_centers=_bin_centers(0, 1, plddt_logits.shape[-1]).to( + plddt_logits.device + ), + pae_logits=pae_logits[s : s + 1], + pae_bin_centers=_bin_centers(0.0, 32.0, 64).to(pae_logits.device), + ) + ) + + return output_paths, confidence_scores, ranking_data diff --git a/chai_lab/data/__init__.py b/chai_lab/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/collate/__init__.py b/chai_lab/data/collate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/collate/collate.py b/chai_lab/data/collate/collate.py new file mode 100644 index 0000000..30e5f53 --- /dev/null +++ b/chai_lab/data/collate/collate.py @@ -0,0 +1,93 @@ +import dataclasses +import logging +from typing import Any + +import torch + +from chai_lab.data.collate.utils import get_pad_sizes +from chai_lab.data.dataset.all_atom_feature_context import AllAtomFeatureContext +from chai_lab.data.features.feature_factory import FeatureFactory +from chai_lab.model.utils import ( + get_block_atom_pair_mask, + get_qkv_indices_for_blocks, +) +from chai_lab.utils.dict import list_dict_to_dict_list + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class Collate: + feature_factory: FeatureFactory + num_query_atoms: int + num_key_atoms: int + + def __call__( + self, + feature_contexts: list[AllAtomFeatureContext], + ) -> dict[str, Any]: + raw_batch = self._collate(feature_contexts) + prepared_batch = self._post_collate(raw_batch) + return prepared_batch + + def _collate( + self, + feature_contexts: list[AllAtomFeatureContext], + ) -> dict[str, Any]: + # Get the pad sizes, finding the max number of tokens/atoms/bonds in the batch. + pad_sizes = get_pad_sizes([p.structure_context for p in feature_contexts]) + + # Pad each feature context to the max sizes + padded_feature_contexts = [ + feature_context.pad( + n_tokens=pad_sizes.n_tokens, + n_atoms=pad_sizes.n_atoms, + ) + for feature_context in feature_contexts + ] + + # Convert all the input data into dicts, for each feature context + inputs_per_context = [e.to_dict() for e in padded_feature_contexts] + + # Stack the dict inputs into a single batch dict, across all feature contexts + batched_inputs = { + k: (torch.stack(v, dim=0) if isinstance(v[0], torch.Tensor) else v) + for k, v in list_dict_to_dict_list(inputs_per_context).items() + } + + # Make a batch dict + batch = dict(inputs=batched_inputs) + return batch + + def _post_collate(self, raw_batch: dict[str, Any]) -> dict[str, Any]: + """ + takes a list of processed multi-chain systems, + returns a dictionary with batched tensors to feed in the model forward method + and any other necessary data for the task/losses + """ + raw_b_i = raw_batch["inputs"] + + # prepare atom pair block data: + atom_exists_mask = raw_b_i["atom_exists_mask"] + block_q_atom_idces, block_kv_atom_idces, kv_mask = get_qkv_indices_for_blocks( + atom_exists_mask.shape[1], + self.num_query_atoms, + self.num_key_atoms, + atom_exists_mask.device, + ) + block_atom_pair_mask = get_block_atom_pair_mask( + atom_single_mask=raw_b_i["atom_ref_mask"], + q_idx=block_q_atom_idces, + kv_idx=block_kv_atom_idces, + kv_is_wrapped_mask=kv_mask, + ) + raw_b_i |= dict( + block_atom_pair_q_idces=block_q_atom_idces, + block_atom_pair_kv_idces=block_kv_atom_idces, + block_atom_pair_mask=block_atom_pair_mask, + ) + + # Compute features + raw_batch["features"] = self.feature_factory.generate(raw_batch) + + return raw_batch diff --git a/chai_lab/data/collate/utils.py b/chai_lab/data/collate/utils.py new file mode 100644 index 0000000..0cb7431 --- /dev/null +++ b/chai_lab/data/collate/utils.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass + +from chai_lab.data.dataset.structure.all_atom_structure_context import ( + AllAtomStructureContext, +) + +# static graph is exported for different n_tokens, +# we pad to the closest one +AVAILABLE_MODEL_SIZES = [256, 384, 512, 768, 1024, 2048] + + +@dataclass(frozen=True) +class PadSizes: + n_tokens: int + n_atoms: int + + +def pad_size(max_in_batch: int, allowed_sizes: list[int]) -> int: + """pads to the smallest allowed size""" + max_allowed_size = allowed_sizes[-1] + if max_in_batch > max_allowed_size: + raise ValueError(f"{max_in_batch=} > {max_allowed_size=}") + return min(n for n in allowed_sizes if n >= max_in_batch) + + +def get_pad_sizes(contexts: list[AllAtomStructureContext]) -> PadSizes: + max_n_tokens = max(context.num_tokens for context in contexts) + n_tokens = pad_size(max_n_tokens, AVAILABLE_MODEL_SIZES) + + max_n_atoms = max(context.num_atoms for context in contexts) + n_atoms = 23 * n_tokens + assert max_n_atoms <= n_atoms + + return PadSizes(n_tokens=n_tokens, n_atoms=n_atoms) diff --git a/chai_lab/data/dataset/__init__.py b/chai_lab/data/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/dataset/all_atom_feature_context.py b/chai_lab/data/dataset/all_atom_feature_context.py new file mode 100644 index 0000000..6523f9a --- /dev/null +++ b/chai_lab/data/dataset/all_atom_feature_context.py @@ -0,0 +1,92 @@ +import logging +from dataclasses import dataclass +from typing import Any, Final + +from chai_lab.data.dataset.constraints.constraint_context import ConstraintContext +from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext +from chai_lab.data.dataset.msas.msa_context import MSAContext +from chai_lab.data.dataset.structure.all_atom_structure_context import ( + AllAtomStructureContext, +) +from chai_lab.data.dataset.structure.chain import Chain +from chai_lab.data.dataset.templates.context import TemplateContext + +logger = logging.getLogger(__name__) + +MAX_MSA_DEPTH: Final[int] = 16_384 +MAX_NUM_TEMPLATES: Final[int] = 4 + + +@dataclass +class AllAtomFeatureContext: + """ + Feature contexts are produced by datasets. Multiple feature contexts are passed to + collator, which transforms them into a batch (by padding and stacking them). + """ + + # Metadata: these are not padded and batched + chains: list[Chain] + # Contexts: these are what get padded and batched + structure_context: AllAtomStructureContext + msa_context: MSAContext + main_msa_context: MSAContext + template_context: TemplateContext + embedding_context: EmbeddingContext | None + constraint_context: ConstraintContext + + def __str__(self) -> str: + chains_info = [str(chain) for chain in self.chains] + return f"{self.__class__.__name__}(chains={chains_info})" + + def pad( + self, + n_tokens: int, + n_atoms: int, + ) -> "AllAtomFeatureContext": + return AllAtomFeatureContext( + # Metadata + chains=self.chains, + # Contexts + structure_context=self.structure_context.pad( + n_tokens=n_tokens, + n_atoms=n_atoms, + ), + msa_context=self.msa_context.pad( + max_num_tokens=n_tokens, + max_msa_depth=MAX_MSA_DEPTH, + ), + main_msa_context=self.main_msa_context.pad( + max_num_tokens=n_tokens, + max_msa_depth=MAX_MSA_DEPTH, + ), + template_context=self.template_context.pad( + max_tokens=n_tokens, + max_templates=MAX_NUM_TEMPLATES, + ), + embedding_context=( + self.embedding_context.pad(max_tokens=n_tokens) + if self.embedding_context is not None + else None + ), + constraint_context=self.constraint_context.pad(max_tokens=n_tokens), + ) + + def to_dict(self) -> dict[str, Any]: + msa_context_dict = dict( + msa_tokens=self.msa_context.tokens, + msa_mask=self.msa_context.mask, + msa_deletion_matrix=self.msa_context.deletion_matrix, + msa_species=self.msa_context.species, + msa_sequence_source=self.msa_context.sequence_source, + main_msa_tokens=self.main_msa_context.tokens, + main_msa_mask=self.main_msa_context.mask, + main_msa_deletion_matrix=self.main_msa_context.deletion_matrix, + paired_msa_depth=self.msa_context.paired_msa_depth, + ) + return { + **self.structure_context.to_dict(), + **msa_context_dict, + **self.template_context.to_dict(), + **(self.embedding_context.to_dict() if self.embedding_context else {}), + **self.constraint_context.to_dict(), + } diff --git a/chai_lab/data/dataset/constraints/__init__.py b/chai_lab/data/dataset/constraints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/dataset/constraints/constraint_context.py b/chai_lab/data/dataset/constraints/constraint_context.py new file mode 100644 index 0000000..37c4cbc --- /dev/null +++ b/chai_lab/data/dataset/constraints/constraint_context.py @@ -0,0 +1,58 @@ +from dataclasses import asdict, dataclass +from typing import Any + +from chai_lab.data.features.generators.docking import ( + ConstraintGroup as DockingConstraint, +) +from chai_lab.data.features.generators.token_dist_restraint import ( + ConstraintGroup as ContactConstraint, +) +from chai_lab.data.features.generators.token_pair_pocket_restraint import ( + ConstraintGroup as PocketConstraint, +) +from chai_lab.utils.typing import typecheck + + +@typecheck +@dataclass +class ConstraintContext: + docking_constraints: list[DockingConstraint] | None + contact_constraints: list[ContactConstraint] | None + pocket_constraints: list[PocketConstraint] | None + + def __str__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"\n\tdocking_constraints {self.docking_constraints})" + f"\n\tcontact_constraints {self.contact_constraints}" + f"\n\tpocket_constraints {self.pocket_constraints}\n)" + ) + + def pad(self, *args, **kwargs) -> "ConstraintContext": + # No-op + return ConstraintContext( + docking_constraints=self.docking_constraints, + contact_constraints=self.contact_constraints, + pocket_constraints=self.pocket_constraints, + ) + + def to_dict(self) -> dict[str, Any]: + return dict( + docking_constraints=[asdict(c) for c in self.docking_constraints] + if self.docking_constraints is not None + else [None], + contact_constraints=[asdict(c) for c in self.contact_constraints] + if self.contact_constraints is not None + else [None], + pocket_constraints=[asdict(c) for c in self.pocket_constraints] + if self.pocket_constraints is not None + else [None], + ) + + @classmethod + def empty(cls) -> "ConstraintContext": + return cls( + docking_constraints=None, + contact_constraints=None, + pocket_constraints=None, + ) diff --git a/chai_lab/data/dataset/embeddings/__init__.py b/chai_lab/data/dataset/embeddings/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/dataset/embeddings/embedding_context.py b/chai_lab/data/dataset/embeddings/embedding_context.py new file mode 100644 index 0000000..4c09ff7 --- /dev/null +++ b/chai_lab/data/dataset/embeddings/embedding_context.py @@ -0,0 +1,47 @@ +from dataclasses import asdict, dataclass + +import torch +from torch import Tensor + +from chai_lab.utils.typing import Float, typecheck + + +@typecheck +@dataclass +class EmbeddingContext: + esm_embeddings: Float[Tensor, "num_tokens d_emb"] + + def __str__(self) -> str: + return ( + f"{self.__class__.__name__}(esm_embeddings of {self.esm_embeddings.shape})" + ) + + @property + def num_tokens(self) -> int: + (num_tokens, _) = self.esm_embeddings.shape + return num_tokens + + def pad(self, max_tokens: int) -> "EmbeddingContext": + assert self.num_tokens <= max_tokens + + pad_dims_token = (0, max_tokens - self.num_tokens) + pad_dims_emb = (0, 0) + + padded_embeddings = torch.nn.functional.pad( + self.esm_embeddings, + pad_dims_emb + pad_dims_token, + value=0, + ) + + return EmbeddingContext( + esm_embeddings=padded_embeddings, + ) + + def to_dict(self) -> dict[str, torch.Tensor]: + return asdict(self) + + @classmethod + def empty(cls, n_tokens: int, d_emb: int = 2560) -> "EmbeddingContext": + return cls( + esm_embeddings=torch.zeros(n_tokens, d_emb), + ) diff --git a/chai_lab/data/dataset/embeddings/esm.py b/chai_lab/data/dataset/embeddings/esm.py new file mode 100644 index 0000000..dd12ea8 --- /dev/null +++ b/chai_lab/data/dataset/embeddings/esm.py @@ -0,0 +1,97 @@ +import os +from contextlib import contextmanager + +import torch +from transformers import logging as tr_logging + +from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext +from chai_lab.data.dataset.structure.chain import Chain +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.utils.tensor_utils import move_data_to_device +from chai_lab.utils.typing import typecheck + +_esm_model: list = [] # persistent in-process container + +os.register_at_fork(after_in_child=lambda: _esm_model.clear()) + + +# unfortunately huggingface complains on pooler layer in ESM being non-initialized. +# Did not find a way to filter specifically that logging message :/ +tr_logging.set_verbosity_error() + + +@contextmanager +def esm_model(model_name: str, device): + """Context transiently keeps ESM model on specified device.""" + from transformers import EsmModel + + if len(_esm_model) == 0: + # lazy loading of the model + _esm_model.append(EsmModel.from_pretrained(model_name)) + + [model] = _esm_model + model.to(device) + model.eval() + yield model + # move model back to CPU + model.to("cpu") + + +def embedding_context_from_sequence(seq: str, device) -> EmbeddingContext: + # local import, requires huggingface transformers + from transformers import EsmTokenizer + + model_name = "facebook/esm2_t36_3B_UR50D" + tokenizer = EsmTokenizer.from_pretrained(model_name) + + inputs = tokenizer(seq, return_tensors="pt") + inputs = move_data_to_device(dict(**inputs), device=device) + + with torch.no_grad(): + with esm_model(model_name=model_name, device=device) as model: + outputs = model(**inputs) + + # remove BOS/EOS, back to CPU + esm_embeddings = outputs.last_hidden_state[0, 1:-1].to("cpu") + seq_len, _emb_dim = esm_embeddings.shape + assert seq_len == len(seq) + return EmbeddingContext(esm_embeddings=esm_embeddings) + + +@typecheck +def get_esm_embedding_context(chains: list[Chain], device) -> EmbeddingContext: + # device is used for computing, but result is still on CPU + chain_embs = [] + + for chain in chains: + if chain.entity_data.entity_type == EntityType.PROTEIN: + emb = embedding_context_from_sequence( + # modified residues represented as X + seq=chain.entity_data.sequence, + device=device, + ) + chain_embs.append(emb) + else: + # embed non-proteins with zeros + chain_embs.append( + EmbeddingContext.empty(n_tokens=chain.structure_context.num_tokens) + ) + + exploded_embs = [ + embedding.esm_embeddings[chain.structure_context.token_residue_index, :] + for embedding, chain in zip(chain_embs, chains, strict=True) + ] + + # don't crop any chains during inference + cropped_embs = exploded_embs + + # if we had to crop, we'd need some logic like below: + # crop_idces: list[torch.Tensor] + # cropped_embs = [ + # embedding[crop_idx, :] for embedding, crop_idx in zip(exploded_embs, crop_idces) + # ] + + # Merge the embeddings along the tokens dimension (i.e. merge the chains) + merged_embs = torch.cat(cropped_embs, dim=0) + + return EmbeddingContext(esm_embeddings=merged_embs) diff --git a/chai_lab/data/dataset/inference_dataset.py b/chai_lab/data/dataset/inference_dataset.py new file mode 100644 index 0000000..27e37ce --- /dev/null +++ b/chai_lab/data/dataset/inference_dataset.py @@ -0,0 +1,216 @@ +import logging +import string +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import gemmi + +from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import ( + AllAtomResidueTokenizer, + _make_sym_ids, +) +from chai_lab.data.dataset.structure.chain import Chain +from chai_lab.data.parsing.fasta import parse_modified_fasta_sequence, read_fasta +from chai_lab.data.parsing.structure.all_atom_entity_data import AllAtomEntityData +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.data.parsing.structure.residue import Residue, get_restype +from chai_lab.data.residue_constants import ( + new_ligand_residue_name, + residue_types_with_nucleotides_order, +) +from chai_lab.data.sources.rdkit import RefConformerGenerator + +logger = logging.getLogger(__name__) + + +@dataclass +class Input: + sequence: str + entity_type: int + + +def get_lig_residues( + smiles: str, +) -> list[Residue]: + return [ + Residue( + name=new_ligand_residue_name, + label_seq=0, + restype=residue_types_with_nucleotides_order["X"], + residue_index=0, + is_missing=False, + b_factor_or_plddt=0.0, + conformer_data=None, + smiles=smiles, + ) + ] + + +def get_polymer_residues( + residue_names: list[str], + entity_type: EntityType, +): + residues = [] + for i, residue_name in enumerate(residue_names): + residues.append( + Residue( + name=residue_name, + label_seq=i, + restype=get_restype( + gemmi.find_tabulated_residue(residue_name), entity_type + ), + residue_index=i, + is_missing=False, + b_factor_or_plddt=1.0, + conformer_data=None, + ) + ) + return residues + + +def _synth_subchain_id(idx: int) -> str: + n = len(string.ascii_uppercase) + retval = "" + while idx >= 0: + retval = string.ascii_uppercase[idx % n] + retval + idx = idx // n - 1 + return retval + + +def raw_inputs_to_entitites_data( + inputs: list[Input], identifier: str = "test" +) -> list[AllAtomEntityData]: + entities = [] + + # track unique entities + entity_to_index: dict[tuple[EntityType, tuple[str, ...]], int] = {} + + for i, input in enumerate(inputs): + # Parse residues based on entity type + residues = None + match entity_type := EntityType(input.entity_type): + case EntityType.LIGAND: + residues = get_lig_residues(smiles=input.sequence) + + case EntityType.PROTEIN | EntityType.RNA | EntityType.DNA: + parsed_sequence: list = parse_modified_fasta_sequence( + input.sequence, entity_type + ) + residues = get_polymer_residues(parsed_sequence, entity_type) + case _: + raise NotImplementedError + assert residues is not None + + # Determine the entity id (unique integer for each distinct sequence) + # NOTE very important for recognizing things like homo polymers + seq: tuple[str, ...] = tuple(res.name for res in residues) + entity_key: tuple[EntityType, tuple[str, ...]] = (entity_type, seq) + if entity_key in entity_to_index: + entity_id = entity_to_index[entity_key] + else: + entity_id = len(entity_to_index) + entity_to_index[entity_key] = entity_id + + entities.append( + AllAtomEntityData( + residues, + full_sequence=[residue.name for residue in residues], + resolution=0.0, + release_datetime=datetime.now(), + pdb_id=identifier, + source_pdb_chain_id=_synth_subchain_id(i), + entity_name=f"entity_{i}_{entity_type.name}", + entity_id=entity_id, + method="none", + entity_type=entity_type, + subchain_id=_synth_subchain_id(i), + ) + ) + + assert len(entities) == len(inputs) + return entities + + +def load_chains_from_raw( + inputs: list[Input], + identifier: str = "test", + tokenizer: AllAtomResidueTokenizer | None = None, +) -> list[Chain]: + """ + loads and tokenizes each input chain + """ + + if tokenizer is None: + conformer_generator = RefConformerGenerator() + tokenizer = AllAtomResidueTokenizer(conformer_generator) + + # Extract the entity data from the gemmi structure. + entities: list[AllAtomEntityData] = raw_inputs_to_entitites_data( + inputs, + identifier=identifier, + ) + + # Tokenize the entity data + structure_contexts = [] + sym_ids = _make_sym_ids([x.entity_id for x in entities]) + for idx, (entity_data, sym_id) in enumerate(zip(entities, sym_ids)): + try: + tok = tokenizer._tokenize_entity( + entity_data, + chain_id=idx + 1, + sym_id=sym_id, + ) + structure_contexts.append(tok) + except Exception: + logger.exception(f"Failed to tokenize input {inputs[idx]}") + + # Join the untokenized entity data with the tokenized chain data, removing + # chains we failed to tokenize + chains = [ + Chain(entity_data=entity_data, structure_context=structure_context) + for entity_data, structure_context in zip(entities, structure_contexts) + if structure_context is not None + ] + + return chains + + +def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list[Input]: + """Read inputs from a fasta file. + + If the total length of sequences' character count is greater than length limit, + return an empty list. Note that character count is not the same as token count, but + is an easy approximation (smiles length is somewhat proportion to number of atoms in + a ligand, number of residues approximates number of tokens with modified amino acids + adding to it, etc.). + """ + sequences = read_fasta(fasta_file) + + retval: list[Input] = [] + total_length: int = 0 + for desc, sequence in sequences: + logger.info(f"[fasta] [{fasta_file}] {desc} {len(sequence)}") + # get the type of the sequence + entity_str = desc.split("|")[0].strip() + match entity_str: + case "protein": + entity_type = EntityType.PROTEIN + case "ligand": + entity_type = EntityType.LIGAND + case "rna": + entity_type = EntityType.RNA + case "dna": + entity_type = EntityType.DNA + case _: + raise ValueError(f"{entity_str} is not a valid entity type") + retval.append(Input(sequence, entity_type.value)) + total_length += len(sequence) + + if length_limit is not None and total_length > length_limit: + logger.warning( + f"[fasta] [{fasta_file}] too many chars ({total_length} > {length_limit}); no inputs" + ) + return [] + + return retval diff --git a/chai_lab/data/dataset/msas/__init__.py b/chai_lab/data/dataset/msas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/dataset/msas/msa_context.py b/chai_lab/data/dataset/msas/msa_context.py new file mode 100644 index 0000000..ca4dfd2 --- /dev/null +++ b/chai_lab/data/dataset/msas/msa_context.py @@ -0,0 +1,209 @@ +from dataclasses import dataclass + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from chai_lab.data.parsing.msas.data_source import ( + MSADataSource, + msa_dataset_source_to_int, +) +from chai_lab.data.parsing.msas.species import UNKNOWN_SPECIES +from chai_lab.data.residue_constants import residue_types_with_nucleotides_order +from chai_lab.utils.defaults import default +from chai_lab.utils.typing import Bool, Int32, UInt8, typecheck + + +@typecheck +@dataclass +class MSAContext: + # MSA-level + dataset_source: MSADataSource + + # token level + tokens: UInt8[Tensor, "msa_depth n_tokens"] + species: Int32[Tensor, "msa_depth n_tokens"] + deletion_matrix: UInt8[Tensor, "msa_depth n_tokens"] + mask: Bool[Tensor, "msa_depth n_tokens"] + sequence_source: UInt8[Tensor, "msa_depth n_tokens"] + is_paired_mask: Bool[Tensor, "msa_depth"] + + @property + def depth(self) -> int: + depth, _ = self._dims + return depth + + @property + def num_tokens(self) -> int: + _, num_tokens = self._dims + return num_tokens + + @property + def _dims(self) -> torch.Size: + return self.tokens.shape + + @property + def paired_msa_depth(self) -> Int32[Tensor, "b"]: + return (self.mask.any(dim=-1) & self.is_paired_mask).sum(dim=-1) + + def __getitem__(self, subscript: tuple) -> "MSAContext": + # enforce typing on item + if not ( + isinstance(subscript, tuple) + and ((len(subscript) == 2) or subscript[0] is Ellipsis) + ): + raise TypeError( + "Subscript must be a tuple with 2 elements or have an ellipsis." + ) + + is_paired_mask = repeat( + self.is_paired_mask, + "msa_depth -> msa_depth n_tokens", + n_tokens=self.num_tokens, + ) + return MSAContext( + dataset_source=self.dataset_source, + tokens=self.tokens[subscript], + species=self.species[subscript], + deletion_matrix=self.deletion_matrix[subscript], + sequence_source=self.sequence_source[subscript], + mask=self.mask[subscript], + is_paired_mask=is_paired_mask[subscript].any(dim=-1), + ) + + def pad( + self, + max_num_tokens: int | None = None, + max_msa_depth: int | None = None, + ) -> "MSAContext": + max_num_tokens = default(max_num_tokens, self.num_tokens) + assert self.num_tokens <= max_num_tokens + + max_msa_depth = default(max_msa_depth, self.depth) + assert self.depth <= max_msa_depth + + pad_dims = (0, max_num_tokens - self.num_tokens, 0, max_msa_depth - self.depth) + return MSAContext( + dataset_source=self.dataset_source, + tokens=torch.nn.functional.pad( + self.tokens, + pad_dims, + value=residue_types_with_nucleotides_order[":"], + ), + species=torch.nn.functional.pad( + self.species, + pad_dims, + value=UNKNOWN_SPECIES, + ), + deletion_matrix=torch.nn.functional.pad( + self.deletion_matrix, + pad_dims, + value=0, # No deletions + ), + mask=torch.nn.functional.pad( + self.mask, + pad_dims, + value=False, + ), + sequence_source=torch.nn.functional.pad( + self.sequence_source, + pad_dims, + value=msa_dataset_source_to_int[MSADataSource.NONE], + ), + is_paired_mask=torch.nn.functional.pad( + self.is_paired_mask, + (0, max_msa_depth - self.depth), + value=False, + ), + ) + + @typecheck + def apply_mask(self, mask: Bool[Tensor, "msa_depth n_tokens"]) -> "MSAContext": + return MSAContext( + dataset_source=self.dataset_source, + tokens=self.tokens.masked_fill( + ~mask, residue_types_with_nucleotides_order[":"] + ), + species=self.species.masked_fill(~mask, UNKNOWN_SPECIES), + deletion_matrix=self.deletion_matrix.masked_fill(~mask, 0), + mask=self.mask.masked_fill(~mask, False), + sequence_source=self.sequence_source.masked_fill( + ~mask, value=msa_dataset_source_to_int[MSADataSource.NONE] + ), + is_paired_mask=self.is_paired_mask.masked_fill(~mask.any(dim=-1), False), + ) + + @classmethod + def cat( + cls, + msas: list["MSAContext"], + dataset_source: MSADataSource | None = None, + dim=-1, + ) -> "MSAContext": + if dataset_source is None: + dataset_sources = set([msa.dataset_source for msa in msas]) + assert len(dataset_sources) == 1 or dataset_sources == { + MSADataSource.MAIN, + MSADataSource.PAIRED, + }, "all MSAs must have the same datasource or be MAIN and PAIRED" + dataset_source = dataset_sources.pop() + + assert dim == -1 or dim >= 0, "dim < 0 not implemented except for -1" + if 0 <= dim < 1: + is_paired_mask = torch.cat([msa.is_paired_mask for msa in msas], dim=dim) + else: + assert len(msas) > 0 + is_paired_mask = msas[0].is_paired_mask + + return MSAContext( + dataset_source=dataset_source, + tokens=torch.cat([msa.tokens for msa in msas], dim=dim), + species=torch.cat([msa.species for msa in msas], dim=dim), + deletion_matrix=torch.cat([msa.deletion_matrix for msa in msas], dim=dim), + sequence_source=torch.cat([msa.sequence_source for msa in msas], dim=dim), + mask=torch.cat([msa.mask for msa in msas], dim=dim), + is_paired_mask=is_paired_mask, + ) + + @classmethod + @typecheck + def create( + cls, + dataset_source: MSADataSource, + tokens: UInt8[Tensor, "n_tokens"], + ) -> "MSAContext": + """ + Creates an MSA comprised of a single sequence. + """ + tokens_for_msa = rearrange(tokens, "n_tokens -> 1 n_tokens") + return MSAContext( + dataset_source=dataset_source, + tokens=tokens_for_msa, + species=torch.full_like(tokens_for_msa, UNKNOWN_SPECIES, dtype=torch.int32), + deletion_matrix=torch.zeros_like(tokens_for_msa, dtype=torch.uint8), + mask=torch.ones_like(tokens_for_msa, dtype=torch.bool), + sequence_source=torch.full_like( + tokens_for_msa, + fill_value=msa_dataset_source_to_int[dataset_source], + ), + is_paired_mask=torch.zeros((1,), dtype=torch.bool), + ) + + @classmethod + def create_empty(cls, n_tokens: int, depth: int = 0) -> "MSAContext": + dims = (depth, n_tokens) + return MSAContext( + dataset_source=MSADataSource.NONE, + tokens=torch.full( + dims, residue_types_with_nucleotides_order[":"], dtype=torch.uint8 + ), + species=torch.full(dims, UNKNOWN_SPECIES, dtype=torch.int32), + deletion_matrix=torch.zeros(dims, dtype=torch.uint8), # No deletions + mask=torch.zeros(dims, dtype=torch.bool), + sequence_source=torch.full( + dims, + fill_value=msa_dataset_source_to_int[MSADataSource.NONE], + dtype=torch.uint8, + ), + is_paired_mask=torch.zeros((depth,), dtype=torch.bool), + ) diff --git a/chai_lab/data/dataset/structure/all_atom_residue_tokenizer.py b/chai_lab/data/dataset/structure/all_atom_residue_tokenizer.py new file mode 100644 index 0000000..320352f --- /dev/null +++ b/chai_lab/data/dataset/structure/all_atom_residue_tokenizer.py @@ -0,0 +1,632 @@ +import logging +from dataclasses import dataclass +from itertools import chain + +import torch +from einops import repeat +from torch import Tensor + +from chai_lab.data.dataset.structure import utils +from chai_lab.data.dataset.structure.all_atom_structure_context import ( + AllAtomStructureContext, +) +from chai_lab.data.dataset.structure.utils import ( + backbone_atoms_all_present, + backbone_atoms_indices, + get_centre_atom_index, + get_reference_atom_index, +) +from chai_lab.data.parsing.structure.all_atom_entity_data import AllAtomEntityData +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.data.parsing.structure.residue import ConformerData, Residue +from chai_lab.data.residue_constants import standard_residue_pdb_codes +from chai_lab.data.sources.rdkit import ( + RefConformerGenerator, + conformer_data_to_rdkit_mol, +) +from chai_lab.utils.tensor_utils import string_to_tensorcode, unique_indexes +from chai_lab.utils.typing import Bool, Float, Int, typecheck + +logger = logging.getLogger(__name__) + + +# jaxtyping on residue-level objects is extremely slow. +@dataclass(frozen=True) +class TokenSpan: + restype: Int[Tensor, "n_tokens"] + residue_index: Int[Tensor, "n_tokens"] + centre_atom_index: Int[Tensor, "n_tokens"] + reference_atom_index: Int[Tensor, "n_tokens"] + backbone_frame_mask: Bool[Tensor, "n_tokens"] + backbone_frame_index: Int[Tensor, "n_tokens 3"] + atom_gt_coords: Float[Tensor, "n_atoms 3"] + atom_exists_mask: Bool[Tensor, "n_atoms"] + atom_token_index: Int[Tensor, "n_atoms"] + ref_pos: Float[Tensor, "n_atoms 3"] + ref_mask: Bool[Tensor, "n_atoms"] + ref_element: Int[Tensor, "n_atoms"] + ref_charge: Int[Tensor, "n_atoms"] + atom_names: list[str] + # Consistent atom ordering witin each token + atom_within_token_indices: Int[Tensor, "n_atoms"] + residue_names: list[str] + symmetries: Int[Tensor, "n_atoms n_symm"] + b_factor_or_plddt: Float[Tensor, "n_tokens"] + + @classmethod + def concatenate(cls, spans: list["TokenSpan"]) -> "TokenSpan": + # offset bond indices: + tokens_per_span = torch.tensor([span.restype.shape[0] for span in spans]) + token_count = torch.cumsum(tokens_per_span, dim=0).roll(1, 0) + token_count[0] = 0 + + # offsets indices of centre atoms: + atoms_per_span = torch.tensor( + [span.atom_exists_mask.shape[0] for span in spans] + ) + atom_offsets = torch.cumsum(atoms_per_span, dim=0).roll(1, 0) + atom_offsets[0] = 0 + + centre_atom_index = torch.cat( + [ + span.centre_atom_index + offset + for span, offset in zip(spans, atom_offsets) + ] + ) + reference_atom_index = torch.cat( + [ + span.reference_atom_index + offset + for span, offset in zip(spans, atom_offsets) + ] + ) + + atom_token_index = ( + torch.cumsum( + torch.cat([x.atom_token_index for x in spans]), + dim=0, + dtype=torch.int, + ) + - 1 + ) + backbone_frame_index = torch.cat( + [ + span.backbone_frame_index + offset + for span, offset in zip(spans, atom_offsets) + ] + ) + + # concatenate symmetric permutations at the atom level + # make sure that trailing shape is the same + # NOTE: we store the *local* permutation indices, not the global ones + # i.e. the permutation indices are relative to the residue + atom_symms = [span.symmetries for span in spans] + max_symms = max(x.shape[-1] for x in atom_symms) + atom_symms = [ + torch.nn.functional.pad(x, (0, max_symms - x.shape[-1]), value=-1) + for x in atom_symms + ] + return cls( + restype=torch.cat([x.restype for x in spans]), + residue_index=torch.cat([x.residue_index for x in spans]), + centre_atom_index=centre_atom_index, + reference_atom_index=reference_atom_index, + backbone_frame_mask=torch.cat([x.backbone_frame_mask for x in spans]), + backbone_frame_index=backbone_frame_index, + atom_gt_coords=torch.cat([x.atom_gt_coords for x in spans]), + atom_exists_mask=torch.cat([x.atom_exists_mask for x in spans]), + atom_token_index=atom_token_index, + ref_pos=torch.cat([x.ref_pos for x in spans]), + ref_mask=torch.cat([x.ref_mask for x in spans]), + ref_element=torch.cat([x.ref_element for x in spans]), + ref_charge=torch.cat([x.ref_charge for x in spans]), + atom_names=list(chain.from_iterable([x.atom_names for x in spans])), + atom_within_token_indices=torch.cat( + [x.atom_within_token_indices for x in spans] + ), + residue_names=list(chain.from_iterable([x.residue_names for x in spans])), + symmetries=torch.cat(atom_symms, dim=0), + b_factor_or_plddt=torch.cat([x.b_factor_or_plddt for x in spans]), + ) + + +class AllAtomResidueTokenizer: + ref_conformer_generator: RefConformerGenerator + + def __init__(self, ref_conformer_generator: RefConformerGenerator): + self.ref_conformer_generator = ref_conformer_generator + + def tokenize_residue( + self, + residue: Residue, + entity_type: EntityType, + ) -> TokenSpan | None: + ref_conformer_data = self._get_ref_conformer_data(residue) + if ref_conformer_data.num_atoms == 0: + # avoid dealing with empty tensors in downstream processing + # this should only happen when residue is sole hydrogen + # or when residue code is not in CCD dictionary and + # the residue has 0 coords in the PDB structure + logger.warning( + f"skipping residue {residue.name} {residue.label_seq} as reference conformer has 0 heavy atoms" + ) + return None + + # Keep only the atoms from the ground truth conformer that are present in + # reference conformer. + # + # If we don't have a reference conformer, we fall back to using the ground truth + # conformer names, i.e. we keep all atoms in the ground truth conformer. + # When a true conformer data is not provided, use reference conformer directly + gt_conformer_data = residue.conformer_data + + if gt_conformer_data is not None: + atom_gt_coords, atom_exists_mask = gt_conformer_data.gather_atom_positions( + ref_conformer_data.atom_names + ) + else: + atom_gt_coords = ref_conformer_data.position + atom_exists_mask = torch.ones( + atom_gt_coords.shape[0], dtype=torch.bool, device=atom_gt_coords.device + ) + + # Tokenization is by residue if it is a standard amino acid or standard + # nucleotide; all ligands and all modified residues are tokenized per atom. + tokenize_fn = ( + self._tokenize_per_residue + if ( + residue.name in standard_residue_pdb_codes + and entity_type != EntityType.LIGAND + ) + else self._tokenize_per_atom + ) + + return tokenize_fn( + restype=torch.tensor([residue.restype], dtype=torch.int), + residue_index=torch.tensor([residue.residue_index], dtype=torch.int), + atom_gt_coords=atom_gt_coords, + atom_exists_mask=atom_exists_mask, + ref_pos=ref_conformer_data.position, + ref_mask=torch.ones_like(atom_exists_mask), + ref_element=ref_conformer_data.element, + ref_charge=ref_conformer_data.charge, + atom_names=ref_conformer_data.atom_names, + residue_name=residue.name, + bonds=ref_conformer_data.bonds, + symmetries=ref_conformer_data.symmetries, + b_factor_or_plddt=torch.tensor([residue.b_factor_or_plddt]), + ) + + @staticmethod + def filter_atom_symmetries( + symmetries: Int[Tensor, "n_atoms n_symm"], + atom_exists_mask: Bool[Tensor, "n_atoms"], + ) -> Int[Tensor, "n_atoms filtered_n_symm"]: + n_atoms, _ = symmetries.shape + + # Create a mask for non-trivial symmetries + atom_indices = torch.arange(n_atoms).unsqueeze(-1) + non_trivial_symmetries = (symmetries >= 0) & (symmetries != atom_indices) + + masked_atoms = ~atom_exists_mask.unsqueeze(-1) + + # Check if any of the masked-out atoms have non-trivial symmetries + violations = torch.any(masked_atoms & non_trivial_symmetries, dim=1) + + # If any invalid symmetries are found, replace with identity permutation + if torch.any(violations): + return atom_indices + + # Otherwise, return the original symmetries + return symmetries + + # jaxtyping on residue-level objects is very slow, + # use for debug only + # @typecheck + def _tokenize_per_residue( + self, + restype: Int[Tensor, "n_tokens"], + residue_index: Int[Tensor, "n_tokens"], + atom_gt_coords: Float[Tensor, "n_atoms 3"], + atom_exists_mask: Bool[Tensor, "n_atoms"], + ref_pos: Float[Tensor, "n_atoms 3"], + ref_mask: Bool[Tensor, "n_atoms"], + ref_element: Int[Tensor, "n_atoms"], + ref_charge: Int[Tensor, "n_atoms"], + atom_names: list[str], + residue_name: str, + bonds: list[tuple[int, int]], + symmetries: Int[Tensor, "n_atoms n_symm"], + b_factor_or_plddt: Float[Tensor, "n_tokens"], + ) -> TokenSpan: + centre_atom_index = get_centre_atom_index( + atom_names, + residue_name, + ) + reference_atom_index = get_reference_atom_index( + atom_names, + residue_name, + ) + backbone_frame_mask = backbone_atoms_all_present( + atom_names, + residue_name, + ) + backbone_indices = backbone_atoms_indices(atom_names, residue_name).unsqueeze(0) + + # to 1 token + atom_token_index = torch.zeros_like(atom_exists_mask, dtype=torch.int) + atom_token_index[0] = 1 + + residue_names = [residue_name] + + # Find atom ordering; these should always be available because per residue + # tokenization works only on standard residues. + atom_within_token_index = atom_names_to_atom37_indices( + atom_names=atom_names, + residue_name=residue_name, + ) + + return TokenSpan( + restype=restype, + residue_index=residue_index, + centre_atom_index=centre_atom_index, + reference_atom_index=reference_atom_index, + backbone_frame_mask=backbone_frame_mask, + backbone_frame_index=backbone_indices, + atom_gt_coords=atom_gt_coords, + atom_exists_mask=atom_exists_mask, + atom_token_index=atom_token_index, + ref_pos=ref_pos, + ref_mask=ref_mask, + ref_element=ref_element, + ref_charge=ref_charge, + atom_names=atom_names, + atom_within_token_indices=atom_within_token_index, + residue_names=residue_names, + symmetries=self.filter_atom_symmetries(symmetries, atom_exists_mask), + b_factor_or_plddt=b_factor_or_plddt, + ) + + # jaxtyping on residue-level objects is very slow, + # use for debug only + # @typecheck + def _tokenize_per_atom( + self, + restype: Int[Tensor, "n_tokens"], + residue_index: Int[Tensor, "n_tokens"], + atom_gt_coords: Float[Tensor, "n_atoms 3"], + atom_exists_mask: Bool[Tensor, "n_atoms"], + ref_pos: Float[Tensor, "n_atoms 3"], + ref_mask: Bool[Tensor, "n_atoms"], + ref_element: Int[Tensor, "n_atoms"], + ref_charge: Int[Tensor, "n_atoms"], + atom_names: list[str], + residue_name: str, + bonds: list[tuple[int, int]], + symmetries: Int[Tensor, "n_atoms n_symm"], + b_factor_or_plddt: Float[Tensor, "n_tokens"], + ) -> TokenSpan: + # to n_atoms tokens + n_atoms = atom_gt_coords.shape[0] + restype = repeat(restype, "1 -> a", a=n_atoms) + residue_index = repeat(residue_index, "1 -> a", a=n_atoms) + b_factor_or_plddt = repeat(b_factor_or_plddt, "1 -> a", a=n_atoms) + + # centre of the token is the first and only atom in each token + # when tokenizing per-atom + centre_atom_index = torch.arange(n_atoms, dtype=torch.int) + reference_atom_index = torch.arange(n_atoms, dtype=torch.int) + backbone_frame_mask = torch.zeros((n_atoms,), dtype=torch.bool) + backbone_indices = ( + torch.arange(n_atoms, dtype=torch.int).unsqueeze(1).expand(-1, 3) + ) + + atom_token_index = torch.ones_like(atom_exists_mask, dtype=torch.int) + + residue_names = [residue_name] * n_atoms + + # Each atom is alone in its own token + atom_within_token_index = torch.zeros(n_atoms, dtype=torch.int) + + return TokenSpan( + restype=restype, + residue_index=residue_index, + centre_atom_index=centre_atom_index, + reference_atom_index=reference_atom_index, + backbone_frame_mask=backbone_frame_mask, + backbone_frame_index=backbone_indices, + atom_gt_coords=atom_gt_coords, + atom_exists_mask=atom_exists_mask, + atom_token_index=atom_token_index, + ref_pos=ref_pos, + ref_mask=ref_mask, + ref_element=ref_element, + ref_charge=ref_charge, + atom_names=atom_names, + atom_within_token_indices=atom_within_token_index, + residue_names=residue_names, + symmetries=self.filter_atom_symmetries(symmetries, atom_exists_mask), + b_factor_or_plddt=b_factor_or_plddt, + ) + + def tokenize_entity( + self, entity_data: AllAtomEntityData + ) -> AllAtomStructureContext | None: + return self.tokenize_entities([entity_data])[0] + + def tokenize_entities( + self, + entities_data: list[AllAtomEntityData], + ) -> list[AllAtomStructureContext | None]: + sym_ids = _make_sym_ids([x.entity_id for x in entities_data]) + + return [ + self._tokenize_entity( + entity_data, + chain_id=idx + 1, + sym_id=sym_id, + ) + for idx, (entity_data, sym_id) in enumerate(zip(entities_data, sym_ids)) + ] + + def _tokenize_entity( + self, + entity_data: AllAtomEntityData, + chain_id: int = 1, + sym_id: int = 1, + ) -> AllAtomStructureContext | None: + tokenized_residues = [ + self.tokenize_residue(residue, entity_data.entity_type) + for residue in entity_data.residues + ] + + valid_residues = [x for x in tokenized_residues if x is not None] + if len(valid_residues) == 0: + return None + + tokens = TokenSpan.concatenate(valid_residues) + + num_tokens = tokens.restype.shape[0] + token_index = torch.arange(num_tokens, dtype=torch.int) + + # mask indicating if a token has >=1 atom with known coordinates + token_exists_mask = (tokens.atom_token_index == token_index[..., None]).sum( + dim=-1 + ) > 0 + + # checks on atom mask and positions: + # max 1 atom per-example has zero coordinates + if ( + torch.sum( + torch.all(tokens.atom_gt_coords[tokens.atom_exists_mask] == 0, dim=-1) + ) + > 1 + ): + raise ValueError( + f"Zero coordinates found in unmasked atoms for {entity_data.pdb_id}" + ) + + # construct asym_id, entity_id, sym_id + asym_id = chain_id + entity_id = entity_data.entity_id + + # Create unique ids to identify atoms which belong to same residue in same chain + # here assume we featurize a single chain + atom_residue_index = torch.gather( + tokens.residue_index, + dim=0, + index=tokens.atom_token_index.long(), + ) + + atom_ref_space_uid = atom_residue_index + + residue_names = tokens.residue_names + + match entity_data.entity_type: + case EntityType.PROTEIN: + if tokens.residue_index[0].item() != 0: + logger.error( + f"Protein residue index should start at zero, {entity_data}" + ) + + if not torch.all(torch.diff(tokens.residue_index) <= 1): + logger.error( + f"Protein residue index should be contiguous (no gaps), {entity_data}" + ) + + _, unique_indices = unique_indexes(tokens.residue_index) + res_seq = [residue_names[i.item()] for i in unique_indices] + if res_seq != entity_data.full_sequence: + logger.error( + f"Protein residue names should match entity data full sequence, {entity_data}" + ) + + return AllAtomStructureContext( + # token-level + token_residue_type=tokens.restype, + token_residue_index=tokens.residue_index, + token_centre_atom_index=tokens.centre_atom_index, + token_ref_atom_index=tokens.reference_atom_index, + token_index=token_index, + token_exists_mask=token_exists_mask, + token_backbone_frame_mask=tokens.backbone_frame_mask, + token_backbone_frame_index=tokens.backbone_frame_index, + token_asym_id=_id_to_token_tensor(asym_id, num_tokens), + token_entity_id=_id_to_token_tensor(entity_id, num_tokens), + token_sym_id=_id_to_token_tensor(sym_id, num_tokens), + token_entity_type=entity_type_to_tensor( + entity_data.entity_type, + num_tokens, + ), + # token res name is padded to 8 characters + token_residue_name=torch.stack( + [string_to_tensorcode(x, 8) for x in residue_names], + dim=0, + ), + token_b_factor_or_plddt=tokens.b_factor_or_plddt, + # atom-level + atom_token_index=tokens.atom_token_index, + atom_within_token_index=tokens.atom_within_token_indices, + atom_ref_pos=tokens.ref_pos, + atom_ref_mask=tokens.ref_mask, + atom_ref_element=tokens.ref_element, + atom_ref_charge=tokens.ref_charge, + atom_ref_name=tokens.atom_names, + atom_ref_name_chars=_atom_names_to_tensor(tokens.atom_names), + atom_ref_space_uid=atom_ref_space_uid, + atom_is_not_padding_mask=torch.ones_like( + tokens.atom_exists_mask, + dtype=torch.bool, + ), + # supervision only + atom_gt_coords=tokens.atom_gt_coords, + atom_exists_mask=tokens.atom_exists_mask, + # structure-only + pdb_id=repeat( + # PDB ids are only 4 characters long, but AFDB ids can be longer + string_to_tensorcode(entity_data.pdb_id, pad_to_length=32), + "length -> num_tokens length", + num_tokens=num_tokens, + ), + source_pdb_chain_id=repeat( + string_to_tensorcode(entity_data.source_pdb_chain_id, pad_to_length=4), + "length -> num_tokens length", + num_tokens=num_tokens, + ), + subchain_id=repeat( + string_to_tensorcode(entity_data.subchain_id, pad_to_length=4), + "length -> num_tokens length", + num_tokens=num_tokens, + ), + resolution=torch.tensor( + [entity_data.resolution], + dtype=torch.float32, + ), + is_distillation=torch.tensor( + [entity_data.is_distillation], + dtype=torch.bool, + ), + symmetries=tokens.symmetries, + ) + + def _get_ref_conformer_data(self, residue: Residue) -> ConformerData: + """ + Returns the reference conformer data for the residue. We determine the reference + conformer according to the following logic: + 1. conformer_generator is available and a reference + conformer exists for the residue name => we return the cached reference + conformer via the conformer generator + 2. conformer_generator is available and a smiles is given for the residue => + we generate a reference conformer using Rdkit via the conformer generator + 3. conformer_generator is available, the reference conformer can't be + found and no smiles is given => we convert the Residue to an RDKit molecule + and load full conformer data with the residue's atom positions as coordinates. + 4. conformer generator is not available => we set reference conformer to + the ground truth conformer + """ + # The reference conformer tells us: + # - which atoms we should expect in this ligand / residue, and how many of them + # - what are the ideal coordinates of these atoms if the ligand or residue was + # assembled alone in the void + ref_conformer = self.ref_conformer_generator.get(residue.name) + + if ref_conformer is not None: + if residue.name in standard_residue_pdb_codes: + return ref_conformer + else: + return ref_conformer.center_random_augment() + + # When we can't find a reference conformer, and a smiles is given, + # generate a reference conformer using rdkit + if residue.smiles is not None: + logger.info( + f"Generating ref conformer for {residue.name}, {residue.smiles}" + ) + return self.ref_conformer_generator.generate(residue.smiles) + + # When we can't find a reference conformer, attempt to use the ground + # truth conformer data as the reference conformer. + logger.warning( + f"No reference conformer found for residue {residue.name}," + "using training example conformer" + ) + assert residue.conformer_data is not None + + try: + # Rather than just setting the reference conformer to the ground truth, we + # make a fake RDKit molecule from the ground truth data and then convert + # back into a conformer data so that we can extract inter-atom aymmetries + # bond and info + rdkit_mol = conformer_data_to_rdkit_mol(residue.conformer_data) + gt_conformer = RefConformerGenerator._load_ref_conformer_from_rdkit( + rdkit_mol + ) + except Exception as e: + # Occasionally _load_ref_conformer_from_rdkit fails on unknown ligands e.g. + # rdkit.Chem.rdchem.AtomValenceException's can be raised or ValueError: + # can't infer bonds for Ligand. due to inexact connectivity. + logger.warning( + f"Caught error for {residue.name=} while loading reference conformer " + f"from RDKit, {(type(e).__name__)}. Using ground truth conformer instead." + ) + gt_conformer = residue.conformer_data + + return gt_conformer.center_random_augment() + + +@typecheck +def _atom_names_to_tensor(atom_names: list[str]) -> Int[Tensor, "n_atoms 4"]: + ords = torch.tensor( + [[ord(c) - 32 for c in atom_name.ljust(4, " ")] for atom_name in atom_names], + dtype=torch.int, + ) + return ords[:, :4] + + +@typecheck +def _id_to_token_tensor(id: int, num_tokens: int) -> Int[Tensor, "n"]: + return id * torch.ones((num_tokens,), dtype=torch.int) + + +@typecheck +def entity_type_to_tensor(entity_type: EntityType, num_tokens: int) -> Int[Tensor, "n"]: + return torch.full((num_tokens,), fill_value=entity_type.value, dtype=torch.int) + + +def _make_sym_ids(entity_ids_per_chain: list[int]) -> list[int]: + entities_dict: dict[int, int] = dict() + sym_ids = [] + + for entity_id in entity_ids_per_chain: + sym_id = entities_dict.get(entity_id, 0) + sym_ids.append(sym_id) + entities_dict[entity_id] = sym_id + 1 + + return sym_ids + + +def atom_names_to_atom37_indices( + atom_names: list[str], residue_name: str +) -> Int[Tensor, "n_atoms"]: + """ + Returns a tensor of indices into the token-level atom names. + """ + # Proteins use the atom37 ordering and indexing + # nucleotides use the 36 atom ordering and indexing + # - DNA is written as DA DG DC DT + # - RNA is given as A G C U + + precomputed_idces = utils.atom_37_atom_indices() + + if residue_name == "UNK": + retval = torch.arange(len(atom_names), dtype=torch.int) + + elif residue_name in standard_residue_pdb_codes: + idx = [precomputed_idces[(residue_name, atom_name)] for atom_name in atom_names] + retval = torch.tensor(idx, dtype=torch.int) + else: + raise ValueError( + f"Unknown residue name {residue_name} (atom names: {atom_names})" + ) + + assert retval.max() <= 36, f"Out of bounds ordering {atom_names} in {residue_name}" + return retval diff --git a/chai_lab/data/dataset/structure/all_atom_structure_context.py b/chai_lab/data/dataset/structure/all_atom_structure_context.py new file mode 100644 index 0000000..4459479 --- /dev/null +++ b/chai_lab/data/dataset/structure/all_atom_structure_context.py @@ -0,0 +1,286 @@ +import logging +from dataclasses import asdict, dataclass +from functools import cached_property, partial + +import torch +from torch import Tensor + +from chai_lab.utils.tensor_utils import ( + batch_tensorcode_to_string, + tensorcode_to_string, +) +from chai_lab.utils.typing import Bool, Float, Int, UInt8, typecheck + +logger = logging.getLogger(__name__) + + +@typecheck +@dataclass +class AllAtomStructureContext: + # token-level + token_residue_type: Int[Tensor, "n_tokens"] + token_residue_index: Int[Tensor, "n_tokens"] + token_index: Int[Tensor, "n_tokens"] + token_centre_atom_index: Int[Tensor, "n_tokens"] + token_ref_atom_index: Int[Tensor, "n_tokens"] + token_exists_mask: Bool[Tensor, "n_tokens"] + token_backbone_frame_mask: Bool[Tensor, "n_tokens"] + token_backbone_frame_index: Int[Tensor, "n_tokens 3"] + token_asym_id: Int[Tensor, "n_tokens"] + token_entity_id: Int[Tensor, "n_tokens"] + token_sym_id: Int[Tensor, "n_tokens"] + token_entity_type: Int[Tensor, "n_tokens"] + token_residue_name: UInt8[Tensor, "n_tokens 8"] + token_b_factor_or_plddt: Float[Tensor, "n_tokens"] + # atom-level + atom_token_index: Int[Tensor, "n_atoms"] + atom_within_token_index: Int[Tensor, "n_atoms"] # consistent atom ordering + atom_ref_pos: Float[Tensor, "n_atoms 3"] + atom_ref_mask: Bool[Tensor, "n_atoms"] + atom_ref_element: Int[Tensor, "n_atoms"] + atom_ref_charge: Int[Tensor, "n_atoms"] + atom_ref_name: list[str] + atom_ref_name_chars: Int[Tensor, "n_atoms 4"] + atom_ref_space_uid: Int[Tensor, "n_atoms"] + atom_is_not_padding_mask: Bool[Tensor, "n_atoms"] + # supervision only + atom_gt_coords: Float[Tensor, "n_atoms 3"] + atom_exists_mask: Bool[Tensor, "n_atoms"] + # structure-level + pdb_id: UInt8[Tensor, "n_tokens 32"] + # source_pdb_chain_id corresponds to auth_asym_id in pdb + # can be the same for two different asym_id values + # (we split protein and ligand for example) + source_pdb_chain_id: UInt8[Tensor, "n_tokens 4"] + # subchain_id is label_asym_id in pdb + # it is assigned by the PDB and separates different + # chemical entities (protein, ligand) + # should be a 1-1 mapping to asym_id + subchain_id: UInt8[Tensor, "n_tokens 4"] + resolution: Float[Tensor, "1"] + is_distillation: Bool[Tensor, "1"] + # symmetric atom swap indices + symmetries: Int[Tensor, "n_atoms n_symmetries"] + + def __post_init__(self): + # Resolved residues filter should eliminate PDBs with missing residues, but that + # we can still have atom_exists mask set to False at every position if we have a + # bad crop so we log examples with no valid coordinates + if self.num_atoms > 0 and not torch.any(self.atom_exists_mask): + pdb_id = tensorcode_to_string(self.pdb_id[0]) + logger.error(f"No valid coordinates found in any atoms for {pdb_id}") + + # Check that atom and token masks are compatible. Anywhere that the atom mask is + # true, the token mask should also be true + if self.num_atoms > 0 and not torch.all( + self.token_exists_mask[self.atom_token_index][self.atom_exists_mask] + ): + pdb_id = tensorcode_to_string(self.pdb_id[0]) + logger.error(f"Incompatible masks for {pdb_id}") + + @cached_property + def residue_names(self) -> list[str]: + return batch_tensorcode_to_string(self.token_residue_name) + + def pad( + self, + n_tokens: int, + n_atoms: int, + ) -> "AllAtomStructureContext": + assert n_tokens >= self.num_tokens + pad_tokens_func = partial(_pad_func, pad_size=n_tokens - self.num_tokens) + + assert n_atoms >= self.num_atoms + pad_atoms_func = partial(_pad_func, pad_size=n_atoms - self.num_atoms) + + return AllAtomStructureContext( + # token-level + token_residue_type=pad_tokens_func(self.token_residue_type), + token_residue_index=pad_tokens_func(self.token_residue_index), + token_index=pad_tokens_func(self.token_index), + token_centre_atom_index=pad_tokens_func(self.token_centre_atom_index), + token_ref_atom_index=pad_tokens_func(self.token_ref_atom_index), + token_exists_mask=pad_tokens_func(self.token_exists_mask), + token_backbone_frame_mask=pad_tokens_func(self.token_backbone_frame_mask), + token_backbone_frame_index=torch.cat( + [ + pad_tokens_func(self.token_backbone_frame_index[..., i]).unsqueeze( + -1 + ) + for i in range(3) + ], + dim=-1, + ), + token_asym_id=pad_tokens_func(self.token_asym_id), + token_entity_id=pad_tokens_func(self.token_entity_id), + token_sym_id=pad_tokens_func(self.token_sym_id), + token_entity_type=pad_tokens_func(self.token_entity_type), + token_residue_name=pad_tokens_func(self.token_residue_name), + token_b_factor_or_plddt=pad_tokens_func(self.token_b_factor_or_plddt), + # atom-level + atom_token_index=pad_atoms_func(self.atom_token_index), + atom_within_token_index=pad_atoms_func(self.atom_within_token_index), + atom_ref_pos=pad_atoms_func(self.atom_ref_pos), + atom_ref_mask=pad_atoms_func(self.atom_ref_mask), + atom_ref_element=pad_atoms_func(self.atom_ref_element), + atom_ref_charge=pad_atoms_func(self.atom_ref_charge), + atom_ref_name=self.atom_ref_name, + atom_ref_name_chars=pad_atoms_func(self.atom_ref_name_chars), + atom_ref_space_uid=pad_atoms_func(self.atom_ref_space_uid, pad_value=-1), + atom_is_not_padding_mask=pad_atoms_func(self.atom_is_not_padding_mask), + # supervision-only + atom_gt_coords=pad_atoms_func(self.atom_gt_coords), + atom_exists_mask=pad_atoms_func(self.atom_exists_mask), + # structure-level + pdb_id=pad_tokens_func(self.pdb_id), + source_pdb_chain_id=pad_tokens_func(self.source_pdb_chain_id), + subchain_id=pad_tokens_func(self.subchain_id), + resolution=self.resolution, + is_distillation=self.is_distillation, + symmetries=pad_atoms_func(self.symmetries, pad_value=-1), + ) + + @typecheck + @classmethod + def merge( + cls, + contexts: list["AllAtomStructureContext"], + ) -> "AllAtomStructureContext": + # indexes: + token_offsets = _exclusive_cum_lengths([x.token_residue_type for x in contexts]) + atom_offsets = _exclusive_cum_lengths([x.atom_token_index for x in contexts]) + + atom_token_index = torch.cat( + [x.atom_token_index + count for x, count in zip(contexts, token_offsets)] + ) + + token_centre_atom_index = torch.cat( + [ + x.token_centre_atom_index + count + for x, count in zip(contexts, atom_offsets) + ] + ) + token_ref_atom_index = torch.cat( + [x.token_ref_atom_index + count for x, count in zip(contexts, atom_offsets)] + ) + token_backbone_frame_index = torch.cat( + [ + x.token_backbone_frame_index + count + for x, count in zip(contexts, token_offsets) + ] + ) + + n_tokens = sum(x.num_tokens for x in contexts) + token_index = torch.arange(n_tokens, dtype=torch.int) + + # re-index the reference space from 0..n_tokens-1. + zero_indexed_ref_uids = [ + torch.unique_consecutive(x.atom_ref_space_uid, return_inverse=True)[1] + for x in contexts + ] + + ref_space_uids_offsets = _exclusive_cum_lengths( + [x.atom_ref_space_uid for x in contexts] + ) + atom_ref_space_uid = torch.cat( + [ + x + count + for x, count in zip(zero_indexed_ref_uids, ref_space_uids_offsets) + ], + ) + + # pad symmetric permutations to have same length + max_symms = max(x.symmetries.shape[-1] for x in contexts) + padded_symms = [ + torch.nn.functional.pad( + x.symmetries, (0, max_symms - x.symmetries.shape[-1]), value=-1 + ) + for x in contexts + ] + # offset symmetries by number of atoms in each chain + symm_mask = torch.cat([x >= 0 for x in padded_symms]) + symmetries = torch.cat(padded_symms) + symmetries = symmetries.masked_fill(~symm_mask, -1) + + return cls( + # token-level + token_residue_type=torch.cat([x.token_residue_type for x in contexts]), + token_residue_index=torch.cat([x.token_residue_index for x in contexts]), + token_index=token_index, + token_centre_atom_index=token_centre_atom_index, + token_ref_atom_index=token_ref_atom_index, + token_exists_mask=torch.cat([x.token_exists_mask for x in contexts]), + token_backbone_frame_mask=torch.cat( + [x.token_backbone_frame_mask for x in contexts] + ), + token_backbone_frame_index=token_backbone_frame_index, + token_asym_id=torch.cat([x.token_asym_id for x in contexts]), + token_entity_id=torch.cat([x.token_entity_id for x in contexts]), + token_sym_id=torch.cat([x.token_sym_id for x in contexts]), + token_entity_type=torch.cat([x.token_entity_type for x in contexts]), + token_residue_name=torch.cat([x.token_residue_name for x in contexts]), + token_b_factor_or_plddt=torch.cat( + [x.token_b_factor_or_plddt for x in contexts] + ), + # atom-level + atom_token_index=atom_token_index, + atom_within_token_index=torch.cat( + [x.atom_within_token_index for x in contexts] + ), + atom_ref_pos=torch.cat([x.atom_ref_pos for x in contexts]), + atom_ref_mask=torch.cat([x.atom_ref_mask for x in contexts]), + atom_ref_element=torch.cat([x.atom_ref_element for x in contexts]), + atom_ref_charge=torch.cat([x.atom_ref_charge for x in contexts]), + atom_ref_name=[x for context in contexts for x in context.atom_ref_name], + atom_ref_name_chars=torch.cat([x.atom_ref_name_chars for x in contexts]), + atom_ref_space_uid=atom_ref_space_uid, + atom_is_not_padding_mask=torch.cat( + [x.atom_is_not_padding_mask for x in contexts] + ), + # supervision only + atom_gt_coords=torch.cat([x.atom_gt_coords for x in contexts]), + atom_exists_mask=torch.cat([x.atom_exists_mask for x in contexts]), + # structure-level + pdb_id=torch.cat([x.pdb_id for x in contexts]), + source_pdb_chain_id=torch.cat([x.source_pdb_chain_id for x in contexts]), + subchain_id=torch.cat([x.subchain_id for x in contexts]), + resolution=torch.max( + torch.stack([x.resolution for x in contexts]), 0 + ).values, + is_distillation=torch.max( + torch.stack([x.is_distillation for x in contexts]), 0 + ).values, + symmetries=symmetries, + ) + + def to(self, device: torch.device | str) -> "AllAtomStructureContext": + dict_ = { + k: v.to(device) if torch.is_tensor(v) else v + for k, v in asdict(self).items() + } + return AllAtomStructureContext(**dict_) + + @property + def num_tokens(self) -> int: + (n_tokens,) = self.token_index.shape + return n_tokens + + @property + def num_atoms(self) -> int: + (n_atoms,) = self.atom_token_index.shape + return n_atoms + + def to_dict(self) -> dict[str, torch.Tensor]: + return asdict(self) + + +def _pad_func(x: Tensor, pad_size: int, pad_value: float | None = None) -> Tensor: + sizes = [0, 0] * (x.ndim - 1) + [0, pad_size] + return torch.nn.functional.pad(x, sizes, value=pad_value) + + +def _exclusive_cum_lengths(tensors: list[Int[Tensor, "n"]]): + lengths = torch.tensor([t.shape[0] for t in tensors]) + cum_lengths = torch.cumsum(lengths, dim=0).roll(1, 0) + cum_lengths[0] = 0 + return cum_lengths diff --git a/chai_lab/data/dataset/structure/chain.py b/chai_lab/data/dataset/structure/chain.py new file mode 100644 index 0000000..91908ef --- /dev/null +++ b/chai_lab/data/dataset/structure/chain.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + +from chai_lab.data.dataset.structure.all_atom_structure_context import ( + AllAtomStructureContext, +) +from chai_lab.data.parsing.structure.all_atom_entity_data import AllAtomEntityData + + +@dataclass +class Chain: + # The untokenized entity data + entity_data: AllAtomEntityData + + # The tokenized chain, derived from the entity data + structure_context: AllAtomStructureContext + + def __str__(self) -> str: + return f"{self.__class__.__name__}(entity_data={self.entity_data})" + + @property + def num_tokens(self) -> int: + return self.structure_context.num_tokens diff --git a/chai_lab/data/dataset/structure/utils.py b/chai_lab/data/dataset/structure/utils.py new file mode 100644 index 0000000..fb68408 --- /dev/null +++ b/chai_lab/data/dataset/structure/utils.py @@ -0,0 +1,153 @@ +from functools import lru_cache + +import torch +from torch import Tensor + +import chai_lab.data.residue_constants as rc +from chai_lab.utils.typing import Bool, Int + + +def get_centre_atom_name(residue_name: str) -> str: + if residue_name not in rc.standard_residue_pdb_codes: + raise ValueError(f"Residue {residue_name} is not a standard residue") + + if residue_name in { + "A", + "G", + "C", + "U", + "DA", + "DG", + "DC", + "DT", + }: + return "C1'" + else: + return "CA" + + +def get_reference_atom_name(residue_name: str) -> str: + if residue_name not in rc.standard_residue_pdb_codes: + raise ValueError(f"Residue {residue_name} is not a standard residue") + + if residue_name == "GLY": + return "CA" + elif residue_name in {"A", "G", "DA", "DG"}: + return "C4" + elif residue_name in {"C", "U", "DC", "DT"}: + return "C2" + else: + return "CB" + + +def get_centre_atom_index(atom_names: list[str], residue_name: str) -> Int[Tensor, "1"]: + # centre of the token is Calpha or C1' + name = get_centre_atom_name(residue_name) + + if name in atom_names: + idx = atom_names.index(name) + else: + raise ValueError( + f"Residue {residue_name} marked as standard, " + f"but reference conformer misses centre atom {name}. " + "Either the residue is not standard or reference conformer is wrong." + ) + + return torch.tensor([idx], dtype=torch.int) + + +def get_reference_atom_index( + atom_names: list[str], residue_name: str +) -> Int[Tensor, "1"]: + name = get_reference_atom_name(residue_name) + if name in atom_names: + idx = atom_names.index(name) + else: + raise ValueError( + f"Residue {residue_name} marked as standard, " + f"but reference conformer misses reference atom {name}. " + "Either the residue is not standard or reference conformer is wrong." + ) + + return torch.tensor([idx], dtype=torch.int) + + +def get_backbone_frame_atom_names(residue_name: str) -> tuple[str, str, str]: + """Return names of the 3 atoms used in canonical token frame.""" + if residue_name in { + "A", + "G", + "C", + "U", + "DA", + "DG", + "DC", + "DT", + }: + return "C1'", "C3'", "C4'" + if residue_name in rc.residue_atoms: + return "N", "CA", "C" + return "", "", "" + + +def backbone_atoms_all_present( + atom_names: list[str], residue_name: str +) -> Bool[Tensor, "1"]: + """Check if all *protein* backbone atoms are present in the list of atom names.""" + backbone_frame_atoms = get_backbone_frame_atom_names(residue_name) + if all(a == "" for a in backbone_frame_atoms): + # Not a nucleic acid or a protein residue + all_present = False + else: + all_present = all(name in atom_names for name in backbone_frame_atoms) + return torch.tensor([all_present], dtype=torch.bool) + + +def backbone_atoms_indices( + atom_names: list[str], residue_name: str +) -> Int[Tensor, "3"]: + """Return indices of backbone atoms N, Ca, C in the list of atom names.""" + backbone_frame_atom_names = get_backbone_frame_atom_names(residue_name) + + if backbone_atoms_all_present(atom_names, residue_name): + indices = [atom_names.index(name) for name in backbone_frame_atom_names] + else: + indices = [0, 0, 0] + + return torch.tensor(indices, dtype=torch.int) + + +@lru_cache(maxsize=1) +def atom_37_atom_indices() -> dict[tuple[str, str | None], int]: + num_protein_atoms = 37 + protein_res_atom_to_index: dict[tuple[str, str | None], int] = { + (residue_name, atom_name): atom_index + for residue_name in rc.residue_atoms.keys() + for atom_name, atom_index in rc.atom_order.items() + } + assert max(protein_res_atom_to_index.values()) == num_protein_atoms - 1 + + num_rna_atoms = 36 + # note: convert RNA residues to R{} to match residue names from residue_constants.py + rna_res_atom_to_index = { + (residue_name, atom_name): atom_index + for residue_name in {"A", "C", "G", "U"} + for atom_index, atom_name in enumerate( + rc.nucleic_acid_atoms[f"R{residue_name}"] + ) + } + assert max(rna_res_atom_to_index.values()) == num_rna_atoms - 1 + + num_dna_atoms = 36 + dna_res_atom_to_index = { + (residue_name, atom_name): atom_index + for residue_name in {"DA", "DC", "DG", "DT"} + for atom_index, atom_name in enumerate(rc.nucleic_acid_atoms[residue_name]) + } + assert max(dna_res_atom_to_index.values()) == num_dna_atoms - 1 + + return { + **protein_res_atom_to_index, + **rna_res_atom_to_index, + **dna_res_atom_to_index, + } diff --git a/chai_lab/data/dataset/templates/context.py b/chai_lab/data/dataset/templates/context.py new file mode 100644 index 0000000..f0ccefc --- /dev/null +++ b/chai_lab/data/dataset/templates/context.py @@ -0,0 +1,215 @@ +import logging +from dataclasses import asdict, dataclass + +import torch +from torch import Tensor +from torch.nn import functional as F + +from chai_lab.data import residue_constants as rc +from chai_lab.utils.defaults import default +from chai_lab.utils.typing import Bool, Float, Int, typecheck + +logger = logging.getLogger(__name__) + + +@typecheck +@dataclass(frozen=True) +class TemplateContext: + """Context for templates; always aligned by construction.""" + + template_restype: Int[Tensor, "n_templates n_tokens"] + template_pseudo_beta_mask: Bool[Tensor, "n_templates n_tokens"] + template_backbone_frame_mask: Bool[Tensor, "n_templates n_tokens"] + template_distances: Float[Tensor, "n_templates n_tokens n_tokens"] + template_unit_vector: Float[Tensor, "n_templates n_tokens n_tokens 3"] + + def __str__(self) -> str: + return ( + f"TemplateContext(num_templates={self.num_templates}, " + f"num_nonnull_templates={self.num_nonnull_templates}, " + f"num_tokens={self.num_tokens})" + ) + + @property + def num_tokens(self) -> int: + return self.template_restype.shape[1] + + @property + def num_templates(self) -> int: + return self.template_restype.shape[0] + + @property + def num_nonnull_templates(self) -> int: + """Number of templates that aren't all null padding templates.""" + template_exists = self.template_mask.any(dim=-1).int() + return int(template_exists.sum().item()) + + @property + def template_mask(self) -> Bool[Tensor, "n_templates n_tokens"]: + return self.template_restype != rc.residue_types_with_nucleotides_order["-"] + + def to_dict(self) -> dict[str, torch.Tensor]: + retval = asdict(self) + retval.update( + { + "num_templates": torch.tensor(self.num_nonnull_templates), + "template_mask": self.template_mask, + } + ) + return retval + + @classmethod + def empty(cls, n_templates: int, n_tokens: int) -> "TemplateContext": + return cls( + template_restype=torch.full( + (n_templates, n_tokens), + fill_value=rc.residue_types_with_nucleotides_order["-"], + dtype=torch.int32, + ), + template_pseudo_beta_mask=torch.zeros( + n_templates, n_tokens, dtype=torch.bool + ), + template_backbone_frame_mask=torch.zeros( + n_templates, n_tokens, dtype=torch.bool + ), + template_distances=torch.zeros( + n_templates, n_tokens, n_tokens, dtype=torch.float32 + ), + template_unit_vector=torch.zeros( + n_templates, n_tokens, n_tokens, 3, dtype=torch.float32 + ), + ) + + def index_select(self, idxs: Int[Tensor, "n"]) -> "TemplateContext": + return TemplateContext( + template_restype=self.template_restype[:, idxs], + template_pseudo_beta_mask=self.template_pseudo_beta_mask[:, idxs], + template_backbone_frame_mask=self.template_backbone_frame_mask[:, idxs], + template_distances=self.template_distances[:, idxs][:, :, idxs], + template_unit_vector=self.template_unit_vector[:, idxs][:, :, idxs], + ) + + # @classmethod + # def merge( + # cls, + # templates: list["TemplateContext"], + # ) -> "TemplateContext": + # """Merge template contexts along the template dimensions.""" + # # n_token can be simply concatenated + # logger.debug(f"Merging {len(templates)} templates") + + # # Handle case where we get an empty list (no templates to merge) + # if len(templates) == 0: + # return cls.empty(n_templates=4, n_tokens=1) + + # # Pad each template_restype's template_dimension to match the largest + # # NOTE count num_templates here, NOT num_nonnull_templates + # n_templates_new: int = max(t.num_templates for t in templates) + # padded_templates = [t.pad(max_templates=n_templates_new) for t in templates] + # new_template_restype = torch.cat( + # [t.template_restype for t in padded_templates], + # dim=1, # Concat on sequence dim + # ) + # new_template_pseudo_beta_mask = torch.cat( + # [t.template_pseudo_beta_mask for t in padded_templates], + # dim=1, + # ) + # new_template_backbone_frame_mask = torch.cat( + # [t.template_backbone_frame_mask for t in padded_templates], + # dim=1, + # ) + + # # Number of tokens after concatenation along token dim + # n_token_new = new_template_restype.shape[1] + + # # n_token x n_token must be tiled into a square matrix + # # These indices like [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 3, ...] indicate the region + # # of the square matrix that corresponds to each template. + # template_indices = torch.repeat_interleave( + # input=torch.arange(len(templates), device=new_template_restype.device), + # repeats=torch.tensor([t.template_restype.shape[-1] for t in templates]), + # ) + # # Sample template and token dim + # assert template_indices.shape[0] == n_token_new + + # new_template_distances = torch.zeros( + # n_templates_new, n_token_new, n_token_new, dtype=torch.float32 + # ) + # new_template_unit_vector = torch.zeros( + # n_templates_new, n_token_new, n_token_new, 3, dtype=torch.float32 + # ) + + # # For each template, find the block that it corresponds to and copy in the data + # for i, t in enumerate(templates): + # m = template_indices == i + # mask = m[:, None] * m[None, :] + # idx = torch.arange(t.template_distances.shape[0]) + # new_template_distances[idx.unsqueeze(1), mask] = ( + # t.template_distances.flatten(1, 2) + # ) + # new_template_unit_vector[idx.unsqueeze(1), mask] = ( + # t.template_unit_vector.flatten(1, 2) + # ) + + # return cls( + # template_restype=new_template_restype, + # template_pseudo_beta_mask=new_template_pseudo_beta_mask, + # template_backbone_frame_mask=new_template_backbone_frame_mask, + # template_distances=new_template_distances, + # template_unit_vector=new_template_unit_vector, + # ) + + def pad( + self, + max_templates: int | None = None, + max_tokens: int | None = None, + ) -> "TemplateContext": + """Pad to the given number of templates and tokens.""" + max_templates = default(max_templates, self.num_templates) + assert ( + self.num_templates <= max_templates + ), f"Cannot pad templates containing {self.num_templates} templates to {max_templates} templates" + n_pad_templates = max_templates - self.num_templates + + max_tokens = default(max_tokens, self.num_tokens) + assert ( + self.num_tokens <= max_tokens + ), f"Cannot pad templates containing {self.num_tokens} tokens to {max_tokens} tokens" + n_pad_tokens = max_tokens - self.num_tokens + + if n_pad_templates == 0 and n_pad_tokens == 0: # Exact match yay + return self + + logger.debug(f"Padding templates by {n_pad_templates=} {n_pad_tokens=}") + + # Padding works from last dim forward in pairs of padding (left, right) + # - (0, n_pad_tokens) = pad nothing on left, pad by n_pad_tokens on right for + # last dim + # - (0, 0, 0, n_pad_tokens, 0, n_pad_tokens) = pad nothing on last dim, but pad + # next two dims + pad_dims_template = (0, n_pad_templates) + pad_dims_token = (0, n_pad_tokens) + return TemplateContext( + template_restype=F.pad( + self.template_restype, + pad=pad_dims_token + pad_dims_template, + value=rc.residue_types_with_nucleotides_order["-"], + ), + template_pseudo_beta_mask=F.pad( + self.template_pseudo_beta_mask, + pad=pad_dims_token + pad_dims_template, + ), + template_backbone_frame_mask=F.pad( + self.template_backbone_frame_mask, + pad=pad_dims_token + pad_dims_template, + ), + template_distances=F.pad( + self.template_distances, + pad=pad_dims_token + pad_dims_token + pad_dims_template, + ), + template_unit_vector=F.pad( + self.template_unit_vector, + # This field has a final dimension of size 3, which we shouldn't pad + pad=(0, 0) + pad_dims_token + pad_dims_token + pad_dims_template, + ), + ) diff --git a/chai_lab/data/features/__init__.py b/chai_lab/data/features/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/features/feature_factory.py b/chai_lab/data/features/feature_factory.py new file mode 100644 index 0000000..a7baa41 --- /dev/null +++ b/chai_lab/data/features/feature_factory.py @@ -0,0 +1,22 @@ +"""Helper methods for generating model input features""" + +import logging + +from torch import Tensor + +from chai_lab.data.features.generators.base import FeatureGenerator + +logger = logging.getLogger(__name__) + + +class FeatureFactory: + generators: dict[str, FeatureGenerator] + + def __init__(self, generators: dict[str, FeatureGenerator]): + self.generators = generators + + def generate(self, batch) -> dict[str, Tensor]: + return {name: gen.generate(batch) for name, gen in self.generators.items()} + + def __repr__(self) -> str: + return f"Feature factory, {len(self.generators)=}" diff --git a/chai_lab/data/features/feature_type.py b/chai_lab/data/features/feature_type.py new file mode 100644 index 0000000..49eb42e --- /dev/null +++ b/chai_lab/data/features/feature_type.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class FeatureType(Enum): + RESIDUE = "RESIDUE" + PAIR = "PAIR" + MSA = "MSA" + TEMPLATES = "TEMPLATES" + TOKEN = "TOKEN" + TOKEN_PAIR = "TOKEN_PAIR" + ATOM = "ATOM" + ATOM_PAIR = "ATOM_PAIR" diff --git a/chai_lab/data/features/feature_utils.py b/chai_lab/data/features/feature_utils.py new file mode 100644 index 0000000..ef25512 --- /dev/null +++ b/chai_lab/data/features/feature_utils.py @@ -0,0 +1,27 @@ +"""Utility classes and functions for feature representations""" + +from chai_lab.utils.typing import typecheck + + +@typecheck +def get_entry_for_key(data: dict, key: str): + """finds entry 'key' in data dictionary + + Parameters: + data: the dict to search in + key: the key to search for + + Example 1: + data=dict(foo=dict(bar="bar")) + key = "foo" + returns: dict(bar="bar") + Example 2: + data=dict(foo=dict(bar="bar")) + key = "foo/bar" + returns: "bar" + + """ + sub_keys, sub_dict = key.split("/"), data + for sub_key in sub_keys: + sub_dict = sub_dict[sub_key] + return sub_dict diff --git a/chai_lab/data/features/generators/__init__.py b/chai_lab/data/features/generators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/features/generators/atom_element.py b/chai_lab/data/features/generators/atom_element.py new file mode 100644 index 0000000..37be5cd --- /dev/null +++ b/chai_lab/data/features/generators/atom_element.py @@ -0,0 +1,30 @@ +import torch +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.typing import Int, typecheck + + +class AtomElementOneHot(FeatureGenerator): + def __init__( + self, + max_atomic_num: int = 128, + ): + super().__init__( + ty=FeatureType.ATOM, + encoding_ty=EncodingType.ONE_HOT, + can_mask=True, + num_classes=max_atomic_num + 1, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict(atomic_numbers=batch["inputs"]["atom_ref_element"]) + + @typecheck + def _generate(self, atomic_numbers: Int[Tensor, "b n"]) -> Tensor: + """see super class""" + return self.make_feature( + data=torch.clamp(atomic_numbers, max=self.num_classes).unsqueeze(-1), + ) diff --git a/chai_lab/data/features/generators/atom_name.py b/chai_lab/data/features/generators/atom_name.py new file mode 100644 index 0000000..2c0b742 --- /dev/null +++ b/chai_lab/data/features/generators/atom_name.py @@ -0,0 +1,27 @@ +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.typing import Int, typecheck + + +class AtomNameOneHot(FeatureGenerator): + def __init__( + self, + num_chars: int = 64, + ): + super().__init__( + ty=FeatureType.ATOM, + encoding_ty=EncodingType.ONE_HOT, + can_mask=True, + num_classes=num_chars, + mult=4, + ) + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict(atom_name_chars=batch["inputs"]["atom_ref_name_chars"]) + + @typecheck + def _generate(self, atom_name_chars: Int[Tensor, "b n 4"]) -> Tensor: + """see super class""" + return self.make_feature(data=atom_name_chars) diff --git a/chai_lab/data/features/generators/base.py b/chai_lab/data/features/generators/base.py new file mode 100644 index 0000000..f268937 --- /dev/null +++ b/chai_lab/data/features/generators/base.py @@ -0,0 +1,109 @@ +"""Feature Generator ABC and Default implementation""" + +from abc import ABC +from enum import Enum + +import torch +from beartype import beartype as typechecker +from torch import Tensor +from typing_extensions import assert_never + +from chai_lab.data.features.feature_type import FeatureType + + +class EncodingType(Enum): + ONE_HOT = "one-hot" + RBF = "rbf" + FOURIER = "fourier" + IDENTITY = "identity" + ESM = "esm" + OUTERSUM = "outersum" + + +def cast_feature( + feature: Tensor, + encoding_ty: EncodingType, +): + match encoding_ty: + case EncodingType.IDENTITY: + feature = feature.float() + # safety check + assert feature.abs().max() < 100, feature + return feature + case EncodingType.RBF | EncodingType.FOURIER: + assert feature.dtype in (torch.float16, torch.float32, torch.bfloat16) + return feature + case EncodingType.ONE_HOT | EncodingType.OUTERSUM: + if feature.dtype not in { + torch.long, + torch.int, + torch.int16, + torch.int8, + torch.uint8, + }: + raise ValueError( + f"dtype {feature.dtype} is not a valid type for {encoding_ty}" + ) + return feature + case EncodingType.ESM: + return feature + + assert_never(encoding_ty) # Enum exhaustiveness check + + +class FeatureGenerator(ABC): + @typechecker + def __init__( + self, + ty: FeatureType, + encoding_ty: EncodingType, + num_classes: int = -1, + mult: int = 1, + ignore_index: float = -100.0, + can_mask: bool = True, # marks existing, but unknown values (e.g. atom position) + ): + self.ty = ty + self.encoding_ty = encoding_ty + self.num_classes = num_classes + self.mult = mult + self.ignore_index = ignore_index + self.can_mask = can_mask + + @property + def mask_value(self) -> int | float | Tensor: + """Get value used to mask this feature""" + match self.encoding_ty: + case EncodingType.ONE_HOT | EncodingType.OUTERSUM: + return self.num_classes + case EncodingType.FOURIER | EncodingType.RBF: + return -100.0 + case EncodingType.IDENTITY: + assert self.can_mask + mask = torch.zeros(self.num_classes + int(self.can_mask)) + mask[-1] = 1 # last channel is 1 for masked-out items + return mask + case EncodingType.ESM: + return 0.0 + + assert_never(self.encoding_ty) # Enum exhaustiveness check + + def generate(self, batch) -> Tensor: + """Generate a feature""" + kwargs = self.get_input_kwargs_from_batch(batch) + feature = self._generate(**kwargs) + return feature + + def _generate(self, *args, **kwargs) -> Tensor: + """Generate a feature""" + raise NotImplementedError("implement me") + + def get_input_kwargs_from_batch(self, batch) -> dict: + """Get input keyword arguments to pass to _generate""" + raise NotImplementedError("implement me") + + def make_feature(self, data: Tensor) -> Tensor: + """Checks and converts dtype if necessary""" + return cast_feature(data, encoding_ty=self.encoding_ty) + + def __repr__(self): + return f"[FeatureGenerator] : type: {self.ty}" diff --git a/chai_lab/data/features/generators/blocked_atom_pair_distances.py b/chai_lab/data/features/generators/blocked_atom_pair_distances.py new file mode 100644 index 0000000..95a9f49 --- /dev/null +++ b/chai_lab/data/features/generators/blocked_atom_pair_distances.py @@ -0,0 +1,171 @@ +from typing import Any, Literal + +import torch +from einops import rearrange +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.tensor_utils import cdist +from chai_lab.utils.typing import Bool, Float, Int, typecheck + +_VALID_ENCODING_TYPES = [ + EncodingType.IDENTITY, +] +DEFAULT_ONE_HOT_DIST_BINS = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 12.0, 16.0] +DEFAULT_RBF_DIST_BINS = [0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0] + + +class BlockedAtomPairDistances(FeatureGenerator): + transform: Literal["none", "inverse_squared"] + + def __init__( + self, + encoding_ty: EncodingType = EncodingType.IDENTITY, + transform: Literal["none", "inverse_squared"] = "inverse_squared", + ): + assert ( + encoding_ty in _VALID_ENCODING_TYPES + ), f"invalid encoding type: {encoding_ty}" + + # initialize superclass after augmenting input params =O. + super().__init__( + ty=FeatureType.ATOM_PAIR, + encoding_ty=encoding_ty, + # one of dist_bins of rbf_radii is not None. + num_classes=1, + mult=1, + can_mask=True, + ) + self.transform = transform + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + atom_ref_pos=batch["inputs"]["atom_ref_pos"], + atom_ref_mask=batch["inputs"]["atom_ref_mask"], + atom_ref_space_uid=batch["inputs"]["atom_ref_space_uid"], + q_idces=batch["inputs"]["block_atom_pair_q_idces"], + kv_idces=batch["inputs"]["block_atom_pair_kv_idces"], + block_atom_pair_mask=batch["inputs"]["block_atom_pair_mask"], + ) + + @typecheck + def _generate( + self, + atom_ref_pos: Float[Tensor, "b n 3"], + atom_ref_mask: Bool[Tensor, "b n"], + atom_ref_space_uid: Int[Tensor, "b n"], + q_idces: Int[Tensor, "bl bl_q"], + kv_idces: Int[Tensor, "bl bl_kv"], + block_atom_pair_mask: Bool[Tensor, "b bl bl_q bl_kv"], + ) -> Tensor: + """see super class""" + + blocked_feat, blocked_mask = get_blocked_atom_pair_dists( + atom_ref_pos, + atom_ref_space_uid, + q_idces, + kv_idces, + block_atom_pair_mask, + ) + + if self.transform == "inverse_squared": + blocked_feat = 1 / (1 + blocked_feat**2) + + # return (B, n, n, 2) where ...,0 is the feature + # and ...,1 indicates if the value is masked + # because 0.0 has a meaning as a distance + + blocked_feat = blocked_feat.unsqueeze(-1) + blocked_mask = blocked_mask.unsqueeze(-1).float() + + return self.make_feature( + torch.cat( + [blocked_feat, blocked_mask], + dim=-1, + ) + ) + + +class BlockedAtomPairDistogram(FeatureGenerator): + dist_bins: Tensor + + def __init__( + self, + dist_bins: list[float] | None = None, + encoding_ty: EncodingType = EncodingType.ONE_HOT, + ): + if dist_bins is None and encoding_ty == EncodingType.ONE_HOT: + dist_bins = DEFAULT_ONE_HOT_DIST_BINS + elif dist_bins is None and encoding_ty == EncodingType.RBF: + dist_bins = DEFAULT_RBF_DIST_BINS + assert dist_bins is not None, "must provide dist_bins" + + # initialize superclass after augmenting input params =O. + super().__init__( + ty=FeatureType.ATOM_PAIR, + encoding_ty=encoding_ty, + # one of dist_bins of rbf_radii is not None. + num_classes=len(dist_bins) + 1, + mult=1, + can_mask=True, + ) + self.dist_bins = torch.tensor(dist_bins) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + atom_ref_pos=batch["inputs"]["atom_ref_pos"], + atom_ref_mask=batch["inputs"]["atom_ref_mask"], + atom_ref_space_uid=batch["inputs"]["atom_ref_space_uid"], + q_idces=batch["inputs"]["block_atom_pair_q_idces"], + kv_idces=batch["inputs"]["block_atom_pair_kv_idces"], + block_atom_pair_mask=batch["inputs"]["block_atom_pair_mask"], + ) + + @typecheck + def _generate( + self, + atom_ref_pos: Float[Tensor, "b n 3"], + atom_ref_mask: Bool[Tensor, "b n"], + atom_ref_space_uid: Int[Tensor, "b n"], + q_idces: Int[Tensor, "bl bl_q"], + kv_idces: Int[Tensor, "bl bl_kv"], + block_atom_pair_mask: Bool[Tensor, "b bl bl_q bl_kv"], + ) -> Tensor: + """see super class""" + feat, mask = get_blocked_atom_pair_dists( + atom_ref_pos, + atom_ref_space_uid, + q_idces, + kv_idces, + block_atom_pair_mask, + ) + if self.encoding_ty == EncodingType.ONE_HOT: + feat = torch.searchsorted(self.dist_bins.to(atom_ref_pos.device), feat) + feat.masked_fill_(~mask, self.mask_value) + + return self.make_feature(feat.unsqueeze(-1)) + + +@typecheck +def get_blocked_atom_pair_dists( + positions: Float[Tensor, "b a 3"], + atom_ref_space_uid: Int[Tensor, "b a"], + q_idx: Int[Tensor, "bl bl_q"], + kv_idx: Int[Tensor, "bl bl_kv"], + block_atom_pair_mask: Bool[Tensor, "b bl bl_q bl_kv"], +) -> tuple[Float[Tensor, "b bl bl_q bl_kv"], Bool[Tensor, "b bl bl_q bl_kv"]]: + q_pos = positions[:, q_idx] + kv_pos = positions[:, kv_idx] + + blocked_pair_dists = cdist(q_pos, kv_pos) # b bl bl_q bl_kv + + atom_ref_space_q = atom_ref_space_uid[:, q_idx] + atom_ref_space_kv = atom_ref_space_uid[:, kv_idx] + block_same_atom_ref_space = rearrange( + atom_ref_space_q, "b bl a_q -> b bl a_q 1" + ) == rearrange(atom_ref_space_kv, "b bl a_kv -> b bl 1 a_kv") + + block_atom_pair_mask &= block_same_atom_ref_space + + return blocked_pair_dists, block_atom_pair_mask diff --git a/chai_lab/data/features/generators/docking.py b/chai_lab/data/features/generators/docking.py new file mode 100644 index 0000000..6bcc726 --- /dev/null +++ b/chai_lab/data/features/generators/docking.py @@ -0,0 +1,367 @@ +import logging +import random +from dataclasses import dataclass +from typing import Any + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.data.features.generators.token_pair_distance import TokenCenterDistance +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.model.utils import get_asym_id_from_subchain_id +from chai_lab.utils.defaults import default +from chai_lab.utils.tensor_utils import cdist, und, und_self +from chai_lab.utils.typing import Bool, Float, Int, UInt8, typecheck + +logger = logging.getLogger(__name__) + + +@typecheck +@dataclass +class ConstraintGroup: + """ + Container for a docking constraint group -- + collection of chains with inter/intra distance constraints + + This class can be used to specify a set of chains to be + grouped together for the docking feature + """ + + subchain_ids: list[str] + noise_sigma: float + dropout_prob: float + atom_center_mask: list[Bool[Tensor, "_"]] + atom_center_coords: list[Float[Tensor, "_ 3"]] + + def __post_init__(self) -> None: + """Ensure params are consistent""" + assert len(self.subchain_ids) == len( + self.atom_center_coords + ), f"{len(self.subchain_ids)=}, {len(self.atom_center_coords)=}" + assert len(self.subchain_ids) == len( + self.atom_center_mask + ), f"{len(self.subchain_ids)=}, {len(self.atom_center_mask)=}" + assert all( + [ + len(mask) == len(coord) + for coord, mask in zip(self.atom_center_coords, self.atom_center_mask) + ] + ), ( + f"{[len(x) for x in self.atom_center_coords]=}, " + f"{[len(x) for x in self.atom_center_mask]=}" + ) + + def get_asym_ids( + self, + token_subchain_id: UInt8[Tensor, "n 4"], + token_asym_id: Int[Tensor, "n"], + ) -> list[int]: + return [ + get_asym_id_from_subchain_id( + subchain_id=subchain_id, + source_pdb_chain_id=token_subchain_id, + token_asym_id=token_asym_id, + ) + for subchain_id in self.subchain_ids + ] + + def __str__(self): + return ( + f"ConstraintGroup(subchain_ids={self.subchain_ids}, " + f"atom_center_coords.shape={[x.shape for x in self.atom_center_coords]}, " + f"atom_center_mask.shape={[x.shape for x in self.atom_center_mask]})" + ) + + +class DockingConstraintGenerator(FeatureGenerator): + """Docking Feature Generator + + Works as follows: + separate input chains into two groups by randomly + partitioning asym_id's. + Provide all token-center distances for chains within the + same asm_id group. + Mask all token-center distances for chains within the + different asm_id groups. + + """ + + def __init__( + self, + dist_bins: list[float] | None = None, + coord_noise: tuple[float, float] = (0.0, 3.0), + include_probability: float = 0.1, + structure_dropout_prob: float = 0.0, + chain_dropout_prob: float = 0.0, + entity_types: list[EntityType] | None = None, + ): + dist_bins = dist_bins if dist_bins is not None else [0.0, 4.0, 8.0, 16.0] + super().__init__( + ty=FeatureType.TOKEN_PAIR, + encoding_ty=EncodingType.ONE_HOT, + # one of dist_bins of rbf_radii is not None. + num_classes=len(dist_bins) + 1, + mult=1, + can_mask=True, + ) + self.token_dist_gen = TokenCenterDistance(dist_bins=dist_bins) + + # maintain consistent orders + self.coord_noise = coord_noise + self.include_probability = include_probability + self.structure_dropout_prob = structure_dropout_prob + self.chain_dropout_prob = chain_dropout_prob + self.entity_types = set( + [x.value for x in default(entity_types, [e for e in EntityType])] + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + maybe_constraint_dicts = batch["inputs"].get("docking_constraints", [[None]])[0] + docking_constraints = batch["inputs"]["docking_constraints"] = ( + [ConstraintGroup(**d) for d in maybe_constraint_dicts] + if isinstance(maybe_constraint_dicts[0], dict) + else None + ) + + return dict( + all_atom_positions=batch["inputs"]["atom_gt_coords"], + all_atom_mask=batch["inputs"]["atom_exists_mask"], + token_single_mask=batch["inputs"]["token_exists_mask"], + token_center_atom_index=batch["inputs"]["token_centre_atom_index"].long(), + token_asym_id=batch["inputs"]["token_asym_id"].long(), + token_subchain_id=batch["inputs"]["subchain_id"], + token_entity_type=batch["inputs"]["token_entity_type"].long(), + constraints=docking_constraints, + ) + + def apply_structure_dropout( + self, feature: Tensor, prob: float | None = None + ) -> Tensor: + prob = default(prob, torch.rand(1).item()) + dropout_single_mask = torch.rand_like(feature.data[..., 0, 0].float()) < prob + dropout_pair_mask = und_self(dropout_single_mask, "b i, b j -> b i j") + feature = feature.masked_fill(dropout_pair_mask.unsqueeze(-1), self.mask_value) + return feature + + def apply_chain_dropout( + self, feature: Tensor, token_asym_id: Int[Tensor, "b n"] + ) -> Tensor: + structure_masks = [] + for i in range(token_asym_id.shape[0]): + data_i, asym_i = feature.data[i], token_asym_id[i] + unique_asyms = torch.unique(asym_i[asym_i != 0]).tolist() + random.shuffle(unique_asyms) # select chains to mask at random + selected_asyms = unique_asyms[: random.randint(0, len(unique_asyms))] + if len(selected_asyms) == 0: + structure_masks.append( + torch.zeros_like(data_i[..., 0], dtype=torch.bool) + ) + continue + asyms_to_mask = torch.tensor(selected_asyms, device=data_i.device) + asym_mask = torch.any(asym_i.unsqueeze(-1) == asyms_to_mask, dim=-1) + structure_mask = und_self(asym_mask, "i, j -> i j") + structure_masks.append(structure_mask) + feature_mask = torch.stack(structure_masks, dim=0) + feature = feature.masked_fill(feature_mask.unsqueeze(-1), self.mask_value) + return feature + + @typecheck + def _generate( + self, + all_atom_positions: Float[Tensor, "b a 3"], + all_atom_mask: Bool[Tensor, "b a"], + token_single_mask: Bool[Tensor, "b n"], + token_center_atom_index: Int[Tensor, "b n"], + token_asym_id: Int[Tensor, "b n"], + token_entity_type: Int[Tensor, "b n"], + token_subchain_id: UInt8[Tensor, "b n 4"], + constraints: list[ConstraintGroup] | None = None, + ) -> Tensor: + try: + if constraints is not None: + assert all_atom_positions.shape[0] == 1 + return self._generate_from_constraints( + token_asym_id=token_asym_id, + token_subchain_id=token_subchain_id, + constraints=constraints, + ) + except Exception as e: + logger.error(f"Error {e} generating docking constraints: {constraints}") + + return self._generate_from_batch( + all_atom_positions=all_atom_positions, + all_atom_mask=all_atom_mask, + token_single_mask=token_single_mask, + token_center_atom_index=token_center_atom_index, + token_asym_id=token_asym_id, + token_entity_type=token_entity_type, + ) + + def _asym_to_entity_type( + self, asym_id: Int[Tensor, "n"], entity_type: Int[Tensor, "n"] + ) -> dict[int, int]: + unique_asyms: Tensor = torch.unique(asym_id[asym_id != 0]) + mapping = dict() + for asym in unique_asyms.tolist(): + asym_mask = asym_id == asym + mapping[int(asym)] = int(entity_type[asym_mask][0].item()) + return mapping + + @typecheck + def _generate_from_batch( + self, + all_atom_positions=Float[Tensor, "b a 3"], + all_atom_mask=Bool[Tensor, "b a"], + token_single_mask=Bool[Tensor, "b n"], + token_center_atom_index=Int[Tensor, "b n"], + token_entity_type=Int[Tensor, "b n"], + token_asym_id=Int[Tensor, "b n"], + ) -> Tensor: + sampled_noise = random.uniform(self.coord_noise[0], self.coord_noise[1]) + token_center_dists = self.token_dist_gen._generate( + all_atom_positions=all_atom_positions + + torch.randn_like(all_atom_positions) * sampled_noise, + all_atom_mask=all_atom_mask, + token_single_mask=token_single_mask, + token_center_atom_index=token_center_atom_index, + ).data + for i in range(token_center_dists.shape[0]): + asym_to_entity = self._asym_to_entity_type( + token_asym_id[i], token_entity_type[i] + ) + asym_include_list = [ + asym for asym, ety in asym_to_entity.items() if ety in self.entity_types + ] + asym_exclude_list = [ + asym + for asym, ety in asym_to_entity.items() + if ety not in self.entity_types + ] + # exclude other entity types + asym_exclude_mask = torch.any( + (token_asym_id[i].unsqueeze(-1) == torch.tensor(asym_exclude_list)), + dim=-1, + ) + token_center_dists[i, asym_exclude_mask] = self.mask_value + token_center_dists[i, :, asym_exclude_mask] = self.mask_value + if ( + random.random() < self.include_probability + and len(asym_include_list) > 1 + ): + # include distances between select chains + random.shuffle(asym_include_list) + partition_idx = random.randint(1, len(asym_include_list) - 1) + _group_1, _group_2 = ( + asym_include_list[:partition_idx], + asym_include_list[partition_idx:], + ) + group_1, group_2 = torch.tensor(_group_1), torch.tensor(_group_2) + # find positions of elements in first and second group + group1_mask, group2_mask = [ + torch.any((token_asym_id[i].unsqueeze(-1) == x), dim=-1) + for x in (group_1, group_2) + ] + partition_mask = und(group1_mask, group2_mask, "i, j -> i j") + token_center_dists[i] = token_center_dists[i].masked_fill( + (partition_mask | partition_mask.T).unsqueeze(-1), self.mask_value + ) + else: + mask = torch.ones_like(token_center_dists[i], dtype=torch.bool) + token_center_dists[i] = token_center_dists[i].masked_fill( + mask, self.mask_value + ) + + feature = self.make_feature(token_center_dists) + if random.random() < self.structure_dropout_prob: + feature = self.apply_structure_dropout(feature) + elif random.random() < self.chain_dropout_prob: + feature = self.apply_chain_dropout(feature, token_asym_id) + return feature + + @typecheck + def _generate_from_constraints( + self, + # constraints only supported with batch size 1 + token_asym_id: Int[Tensor, "1 n"], + token_subchain_id: UInt8[Tensor, "1 n 4"], + constraints: list[ConstraintGroup], + ) -> Tensor: + logger.info(f"Generating docking feature from constraints: {constraints}") + n, device = token_asym_id.shape[1], token_asym_id.device + constraint_mat = torch.zeros(n, n, device=device, dtype=torch.float32) + constraint_mask = torch.zeros(n, n, device=device, dtype=torch.bool) + for constraint_group in constraints: + # add constraints between members of each group + coords = [ + x + torch.randn_like(x) * constraint_group.noise_sigma + for x in constraint_group.atom_center_coords + ] + n_chains = len(constraint_group.subchain_ids) + l_idx, r_idx = torch.triu_indices(n_chains, n_chains) + chain_asyms = constraint_group.get_asym_ids( + token_subchain_id=rearrange(token_subchain_id, "1 ... -> ..."), + token_asym_id=rearrange(token_asym_id, "1 ... -> ..."), + ) + for i, j in zip(l_idx.tolist(), r_idx.tolist()): + constraint_mat, constraint_mask = self.add_constraint( + constraint_mat=constraint_mat, + constraint_mask=constraint_mask, + token_asym_id=rearrange(token_asym_id, "1 ... -> ..."), + chain1_asym_id=chain_asyms[i], + chain2_asym_id=chain_asyms[j], + chain1_coords=coords[i], + chain1_mask=constraint_group.atom_center_mask[i], + chain2_coords=coords[j], + chain2_mask=constraint_group.atom_center_mask[j], + ) + # encode and apply mask + feat = torch.searchsorted( + self.token_dist_gen.dist_bins.to(constraint_mat.device), constraint_mat + ) + feat = feat.masked_fill(~constraint_mask, self.mask_value) + # add back batch dim + constraint_mat = repeat(feat, "i j -> 1 i j 1") + # apply structure dropout + dropout = constraints[0].dropout_prob if len(constraints) > 0 else 0.0 + feature = self.make_feature(constraint_mat) + feature = self.apply_structure_dropout(feature, prob=dropout) + return feature + + @typecheck + def add_constraint( + self, + constraint_mat: Float[Tensor, "n n"], + constraint_mask: Bool[Tensor, "n n"], + token_asym_id: Int[Tensor, "n"], + chain1_asym_id: int, + chain2_asym_id: int, + chain1_coords: Float[Tensor, "c1 3"], + chain2_coords: Float[Tensor, "c2 3"], + chain1_mask: Bool[Tensor, "c1"], + chain2_mask: Bool[Tensor, "c2"], + ) -> tuple[Float[Tensor, "n n"], Bool[Tensor, "n n"]]: + (c1_posns,) = torch.where(token_asym_id == chain1_asym_id) + (c2_posns,) = torch.where(token_asym_id == chain2_asym_id) + # make sure we have a coordinate for each position + assert len(c1_posns) == len( + chain1_coords + ), f"{c1_posns.shape=}, {chain1_coords.shape=}" + assert len(c2_posns) == len( + chain2_coords + ), f"{c2_posns.shape=}, {chain2_coords.shape=}" + + pairwise_dists = cdist(chain1_coords, chain2_coords) + pairwise_mask = und(chain1_mask, chain2_mask, "i, j -> i j") + pairwise_dists[~pairwise_mask] = -1.0 + # mask and fill the constraint matrix + row_idxs = repeat(c1_posns, "i -> i c", c=len(c2_posns)) + col_idxs = repeat(c2_posns, "j -> r j", r=len(c1_posns)) + # fill constraints and mask + constraint_mat[row_idxs, col_idxs] = pairwise_dists + constraint_mat[col_idxs, row_idxs] = pairwise_dists + constraint_mask[row_idxs, col_idxs] = pairwise_mask + constraint_mask[col_idxs, row_idxs] = pairwise_mask + return constraint_mat, constraint_mask diff --git a/chai_lab/data/features/generators/esm_generator.py b/chai_lab/data/features/generators/esm_generator.py new file mode 100644 index 0000000..671d482 --- /dev/null +++ b/chai_lab/data/features/generators/esm_generator.py @@ -0,0 +1,30 @@ +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.typing import Float, typecheck + + +class ESMEmbeddings(FeatureGenerator): + def __init__( + self, + ty: FeatureType = FeatureType.TOKEN, + ): + super().__init__( + ty=ty, + encoding_ty=EncodingType.ESM, + can_mask=False, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict( + esm_embeddings=batch["inputs"]["esm_embeddings"], + ) + + @typecheck + def _generate( + self, + esm_embeddings: Float[Tensor, "batch num_tokens d_emb"], + ) -> Tensor: + return self.make_feature(data=esm_embeddings) diff --git a/chai_lab/data/features/generators/identity.py b/chai_lab/data/features/generators/identity.py new file mode 100644 index 0000000..f4d1107 --- /dev/null +++ b/chai_lab/data/features/generators/identity.py @@ -0,0 +1,44 @@ +import torch +from einops import rearrange +from torch import Tensor + +import chai_lab.data.features.feature_utils as futils +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator + + +class Identity(FeatureGenerator): + def __init__( + self, + key: str, + ty: FeatureType, + dim: int, + can_mask: bool = True, + ): + super().__init__( + ty=ty, + encoding_ty=EncodingType.IDENTITY, + mult=1, + num_classes=dim, + can_mask=can_mask, + ) + self.key = key + self.dim = dim + + def generate(self, batch: dict) -> Tensor: + feat = futils.get_entry_for_key(batch, self.key) + + if feat.ndim == 2: # scalar feature + assert self.dim == 1 + feat = rearrange(feat, "b n -> b n 1") + elif feat.ndim == 3: + # feature made from sequence-wise vectors (shape b,n,d) + assert self.dim == feat.shape[-1] + else: + raise ValueError( + f"Input to feature generator has ndim={feat.ndim}, shape {feat.shape}" + ) + + if self.can_mask: # append position for mask token if feat can be masked + feat = torch.cat((feat, torch.zeros_like(feat)[..., :1]), dim=-1) + return self.make_feature(data=feat) diff --git a/chai_lab/data/features/generators/is_cropped_chain.py b/chai_lab/data/features/generators/is_cropped_chain.py new file mode 100644 index 0000000..9bc3a64 --- /dev/null +++ b/chai_lab/data/features/generators/is_cropped_chain.py @@ -0,0 +1,32 @@ +import torch +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.typing import Int, typecheck + + +class ChainIsCropped(FeatureGenerator): + def __init__( + self, + ): + """Chain-level feature that indicates if a chain has been cropped""" + super().__init__( + ty=FeatureType.TOKEN, + can_mask=False, + encoding_ty=EncodingType.IDENTITY, + num_classes=1, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict( + token_asym_id=batch["inputs"]["token_asym_id"].long(), + ) + + @typecheck + def _generate( + self, + token_asym_id: Int[Tensor, "b n"], + ) -> Tensor: + return self.make_feature(torch.zeros_like(token_asym_id).unsqueeze(-1)) diff --git a/chai_lab/data/features/generators/missing_chain_contact.py b/chai_lab/data/features/generators/missing_chain_contact.py new file mode 100644 index 0000000..9171269 --- /dev/null +++ b/chai_lab/data/features/generators/missing_chain_contact.py @@ -0,0 +1,92 @@ +import torch +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.tensor_utils import cdist, und_self +from chai_lab.utils.typing import Bool, Float, Int, typecheck + + +class MissingChainContact(FeatureGenerator): + contact_threshold: float + + def __init__( + self, + # Use DockQ atom contact cutoff as default + contact_threshold: float = 6.0, + ): + """Token-Level feature that indicates is a chain has no tokens + in contact with tokens from another chain. + """ + super().__init__( + ty=FeatureType.TOKEN, + can_mask=False, + encoding_ty=EncodingType.IDENTITY, + num_classes=1, + mult=1, + ) + self.contact_threshold = contact_threshold + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict( + atom_gt_coords=batch["inputs"]["atom_gt_coords"], + atom_exists_mask=batch["inputs"]["atom_exists_mask"], + token_exists_mask=batch["inputs"]["token_exists_mask"], + token_asym_id=batch["inputs"]["token_asym_id"].long(), + atom_token_index=batch["inputs"]["atom_token_index"].long(), + ) + + @typecheck + def _generate( + self, + atom_gt_coords: Float[Tensor, "b a 3"], + atom_exists_mask: Bool[Tensor, "b a"], + token_exists_mask: Bool[Tensor, "b n"], + token_asym_id: Int[Tensor, "b n"], + atom_token_index: Int[Tensor, "b a"], + ) -> Tensor: + # per-atom asym id + atom_asym_id = torch.gather(token_asym_id, dim=1, index=atom_token_index.long()) + # compute atom pair distances and mask + atom_pair_dist = cdist(atom_gt_coords) + atom_pair_mask = und_self(atom_exists_mask, "b i, b j -> b i j") + atom_pair_asym_mask = atom_asym_id.unsqueeze(-1) != atom_asym_id.unsqueeze(-2) + aggregate_mask = ( + atom_pair_mask + & atom_pair_asym_mask + & (atom_pair_dist < self.contact_threshold) + ) + # determine which atoms are in contact with some atom from another chain + atom_in_contact = aggregate_mask.any(dim=-1) + # determine if any chain has no atoms in contact with another chain + chain_contact_features: list[torch.Tensor] = [] + for b in range(atom_gt_coords.shape[0]): + unique_chain_asyms = torch.unique(token_asym_id[b][token_exists_mask[b]]) + if len(unique_chain_asyms) == 1: + # monomers are set to have no missing contacts + chain_contact_features.append( + torch.zeros_like( + token_asym_id[b].unsqueeze(-1), dtype=torch.float32 + ) + ) + continue + unique_asyms_with_contacts = torch.unique( + atom_asym_id[b][atom_in_contact[b]] + ) + unique_chain_asyms, unique_asyms_with_contacts = [ + set(x.tolist()) + for x in (unique_chain_asyms, unique_asyms_with_contacts) + ] + asyms_without_contacts = torch.tensor( + list(unique_chain_asyms - unique_asyms_with_contacts) + ) + # create feature data for this chain + feat = torch.any( + token_asym_id[b].unsqueeze(-1) == asyms_without_contacts, + dim=-1, + keepdim=True, + ) + chain_contact_features.append(feat.float()) + + # make the feature + return self.make_feature(torch.stack(chain_contact_features, dim=0)) diff --git a/chai_lab/data/features/generators/msa.py b/chai_lab/data/features/generators/msa.py new file mode 100644 index 0000000..c06e467 --- /dev/null +++ b/chai_lab/data/features/generators/msa.py @@ -0,0 +1,241 @@ +from typing import Any + +import torch +from einops import rearrange +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.data.parsing.msas.data_source import msa_dataset_source_to_int +from chai_lab.data.parsing.msas.species import UNKNOWN_SPECIES +from chai_lab.data.residue_constants import residue_types_with_nucleotides_order +from chai_lab.utils.tensor_utils import masked_mean +from chai_lab.utils.typing import Bool, Int, UInt8, typecheck + + +class MSAFeatureGenerator(FeatureGenerator): + """Generates feature for one-hot encoding of processed MSA, same classes as restype.""" + + def __init__(self): + num_res_ty = len(residue_types_with_nucleotides_order) + super().__init__( + ty=FeatureType.MSA, + encoding_ty=EncodingType.ONE_HOT, + can_mask=False, + num_classes=num_res_ty, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + msa_tokens=batch["inputs"]["msa_tokens"], + ) + + @typecheck + def _generate( + self, + msa_tokens: UInt8[Tensor, "batch depth tokens"], + ) -> Tensor: + """Generate based on an input of one-hot encoded MSA""" + return self.make_feature(data=msa_tokens.unsqueeze(-1)) + + +class MSAHasDeletionGenerator(FeatureGenerator): + """Binary feature for if there is a deletion to the left of each position.""" + + def __init__(self): + super().__init__( + ty=FeatureType.MSA, + encoding_ty=EncodingType.IDENTITY, + can_mask=False, + num_classes=1, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict(msa_deletion_matrix=batch["inputs"]["msa_deletion_matrix"]) + + @typecheck + def _generate( + self, + msa_deletion_matrix: UInt8[Tensor, "batch depth tokens"], + ) -> Tensor: + has_deletion = msa_deletion_matrix > 0 + return self.make_feature(data=has_deletion.unsqueeze(-1)) + + +class MSADeletionValueGenerator(FeatureGenerator): + """Raw deletion counts left of the current position, with addtional scaling. + Scaling is given by s(d) = 2 / pi * arctan(d / 3) + """ + + def __init__(self): + super().__init__( + ty=FeatureType.MSA, + encoding_ty=EncodingType.IDENTITY, + can_mask=False, + num_classes=1, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict(msa_deletion_matrix=batch["inputs"]["msa_deletion_matrix"]) + + @typecheck + def _generate( + self, + msa_deletion_matrix: UInt8[Tensor, "batch depth tokens"], + ) -> Tensor: + d_scaled = 2.0 / torch.pi * torch.arctan(msa_deletion_matrix.float() / 3.0) + return self.make_feature(data=d_scaled.unsqueeze(-1)) + + +class MSAProfileGenerator(FeatureGenerator): + """MSA profile - distribution across residue types BEFORE processing""" + + def __init__(self): + self.num_res_ty = len(residue_types_with_nucleotides_order) + super().__init__( + ty=FeatureType.TOKEN, + encoding_ty=EncodingType.IDENTITY, + can_mask=False, + num_classes=self.num_res_ty, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + main_msa_tokens=batch["inputs"]["main_msa_tokens"], + main_msa_mask=batch["inputs"]["main_msa_mask"], + ) + + @typecheck + def _generate( + self, + main_msa_tokens: UInt8[Tensor, "batch depth tokens"], + main_msa_mask: Bool[Tensor, "batch depth tokens"], + ) -> Tensor: + """Optimized implementation based on torch.scatter_add""" + batch, _, tokens = main_msa_tokens.shape + + unnormalized_profile = torch.zeros( + (batch, tokens, self.num_res_ty), dtype=main_msa_tokens.dtype + ).scatter_add( + dim=2, + index=rearrange( + main_msa_tokens.long(), "batch depth tokens -> batch tokens depth" + ), + src=rearrange( + main_msa_mask.to(main_msa_tokens.dtype), + "batch depth tokens -> batch tokens depth", + ), + ) + denom = unnormalized_profile.sum(dim=-1, keepdim=True).clamp_min_(1) + profile = unnormalized_profile / denom + + return self.make_feature(data=profile) + + +class MSADeletionMeanGenerator(FeatureGenerator): + """MSA deletion mean - mean number of deletions at each position in main MSA.""" + + def __init__(self): + super().__init__( + ty=FeatureType.TOKEN, + encoding_ty=EncodingType.IDENTITY, + can_mask=False, + num_classes=1, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + main_msa_mask=batch["inputs"]["main_msa_mask"], + main_msa_deletion_matrix=batch["inputs"]["main_msa_deletion_matrix"], + ) + + @typecheck + def _generate( + self, + main_msa_mask: Bool[Tensor, "batch depth tokens"], + main_msa_deletion_matrix: UInt8[Tensor, "batch depth tokens"], + ) -> Tensor: + """Mean number of deletions at each position in main MSA.""" + # Average out the depth to get per-tokens + mean_deletion_matrix = masked_mean( + mask=main_msa_mask, value=main_msa_deletion_matrix.float(), dim=1 + ) + return self.make_feature(data=mean_deletion_matrix.unsqueeze(-1)) + + +class IsPairedMSAGenerator(FeatureGenerator): + """ + Relative species encoding within each MSA sequence + """ + + def __init__(self): + super().__init__( + ty=FeatureType.MSA, + encoding_ty=EncodingType.IDENTITY, + can_mask=False, + num_classes=1, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + msa_mask=batch["inputs"]["msa_mask"], + msa_species=batch["inputs"]["msa_species"], + ) + + @typecheck + def _generate( + self, + msa_mask: Bool[Tensor, "batch depth tokens"], + msa_species: Int[Tensor, "batch depth tokens"], + ) -> Tensor: + first_species = msa_species[..., :1] + + is_paired = (msa_species == first_species).to(torch.uint8) + + mask = msa_mask & (msa_species != UNKNOWN_SPECIES) + is_paired = is_paired.masked_fill(~mask, 0) + + return self.make_feature(data=is_paired.unsqueeze(-1)) + + +class MSADataSourceGenerator(FeatureGenerator): + """ + MSA data source for each MSA token + """ + + def __init__( + self, + num_classes: int = 5, + ): + assert num_classes == max(msa_dataset_source_to_int.values()) + 1 + + super().__init__( + ty=FeatureType.MSA, + encoding_ty=EncodingType.ONE_HOT, + can_mask=True, + num_classes=num_classes, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + msa_mask=batch["inputs"]["msa_mask"], + msa_sequence_source=batch["inputs"]["msa_sequence_source"], + ) + + @typecheck + def _generate( + self, + msa_mask: Bool[Tensor, "batch depth tokens"], + msa_sequence_source: UInt8[Tensor, "batch depth tokens"], + ) -> Tensor: + msa_sequence_source = msa_sequence_source.masked_fill( + ~msa_mask, self.num_classes + ) + + return self.make_feature(data=msa_sequence_source.unsqueeze(-1)) diff --git a/chai_lab/data/features/generators/ref_pos.py b/chai_lab/data/features/generators/ref_pos.py new file mode 100644 index 0000000..f210e31 --- /dev/null +++ b/chai_lab/data/features/generators/ref_pos.py @@ -0,0 +1,28 @@ +import torch + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator + +N_COORDS = 3 # 3 coords: x, y, z + + +class RefPos(FeatureGenerator): + """Provides reference position of atom""" + + def __init__(self): + super().__init__( + ty=FeatureType.ATOM, + encoding_ty=EncodingType.IDENTITY, + mult=1, + num_classes=N_COORDS, + can_mask=False, # we expect to always have valid pos? + ) + + def generate(self, batch: dict) -> torch.Tensor: + original_pos = batch["inputs"]["atom_ref_pos"] + feat = original_pos / 10.0 # better scale for embedding + assert torch.amax(feat.norm(dim=-1)) < 100.0, "wrong scale!" + assert feat.ndim == 3 + assert feat.shape[-1] == N_COORDS + + return self.make_feature(data=feat) diff --git a/chai_lab/data/features/generators/relative_chain.py b/chai_lab/data/features/generators/relative_chain.py new file mode 100644 index 0000000..47ef231 --- /dev/null +++ b/chai_lab/data/features/generators/relative_chain.py @@ -0,0 +1,50 @@ +import torch +from einops import rearrange +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.typing import Int, typecheck + + +class RelativeChain(FeatureGenerator): + def __init__( + self, + s_max: int = 2, + ): + """Relative Entity Encoding + + See algorithm 5 of AF-Multimer + """ + super().__init__( + ty=FeatureType.TOKEN_PAIR, + encoding_ty=EncodingType.ONE_HOT, + num_classes=2 * s_max + 2, + can_mask=False, + ) + self.s_max = s_max + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict( + entity_id=batch["inputs"]["token_entity_id"].long(), + sym_id=batch["inputs"]["token_sym_id"].long(), + ) + + @typecheck + def _generate( + self, + entity_id: Int[Tensor, "b n"], + sym_id: Int[Tensor, "b n"], + ) -> Tensor: + # remap unique sym_id values to 0,n-1 + _, sym_ids_from_zero = torch.unique(sym_id, sorted=True, return_inverse=True) + + rel_entity, rel_chain = map( + lambda x: rearrange(x, "b n -> b n 1") - rearrange(x, "b n -> b 1 n"), + (entity_id, sym_ids_from_zero), + ) + # within an entity, determine relative chain + rel_chain = torch.clamp(rel_chain + self.s_max, 0, 2 * self.s_max) + # mask out inter-entity features + rel_chain[rel_entity != 0] = 2 * self.s_max + 1 + return self.make_feature(rel_chain.unsqueeze(-1)) diff --git a/chai_lab/data/features/generators/relative_entity.py b/chai_lab/data/features/generators/relative_entity.py new file mode 100644 index 0000000..8cd3605 --- /dev/null +++ b/chai_lab/data/features/generators/relative_entity.py @@ -0,0 +1,43 @@ +import torch +from einops import rearrange +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.typing import Int, typecheck + + +class RelativeEntity(FeatureGenerator): + def __init__(self): + """Relative Entity Encoding + + See algorithm 5 of AF-Multimer + """ + super().__init__( + ty=FeatureType.TOKEN_PAIR, + encoding_ty=EncodingType.ONE_HOT, + num_classes=3, + can_mask=False, + ) + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict( + entity_id=batch["inputs"]["token_entity_id"].long(), + ) + + @typecheck + def _generate( + self, + entity_id: Int[Tensor, "b n"], + ) -> Tensor: + # remap unique sym_id values to 0,n-1 + _, entity_id_from_zero = torch.unique( + entity_id, sorted=True, return_inverse=True + ) + + rel_entity = rearrange(entity_id_from_zero, "b n -> b n 1") - rearrange( + entity_id_from_zero, "b n -> b 1 n" + ) + rel_entity = torch.clamp(rel_entity + 1, 0, 2) + assert rel_entity.dtype == torch.long + return self.make_feature(rel_entity.unsqueeze(-1)) diff --git a/chai_lab/data/features/generators/relative_sep.py b/chai_lab/data/features/generators/relative_sep.py new file mode 100644 index 0000000..2d39f24 --- /dev/null +++ b/chai_lab/data/features/generators/relative_sep.py @@ -0,0 +1,58 @@ +import torch +from einops import rearrange +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.typing import Int, typecheck + + +def get_sep_bins(max_offset: int) -> list[float]: + bins = torch.arange(-max_offset, max_offset + 1).float() + return bins.tolist() + + +SMALL_SEP_BINS = get_sep_bins(32) + + +class RelativeSequenceSeparation(FeatureGenerator): + def __init__( + self, + sep_bins: list[int] | list[float] | None = None, + num_bins: int | None = None, + ): + """Relative Sequence Separation Encoding""" + sep_bins = get_sep_bins(num_bins) if num_bins is not None else sep_bins + sep_bins = sep_bins if sep_bins is not None else SMALL_SEP_BINS + super().__init__( + ty=FeatureType.TOKEN_PAIR, + encoding_ty=EncodingType.ONE_HOT, + num_classes=len(sep_bins) + 2, + can_mask=False, + ) + self.sep_bins = torch.tensor(sep_bins) + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict( + residue_index=batch["inputs"]["token_residue_index"].long(), + asym_id=batch["inputs"]["token_asym_id"].long(), + ) + + @typecheck + def _generate( + self, + residue_index: Int[Tensor, "b n"], + asym_id: Int[Tensor, "b n"], + ) -> Tensor: + rel_sep, rel_chain = map( + lambda x: rearrange(x, "b n -> b n 1") - rearrange(x, "b n -> b 1 n"), + (residue_index, asym_id), + ) + encoded_feat = torch.searchsorted( + self.sep_bins.to(rel_sep.device), + rel_sep + 1e-4, # add small epsilon bc. bins are chosen by leftmost index + ) + same_chain_mask = rel_chain == 0 + # mask inter-chain sep + encoded_feat[~same_chain_mask] = self.num_classes - 1 + return self.make_feature(encoded_feat.unsqueeze(-1)) diff --git a/chai_lab/data/features/generators/relative_token.py b/chai_lab/data/features/generators/relative_token.py new file mode 100644 index 0000000..d7a2237 --- /dev/null +++ b/chai_lab/data/features/generators/relative_token.py @@ -0,0 +1,49 @@ +import torch +from einops import rearrange +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.typing import Int, typecheck + + +class RelativeTokenSeparation(FeatureGenerator): + def __init__( + self, + # using 16 for default here since values beyond this are very rare. + r_max: int = 16, + ): + super().__init__( + ty=FeatureType.TOKEN_PAIR, + encoding_ty=EncodingType.ONE_HOT, + num_classes=2 * r_max + 3, + can_mask=False, + ) + self.r_max = r_max + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict( + token_index=batch["inputs"]["token_index"], + token_residue_index=batch["inputs"]["token_residue_index"], + token_asym_id=batch["inputs"]["token_asym_id"], + ) + + @typecheck + def _generate( + self, + token_index: Int[Tensor, "b n"], + token_residue_index: Int[Tensor, "b n"], + token_asym_id: Int[Tensor, "b n"], + ) -> Tensor: + rel_sep, rel_residue, rel_chain = map( + lambda x: rearrange(x, "b n -> b n 1") - rearrange(x, "b n -> b 1 n"), + (token_index, token_residue_index, token_asym_id), + ) + + mask = (rel_residue == 0) & (rel_chain == 0) + + rel_sep = torch.clamp(rel_sep + self.r_max, 0, 2 * self.r_max + 1) + # zero inter-residue and inter-chain + rel_sep = rel_sep.masked_fill(~mask, 2 * self.r_max + 2) + + return self.make_feature(rel_sep.unsqueeze(-1)) diff --git a/chai_lab/data/features/generators/residue_type.py b/chai_lab/data/features/generators/residue_type.py new file mode 100644 index 0000000..704da00 --- /dev/null +++ b/chai_lab/data/features/generators/residue_type.py @@ -0,0 +1,53 @@ +import numpy as np +import torch +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.typing import Bool, Int, typecheck + + +class ResidueType(FeatureGenerator): + def __init__( + self, + min_corrupt_prob: float = 0.0, + max_corrupt_prob: float = 0.0, + num_res_ty: int = 22, # 20AA + gap + X + key: str = "aatype", + ): + super().__init__( + ty=FeatureType.TOKEN, + encoding_ty=EncodingType.ONE_HOT, + can_mask=True, + num_classes=num_res_ty, + mult=1, + ) + self.min_corrupt_prob = min_corrupt_prob + self.max_corrupt_prob = max_corrupt_prob + self.key = key + + @typecheck + def _corrupt_seq( + self, sequence: Int[Tensor, "... n"] + ) -> tuple[Int[Tensor, "... n"], Bool[Tensor, "... n"]]: + """Corrupt the sequence with the given probability""" + corrupt_prob = np.random.uniform( + low=self.min_corrupt_prob, high=self.max_corrupt_prob + ) + corrupt_mask = torch.rand_like(sequence.float()) < corrupt_prob + corrupt_aas = torch.randint_like( + corrupt_mask[corrupt_mask].long(), high=self.num_classes - 1 + ) + corrupt_sequence = sequence.clone() + corrupt_sequence[corrupt_mask] = corrupt_aas + return corrupt_sequence, corrupt_mask + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict(aatype=batch["inputs"][self.key].long()) + + @typecheck + def _generate(self, aatype: Int[Tensor, "b n"]) -> Tensor: + """see super class""" + seq_emb = aatype.clone() + seq_emb, _corrupt_mask = self._corrupt_seq(seq_emb) + return self.make_feature(data=seq_emb.unsqueeze(-1)) diff --git a/chai_lab/data/features/generators/structure_metadata.py b/chai_lab/data/features/generators/structure_metadata.py new file mode 100644 index 0000000..8e071c9 --- /dev/null +++ b/chai_lab/data/features/generators/structure_metadata.py @@ -0,0 +1,139 @@ +import torch +from einops import repeat +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.defaults import default +from chai_lab.utils.typing import Bool, Float, typecheck + +DEFAULT_BFACTOR_BINS = [140.0] + +DEFAULT_PLDDT_BINS = [0.3, 0.7] + + +class IsDistillation(FeatureGenerator): + def __init__(self): + super().__init__( + ty=FeatureType.TOKEN, + encoding_ty=EncodingType.ONE_HOT, + can_mask=True, + num_classes=1, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict( + is_distillation=batch["inputs"]["is_distillation"], + token_exists_mask=batch["inputs"]["token_exists_mask"], + ) + + @typecheck + def _generate( + self, + is_distillation: Bool[Tensor, "b 1"], + token_exists_mask: Bool[Tensor, "b n"], + ) -> Tensor: + _, n = token_exists_mask.shape + is_distillation = repeat(is_distillation, "b 1 -> b n 1", n=n).to(torch.uint8) + return self.make_feature(data=is_distillation) + + +class TokenBFactor(FeatureGenerator): + def __init__( + self, + include_prob: float = 1.0, + bins: list[float] | None = None, + ): + self.bins = torch.tensor(default(bins, DEFAULT_BFACTOR_BINS)) + + super().__init__( + ty=FeatureType.TOKEN, + encoding_ty=EncodingType.ONE_HOT, + can_mask=True, + num_classes=len(self.bins) + 1, + mult=1, + ) + self.include_prob = include_prob + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict( + token_b_factor=batch["inputs"]["token_b_factor_or_plddt"], + is_distillation=batch["inputs"]["is_distillation"], + token_exists_mask=batch["inputs"]["token_exists_mask"], + ) + + @typecheck + def _generate( + self, + token_b_factor: Float[Tensor, "b n"], + is_distillation: Bool[Tensor, "b 1"], + token_exists_mask: Bool[Tensor, "b n"], + ) -> Tensor: + _, n = token_exists_mask.shape + + include_mask = ( + torch.rand_like(is_distillation, dtype=torch.float) <= self.include_prob + ) + + # this feature is not defined for distillation data + mask = ( + repeat(~is_distillation, "b 1 -> b n", n=n) + & token_exists_mask + & repeat(include_mask, "b 1 -> b n", n=n) + ) + + feat = torch.searchsorted(self.bins.to(is_distillation.device), token_b_factor) + feat.masked_fill_(~mask, self.mask_value) + + return self.make_feature(data=feat.unsqueeze(-1)) + + +class TokenPLDDT(FeatureGenerator): + def __init__( + self, + include_prob: float = 1.0, + bins: list[float] | None = None, + ): + self.bins = torch.tensor(default(bins, DEFAULT_PLDDT_BINS)) + + super().__init__( + ty=FeatureType.TOKEN, + encoding_ty=EncodingType.ONE_HOT, + can_mask=True, + num_classes=len(self.bins) + 1, + mult=1, + ) + self.include_prob = include_prob + + def get_input_kwargs_from_batch(self, batch) -> dict: + return dict( + token_plddt=batch["inputs"]["token_b_factor_or_plddt"], + is_distillation=batch["inputs"]["is_distillation"], + token_exists_mask=batch["inputs"]["token_exists_mask"], + ) + + @typecheck + def _generate( + self, + token_plddt: Float[Tensor, "b n"], + is_distillation: Bool[Tensor, "b 1"], + token_exists_mask: Bool[Tensor, "b n"], + ) -> Tensor: + _, n = token_exists_mask.shape + + include_mask = ( + torch.rand_like(is_distillation, dtype=torch.float) <= self.include_prob + ) + + # this feature is defined ONLY for distillation data + mask = ( + repeat(is_distillation, "b 1 -> b n", n=n) + & token_exists_mask + & repeat(include_mask, "b 1 -> b n", n=n) + ) + + feat = torch.searchsorted(self.bins.to(is_distillation.device), token_plddt) + feat.masked_fill_(~mask, self.mask_value) + + return self.make_feature(data=feat.unsqueeze(-1)) diff --git a/chai_lab/data/features/generators/templates.py b/chai_lab/data/features/generators/templates.py new file mode 100644 index 0000000..8e638af --- /dev/null +++ b/chai_lab/data/features/generators/templates.py @@ -0,0 +1,162 @@ +""" +Feature generators for templates. This includes the following: +- Template mask (includes both the psuedo beta mask and backbone frame mask) +- Template unit vector generator +- Template residue type generator +- Template distogram generator +""" + +import logging +from typing import Any + +import torch +from einops import rearrange +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.data.residue_constants import residue_types_with_nucleotides_order +from chai_lab.utils.typing import Bool, Float, Int, UInt8, typecheck + +logger = logging.getLogger(__name__) + + +class TemplateMaskGenerator(FeatureGenerator): + def __init__(self): + super().__init__( + ty=FeatureType.TEMPLATES, + encoding_ty=EncodingType.IDENTITY, + can_mask=False, + num_classes=2, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + template_backbone_frame_mask=batch["inputs"][ + "template_backbone_frame_mask" + ], + template_pseudo_beta_mask=batch["inputs"]["template_pseudo_beta_mask"], + asym_ids=batch["inputs"]["token_asym_id"].type(torch.int32), + ) + + def _generate( + self, + template_backbone_frame_mask: Bool[Tensor, "batch templ tokens"], + template_pseudo_beta_mask: Bool[Tensor, "batch templ tokens"], + asym_ids: Int[Tensor, "batch tokens"], + ) -> Tensor: + same_asym = rearrange(asym_ids, "b t -> b 1 t 1 1") == rearrange( + asym_ids, "b t -> b 1 1 t 1" + ) + # Line 1: backbone frame mask + # (b t n n) + bij_backbone = rearrange( + template_backbone_frame_mask, "b t n -> b t n 1 1" + ) * rearrange(template_backbone_frame_mask, "b t n -> b t 1 n 1") + + # Line 2: backbone pseudo beta mask + # (b t n n) + bij_pseudo_beta = rearrange( + template_pseudo_beta_mask, "b t n -> b t n 1 1" + ) * rearrange(template_pseudo_beta_mask, "b t n -> b t 1 n 1") + + mask_feat = torch.cat([bij_backbone, bij_pseudo_beta], dim=-1).float() + + return self.make_feature(mask_feat.float() * same_asym.float()) + + +class TemplateUnitVectorGenerator(FeatureGenerator): + """Generates feature for template unit vector""" + + def __init__(self): + super().__init__( + ty=FeatureType.TEMPLATES, + encoding_ty=EncodingType.IDENTITY, + can_mask=False, + num_classes=3, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + template_unit_vector=batch["inputs"]["template_unit_vector"], + asym_ids=batch["inputs"]["token_asym_id"].to(torch.int32), + ) + + @typecheck + def _generate( + self, + template_unit_vector: Float[Tensor, "batch templ tokens tokens 3"], + asym_ids: Int[Tensor, "batch tokens"], + ) -> Tensor: + same_asym = rearrange(asym_ids, "b t -> b 1 t 1 1") == rearrange( + asym_ids, "b t -> b 1 1 t 1" + ) + same_asym = same_asym.to(template_unit_vector.dtype) + # mask out pairs with different asyms + template_unit_vector = template_unit_vector * same_asym + return self.make_feature(template_unit_vector) + + +class TemplateResTypeGenerator(FeatureGenerator): + """Generates feature for one-hot encoding of templates, same classes as restype.""" + + def __init__(self, embed_dim=32): + num_res_ty = len(residue_types_with_nucleotides_order) + super().__init__( + ty=FeatureType.TEMPLATES, + encoding_ty=EncodingType.OUTERSUM, + can_mask=False, + num_classes=num_res_ty, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + template_tokens=batch["inputs"]["template_restype"].type(torch.uint8), + ) + + @typecheck + def _generate( + self, + template_tokens: UInt8[Tensor, "batch templ tokens"], + ) -> Tensor: + return self.make_feature(data=template_tokens.unsqueeze(-1)) + + +class TemplateDistogramGenerator(FeatureGenerator): + """Generates feature for distogram of templates.""" + + def __init__( + self, + min_dist_bin: float = 3.25, + max_dist_bin: float = 50.75, + n_dist_bin: int = 38, + ): + super().__init__( + ty=FeatureType.TEMPLATES, + encoding_ty=EncodingType.ONE_HOT, + can_mask=True, + num_classes=n_dist_bin, + mult=1, + ) + self.dist_bins = torch.linspace(min_dist_bin, max_dist_bin, n_dist_bin)[1:] + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + template_distances=batch["inputs"]["template_distances"], + asym_ids=batch["inputs"]["token_asym_id"].to(torch.int32), + ) + + @typecheck + def _generate( + self, + template_distances: Float[Tensor, "batch templ tokens tokens"], + asym_ids: Int[Tensor, "batch tokens"], + ) -> Tensor: + discretized = torch.searchsorted(self.dist_bins, template_distances) + same_asym = rearrange(asym_ids, "b t -> b 1 t 1") == rearrange( + asym_ids, "b t -> b 1 1 t" + ) + discretized = torch.masked_fill(discretized, ~same_asym, self.mask_value) + return self.make_feature(data=discretized.unsqueeze(-1)) diff --git a/chai_lab/data/features/generators/token_dist_restraint.py b/chai_lab/data/features/generators/token_dist_restraint.py new file mode 100644 index 0000000..a3a09b3 --- /dev/null +++ b/chai_lab/data/features/generators/token_dist_restraint.py @@ -0,0 +1,367 @@ +import logging +from dataclasses import dataclass + +import numpy as np +import torch +from einops import rearrange, repeat +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.model.utils import get_asym_id_from_subchain_id +from chai_lab.utils.tensor_utils import cdist, tensorcode_to_string, und, und_self +from chai_lab.utils.typing import Bool, Float, Int, UInt8, typecheck + +logger = logging.getLogger(__name__) + + +@typecheck +@dataclass +class ConstraintGroup: + """ + Container for a token pair distance restraint (contact) + """ + + left_residue_subchain_id: str + right_residue_subchain_id: str + left_residue_index: int + right_residue_index: int + right_residue_name: str + left_residue_name: str + distance_threshold: float + + def get_left_and_right_asym_ids( + self, + token_subchain_id: UInt8[Tensor, "n 4"], + token_asym_id: Int[Tensor, "n"], + ): + left_asym_id = get_asym_id_from_subchain_id( + subchain_id=self.left_residue_subchain_id, + source_pdb_chain_id=token_subchain_id, + token_asym_id=token_asym_id, + ) + right_asym_id = get_asym_id_from_subchain_id( + subchain_id=self.right_residue_subchain_id, + source_pdb_chain_id=token_subchain_id, + token_asym_id=token_asym_id, + ) + return left_asym_id, right_asym_id + + def __str__(self): + return ( + f"ConstraintGroup(left_residue_subchain_id={self.left_residue_subchain_id}, " + f"right_residue_subchain_id={self.right_residue_subchain_id}, " + f"left_residue_index={self.left_residue_index}, " + f"right_residue_index={self.right_residue_index}, " + f"right_residue_name={self.right_residue_name}, " + f"left_residue_name={self.left_residue_name}, " + f"distance_threshold={self.distance_threshold})" + ) + + +class TokenDistanceRestraint(FeatureGenerator): + def __init__( + self, + include_probability: float = 1.0, + size: int | float = 0.33, + min_dist: int | float = 10.0, + max_dist: int | float = 30.0, + coord_noise: float = 0.0, + num_rbf_radii: int = 5, + query_entity_types: list[EntityType] | None = None, + key_entity_types: list[EntityType] | None = None, + ): + """Randomly sample inter-chain token distance restraints + + Parameters: + include_probability: Probability with which to include restraints + for a given example. i.e. if include probability is 0.75, then 25% + of the time, we will not sample any restraints for an example. + size: Number of restraints to sample. If 0 < size < 1, then the number + of restraints will be determined as geom(size), independently for each + example. + min_dist: Minimum distance to encode restraints for + max_dist: Maximum distance to encode restraints for + coord_noise: gaussian noise with mean 0 and variance coord_noise + added to coordinates before sampling restraints. + num_rbf_radii: Number of radii to use for the radial basis function + embedding of restraints + query_entity_types: Entity types to consider when sampling "query" tokens + for restraints. Defaults to all entity types. + key_entity_types: Entity types to consider when sampling "key" tokens + for restraints. Defaults to all entity types. + + NOTE: We only sample restraints between tokens if one of the tokens is in + the query entity types and the other is in the key entity types. + """ + super().__init__( + ty=FeatureType.TOKEN_PAIR, + can_mask=False, + encoding_ty=EncodingType.RBF, + num_classes=num_rbf_radii, + mult=1, + ignore_index=-1.0, + ) + self.ignore_idx = -1.0 + self.min_dist, self.max_dist = min_dist, max_dist + self.coord_noise = coord_noise + self.include_prob = include_probability + self.size = size + self.query_entity_types = torch.tensor( + ( + [e.value for e in query_entity_types] + if query_entity_types is not None + else [e.value for e in EntityType] + ) + ).long() + self.key_entity_types = torch.tensor( + [ + [e.value for e in key_entity_types] + if key_entity_types is not None + else [e.value for e in EntityType] + ] + ).long() + + def get_num_restraints(self, batch_size) -> list[int]: + if 0 < self.size < 1: + seles = np.random.geometric(self.size, size=batch_size) + include_mask = np.random.uniform(size=batch_size) < self.include_prob + seles[~include_mask] = 0 + return [int(x) for x in seles] + return [int(self.size)] * batch_size + + def get_input_kwargs_from_batch(self, batch) -> dict: + maybe_constraint_dicts = batch["inputs"].get("contact_constraints", [[None]])[0] + contact_constraints = ( + [ConstraintGroup(**d) for d in maybe_constraint_dicts] + if isinstance(maybe_constraint_dicts[0], dict) + else None + ) + return dict( + atom_gt_coords=batch["inputs"]["atom_gt_coords"], + atom_exists_mask=batch["inputs"]["atom_exists_mask"], + token_asym_id=batch["inputs"]["token_asym_id"].long(), + token_ref_atom_index=batch["inputs"]["token_ref_atom_index"].long(), + token_exists_mask=batch["inputs"]["token_exists_mask"], + token_entity_type=batch["inputs"]["token_entity_type"].long(), + token_residue_index=batch["inputs"]["token_residue_index"].long(), + token_residue_names=batch["inputs"]["token_residue_name"], + token_subchain_id=batch["inputs"]["subchain_id"], + constraints=contact_constraints, + ) + + def _sample_restraints( + self, + dists: Float[Tensor, "n n"], + num_restraints: int, + ): + sampled_restraints = torch.full_like(dists, self.ignore_idx) + # sample upper bound independently in range (min_dist, max_dist) + # for each pair of tokens + # We choose a random delta to upper bound all sampled distances with. + # We do this because larger distance restraints are more likely to be + # valid than smaller ones, and we try to reduce that bias here. + delta = torch.rand(1) * (self.max_dist - self.min_dist) + all_restraint_bounds = torch.rand_like(dists) * delta + self.min_dist + all_valid_restraints = dists < all_restraint_bounds + num_valid_restraints = int(all_valid_restraints.sum().item()) + if num_valid_restraints == 0 or num_restraints == 0: # no restraints to add + return sampled_restraints + num_restraints = min(num_valid_restraints, num_restraints) + # select random restraints and respective sampled bounds + sampled_restraint_mask = all_restraint_bounds.new_zeros( + num_valid_restraints, dtype=torch.bool + ) + sampled_restraint_mask[:num_restraints] = True + # select random restraints by shuffling + sampled_restraint_mask = sampled_restraint_mask[ + torch.randperm(num_valid_restraints) + ] + + # add the bounds/pairs that we sampled to the sampled restraint matrix + flat_restraint_bounds = all_restraint_bounds[all_valid_restraints] + flat_restraint_bounds[~sampled_restraint_mask] = self.ignore_idx + sampled_restraints[all_valid_restraints] = flat_restraint_bounds + + return sampled_restraints + + @typecheck + def _generate( + self, + atom_gt_coords: Float[Tensor, "b a 3"], + atom_exists_mask: Bool[Tensor, "b a"], + token_asym_id: Int[Tensor, "b n"], + token_ref_atom_index: Int[Tensor, "b n"], + token_exists_mask: Bool[Tensor, "b n"], + token_entity_type: Int[Tensor, "b n"], + token_residue_index: Int[Tensor, "b n"], + token_residue_names: UInt8[Tensor, "b n 8"], + token_subchain_id: UInt8[Tensor, "b n 4"], + constraints: list[ConstraintGroup] | None = None, + ) -> Tensor: + try: + if constraints is not None: + assert atom_gt_coords.shape[0] == 1 + return self.generate_from_constraints( + token_asym_id=token_asym_id, + token_residue_index=token_residue_index, + token_residue_names=token_residue_names, + token_subchain_id=token_subchain_id, + constraints=constraints, + ) + except Exception as e: + logger.error(f"Error {e} generating distance constraints: {constraints}") + + return self._generate_from_batch( + atom_gt_coords=atom_gt_coords, + atom_exists_mask=atom_exists_mask, + token_asym_id=token_asym_id, + token_ref_atom_index=token_ref_atom_index, + token_exists_mask=token_exists_mask, + token_entity_type=token_entity_type, + ) + + @typecheck + def _generate_from_batch( + self, + atom_gt_coords: Float[Tensor, "b a 3"], + atom_exists_mask: Bool[Tensor, "b a"], + token_asym_id: Int[Tensor, "b n"], + token_ref_atom_index: Int[Tensor, "b n"], + token_exists_mask: Bool[Tensor, "b n"], + token_entity_type: Int[Tensor, "b n"], + ) -> Tensor: + batch_size = atom_gt_coords.shape[0] + # create inter-chain contact mask + valid_token_pair_mask = und_self(token_exists_mask, "b i, b j -> b i j") + left_entity_type_mask = torch.any( + (token_entity_type.unsqueeze(-1) - self.query_entity_types) == 0, dim=-1 + ) + right_entity_type_mask = torch.any( + (token_entity_type.unsqueeze(-1) - self.key_entity_types) == 0, dim=-1 + ) + valid_entity_pair_mask = und( + left_entity_type_mask, right_entity_type_mask, "b i, b j -> b i j" + ) + diff_chain_mask = rearrange(token_asym_id, "b i -> b i 1") != rearrange( + token_asym_id, "b j -> b 1 j" + ) + ref_atom_mask = torch.gather( + atom_exists_mask, dim=1, index=token_ref_atom_index + ) + valid_token_ref_atom_mask = und_self(ref_atom_mask, "b i, b j -> b i j") + valid_contact_mask = ( + valid_token_pair_mask + & valid_entity_pair_mask + & valid_token_ref_atom_mask + & diff_chain_mask + ) + + # compute pairwise distances + token_ref_atom_coords = torch.gather( + atom_gt_coords, dim=1, index=repeat(token_ref_atom_index, "... -> ... 3") + ) + # optionally add noise to coordinates before computing distances + token_ref_atom_coords = ( + token_ref_atom_coords + + torch.randn_like(token_ref_atom_coords) * self.coord_noise + ) + inter_token_dists = cdist(token_ref_atom_coords) + inter_token_dists[~valid_contact_mask] = self.max_dist + 1 + # compute contacts by (1) sampling an upper bound on the distance + # and (2) selecting pairwise distances below the threshold + num_to_include = self.get_num_restraints(batch_size) + restraint_mats = [ + self._sample_restraints(inter_token_dists[i], n) + for i, n in enumerate(num_to_include) + ] + encoded_feat = torch.stack(restraint_mats, dim=0) + return self.make_feature(encoded_feat.unsqueeze(-1)) + + @typecheck + def generate_from_constraints( + self, + token_asym_id: Int[Tensor, "1 n"], + token_residue_index: Int[Tensor, "1 n"], + token_residue_names: UInt8[Tensor, "1 n 8"], + token_subchain_id: UInt8[Tensor, "1 n 4"], + constraints: list[ConstraintGroup], + ) -> Tensor: + logger.info(f"Generating distance feature from constraints: {constraints}") + n, device = token_asym_id.shape[1], token_asym_id.device + constraint_mat = torch.full( + (n, n), fill_value=self.ignore_idx, device=device, dtype=torch.float32 + ) + for constraint_group in constraints: + left_residue_asym_id, right_residue_asym_id = ( + constraint_group.get_left_and_right_asym_ids( + token_subchain_id=rearrange(token_subchain_id, "1 ... -> ..."), + token_asym_id=rearrange(token_asym_id, "1 ... -> ..."), + ) + ) + constraint_mat = self.add_distance_constraint( + constraint_mat=constraint_mat, + token_asym_id=rearrange(token_asym_id, "1 ... -> ..."), + token_residue_index=rearrange(token_residue_index, "1 ... -> ..."), + token_residue_names=rearrange(token_residue_names, "1 ... -> ..."), + left_residue_asym_id=left_residue_asym_id, + right_residue_asym_id=right_residue_asym_id, + left_residue_index=constraint_group.left_residue_index, + right_residue_index=constraint_group.right_residue_index, + right_residue_name=constraint_group.right_residue_name, + left_residue_name=constraint_group.left_residue_name, + distance_threshold=constraint_group.distance_threshold, + ) + # encode and apply mask + constraint_mat = repeat(constraint_mat, "i j -> 1 i j 1") + return self.make_feature(constraint_mat) + + @typecheck + def add_distance_constraint( + self, + constraint_mat: Float[Tensor, "n n"], + token_asym_id: Int[Tensor, "n"], + token_residue_index: Int[Tensor, "n"], + token_residue_names: UInt8[Tensor, "n 8"], + # asym id of the chain that binds in the pocket + left_residue_asym_id: int, + right_residue_asym_id: int, + left_residue_index: int, + right_residue_index: int, + right_residue_name: str, + left_residue_name: str, + distance_threshold: float, + ): + left_asym_mask = token_asym_id == left_residue_asym_id + right_asym_mask = token_asym_id == right_residue_asym_id + left_index_mask = token_residue_index == left_residue_index + right_index_mask = token_residue_index == right_residue_index + left_residue_mask = left_asym_mask & left_index_mask + right_residue_mask = right_asym_mask & right_index_mask + # restraint should point to single residue pair + assert torch.sum(left_residue_mask) == 1, ( + f"Expected unique residue but found {torch.sum(left_residue_mask)}\n" + f"{left_residue_asym_id=}, {left_residue_index=}, {left_residue_name=}" + ) + assert torch.sum(right_residue_mask) == 1, ( + f"Expected unique residue but found {torch.sum(right_residue_mask)}\n" + f"{right_residue_asym_id=}, {right_residue_index=}, {right_residue_name=}" + ) + # make sure the residue names in the constraint match the + # ones we parsed + left_res_name = token_residue_names[left_residue_mask] + right_res_name = token_residue_names[right_residue_mask] + expected_res_name = tensorcode_to_string(rearrange(left_res_name, "1 l -> l")) + assert expected_res_name == left_residue_name, ( + f"Expected residue name {expected_res_name} but got " f"{left_residue_name}" + ) + expected_res_name = tensorcode_to_string(rearrange(right_res_name, "1 l -> l")) + assert expected_res_name == right_residue_name, ( + f"Expected residue name {expected_res_name} but got " + f"{right_residue_name}" + ) + # add constraint + # NOTE: feature is *not* symmetric + constraint_mat[left_residue_mask, right_residue_mask] = distance_threshold + return constraint_mat diff --git a/chai_lab/data/features/generators/token_pair_distance.py b/chai_lab/data/features/generators/token_pair_distance.py new file mode 100644 index 0000000..c016ec4 --- /dev/null +++ b/chai_lab/data/features/generators/token_pair_distance.py @@ -0,0 +1,61 @@ +from typing import Any + +import torch +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.data.features.token_utils import get_centre_positions_and_mask +from chai_lab.utils.tensor_utils import cdist +from chai_lab.utils.typing import Bool, Float, Int, typecheck + + +class TokenCenterDistance(FeatureGenerator): + def __init__( + self, + dist_bins: list[float] | None = None, + ): + dist_bins = dist_bins if dist_bins is not None else [0.0, 4.0, 8.0, 12.0, 16.0] + super().__init__( + ty=FeatureType.TOKEN_PAIR, + encoding_ty=EncodingType.ONE_HOT, + # one of dist_bins of rbf_radii is not None. + num_classes=len(dist_bins) + 1, + mult=1, + can_mask=True, + ) + + # maintain consistent orders + self.dist_bins = torch.tensor(dist_bins) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + all_atom_positions=batch["inputs"]["atom_gt_coords"], + all_atom_mask=batch["inputs"]["atom_exists_mask"], + token_single_mask=batch["inputs"]["token_exists_mask"], + token_center_atom_index=batch["inputs"]["token_centre_atom_index"].long(), + ) + + @typecheck + def _generate( + self, + all_atom_positions=Float[Tensor, "b a 3"], + all_atom_mask=Bool[Tensor, "b a"], + token_single_mask=Bool[Tensor, "b n"], + token_center_atom_index=Int[Tensor, "b n"], + ) -> Tensor: + """see super class""" + center_atom_coords, center_atom_mask = get_centre_positions_and_mask( + atom_gt_coords=all_atom_positions, + atom_exists_mask=all_atom_mask, + token_centre_atom_index=token_center_atom_index, + token_exists_mask=token_single_mask, + ) + feat = torch.searchsorted( + self.dist_bins.to(center_atom_coords.device), cdist(center_atom_coords) + ) + center_atom_pair_exists = torch.einsum( + "b i, b j -> b i j", center_atom_mask, center_atom_mask + ) + feat.masked_fill_(~center_atom_pair_exists, self.mask_value) + return self.make_feature(feat.unsqueeze(-1)) diff --git a/chai_lab/data/features/generators/token_pair_pocket_restraint.py b/chai_lab/data/features/generators/token_pair_pocket_restraint.py new file mode 100644 index 0000000..41bb0d6 --- /dev/null +++ b/chai_lab/data/features/generators/token_pair_pocket_restraint.py @@ -0,0 +1,283 @@ +import logging +from dataclasses import dataclass + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.data.features.generators.token_dist_restraint import ( + TokenDistanceRestraint, +) +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.model.utils import get_asym_id_from_subchain_id +from chai_lab.utils.tensor_utils import tensorcode_to_string +from chai_lab.utils.typing import Bool, Float, Int, UInt8, typecheck + +logger = logging.getLogger(__name__) + + +@typecheck +@dataclass +class ConstraintGroup: + """ + Container for a token pocket pair restraint group + """ + + # subchain ID of the pocket chain + pocket_chain_subchain_id: str + # subchain ID of the pocket token + pocket_token_subchain_id: str + # residue index of the pocket token + pocket_token_residue_index: int + # residue name of the pocket token + pocket_token_residue_name: str + # pocket distance threshold + pocket_distance_threshold: float + # optional subchain IDs + + def get_chain_and_token_asym_ids( + self, + token_subchain_id: UInt8[Tensor, "n 4"], + token_asym_id: Int[Tensor, "n"], + ): + pocket_chain_asym_id = get_asym_id_from_subchain_id( + subchain_id=self.pocket_chain_subchain_id, + source_pdb_chain_id=token_subchain_id, + token_asym_id=token_asym_id, + ) + pocket_token_asym_id = get_asym_id_from_subchain_id( + subchain_id=self.pocket_token_subchain_id, + source_pdb_chain_id=token_subchain_id, + token_asym_id=token_asym_id, + ) + return pocket_chain_asym_id, pocket_token_asym_id + + def __str__(self): + return ( + f"ConstraintGroup(pocket_chain_subchain_id={self.pocket_chain_subchain_id}, " + f"pocket_token_subchain_id={self.pocket_token_subchain_id}, " + f"pocket_token_residue_index={self.pocket_token_residue_index}, " + f"pocket_token_residue_name={self.pocket_token_residue_name}, " + f"pocket_distance_threshold={self.pocket_distance_threshold})" + ) + + +class TokenPairPocketRestraint(FeatureGenerator): + def __init__( + self, + include_probability: float = 1.0, + size: int | float = 0.33, + min_dist: int | float = 10.0, + max_dist: int | float = 30.0, + coord_noise: float = 0.0, + num_rbf_radii: int = 5, + query_entity_types: list[EntityType] | None = None, + key_entity_types: list[EntityType] | None = None, + ): + """ + Derives pocket constraints by first generating pairwise distance restraints, + and then selecting the query tokens that were assigned to some non-zero + constraint. + + NOTE: Pocket restraints will only be sampled for tokens that are in the + query entity types. + """ + super().__init__( + ty=FeatureType.TOKEN_PAIR, + can_mask=False, + encoding_ty=EncodingType.RBF, + num_classes=num_rbf_radii, + mult=1, + ignore_index=-1.0, + ) + # use distance restraint generator to sample pocket tokens/chains + self.distance_restraint_gen = TokenDistanceRestraint( + include_probability=include_probability, + size=size, + min_dist=min_dist, + max_dist=max_dist, + coord_noise=coord_noise, + num_rbf_radii=num_rbf_radii, + query_entity_types=query_entity_types, + key_entity_types=key_entity_types, + ) + self.ignore_idx = -1.0 + self.min_dist, self.max_dist = min_dist, max_dist + self.coord_noise = coord_noise + self.include_prob = include_probability + self.size = size + # override feature type + self.ty = FeatureType.TOKEN_PAIR + + def get_input_kwargs_from_batch(self, batch) -> dict: + # cast pocket constraints from dict back to dataclass + maybe_constraint_dicts = batch["inputs"].get("pocket_constraints", [[None]])[0] + pocket_constraints = batch["inputs"]["pocket_constraints"] = ( + [ConstraintGroup(**d) for d in maybe_constraint_dicts] + if isinstance(maybe_constraint_dicts[0], dict) + else None + ) + + return dict( + atom_gt_coords=batch["inputs"]["atom_gt_coords"], + atom_exists_mask=batch["inputs"]["atom_exists_mask"], + token_asym_id=batch["inputs"]["token_asym_id"].long(), + token_ref_atom_index=batch["inputs"]["token_ref_atom_index"].long(), + token_exists_mask=batch["inputs"]["token_exists_mask"], + token_entity_type=batch["inputs"]["token_entity_type"].long(), + token_residue_index=batch["inputs"]["token_residue_index"].long(), + token_residue_names=batch["inputs"]["token_residue_name"], + token_subchain_id=batch["inputs"]["subchain_id"], + constraints=pocket_constraints, + ) + + @typecheck + def _generate( + self, + atom_gt_coords: Float[Tensor, "b a 3"], + atom_exists_mask: Bool[Tensor, "b a"], + token_asym_id: Int[Tensor, "b n"], + token_ref_atom_index: Int[Tensor, "b n"], + token_exists_mask: Bool[Tensor, "b n"], + token_entity_type: Int[Tensor, "b n"], + token_residue_index: Int[Tensor, "b n"], + token_residue_names: UInt8[Tensor, "b n 8"], + token_subchain_id: UInt8[Tensor, "b n 4"], + constraints: list[ConstraintGroup] | None = None, + ) -> Tensor: + try: + if constraints is not None: + assert atom_gt_coords.shape[0] == 1 + return self.generate_from_constraints( + token_asym_id=token_asym_id, + token_residue_index=token_residue_index, + token_residue_names=token_residue_names, + token_subchain_id=token_subchain_id, + constraints=constraints, + ) + except Exception as e: + logger.error(f"Error {e} generating pocket constraints: {constraints}") + + return self._generate_from_batch( + atom_gt_coords=atom_gt_coords, + atom_exists_mask=atom_exists_mask, + token_asym_id=token_asym_id, + token_ref_atom_index=token_ref_atom_index, + token_exists_mask=token_exists_mask, + token_entity_type=token_entity_type, + ) + + @typecheck + def _generate_from_batch( + self, + atom_gt_coords: Float[Tensor, "b a 3"], + atom_exists_mask: Bool[Tensor, "b a"], + token_asym_id: Int[Tensor, "b n"], + token_ref_atom_index: Int[Tensor, "b n"], + token_exists_mask: Bool[Tensor, "b n"], + token_entity_type: Int[Tensor, "b n"], + ) -> Tensor: + contact_feat = self.distance_restraint_gen._generate_from_batch( + atom_gt_coords=atom_gt_coords, + atom_exists_mask=atom_exists_mask, + token_asym_id=token_asym_id, + token_ref_atom_index=token_ref_atom_index, + token_exists_mask=token_exists_mask, + token_entity_type=token_entity_type, + ).data + # derive the pocket from the contact feature + contact_feat[contact_feat == self.ignore_idx] = self.max_dist + 1 + # determine contacting asym pairs and their respective distances + contact_mask = contact_feat < self.max_dist + # batch dim, row dim, col dim + bs, rs, cs = torch.where(contact_mask.squeeze(-1)) + # determine asym ids of tokens in contact. + for b, r, c in zip(bs, rs, cs): + col_asym_mask = token_asym_id[b] == token_asym_id[b, c] + pocket_constraint = contact_feat[b, r, c] + contact_feat[b, r, col_asym_mask] = pocket_constraint + + # re-mask + contact_feat[contact_feat > self.max_dist] = self.ignore_idx + return self.make_feature(contact_feat) + + @typecheck + def generate_from_constraints( + self, + # only batch size 1 is supported + token_asym_id: Int[Tensor, "1 n"], + token_subchain_id: UInt8[Tensor, "1 n 4"], + token_residue_index: Int[Tensor, "1 n"], + token_residue_names: UInt8[Tensor, "1 n 8"], + constraints: list[ConstraintGroup], + ) -> Tensor: + logger.info(f"Generating pocket feature from constraints: {constraints}") + n, device = token_asym_id.shape[1], token_asym_id.device + constraint_mat = torch.full( + (n, n), fill_value=self.ignore_idx, device=device, dtype=torch.float32 + ) + for constraint_group in constraints: + pocket_chain_asym_id, pocket_token_asym_id = ( + constraint_group.get_chain_and_token_asym_ids( + token_subchain_id=rearrange(token_subchain_id, "1 ... -> ..."), + token_asym_id=rearrange(token_asym_id, "1 ... -> ..."), + ) + ) + constraint_mat = self.add_pocket_constraint( + constraint_mat=constraint_mat, + token_asym_id=rearrange(token_asym_id, "1 ... -> ..."), + token_residue_index=rearrange(token_residue_index, "1 ... -> ..."), + token_residue_names=rearrange(token_residue_names, "1 ... -> ..."), + pocket_chain_asym_id=pocket_chain_asym_id, + pocket_token_asym_id=pocket_token_asym_id, + pocket_token_residue_index=constraint_group.pocket_token_residue_index, + pocket_token_residue_name=constraint_group.pocket_token_residue_name, + pocket_distance_threshold=constraint_group.pocket_distance_threshold, + ) + # encode and apply mask + constraint_mat = repeat(constraint_mat, "i j -> 1 i j 1") + return self.make_feature(constraint_mat) + + @typecheck + def add_pocket_constraint( + self, + constraint_mat: Float[Tensor, "n n"], + token_asym_id: Int[Tensor, "n"], + token_residue_index: Int[Tensor, "n"], + token_residue_names: UInt8[Tensor, "n 8"], + # asym id of the chain that binds in the pocket + pocket_chain_asym_id: int, + # asym id of the token defining the pocket + pocket_token_asym_id: int, + # residue index of the pocket token + pocket_token_residue_index: int, + # residue name of the pocket token + pocket_token_residue_name: str, + # distance from the pocket token to pocket chain + pocket_distance_threshold: float, + ): + pocket_chain_asym_mask = token_asym_id == pocket_chain_asym_id + pocket_token_asym_mask = token_asym_id == pocket_token_asym_id + pocket_token_residue_mask = token_residue_index == pocket_token_residue_index + pocket_token_residue_mask &= pocket_token_asym_mask + assert torch.sum(pocket_token_residue_mask) == 1, ( + f"Expected unique residue but found {torch.sum(pocket_token_residue_mask)}\n" + f"{pocket_token_asym_id=}, {pocket_token_residue_index=}, " + f"{pocket_token_residue_name=}" + ) + pocket_token_res_name = token_residue_names[pocket_token_residue_mask] + pocket_token_res_name = rearrange(pocket_token_res_name, "1 l -> l") + expected_res_name = tensorcode_to_string(pocket_token_res_name) + assert expected_res_name == pocket_token_residue_name, ( + f"Expected residue name {expected_res_name} but got " + f"{pocket_token_residue_name}" + ) + # add constraints between the pocket token and all other tokens in the pocket + # chain + # NOTE: feature is not symmetric + constraint_mat[pocket_token_residue_mask, pocket_chain_asym_mask] = ( + pocket_distance_threshold + ) + return constraint_mat diff --git a/chai_lab/data/features/token_utils.py b/chai_lab/data/features/token_utils.py new file mode 100644 index 0000000..1a833ed --- /dev/null +++ b/chai_lab/data/features/token_utils.py @@ -0,0 +1,26 @@ +import torch +from einops import repeat +from torch import Tensor + +from chai_lab.utils.typing import Bool, Float, Int, typecheck + + +@typecheck +def get_centre_positions_and_mask( + atom_gt_coords: Float[Tensor, "... n_atoms 3"], + atom_exists_mask: Bool[Tensor, "... n_atoms"], + token_centre_atom_index: Int[Tensor, "... n_tokens"], + token_exists_mask: Bool[Tensor, "... n_tokens"], +) -> tuple[Float[Tensor, "... n_tokens 3"], Bool[Tensor, "... n_tokens"]]: + assert token_centre_atom_index.dtype in (torch.int32, torch.long) + center_index = token_centre_atom_index.long() + indices = repeat(center_index, "... n -> ... n c", c=3) + center_pos = torch.gather(atom_gt_coords, dim=-2, index=indices) + center_mask = torch.gather(atom_exists_mask, dim=-1, index=center_index) + + # because token_centre_atom_index is zero-padded, and because + # atom number 0 is probably a valid atom, we need to reapply + # the token mask + center_mask = center_mask & token_exists_mask + + return center_pos, center_mask diff --git a/chai_lab/data/io/__init__.py b/chai_lab/data/io/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/io/pdb_utils.py b/chai_lab/data/io/pdb_utils.py new file mode 100644 index 0000000..8430b0a --- /dev/null +++ b/chai_lab/data/io/pdb_utils.py @@ -0,0 +1,217 @@ +import logging +import string +from collections import defaultdict +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path + +import gemmi +from torch import Tensor + +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.utils.tensor_utils import tensorcode_to_string +from chai_lab.utils.typing import Bool, Float, Int, UInt8, typecheck + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class PDBAtom: + record_type: str + atom_index: int + atom_name: str + alt_loc: str + res_name_3: str + chain_tag: str + residue_index: int + insertion_code: str + pos: list[float] + occupancy: float + b_factor: float + element: str + charge: str + + def __str__( + self, + ): + # currently this works only for single-char chain tags + atom_line = ( + f"{self.record_type:<6}{self.atom_index:>5} {self.atom_name:<4}{self.alt_loc:>1}" + f"{self.res_name_3:>3} {self.chain_tag:>1}" + f"{self.residue_index:>4}{self.insertion_code:>1} " + f"{self.pos[0]:>8.3f}{self.pos[1]:>8.3f}{self.pos[2]:>8.3f}" + f"{self.occupancy:>6.2f}{self.b_factor:>6.2f} " + f"{self.element:>2}{self.charge:>2}" + ) + return atom_line + + +def write_pdb(chain_atoms: list[list[PDBAtom]], out_path: str): + with open(out_path, "w") as f: + for chain in chain_atoms: + for atom in chain: + f.write(str(atom) + "\n") + f.write("TER\n") + f.write("END\n") + + +@typecheck +@dataclass +class PDBContext: + """Data needed to produce Posebuster input file types""" + + token_residue_index: Int[Tensor, "n_tokens"] + token_asym_id: Int[Tensor, "n_tokens"] + token_entity_type: Int[Tensor, "n_tokens"] + token_residue_names: UInt8[Tensor, "n_tokens 8"] + atom_token_index: Int[Tensor, "n_atoms"] + atom_ref_element: Int[Tensor, "n_atoms"] + atom_ref_mask: Bool[Tensor, "n_atoms"] + atom_coords: Float[Tensor, "n_atoms 3"] + atom_exists_mask: Bool[Tensor, "n_atoms"] + atom_ref_name_chars: Int[Tensor, "n_atoms 4"] + atom_bfactor_or_plddt: Float[Tensor, "n_atoms"] | None = None + + @cached_property + def token_res_names_to_string(self) -> list[str]: + return [tensorcode_to_string(x) for x in self.token_residue_names.cpu()] + + @property + def num_atoms(self) -> int: + return self.atom_coords.shape[0] + + @property + def is_protein(self) -> bool: + return self.is_entity(EntityType.PROTEIN) + + @property + def is_ligand(self) -> bool: + return self.is_entity(EntityType.LIGAND) + + @property + def first_residue_name(self) -> str: + return self.token_res_names_to_string[0].strip() + + def is_entity(self, ety: EntityType) -> bool: + return self.token_entity_type[0].item() == ety.value + + def get_pdb_atoms(self): + # warning: calling this on cuda tensors is extremely slow + atom_asym_id = self.token_asym_id[self.atom_token_index] + # atom level attributes + atom_residue_index = self.token_residue_index[self.atom_token_index] + atom_names = _tensor_to_atom_names(self.atom_ref_name_chars.unsqueeze(0)) + atom_res_names = self.token_residue_names[self.atom_token_index] + atom_res_names_strs = [ + tensorcode_to_string(x)[:3].ljust(3) for x in atom_res_names + ] + atom_element_names = [ + _atomic_num_to_element(int(x.item())) for x in self.atom_ref_element + ] + + pdb_atoms = [] + for atom_index in range(self.num_atoms): + if not self.atom_exists_mask[atom_index].item(): + # skip missing atoms + continue + + chain_tag_vocab = string.ascii_uppercase + string.ascii_lowercase + if int(atom_asym_id[atom_index].item()) >= len(chain_tag_vocab): + logger.warning( + f"Too many chains for PDB file: {atom_asym_id[atom_index].item()} -- wrapping around" + ) + atom = PDBAtom( + record_type="ATOM", + atom_index=atom_index, + atom_name=atom_names[atom_index], + alt_loc="", + res_name_3=atom_res_names_strs[atom_index], + chain_tag=chain_tag_vocab[ + int(atom_asym_id[atom_index].item()) % len(chain_tag_vocab) + ], + residue_index=int(atom_residue_index[atom_index].item()), + insertion_code="", + pos=self.atom_coords[atom_index].tolist(), + occupancy=1.00, + b_factor=( + 1.00 + if self.atom_bfactor_or_plddt is None + else self.atom_bfactor_or_plddt[atom_index].item() + ), + element=atom_element_names[atom_index], + charge="", + ) + pdb_atoms.append(atom) + return pdb_atoms + + # @classmethod + # def cat(cls, contexts: list["PDBContext"]) -> "PDBContext": + # """Concatenates multiple posebuster contexts into a single context""" + # cat_attrs: dict[str, Tensor] = dict() + # for attr in cls.__annotations__.keys(): + # cat_attrs[attr] = torch.cat([getattr(c, attr) for c in contexts], dim=0) + # return cls(**cat_attrs) + + +def _atomic_num_to_element(atomic_num: int) -> str: + return gemmi.Element(atomic_num).name + + +def entity_to_pdb_atoms(entity: PDBContext) -> list[list[PDBAtom]]: + """Writes a single tokenized entity to PDB file""" + pdb_atoms = entity.get_pdb_atoms() + chains = defaultdict(list) + for atom in pdb_atoms: + chains[atom.chain_tag].append(atom) + return list(chains.values()) + + +def entities_to_pdb_file(entities: list[PDBContext], path: str): + pdb_atoms: list[list[PDBAtom]] = [] + for entity in entities: + pdb_atoms = pdb_atoms + entity_to_pdb_atoms(entity) + write_pdb(pdb_atoms, path) + + +def pdb_context_from_batch( + d: dict, coords: Tensor, plddt: Tensor | None = None +) -> PDBContext: + return PDBContext( + token_residue_index=d["token_residue_index"][0], + token_asym_id=d["token_asym_id"][0], + token_entity_type=d["token_entity_type"][0], + token_residue_names=d["token_residue_name"][0], + atom_token_index=d["atom_token_index"][0], + atom_ref_element=d["atom_ref_element"][0], + atom_ref_mask=d["atom_ref_mask"][0], + atom_coords=coords[0], + atom_exists_mask=d["atom_exists_mask"][0], + atom_ref_name_chars=d["atom_ref_name_chars"][0], + atom_bfactor_or_plddt=plddt[0] if plddt is not None else None, + ) + + +def write_pdbs_from_outputs( + coords: Float[Tensor, "1 n_atoms 3"], + output_batch: dict, + write_path: Path, + bfactors: Float[Tensor, "1 n_atoms"] | None = None, +): + # save outputs + context = pdb_context_from_batch(output_batch, coords, plddt=bfactors) + write_path.parent.mkdir(parents=True, exist_ok=True) + entities_to_pdb_file( + [context], + str(write_path), + ) + logger.info(f"saved pdb file to {write_path}") + + +@typecheck +def _tensor_to_atom_names( + tensor: Int[Tensor, "*dims 4"] | UInt8[Tensor, "*dims 4"], +) -> list[str]: + return [ + "".join([chr(ord_val + 32) for ord_val in ords_atom]).rstrip() + for ords_atom in tensor.squeeze(0) + ] diff --git a/chai_lab/data/parsing/__init__.py b/chai_lab/data/parsing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/parsing/fasta.py b/chai_lab/data/parsing/fasta.py new file mode 100644 index 0000000..d8cbb4d --- /dev/null +++ b/chai_lab/data/parsing/fasta.py @@ -0,0 +1,80 @@ +import logging +import re +from pathlib import Path +from typing import Iterable + +import fsspec + +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.data.residue_constants import restype_1to3_with_x + +logger = logging.getLogger(__name__) + +Fasta = tuple[str, str] +Fastas = list[Fasta] + + +nucleic_acid_1_to_name: dict[tuple[str, EntityType], str] = { + ("A", EntityType.RNA): "A", + ("U", EntityType.RNA): "U", + ("G", EntityType.RNA): "G", + ("C", EntityType.RNA): "C", + ("A", EntityType.DNA): "DA", + ("T", EntityType.DNA): "DT", + ("G", EntityType.DNA): "DG", + ("C", EntityType.DNA): "DC", +} + + +def _fasta_to_str(fasta: Fasta) -> str: + header, sequence = fasta + return f">{header}\n{sequence}\n" + + +def fastas_to_str(fastas: Fastas) -> str: + return "".join(_fasta_to_str(fasta) for fasta in fastas) + + +def write_fastas(fastas: Fastas, output_path: str): + logger.debug(f"Writing {len(fastas)} sequences to {output_path}") + with fsspec.open(output_path, "w") as fp: + fp.write(fastas_to_str(fastas)) + + +def read_fasta(file_path: str | Path) -> Iterable[Fasta]: + from Bio import SeqIO + + fasta_sequences = SeqIO.parse(open(file_path), "fasta") + return [(fasta.id, str(fasta.seq)) for fasta in fasta_sequences] + + +def get_residue_name( + fasta_code: str, + entity_type: EntityType, +) -> str: + match entity_type: + case EntityType.PROTEIN: + return restype_1to3_with_x.get(fasta_code, "UNK") + case EntityType.RNA | EntityType.DNA: + # under nucleic_acid_1_to_name, DNA is mapped to D_ and RNA to _ + unk = "X" if entity_type == EntityType.RNA else "DX" + return nucleic_acid_1_to_name.get((fasta_code, entity_type), unk) + case _: + raise ValueError(f"Invalid polymer entity type {entity_type}") + + +def parse_modified_fasta_sequence(sequence: str, entity_type: EntityType) -> list[str]: + """ + Parses a fasta-like string containing modified residues in + brackets, returns a list of residue codes + """ + pattern = r"[A-Z]|\[[A-Z0-9]+\]" + residues = re.findall(pattern, sequence) + + # get full residue name if regular fasta code (not in brackets), + # otherwise return what user passed in brackets + parsed_residues = [ + get_residue_name(x, entity_type) if not x.startswith("[") else x.strip("[]") + for x in residues + ] + return parsed_residues diff --git a/chai_lab/data/parsing/msas/__init__.py b/chai_lab/data/parsing/msas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/parsing/msas/data_source.py b/chai_lab/data/parsing/msas/data_source.py new file mode 100644 index 0000000..831a385 --- /dev/null +++ b/chai_lab/data/parsing/msas/data_source.py @@ -0,0 +1,53 @@ +import logging +from enum import Enum + +logger = logging.getLogger(__name__) + + +class MSADataSource(Enum): + UNIPROT = "uniprot" + UNIREF90 = "uniref90" + BFD = "BFD" + MGNIFY = "mgnify" + PAIRED = "paired" + MAIN = "main" + BFD_UNICLUST = "bfd_uniclust" + SINGLETON = "singleton" + NONE = "none" + + # templates + PDB70 = "pdb70" + + # ran with 3 jackhmmer iterations (-N=3), + # higher quality but sloow to generate + UNIPROT_N3 = "uniprot_n3" + UNIREF90_N3 = "uniref90_n3" + MGNIFY_N3 = "mgnify_n3" + + @classmethod + def get_default_sources(cls): + return [ + MSADataSource.BFD_UNICLUST, + MSADataSource.MGNIFY, + MSADataSource.UNIREF90, + MSADataSource.UNIPROT, + ] + + +def encode_source_to_int(source: MSADataSource) -> int: + return msa_dataset_source_to_int.get(source, 4) + + +# This becomes a feature so changing it might break checkpoint compatibility +msa_dataset_source_to_int = { + MSADataSource.BFD_UNICLUST: 0, + MSADataSource.MGNIFY: 1, + MSADataSource.UNIREF90: 2, + MSADataSource.UNIPROT: 3, + MSADataSource.NONE: 4, + MSADataSource.UNIPROT_N3: 3, + MSADataSource.UNIREF90_N3: 2, + MSADataSource.MGNIFY_N3: 1, +} + +database_ids: set[str] = set(x.value for x in MSADataSource) diff --git a/chai_lab/data/parsing/msas/species.py b/chai_lab/data/parsing/msas/species.py new file mode 100644 index 0000000..ac684f3 --- /dev/null +++ b/chai_lab/data/parsing/msas/species.py @@ -0,0 +1,5 @@ +import logging + +logger = logging.getLogger(__name__) + +UNKNOWN_SPECIES = 0 diff --git a/chai_lab/data/parsing/structure/__init__.py b/chai_lab/data/parsing/structure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/data/parsing/structure/all_atom_entity_data.py b/chai_lab/data/parsing/structure/all_atom_entity_data.py new file mode 100644 index 0000000..a28871a --- /dev/null +++ b/chai_lab/data/parsing/structure/all_atom_entity_data.py @@ -0,0 +1,87 @@ +import logging +from dataclasses import dataclass +from datetime import datetime +from functools import cached_property + +from chai_lab.data.parsing.structure import sequence +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.data.parsing.structure.residue import Residue +from chai_lab.data.residue_constants import standard_residue_pdb_codes +from chai_lab.utils.typing import typecheck + +logger = logging.getLogger(__name__) + + +@typecheck +@dataclass +class AllAtomEntityData: + residues: list[Residue] + full_sequence: list[str] + resolution: float + release_datetime: datetime | None # None if no date found + pdb_id: str + source_pdb_chain_id: str + # Unique string identifying the entity. + entity_name: str + # Unique integer identifying the entity, starting at 0. There is a 1:1 mapping + # between entity_name and entity_index. + entity_id: int + method: str + entity_type: EntityType + subchain_id: str + is_d_polypeptide: bool = False # NOTE (mostly) exists for eval set construction + + def __post_init__(self): + assert ( + len(self.residues) == len(self.full_sequence) + ), f"{self.__class__.__name__} residues and full_sequence must be the same length" + + @property + def missing_residues(self) -> list[Residue]: + """ + Returns a list of missing residues in the entity + """ + return [residue for residue in self.residues if residue.is_missing] + + @cached_property + def has_modifications(self) -> bool: + """ + Returns True if the entity has modifications; this only applies to polymers so + is always False for ligands, waters, and unknowns. + """ + if self.entity_type not in ( + EntityType.PROTEIN, + EntityType.RNA, + EntityType.DNA, + EntityType.POLYMER_HYBRID, + ): + return False + + return any(res.name not in standard_residue_pdb_codes for res in self.residues) + + @property + def is_distillation(self) -> bool: + return self.pdb_id.startswith("AF-") + + @property + def sequence(self) -> str: + """Sequence with modified residues encoded as X.""" + return sequence.protein_one_letter_sequence(self.full_sequence) + + @property + def sequence_with_mods(self) -> str: + """Sequence with modifications encoded as [FOO] where FOO is modified residue.""" + return sequence.protein_one_letter_sequence_with_mods(self.full_sequence) + + def __str__(self) -> str: + fields = ", ".join( + [ + f"pdb_id={self.pdb_id}", + f"source_pdb_chain_id={self.source_pdb_chain_id}", + f"entity_name={self.entity_name}", + f"entity_id={self.entity_id}", + f"entity_type={self.entity_type}", + f"subchain_id={self.subchain_id}", + ] + ) + return f"AllAtomEntityData({fields})" diff --git a/chai_lab/data/parsing/structure/entity_type.py b/chai_lab/data/parsing/structure/entity_type.py new file mode 100644 index 0000000..ec5249b --- /dev/null +++ b/chai_lab/data/parsing/structure/entity_type.py @@ -0,0 +1,14 @@ +import logging +from enum import Enum + +logger = logging.getLogger(__name__) + + +class EntityType(Enum): + PROTEIN = 0 + RNA = 1 + DNA = 2 + LIGAND = 3 + POLYMER_HYBRID = 4 + WATER = 5 + UNKNOWN = 6 diff --git a/chai_lab/data/parsing/structure/residue.py b/chai_lab/data/parsing/structure/residue.py new file mode 100644 index 0000000..6174c75 --- /dev/null +++ b/chai_lab/data/parsing/structure/residue.py @@ -0,0 +1,105 @@ +from dataclasses import dataclass + +import gemmi +import torch +from torch import Tensor + +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.data.residue_constants import residue_types_with_nucleotides_order +from chai_lab.model.utils import center_random_augmentation +from chai_lab.utils.typing import Bool, Float, Int + + +@dataclass +class ConformerData: + position: Float[Tensor, "n 3"] + element: Int[Tensor, "n"] + charge: Int[Tensor, "n"] + atom_names: list[str] + bonds: list[tuple[int, int]] + symmetries: Int[Tensor, "n n_symm"] + + @property + def num_atoms(self) -> int: + num_atoms, _ = self.position.shape + assert num_atoms == len(self.atom_names) + return num_atoms + + def gather_atom_positions( + self, query_atom_names: list[str] + ) -> tuple[Float[Tensor, "n 3"], Bool[Tensor, "n"]]: + if self.num_atoms == 0: + gathered_positions = torch.zeros(len(query_atom_names), 3) + mask = torch.zeros(len(query_atom_names), dtype=torch.bool) + return gathered_positions, mask + + atom_indices = {name: i for i, name in enumerate(self.atom_names)} + indices = torch.tensor( + [atom_indices.get(name, -1) for name in query_atom_names], + dtype=torch.int, + ) + mask = indices != -1 + gathered_positions = self.position[indices] * mask.unsqueeze(-1) + + return gathered_positions, mask + + def center_random_augment( + self, + ) -> "ConformerData": + if self.num_atoms == 0: + return self + + atom_mask = torch.ones_like(self.element, dtype=torch.bool) + centered_coords = center_random_augmentation( + self.position.unsqueeze(0), atom_mask.unsqueeze(0) + )[0] + return ConformerData( + centered_coords, + self.element, + self.charge, + self.atom_names, + self.bonds, + self.symmetries, + ) + + +@dataclass +class Residue: + name: str + label_seq: int | None + restype: int + residue_index: int + is_missing: bool + b_factor_or_plddt: float + conformer_data: ConformerData | None + smiles: str | None = None + + +def get_restype( + residue_info: gemmi.ResidueInfo, + entity_type: EntityType, +) -> int: + """ + Encodes residues into alphabet of size 32: + 20 standards AAs + X + 4 RNA bases + RX + 4 DNA bases + DX + GAP + note: ligand residues as encoded as X + """ + + if residue_info.is_amino_acid(): + restype = residue_info.fasta_code() # encodes non-standard as X + unknown_value = residue_types_with_nucleotides_order["X"] + elif residue_info.is_nucleic_acid() and entity_type == EntityType.RNA: + restype = "R{}".format(residue_info.one_letter_code) + unknown_value = residue_types_with_nucleotides_order["RX"] + elif residue_info.is_nucleic_acid() and entity_type == EntityType.DNA: + restype = "D{}".format(residue_info.one_letter_code) + unknown_value = residue_types_with_nucleotides_order["DX"] + else: + restype = "X" + unknown_value = residue_types_with_nucleotides_order["X"] + + tokenized_restype = residue_types_with_nucleotides_order.get(restype, unknown_value) + return tokenized_restype diff --git a/chai_lab/data/parsing/structure/sequence.py b/chai_lab/data/parsing/structure/sequence.py new file mode 100644 index 0000000..dbea3fd --- /dev/null +++ b/chai_lab/data/parsing/structure/sequence.py @@ -0,0 +1,135 @@ +import logging + +import gemmi + +from chai_lab.data import residue_constants +from chai_lab.data.parsing.structure.entity_type import EntityType + +logger = logging.getLogger(__name__) + + +def fasta_one_letter_sequence(residue_codes: list[str]) -> str: + """ + Converts a list of residue names into a one-letter-code sequence + """ + return "".join( + [gemmi.find_tabulated_residue(res).fasta_code() for res in residue_codes] + ) + + +def protein_one_letter_sequence(residue_codes: list[str]) -> str: + """ + Converts a list of protein residue names into a one-letter-code sequence. + Probably equivalent to gemmi fasta_code() method but kept for consistency + with old parsing + to be explicit about how non-standard res are handled (with X) + """ + return "".join([_get_protein_only_residue_token(res) for res in residue_codes]) + + +def protein_one_letter_sequence_with_mods(residue_codes: list[str]) -> str: + """ + Convert a list of protein residue names into a one-letter code sequence, + insert non-standard residues as [FOO] where FOO corresponds to the residue code of + that non-standard residue. + + For example, 1PFH is ...APNGL[HIP]TRP... where HIP is the modified residue. + """ + return "".join( + [ + _get_protein_only_residue_token(res, mods_in_brackets=True) + for res in residue_codes + ] + ) + + +def _get_protein_only_residue_token( + three_letter_code: str, + mods_in_brackets: bool = False, +) -> str: + """Encodes everything that is not a standard amino acid as X if nonstandard_as_X is + True, otherwise return nonstandard FOO as [FOO]""" + residue_info = gemmi.find_tabulated_residue(three_letter_code) + # Standard amino acids are always given as single letters + if residue_info.is_amino_acid() and residue_info.is_standard(): + single_letter = residue_info.one_letter_code + single_letter = single_letter.upper() + # non-standard residues derived from a parent std residue are lowercase + single_letter = ( + single_letter if single_letter in residue_constants.restypes else "X" + ) + return single_letter + else: + if mods_in_brackets: + return f"[{three_letter_code}]" + else: + # non-standard residues derived from a parent std residue may have a + # lowercase one-letter code; make this upper case. + single_letter = residue_info.one_letter_code.upper() + return single_letter if single_letter in residue_constants.restypes else "X" + + +def _get_residue_token( + three_letter_code: str, + entity_type: EntityType, +) -> str: + """ + Encodes amino-acids and nucleic acids into corresponding tokens + 20 standard AAs + X + 4 RNA bases + RX + 4 DNA bases + DX + """ + residue_info = gemmi.find_tabulated_residue(three_letter_code) + if residue_info.is_amino_acid(): + single_letter = residue_info.one_letter_code + single_letter = single_letter.upper() + # non-standard residues derived from a parent std residue are lowercase + single_letter = ( + single_letter if single_letter in residue_constants.restypes else "X" + ) + return single_letter + + elif residue_info.is_nucleic_acid() and entity_type == EntityType.RNA: + return "R{}".format(residue_info.one_letter_code) + + elif residue_info.is_nucleic_acid() and entity_type == EntityType.DNA: + return "D{}".format(residue_info.one_letter_code) + + else: + # more properties at https://gemmi.readthedocs.io/en/latest/mol.html#built-in-data + return "X" + + +def get_residue_codes(subchain: gemmi.ResidueSpan, entity: gemmi.Entity) -> list[str]: + """ + Get list of residue codes (3-letter for protein residues, + 1 to 3 letters/digits for ligands, 1 or 2 letters for RNA/DNA) + for a gemmi subchain + """ + # entity.full_sequence comes from SEQRES, so it might be missing in PDB files + if entity.full_sequence is not None and len(entity.full_sequence) > 0: + return [ + gemmi.Entity.first_mon(item) # Ignore point mutations + for item in entity.full_sequence + ] + # this infers the sequence from the set of residues in the structure + return [res.name for res in subchain.first_conformer()] + + +def tokenize_sequence( + subchain: gemmi.ResidueSpan, entity: gemmi.Entity, entity_type: EntityType +) -> list[str]: + three_letter_sequence = get_residue_codes(subchain, entity) + + match entity_type: + case EntityType.PROTEIN: + return [ + _get_protein_only_residue_token(three_letter_code) + for three_letter_code in three_letter_sequence + ] + case EntityType.RNA | EntityType.DNA: + return [ + _get_residue_token(three_letter_code, entity_type) + for three_letter_code in three_letter_sequence + ] + case _: + raise NotImplementedError diff --git a/chai_lab/data/residue_constants.py b/chai_lab/data/residue_constants.py new file mode 100644 index 0000000..fdd6e11 --- /dev/null +++ b/chai_lab/data/residue_constants.py @@ -0,0 +1,597 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from enum import Enum + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + "ALA": ["C", "CA", "CB", "N", "O"], + "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"], + "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], + "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"], + "CYS": ["C", "CA", "CB", "N", "O", "SG"], + "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"], + "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"], + "GLY": ["C", "CA", "N", "O"], + "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"], + "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"], + "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"], + "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"], + "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"], + "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"], + "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"], + "SER": ["C", "CA", "CB", "N", "O", "OG"], + "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"], + "TRP": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + "N", + "NE1", + "O", + ], + "TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"], + "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"], +} + +# nucleic acid atoms from rosettafold-all-atoms +# we prefix nucleic acid tokens with "R" for RNA and "D" for DNA +# we add an unknown token RX for RNA-unknown and DX for DNA-unknown +nucleic_acid_atoms: dict[str, tuple[str | None, ...]] = { + "DA": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "N9", + "C4", + "N3", + "C2", + "N1", + "C6", + "C5", + "N7", + "C8", + "N6", + None, + None, + "H5''", + "H5'", + "H4'", + "H3'", + "H2''", + "H2'", + "H1'", + "H2", + "H61", + "H62", + "H8", + None, + None, + ), + "DC": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "N1", + "C2", + "O2", + "N3", + "C4", + "N4", + "C5", + "C6", + None, + None, + None, + None, + "H5''", + "H5'", + "H4'", + "H3'", + "H2''", + "H2'", + "H1'", + "H42", + "H41", + "H5", + "H6", + None, + None, + ), + "DG": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "N9", + "C4", + "N3", + "C2", + "N1", + "C6", + "C5", + "N7", + "C8", + "N2", + "O6", + None, + "H5''", + "H5'", + "H4'", + "H3'", + "H2''", + "H2'", + "H1'", + "H1", + "H22", + "H21", + "H8", + None, + None, + ), + "DT": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "N1", + "C2", + "O2", + "N3", + "C4", + "O4", + "C5", + "C7", + "C6", + None, + None, + None, + "H5''", + "H5'", + "H4'", + "H3'", + "H2''", + "H2'", + "H1'", + "H3", + "H71", + "H72", + "H73", + "H6", + None, + ), + "DX": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + "N1", + "C2", + "N3", + "C4", + "C5", + "C6", + "N6", + "N7", + "C8", + "N9", + None, + "H5'", + "H5''", + "H4'", + "H3'", + "H2'", + "HO2'", + "H1'", + "H2", + "H61", + "H62", + "H8", + None, + None, + ), + "RA": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + "N1", + "C2", + "N3", + "C4", + "C5", + "C6", + "N6", + "N7", + "C8", + "N9", + None, + "H5'", + "H5''", + "H4'", + "H3'", + "H2'", + "HO2'", + "H1'", + "H2", + "H61", + "H62", + "H8", + None, + None, + ), + "RC": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + "N1", + "C2", + "O2", + "N3", + "C4", + "N4", + "C5", + "C6", + None, + None, + None, + "H5'", + "H5''", + "H4'", + "H3'", + "H2'", + "HO2'", + "H1'", + "H42", + "H41", + "H5", + "H6", + None, + None, + ), + "RG": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + "N1", + "C2", + "N2", + "N3", + "C4", + "C5", + "C6", + "O6", + "N7", + "C8", + "N9", + "H5'", + "H5''", + "H4'", + "H3'", + "H2'", + "HO2'", + "H1'", + "H1", + "H22", + "H21", + "H8", + None, + None, + ), + "RU": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + "N1", + "C2", + "O2", + "N3", + "C4", + "O4", + "C5", + "C6", + None, + None, + None, + "H5'", + "H5''", + "H4'", + "H3'", + "H2'", + "HO2'", + "H1'", + "H3", + "H5", + "H6", + None, + None, + None, + ), + "RX": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + "H5'", + "H5''", + "H4'", + "H3'", + "H2'", + "HO2'", + "H1'", + None, + None, + None, + None, + None, + None, + ), +} + + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_indices = Enum( + "atom_indices", + [ + "N", + "CA", + "C", + "CB", + "O", + "CG", + "CG1", + "CG2", + "OG", + "OG1", + "SG", + "CD", + "CD1", + "CD2", + "ND1", + "ND2", + "OD1", + "OD2", + "SD", + "CE", + "CE1", + "CE2", + "CE3", + "NE", + "NE1", + "NE2", + "OE1", + "OE2", + "CH2", + "NH1", + "NH2", + "OH", + "CZ", + "CZ2", + "CZ3", + "NZ", + "OXT", + ], + start=0, +) +atom_types = [atom_index.name for atom_index in atom_indices] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", +] + +residue_types_with_nucleotides = ( + restypes + + ["X"] + + ["RA", "RC", "RG", "RU", "RX"] + + ["DA", "DC", "DG", "DT", "DX"] + + ["-"] # gap + + [":"] # non-existent (i.e. should get masked) +) + +residue_types_with_nucleotides_order = { + restype: i for i, restype in enumerate(residue_types_with_nucleotides) +} + + +# Residue names as found in mmcif/ gemmi parsed data +# that indicate a residue will be tokenized by residue +# and not by atom. +standard_residue_pdb_codes = { + "ALA", + "ARG", + "ASN", + "ASP", + "CYS", + "GLN", + "GLU", + "GLY", + "HIS", + "ILE", + "LEU", + "LYS", + "MET", + "PHE", + "PRO", + "SER", + "THR", + "TRP", + "TYR", + "VAL", + "UNK", + "A", + "G", + "C", + "U", + "DA", + "DG", + "DC", + "DT", +} + +# we reserve this residue name for ligands +# it is not assigned to any chemical in the PDB +# it should never have a cached ref conformer +new_ligand_residue_name = "LIG" + + +restype_1to3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", +} + +restype_1to3_with_x = {**restype_1to3, "X": "UNK"} diff --git a/chai_lab/data/sources/rdkit.py b/chai_lab/data/sources/rdkit.py new file mode 100644 index 0000000..be6df3a --- /dev/null +++ b/chai_lab/data/sources/rdkit.py @@ -0,0 +1,253 @@ +import logging +from pathlib import Path + +import antipickle +import torch +from rdkit import Chem +from rdkit.Chem import AllChem + +# for some reason calling Chem.rdDetermineBonds doesnt work +from rdkit.Chem.rdDetermineBonds import DetermineBonds +from rdkit.Geometry import Point3D +from rdkit.rdBase import BlockLogs +from tqdm import tqdm + +from chai_lab.data.parsing.structure.residue import ConformerData +from chai_lab.data.residue_constants import ( + new_ligand_residue_name, + standard_residue_pdb_codes, +) +from chai_lab.utils import paths +from chai_lab.utils.pickle import TorchAntipickleAdapter +from chai_lab.utils.timeout import timeout + +# important to set this flag otherwise atom properties such as +# "name" will be lost when pickling +# See https://github.com/rdkit/rdkit/issues/1320 +Chem.SetDefaultPickleProperties(Chem.PropertyPickleOptions.AllProps) + +logger = logging.getLogger(__name__) + + +class RefConformerGenerator: + def __init__( + self, + leaving_atoms_cache_file: str | None = None, + ): + """ + N.B. in almost all cases, you want to use RefConformerGenerator.make() rather + than initializing the object directly, since constructor the conformer generator + is expensive, and we want to cache the result. + + Caches idealized 3D coordinates and list of atoms for residues that exist in the PDB + This is needed to create empty atom coordinates and mask for missing residues + and ensure the number of tokens and atoms is the same for chains with the same entity_id + """ + # Mapping of molecule names to (atom_names, leaving_atoms); leaving atoms + # correspond to True. See the following file for how this was constructed: + # src/scripts/small_molecule_preprocess/leaving_atoms.py + self.leaving_atoms: dict[str, tuple[list[str], list[bool]]] = dict() + if leaving_atoms_cache_file is not None: + self.leaving_atoms = antipickle.load(leaving_atoms_cache_file) + + # download conformers' cache if needed + conformers_cache_file = paths.cached_conformers.get_path().as_posix() + # load cached conformers after leaving atoms cache is generated in + # case we need to re-generate the cache + self.cached_conformers = self._load_apkl_conformers(conformers_cache_file) + + if new_ligand_residue_name in self.cached_conformers: + self.cached_conformers.pop(new_ligand_residue_name) + + assert len(self.cached_conformers) > 0 + + def _load_apkl_conformers(self, path: str) -> dict[str, ConformerData]: + assert path.endswith(".apkl") + assert Path(path).exists() + return antipickle.load(path, adapters=_get_adapters()) + + def _load_cached_conformers(self, path: str) -> dict[str, ConformerData]: + block = BlockLogs() + with Chem.SDMolSupplier(path) as suppl: + mols = [m for m in suppl if m is not None] + del block + logger.info(f"Loaded {len(mols)} cached conformers") + + residues_dict = { + m.GetProp("_Name"): self._load_ref_conformer_from_rdkit(m) + for m in tqdm(mols) + } + + # check at least standard residues were loaded + # otherwise missing protein residues cannot be handled + for res_name in standard_residue_pdb_codes: + assert ( + res_name in residues_dict + ), f"Standard residue {res_name} should have a reference conformer loaded" + + return residues_dict + + @classmethod + def _load_ref_conformer_from_rdkit(self, mol: Chem.Mol) -> ConformerData: + mol = Chem.RemoveAllHs(mol) + + ref_pos = torch.tensor(mol.GetConformer().GetPositions(), dtype=torch.float) + + ref_atom_names = [atom.GetProp("name") for atom in mol.GetAtoms()] + + ref_atom_charge = torch.tensor( + [atom.GetFormalCharge() for atom in mol.GetAtoms()], dtype=torch.int + ) + ref_atom_element = torch.tensor( + [atom.GetAtomicNum() for atom in mol.GetAtoms()], dtype=torch.int + ) + + bonds = [ + (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds() + ] + + symms = get_intra_res_atom_symmetries(mol) + + symmetries = ( + torch.stack([torch.tensor(x) for x in symms], dim=-1) + if len(symms) > 0 + else torch.arange(len(ref_atom_names)).unsqueeze(-1) + ) + + return ConformerData( + position=ref_pos, + element=ref_atom_element, + charge=ref_atom_charge, + atom_names=ref_atom_names, + bonds=bonds, + symmetries=symmetries, + ) + + def get(self, residue_name: str) -> ConformerData | None: + """ + Returns an rdkit reference conformer if residue is in CCD and conformer + generation succeeded. Otherwise, returns None. + + N.B. we should _not_ add more post-processing logic to this method, since we + call this for every residue and want cache lookups to be fast for large + chains. If you need to modify the conformer data, do that when building the + cache instead. + """ + return self.cached_conformers.get(residue_name) + + def generate(self, smiles: str) -> ConformerData: + """Generates a conformer for a ligand from its SMILES string.""" + mol = Chem.MolFromSmiles(smiles) + assert mol is not None, f"Invalid smiles {smiles}" + + mol_with_hs = Chem.AddHs(mol) + + params = AllChem.ETKDGv3() + params.useSmallRingTorsions = True + params.randomSeed = 123 + params.useChirality = True + # below params were added after facing 'Value Error: Bad Conformer id' + # https://github.com/rdkit/rdkit/issues/1433#issuecomment-305097888 + params.maxAttempts = 10_000 + params.useRandomCoords = True + + AllChem.EmbedMultipleConfs(mol_with_hs, numConfs=1, params=params) + AllChem.RemoveHs(mol_with_hs) + for atom in mol_with_hs.GetAtoms(): + atom.SetProp("name", atom.GetSymbol()) + retval = self._load_ref_conformer_from_rdkit(mol_with_hs) + retval.atom_names = [a.upper() for a in retval.atom_names] + return retval + + +def _get_adapters(): + ## adapters define how antipickle should serialize unknown types + from antipickle.adapters import DataclassAdapter + + return [TorchAntipickleAdapter(), DataclassAdapter(dict(conf=ConformerData))] + + +def conformer_data_to_rdkit_mol(conformer: ConformerData) -> Chem.Mol: + """Convert ConformerData to RDKit Mol + RDKit Molecules can be used infer bonds (often better than the PDB) and compute + intra-residue atom symmetries. + """ + + # Create an editable molecule object and add atoms + editable_mol = Chem.RWMol() + + # Add atoms to the molecule + for atom_type, atom_name in zip(conformer.element, conformer.atom_names): + atom = Chem.Atom(atom_type.item()) + atom.SetProp("name", atom_name) + editable_mol.AddAtom(atom) + + # Create a conformer to hold the 3D coordinates + rd_conformer = Chem.Conformer(len(conformer.element)) + + # Set the coordinates for each atom + for i, pos in enumerate(conformer.position.tolist()): + rd_conformer.SetAtomPosition(i, Point3D(*pos)) + + # Add the conformer and convert back to standard molecule instance + editable_mol.AddConformer(rd_conformer) + # add bonds + mol = editable_mol.GetMol() + mol = maybe_add_bonds(mol) + return mol + + +def maybe_add_bonds(mol: Chem.Mol, timeout_after: float = 1.0) -> Chem.Mol: + """Attempts to add bonds to a molecule. Returns original molecule if not + successful + + The RDKit determineBonds function is known to hang for certain molecules. + This function wraps the call in a timeout. + + """ + + @timeout(timeout_after) + def _add_bonds(mol): + # hard-to-find function for inferring bond information + # https://rdkit.org/docs/source/rdkit.Chem.rdDetermineBonds.html + # We wrap this in a timeout because this function is known to hang + # for some molecules. See Issue + # (https://github.com/rdkit/rdkit/discussions/7289#discussioncomment-8930333) + DetermineBonds(mol) + return mol + + try: + mol = _add_bonds(mol) + except ValueError as e: + # ValueError is caused by rdKit, e.g. + # - "could not find valid bond ordering" + # - "determineBondOrdering() does not work with element Os" + logger.warning(f"Failed to determine bonds for {Chem.MolToSmiles(mol)}, {e}") + except TimeoutError as e: + # TimoutError is cause by bug in rdkit + logger.warning(f"Failed to determine bonds for {Chem.MolToSmiles(mol)}, {e}") + + return mol + + +def get_intra_res_atom_symmetries( + mol: Chem.Mol, max_symmetries: int = 1000, timeout_after: float = 1.0 +) -> tuple[tuple[int, ...]]: + """Attempts to compute full set of intra-residue atom symmetries. Returns identity + permutation of atoms if not successful""" + + @timeout(timeout_after) + def _get_symmetries(): + return mol.GetSubstructMatches( + mol, uniquify=False, maxMatches=max_symmetries, useChirality=False + ) + + try: + symms = _get_symmetries() + except TimeoutError: + # Issues of hangup have been reported for certain ligand pairs + # Issue(https://github.com/michellab/BioSimSpace/issues/100) + # NOTE: this function calls MCS algorithm described in linked issue. + symms = (tuple(range(mol.GetNumAtoms())),) + + return symms diff --git a/chai_lab/model/__init__.py b/chai_lab/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/model/diffusion_schedules.py b/chai_lab/model/diffusion_schedules.py new file mode 100644 index 0000000..4149115 --- /dev/null +++ b/chai_lab/model/diffusion_schedules.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor + +from chai_lab.utils.typing import Float, typecheck + + +@dataclass(frozen=True) +class InferenceNoiseSchedule: + s_max: float = 160.0 + s_min: float = 4e-4 + p: float = 7.0 + sigma_data: float = 16.0 + + @typecheck + def get_schedule( + self, + device, + num_timesteps: int = 200, + ) -> Float[Tensor, "{num_timesteps}"]: + times = torch.linspace(0, 1, 2 * num_timesteps + 1, device=device)[1::2] + return self.get_noise_for_times(times) + + @typecheck + def get_noise_for_times( + self, times: Float[Tensor, "n_samples"] + ) -> Float[Tensor, "n_samples"]: + if times.min() < 0 or times.max() > 1: + raise ValueError("times must be in [0, 1]") + + sigmas = self.sigma_data * _power_interpolation( + times, val_0=self.s_max, val_1=self.s_min, p=self.p + ) + return sigmas + + +@typecheck +def _power_interpolation( + t: Float[Tensor, "n_samples"], val_0: float, val_1: float, p: float +) -> Float[Tensor, "n_samples"]: + # val0 at t=0, and val1 at t=1 + assert t.min() >= 0 and t.max() <= 1, f"0 <= t <= 1, but {t=}" + return (t * val_1 ** (1 / p) + (1 - t) * val_0 ** (1 / p)) ** p diff --git a/chai_lab/model/utils.py b/chai_lab/model/utils.py new file mode 100644 index 0000000..5e3828c --- /dev/null +++ b/chai_lab/model/utils.py @@ -0,0 +1,212 @@ +from typing import Any + +import torch +from einops import rearrange, reduce, repeat +from torch import Tensor + +from chai_lab.utils.tensor_utils import string_to_tensorcode, und +from chai_lab.utils.typing import Bool, Float, Int, UInt8, typecheck + + +def get_qkv_indices_for_blocks( + sequence_length: int, + stride: int, + kv_block_size: int, + device: Any, +) -> tuple[ + Int[torch.Tensor, "bl bl_q"], + Int[torch.Tensor, "bl bl_kv"], + Bool[torch.Tensor, "bl bl_kv"], +]: + """Gets q, kv indices for local attention blocks.""" + sequence_length + # from now on pretend q and kv are different axes + num_blocks = sequence_length // stride + assert ( + sequence_length == num_blocks * stride + ), f"only seqlens divisible by {stride=} are supported, not {sequence_length=}" + q_indices = torch.arange(sequence_length, device=device) + q_indices = rearrange( + q_indices, "(bl bl_q) -> bl bl_q", bl=num_blocks, bl_q=stride + ) # bl bl_q -> q + kv_indices = q_indices[:, :1] + (stride - kv_block_size) // 2 + kv_indices = kv_indices + torch.arange( + kv_block_size, device=kv_indices.device + ) # bl bl_kv -> kv + # mask out positions where kv_indices gets wrapped + # Rationale: the local attention block should allways process + # local blocks (i.e. same rel-positional encodings for each block.) + kv_mask = (kv_indices < sequence_length) & (kv_indices >= 0) + # Use of % not .clamp is important for short sequences + kv_indices = kv_indices % sequence_length + # q_idx is returned for reference, downstream code uses reshapes instead + return q_indices, kv_indices, kv_mask + + +@typecheck +def get_block_atom_pair_mask( + atom_single_mask: Bool[Tensor, "b a"], + q_idx: Int[Tensor, "bl bl_q"], + kv_idx: Int[Tensor, "bl bl_kv"], + kv_is_wrapped_mask: Bool[Tensor, "bl bl_kv"], +) -> Bool[Tensor, "b bl bl_q bl_kv"]: + atom_q_mask = atom_single_mask[:, q_idx] + atom_kv_mask = atom_single_mask[:, kv_idx] + + block_atom_pair_mask = und( + atom_q_mask, atom_kv_mask, "b bl bl_q, b bl bl_kv -> b bl bl_q bl_kv" + ) + + block_atom_pair_mask &= rearrange(kv_is_wrapped_mask, "bl bl_kv -> 1 bl 1 bl_kv") + return block_atom_pair_mask + + +@typecheck +def calc_centroid( + coords: Float[Tensor, "b a 3"], + mask: Bool[Tensor, "#b a"], + weights: Float[Tensor, "b a"] | None = None, +) -> Float[Tensor, "b 3"]: + # mean-center coordinates + masked_weights = weights * mask if weights is not None else mask.to(coords.dtype) + masked_weights /= reduce(masked_weights, "b a -> b 1", "sum").clamp(min=1e-4) + # not using einsum to avoid autocasting + return reduce(coords * masked_weights[:, :, None], "b a c -> b c", "sum") + + +def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Transform from: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Transform from: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def random_quaternions( + n: int, dtype: torch.dtype | None = None, device: Any | None = None +) -> torch.Tensor: + """ + Transform from: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + if isinstance(device, str): + device = torch.device(device) + o = torch.randn((n, 4), dtype=dtype, device=device) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: torch.dtype | None = None, device: Any = None +) -> torch.Tensor: + """ + Transform from: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions(n, dtype=dtype, device=device) + return quaternion_to_matrix(quaternions) + + +@torch.no_grad() +@typecheck +def center_random_augmentation( + atom_coords: Float[Tensor, "b a 3"], + atom_single_mask: Bool[Tensor, "#b a"], + s_trans: float = 1.0, + rotations: Float[Tensor, "b 3 3"] | None = None, +) -> Float[Tensor, "b a 3"]: + centroid = calc_centroid(atom_coords, mask=atom_single_mask) + centroid = rearrange(centroid, "b c -> b 1 c") + atom_coords = atom_coords - centroid + # randomly rotate + if rotations is None: + rotations = random_rotations(atom_coords.shape[0], device=atom_coords.device) + rotated_coords = torch.einsum("b i j, b a j -> b a i", rotations, atom_coords) + random_translation = torch.randn_like(centroid) # b 1 c=3 + return rotated_coords + s_trans * random_translation + + +@typecheck +def get_asym_id_from_subchain_id( + subchain_id: str, + source_pdb_chain_id: UInt8[Tensor, "n_tokens 4"], + token_asym_id: Int[Tensor, "n"], +): + # encde the subchain ids and perform lookup in context features + chain_id_tensorcode = string_to_tensorcode(subchain_id, pad_to_length=4) + chain_id_tensorcode = chain_id_tensorcode.to(token_asym_id.device) + # create masks + chain_id_tensorcode = repeat(chain_id_tensorcode, "c -> 1 c") + chain_id_mask = torch.all(chain_id_tensorcode == source_pdb_chain_id, dim=-1) + # check uniqueness + chain_id_asyms = torch.unique(token_asym_id[chain_id_mask]) + + assert len(chain_id_asyms) == 1, ( + f"Expected only one token asym, but got {len(chain_id_asyms)} " + f"asyms: {chain_id_asyms}" + ) + return chain_id_asyms[0].item() diff --git a/chai_lab/py.typed b/chai_lab/py.typed new file mode 100644 index 0000000..cd62ab2 --- /dev/null +++ b/chai_lab/py.typed @@ -0,0 +1 @@ +# marker that this package is compatible with python typing \ No newline at end of file diff --git a/chai_lab/ranking/__init__.py b/chai_lab/ranking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/ranking/clashes.py b/chai_lab/ranking/clashes.py new file mode 100644 index 0000000..fd3b569 --- /dev/null +++ b/chai_lab/ranking/clashes.py @@ -0,0 +1,261 @@ +from dataclasses import dataclass +from itertools import combinations + +import torch +from einops import rearrange, reduce, repeat +from torch import Tensor + +import chai_lab.ranking.utils as rutils +from chai_lab.ranking.utils import ( + get_chain_masks_and_asyms, +) +from chai_lab.utils.tensor_utils import cdist, und_self +from chai_lab.utils.typing import Bool, Float, Int, typecheck + + +@typecheck +@dataclass +class ClashScores: + """ + total_clashes: total number of clashes in the complex + total_inter_chain_clashes: total number of inter-chain clashes in the complex, + i.e. inter-chain clashes summed over all chain pairs + per_chain_intra_clashes: number of intra-chain clashes for each chain in the complex + per_chain_pair_clashes: number of inter-chain clashes for each chain pair in the complex + """ + + total_clashes: Float[Tensor, "..."] + total_inter_chain_clashes: Float[Tensor, "..."] + per_chain_intra_clashes: Float[Tensor, "... n_chains"] + per_chain_pair_clashes: Float[Tensor, "... n_chains n_chains"] + has_clashes: Bool[Tensor, "..."] + + +@typecheck +def _compute_clashes( + atom_coords: Float[Tensor, "... a 3"], + atom_mask: Bool[Tensor, "... a"], + clash_threshold: float = 1.1, +) -> Bool[Tensor, "... a a"]: + pairwise_dists = cdist(atom_coords) + valid_mask = und_self(atom_mask, "... i, ... j -> ... i j") + valid_mask = valid_mask & ~torch.eye( + atom_coords.shape[-2], device=atom_coords.device, dtype=torch.bool + ) + return valid_mask & (pairwise_dists < clash_threshold) + + +@typecheck +def maybe_compute_clashes( + atom_coords: Float[Tensor, "... a 3"], + atom_mask: Bool[Tensor, "... a"], + clash_matrix: Bool[Tensor, "... a a"] | None = None, + clash_threshold: float = 1.1, +) -> Bool[Tensor, "... a a"]: + if clash_matrix is None: + return _compute_clashes(atom_coords, atom_mask, clash_threshold) + else: + return clash_matrix + + +@typecheck +def total_clashes( + atom_coords: Float[Tensor, "... a 3"], + atom_mask: Bool[Tensor, "... a"], + clash_matrix: Bool[Tensor, "... a a"] | None = None, + clash_threshold: float = 1.1, +) -> Float[Tensor, "..."]: + """ + Computes the total number of clashes in the complex + """ + clash_matrix = maybe_compute_clashes( + atom_coords, atom_mask, clash_matrix, clash_threshold + ) + # clash matrix is symmetric + return reduce(clash_matrix, "... a1 a2 -> ...", "sum") / 2 + + +@typecheck +def total_inter_chain_clashes( + atom_coords: Float[Tensor, "... a 3"], + atom_mask: Bool[Tensor, "... a"], + asym_id: Int[Tensor, "... a"], + clash_matrix: Bool[Tensor, "... a a"] | None = None, + clash_threshold: float = 1.1, +) -> Float[Tensor, "..."]: + """Compute total number of inter-chain clashes in the complex""" + clash_matrix = maybe_compute_clashes( + atom_coords, atom_mask, clash_matrix, clash_threshold + ).clone() # don't overwrite an input + # clash matrix is symmetric + clash_matrix &= rearrange(asym_id, "... a -> ... a 1") != rearrange( + asym_id, "... a -> ... 1 a" + ) + # account for double counting + return reduce(clash_matrix, "... a1 a2 -> ...", "sum") / 2 + + +@typecheck +def per_chain_intra_clashes( + atom_coords: Float[Tensor, "... a 3"], + atom_mask: Bool[Tensor, "... a"], + asym_id: Int[Tensor, "... a"], + clash_matrix: Bool[Tensor, "... a a"] | None = None, + clash_threshold: float = 1.1, +) -> tuple[Float[Tensor, "... n_chains"], Int[Tensor, "n_chains"]]: + clash_matrix = maybe_compute_clashes( + atom_coords, atom_mask, clash_matrix, clash_threshold + ).clone() # don't overwrite an input + # clash matrix is symmetric + clash_matrix &= rearrange(asym_id, "... a -> ... a 1") == rearrange( + asym_id, "... a -> ... 1 a" + ) + per_atom_clashes = reduce(clash_matrix, "... a -> ...", "sum") / 2 + # add dimension for chains + per_atom_clashes = rearrange(per_atom_clashes, "... a -> ... 1 a") + chain_masks, asyms = get_chain_masks_and_asyms(asym_id, atom_mask) + return reduce(per_atom_clashes * chain_masks, "... c a -> ... c", "sum"), asyms + + +@typecheck +def per_chain_pair_clashes( + atom_coords: Float[Tensor, "... a 3"], + atom_mask: Bool[Tensor, "... a"], + asym_id: Int[Tensor, "... a"], + clash_matrix: Bool[Tensor, "... a a"] | None = None, + clash_threshold: float = 1.1, +) -> tuple[Float[Tensor, "... n_chains n_chains"], Int[Tensor, "n_chains"]]: + """ + Compute the number of inter-chain clashes for each chain in the complex + """ + clash_matrix = maybe_compute_clashes( + atom_coords, atom_mask, clash_matrix, clash_threshold + ).clone() # don't overwrite an input + clash_matrix &= rearrange(asym_id, "... a -> ... a 1") != rearrange( + asym_id, "... a -> ... 1 a" + ) + chain_masks, asyms = get_chain_masks_and_asyms(asym_id, atom_mask) + per_chain_clashes = torch.zeros( + *chain_masks.shape[:-2], + len(asyms), + len(asyms), + device=atom_coords.device, + dtype=torch.float32, + ) + # compute in loop to minimize peak memory + for i, j in combinations(range(len(asyms)), 2): + chain_pair_mask = torch.einsum( + "...i,...j->...ij", chain_masks[..., i, :], chain_masks[..., j, :] + ) + # chain_pair_mask is triangular, so don't need to account for double counting + per_chain_clashes[..., i, j] = reduce( + clash_matrix * chain_pair_mask, "... i j -> ...", "sum" + ) + symm_clashes = per_chain_clashes + rearrange( + per_chain_clashes, "... i j -> ... j i" + ) + return symm_clashes, asyms + + +@typecheck +def has_clashes( + atom_mask: Bool[Tensor, "... a"], + atom_asym_id: Int[Tensor, "... a"], + atom_entity_type: Int[Tensor, "... a"], + per_chain_pair_clashes: Float[Tensor, "... n_chains n_chains"], + max_clashes: int = 100, + max_clash_ratio: float = 0.5, +) -> Bool[Tensor, "..."]: + """ + Determine if the complex has inter-chain clashes. + Criteria: + (1) If a chain pair has more than `max_clashes` clashes, then consider it a clash + (2) If a chain pair has less than `max_clashes` clashes, but the total number of + clashes is more than `max_clash_ratio` of the smaller chain's total atoms, + then also consider it a clash + (3) The chain pairs must be both be polymers + + """ + has_clashes = per_chain_pair_clashes >= max_clashes + + atoms_per_chain = rutils.num_atoms_per_chain( + atom_mask=atom_mask, + asym_id=atom_asym_id, + ) + + # if a chain pair has less than max_clashes clashes, butmore than + # max_clash_ratio of the smaller chain's total atoms, then also + # consider it a clash + c = atoms_per_chain.shape[-1] + atoms_per_chain_row = repeat(atoms_per_chain, "... c -> ... (c k)", k=c) + atoms_per_chain_col = repeat(atoms_per_chain, "... c -> ... (k c)", k=c) + min_atoms_per_chain_pair, _ = torch.min( + torch.stack([atoms_per_chain_row, atoms_per_chain_col], dim=-1), dim=-1 + ) + min_atoms_per_chain_pair = rearrange( + min_atoms_per_chain_pair, + "... (c_row c_col) -> ... c_row c_col", + c_row=c, + ) + has_clashes |= ( + per_chain_pair_clashes / torch.clamp(min_atoms_per_chain_pair, min=1) + ) >= max_clash_ratio + + # only consider clashes between pairs of polymer chains + polymer_chains = rutils.chain_is_polymer( + asym_id=atom_asym_id, + mask=atom_mask, + entity_type=atom_entity_type, + ) + is_polymer_pair = rearrange(polymer_chains, "... c -> ... c 1") & rearrange( + polymer_chains, "... c -> ... 1 c" + ) + # reduce over all chain pairs + return torch.any(has_clashes & is_polymer_pair, dim=(-1, -2)) + + +@typecheck +def get_scores( + atom_coords: Float[Tensor, "... a 3"], + atom_mask: Bool[Tensor, "... a"], + atom_asym_id: Int[Tensor, "... a"], + atom_entity_type: Int[Tensor, "... a"], + clash_threshold: float = 1.1, + max_clashes: int = 100, + max_clash_ratio: float = 0.5, +) -> ClashScores: + clash_matrix = _compute_clashes(atom_coords, atom_mask, clash_threshold) + _per_chain_pair_clashes = per_chain_pair_clashes( + atom_coords, atom_mask, atom_asym_id, clash_matrix, clash_threshold + )[0] + return ClashScores( + total_clashes=total_clashes( + atom_coords=atom_coords, + atom_mask=atom_mask, + clash_matrix=clash_matrix, + clash_threshold=clash_threshold, + ), + total_inter_chain_clashes=total_inter_chain_clashes( + atom_coords=atom_coords, + atom_mask=atom_mask, + asym_id=atom_asym_id, + clash_matrix=clash_matrix, + clash_threshold=clash_threshold, + ), + per_chain_intra_clashes=per_chain_intra_clashes( + atom_coords=atom_coords, + atom_mask=atom_mask, + asym_id=atom_asym_id, + clash_matrix=clash_matrix, + clash_threshold=clash_threshold, + )[0], + per_chain_pair_clashes=_per_chain_pair_clashes, + has_clashes=has_clashes( + atom_mask=atom_mask, + atom_asym_id=atom_asym_id, + atom_entity_type=atom_entity_type, + per_chain_pair_clashes=_per_chain_pair_clashes, + max_clashes=max_clashes, + max_clash_ratio=max_clash_ratio, + ), + ) diff --git a/chai_lab/ranking/frames.py b/chai_lab/ranking/frames.py new file mode 100644 index 0000000..0387dcd --- /dev/null +++ b/chai_lab/ranking/frames.py @@ -0,0 +1,168 @@ +import torch +from einops import rearrange, repeat +from torch import Tensor + +from chai_lab.data.features.token_utils import get_centre_positions_and_mask +from chai_lab.utils.tensor_utils import cdist, und_self +from chai_lab.utils.typing import Bool, Float, Int, typecheck + + +@typecheck +def abc_is_colinear( + atoms_a: Float[Tensor, "b n_triplets 3"], + atoms_b: Float[Tensor, "b n_triplets 3"], + atoms_c: Float[Tensor, "b n_triplets 3"], +) -> Bool[Tensor, "b n_triplets"]: + """Check to see if each triplet of 3 atoms (a, b, c) is co-linear.""" + w1 = atoms_a - atoms_b + w1 /= torch.linalg.norm(w1, dim=-1, keepdim=True) + w2 = atoms_c - atoms_b + w2 /= torch.linalg.norm(w2, dim=-1, keepdim=True) + + cos_sim = torch.sum(w1 * w2, dim=-1) + cos_sim = torch.clamp(cos_sim, -1.0, 1.0) + angle = torch.acos(cos_sim) # radians + + # Colinearity should cover cases that are very small acute angles and cases of large + # obtuse angles that are close to 180 degrees. + colinear = ( + torch.isnan(angle) + | (angle < 25 / 180 * torch.pi) + | (angle > 155 / 180 * torch.pi) + ) + return colinear + + +@typecheck +def get_single_atom_frames( + atom_coords: Float[Tensor, "b n_atoms 3"], + token_asym_id: Int[Tensor, "b n_tokens"], + token_residue_index: Int[Tensor, "b n_tokens"], + token_backbone_frame_mask: Bool[Tensor, "b n_tokens"], + token_centre_atom_index: Int[Tensor, "b n_tokens"], + token_exists_mask: Bool[Tensor, "b n_tokens"], + atom_exists_mask: Bool[Tensor, "b n_atoms"], + atom_token_index: Int[Tensor, "b n_atoms"], +) -> tuple[Int[Tensor, "b n_tokens 3"], Bool[Tensor, "b n_tokens"]]: + """Makes frames for everything that is tokenized per-atom""" + # For tokens that are one atom per token, a_i, b_i, c_i for frame is: + # - token atom is assigned as b_i + # - closest atom to the token atom is a_i + # - second closest atom is c_i + + # Compute distances; n_tokens size + centre_coords, centre_mask = get_centre_positions_and_mask( + atom_coords, + atom_exists_mask, + token_centre_atom_index, + token_exists_mask, + ) + + asym_mask = rearrange(token_asym_id, "b i -> b i 1") == rearrange( + token_asym_id, "b j -> b 1 j" + ) + res_idx_mask = rearrange(token_residue_index, "b i -> b i 1") == rearrange( + token_residue_index, "b j -> b 1 j" + ) + dists = cdist(centre_coords) # Symmetric (tokens x tokens) + # Mask out distances that don't exist + centre_mask_square = und_self(centre_mask, "b i, b j -> b i j") + # restrict to intra-residue pairs with valid coords + dists = dists.masked_fill( + ~centre_mask_square | ~asym_mask | ~res_idx_mask, torch.inf + ) + + B, tokens = dists.shape[:2] + device = dists.device + + # Mask out diagonal + batch_indices = torch.arange(B, device=device)[..., None, None] + dists[batch_indices, torch.eye(tokens, device=device).bool()] = torch.inf + + _, idces = torch.topk(dists, 2, dim=-1, largest=False) # b, n_tokens, 2 + a, c = idces.unbind(dim=-1) + b = torch.arange(tokens, device=device).unsqueeze(0) # Token index + + # Convert from token index to ATOM index + batch_indices = torch.arange(B, device=device)[..., None] + abc_atom_indices = torch.stack( + [token_centre_atom_index[batch_indices, idx] for idx in [a, b, c]], + dim=-1, + ) + abc_coords_mask = torch.stack( + [centre_mask[batch_indices, idx] for idx in [a, b, c]], + dim=-1, + ).all(dim=-1) + + # Make mask for tokens within the same chain + a_res_idx = token_residue_index[batch_indices, a] + b_res_idx = token_residue_index[batch_indices, b] + c_res_idx = token_residue_index[batch_indices, c] + + a_asym, b_asym, c_asym = ( + token_asym_id[batch_indices, a], + token_asym_id[batch_indices, b], + token_asym_id[batch_indices, c], + ) + + same_residue = (a_res_idx == b_res_idx) & (b_res_idx == c_res_idx) + same_chain = (a_asym == b_asym) & (b_asym == c_asym) + + # Check for co-linearity (< 25 degrees deviation) + colinear = abc_is_colinear( + centre_coords[batch_indices, a], + centre_coords[batch_indices, b], + centre_coords[batch_indices, c], + ) + + # Positions where the token backbone was NOT already defined, shares the same + # entity_id, are not co-linear, and is actually a centre atom + mask = torch.ones_like(token_backbone_frame_mask) + for i in range(mask.shape[0]): + all_idces, counts = torch.unique(atom_token_index[i], return_counts=True) + not_single_idces = all_idces[counts != 1] + mask[i, not_single_idces] = False + + mask &= ( + ~token_backbone_frame_mask + & same_residue + & same_chain + & ~colinear + & abc_coords_mask + & token_exists_mask + ) + + return abc_atom_indices, mask + + +@typecheck +def get_frames_and_mask( + atom_coords: Float[Tensor, "b n_atoms 3"], + token_asym_id: Int[Tensor, "b n_tokens"], + token_residue_index: Int[Tensor, "b n_tokens"], + token_backbone_frame_mask: Bool[Tensor, "b n_tokens"], + token_centre_atom_index: Int[Tensor, "b n_tokens"], + token_exists_mask: Bool[Tensor, "b n_tokens"], + atom_exists_mask: Bool[Tensor, "b n_atoms"], + backbone_frame_idces: Int[Tensor, "b n_tokens 3"], + atom_token_index: Int[Tensor, "b n_atoms"], +) -> tuple[Int[Tensor, "b n_tokens 3"], Bool[Tensor, "b n_tokens"]]: + """Computes union of defined backbone frames and single atom frames""" + single_atom_frame_idces, single_atom_frames_mask = get_single_atom_frames( + atom_coords=atom_coords, + token_asym_id=token_asym_id, + token_residue_index=token_residue_index, + token_backbone_frame_mask=token_backbone_frame_mask, + token_centre_atom_index=token_centre_atom_index, + token_exists_mask=token_exists_mask, + atom_exists_mask=atom_exists_mask, + atom_token_index=atom_token_index, + ) + + frame_idces = backbone_frame_idces.clone() + mask = repeat(single_atom_frames_mask, "b n -> b n 3") + frame_idces[mask] = single_atom_frame_idces[mask] + + all_frames_mask = single_atom_frames_mask | token_backbone_frame_mask + + return frame_idces, all_frames_mask diff --git a/chai_lab/ranking/plddt.py b/chai_lab/ranking/plddt.py new file mode 100644 index 0000000..525c0dd --- /dev/null +++ b/chai_lab/ranking/plddt.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass + +from einops import repeat +from torch import Tensor + +import chai_lab.ranking.utils as rutils +from chai_lab.utils.tensor_utils import masked_mean +from chai_lab.utils.typing import Bool, Float, Int, typecheck + + +@typecheck +@dataclass +class PLDDTScores: + """ + complex_plddt: plddt score of the complex + per_chain_plddt: plddt score for each chain in the complex + per_atom_plddt: plddt score for each atom in the complex + """ + + complex_plddt: Float[Tensor, "..."] + per_chain_plddt: Float[Tensor, "... c"] + per_atom_plddt: Float[Tensor, "... a"] + + +@typecheck +def plddt( + logits: Float[Tensor, "... a bins"], + mask: Bool[Tensor, "... a"], + bin_centers: Float[Tensor, "bins"], + per_residue: bool = False, +) -> Float[Tensor, "..."] | Float[Tensor, "... a"]: + expectations = rutils.expectation(logits, bin_centers) + if per_residue: + return expectations + else: + return masked_mean(mask, expectations, dim=-1) + + +@typecheck +def per_chain_plddt( + logits: Float[Tensor, "... a bins"], + atom_mask: Bool[Tensor, "... a"], + asym_id: Int[Tensor, "... a"], + bin_centers: Float[Tensor, "bins"], +) -> Float[Tensor, "... c"]: + chain_masks, _ = rutils.get_chain_masks_and_asyms(asym_id, atom_mask) + logits = repeat(logits, "... a b -> ... c a b", c=chain_masks.shape[-2]) + return plddt(logits, chain_masks, bin_centers, per_residue=False) + + +@typecheck +def get_scores( + lddt_logits: Float[Tensor, "... a bins"], + atom_mask: Bool[Tensor, "... a"], + atom_asym_id: Int[Tensor, "... a"], + bin_centers: Float[Tensor, "bins"], +) -> PLDDTScores: + return PLDDTScores( + complex_plddt=plddt( + logits=lddt_logits, + mask=atom_mask, + bin_centers=bin_centers, + per_residue=False, + ), + per_atom_plddt=plddt( + logits=lddt_logits, + mask=atom_mask, + bin_centers=bin_centers, + per_residue=True, + ), + per_chain_plddt=per_chain_plddt( + logits=lddt_logits, + atom_mask=atom_mask, + asym_id=atom_asym_id, + bin_centers=bin_centers, + ), + ) diff --git a/chai_lab/ranking/ptm.py b/chai_lab/ranking/ptm.py new file mode 100644 index 0000000..25fbcc7 --- /dev/null +++ b/chai_lab/ranking/ptm.py @@ -0,0 +1,217 @@ +from dataclasses import dataclass + +import torch +from einops import rearrange, reduce, repeat +from torch import Tensor + +from chai_lab.ranking.utils import expectation, get_chain_masks_and_asyms +from chai_lab.utils.tensor_utils import und +from chai_lab.utils.typing import Bool, Float, Int, typecheck + + +@typecheck +@dataclass +class PTMScores: + """ + complex_ptm: pTM score of the complex + interface_ptm: ipTM score of the complex + per_chain_ptm: pTM score for each chain in the complex + per_chain_pair_iptm: ipTM score for each chain pair in the complex + """ + + complex_ptm: Float[Tensor, "..."] + interface_ptm: Float[Tensor, "..."] + per_chain_ptm: Float[Tensor, "... c"] + per_chain_pair_iptm: Float[Tensor, "... c c"] + + +@typecheck +def tm_d0(n_tokens: Float[Tensor, "*dims"]) -> Float[Tensor, "*dims"]: + """Compute TM-Score d0 from the number of tokens""" + n_tokens = torch.clamp_min(n_tokens, 19) + return 1.24 * (n_tokens - 15) ** (1.0 / 3) - 1.8 + + +@typecheck +def _compute_ptm( + logits: Float[Tensor, "... n n bins"], + query_res_mask: Bool[Tensor, "... n"], + query_has_frame_mask: Bool[Tensor, "... n"], + key_res_mask: Bool[Tensor, "... n"], + bin_centers: Float[Tensor, "bins"], +) -> Float[Tensor, "..."]: + """ + Compute predicted TM score, normalized by the number of "key" tokens + """ + num_key_tokens = reduce(key_res_mask, "... n -> ...", "sum").to(logits.dtype) + # compute pairwise-TM normalized by the number of key tokens + d0 = rearrange(tm_d0(num_key_tokens), "... -> ... 1") + bin_weights: Float[Tensor, "bins"] = 1 / (1 + (bin_centers / d0) ** 2) + # btm has shape (b,bins). Need to broadcast with probs + # of shape (b,n,n,bins) + bin_weights = rearrange(bin_weights, "... bins -> ... 1 1 bins") + # determine key-query pairs with valid logits + valid_pairs = und( + query_has_frame_mask & query_res_mask, key_res_mask, "... i, ... j -> ... i j" + ) + # compute per-pair expected TM scores + expected_pair_tm = expectation(logits, bin_weights) + # normalized scores by the number of key tokens + num_key_tokens = rearrange(num_key_tokens, "... -> ... 1 1") + qk_weights = valid_pairs.float() / torch.clamp_min(num_key_tokens, 1) + # (b i j) -> (b i) + query_key_tm = torch.sum(qk_weights * expected_pair_tm, dim=-1) + # want to select the row with the most optimistic logits + # and compute TM for this rows predicted alignment + return torch.max(query_key_tm, dim=-1)[0] + + +@typecheck +def complex_ptm( + pae_logits: Float[Tensor, "... n n n_bins"], + token_exists_mask: Bool[Tensor, "... n"], + valid_frames_mask: Bool[Tensor, "... n"], + bin_centers: Float[Tensor, "n_bins"], +) -> Float[Tensor, "..."]: + """Compute pTM score of the complex""" + return _compute_ptm( + logits=pae_logits, + query_res_mask=token_exists_mask, + query_has_frame_mask=valid_frames_mask, + key_res_mask=token_exists_mask, + bin_centers=bin_centers, + ) + + +@typecheck +def interface_ptm( + pae_logits: Float[Tensor, "... n n n_bins"], + token_exists_mask: Bool[Tensor, "... n"], + valid_frames_mask: Bool[Tensor, "... n"], + bin_centers: Float[Tensor, "n_bins"], + token_asym_id: Int[Tensor, "... n"], +) -> Float[Tensor, "..."]: + """Compute Interface pTM score + + ipTM is the max TM score over chains c \in C, restricting + to interactions between c and C - {c}. + """ + query_res_mask, _ = get_chain_masks_and_asyms( + asym_id=token_asym_id, mask=token_exists_mask + ) + + per_chain_ptm = _compute_ptm( + logits=rearrange(pae_logits, "... i j n_bins -> ... 1 i j n_bins"), + query_res_mask=query_res_mask, + query_has_frame_mask=rearrange(valid_frames_mask, "... n -> ... 1 n"), + key_res_mask=~query_res_mask & rearrange(token_exists_mask, "... n -> ... 1 n"), + bin_centers=bin_centers, + ) + + return torch.max(per_chain_ptm, dim=-1)[0] + + +@typecheck +def per_chain_pair_iptm( + pae_logits: Float[Tensor, "... n n n_bins"], + token_exists_mask: Bool[Tensor, "... n"], + valid_frames_mask: Bool[Tensor, "... n"], + bin_centers: Float[Tensor, "n_bins"], + token_asym_id: Int[Tensor, "... n"], + batched=False, +) -> tuple[Float[Tensor, "... n_chains n_chains"], Int[Tensor, "n_chains"]]: + """Compute pairwise pTM score for each chain pair""" + chain_mask, asyms = get_chain_masks_and_asyms( + asym_id=token_asym_id, mask=token_exists_mask + ) + c = asyms.numel() + size = 32 * chain_mask.numel() ** 2 * c**2 + + batched = batched and size < 2**32 + + if not batched: + # in the interest of saving memory we compute this in a for-loop + results = [] + for i in range(c): + result = _compute_ptm( + logits=rearrange(pae_logits, "... i j n_bins -> ... 1 i j n_bins"), + query_res_mask=repeat(chain_mask[..., i, :], "... n -> ... k n", k=c), + query_has_frame_mask=rearrange(valid_frames_mask, "... n -> ... 1 n"), + key_res_mask=chain_mask, + bin_centers=bin_centers, + ) + results.append(result) + return torch.stack(results, dim=-2), asyms # b, query_chain, key_chain + else: + # compute batched + query_mask = repeat(chain_mask, "... c n -> ... c k n", k=c) + key_mask = repeat(chain_mask, "... c n -> ... k c n", k=c) + result = _compute_ptm( + logits=rearrange(pae_logits, "... i j n_bins -> ... 1 1 i j n_bins"), + query_res_mask=query_mask, + query_has_frame_mask=rearrange(valid_frames_mask, "... n -> ... 1 1 n"), + key_res_mask=key_mask, + bin_centers=bin_centers, + ) + return result, asyms + + +@typecheck +def per_chain_ptm( + pae_logits: Float[Tensor, "... n n n_bins"], + token_exists_mask: Bool[Tensor, "... n"], + valid_frames_mask: Bool[Tensor, "... n"], + bin_centers: Float[Tensor, "n_bins"], + token_asym_id: Int[Tensor, "... n"], +) -> tuple[Float[Tensor, "... n_chains"], Int[Tensor, "n_chains"]]: + """Computes pTM for each chain in the input""" + chain_mask, unique_asyms = get_chain_masks_and_asyms( + asym_id=token_asym_id, mask=token_exists_mask + ) + per_chain_ptm = _compute_ptm( + logits=rearrange(pae_logits, "... i j n_bins -> ... 1 i j n_bins"), + query_res_mask=chain_mask, + query_has_frame_mask=rearrange(valid_frames_mask, "... n -> ... 1 n"), + key_res_mask=chain_mask, + bin_centers=bin_centers, + ) + return per_chain_ptm, unique_asyms + + +@typecheck +def get_scores( + pae_logits: Float[Tensor, "... n n n_bins"], + token_exists_mask: Bool[Tensor, "... n"], + valid_frames_mask: Bool[Tensor, "... n"], + bin_centers: Float[Tensor, "n_bins"], + token_asym_id: Int[Tensor, "... n"], +) -> PTMScores: + return PTMScores( + complex_ptm=complex_ptm( + pae_logits=pae_logits, + token_exists_mask=token_exists_mask, + valid_frames_mask=valid_frames_mask, + bin_centers=bin_centers, + ), + interface_ptm=interface_ptm( + pae_logits=pae_logits, + token_exists_mask=token_exists_mask, + valid_frames_mask=valid_frames_mask, + bin_centers=bin_centers, + token_asym_id=token_asym_id, + ), + per_chain_pair_iptm=per_chain_pair_iptm( + pae_logits=pae_logits, + token_exists_mask=token_exists_mask, + valid_frames_mask=valid_frames_mask, + bin_centers=bin_centers, + token_asym_id=token_asym_id, + )[0], + per_chain_ptm=per_chain_ptm( + pae_logits=pae_logits, + token_exists_mask=token_exists_mask, + valid_frames_mask=valid_frames_mask, + bin_centers=bin_centers, + token_asym_id=token_asym_id, + )[0], + ) diff --git a/chai_lab/ranking/rank.py b/chai_lab/ranking/rank.py new file mode 100644 index 0000000..01687a6 --- /dev/null +++ b/chai_lab/ranking/rank.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor + +import chai_lab.ranking.clashes as clashes +import chai_lab.ranking.plddt as plddt +import chai_lab.ranking.ptm as ptm +import chai_lab.ranking.utils as rutils +from chai_lab.utils.typing import Bool, Float, Int, typecheck + + +@typecheck +@dataclass +class SampleRanking: + """Sample Ranking Data + + token_chain_masks: a tensor of shape (..., c, n) containing a boolean mask + for each chain in the input + token_asyms: a tensor of shape (c,) containing the unique asym ids for + each chain in the sample. The token asyms are sorted numerically. + token_chain_masks: a tensor of shape (..., c, n) containing a mask + for each chain in the sample. The order of the chains is the same + as token_asyms. + aggregate_score: a tensor of shape (...) containing the aggregate ranking + score for the sample + ptm_scores: see ptm.get_scores for a description of the ptm scores + clash_scores: a dictionary of clash scores + plddt_scores: see plddt.PLDDTScores for a description of the plddt scores + """ + + token_chain_masks: Bool[Tensor, "... c n"] + token_asyms: Int[Tensor, "c"] + aggregate_score: Float[Tensor, "..."] + ptm_scores: ptm.PTMScores + clash_scores: clashes.ClashScores + plddt_scores: plddt.PLDDTScores + + +@typecheck +def rank( + atom_coords: Float[Tensor, "... a 3"], + atom_mask: Bool[Tensor, "... a"], + atom_token_index: Int[Tensor, "... a"], + token_exists_mask: Bool[Tensor, "... n"], + token_asym_id: Int[Tensor, "... n"], + token_entity_type: Int[Tensor, "... n"], + token_valid_frames_mask: Bool[Tensor, "... n"], + # lddt + lddt_logits: Float[Tensor, "... a lddt_bins"], + lddt_bin_centers: Float[Tensor, "lddt_bins"], + # pae + pae_logits: Float[Tensor, "... n n pae_bins"], + pae_bin_centers: Float[Tensor, "pae_bins"], + # clash + clash_threshold: float = 1.1, + max_clashes: int = 100, + max_clash_ratio: float = 0.5, +) -> SampleRanking: + """ + Compute ranking scores for a sample. + In addition to the pTM/ipTM aggregate score, we also return chain + and inter-chain level statistics for pTM and clashes. + see documentation for SampleRanking for a complete description. + """ + + ptm_scores = ptm.get_scores( + pae_logits=pae_logits, + token_exists_mask=token_exists_mask, + valid_frames_mask=token_valid_frames_mask, + bin_centers=pae_bin_centers, + token_asym_id=token_asym_id, + ) + clash_scores = clashes.get_scores( + atom_coords=atom_coords, + atom_mask=atom_mask, + atom_asym_id=torch.gather( + token_asym_id, + dim=-1, + index=atom_token_index.long(), + ), + atom_entity_type=torch.gather( + token_entity_type, + dim=-1, + index=atom_token_index.long(), + ), + max_clashes=max_clashes, + max_clash_ratio=max_clash_ratio, + clash_threshold=clash_threshold, + ) + + plddt_scores = plddt.get_scores( + lddt_logits=lddt_logits, + atom_mask=atom_mask, + bin_centers=lddt_bin_centers, + atom_asym_id=torch.gather( + token_asym_id, + dim=-1, + index=atom_token_index.long(), + ), + ) + + # aggregate score + aggregate_score = ( + 0.2 * ptm_scores.complex_ptm + + 0.8 * ptm_scores.interface_ptm + - 100 * clash_scores.has_clashes.float() + ) + + chain_masks, asyms = rutils.get_chain_masks_and_asyms( + asym_id=token_asym_id, + mask=token_exists_mask, + ) + + return SampleRanking( + token_chain_masks=chain_masks, + token_asyms=asyms, + aggregate_score=aggregate_score, + ptm_scores=ptm_scores, + clash_scores=clash_scores, + plddt_scores=plddt_scores, + ) diff --git a/chai_lab/ranking/utils.py b/chai_lab/ranking/utils.py new file mode 100644 index 0000000..dbdf5a4 --- /dev/null +++ b/chai_lab/ranking/utils.py @@ -0,0 +1,83 @@ +import torch +from einops import rearrange +from torch import Tensor + +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.utils.tensor_utils import cdist +from chai_lab.utils.typing import Bool, Float, Int, typecheck + + +@typecheck +def get_chain_masks_and_asyms( + asym_id: Int[Tensor, "... n"], + mask: Bool[Tensor, "... n"], +) -> tuple[Bool[Tensor, "... c n"], Int[Tensor, "c"]]: + """ + Returns a mask for each chain and the unique asym ids + """ + unique_asyms = torch.unique(asym_id[mask]) + sorted_unique_asyms, _ = torch.sort(unique_asyms) + # shape: (..., max_num_chains, n) + chain_masks = rearrange(asym_id, "... n -> ... 1 n") == rearrange( + sorted_unique_asyms, "nc -> nc 1" + ) # shape: (..., n, max_num_chains) + return chain_masks & rearrange(mask, "... n -> ... 1 n"), sorted_unique_asyms + + +@typecheck +def get_interface_mask( + coords: Float[Tensor, "... n 3"], + asym_id: Int[Tensor, "... n"], + mask: Bool[Tensor, "... n"], + interface_threshold: float, +) -> Bool[Tensor, "... n n"]: + valid_mask = rearrange(asym_id, "... n -> ... n 1") != rearrange( + asym_id, "... n -> ... 1 n" + ) + valid_mask &= rearrange(mask, "... n -> ... n 1") & rearrange( + mask, "... n -> ... 1 n" + ) + dists = torch.masked_fill(cdist(coords), ~valid_mask, torch.inf) + min_dists, _ = torch.min(dists, dim=-1) + return min_dists < interface_threshold + + +@typecheck +def expectation( + logits: Float[Tensor, "... bins"], + weights: Float[Tensor, "... bins"], +) -> Float[Tensor, "..."]: # last dim will be dropped + logits = torch.softmax(logits, dim=-1) + return (logits * weights).sum(dim=-1) + + +@typecheck +def num_atoms_per_chain( + atom_mask: Bool[Tensor, "... a"], + asym_id: Int[Tensor, "... a"], +) -> Int[Tensor, "... c"]: + masks, _ = get_chain_masks_and_asyms(asym_id, atom_mask) + return masks.sum(dim=-1) + + +@typecheck +def chain_is_polymer( + asym_id: Int[Tensor, "... n"], + mask: Bool[Tensor, "... n"], + entity_type: Int[Tensor, "... n"], +) -> Bool[Tensor, "... c"]: + chain_masks, _ = get_chain_masks_and_asyms(asym_id, mask) + polymer_types = torch.tensor( + [ + EntityType.PROTEIN.value, + EntityType.RNA.value, + EntityType.DNA.value, + EntityType.POLYMER_HYBRID.value, + ], + device=entity_type.device, + ) + is_polymer = torch.any(entity_type.unsqueeze(-1) == polymer_types, dim=-1) + chain_is_polymer = [] + for polymer_mask in chain_masks.unbind(dim=-2): + chain_is_polymer.append(torch.any(is_polymer & polymer_mask, dim=-1)) + return torch.stack(chain_is_polymer, dim=-1) diff --git a/chai_lab/utils/__init__.py b/chai_lab/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chai_lab/utils/defaults.py b/chai_lab/utils/defaults.py new file mode 100644 index 0000000..7731046 --- /dev/null +++ b/chai_lab/utils/defaults.py @@ -0,0 +1,7 @@ +from typing import TypeVar + +T = TypeVar("T") + + +def default(x: T | None, y: T) -> T: + return x if x is not None else y diff --git a/chai_lab/utils/dict.py b/chai_lab/utils/dict.py new file mode 100644 index 0000000..2ae3d89 --- /dev/null +++ b/chai_lab/utils/dict.py @@ -0,0 +1,19 @@ +from typing import TypeVar + +K = TypeVar("K") +V = TypeVar("V") + + +def list_dict_to_dict_list(list_dict: list[dict[K, V]]) -> dict[K, list[V]]: + """ + Converts a list of dicts that contain the same keys to a dict of lists, where each + list contains an ordered list of values of the corresponding dict. + """ + if len(list_dict) == 0: + return {} + + keys = list_dict[0].keys() + if any(d.keys() != keys for d in list_dict): + raise ValueError("All dicts must have the same keys") + + return {k: [d[k] for d in list_dict] for k in keys} diff --git a/chai_lab/utils/paths.py b/chai_lab/utils/paths.py new file mode 100644 index 0000000..d816cd6 --- /dev/null +++ b/chai_lab/utils/paths.py @@ -0,0 +1,62 @@ +import dataclasses +from pathlib import Path +from typing import Final + +import requests + +# use this path object to specify location +# of anything within repository +repo_root: Final[Path] = Path(__file__).parents[2].absolute() + +# minimal sanity check in case we start moving things around +assert repo_root.exists() + + +def download(http_url: str, path: Path): + print(f"downloading {http_url}") + tmp_path = path.with_suffix(".download_tmp") + + with requests.get(http_url, stream=True) as response: + response.raise_for_status() # Check if the request was successful + # Open a local file with the specified name + path.parent.mkdir(exist_ok=True, parents=True) + with tmp_path.open("wb") as file: + # Download the file in chunks + for chunk in response.iter_content(chunk_size=8192): + if chunk: # Filter out keep-alive new chunks + file.write(chunk) + tmp_path.rename(path) + assert path.exists() + + +@dataclasses.dataclass +class Downloadable: + url: str + path: Path + + def get_path(self) -> Path: + # downloads artifact if necessary + if not self.path.exists(): + download(self.url, path=self.path) + + return self.path + + +cached_conformers = Downloadable( + url="https://chaiassets.com/chai1-inference-depencencies/conformers.apkl", + path=repo_root.joinpath("downloads", "conformers.apkl"), +) + + +def chai1_component(comp_key: str) -> Path: + """ + Downloads exported model, stores in locally in the repo/downloads + comp_key: e.g. '384/trunk.pt2' + """ + assert comp_key.endswith(".pt2") + url = f"https://chaiassets.com/chai1-inference-depencencies/models/{comp_key}" + result = repo_root.joinpath("downloads", "models", comp_key) + if not result.exists(): + download(url, result) + + return result diff --git a/chai_lab/utils/pickle.py b/chai_lab/utils/pickle.py new file mode 100644 index 0000000..9cda4b3 --- /dev/null +++ b/chai_lab/utils/pickle.py @@ -0,0 +1,19 @@ +import antipickle +import torch + + +class TorchAntipickleAdapter(antipickle.AbstractAdapter): + typestring = "torch" + + def __init__(self): + self.cpu_device = torch.device("cpu") + + def check_type(self, obj): + return type(obj) is torch.Tensor # ignore inherited classes + + def to_dict(self, obj): + assert obj.device == self.cpu_device, "serializing only cpu tensors" + return {"data": antipickle.wrap(obj.numpy())} # use numpy serialization + + def from_dict(self, d): + return torch.from_numpy(d["data"]) diff --git a/chai_lab/utils/tensor_utils.py b/chai_lab/utils/tensor_utils.py new file mode 100644 index 0000000..9fbc0ba --- /dev/null +++ b/chai_lab/utils/tensor_utils.py @@ -0,0 +1,288 @@ +import typing +from functools import lru_cache +from typing import TypeVar + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor + +from chai_lab.utils.defaults import default +from chai_lab.utils.typing import Bool, Float, UInt8, typecheck + + +@typecheck +def cdist( + x: Float[Tensor, "... p m"], + y: Float[Tensor, "... r m"] | None = None, + p: float = 2.0, +) -> Float[Tensor, "... p r"]: + y = default(y, x) + assert x.ndim == y.ndim + + _threshold = 2147400000 + n, m = x.shape[-2], y.shape[-2] + + flat_size = torch.prod(torch.tensor(x.shape[:-2])) * n * m + + if x.is_cuda and flat_size > _threshold: + # Torch cdist without mm fails when the total number of + # points is > _threshold (in dimension 3) + # or 8192 points for batch size 32. + # To preserve accuracy, we fallback to naive distances + return _naive_pairwise_distances(x, y) + + return torch.cdist(x1=x, x2=y, compute_mode="donot_use_mm_for_euclid_dist", p=p) + + +@typecheck +def _naive_pairwise_distances( + x: Float[Tensor, "... p m"], + y: Float[Tensor, "... r m"] | None = None, + eps: float = 1e-10, +) -> Float[Tensor, "... p r"]: + y = default(y, x) + diff = x.unsqueeze(-2) - y.unsqueeze(-3) + + return diff.pow_(2).sum(dim=-1).add_(eps).sqrt_() + + +@typecheck +def masked_mean( + mask: Bool[Tensor, "..."], + value: Tensor, + dim: int | tuple[int, ...], + keepdim=False, +) -> Tensor: + mask = mask.expand(*value.shape) + num = torch.sum(mask * value, dim=dim, keepdim=keepdim) + denom = torch.sum(mask, dim=dim, keepdim=keepdim).clamp(min=1) + return num / denom + + +@typecheck +def one_hot(x: Tensor, v_bins: Tensor) -> Tensor: + """One hot encoding; v_bins should N-1 bins where N is desired bins.""" + bins = torch.searchsorted(v_bins, x) + return F.one_hot(bins, v_bins.shape[-1] + 1).float() + + +@lru_cache() +def _get_individual_und_patterns(multipattern: str) -> list[str]: + assert isinstance(multipattern, str), "pattern goes as last argument" + left_parts, right_part = multipattern.split("->") + assert "(" not in right_part, "parenthesis not supported for now" + result = [] + + all_left_ids = set() + all_left_parts_have_ellipsis = True + + for left_part in left_parts.split(","): + left_ids = set(left_part.split()) + if "..." not in left_ids: + all_left_parts_have_ellipsis = False + all_left_ids.update(left_ids) + right_parts = [] + for token in right_part.split(): + if token == "1" or token in left_ids: # '...' should be in left ids + right_parts.append(token) + elif token.isidentifier(): + right_parts.append("1") + elif token == "...": + raise RuntimeError( + f"Ellipis not in one of left sides of {multipattern=}" + ) + else: + raise RuntimeError(f"Unknown {token=} in {multipattern=}") + result.append(f"{left_part} -> " + " ".join(right_parts)) + + if "..." in right_part.split(): + msg = "for now ALL or NONE left parts should have ellipsis (...) " + assert all_left_parts_have_ellipsis, msg + + unk_ids = [ + x + for x in right_part.split() + if x not in all_left_ids and x != "1" and x != "..." + ] + assert len(unk_ids) == 0, f"{unk_ids=} not found on left side of {multipattern}" + + return result + + +@typing.overload +def und(t1: Tensor, pattern: str) -> Tensor: ... + + +@typing.overload +def und(t1: Tensor, t2: Tensor, pattern: str) -> Tensor: ... + + +@typing.overload +def und(t1: Tensor, t2: Tensor, t3: Tensor, pattern: str) -> Tensor: ... + + +@typing.overload +def und(t1: Tensor, t2: Tensor, t3: Tensor, t4: Tensor, pattern: str) -> Tensor: ... + + +def und(*args): + """ + Micro-extension to einops. + + Performs & (logical_and) for several masks. + Similar to einsum over masks, but additionally can add/remove 1-dims. + + > und(mask1, mask2, "b i, b j -> b 1 i j") + """ + *tensors, multipattern = args + patterns = _get_individual_und_patterns(multipattern) + + result = None + for arg_val, arg_pattern in zip(tensors, patterns, strict=True): + assert arg_val.dtype == torch.bool + if result is None: + result = rearrange(arg_val, arg_pattern) + else: + result = result & rearrange(arg_val, arg_pattern) + return result + + +def und_self(mask: Tensor, pattern: str) -> Tensor: + """ + Performs & (logical_and) for two replicas of the same tensor + + > und_self(mask, "b i, b j -> b 1 i j") + is a better version of + > und(mask, mask, "b i, b j -> b 1 i j") + """ + return und(mask, mask, pattern) + + +# 255 is not an ASCII char +TENSORCODE_PAD_TOKEN = torch.iinfo(torch.uint8).max + + +@typecheck +def string_to_tensorcode( + input: str, + pad_to_length: int | None = None, + device: torch.device | None = None, +) -> UInt8[Tensor, "l"]: + """ + Converts an ASCII string to a tensor of integers. + + If pad_to_length is specified, the output tensor will have this length, and we add a + special padding character if the tensor has less than the specified length. + + The minimum value of the output tensor is 0, and the maximum is 127 (excluding the + padding token, which can be 255). + """ + assert input.isascii(), "Expected input to be ASCII" + ords = [ord(c) for c in input] + + tensorcode = torch.tensor(ords, dtype=torch.uint8, device=device) + if pad_to_length is None: + return tensorcode + + input_length = len(input) + assert ( + pad_to_length >= input_length + ), f"Expected {input_length=} to be shorter than {pad_to_length=} for {input=}" + + return F.pad( + tensorcode, + (0, pad_to_length - input_length), + value=TENSORCODE_PAD_TOKEN, + ) + + +@typecheck +def tensorcode_to_string(tensor: UInt8[Tensor, "l"]) -> str: + """ + Applies the inverse of the string_to_tensorcode function + """ + assert tensor.device == torch.device("cpu") + chars = [chr(i) for i in tensor if i != TENSORCODE_PAD_TOKEN] + return "".join(chars) + + +@typecheck +def batch_tensorcode_to_string( + tensor: UInt8[Tensor, "*dims l"], +) -> list[str]: + tensor = rearrange(tensor, "... l -> (...) l") + tensor = tensor[tensor.amax(dim=1) > 0, :] + return [ + "".join(chr(i) for i in row if i != TENSORCODE_PAD_TOKEN) + for row in tensor.tolist() + ] + + +def unique_indexes(x: torch.Tensor, dim=-1, sorted: bool = True): + """Implements return_index=True behavior for torch.unique. + + See https://numpy.org/doc/stable/reference/generated/numpy.unique.html for info and + https://github.com/pytorch/pytorch/issues/36748 for context.""" + assert x.size(dim) > 0 + + unique, inverse = torch.unique(x, return_inverse=True, sorted=True, dim=dim) + perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) + inverse, perm = inverse.flip([0]), perm.flip([0]) + inverse = inverse.new_empty(unique.size(dim)).scatter_(0, inverse, perm) + if sorted: + inverse = inverse.sort().values + + return unique, inverse + + +T = TypeVar("T") + + +# mypy is too angry when this function is directly annotated +def _move_data_to_device(x, device: torch.device): + if x is None: + return None + if isinstance(x, (str, int, float, bool)): + return x + if isinstance(x, torch.Tensor): + return x.to(device=device) + elif isinstance(x, dict): + return {k: move_data_to_device(v, device) for k, v in x.items()} + elif isinstance(x, list): + return [move_data_to_device(el, device) for el in x] + elif isinstance(x, tuple): + return tuple(move_data_to_device(el, device) for el in x) + else: + raise NotImplementedError(type(x)) + + +def move_data_to_device(x: T, device: torch.device) -> T: + return _move_data_to_device(x, device=device) + + +def set_seed(seed_sequence: list[int]) -> None: + """ + Seeds numpy, torch, and Python. + + This function is heavily inspired by Lightning's pl_worker_init_function. + """ + import random + + import numpy as np + + # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module + np_ss = np.random.SeedSequence(seed_sequence) + torch_ss, stdlib_ss = np_ss.spawn(2) + + # Seed numpy, use 128 bits (4 x 32-bit words) + np.random.seed(np_ss.generate_state(4)) + + # Seed torch + torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) + + # Seed python, use 128 bits expressed as an integer + stdlib_seed = ( + stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1] + ).sum() + random.seed(stdlib_seed) diff --git a/chai_lab/utils/timeout.py b/chai_lab/utils/timeout.py new file mode 100644 index 0000000..0ed458e --- /dev/null +++ b/chai_lab/utils/timeout.py @@ -0,0 +1,98 @@ +""" +Timeout utility for a function, creates a new process + +Implementation modified from: +https://www.reddit.com/r/Python/comments/8t9bk4/the_absolutely_easiest_way_to_time_out_a_function/ +""" + +import multiprocessing +import queue as _queue +from enum import Enum +from functools import wraps +from multiprocessing import Process, Queue +from typing import Any, assert_never + + +# TODO: This is dangerous: revert once the underlying problem in rdkit is fixed +# RDKit Issue(https://github.com/rdkit/rdkit/discussions/7289) +class Undaemonize(object): + """Context Manager to resolve AssertionError: daemonic processes are not allowed to have children + See https://stackoverflow.com/questions/6974695/python-process-pool-non-daemonic""" + + def __init__(self): + self.conf: dict = multiprocessing.process.current_process()._config # type: ignore + if "daemon" in self.conf: + self.daemon_status_set = True + else: + self.daemon_status_set = False + self.daemon_status_value = self.conf.get("daemon") + + def __enter__(self): + if self.daemon_status_set: + del self.conf["daemon"] + + def __exit__(self, type, value, traceback): + if self.daemon_status_set: + self.conf["daemon"] = self.daemon_status_value + + +class HandlerStatus(Enum): + SUCCESS = 0 + EXCEPTION = 1 + + +class ChildProcessException(Exception): + pass + + +def timeout(timeout: float | int) -> Any: + """Force function to timeout after 'seconds'. + + Returns: + The return value of the wrapped function. + Raises: + TimeoutError if the function does not return before the timeout. + """ + + def handler(queue, func, args, kwargs) -> None: + try: + queue.put((HandlerStatus.SUCCESS, func(*args, **kwargs))) + except Exception as e: + queue.put((HandlerStatus.EXCEPTION, e)) + + def decorator(func): + @wraps(func) + def new_fn(*args, **kwargs): + queue: Queue = Queue() + proc = Process( + target=handler, args=(queue, func, args, kwargs), daemon=True + ) + with Undaemonize(): + proc.start() + proc.join(timeout=float(timeout)) + if proc.is_alive(): + proc.terminate() + proc.join() + raise TimeoutError(f"Function {func} timed out after {timeout} seconds") + else: + # When child process dies unexpectedly Queue.get waits indefinitely. + # See Issue(https://bugs.python.org/issue43805) + # prevent queue from hanging with another very short timeout + try: + status, value = queue.get(timeout=0.1) + except _queue.Empty: + # in this case, child process has died unexpectedly + raise ChildProcessException("Child process died unexpectedly") + + match status: + case HandlerStatus.SUCCESS: + return value + case HandlerStatus.EXCEPTION: + # Re-raise the exception we caught in the child process + raise value + + assert_never(status) + + return new_fn + + return decorator diff --git a/chai_lab/utils/typing.py b/chai_lab/utils/typing.py new file mode 100644 index 0000000..91c4d16 --- /dev/null +++ b/chai_lab/utils/typing.py @@ -0,0 +1,44 @@ +import typing + +from beartype import beartype +from jaxtyping import ( + Bool, + Float, + Float32, + Int, + Int32, + Num, + Shaped, + TypeCheckError, + UInt8, + jaxtyped, +) + +# Modules are only loaded and executed the first time they are imported, so the value of +# should_typecheck will constant over the lifetime of the program. +should_typecheck = True + + +Func = typing.TypeVar("Func") + + +def typecheck(cls_or_func: Func) -> Func: + if should_typecheck: + return jaxtyped(typechecker=beartype)(cls_or_func) + else: + return cls_or_func + + +__all__ = [ + "typecheck", + "TypeCheckError", + # re-export jaxtyping types + "Bool", + "Float", + "Int", + "Int32", + "Float32", + "Num", + "Shaped", + "UInt8", +] diff --git a/examples/predict_structure.py b/examples/predict_structure.py new file mode 100644 index 0000000..5b4b1a7 --- /dev/null +++ b/examples/predict_structure.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import torch + +from chai_lab.chai1 import run_inference + +# We use fasta-like format for inputs. +# Every record may encode protein, ligand, RNA or DNA +# see example below + +example_fasta = """ +>protein|example-of-long-protein +AGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASPRGEPRAPWVEQEGPEYWDRETQKYKRQAQTDRVSLRNLRGYYNQSEAGSHTLQWMFGCDLGPDGRLLRGYDQSAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAAREAEQRRAYLEGTCVEWLRRYLENGKETLQRAEHPKTHVTHHPVSDHEATLRCWALGFYPAEITLTWQWDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPEPLTLRWEP +>protein|example-of-short-protein +AIQRTPKIQVYSRHPAENGKSNFLNCYVSGFHPSDIEVDLLKNGERIEKVEHSDLSFSKDWSFYLLYYTEFTPTEKDEYACRVNHVTLSQPKIVKWDRDM +>protein|example-of-peptide +GAAL +>ligand|and-example-for-ligand-encoded-as-smiles +CCCCCCCCCCCCCC(=O)O +""".strip() + +fasta_path = Path("/tmp/example.fasta") +fasta_path.write_text(example_fasta) + +output_paths = run_inference( + fasta_file=fasta_path, + output_dir=Path("/tmp/outputs"), + # 'default' setup + num_trunk_recycles=3, + num_diffn_timesteps=200, + seed=42, + device=torch.device("cuda:0"), + use_esm_embeddings=True, +) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..870a3f2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,69 @@ +# important: install in editable mode +[build-system] +requires = [ + "hatchling>=1.20", # build backend + "hatch-requirements-txt", # plugin, to parse requirements.txt +] +build-backend = "hatchling.build" + + +[project] +name = "chai_lab" +description = "Chai Discovery tools for AI + protein research." +requires-python = ">=3.11" +authors = [{ name = "Chai Discovery" }] +# see both defined below +dynamic = ["version", "dependencies"] + +[tool.hatch.version] +path = "chai_lab/__init__.py" +[tool.hatch.metadata.hooks.requirements_txt] +files = ["requirements.in"] +[tool.hatch.metadata] +allow-direct-references = true + +[tool.mypy] +check_untyped_defs = true + +# Ignore missing imports for packages with missing type stubs +[[tool.mypy.overrides]] +module = [ + "anarci.*", + "fsspec.*", + "google.*", + "joblib.*", + "needletail.*", + "numba.*", + "pyximport.*", + "rdkit.*", + "scipy.*", + "seaborn.*", + "sh.*", + "tmtools.*", + "botocore.*", + "s3fs.*", + "biotite.*", + "DockQ.*", + "boto3.*", + "transformers.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +cache_dir = "/tmp/.common_pytest_cache" + + +[tool.hatch.build.targets.sdist] +exclude = [ + "/.devcontainer", + "/.github", + "/.idea", + "/.vscode", + "/.pytest_cache", + "/assets", + "/downloads", + "/outputs", +] + +[tool.hatch.build.targets.wheel] +# should use packages from sdist section \ No newline at end of file diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000..69ceccb --- /dev/null +++ b/requirements.in @@ -0,0 +1,46 @@ +# dev-deps, still placed in the same requirements file +ruff==0.6.3 # in sync with pre-commit-hook +mypy +pytest +pre-commit + +# types/stubs are required by mypy +pandas-stubs +types-pyyaml +types-tqdm +typing-extensions +types-requests + +# CLI, administrator tools +typer~=0.12 # CLI generator +# pydantic~=2.5 # serialization/deserialization of configs + +# notebooks, plotting +ipykernel~=6.27 # needed by vs code to run notebooks in devcontainer +# seaborn +# matplotlib + +# misc +tqdm~=4.66 + +# data import/export, application-specific +gemmi~=0.6.3 # pdb/mmcif parsing +rdkit==2023.9.5 # parsing of ligands. 2023.9.6 has broken type stubs +biopython==1.83 # parsing, data access +antipickle==0.2.0 # save/load heterogeneous python structures +tmtools>=0.0.3 # Python bindings for the TM-align algorithm +# dockq metric for comparing predicted pdbs and ground truth pdbs +dockq @ git+https://github.com/bjornwallner/DockQ.git@v2.1.1 +# pip-compatible minimized version of anarci +anarci @ git+https://github.com/arogozhnikov/microANARCI@d81823395d0c3532d6e033d80b036b4aa4a4565e + +# computing, dl +numpy~=1.21 +pandas[parquet,gcp,aws]~=2.1 +# polars +einops~=0.8 +jaxtyping>=0.2.25 # versions <0.2.25 do not easily support runtime typechecking +beartype>=0.18 # compatible typechecker to use with jaxtyping +# do not use 2.2 because https://github.com/pytorch/pytorch/issues/122385 +torch~=2.3.1 +transformers~=4.44 # for esm inference \ No newline at end of file diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..1b6d203 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,12 @@ +# move ruff cache outside of worktree +cache-dir = "/tmp/.ruff_chai_cache" + + +[lint] +extend-select = ["I"] +# jaxtyping requires disabling two following errors +# https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error +ignore = ["F821", "F722"] + +[lint.isort] +known-first-party = ["chai", "chai_lab"]