Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sketch #483

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from etna.transforms.base import Transform

if SETTINGS.torch_required:
from torch.utils.data import Dataset
from torch.utils.data import Dataset, IterableDataset


class TSDataset:
Expand Down Expand Up @@ -1829,6 +1829,17 @@ def info(self, segments: Optional[Sequence[str]] = None) -> None:
result_string = "\n".join(lines)
print(result_string)

def to_torch_generator_dataset(
self, make_samples: Callable[[pd.DataFrame], Union[Iterator[dict], Iterable[dict]]], dropna: bool = True
) -> "Dataset":
df = self.to_pandas(flatten=True)
if dropna:
df = df.dropna() # TODO: Fix this

# indexes = "generate indexes"

return _TorchGeneratorDataset(df=df, indexes=indexes, make_samples=make_samples)

def to_torch_dataset(
self, make_samples: Callable[[pd.DataFrame], Union[Iterator[dict], Iterable[dict]]], dropna: bool = True
) -> "Dataset":
Expand Down
20 changes: 19 additions & 1 deletion etna/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from etna import SETTINGS

if SETTINGS.torch_required:
from torch.utils.data import Dataset
from torch.utils.data import Dataset, IterableDataset
else:
from unittest.mock import Mock

Expand Down Expand Up @@ -195,6 +195,24 @@ def duplicate_data(df: pd.DataFrame, segments: Sequence[str], format: str = Data
return df_long


class _TorchGeneratorDataset(Dataset):
"""In memory dataset for torch dataloader."""

def __init__(self, df, indexes, make_samples):
self.df = df
self.indexes = indexes
self.make_samples = make_samples

def __getitem__(self, index):
# choose segment dataframe by index
# choose slice of segment dataframe by index
sample = [sample for sample in self.make_samples(df_segment)]
return sample[0]

def __len__(self):
return len(self.indexes)


class _TorchDataset(Dataset):
"""In memory dataset for torch dataloader."""

Expand Down
22 changes: 16 additions & 6 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def __init__(
decoder_length: int,
train_batch_size: int,
test_batch_size: int,
generate_running_samples: bool,
trainer_params: Optional[dict],
train_dataloader_params: Optional[dict],
test_dataloader_params: Optional[dict],
Expand Down Expand Up @@ -527,6 +528,7 @@ def __init__(
self.decoder_length = decoder_length
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.generate_running_samples = generate_running_samples
self.train_dataloader_params = {} if train_dataloader_params is None else train_dataloader_params
self.test_dataloader_params = {} if test_dataloader_params is None else test_dataloader_params
self.val_dataloader_params = {} if val_dataloader_params is None else val_dataloader_params
Expand All @@ -553,12 +555,20 @@ def fit(self, ts: TSDataset) -> "DeepBaseModel":
:
Model after fit
"""
torch_dataset = ts.to_torch_dataset(
functools.partial(
self.net.make_samples, encoder_length=self.encoder_length, decoder_length=self.decoder_length
),
dropna=True,
)
if self.generate_running_samples:
torch_dataset = ts.to_torch_generator_dataset(
functools.partial(
self.net.make_samples, encoder_length=self.encoder_length, decoder_length=self.decoder_length
),
dropna=True,
)
else:
torch_dataset = ts.to_torch_dataset(
functools.partial(
self.net.make_samples, encoder_length=self.encoder_length, decoder_length=self.decoder_length
),
dropna=True,
)
self.raw_fit(torch_dataset)
return self

Expand Down
2 changes: 2 additions & 0 deletions etna/models/nn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def __init__(
loss: Optional["torch.nn.Module"] = None,
train_batch_size: int = 16,
test_batch_size: int = 16,
generate_running_samples: bool = False,
optimizer_params: Optional[dict] = None,
trainer_params: Optional[dict] = None,
train_dataloader_params: Optional[dict] = None,
Expand Down Expand Up @@ -353,6 +354,7 @@ def __init__(
encoder_length=encoder_length,
train_batch_size=train_batch_size,
test_batch_size=test_batch_size,
generate_running_samples=generate_running_samples,
train_dataloader_params=train_dataloader_params,
test_dataloader_params=test_dataloader_params,
val_dataloader_params=val_dataloader_params,
Expand Down
Loading