Skip to content

Commit

Permalink
pass transform and its kwargs to dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
RemyLau committed Oct 17, 2023
1 parent 261be95 commit 38899d5
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/obnb/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import obnb.label.split
from obnb.dataset.base import Dataset
from obnb.label import filters
from obnb.typing import List, LogLevel, Optional
from obnb.typing import Any, Callable, Dict, List, LogLevel, Optional
from obnb.util.converter import GenePropertyConverter
from obnb.util.version import parse_data_version

Expand Down Expand Up @@ -35,6 +35,9 @@ class OpenBiomedNetBench(Dataset):
gene ids for filtering. More specifically, only genes that are
present in the network and in the provided selected gene list will
be used. Only use network genes if this is list is not provided.
transform: Transform function or name of the transform class.
transform_kwargs: Keyword arguments for initializing the transform
function. Only effective when transform is passed as a string.
log_level: Logging level.
"""
Expand All @@ -55,6 +58,8 @@ def __init__(
val_ratio: float = 0.2,
test_ratio: float = 0.2,
selected_genes: Optional[List[str]] = None,
transform: Optional[Callable] = None,
transform_kwargs: Optional[Dict[str, Any]] = None,
log_level: LogLevel = "INFO",
):
"""Initialize OpenBiomedNetBench object."""
Expand Down Expand Up @@ -125,4 +130,6 @@ def __init__(
label=label,
splitter=splitter,
auto_generate_feature=auto_generate_feature,
transform=transform,
transform_kwargs=transform_kwargs,
)

0 comments on commit 38899d5

Please sign in to comment.