Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Commit

Permalink
debug and update
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Aug 14, 2024
1 parent 122d110 commit 9e07a17
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 125 deletions.
2 changes: 1 addition & 1 deletion config/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ trainer:
- class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: ${project}
save_dir: /pasteur/zeus/projets/p02/ml4ig_hot/Users/jkalfon/
save_dir: data/log/
offline: True
callbacks:
- class_path: lightning.pytorch.callbacks.StochasticWeightAveraging
Expand Down
255 changes: 255 additions & 0 deletions data/log/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# lightning.pytorch==2.3.3
seed_everything: 42
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
- class_path: lightning.pytorch.loggers.WandbLogger
init_args:
name: null
save_dir: data/log/
version: null
offline: true
dir: null
id: null
anonymous: null
project: scprint_scale
log_model: false
experiment: null
prefix: ''
checkpoint_name: null
job_type: null
config: null
entity: null
reinit: null
tags: null
group: null
notes: null
magic: null
config_exclude_keys: null
config_include_keys: null
mode: null
allow_val_change: null
resume: null
force: null
tensorboard: null
sync_tensorboard: null
monitor_gym: null
save_code: null
fork_from: null
resume_from: null
settings: null
callbacks:
- class_path: lightning.pytorch.callbacks.StochasticWeightAveraging
init_args:
swa_lrs: 0.03
swa_epoch_start: 0.8
annealing_epochs: 10
annealing_strategy: cos
avg_fn: null
device: Unable to serialize instance cpu
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
dirpath: null
filename: null
monitor: val_loss
verbose: false
save_last: true
save_top_k: 20
save_weights_only: false
mode: min
auto_insert_metric_name: true
every_n_train_steps: null
train_time_interval: null
every_n_epochs: null
save_on_train_epoch_end: null
enable_version_counter: true
fast_dev_run: false
max_epochs: null
min_epochs: null
max_steps: -1
min_steps: null
max_time: '{''hours'': 71}'
limit_train_batches: 7000
limit_val_batches: 2000
limit_test_batches: 1
limit_predict_batches: null
overfit_batches: 0.0
val_check_interval: null
check_val_every_n_epoch: 1
num_sanity_val_steps: null
log_every_n_steps: 100
enable_checkpointing: null
enable_progress_bar: null
enable_model_summary: null
accumulate_grad_batches: 1
gradient_clip_val: 100
gradient_clip_algorithm: null
deterministic: null
benchmark: null
inference_mode: true
use_distributed_sampler: true
profiler: null
detect_anomaly: false
barebones: false
plugins: null
sync_batchnorm: false
reload_dataloaders_every_n_epochs: 1
default_root_dir: null
model:
normalization: sum
d_model: 128
nhead: 2
attn_bias: none
d_hid: 512
edge_dim: 12
nlayers: 4
expr_encoder_layers: 2
layers_cls:
- 128
dropout: 0.1
transformer: flash
expr_emb_style: continuous
domain_spec_batchnorm: None
n_input_bins: 0
mvc_decoder: inner product
pred_embedding:
- cell_type_ontology_term_id
- disease_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
cell_emb_style: cls
freeze_embeddings: true
zinb: true
lr: 0.0001
optim: adamW
weight_decay: 0.01
residual_in_fp32: true
num_heads_kv: null
checkpointing: false
fused_dropout_add_ln: false
return_residual: false
prenorm: true
mlp_ratio: 4.0
fused_mlp: false
fused_bias_fc: false
sequence_parallel: false
drop_path_rate: 0.0
weight_init: ''
data:
collection_name: some
clss_to_weight:
- cell_type_ontology_term_id
- disease_ontology_term_id
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
- organism_ontology_term_id
organisms:
- NCBITaxon:9606
- NCBITaxon:10090
weight_scaler: 50
train_oversampling_per_epoch: 0.3
validation_split: 0.02
test_split: 0.02
gene_embeddings: ./data/main/gene_embeddings.parquet
use_default_col: true
gene_position_tolerance: 10000
clss_to_pred:
- cell_type_ontology_term_id
- disease_ontology_term_id
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
- organism_ontology_term_id
all_clss:
- cell_type_ontology_term_id
- disease_ontology_term_id
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
- organism_ontology_term_id
hierarchical_clss:
- cell_type_ontology_term_id
- disease_ontology_term_id
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
how: random expr
organism_name: organism_ontology_term_id
max_len: 2200
add_zero_genes: 0
do_gene_pos: ./data/main/biomart_pos.parquet
tp_name: null
assays_to_drop:
- EFO:0008853
- EFO:0010961
- EFO:0030007
- EFO:0030062
batch_size: 64
shuffle: null
batch_sampler: null
num_workers: 12
pin_memory: false
drop_last: false
timeout: 0
worker_init_fn: null
multiprocessing_context: null
generator: null
prefetch_factor: null
persistent_workers: false
pin_memory_device: ''
scprint_early_stopping:
monitor: val_loss
min_delta: 0.0
patience: 3
verbose: false
mode: min
strict: true
check_finite: true
stopping_threshold: null
divergence_threshold: null
check_on_train_epoch_end: null
log_rank_zero_only: false
scprint_learning_rate_monitor:
logging_interval: epoch
log_momentum: false
log_weight_decay: false
scprint_training:
do_denoise: true
noise:
- 0.6
do_cce: false
cce_sim: 0.5
cce_scale: 0.002
do_ecs: false
ecs_threshold: 0.3
ecs_scale: 0.05
do_mvc: false
mvc_scale: 1.0
do_adv_cls: false
do_next_tp: false
do_generate: true
class_scale: 1.5
mask_ratio: []
warmup_duration: 500
fused_adam: false
adv_class_scale: 0.1
lr_reduce_patience: 1
lr_reduce_factor: 0.6
lr_reduce_monitor: val_loss
do_cls: true
do_adv_batch: false
run_full_forward: false
lr: 0.001
optim: adamW
weight_decay: 0.01
name: ''
set_float32_matmul_precision: true
wandblog: all
log_freq: 200
log_graph: true
project: scprint_scale
ckpt_path: null
Binary file modified data/main/biomart_pos.parquet
Binary file not shown.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "scprint"
version = "1.0.13"
version = "1.0.14"
license = "MIT"
description = "scPRINT is a Large Cell Model for Gene Network Inference, Denoising and more from scRNAseq data"
authors = ["jeremie kalfon"]
Expand Down
15 changes: 8 additions & 7 deletions scprint/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lightning.pytorch.tuner.lr_finder import _LRCallback
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
import torch
from galore_torch import GaLoreAdamW
#from galore_torch import GaLoreAdamW
from math import factorial
import lightning as L
import os
Expand Down Expand Up @@ -106,11 +106,11 @@ def __init__(
"""
super().__init__()
# training flags
self.do_denoise = False
self.noise = [0.3]
self.do_denoise = True
self.noise = [0.6]
self.do_cce = False
self.cce_sim = 0.6
self.cce_scale = 0.01
self.cce_sim = 0.5
self.cce_scale = 0.002
self.do_ecs = False
self.ecs_threshold = 0.3
self.ecs_scale = 0.05
Expand All @@ -127,7 +127,7 @@ def __init__(
self.class_scale = 0.4
self.do_next_tp = False
self.do_generate = False
self.mask_ratio = [0.3]
self.mask_ratio = []
self.warmup_duration = 500
self.weight_decay = 0.01
self.optim = "adamW"
Expand Down Expand Up @@ -588,6 +588,7 @@ def configure_optimizers(self):
fused=self.fused_adam,
)
elif self.optim == "galore":
raise NotImplementedError("Galore optimizer not implemented")
param_groups = [
{
"params": [
Expand All @@ -604,7 +605,7 @@ def configure_optimizers(self):
"proj_type": "std",
},
]
optimizer = GaLoreAdamW(param_groups, lr=self.hparams.lr)
#optimizer = GaLoreAdamW(param_groups, lr=self.hparams.lr)
else:
raise ValueError(f"Unknown optimizer: {self.optim}")
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
Expand Down
2 changes: 1 addition & 1 deletion scprint/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
class_scale: float = 1.5,
mask_ratio: List[float] = [], # 0.3
warmup_duration: int = 500,
fused_adam: bool = True,
fused_adam: bool = False,
adv_class_scale: float = 0.1,
lr_reduce_patience: int = 1,
lr_reduce_factor: float = 0.6,
Expand Down
Loading

0 comments on commit 9e07a17

Please sign in to comment.