Skip to content

Commit

Permalink
fix: Use the 'keep_ctrl' parameter in class DataSplitter
Browse files Browse the repository at this point in the history
  • Loading branch information
dzyim committed Nov 10, 2023
1 parent 19732b3 commit d64a5f9
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 24 deletions.
5 changes: 2 additions & 3 deletions examples/pert/conf.drug.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

project = "Perturb"
dataset = "data/K562_compound188"
keep_ctrl = false
split = "unseen"
seed = 42

Expand All @@ -15,7 +16,6 @@ pad_value = -2 # 0 or -2
pert_pad_id = 2
#n_bins = 51
include_zero_gene = "batch-wise" # include zero expr genes in training input: "all", "batch-wise", "row-wise", or False
#n_hvg = 0 # number of highly variable genes
max_seq_len = 1500 # n_hvg+1 if n_hvg > 0
#per_seq_batch_sample = true

Expand Down Expand Up @@ -72,9 +72,8 @@ normalize_total = 1e4 # Step 3: whether to normalize the raw data and
result_normed_key = "X_normed" # the key in adata.layers to store the normalized data
log1p = true # Step 4: whether to log1p the normalized data # log1p = data_is_raw
result_log1p_key = "X_log1p" # the key in adata.layers to store the log transformed data
subset_hvg = 5000 # Step 5: whether to subset the raw data to highly variable genes # subset_hvg = n_hvg
subset_hvg = 5000 # Step 5: whether to subset the raw data to, and the number of, highly variable genes
hvg_flavor = "seurat_v3" # hvg_flavor = "seurat_v3" if data_is_raw else "cell_ranger"
binning = false # Step 6: whether to bin the raw data and to what number of bins
result_binned_key = "X_binned" # the key in adata.layers to store the binned data


5 changes: 2 additions & 3 deletions examples/pert/conf.gene.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

project = "Perturb"
dataset = "adamson"
keep_ctrl = true
split = "unseen"
seed = 42

Expand All @@ -15,7 +16,6 @@ pad_value = -2 # 0 or -2
pert_pad_id = 2
#n_bins = 51
include_zero_gene = "all" # include zero expr genes in training input: "all", "batch-wise", "row-wise", or False
#n_hvg = 0 # number of highly variable genes
max_seq_len = 1500 # n_hvg+1 if n_hvg > 0
#per_seq_batch_sample = true

Expand Down Expand Up @@ -73,9 +73,8 @@ normalize_total = 1e4 # Step 3: whether to normalize the raw data and
result_normed_key = "X_normed" # the key in adata.layers to store the normalized data
log1p = true # Step 4: whether to log1p the normalized data # log1p = data_is_raw
result_log1p_key = "X_log1p" # the key in adata.layers to store the log transformed data
subset_hvg = false # Step 5: whether to subset the raw data to highly variable genes # subset_hvg = n_hvg
subset_hvg = false # Step 5: whether to subset the raw data to, and the number of, highly variable genes
hvg_flavor = "seurat_v3" # hvg_flavor = "seurat_v3" if data_is_raw else "cell_ranger"
binning = false # Step 6: whether to bin the raw data and to what number of bins
result_binned_key = "X_binned" # the key in adata.layers to store the binned data


2 changes: 1 addition & 1 deletion examples/pert/finetune_drug_perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_config():
]
pert_data = PertData(
dataset=conf.h.p.dataset,
keep_ctrl=False,
keep_ctrl=conf.h.p.keep_ctrl,
seed=conf.h.p.seed,
vocab_file=vocab_file,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/pert/finetune_gene_perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_config():
]
pert_data = PertData(
dataset=conf.h.p.dataset,
keep_ctrl=False,
keep_ctrl=conf.h.p.keep_ctrl,
seed=conf.h.p.seed,
vocab_file=vocab_file,
)
Expand Down
57 changes: 41 additions & 16 deletions perturb/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,18 +355,20 @@ def __init__(
vocab_file: Optional[Path] = None,
) -> None:
self.logger = create_logger(name=self.__class__.__name__)
self.logger.info('Init...')
self.dataset = str(dataset)
self.workdir = Path(workdir)
if not self.workdir.exists():
self.workdir.mkdir(parents=True)
self.keep_ctrl = keep_ctrl
self.seed = seed

self._load_adata()
self._check_mode()
self.logger.info(f'mode = {repr(self.mode)}')
self._drop_NA_genes()
self.set_vocab(vocab_file)
self.splitter = DataSplitter(self, seed)
self.splitter = DataSplitter(self)
self.logger.info(f'Got a data splitter: seed = {seed}')

def _load_adata(self) -> None:
Expand Down Expand Up @@ -776,13 +778,26 @@ class DataSplitter:

split_types = ['unseen', 'shuffle']

def __init__(self, pert_data: PertData, seed: Optional[int] = None) -> None:
self.seed = seed
self.data = pert_data
self.save_dir = Path(pert_data.dataset_path)
def __init__(
self,
pert_data: PertData,
save_dir: Optional[Path] = None,
seed: Optional[int] = None,
) -> None:
self.adata = pert_data.adata
self.pert_col = pert_data.pert_col
self.ctrl_str = pert_data.ctrl_str
self.keep_ctrl = pert_data.keep_ctrl

def set_save_dir(self, save_dir: Path) -> None:
self.save_dir = Path(save_dir)
if save_dir is None:
self.save_dir = pert_data.dataset_path
else:
self.save_dir = Path(save_dir)

if seed is None:
self.seed = pert_data.seed
else:
self.seed = int(seed)

def get_save_file(self) -> Path:
return self.save_dir / 'train_test_split.json'
Expand All @@ -805,15 +820,18 @@ def prepare_split(
raise ValueError(f"Invalid split_type: {repr(split_type)}!")

if split_type == 'unseen':
perts = self.data.adata.obs[self.data.pert_col].unique()
perts_no_ctrl = np.setdiff1d(perts, self.data.ctrl_str)
perts = self.adata.obs[self.pert_col].unique()
if not self.keep_ctrl:
perts = np.setdiff1d(perts, self.ctrl_str)

perts_no_ctrl = np.setdiff1d(perts, self.ctrl_str)
test_not_given = (
test_perts is None
or not len(np.intersect1d(test_perts, perts))
or len(np.intersect1d(test_perts, perts)) <= 0
)
val_not_given = (
val_perts is None
or not len(np.intersect1d(val_perts, perts))
or len(np.intersect1d(val_perts, perts)) <= 0
)
np.random.seed(self.seed)
if test_not_given and val_not_given:
Expand Down Expand Up @@ -847,6 +865,7 @@ def prepare_split(
write_json(
{
'type': split_type,
'keep_ctrl': self.keep_ctrl,
'val_size': val_size,
'test_size': test_size,
},
Expand All @@ -859,15 +878,20 @@ def split_data(self) -> dict:
except OSError:
split = None

adata = self.data.adata
if split is None:
adata = self.adata if self.keep_ctrl else self.adata[
self.adata.obs[self.pert_col] != self.ctrl_str
]
train_data, test_data = train_test_split(
adata, test_size=0.1, shuffle=True,
)
train_data, val_data = train_test_split(
train_data, test_size=0.1, shuffle=True,
)
elif split['type'] == 'shuffle':
adata = self.adata if split['keep_ctrl'] else self.adata[
self.adata.obs[self.pert_col] != self.ctrl_str
]
train_data, test_data = train_test_split(
adata, test_size=split['test_size'], shuffle=True,
)
Expand All @@ -878,11 +902,12 @@ def split_data(self) -> dict:
train_perts, val_perts, test_perts = (
split['train'], split['val'], split['test']
)
# Are shuffles needed for train and val sets here?
train_data, val_data, test_data = (
adata[adata.obs[self.data.pert_col].isin(train_perts)],
adata[adata.obs[self.data.pert_col].isin(val_perts)],
adata[adata.obs[self.data.pert_col].isin(test_perts)],
self.adata[self.adata.obs[self.pert_col].isin(train_perts)],
self.adata[self.adata.obs[self.pert_col].isin(val_perts)],
self.adata[self.adata.obs[self.pert_col].isin(test_perts)],
)

return {'train': train_data, 'val': val_data, 'test': test_data}

0 comments on commit d64a5f9

Please sign in to comment.