diff --git a/config/environment.yml b/config/environment.yml new file mode 100644 index 0000000..1ec0e6f --- /dev/null +++ b/config/environment.yml @@ -0,0 +1,36 @@ +name: esmfold +channels: + - conda-forge + - bioconda + - pytorch +dependencies: + - conda-forge::python=3.7 + - conda-forge::setuptools=59.5.0 + - conda-forge::pip + - conda-forge::openmm=7.5.1 + - conda-forge::pdbfixer + - conda-forge::cudatoolkit==11.3.* + - conda-forge::einops + - conda-forge::fairscale + - conda-forge::omegaconf + - conda-forge::hydra-core + - conda-forge::pandas + - conda-forge::pytest + - bioconda::hmmer==3.3.2 + - bioconda::hhsuite==3.3.0 + - bioconda::kalign2==2.04 + - pytorch::pytorch=1.12.* + - pip: + - biopython==1.79 + - deepspeed==0.5.9 + - dm-tree==0.1.6 + - ml-collections==0.1.0 + - numpy==1.21.2 + - PyYAML==5.4.1 + - requests==2.26.0 + - scipy==1.7.1 + - tqdm==4.62.2 + - typing-extensions==3.10.0.2 + - pytorch_lightning==1.5.10 + - wandb==0.12.21 + - git+https://github.com/NVIDIA/dllogger.git \ No newline at end of file diff --git a/config/esmfold_config.yaml b/config/esmfold_config.yaml new file mode 100644 index 0000000..db4b87c --- /dev/null +++ b/config/esmfold_config.yaml @@ -0,0 +1,62 @@ +defaults: + - _self_ + - dataset: gearnet_ec + +task: + class: MultipleBinaryClassification + model: + output_dim: 0 + graph_construction_model: + class: GraphConstruction + node_layers: + - class: AlphaCarbonNode + edge_layers: + - class: SequentialEdge + criterion: bce + num_mlp_layer: 0 + metric: ['auprc@micro', 'f1_max'] + +compute: + array_parallelism: 10 + cpus_per_task: 5 + mem_per_cpu: 2g + timeout_min: 1440 + job_name: esmfold + partition: p.hpcl94g + gpus_per_node: 1 + gpus_per_task: 1 + tasks_per_node: 1 + +data: + batch_size: 20 + fold_batch_size: 10 + +seed: 1234 +model: pst_t6 +use_edge_attr: false +datapath: datasets +nogpu: false +metric: 'f1_max' +batch_size: 4 +num_workers: 4 +device: null + +truncation_seq_length: 5000 +toks_per_batch: 4096 +include_seq: false +aggr: concat +use_pca: null +parallel: false +pretrained: .cache/pst + +logs: + prefix: logs_pst/gearnet_data + path: ${logs.prefix}/${dataset.name}/${model}/${aggr}/${include_seq}/${seed} + +# output directory, generated dynamically on each run +hydra: + run: + dir: ${logs.path} + sweep: + dir: ${logs.prefix}/${dataset.name}/${model}/${aggr}/${include_seq}/${seed} + subdir: "" diff --git a/config/hydra/launcher/slurm.yaml b/config/hydra/launcher/slurm.yaml new file mode 100644 index 0000000..e4e1584 --- /dev/null +++ b/config/hydra/launcher/slurm.yaml @@ -0,0 +1,27 @@ +# @package hydra.launcher +_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + +# SLURM queue parameters +# partition: p.hpcl8 +partition: p.hpcl91 +# partition: p.hpcl91 + +# Job resource requirements +timeout_min: 14_400 +cpus_per_task: 12 +gpus_per_node: 1 +gpus_per_task: 1 +tasks_per_node: 1 +mem_per_cpu: 2g +nodes: 1 +exclude: hpcl9101 + +# Job naming and output +name: "pst" +submitit_folder: ./logs/submitit/%j + +# Additional settings +signal_delay_s: 5 +max_num_timeout: 0 +additional_parameters: {} +array_parallelism: 30 \ No newline at end of file diff --git a/config/pst_edge_perturb.yaml b/config/pst_edge_perturb.yaml new file mode 100644 index 0000000..95f5925 --- /dev/null +++ b/config/pst_edge_perturb.yaml @@ -0,0 +1,55 @@ +defaults: + - _self_ + - training: default + - base_model: esm2_t6 + - mode: default + +debug: false +seed: 1234 + +data: + organism: swissprot + datapath: datasets/AlphaFold/${data.organism} + graph_eps: 8.0 + crop_len: 1024 + mask_rate: 0.15 + val_datapath: datasets/dms + edge_perturb: random + +compute: + accelerator: gpu + precision: 16-mixed + strategy: ddp + num_workers: 8 + n_jobs: 10 + devices: auto + +logs: + prefix: logs_pst/random_ablation/edge_perturb_${data.edge_perturb} + path: ${logs.prefix}/${model.name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} + wandb: + enable: true + name: ${model.name}_edge_perturb_${data.edge_perturb} + tags: + - random_ablation + - organism_${data.organism} + - model_${model.name} + - edge_perturb_${data.edge_perturb} + entity: "BorgwardtLab" + project: "PST" + save_dir: ${logs.path} + +model: + k_hop: 2 + train_struct_only: true + use_edge_attr: false + gnn_type: gin + edge_dim: null + +# output directory, generated dynamically on each run +hydra: + run: + dir: ${logs.path} + sweep: + dir: ${logs.prefix}/${model.name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/config/pst_gearnet.yaml b/config/pst_gearnet.yaml index 8d75857..7d79241 100644 --- a/config/pst_gearnet.yaml +++ b/config/pst_gearnet.yaml @@ -16,6 +16,9 @@ task: num_mlp_layer: 0 metric: ['auprc@micro', 'f1_max'] +data: + edge_perturb: null + seed: 1234 model: pst_t6 use_edge_attr: false diff --git a/config/pst_gearnet_esmfold.yaml b/config/pst_gearnet_esmfold.yaml new file mode 100644 index 0000000..600a85b --- /dev/null +++ b/config/pst_gearnet_esmfold.yaml @@ -0,0 +1,51 @@ +defaults: + - _self_ + - dataset: gearnet_ec + +task: + class: MultipleBinaryClassification + model: + output_dim: 0 + graph_construction_model: + class: GraphConstruction + node_layers: + - class: AlphaCarbonNode + edge_layers: + - class: SequentialEdge + criterion: bce + num_mlp_layer: 0 + metric: ['auprc@micro', 'f1_max'] + +data: + edge_perturb: null + esmfold_structures_path: datasets/esmfold/structures/ + +seed: 1234 +model: pst_t6 +use_edge_attr: false +datapath: datasets +nogpu: false +metric: 'f1_max' +batch_size: 4 +num_workers: 4 +device: null + +truncation_seq_length: 5000 +toks_per_batch: 4096 +include_seq: false +aggr: concat +use_pca: null + +pretrained: .cache/pst + +logs: + prefix: logs_pst/gearnet_data + path: ${logs.prefix}/${dataset.name}/${model}/${aggr}/${include_seq}/${seed} + +# output directory, generated dynamically on each run +hydra: + run: + dir: ${logs.path} + sweep: + dir: ${logs.prefix}/${dataset.name}/${model}/${aggr}/${include_seq}/${seed} + subdir: "" diff --git a/config/pst_gearnet_finetune.yaml b/config/pst_gearnet_finetune.yaml index ffb7237..1c1ac74 100644 --- a/config/pst_gearnet_finetune.yaml +++ b/config/pst_gearnet_finetune.yaml @@ -39,6 +39,7 @@ compute: strategy: auto num_workers: 8 n_jobs: 10 + devices: auto pretrained: .cache/pst diff --git a/config/pst_pretrain.yaml b/config/pst_pretrain.yaml index 29824af..b142b31 100644 --- a/config/pst_pretrain.yaml +++ b/config/pst_pretrain.yaml @@ -21,6 +21,7 @@ compute: strategy: ddp num_workers: 8 n_jobs: 10 + devices: auto logs: prefix: logs_pst/pretrain diff --git a/config/pst_proteinshake.yaml b/config/pst_proteinshake.yaml index 41a8430..76584e9 100644 --- a/config/pst_proteinshake.yaml +++ b/config/pst_proteinshake.yaml @@ -10,16 +10,16 @@ split: structure batch_size: 4 num_workers: 4 device: null - truncation_seq_length: 5000 toks_per_batch: 4096 include_seq: false - pretrained: .cache/pst +perturbation: null +model_path: null logs: prefix: logs_pst/proteinshake - path: ${logs.prefix}/${task.name}/${model}/${split}/${seed} + path: ${logs.prefix}/${task.name}/${model}/${split}/${seed}/${perturbation} # output directory, generated dynamically on each run hydra: diff --git a/environment.yaml b/environment.yaml index e60ba4f..7caf4eb 100644 --- a/environment.yaml +++ b/environment.yaml @@ -4,23 +4,213 @@ channels: - pytorch - nvidia - defaults + - nodefaults - conda-forge dependencies: - - python=3.9 - - pytorch>=2.0.0 - - pytorch-cuda=11.8 - - lightning - - pytorch-scatter - - pytorch-cluster + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - anyio=4.2.0=py39h06a4308_0 + - arrow=1.2.3=py39h06a4308_1 + - backoff=2.2.1=py39h06a4308_1 + - beautifulsoup4=4.12.3=py39h06a4308_0 + - blas=1.0=mkl + - blessed=1.20.0=py39h06a4308_0 + - boto3=1.34.154=py39h06a4308_0 + - botocore=1.34.154=py39h06a4308_0 + - brotli-python=1.0.9=py39h6a678d5_8 + - ca-certificates=2024.7.2=h06a4308_0 + - certifi=2024.7.4=py39h06a4308_0 + - cffi=1.16.0=py39h5eee18b_1 + - charset-normalizer=3.3.2=pyhd3eb1b0_0 + - click=8.1.7=py39h06a4308_0 + - croniter=1.3.7=py39h06a4308_0 + - cryptography=43.0.0=py39hdda0065_0 + - cuda-cudart=11.8.89=0 + - cuda-cupti=11.8.87=0 + - cuda-libraries=11.8.0=0 + - cuda-nvrtc=11.8.89=0 + - cuda-nvtx=11.8.86=0 + - cuda-runtime=11.8.0=0 + - cuda-version=12.6=3 + - dateutils=0.6.12=py39h06a4308_0 + - deepdiff=7.0.1=py39h2f386ee_0 + - exceptiongroup=1.2.0=py39h06a4308_0 + - fastapi=0.112.2=py39h06a4308_0 + - filelock=3.13.1=py39h06a4308_0 + - fsspec=2024.6.1=py39h06a4308_0 + - gmp=6.2.1=h295c915_3 + - gmpy2=2.1.2=py39heeb90bb_0 + - h11=0.14.0=py39h06a4308_0 + - idna=3.7=py39h06a4308_0 + - inquirer=3.1.4=py39h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - itsdangerous=2.2.0=py39h06a4308_0 + - jinja2=3.1.4=py39h06a4308_0 + - jmespath=1.0.1=py39h06a4308_0 + - joblib=1.4.2=py39h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libcublas=11.11.3.6=0 + - libcufft=10.9.0.58=0 + - libcufile=1.11.1.6=0 + - libcurand=10.3.7.68=0 + - libcusolver=11.4.1.48=0 + - libcusparse=11.7.5.86=0 + - libffi=3.4.4=h6a678d5_1 + - libgcc=14.1.0=h77fa898_1 + - libgcc-ng=14.1.0=h69a702a_1 + - libgfortran-ng=11.2.0=h00389a5_1 + - libgfortran5=11.2.0=h1234567_1 + - libgomp=14.1.0=h77fa898_1 + - libnpp=11.8.0.86=0 + - libnvjpeg=11.9.0.86=0 + - libstdcxx=14.1.0=hc0a3c3a_1 + - libstdcxx-ng=14.1.0=h4852527_1 + - lightning=2.0.9.post0=py39h06a4308_0 + - lightning-cloud=0.5.57=py39h06a4308_0 + - lightning-utilities=0.9.0=py39h06a4308_0 + - llvm-openmp=14.0.6=h9e868ea_0 + - markdown-it-py=2.2.0=py39h06a4308_1 + - markupsafe=2.1.3=py39h5eee18b_0 + - mdurl=0.1.0=py39h06a4308_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py39h5eee18b_1 + - mkl_fft=1.3.8=py39h5eee18b_0 + - mkl_random=1.2.4=py39hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.3.0=py39h06a4308_0 + - ncurses=6.4=h6a678d5_0 + - networkx=3.2.1=py39h06a4308_0 + - numpy=1.26.4=py39h5f9d8c6_0 + - numpy-base=1.26.4=py39hb5e798b_0 + - openssl=3.0.14=h5eee18b_0 + - ordered-set=4.1.0=py39h06a4308_0 + - orjson=3.9.15=py39h97a8848_0 + - packaging=24.1=py39h06a4308_0 + - pip=24.2=py39h06a4308_0 + - psutil=5.9.0=py39h5eee18b_0 + - pybind11-abi=4=hd3eb1b0_1 + - pycparser=2.21=pyhd3eb1b0_0 + - pydantic=1.10.12=py39h5eee18b_1 + - pygments=2.15.1=py39h06a4308_1 + - pyjwt=2.8.0=py39h06a4308_0 + - pyopenssl=24.2.1=py39h06a4308_0 + - pysocks=1.7.1=py39h06a4308_0 + - python=3.9.19=h955ad1f_1 + - python-dateutil=2.9.0post0=py39h06a4308_2 + - python-editor=1.0.4=pyhd3eb1b0_0 + - python-multipart=0.0.9=py39h06a4308_0 + - python_abi=3.9=2_cp39 + - pytorch=2.4.0=py3.9_cuda11.8_cudnn9.1.0_0 + - pytorch-cluster=1.6.3=py39_torch_2.4.0_cu118 + - pytorch-cuda=11.8=h7e8668a_5 + - pytorch-lightning=2.3.0=py39h06a4308_0 + - pytorch-mutex=1.0=cuda + - pytorch-scatter=2.1.2=py39_torch_2.4.0_cu118 + - pytz=2024.1=py39h06a4308_0 + - pyyaml=6.0.1=py39h5eee18b_0 + - readchar=4.0.5=py39h06a4308_0 + - readline=8.2=h5eee18b_0 + - requests=2.32.3=py39h06a4308_0 + - rich=13.7.1=py39h06a4308_0 + - s3transfer=0.10.1=py39h06a4308_0 + - scipy=1.13.1=py39h5f9d8c6_0 + - setuptools=72.1.0=py39h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sniffio=1.3.0=py39h06a4308_0 + - soupsieve=2.5=py39h06a4308_0 + - sqlite=3.45.3=h5eee18b_0 + - starlette=0.38.2=py39h06a4308_0 + - starsessions=1.3.0=py39h06a4308_0 + - sympy=1.13.2=py39h06a4308_0 + - tbb=2021.8.0=hdb19cb5_0 + - threadpoolctl=3.5.0=py39h2f386ee_0 + - tk=8.6.14=h39e8969_0 + - torchmetrics=1.4.0.post0=py39h06a4308_0 + - torchtriton=3.0.0=py39 + - tqdm=4.66.5=py39h2f386ee_0 + - traitlets=5.14.3=py39h06a4308_0 + - typing-extensions=4.11.0=py39h06a4308_0 + - typing_extensions=4.11.0=py39h06a4308_0 + - urllib3=1.26.19=py39h06a4308_0 + - uvicorn=0.20.0=py39h06a4308_0 + - wcwidth=0.2.5=pyhd3eb1b0_0 + - websocket-client=1.8.0=py39h06a4308_0 + - websockets=10.4=py39h5eee18b_1 + - wheel=0.43.0=py39h06a4308_0 + - xz=5.4.6=h5eee18b_1 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.13=h5eee18b_1 - pip: - - torch_geometric==2.3.1 - - proteinshake - - fair-esm - - pyprojroot - - einops - - pandas - - easydict - - pyprojroot - - scikit-learn - - hydra-core - - tensorboard \ No newline at end of file + - absl-py==2.1.0 + - antlr4-python3-runtime==4.9.3 + - asttokens==2.4.1 + - biopandas==0.5.1 + - biotite==0.40.0 + - cloudpickle==3.1.0 + - contourpy==1.3.0 + - cycler==0.12.1 + - decorator==5.1.1 + - docopt==0.6.2 + - easydict==1.13 + - einops==0.8.0 + - et-xmlfile==1.1.0 + - executing==2.1.0 + - fair-esm==2.0.0 + - fastavro==1.9.5 + - fastpdb==1.3.1 + - fonttools==4.53.1 + - freesasa==2.2.1 + - ftpretty==0.4.0 + - goatools==1.4.12 + - grpcio==1.66.1 + - huggingface-hub==0.25.2 + - hydra-core==1.3.2 + - importlib-metadata==8.4.0 + - importlib-resources==6.4.4 + - ipython==8.18.1 + - jedi==0.19.1 + - kiwisolver==1.4.5 + - lmdb==1.5.1 + - loguru==0.7.2 + - looseversion==1.1.2 + - markdown==3.7 + - matplotlib==3.9.2 + - matplotlib-inline==0.1.7 + - mmtf-python==1.1.3 + - msgpack==1.0.8 + - ninja==1.11.1.1 + - omegaconf==2.3.0 + - openpyxl==3.1.5 + - pandas==2.2.2 + - parso==0.8.4 + - patsy==0.5.6 + - pexpect==4.9.0 + - pillow==10.4.0 + - prompt-toolkit==3.0.47 + - proteinshake==0.3.14 + - protobuf==5.28.0 + - pst==0.0.0 + - ptyprocess==0.7.0 + - pure-eval==0.2.3 + - pydot==3.0.1 + - pyparsing==3.1.4 + - pyprojroot==0.3.0 + - rdkit-pypi==2022.9.5 + - regex==2024.9.11 + - safetensors==0.4.5 + - scikit-learn==1.5.1 + - stack-data==0.6.3 + - statsmodels==0.14.2 + - submitit==1.5.2 + - tensorboard==2.17.1 + - tensorboard-data-server==0.7.2 + - tokenizers==0.20.1 + - torch-geometric==2.3.1 + - torchdrug==0.2.1 + - transformers==4.45.2 + - tzdata==2024.1 + - werkzeug==3.0.4 + - xlsxwriter==3.2.0 + - zipp==3.20.1 +prefix: /fs/pool/pool-hartout/.conda/envs/pst diff --git a/experiments/esmfold/command_perturbed_experiment.sh b/experiments/esmfold/command_perturbed_experiment.sh new file mode 100644 index 0000000..1f96949 --- /dev/null +++ b/experiments/esmfold/command_perturbed_experiment.sh @@ -0,0 +1,33 @@ +# Baselines +python experiments/perturbed/predict_gearnet_perturbed.py +python experiments/perturbed/predict_gearnet_perturbed.py dataset=gearnet_go_bp +python experiments/perturbed/predict_gearnet_perturbed.py dataset=gearnet_go_cc +python experiments/perturbed/predict_gearnet_perturbed.py dataset=gearnet_go_mf +python experiments/perturbed/predict_mutation_perturbed.py --outdir logs/pst_perturbation/baselines +python experiments/perturbed/predict_scop_perturbed.py dataset=scop + +# Complete graphs +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt data.edge_perturb=complete +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt dataset=gearnet_go_bp data.edge_perturb=complete +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt dataset=gearnet_go_cc data.edge_perturb=complete +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt dataset=gearnet_go_mf data.edge_perturb=complete +python experiments/perturbed/predict_mutation_perturbed.py --outdir logs/pst_perturbation/complete --perturbation complete --pretrained ./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt +python experiments/perturbed/predict_scop_perturbed.py dataset=scop data.edge_perturb=complete pretrained=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt + +# Sequence graphs +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt data.edge_perturb=sequence +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt dataset=gearnet_go_bp data.edge_perturb=sequence +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt dataset=gearnet_go_cc data.edge_perturb=sequence +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt dataset=gearnet_go_mf data.edge_perturb=sequence +python experiments/perturbed/predict_mutation_perturbed.py --outdir logs/pst_perturbation/sequence --perturbation sequence --pretrained ./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt +python experiments/perturbed/predict_scop_perturbed.py dataset=scop data.edge_perturb=sequence pretrained=./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt + +# Random graphs +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt data.edge_perturb=random +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt dataset=gearnet_go_bp data.edge_perturb=random +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt dataset=gearnet_go_cc data.edge_perturb=random +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt dataset=gearnet_go_mf data.edge_perturb=random +python experiments/perturbed/predict_mutation_perturbed.py --outdir logs/pst_perturbation/random --perturbation random --pretrained ./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt +python experiments/perturbed/predict_scop_perturbed.py dataset=scop data.edge_perturb=random pretrained=./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt + +# python experiments/perturbed/predict_proteinshake_perturbed.py logs.prefix=logs_pst/proteinshake_perturbed perturbation=complete model_path=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt task=binding_site_detection,enzyme_class,gene_ontology,pfam_task,structural_class diff --git a/experiments/esmfold/fold_sequences_esmfold.py b/experiments/esmfold/fold_sequences_esmfold.py new file mode 100644 index 0000000..0b2ebe5 --- /dev/null +++ b/experiments/esmfold/fold_sequences_esmfold.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +"""fold_sequences_esmfold.py +This must be executed separately to ensure that all structures are folded + +""" + +import pickle +from pathlib import Path + +import esm +import hydra +import pandas as pd +import submitit +import torch +import torchdrug +from easydict import EasyDict as edict +from loguru import logger +from matplotlib import pyplot as plt +from omegaconf import OmegaConf +from pyprojroot import here +from torchdrug import core, datasets, models, tasks # noqa + +from tqdm import tqdm +from transformers import AutoTokenizer, EsmForProteinFolding, EsmModel +# from transformers.models.esm.openfold_utils.protein import to_pdb, from_prediction, Protein +# from tqdm.rich import tqdm + + +# def fold_sequence(idx: int, sequence: str, model: torch.nn.Module, path_prefix: Path) -> int: +# """ +# Folds a protein sequence and writes the resulting structure to a PDB file. + +# Args: +# idx (int): Index of the sequence. +# sequence (str): Protein sequence to be folded. +# model (EsmForProteinFolding): Protein folding model. +# path_prefix (Path): Directory prefix where the PDB file will be saved. + +# Returns: +# int: The index of the sequence. +# """ +# # logger.info(f"Folding sequence {idx} - {sequence}") +# out_path = path_prefix / (str(idx) + ".pdb") + +# if not out_path.exists(): +# with torch.no_grad(): +# outputs = model.infer_pdb(sequence) +# # logger.info(f"PDB file produced") +# # Write to pdb file + +# with open(out_path, "w") as f: +# f.write(outputs) +# # logger.info(f"Saved to {path_prefix / (str(idx) + '.pdb')}") +# return out_path + +def fold_sequence_batch(sequences: list, path_prefix: Path) -> list: + """ + Folds a batch of protein sequences and writes the resulting structures to PDB files. + + Args: + idx (int): Index of the sequence. + sequences (list): List of protein sequences to be folded. + model (EsmForProteinFolding): Protein folding model. + path_prefix (Path): Directory prefix where the PDB file will be saved. + + Returns: + list: List of indices of the sequences. + """ + model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1") + # for idx, sequence in tqdm(sequences): + for idx, sequence in sequences: + out_path = path_prefix / (str(idx) + ".pdb") + if not out_path.exists(): + with torch.no_grad(): + outputs = model.infer_pdb(sequence) + # Write to pdb file + with open(out_path, "w") as f: + f.write(outputs) + +# def contact_prediction_batch(sequences: list, path_prefix: Path) -> list: + +# # tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t48_15B_UR50D") +# # model = EsmModel.from_pretrained("facebook/esm2_t48_15B_UR50D") +# sequences = sequences[:1] +# model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() +# batch_converter = alphabet.get_batch_converter() +# model.eval() +# batch_labels, batch_strs, batch_tokens = batch_converter(sequences) +# batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) +# with torch.no_grad(): +# results = model(batch_tokens, return_contacts=True) + +# return out_paths + +def submitit_executor_wrapper(cfg, sequences_batches, folding_function, path_prefix): + executor = submitit.AutoExecutor(folder=str(here() / "logs")) + executor.update_parameters( + slurm_array_parallelism=cfg.compute.array_parallelism, + cpus_per_task=cfg.compute.cpus_per_task, + slurm_mem_per_cpu=cfg.compute.mem_per_cpu, + timeout_min=cfg.compute.timeout_min, + slurm_job_name=cfg.compute.job_name, + slurm_partition=cfg.compute.partition, + slurm_gpus_per_node=cfg.compute.gpus_per_node, + slurm_gpus_per_task=cfg.compute.gpus_per_task, + slurm_tasks_per_node=cfg.compute.tasks_per_node, + ) + + total_jobs = len(sequences_batches) + logger.info(f"Starting {total_jobs} jobs") + + # path_prefix = here() / "datasets/esmfold/structures/" + + jobs = [] + with executor.batch(): + for batch in sequences_batches: + job = executor.submit( + folding_function, sequences=batch, path_prefix=path_prefix + ) + jobs.append(job) + +# def loader_wrapper(loader, indices, cfg, partition="train", parallel=False): + + + +@hydra.main( + version_base="1.3", + config_path=str(here() / "config"), + config_name="esmfold_config", +) +def main(cfg): + # task = core.Configurable.load_config_dict( + # edict(OmegaConf.to_container(cfg.task, resolve=True)) + # ) + cached_seq = here() / "datasets/.cached_esmfold_data.pkl" + if not cached_seq.exists(): + dataset = core.Configurable.load_config_dict( + OmegaConf.to_container(cfg.dataset, resolve=True) + ) + # train, val, test = dataset.split() + sequences = [ + (idx, protein["graph"].to_sequence().replace(".G", "").replace(".", "")) + for idx, protein in tqdm( + enumerate(dataset), + desc="Extracting sequences", + total=len(dataset), + ) + ] + sequences_batches = [ + sequences[i : i + cfg.data.fold_batch_size] + for i in range(0, len(sequences), cfg.data.fold_batch_size) + ] + # Save list as pkl file + with open(cached_seq, "wb") as file: + pickle.dump(sequences_batches, file) + + else: + with open(cached_seq, 'rb') as file: + sequences_batches = pickle.load(file) + # if not cfg.parallel: + # for batch in tqdm(sequences_batches, desc="Folding sequences without submitit"): + # fold_sequence_batch(sequences=batch, path_prefix=here() / "datasets/esmfold/structures/") + # else: + submitit_executor_wrapper(cfg, sequences_batches, fold_sequence_batch, here() / "datasets/esmfold/structures/") + # train_loader = torchdrug.data.DataLoader( + # train, batch_size=1, shuffle=False + # ) + # val_loader = torchdrug.data.DataLoader( + # val, batch_size=1, shuffle=False + # ) + # test_loader = torchdrug.data.DataLoader( + # test, batch_size=1, shuffle=False + # ) + # loader_wrapper(train_loader, cfg, partition="train", parallel=cfg.parallel) + # loader_wrapper(val_loader, cfg, partition="val", parallel=cfg.parallel) + # loader_wrapper(test_loader, cfg, partition="test", parallel=cfg.parallel) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/experiments/esmfold/predict_gearnet_esmfold.py b/experiments/esmfold/predict_gearnet_esmfold.py new file mode 100644 index 0000000..7b92906 --- /dev/null +++ b/experiments/esmfold/predict_gearnet_esmfold.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +import json +import logging +import pickle +from pathlib import Path +from timeit import default_timer as timer + +import biotite.structure as struc +import esm +import fastpdb +import hydra +import numpy as np +import pandas as pd +import torch +import torch_geometric.nn as gnn +import torchdrug +from easydict import EasyDict as edict +from omegaconf import OmegaConf +from pyprojroot import here +from scipy.spatial.distance import pdist, squareform +from sklearn.neighbors import radius_neighbors_graph +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from torch_geometric.utils import from_scipy_sparse_matrix +from torchdrug import core, datasets, models, tasks # noqa +from tqdm import tqdm + +from pst.downstream import ( + convert_to_numpy, + mask_cls_idx, + preprocess, +) +from pst.downstream.mlp import train_and_eval_mlp +from pst.esm2 import PST + +log = logging.getLogger(__name__) + +esm_alphabet = esm.data.Alphabet.from_architecture("ESM-1b") + +AA_THREE_TO_ONE = { + "ALA": "A", + "CYS": "C", + "ASP": "D", + "GLU": "E", + "PHE": "F", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LYS": "K", + "LEU": "L", + "MET": "M", + "ASN": "N", + "PRO": "P", + "GLN": "Q", + "ARG": "R", + "SER": "S", + "THR": "T", + "VAL": "V", + "TRP": "W", + "TYR": "Y", + "UNK": "X", +} + + +@torch.no_grad() +def compute_repr(data_loader, model, cfg): + embeddings = [] + for batch_idx, data in enumerate(tqdm(data_loader, desc="Computing embeddings")): + data = data.to(cfg.device) + out = model(data, return_repr=True, aggr=cfg.aggr) + out, batch = out[data.idx_mask], data.batch[data.idx_mask] + out = gnn.global_mean_pool(out, batch) + + if cfg.include_seq: + data.edge_index = None + out_seq = model(data, return_repr=True, aggr=cfg.aggr) + out_seq = out_seq[data.idx_mask] + out_seq = gnn.global_mean_pool(out_seq, batch) + out = (out + out_seq) * 0.5 + + out = out.cpu() + + embeddings = embeddings + list(torch.chunk(out, len(data.ptr) - 1)) + + return torch.cat(embeddings) + +def create_graph_from_pdb(idx, protein): + # Load the protein structure + pdb_file = fastpdb.PDBFile.read(here() / "datasets" / "esmfold" / "structures" / f"{idx}.pdb") + structure = pdb_file.get_structure(model=1) + + coords = structure.coord + element = structure.element + resname = structure.res_name + resid = structure.res_id + chain_id = structure.chain_id + atom_name = structure.atom_name + + df = pd.DataFrame( + { + "x": coords[:, 0], + "y": coords[:, 1], + "z": coords[:, 2], + "element": element, + "resname": resname, + "atom_name": atom_name, + "resid": resid, + "chain_id": chain_id, + } + ) + + # Extract CA atom coordinates + coordinates = df.loc[df["atom_name"] == "CA", ["x", "y", "z"]].values + + sequence = "".join( + df.loc[df.atom_name == "CA"].resname.map(AA_THREE_TO_ONE).tolist() + ) + x = torch.LongTensor( + [esm_alphabet.get_idx(res) for res in esm_alphabet.tokenize(sequence)] + ) + + # Create edge index and edge attributes + edge_index, edge_attr = from_scipy_sparse_matrix( + radius_neighbors_graph(coordinates, 8.0) + ) + return Data(edge_index=edge_index, x=x, edge_attr=edge_attr) + +def get_structures(dataset, task, eps=8): + data_loader = torchdrug.data.DataLoader(dataset, batch_size=1, shuffle=False) + structures = [] + labels = [] + idx_range = dataset.indices + for idx, protein in tqdm(zip(idx_range, data_loader), total=len(list(idx_range)), desc="Get structures"): + graph = create_graph_from_pdb(idx, protein) + # x, edge_index, edge_attr = create_graph_from_contact_map() + labels.append(protein["targets"]) + structures.append( + graph + ) + + return structures, torch.cat(labels) + + +@hydra.main( + version_base="1.3", config_path=str(here() / "config"), config_name="pst_gearnet_esmfold" +) +def main(cfg): + cfg.device = "cuda" if torch.cuda.is_available() else "cpu" + log.info(f"Configs:\n{OmegaConf.to_yaml(cfg)}") + + if cfg.include_seq and "so" not in cfg.model: + cfg.model = f"{cfg.model}_so" + + pretrained_path = Path(cfg.pretrained) / f"{cfg.model}.pt" + pretrained_path.parent.mkdir(parents=True, exist_ok=True) + + model, model_cfg = PST.from_pretrained_url( + cfg.model, pretrained_path, + ) + + model.eval() + model.to(cfg.device) + + task = core.Configurable.load_config_dict( + edict(OmegaConf.to_container(cfg.task, resolve=True)) + ) + + structure_path = ( + Path(cfg.data.esmfold_structures_path) / f"structures_{model_cfg.data.graph_eps}.pt" + ) + if structure_path.exists(): + tmp = torch.load(structure_path) + train_str, y_tr = tmp["train_str"], tmp["y_tr"] + val_str, y_val = tmp["val_str"], tmp["y_val"] + test_str, y_te = tmp["test_str"], tmp["y_te"] + del tmp + else: + # To make torchdrug work, one has to delete unrecognized attributes... + cfg.dataset.__delattr__('name') + dataset = core.Configurable.load_config_dict( + OmegaConf.to_container(cfg.dataset, resolve=True) + ) + train_dset, val_dset, test_dset = dataset.split() + + train_str, y_tr = get_structures(train_dset, task, eps=model_cfg.data.graph_eps) + val_str, y_val = get_structures(val_dset, task, eps=model_cfg.data.graph_eps) + test_str, y_te = get_structures(test_dset, task, eps=model_cfg.data.graph_eps) + torch.save( + { + "train_str": train_str, + "val_str": val_str, + "test_str": test_str, + "y_tr": y_tr, + "y_val": y_val, + "y_te": y_te, + }, + structure_path, + ) + + # this is awful i know, todo: proper transform and dataset + train_str = [mask_cls_idx(data) for data in train_str] + val_str = [mask_cls_idx(data) for data in val_str] + test_str = [mask_cls_idx(data) for data in test_str] + + train_loader = DataLoader( + train_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + val_loader = DataLoader( + val_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + test_loader = DataLoader( + test_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + + # compute embeddings + tic = timer() + X_tr = compute_repr(train_loader, model, cfg) + X_val = compute_repr(val_loader, model, cfg) + X_te = compute_repr(test_loader, model, cfg) + compute_time = timer() - tic + preprocess(X_tr) + preprocess(X_val) + preprocess(X_te) + + X_tr, X_val, X_te, y_tr, y_val, y_te = convert_to_numpy( + X_tr, X_val, X_te, y_tr, y_val, y_te + ) + X_mask = np.isnan(X_tr.sum(1)) + X_tr, y_tr = X_tr[~X_mask], y_tr[~X_mask] + log.info(f"X_tr shape: {X_tr.shape} y_tr shape: {y_tr.shape}") + + if cfg.use_pca is not None: + from sklearn.decomposition import PCA + + cfg.use_pca = 1024 if X_tr.shape[1] < 10000 else 2048 + pca = PCA(cfg.use_pca) + pca = pca.fit(X_tr) + X_tr = pca.transform(X_tr) + X_val = pca.transform(X_val) + X_te = pca.transform(X_te) + log.info(f"PCA done. X_tr shape: {X_tr.shape}") + + X_tr, y_tr = torch.from_numpy(X_tr).float(), torch.from_numpy(y_tr).float() + X_val, y_val = torch.from_numpy(X_val).float(), torch.from_numpy(y_val).float() + X_te, y_te = torch.from_numpy(X_te).float(), torch.from_numpy(y_te).float() + + train_and_eval_mlp( + X_tr, + y_tr, + X_val, + y_val, + X_te, + y_te, + cfg, + task, + batch_size=32, + epochs=100, + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/esmfold/predict_gearnet_esmfold_old.py b/experiments/esmfold/predict_gearnet_esmfold_old.py new file mode 100644 index 0000000..3271ac0 --- /dev/null +++ b/experiments/esmfold/predict_gearnet_esmfold_old.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +import logging +import xml.etree.ElementTree as ET +from pathlib import Path +from timeit import default_timer as timer + +import esm +import hydra +import matplotlib.pyplot as plt +import numpy as np + +# from tqdm.rich import tqdm +import pandas as pd +import requests +import torch +import torch_geometric.nn as gnn +import torchdrug +from easydict import EasyDict as edict +from omegaconf import OmegaConf +from pyprojroot import here +from sklearn.neighbors import radius_neighbors_graph +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from torch_geometric.utils import from_scipy_sparse_matrix, to_edge_index, to_undirected +from torchdrug import core, datasets, models, tasks # noqa +from tqdm import tqdm + +from pst.downstream import ( + convert_to_numpy, + mask_cls_idx, + preprocess, +) +from pst.downstream.mlp import train_and_eval_mlp +from pst.esm2 import PST +from pst.transforms import CompleteEdges, RandomizeEdges, SequenceEdges + +logger = logging.getLogger(__name__) + +esm_alphabet = esm.data.Alphabet.from_architecture("ESM-1b") + + +@torch.no_grad() +def compute_repr(data_loader, model, cfg): + embeddings = [] + for batch_idx, data in enumerate(tqdm(data_loader)): + data = data.to(cfg.device) + out = model(data, return_repr=True, aggr=cfg.aggr) + out, batch = out[data.idx_mask], data.batch[data.idx_mask] + out = gnn.global_mean_pool(out, batch) + + if cfg.include_seq: + data.edge_index = None + out_seq = model(data, return_repr=True, aggr=cfg.aggr) + out_seq = out_seq[data.idx_mask] + out_seq = gnn.global_mean_pool(out_seq, batch) + out = (out + out_seq) * 0.5 + + out = out.cpu() + + embeddings = embeddings + list(torch.chunk(out, len(data.ptr) - 1)) + + return torch.cat(embeddings) + + +def get_structures(dataset, task, eps=8): + data_loader = torchdrug.data.DataLoader(dataset, batch_size=1, shuffle=False) + structures = [] + labels = [] + idx_range = dataset.indices + for protein, idx in tqdm( + zip(data_loader, idx_range), total=len(list(idx_range)), desc="Get structures" + ): + out = task.graph_construction_model(protein["graph"]) + sequence = out.to_sequence()[0] + if len(sequence) == 0: + continue + + # TODO, add warning if file is not found. + contact_map = torch.load( + here() / f"datasets/esmfold/contacts/{idx}.pt", map_location="cpu" + ) + + # Vanilla structure + coords = out.node_position + graph_adj = radius_neighbors_graph(coords, radius=eps, mode="connectivity") + edge_index = from_scipy_sparse_matrix(graph_adj)[0].long() + + # Save contact_map as image + # plt.imshow(contact_map[0]) + # plt.savefig(here() / f"datasets/esmfold/contact_images/{idx}.png") + + edge_index, edge_attr = to_edge_index( + ((contact_map > 0.5).long())[0].to_sparse() + ) + # Just making sure we are not undercounting the edges. + edge_index, edge_attr = to_undirected( + edge_index, edge_attr, num_nodes=len(sequence) + ) + + labels.append(protein["targets"]) + + torch_sequence = torch.LongTensor( + [esm_alphabet.get_idx(res) for res in esm_alphabet.tokenize(sequence)] + ) + + torch_sequence = torch.cat( + [ + torch.LongTensor([esm_alphabet.cls_idx]), + torch_sequence, + torch.LongTensor([esm_alphabet.eos_idx]), + ] + ) + edge_index = edge_index + 1 # shift for cls_idx + + edge_attr = None + + structures.append( + Data(edge_index=edge_index, x=torch_sequence, edge_attr=edge_attr) + ) + + return structures, torch.cat(labels) + +@hydra.main( + version_base="1.3", config_path=str(here() / "config"), config_name="pst_gearnet" +) +def main(cfg): + cfg.device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Configs:\n{OmegaConf.to_yaml(cfg)}") + + if cfg.include_seq and "so" not in cfg.model: + cfg.model = f"{cfg.model}_so" + + if cfg.data.edge_perturb is not None: + pretrained_path = cfg.pretrained + model, model_cfg = PST.from_pretrained( + model_path=pretrained_path, + ) + else: + pretrained_path = Path(cfg.pretrained) / f"{cfg.model}.pt" + pretrained_path.parent.mkdir(parents=True, exist_ok=True) + model, model_cfg = PST.from_pretrained_url( + cfg.model, + pretrained_path, + ) + + model.eval() + model.to(cfg.device) + + task = core.Configurable.load_config_dict( + edict(OmegaConf.to_container(cfg.task, resolve=True)) + ) + + + # Save matrices if they have been loaded before + esm_cache_path = here() / "datasets/esmfold/contacts/.cache/esm_cache_path.pt" + # if True: + if not esm_cache_path.exists(): + cfg.dataset.__delattr__("name") + dataset = core.Configurable.load_config_dict( + OmegaConf.to_container(cfg.dataset, resolve=True) + ) + proteins = [ + protein["graph"].to_sequence().replace(".G", "") for protein in tqdm(dataset, desc="Get sequences") + ] + proteins = pd.DataFrame(proteins, columns=["sequences"]).to_csv( + here() / "datasets/esmfold/sequences.csv", index=False + ) + train_dset, val_dset, test_dset = dataset.split() + train_str, y_tr = get_structures(train_dset, task, eps=model_cfg.data.graph_eps) + val_str, y_val = get_structures(val_dset, task, eps=model_cfg.data.graph_eps) + test_str, y_te = get_structures(test_dset, task, eps=model_cfg.data.graph_eps) + train_str, y_tr, val_str, y_val, test_str, y_te + torch.save( + { + "train_str": train_str, + "y_tr": y_tr, + "val_str": val_str, + "y_val": y_val, + "test_str": test_str, + "y_te": y_te, + }, + esm_cache_path, + ) + else: + cached_dataset = torch.load(esm_cache_path) + train_str, y_tr = cached_dataset["train_str"], cached_dataset["y_tr"] + val_str, y_val = cached_dataset["val_str"], cached_dataset["y_val"] + test_str, y_te = cached_dataset["test_str"], cached_dataset["y_te"] + + # this is awful i know, todo: proper transform and dataset + train_str = [mask_cls_idx(data) for data in train_str] + val_str = [mask_cls_idx(data) for data in val_str] + test_str = [mask_cls_idx(data) for data in test_str] + + if cfg.data.edge_perturb == "random": + PerturbationTransform = RandomizeEdges + elif cfg.data.edge_perturb == "sequence": + PerturbationTransform = SequenceEdges + elif cfg.data.edge_perturb == "complete": + PerturbationTransform = CompleteEdges + elif cfg.data.edge_perturb is None: + pass + else: + raise ValueError("Invalid value for cfg.data.edge_perturb") + + if cfg.data.edge_perturb is not None: + train_str = [ + PerturbationTransform()(data) + for data in tqdm(train_str, desc=f"Apply {cfg.data.edge_perturb} to train") + ] + val_str = [ + PerturbationTransform()(data) + for data in tqdm(val_str, desc=f"Apply {cfg.data.edge_perturb} to val") + ] + test_str = [ + PerturbationTransform()(data) + for data in tqdm(test_str, desc=f"Apply {cfg.data.edge_perturb} to test") + ] + else: + pass + + train_loader = DataLoader( + train_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + val_loader = DataLoader( + val_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + test_loader = DataLoader( + test_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + + # compute embeddings + tic = timer() + X_tr = compute_repr(train_loader, model, cfg) + X_val = compute_repr(val_loader, model, cfg) + X_te = compute_repr(test_loader, model, cfg) + compute_time = timer() - tic + preprocess(X_tr) + preprocess(X_val) + preprocess(X_te) + + X_tr, X_val, X_te, y_tr, y_val, y_te = convert_to_numpy( + X_tr, X_val, X_te, y_tr, y_val, y_te + ) + X_mask = np.isnan(X_tr.sum(1)) + X_tr, y_tr = X_tr[~X_mask], y_tr[~X_mask] + logger.info(f"X_tr shape: {X_tr.shape} y_tr shape: {y_tr.shape}") + + if cfg.use_pca is not None: + from sklearn.decomposition import PCA + + cfg.use_pca = 1024 if X_tr.shape[1] < 10000 else 2048 + pca = PCA(cfg.use_pca) + pca = pca.fit(X_tr) + X_tr = pca.transform(X_tr) + X_val = pca.transform(X_val) + X_te = pca.transform(X_te) + logger.info(f"PCA done. X_tr shape: {X_tr.shape}") + + X_tr, y_tr = torch.from_numpy(X_tr).float(), torch.from_numpy(y_tr).float() + X_val, y_val = torch.from_numpy(X_val).float(), torch.from_numpy(y_val).float() + X_te, y_te = torch.from_numpy(X_te).float(), torch.from_numpy(y_te).float() + + train_and_eval_mlp( + X_tr, + y_tr, + X_val, + y_val, + X_te, + y_te, + cfg, + task, + batch_size=32, + epochs=100, + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/esmfold/predict_mutation_esmfold.py b/experiments/esmfold/predict_mutation_esmfold.py new file mode 100644 index 0000000..35ecac3 --- /dev/null +++ b/experiments/esmfold/predict_mutation_esmfold.py @@ -0,0 +1,238 @@ +import argparse +import os +from collections import defaultdict +from pathlib import Path +from timeit import default_timer as timer + +import pandas as pd +import scipy +import torch +from proteinshake.utils import residue_alphabet +from proteinshake.transforms import Compose +from torch_geometric.loader import DataLoader +from tqdm import tqdm + +from pst.downstream.mutation import DeepSequenceDataset +from pst.esm2 import PST +from pst.transforms import MutationDataset, Proteinshake2ESM, RandomizeEdges, SequenceEdges, CompleteEdges + + +def load_args(): + parser = argparse.ArgumentParser( + description="Use PST for mutation prediction", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model-dir", type=str, default='.cache/pst', help="directory for downloading models" + ) + parser.add_argument( + "--model", type=str, default='pst_t6', help="pretrained model names (see README for models)" + ) + parser.add_argument( + "--pretrained", type=str, default=None, help="pretrained model path" + ) + parser.add_argument( + "--datapath", type=str, default="./datasets/dms", help="dataset prefix" + ) + parser.add_argument( + "--outdir", type=str, default="./logs_pst/dms", help="output directory", + ) + parser.add_argument( + "--protein_id", type=int, default=-1, nargs="+", help="protein id list" + ) + parser.add_argument( + "--perturbation", type=str, default=None, help="perturbation type" + ) + parser.add_argument( + "--strategy", + type=str, + choices=["masked", "wt", "mt", "mt-all"], + default="masked", + help="scoring strategy: masked marginals or wildtype marginals", + ) + args = parser.parse_args() + + args.datapath = Path(args.datapath) + args.outdir = Path(args.outdir) + + args.device = ( + torch.device(torch.cuda.current_device()) + if torch.cuda.is_available() + else torch.device("cpu") + ) + + return args + + +@torch.no_grad() +def predict_masked(model, data_loader): + logits = [] + y_true = [] + all_scores = [] + sample_lengths = [] + mt_indices = [] + wt_indices = [] + tic = timer() + for data in tqdm(data_loader, desc="Predicting"): + data = data.to(cfg.device) + out = model.mask_predict(data) + probs = torch.log_softmax(out, dim=-1) + if cfg.strategy == "mt-all": + score = probs.gather(-1, data.x.view(-1, 1)) + else: + score = probs.gather(-1, data.mt_indices) - probs.gather( + -1, data.wt_indices + ) + logits.append(out.cpu()) + y_true.append(data.y.cpu()) + all_scores.append(score.sum(dim=0).cpu()) + mt_indices.append(data.mt_indices.cpu()) + wt_indices.append(data.wt_indices.cpu()) + sample_lengths.append(len(out)) + toc = timer() + + logits = torch.cat(logits) + y_true = torch.cat(y_true) + all_scores = torch.cat(all_scores) + mt_indices = torch.cat(mt_indices) + wt_indices = torch.cat(wt_indices) + return { + "probabilities": logits, + "y_true": y_true, + "y_score": all_scores, + "mt_indices": mt_indices, + "wt_indices": wt_indices, + "sample_lengths": sample_lengths, + "total_time": toc - tic, + } + + +def label_row_wt(row, probs): + row = row.split() + wt_indices = torch.tensor( + list(map(lambda x: residue_alphabet.index(x[0]), row)) + ).view(-1, 1) + mt_indices = torch.tensor( + list(map(lambda x: residue_alphabet.index(x[-1]), row)) + ).view(-1, 1) + score = probs.gather(-1, mt_indices) - probs.gather(-1, wt_indices) + return score.sum(dim=0).item() + + +def main(): + global cfg + cfg = load_args() + print(cfg) + transforms = [Proteinshake2ESM()] + + if cfg.perturbation == "random": + transforms.append(RandomizeEdges()) + elif cfg.perturbation == "sequence": + transforms.append(SequenceEdges()) + elif cfg.perturbation == "complete": + transforms.append(CompleteEdges()) + else: + pass + + dataset_cls = DeepSequenceDataset + + protein_ids = dataset_cls.available_ids() + if isinstance(cfg.protein_id, list) and cfg.protein_id[0] != -1: + protein_ids = [protein_ids[i] for i in cfg.protein_id if i < len(protein_ids)] + else: + cfg.protein_id = list(range(len(protein_ids))) + print(f"# of Datasets: {len(protein_ids)}") + + dataset = dataset_cls(root=cfg.datapath) + mutations_list = dataset.mutations + + + if cfg.perturbation is not None: + pretrained_path = cfg.pretrained + model, model_cfg = PST.from_pretrained( + model_path=pretrained_path, + ) + else: + pretrained_path = Path(f"{cfg.model_dir}/{cfg.model}.pt") + pretrained_path.parent.mkdir(parents=True, exist_ok=True) + model, model_cfg = PST.from_pretrained_url( + cfg.model, pretrained_path + ) + + model.eval() + model.to(cfg.device) + dataset = dataset.to_graph(eps=model_cfg.data.graph_eps, transform=Compose(transforms=transforms)).pyg() + + all_results = defaultdict(list) + all_scores = [] + + for i, protein_id in zip(cfg.protein_id, protein_ids): + print("-" * 40) + print(f"Protein id: {protein_id}") + + mutations = mutations_list[i] + graph, protein_dict = dataset[i] + + df = mutations.copy() + df.rename(columns={"y": "effect"}, inplace=True) + df["protein_id"] = protein_id + df = df[["protein_id", "mutations", "effect"]] + if graph.num_nodes > 3000: + all_scores.append(df) + continue + if cfg.strategy == "masked" or cfg.strategy == "mt" or cfg.strategy == "mt-all": + ds = MutationDataset( + graph, + protein_dict, + mutations, + strategy=cfg.strategy, + use_transform=True, + transform=Compose(transforms=transforms), + ) + data_loader = DataLoader(ds, batch_size=1, shuffle=False) + results = predict_masked(model, data_loader) + + if cfg.strategy == "mt-all": + data_loader = DataLoader([graph], batch_size=1, shuffle=False) + graph = next(iter(data_loader)).to(cfg.device) + with torch.no_grad(): + out = model.mask_predict(graph) + probs = torch.log_softmax(out, dim=-1).cpu() + bias = probs.gather(-1, graph.x.cpu().view(-1, 1)).sum(dim=0) + results["y_score"] = results["y_score"] - bias + + current_dir = cfg.outdir / f"{protein_id}" + os.makedirs(current_dir, exist_ok=True) + torch.save(results, current_dir / "results.pt") + + df["PST"] = results["y_score"] + elif cfg.strategy == "wt": + data_loader = DataLoader([graph], batch_size=1, shuffle=False) + graph = next(iter(data_loader)).to(cfg.device) + with torch.no_grad(): + out = model.mask_predict(graph) + probs = torch.log_softmax(out, dim=-1).cpu() + + df["PST"] = df.apply( + lambda row: label_row_wt( + row["mutations"], + probs, + ), + axis=1, + ) + + rho = scipy.stats.spearmanr(df["effect"], df["PST"]) + print(f"Spearmanr: {rho}") + + all_scores.append(df) + all_results["protein_id"].append(protein_id) + all_results["spearmanr"].append(rho.correlation) + + all_results = pd.DataFrame.from_dict(all_results) + all_results.to_csv(cfg.outdir / "results.csv") + all_scores = pd.concat(all_scores, ignore_index=True) + all_scores.to_csv(cfg.outdir / "scores.csv") + + +if __name__ == "__main__": + main() diff --git a/experiments/esmfold/predict_proteinshake_esmfold.py b/experiments/esmfold/predict_proteinshake_esmfold.py new file mode 100644 index 0000000..6f55dba --- /dev/null +++ b/experiments/esmfold/predict_proteinshake_esmfold.py @@ -0,0 +1,197 @@ +import logging +from pathlib import Path +from timeit import default_timer as timer + +import hydra +import numpy as np +import pandas as pd +import torch +import torch_geometric.nn as gnn +from omegaconf import OmegaConf +from proteinshake.transforms import Compose +from pyprojroot import here +from sklearn.metrics import make_scorer +from sklearn.model_selection import GridSearchCV, PredefinedSplit +from torch_geometric.loader import DataLoader +from tqdm import tqdm + +from pst.downstream import compute_metrics, get_task, prepare_data +from pst.downstream.sklearn_wrapper import SklearnPredictor +from pst.esm2 import PST +from pst.transforms import ( + CompleteEdges, + PretrainingAttr, + Proteinshake2ESM, + RandomizeEdges, + SequenceEdges, +) + +log = logging.getLogger(__name__) + + +@torch.no_grad() +def compute_repr(data_loader, model, task, cfg): + embeddings = [] + for batch_idx, data in enumerate(tqdm(data_loader)): + data = data.to(cfg.device) + out = model(data, return_repr=True) + out, batch = out[data.idx_mask], data.batch[data.idx_mask] + if "protein" in task.task_type[0]: + out = gnn.global_mean_pool(out, batch) + if cfg.include_seq: + data.edge_index = None + out_seq = model(data, return_repr=True) + out_seq = out_seq[data.idx_mask] + if "protein" in task.task_type[0]: + out_seq = gnn.global_mean_pool(out_seq, batch) + out = (out + out_seq) * 0.5 + out = out.cpu() + if "protein" in task.task_type[0]: + embeddings = embeddings + list(torch.chunk(out, len(data.ptr) - 1)) + else: + embeddings = embeddings + list( + torch.split(out, tuple(torch.diff(data.ptr) - 2)) + ) + return embeddings + + +@hydra.main( + version_base="1.3", + config_path=str(here() / "config"), + config_name="pst_proteinshake", +) +def main(cfg): + cfg.device = "cuda" if torch.cuda.is_available() else "cpu" + log.info(f"Configs:\n{OmegaConf.to_yaml(cfg)}") + + if cfg.include_seq and "so" not in cfg.model: + cfg.model = f"{cfg.model}_so" + + if cfg.perturbation is not None: + pretrained_path = cfg.model_path + model, model_cfg = PST.from_pretrained( + model_path=pretrained_path, + ) + else: + pretrained_path = Path(cfg.pretrained) / f"{cfg.model}.pt" + pretrained_path.parent.mkdir(parents=True, exist_ok=True) + model, model_cfg = PST.from_pretrained_url( + cfg.model, pretrained_path, + ) + + model.eval() + model.to(cfg.device) + + transforms = [ + PretrainingAttr(), + Proteinshake2ESM(mask_cls_idx=True), + ] + if cfg.perturbation == "random": + transforms.append(RandomizeEdges()) + elif cfg.perturbation == "sequence": + transforms.append(SequenceEdges()) + elif cfg.perturbation == "complete": + transforms.append(CompleteEdges()) + else: + pass + + task = get_task(cfg.task.class_name)(root=cfg.task.path, split=cfg.split) + dataset = task.dataset.to_graph( + eps=model_cfg.data.graph_eps + ).pyg( + transform=Compose( + transforms=transforms, + ) + ) + + data_loader = DataLoader( + dataset, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + tic = timer() + X = compute_repr(data_loader, model, task, cfg) + compute_time = timer() - tic + X_tr, y_tr, X_val, y_val, X_te, y_te = prepare_data(X, task) + log.info(f"X_tr shape: {X_tr.shape} y_tr shape: {y_tr.shape}") + + ## Solving the problem with sklearn + estimator = SklearnPredictor(task.task_out) + + grid = estimator.get_grid() + + scoring = lambda y_true, y_pred: compute_metrics(y_true, y_pred, task)[ + cfg.task.metric + ] + if task.task_out == "multi_label" or task.task_out == "binary": + scoring = make_scorer(scoring, needs_threshold=True) + else: + scoring = make_scorer(scoring) + + test_split_index = [-1] * len(y_tr) + [0] * len(y_val) + X_tr_val, y_tr_val = np.concatenate((X_tr, X_val), axis=0), np.concatenate( + (y_tr, y_val) + ) + + splits = PredefinedSplit(test_fold=test_split_index) + + clf = GridSearchCV( + estimator=estimator, + param_grid=grid, + scoring=scoring, + cv=splits, + refit=False, + n_jobs=-1, + ) + + tic = timer() + clf.fit(X_tr_val, y_tr_val) + log.info(pd.DataFrame.from_dict(clf.cv_results_).sort_values("rank_test_score")) + estimator.set_params(**clf.best_params_) + clf = estimator + clf.fit(X_tr, y_tr) + clf_time = timer() - tic + ##### + + if task.task_out == "multi_label" or task.task_out == "binary": + try: + y_pred = clf.decision_function(X_te) + except: + y_pred = clf.predict(X_te) + else: + y_pred = clf.predict(X_te) + if isinstance(y_pred, list): + if y_pred[0].ndim > 1: + y_pred = [y[:, 1] for y in y_pred] + y_pred = np.asarray(y_pred).T + test_score = compute_metrics(y_te, y_pred, task)[cfg.task.metric] + log.info(f"Test score: {test_score:.3f}") + + if task.task_out == "multi_label" or task.task_out == "binary": + try: + y_val_pred = clf.decision_function(X_val) + except: + y_val_pred = clf.predict(X_val) + else: + y_val_pred = clf.predict(X_val) + if isinstance(y_val_pred, list): + if y_val_pred[0].ndim > 1: + y_val_pred = [y[:, 1] for y in y_val_pred] + y_val_pred = np.asarray(y_val_pred).T + val_score = compute_metrics(y_val, y_val_pred, task)[cfg.task.metric] + + results = [ + { + "test_score": test_score, + "val_score": val_score, + "compute_time": compute_time, + "clf_time": clf_time, + } + ] + + pd.DataFrame(results).to_csv(f"{cfg.logs.path}/results.csv") + + +if __name__ == "__main__": + main() diff --git a/experiments/esmfold/predict_scop_esmfold.py b/experiments/esmfold/predict_scop_esmfold.py new file mode 100644 index 0000000..0978710 --- /dev/null +++ b/experiments/esmfold/predict_scop_esmfold.py @@ -0,0 +1,267 @@ +import logging +from pathlib import Path +from timeit import default_timer as timer + +import esm +import hydra +import pandas as pd +import torch +import torch_geometric.nn as gnn +from omegaconf import OmegaConf +from pyprojroot import here +from sklearn.neighbors import radius_neighbors_graph +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from torch_geometric.utils import from_scipy_sparse_matrix +from tqdm import tqdm +from pst.transforms import ( + RandomizeEdges, + SequenceEdges, + CompleteEdges, +) +from pst.esm2 import PST +from pst.downstream.mlp import train_and_eval_linear +from pst.downstream import ( + preprocess, + convert_to_numpy, + mask_cls_idx, +) + +log = logging.getLogger(__name__) + +esm_alphabet = esm.data.Alphabet.from_architecture("ESM-1b") + + +@torch.no_grad() +def compute_repr(data_loader, model, cfg): + embeddings = [] + for batch_idx, data in enumerate(tqdm(data_loader)): + data = data.to(cfg.device) + out = model(data, return_repr=True, aggr=cfg.aggr) + out, batch = out[data.idx_mask], data.batch[data.idx_mask] + out = gnn.global_mean_pool(out, batch) + + if cfg.include_seq: + data.edge_index = None + out_seq = model(data, return_repr=True, aggr=cfg.aggr) + out_seq = out_seq[data.idx_mask] + out_seq = gnn.global_mean_pool(out_seq, batch) + out = (out + out_seq) * 0.5 + + out = out.cpu() + + embeddings = embeddings + list(torch.chunk(out, len(data.ptr) - 1)) + + return torch.cat(embeddings) + + +def get_structures(dataset, eps=8.0): + structures = [] + labels = [] + for protein in tqdm(dataset): + sequence = protein.seq + if len(sequence) == 0: + continue + coords = protein.pos + labels.append(torch.tensor(protein.y)) + + torch_sequence = torch.LongTensor( + [esm_alphabet.get_idx(res) for res in esm_alphabet.tokenize(sequence)] + ) + graph_adj = radius_neighbors_graph(coords, radius=eps, mode="connectivity") + edge_index = from_scipy_sparse_matrix(graph_adj)[0].long() + torch_sequence = torch.cat( + [ + torch.LongTensor([esm_alphabet.cls_idx]), + torch_sequence, + torch.LongTensor([esm_alphabet.eos_idx]), + ] + ) + edge_index = edge_index + 1 # shift for cls_idx + + edge_attr = None + + structures.append( + Data(edge_index=edge_index, x=torch_sequence, edge_attr=edge_attr) + ) + + return structures, torch.stack(labels) + + +@hydra.main( + version_base="1.3", config_path=str(here() / "config"), config_name="pst_gearnet" +) +def main(cfg): + cfg.device = "cuda" if torch.cuda.is_available() else "cpu" + log.info(f"Configs:\n{OmegaConf.to_yaml(cfg)}") + + if cfg.include_seq and "so" not in cfg.model: + cfg.model = f"{cfg.model}_so" + + + if cfg.data.edge_perturb is not None: + pretrained_path = cfg.pretrained + model, model_cfg = PST.from_pretrained( + pretrained_path, + ) + else: + pretrained_path = Path(cfg.pretrained) / f"{cfg.model}.pt" + pretrained_path.parent.mkdir(parents=True, exist_ok=True) + model, model_cfg = PST.from_pretrained_url( + cfg.model, pretrained_path, + ) + + model.eval() + model.to(cfg.device) + + scop_data = torch.load(Path(cfg.dataset.path) / "data.pt") + + structure_path = ( + Path(cfg.dataset.path) / f"structures_{model_cfg.data.graph_eps}.pt" + ) + if structure_path.exists(): + tmp = torch.load(structure_path) + train_str, y_tr = tmp["train_str"], tmp["y_tr"] + val_str, y_val = tmp["val_str"], tmp["y_val"] + test_str, y_te = tmp["test_str"], tmp["y_te"] + stratified_indices = tmp["stratified_indices"] + del tmp + else: + train_str, y_tr = get_structures( + scop_data["train"], eps=model_cfg.data.graph_eps + ) + val_str, y_val = get_structures(scop_data["val"], eps=model_cfg.data.graph_eps) + test_data = ( + scop_data["test_family"] + + scop_data["test_superfamily"] + + scop_data["test_fold"] + ) + n_fm = len(scop_data["test_family"]) + n_sf = len(scop_data["test_superfamily"]) + n_fo = len(scop_data["test_fold"]) + test_str, y_te = get_structures(test_data, eps=model_cfg.data.graph_eps) + stratified_indices = {} + stratified_indices["family"] = torch.arange(0, n_fm) + stratified_indices["superfamily"] = torch.arange(n_fm, n_fm + n_sf) + stratified_indices["fold"] = torch.arange(n_fm + n_sf, n_fm + n_sf + n_fo) + torch.save( + { + "train_str": train_str, + "val_str": val_str, + "test_str": test_str, + "y_tr": y_tr, + "y_val": y_val, + "y_te": y_te, + "stratified_indices": stratified_indices, + }, + structure_path, + ) + + # this is awful i know, todo: proper transform and dataset + train_str = [mask_cls_idx(data) for data in train_str] + val_str = [mask_cls_idx(data) for data in val_str] + test_str = [mask_cls_idx(data) for data in test_str] + + if cfg.data.edge_perturb == "random": + PerturbationTransform = RandomizeEdges + elif cfg.data.edge_perturb == "sequence": + PerturbationTransform = SequenceEdges + elif cfg.data.edge_perturb == "complete": + PerturbationTransform = CompleteEdges + elif cfg.data.edge_perturb is None: + pass + else: + raise ValueError("Invalid value for cfg.data.edge_perturb") + + if cfg.data.edge_perturb is not None: + train_str = [ + PerturbationTransform()(data) + for data in tqdm(train_str, desc=f"Apply {cfg.data.edge_perturb} to train") + ] + val_str = [ + PerturbationTransform()(data) + for data in tqdm(val_str, desc=f"Apply {cfg.data.edge_perturb} to val") + ] + test_str = [ + PerturbationTransform()(data) + for data in tqdm(test_str, desc=f"Apply {cfg.data.edge_perturb} to test") + ] + else: + pass + train_loader = DataLoader( + train_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + val_loader = DataLoader( + val_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + test_loader = DataLoader( + test_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + + # compute embeddings + tic = timer() + X_tr = compute_repr(train_loader, model, cfg) + X_val = compute_repr(val_loader, model, cfg) + X_te = compute_repr(test_loader, model, cfg) + compute_time = timer() - tic + preprocess(X_tr) + preprocess(X_val) + preprocess(X_te) + X_tr, X_val, X_te, y_tr, y_val, y_te = convert_to_numpy( + X_tr, X_val, X_te, y_tr, y_val, y_te + ) + + log.info(f"X_tr shape: {X_tr.shape} y_tr shape: {y_tr.shape}") + + if cfg.use_pca is not None: + from sklearn.decomposition import PCA + + cfg.use_pca = 1024 if X_tr.shape[1] < 10000 else 2048 + pca = PCA(cfg.use_pca) + pca = pca.fit(X_tr) + X_tr = pca.transform(X_tr) + X_val = pca.transform(X_val) + X_te = pca.transform(X_te) + log.info(f"PCA done. X_tr shape: {X_tr.shape}") + + X_tr, y_tr = torch.from_numpy(X_tr).float(), torch.from_numpy(y_tr).long() + X_val, y_val = torch.from_numpy(X_val).float(), torch.from_numpy(y_val).long() + X_te, y_te = torch.from_numpy(X_te).float(), torch.from_numpy(y_te).long() + + val_score, test_score, test_stratified_score = train_and_eval_linear( + X_tr, + y_tr, + X_val, + y_val, + X_te, + y_te, + 1195, + stratified_indices, + use_cuda=torch.cuda.is_available(), + ) + + results = [ + { + "test_top1": test_score[0], + "test_family": test_stratified_score["family"][0], + "test_superfamily": test_stratified_score["superfamily"][0], + "test_fold": test_stratified_score["fold"][0], + "val_acc": val_score, + "compute_time": compute_time, + } + ] + + pd.DataFrame(results).to_csv(f"{cfg.logs.path}/results.csv") + + +if __name__ == "__main__": + main() diff --git a/experiments/fixed/predict_gearnet.py b/experiments/fixed/predict_gearnet.py index 31d0a22..76e0ccb 100644 --- a/experiments/fixed/predict_gearnet.py +++ b/experiments/fixed/predict_gearnet.py @@ -38,7 +38,7 @@ @torch.no_grad() def compute_repr(data_loader, model, cfg): embeddings = [] - for batch_idx, data in enumerate(tqdm(data_loader)): + for batch_idx, data in enumerate(tqdm(data_loader, desc="Computing embeddings")): data = data.to(cfg.device) out = model(data, return_repr=True, aggr=cfg.aggr) out, batch = out[data.idx_mask], data.batch[data.idx_mask] diff --git a/experiments/perturbed/command_perturbed_experiment.sh b/experiments/perturbed/command_perturbed_experiment.sh new file mode 100644 index 0000000..1f96949 --- /dev/null +++ b/experiments/perturbed/command_perturbed_experiment.sh @@ -0,0 +1,33 @@ +# Baselines +python experiments/perturbed/predict_gearnet_perturbed.py +python experiments/perturbed/predict_gearnet_perturbed.py dataset=gearnet_go_bp +python experiments/perturbed/predict_gearnet_perturbed.py dataset=gearnet_go_cc +python experiments/perturbed/predict_gearnet_perturbed.py dataset=gearnet_go_mf +python experiments/perturbed/predict_mutation_perturbed.py --outdir logs/pst_perturbation/baselines +python experiments/perturbed/predict_scop_perturbed.py dataset=scop + +# Complete graphs +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt data.edge_perturb=complete +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt dataset=gearnet_go_bp data.edge_perturb=complete +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt dataset=gearnet_go_cc data.edge_perturb=complete +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt dataset=gearnet_go_mf data.edge_perturb=complete +python experiments/perturbed/predict_mutation_perturbed.py --outdir logs/pst_perturbation/complete --perturbation complete --pretrained ./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt +python experiments/perturbed/predict_scop_perturbed.py dataset=scop data.edge_perturb=complete pretrained=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt + +# Sequence graphs +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt data.edge_perturb=sequence +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt dataset=gearnet_go_bp data.edge_perturb=sequence +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt dataset=gearnet_go_cc data.edge_perturb=sequence +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt dataset=gearnet_go_mf data.edge_perturb=sequence +python experiments/perturbed/predict_mutation_perturbed.py --outdir logs/pst_perturbation/sequence --perturbation sequence --pretrained ./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt +python experiments/perturbed/predict_scop_perturbed.py dataset=scop data.edge_perturb=sequence pretrained=./logs_pst/random_ablation/edge_perturb_sequence/esm2_t6_8M_UR50D/runs/2024-07-20_17-54-40/model.pt + +# Random graphs +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt data.edge_perturb=random +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt dataset=gearnet_go_bp data.edge_perturb=random +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt dataset=gearnet_go_cc data.edge_perturb=random +python experiments/perturbed/predict_gearnet_perturbed.py pretrained=./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt dataset=gearnet_go_mf data.edge_perturb=random +python experiments/perturbed/predict_mutation_perturbed.py --outdir logs/pst_perturbation/random --perturbation random --pretrained ./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt +python experiments/perturbed/predict_scop_perturbed.py dataset=scop data.edge_perturb=random pretrained=./logs_pst/random_ablation/edge_perturb_random/esm2_t6_8M_UR50D/runs/2024-07-21_18-41-49/model.pt + +# python experiments/perturbed/predict_proteinshake_perturbed.py logs.prefix=logs_pst/proteinshake_perturbed perturbation=complete model_path=./logs_pst/random_ablation/edge_perturb_complete/esm2_t6_8M_UR50D/runs/2024-07-17_17-41-55/model.pt task=binding_site_detection,enzyme_class,gene_ontology,pfam_task,structural_class diff --git a/experiments/perturbed/predict_gearnet_perturbed.py b/experiments/perturbed/predict_gearnet_perturbed.py new file mode 100644 index 0000000..a59c0e9 --- /dev/null +++ b/experiments/perturbed/predict_gearnet_perturbed.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- +import json +import logging +import pickle +from pathlib import Path +from timeit import default_timer as timer + +import esm +import hydra +import numpy as np +import pandas as pd +import torch +import torch_geometric.nn as gnn +import torchdrug +from easydict import EasyDict as edict +from omegaconf import OmegaConf +from pyprojroot import here +from sklearn.neighbors import radius_neighbors_graph +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from torch_geometric.utils import from_scipy_sparse_matrix +from torchdrug import core, datasets, models, tasks # noqa +from tqdm.rich import tqdm + +from pst.esm2 import PST +from pst.downstream.mlp import train_and_eval_mlp +from pst.downstream import ( + preprocess, + convert_to_numpy, + mask_cls_idx, +) +from pst.transforms import RandomizeEdges, SequenceEdges, CompleteEdges +from proteinshake.transforms import Compose + +log = logging.getLogger(__name__) + +esm_alphabet = esm.data.Alphabet.from_architecture("ESM-1b") + + +@torch.no_grad() +def compute_repr(data_loader, model, cfg): + embeddings = [] + for batch_idx, data in enumerate(tqdm(data_loader)): + data = data.to(cfg.device) + out = model(data, return_repr=True, aggr=cfg.aggr) + out, batch = out[data.idx_mask], data.batch[data.idx_mask] + out = gnn.global_mean_pool(out, batch) + + if cfg.include_seq: + data.edge_index = None + out_seq = model(data, return_repr=True, aggr=cfg.aggr) + out_seq = out_seq[data.idx_mask] + out_seq = gnn.global_mean_pool(out_seq, batch) + out = (out + out_seq) * 0.5 + + out = out.cpu() + + embeddings = embeddings + list(torch.chunk(out, len(data.ptr) - 1)) + + return torch.cat(embeddings) + + +def get_structures(dataset, task, eps=8): + data_loader = torchdrug.data.DataLoader(dataset, batch_size=1, shuffle=False) + structures = [] + labels = [] + for protein in tqdm(data_loader): + out = task.graph_construction_model(protein["graph"]) + sequence = out.to_sequence()[0] + if len(sequence) == 0: + continue + coords = out.node_position + labels.append(protein["targets"]) + + torch_sequence = torch.LongTensor( + [esm_alphabet.get_idx(res) for res in esm_alphabet.tokenize(sequence)] + ) + graph_adj = radius_neighbors_graph(coords, radius=eps, mode="connectivity") + edge_index = from_scipy_sparse_matrix(graph_adj)[0].long() + torch_sequence = torch.cat( + [ + torch.LongTensor([esm_alphabet.cls_idx]), + torch_sequence, + torch.LongTensor([esm_alphabet.eos_idx]), + ] + ) + edge_index = edge_index + 1 # shift for cls_idx + + edge_attr = None + + structures.append( + Data(edge_index=edge_index, x=torch_sequence, edge_attr=edge_attr) + ) + + return structures, torch.cat(labels) + + +@hydra.main( + version_base="1.3", config_path=str(here() / "config"), config_name="pst_gearnet" +) +def main(cfg): + cfg.device = "cuda" if torch.cuda.is_available() else "cpu" + log.info(f"Configs:\n{OmegaConf.to_yaml(cfg)}") + + if cfg.include_seq and "so" not in cfg.model: + cfg.model = f"{cfg.model}_so" + + + if cfg.data.edge_perturb is not None: + pretrained_path = cfg.pretrained + model, model_cfg = PST.from_pretrained( + model_path=pretrained_path, + ) + else: + pretrained_path = Path(cfg.pretrained) / f"{cfg.model}.pt" + pretrained_path.parent.mkdir(parents=True, exist_ok=True) + model, model_cfg = PST.from_pretrained_url( + cfg.model, pretrained_path, + ) + + model.eval() + model.to(cfg.device) + + task = core.Configurable.load_config_dict( + edict(OmegaConf.to_container(cfg.task, resolve=True)) + ) + + structure_path = ( + Path(cfg.dataset.path) / f"structures_{model_cfg.data.graph_eps}.pt" + ) + if structure_path.exists(): + tmp = torch.load(structure_path) + train_str, y_tr = tmp["train_str"], tmp["y_tr"] + val_str, y_val = tmp["val_str"], tmp["y_val"] + test_str, y_te = tmp["test_str"], tmp["y_te"] + del tmp + else: + # To make torchdrug work, one has to delete unrecognized attributes... + cfg.dataset.__delattr__('name') + dataset = core.Configurable.load_config_dict( + OmegaConf.to_container(cfg.dataset, resolve=True) + ) + train_dset, val_dset, test_dset = dataset.split() + + train_str, y_tr = get_structures(train_dset, task, eps=model_cfg.data.graph_eps) + val_str, y_val = get_structures(val_dset, task, eps=model_cfg.data.graph_eps) + test_str, y_te = get_structures(test_dset, task, eps=model_cfg.data.graph_eps) + torch.save( + { + "train_str": train_str, + "val_str": val_str, + "test_str": test_str, + "y_tr": y_tr, + "y_val": y_val, + "y_te": y_te, + }, + structure_path, + ) + + # this is awful i know, todo: proper transform and dataset + train_str = [mask_cls_idx(data) for data in train_str] + val_str = [mask_cls_idx(data) for data in val_str] + test_str = [mask_cls_idx(data) for data in test_str] + + if cfg.data.edge_perturb == "random": + PerturbationTransform = RandomizeEdges + elif cfg.data.edge_perturb == "sequence": + PerturbationTransform = SequenceEdges + elif cfg.data.edge_perturb == "complete": + PerturbationTransform = CompleteEdges + elif cfg.data.edge_perturb is None: + pass + else: + raise ValueError("Invalid value for cfg.data.edge_perturb") + + if cfg.data.edge_perturb is not None: + train_str = [PerturbationTransform()(data) for data in tqdm(train_str, desc=f"Apply {cfg.data.edge_perturb} to train")] + val_str = [PerturbationTransform()(data) for data in tqdm(val_str, desc=f"Apply {cfg.data.edge_perturb} to val")] + test_str = [PerturbationTransform()(data) for data in tqdm(test_str, desc=f"Apply {cfg.data.edge_perturb} to test")] + else: + pass + + train_loader = DataLoader( + train_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + val_loader = DataLoader( + val_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + test_loader = DataLoader( + test_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + + # compute embeddings + tic = timer() + X_tr = compute_repr(train_loader, model, cfg) + X_val = compute_repr(val_loader, model, cfg) + X_te = compute_repr(test_loader, model, cfg) + compute_time = timer() - tic + preprocess(X_tr) + preprocess(X_val) + preprocess(X_te) + + X_tr, X_val, X_te, y_tr, y_val, y_te = convert_to_numpy( + X_tr, X_val, X_te, y_tr, y_val, y_te + ) + X_mask = np.isnan(X_tr.sum(1)) + X_tr, y_tr = X_tr[~X_mask], y_tr[~X_mask] + log.info(f"X_tr shape: {X_tr.shape} y_tr shape: {y_tr.shape}") + + if cfg.use_pca is not None: + from sklearn.decomposition import PCA + + cfg.use_pca = 1024 if X_tr.shape[1] < 10000 else 2048 + pca = PCA(cfg.use_pca) + pca = pca.fit(X_tr) + X_tr = pca.transform(X_tr) + X_val = pca.transform(X_val) + X_te = pca.transform(X_te) + log.info(f"PCA done. X_tr shape: {X_tr.shape}") + + X_tr, y_tr = torch.from_numpy(X_tr).float(), torch.from_numpy(y_tr).float() + X_val, y_val = torch.from_numpy(X_val).float(), torch.from_numpy(y_val).float() + X_te, y_te = torch.from_numpy(X_te).float(), torch.from_numpy(y_te).float() + + train_and_eval_mlp( + X_tr, + y_tr, + X_val, + y_val, + X_te, + y_te, + cfg, + task, + batch_size=32, + epochs=100, + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/perturbed/predict_mutation_perturbed.py b/experiments/perturbed/predict_mutation_perturbed.py new file mode 100644 index 0000000..35ecac3 --- /dev/null +++ b/experiments/perturbed/predict_mutation_perturbed.py @@ -0,0 +1,238 @@ +import argparse +import os +from collections import defaultdict +from pathlib import Path +from timeit import default_timer as timer + +import pandas as pd +import scipy +import torch +from proteinshake.utils import residue_alphabet +from proteinshake.transforms import Compose +from torch_geometric.loader import DataLoader +from tqdm import tqdm + +from pst.downstream.mutation import DeepSequenceDataset +from pst.esm2 import PST +from pst.transforms import MutationDataset, Proteinshake2ESM, RandomizeEdges, SequenceEdges, CompleteEdges + + +def load_args(): + parser = argparse.ArgumentParser( + description="Use PST for mutation prediction", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model-dir", type=str, default='.cache/pst', help="directory for downloading models" + ) + parser.add_argument( + "--model", type=str, default='pst_t6', help="pretrained model names (see README for models)" + ) + parser.add_argument( + "--pretrained", type=str, default=None, help="pretrained model path" + ) + parser.add_argument( + "--datapath", type=str, default="./datasets/dms", help="dataset prefix" + ) + parser.add_argument( + "--outdir", type=str, default="./logs_pst/dms", help="output directory", + ) + parser.add_argument( + "--protein_id", type=int, default=-1, nargs="+", help="protein id list" + ) + parser.add_argument( + "--perturbation", type=str, default=None, help="perturbation type" + ) + parser.add_argument( + "--strategy", + type=str, + choices=["masked", "wt", "mt", "mt-all"], + default="masked", + help="scoring strategy: masked marginals or wildtype marginals", + ) + args = parser.parse_args() + + args.datapath = Path(args.datapath) + args.outdir = Path(args.outdir) + + args.device = ( + torch.device(torch.cuda.current_device()) + if torch.cuda.is_available() + else torch.device("cpu") + ) + + return args + + +@torch.no_grad() +def predict_masked(model, data_loader): + logits = [] + y_true = [] + all_scores = [] + sample_lengths = [] + mt_indices = [] + wt_indices = [] + tic = timer() + for data in tqdm(data_loader, desc="Predicting"): + data = data.to(cfg.device) + out = model.mask_predict(data) + probs = torch.log_softmax(out, dim=-1) + if cfg.strategy == "mt-all": + score = probs.gather(-1, data.x.view(-1, 1)) + else: + score = probs.gather(-1, data.mt_indices) - probs.gather( + -1, data.wt_indices + ) + logits.append(out.cpu()) + y_true.append(data.y.cpu()) + all_scores.append(score.sum(dim=0).cpu()) + mt_indices.append(data.mt_indices.cpu()) + wt_indices.append(data.wt_indices.cpu()) + sample_lengths.append(len(out)) + toc = timer() + + logits = torch.cat(logits) + y_true = torch.cat(y_true) + all_scores = torch.cat(all_scores) + mt_indices = torch.cat(mt_indices) + wt_indices = torch.cat(wt_indices) + return { + "probabilities": logits, + "y_true": y_true, + "y_score": all_scores, + "mt_indices": mt_indices, + "wt_indices": wt_indices, + "sample_lengths": sample_lengths, + "total_time": toc - tic, + } + + +def label_row_wt(row, probs): + row = row.split() + wt_indices = torch.tensor( + list(map(lambda x: residue_alphabet.index(x[0]), row)) + ).view(-1, 1) + mt_indices = torch.tensor( + list(map(lambda x: residue_alphabet.index(x[-1]), row)) + ).view(-1, 1) + score = probs.gather(-1, mt_indices) - probs.gather(-1, wt_indices) + return score.sum(dim=0).item() + + +def main(): + global cfg + cfg = load_args() + print(cfg) + transforms = [Proteinshake2ESM()] + + if cfg.perturbation == "random": + transforms.append(RandomizeEdges()) + elif cfg.perturbation == "sequence": + transforms.append(SequenceEdges()) + elif cfg.perturbation == "complete": + transforms.append(CompleteEdges()) + else: + pass + + dataset_cls = DeepSequenceDataset + + protein_ids = dataset_cls.available_ids() + if isinstance(cfg.protein_id, list) and cfg.protein_id[0] != -1: + protein_ids = [protein_ids[i] for i in cfg.protein_id if i < len(protein_ids)] + else: + cfg.protein_id = list(range(len(protein_ids))) + print(f"# of Datasets: {len(protein_ids)}") + + dataset = dataset_cls(root=cfg.datapath) + mutations_list = dataset.mutations + + + if cfg.perturbation is not None: + pretrained_path = cfg.pretrained + model, model_cfg = PST.from_pretrained( + model_path=pretrained_path, + ) + else: + pretrained_path = Path(f"{cfg.model_dir}/{cfg.model}.pt") + pretrained_path.parent.mkdir(parents=True, exist_ok=True) + model, model_cfg = PST.from_pretrained_url( + cfg.model, pretrained_path + ) + + model.eval() + model.to(cfg.device) + dataset = dataset.to_graph(eps=model_cfg.data.graph_eps, transform=Compose(transforms=transforms)).pyg() + + all_results = defaultdict(list) + all_scores = [] + + for i, protein_id in zip(cfg.protein_id, protein_ids): + print("-" * 40) + print(f"Protein id: {protein_id}") + + mutations = mutations_list[i] + graph, protein_dict = dataset[i] + + df = mutations.copy() + df.rename(columns={"y": "effect"}, inplace=True) + df["protein_id"] = protein_id + df = df[["protein_id", "mutations", "effect"]] + if graph.num_nodes > 3000: + all_scores.append(df) + continue + if cfg.strategy == "masked" or cfg.strategy == "mt" or cfg.strategy == "mt-all": + ds = MutationDataset( + graph, + protein_dict, + mutations, + strategy=cfg.strategy, + use_transform=True, + transform=Compose(transforms=transforms), + ) + data_loader = DataLoader(ds, batch_size=1, shuffle=False) + results = predict_masked(model, data_loader) + + if cfg.strategy == "mt-all": + data_loader = DataLoader([graph], batch_size=1, shuffle=False) + graph = next(iter(data_loader)).to(cfg.device) + with torch.no_grad(): + out = model.mask_predict(graph) + probs = torch.log_softmax(out, dim=-1).cpu() + bias = probs.gather(-1, graph.x.cpu().view(-1, 1)).sum(dim=0) + results["y_score"] = results["y_score"] - bias + + current_dir = cfg.outdir / f"{protein_id}" + os.makedirs(current_dir, exist_ok=True) + torch.save(results, current_dir / "results.pt") + + df["PST"] = results["y_score"] + elif cfg.strategy == "wt": + data_loader = DataLoader([graph], batch_size=1, shuffle=False) + graph = next(iter(data_loader)).to(cfg.device) + with torch.no_grad(): + out = model.mask_predict(graph) + probs = torch.log_softmax(out, dim=-1).cpu() + + df["PST"] = df.apply( + lambda row: label_row_wt( + row["mutations"], + probs, + ), + axis=1, + ) + + rho = scipy.stats.spearmanr(df["effect"], df["PST"]) + print(f"Spearmanr: {rho}") + + all_scores.append(df) + all_results["protein_id"].append(protein_id) + all_results["spearmanr"].append(rho.correlation) + + all_results = pd.DataFrame.from_dict(all_results) + all_results.to_csv(cfg.outdir / "results.csv") + all_scores = pd.concat(all_scores, ignore_index=True) + all_scores.to_csv(cfg.outdir / "scores.csv") + + +if __name__ == "__main__": + main() diff --git a/experiments/perturbed/predict_proteinshake_perturbed.py b/experiments/perturbed/predict_proteinshake_perturbed.py new file mode 100644 index 0000000..6f55dba --- /dev/null +++ b/experiments/perturbed/predict_proteinshake_perturbed.py @@ -0,0 +1,197 @@ +import logging +from pathlib import Path +from timeit import default_timer as timer + +import hydra +import numpy as np +import pandas as pd +import torch +import torch_geometric.nn as gnn +from omegaconf import OmegaConf +from proteinshake.transforms import Compose +from pyprojroot import here +from sklearn.metrics import make_scorer +from sklearn.model_selection import GridSearchCV, PredefinedSplit +from torch_geometric.loader import DataLoader +from tqdm import tqdm + +from pst.downstream import compute_metrics, get_task, prepare_data +from pst.downstream.sklearn_wrapper import SklearnPredictor +from pst.esm2 import PST +from pst.transforms import ( + CompleteEdges, + PretrainingAttr, + Proteinshake2ESM, + RandomizeEdges, + SequenceEdges, +) + +log = logging.getLogger(__name__) + + +@torch.no_grad() +def compute_repr(data_loader, model, task, cfg): + embeddings = [] + for batch_idx, data in enumerate(tqdm(data_loader)): + data = data.to(cfg.device) + out = model(data, return_repr=True) + out, batch = out[data.idx_mask], data.batch[data.idx_mask] + if "protein" in task.task_type[0]: + out = gnn.global_mean_pool(out, batch) + if cfg.include_seq: + data.edge_index = None + out_seq = model(data, return_repr=True) + out_seq = out_seq[data.idx_mask] + if "protein" in task.task_type[0]: + out_seq = gnn.global_mean_pool(out_seq, batch) + out = (out + out_seq) * 0.5 + out = out.cpu() + if "protein" in task.task_type[0]: + embeddings = embeddings + list(torch.chunk(out, len(data.ptr) - 1)) + else: + embeddings = embeddings + list( + torch.split(out, tuple(torch.diff(data.ptr) - 2)) + ) + return embeddings + + +@hydra.main( + version_base="1.3", + config_path=str(here() / "config"), + config_name="pst_proteinshake", +) +def main(cfg): + cfg.device = "cuda" if torch.cuda.is_available() else "cpu" + log.info(f"Configs:\n{OmegaConf.to_yaml(cfg)}") + + if cfg.include_seq and "so" not in cfg.model: + cfg.model = f"{cfg.model}_so" + + if cfg.perturbation is not None: + pretrained_path = cfg.model_path + model, model_cfg = PST.from_pretrained( + model_path=pretrained_path, + ) + else: + pretrained_path = Path(cfg.pretrained) / f"{cfg.model}.pt" + pretrained_path.parent.mkdir(parents=True, exist_ok=True) + model, model_cfg = PST.from_pretrained_url( + cfg.model, pretrained_path, + ) + + model.eval() + model.to(cfg.device) + + transforms = [ + PretrainingAttr(), + Proteinshake2ESM(mask_cls_idx=True), + ] + if cfg.perturbation == "random": + transforms.append(RandomizeEdges()) + elif cfg.perturbation == "sequence": + transforms.append(SequenceEdges()) + elif cfg.perturbation == "complete": + transforms.append(CompleteEdges()) + else: + pass + + task = get_task(cfg.task.class_name)(root=cfg.task.path, split=cfg.split) + dataset = task.dataset.to_graph( + eps=model_cfg.data.graph_eps + ).pyg( + transform=Compose( + transforms=transforms, + ) + ) + + data_loader = DataLoader( + dataset, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + tic = timer() + X = compute_repr(data_loader, model, task, cfg) + compute_time = timer() - tic + X_tr, y_tr, X_val, y_val, X_te, y_te = prepare_data(X, task) + log.info(f"X_tr shape: {X_tr.shape} y_tr shape: {y_tr.shape}") + + ## Solving the problem with sklearn + estimator = SklearnPredictor(task.task_out) + + grid = estimator.get_grid() + + scoring = lambda y_true, y_pred: compute_metrics(y_true, y_pred, task)[ + cfg.task.metric + ] + if task.task_out == "multi_label" or task.task_out == "binary": + scoring = make_scorer(scoring, needs_threshold=True) + else: + scoring = make_scorer(scoring) + + test_split_index = [-1] * len(y_tr) + [0] * len(y_val) + X_tr_val, y_tr_val = np.concatenate((X_tr, X_val), axis=0), np.concatenate( + (y_tr, y_val) + ) + + splits = PredefinedSplit(test_fold=test_split_index) + + clf = GridSearchCV( + estimator=estimator, + param_grid=grid, + scoring=scoring, + cv=splits, + refit=False, + n_jobs=-1, + ) + + tic = timer() + clf.fit(X_tr_val, y_tr_val) + log.info(pd.DataFrame.from_dict(clf.cv_results_).sort_values("rank_test_score")) + estimator.set_params(**clf.best_params_) + clf = estimator + clf.fit(X_tr, y_tr) + clf_time = timer() - tic + ##### + + if task.task_out == "multi_label" or task.task_out == "binary": + try: + y_pred = clf.decision_function(X_te) + except: + y_pred = clf.predict(X_te) + else: + y_pred = clf.predict(X_te) + if isinstance(y_pred, list): + if y_pred[0].ndim > 1: + y_pred = [y[:, 1] for y in y_pred] + y_pred = np.asarray(y_pred).T + test_score = compute_metrics(y_te, y_pred, task)[cfg.task.metric] + log.info(f"Test score: {test_score:.3f}") + + if task.task_out == "multi_label" or task.task_out == "binary": + try: + y_val_pred = clf.decision_function(X_val) + except: + y_val_pred = clf.predict(X_val) + else: + y_val_pred = clf.predict(X_val) + if isinstance(y_val_pred, list): + if y_val_pred[0].ndim > 1: + y_val_pred = [y[:, 1] for y in y_val_pred] + y_val_pred = np.asarray(y_val_pred).T + val_score = compute_metrics(y_val, y_val_pred, task)[cfg.task.metric] + + results = [ + { + "test_score": test_score, + "val_score": val_score, + "compute_time": compute_time, + "clf_time": clf_time, + } + ] + + pd.DataFrame(results).to_csv(f"{cfg.logs.path}/results.csv") + + +if __name__ == "__main__": + main() diff --git a/experiments/perturbed/predict_scop_perturbed.py b/experiments/perturbed/predict_scop_perturbed.py new file mode 100644 index 0000000..0978710 --- /dev/null +++ b/experiments/perturbed/predict_scop_perturbed.py @@ -0,0 +1,267 @@ +import logging +from pathlib import Path +from timeit import default_timer as timer + +import esm +import hydra +import pandas as pd +import torch +import torch_geometric.nn as gnn +from omegaconf import OmegaConf +from pyprojroot import here +from sklearn.neighbors import radius_neighbors_graph +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from torch_geometric.utils import from_scipy_sparse_matrix +from tqdm import tqdm +from pst.transforms import ( + RandomizeEdges, + SequenceEdges, + CompleteEdges, +) +from pst.esm2 import PST +from pst.downstream.mlp import train_and_eval_linear +from pst.downstream import ( + preprocess, + convert_to_numpy, + mask_cls_idx, +) + +log = logging.getLogger(__name__) + +esm_alphabet = esm.data.Alphabet.from_architecture("ESM-1b") + + +@torch.no_grad() +def compute_repr(data_loader, model, cfg): + embeddings = [] + for batch_idx, data in enumerate(tqdm(data_loader)): + data = data.to(cfg.device) + out = model(data, return_repr=True, aggr=cfg.aggr) + out, batch = out[data.idx_mask], data.batch[data.idx_mask] + out = gnn.global_mean_pool(out, batch) + + if cfg.include_seq: + data.edge_index = None + out_seq = model(data, return_repr=True, aggr=cfg.aggr) + out_seq = out_seq[data.idx_mask] + out_seq = gnn.global_mean_pool(out_seq, batch) + out = (out + out_seq) * 0.5 + + out = out.cpu() + + embeddings = embeddings + list(torch.chunk(out, len(data.ptr) - 1)) + + return torch.cat(embeddings) + + +def get_structures(dataset, eps=8.0): + structures = [] + labels = [] + for protein in tqdm(dataset): + sequence = protein.seq + if len(sequence) == 0: + continue + coords = protein.pos + labels.append(torch.tensor(protein.y)) + + torch_sequence = torch.LongTensor( + [esm_alphabet.get_idx(res) for res in esm_alphabet.tokenize(sequence)] + ) + graph_adj = radius_neighbors_graph(coords, radius=eps, mode="connectivity") + edge_index = from_scipy_sparse_matrix(graph_adj)[0].long() + torch_sequence = torch.cat( + [ + torch.LongTensor([esm_alphabet.cls_idx]), + torch_sequence, + torch.LongTensor([esm_alphabet.eos_idx]), + ] + ) + edge_index = edge_index + 1 # shift for cls_idx + + edge_attr = None + + structures.append( + Data(edge_index=edge_index, x=torch_sequence, edge_attr=edge_attr) + ) + + return structures, torch.stack(labels) + + +@hydra.main( + version_base="1.3", config_path=str(here() / "config"), config_name="pst_gearnet" +) +def main(cfg): + cfg.device = "cuda" if torch.cuda.is_available() else "cpu" + log.info(f"Configs:\n{OmegaConf.to_yaml(cfg)}") + + if cfg.include_seq and "so" not in cfg.model: + cfg.model = f"{cfg.model}_so" + + + if cfg.data.edge_perturb is not None: + pretrained_path = cfg.pretrained + model, model_cfg = PST.from_pretrained( + pretrained_path, + ) + else: + pretrained_path = Path(cfg.pretrained) / f"{cfg.model}.pt" + pretrained_path.parent.mkdir(parents=True, exist_ok=True) + model, model_cfg = PST.from_pretrained_url( + cfg.model, pretrained_path, + ) + + model.eval() + model.to(cfg.device) + + scop_data = torch.load(Path(cfg.dataset.path) / "data.pt") + + structure_path = ( + Path(cfg.dataset.path) / f"structures_{model_cfg.data.graph_eps}.pt" + ) + if structure_path.exists(): + tmp = torch.load(structure_path) + train_str, y_tr = tmp["train_str"], tmp["y_tr"] + val_str, y_val = tmp["val_str"], tmp["y_val"] + test_str, y_te = tmp["test_str"], tmp["y_te"] + stratified_indices = tmp["stratified_indices"] + del tmp + else: + train_str, y_tr = get_structures( + scop_data["train"], eps=model_cfg.data.graph_eps + ) + val_str, y_val = get_structures(scop_data["val"], eps=model_cfg.data.graph_eps) + test_data = ( + scop_data["test_family"] + + scop_data["test_superfamily"] + + scop_data["test_fold"] + ) + n_fm = len(scop_data["test_family"]) + n_sf = len(scop_data["test_superfamily"]) + n_fo = len(scop_data["test_fold"]) + test_str, y_te = get_structures(test_data, eps=model_cfg.data.graph_eps) + stratified_indices = {} + stratified_indices["family"] = torch.arange(0, n_fm) + stratified_indices["superfamily"] = torch.arange(n_fm, n_fm + n_sf) + stratified_indices["fold"] = torch.arange(n_fm + n_sf, n_fm + n_sf + n_fo) + torch.save( + { + "train_str": train_str, + "val_str": val_str, + "test_str": test_str, + "y_tr": y_tr, + "y_val": y_val, + "y_te": y_te, + "stratified_indices": stratified_indices, + }, + structure_path, + ) + + # this is awful i know, todo: proper transform and dataset + train_str = [mask_cls_idx(data) for data in train_str] + val_str = [mask_cls_idx(data) for data in val_str] + test_str = [mask_cls_idx(data) for data in test_str] + + if cfg.data.edge_perturb == "random": + PerturbationTransform = RandomizeEdges + elif cfg.data.edge_perturb == "sequence": + PerturbationTransform = SequenceEdges + elif cfg.data.edge_perturb == "complete": + PerturbationTransform = CompleteEdges + elif cfg.data.edge_perturb is None: + pass + else: + raise ValueError("Invalid value for cfg.data.edge_perturb") + + if cfg.data.edge_perturb is not None: + train_str = [ + PerturbationTransform()(data) + for data in tqdm(train_str, desc=f"Apply {cfg.data.edge_perturb} to train") + ] + val_str = [ + PerturbationTransform()(data) + for data in tqdm(val_str, desc=f"Apply {cfg.data.edge_perturb} to val") + ] + test_str = [ + PerturbationTransform()(data) + for data in tqdm(test_str, desc=f"Apply {cfg.data.edge_perturb} to test") + ] + else: + pass + train_loader = DataLoader( + train_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + val_loader = DataLoader( + val_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + test_loader = DataLoader( + test_str, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + ) + + # compute embeddings + tic = timer() + X_tr = compute_repr(train_loader, model, cfg) + X_val = compute_repr(val_loader, model, cfg) + X_te = compute_repr(test_loader, model, cfg) + compute_time = timer() - tic + preprocess(X_tr) + preprocess(X_val) + preprocess(X_te) + X_tr, X_val, X_te, y_tr, y_val, y_te = convert_to_numpy( + X_tr, X_val, X_te, y_tr, y_val, y_te + ) + + log.info(f"X_tr shape: {X_tr.shape} y_tr shape: {y_tr.shape}") + + if cfg.use_pca is not None: + from sklearn.decomposition import PCA + + cfg.use_pca = 1024 if X_tr.shape[1] < 10000 else 2048 + pca = PCA(cfg.use_pca) + pca = pca.fit(X_tr) + X_tr = pca.transform(X_tr) + X_val = pca.transform(X_val) + X_te = pca.transform(X_te) + log.info(f"PCA done. X_tr shape: {X_tr.shape}") + + X_tr, y_tr = torch.from_numpy(X_tr).float(), torch.from_numpy(y_tr).long() + X_val, y_val = torch.from_numpy(X_val).float(), torch.from_numpy(y_val).long() + X_te, y_te = torch.from_numpy(X_te).float(), torch.from_numpy(y_te).long() + + val_score, test_score, test_stratified_score = train_and_eval_linear( + X_tr, + y_tr, + X_val, + y_val, + X_te, + y_te, + 1195, + stratified_indices, + use_cuda=torch.cuda.is_available(), + ) + + results = [ + { + "test_top1": test_score[0], + "test_family": test_stratified_score["family"][0], + "test_superfamily": test_stratified_score["superfamily"][0], + "test_fold": test_stratified_score["fold"][0], + "val_acc": val_score, + "compute_time": compute_time, + } + ] + + pd.DataFrame(results).to_csv(f"{cfg.logs.path}/results.csv") + + +if __name__ == "__main__": + main() diff --git a/pst/downstream/__init__.py b/pst/downstream/__init__.py index 5f4c4ff..b5cda0b 100644 --- a/pst/downstream/__init__.py +++ b/pst/downstream/__init__.py @@ -1,8 +1,10 @@ -import torch -import numpy as np import importlib -from sklearn import metrics + +import numpy as np +import torch from scipy.stats import spearmanr +from sklearn import metrics +from torch_geometric import utils def mask_cls_idx(data): @@ -31,7 +33,7 @@ def get_task(task_name): def prepare_data(X, task, use_pca=False): train_idx, val_idx, test_idx = task.train_index, task.val_index, task.test_index - if not "pair" in task.task_type[0]: + if "pair" not in task.task_type[0]: y_tr = [task.target(task.proteins[idx]) for idx in train_idx] y_val = [task.target(task.proteins[idx]) for idx in val_idx] y_te = [task.target(task.proteins[idx]) for idx in test_idx] diff --git a/pst/trainer.py b/pst/trainer.py index 362118e..79507b4 100644 --- a/pst/trainer.py +++ b/pst/trainer.py @@ -7,7 +7,6 @@ get_inverse_sqrt_schedule_with_warmup, ) - class BertTrainer(pl.LightningModule): def __init__(self, model, cfg, iterations): super().__init__() diff --git a/pst/transforms.py b/pst/transforms.py index 9d2b172..2fd9419 100644 --- a/pst/transforms.py +++ b/pst/transforms.py @@ -63,6 +63,49 @@ def __call__(self, data): return data +class RandomizeEdges(object): + def __init__(self, seed=42): + self.seed = seed + + def __call__(self, data): + edge_index = data.edge_index + torch.manual_seed(self.seed) + perm_1 = torch.randperm(edge_index.size(1)) + perm_2 = torch.randperm(edge_index.size(1)) + edge_index[0,:] = edge_index[0, perm_1] + edge_index[1,:] = edge_index[1, perm_2] + data.edge_index = edge_index + return data + + +class CompleteEdges(object): + def __init__(self): + pass + + def __call__(self, data): + num_nodes = data.num_nodes + # Create complete graph without self loops but with reciprocal edges + edge_index = utils.to_undirected( + torch.combinations(torch.arange(num_nodes), 2).t() + ) + data.edge_index = edge_index + return data + + +class SequenceEdges(object): + def __init__(self): + pass + + def __call__(self, data): + num_nodes = data.num_nodes + # Create edges between consecutive nodes + edge_index = utils.to_undirected( + torch.stack([torch.arange(num_nodes - 1), torch.arange(1, num_nodes)]) + ) + data.edge_index = edge_index + return data + + class MaskNode(object): def __init__( self, mask_idx=esm_alphabet.mask_idx, mask_rate=0.15, probs=[0.8, 0.1, 0.1] @@ -103,8 +146,8 @@ def __init__( mask_idx=esm_alphabet.mask_idx, strategy="masked", use_transform=True, + transform=None, ): - transform = Proteinshake2ESM() if use_transform: self.graph = transform(graph) else: diff --git a/train_pst.py b/train_pst.py index c13b15d..73676f1 100644 --- a/train_pst.py +++ b/train_pst.py @@ -19,11 +19,32 @@ PretrainingAttr, Proteinshake2ESM, RandomCrop, + RandomizeEdges, + SequenceEdges, + CompleteEdges, ) from pst.utils import get_graph_from_ps_protein log = logging.getLogger(__name__) +torch.set_float32_matmul_precision('medium') + +def get_loggers(cfg): + loggers = [ + pl.loggers.CSVLogger(cfg.logs.path, name="csv_logs"), + pl.loggers.TensorBoardLogger(cfg.logs.path, name="tb_logs"), + ] + if cfg.logs.wandb.enable: + loggers.append( + pl.loggers.WandbLogger( + name=cfg.logs.wandb.name, + tags=cfg.logs.wandb.tags, + entity=cfg.logs.wandb.entity, + project=cfg.logs.wandb.project, + save_dir=cfg.logs.wandb.save_dir, + ) + ) + return loggers @hydra.main( version_base="1.3", config_path=str(here() / "config"), config_name="pst_pretrain" @@ -35,6 +56,11 @@ def main(cfg): featurizer_fn = partial( get_graph_from_ps_protein, use_rbfs=True, eps=cfg.data.graph_eps ) + transforms = [ + RandomCrop(cfg.data.crop_len), + MaskNode(mask_rate=cfg.data.mask_rate), + ] + dataset = CustomGraphDataset( root=cfg.data.datapath, dataset=ps_dataset.AlphaFoldDataset( @@ -42,10 +68,7 @@ def main(cfg): ), pre_transform=featurizer_fn, transform=Compose( - [ - RandomCrop(cfg.data.crop_len), - MaskNode(mask_rate=cfg.data.mask_rate), - ] + transforms=transforms, ), n_jobs=cfg.compute.n_jobs, ) @@ -54,14 +77,27 @@ def main(cfg): dataset = datasets.AlphaFoldDataset( root=cfg.data.datapath, organism=cfg.data.organism ) + transforms = [ + PretrainingAttr(), + Proteinshake2ESM(), + RandomCrop(cfg.data.crop_len), + MaskNode(mask_rate=cfg.data.mask_rate), + ] + + if cfg.data.edge_perturb == "random": + transforms.append(RandomizeEdges()) + elif cfg.data.edge_perturb == "sequence": + transforms.append(SequenceEdges()) + elif cfg.data.edge_perturb == "complete": + transforms.append(CompleteEdges()) + elif cfg.data.edge_perturb is None: + pass + else: + raise ValueError("Invalid value for cfg.data.edge_perturb") + dataset = dataset.to_graph(eps=cfg.data.graph_eps).pyg( transform=Compose( - [ - PretrainingAttr(), - Proteinshake2ESM(), - RandomCrop(cfg.data.crop_len), - MaskNode(mask_rate=cfg.data.mask_rate), - ] + transforms=transforms, ) ) @@ -73,7 +109,6 @@ def main(cfg): shuffle=True, num_workers=cfg.training.num_workers, ) - net = PST.from_model_name( cfg.model.name, k_hop=cfg.model.k_hop, @@ -98,14 +133,11 @@ def main(cfg): max_epochs=cfg.training.epochs, precision=cfg.compute.precision, accelerator=cfg.compute.accelerator, - devices="auto", + devices=cfg.compute.devices, strategy=cfg.compute.strategy, enable_checkpointing=True, default_root_dir=cfg.logs.path, - logger=[ - pl.loggers.CSVLogger(cfg.logs.path, name="csv_logs"), - pl.loggers.TensorBoardLogger(cfg.logs.path, name="tb_logs"), - ], + logger=get_loggers(cfg), callbacks=[ pl.callbacks.LearningRateMonitor(logging_interval="epoch"), pl.callbacks.RichProgressBar(),