-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Fix issue when training TFT model on mac M1 mps device. element 0 of tensors does not require grad and does not have a grad_fn #1725
base: main
Are you sure you want to change the base?
Conversation
pinging @zy636 This seems to be a problem with When PyTorch detects an unsupported operation on MPS, it silently switches to CPU for that operation. However:
For now, I would advise using CPU for training, and if you want to use other optimizers once this PR is merged they'll work, I tested with Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commenting from a general best practice point - when fixing a bug and all tests passed before the fix, we should add a test tha would have failed before the fix, but passes afterwards ("regression" testing).
this was failing on macOS when using |
The following example works well on MPS accelerator: import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import copy
from pathlib import Path
import warnings
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import pandas as pd
import torch
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from lightning.pytorch.tuner import Tuner
warnings.filterwarnings("ignore")
import lightning.pytorch as pl
pl.seed_everything(42)
from pytorch_forecasting.data.examples import get_stallion_data
data = get_stallion_data()
# add time index
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()
# add additional features
data["month"] = data.date.dt.month.astype(str).astype("category") # categories have be strings
data["log_volume"] = np.log(data.volume + 1e-8)
data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")
data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean")
# we want to encode special days as one variable and thus need to first reverse one-hot encoding
special_days = [
"easter_day",
"good_friday",
"new_year",
"christmas",
"labor_day",
"independence_day",
"revolution_day_memorial",
"regional_games",
"fifa_u_17_world_cup",
"football_gold_cup",
"beer_capital",
"music_fest",
]
data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category")
max_prediction_length = 6
max_encoder_length = 24
training_cutoff = data["time_idx"].max() - max_prediction_length
training = TimeSeriesDataSet(
data[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
target="volume",
group_ids=["agency", "sku"],
min_encoder_length=max_encoder_length // 2, # keep encoder length long (as it is in the validation set)
max_encoder_length=max_encoder_length,
min_prediction_length=1,
max_prediction_length=max_prediction_length,
static_categoricals=["agency", "sku"],
static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
time_varying_known_categoricals=["special_days", "month"],
variable_groups={"special_days": special_days}, # group of categorical variables can be treated as one variable
time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=[
"volume",
"log_volume",
"industry_volume",
"soda_volume",
"avg_max_temp",
"avg_volume_by_agency",
"avg_volume_by_sku",
],
target_normalizer=GroupNormalizer(
groups=["agency", "sku"], transformation="softplus"
), # use softplus and normalize by group
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)
# create validation set (predict=True) which means to predict the last max_prediction_length points in time
# for each series
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
# create dataloaders for model
batch_size = 128 # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
# configure network and trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor() # log the learning rate
logger = TensorBoardLogger("lightning_logs") # logging results to a tensorboard
trainer = pl.Trainer(
max_epochs=1,
accelerator="mps",
enable_model_summary=True,
gradient_clip_val=0.1,
limit_train_batches=50, # coment in for training, running valiation every 30 batches
# fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
callbacks=[lr_logger, early_stop_callback],
logger=logger,
)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.03,
hidden_size=16,
attention_head_size=2,
dropout=0.1,
hidden_continuous_size=8,
loss=QuantileLoss(),
log_interval=10, # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
optimizer="sgd",
reduce_on_plateau_patience=4,
)
trainer.fit(
tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
) |
when i try your fix on the above code i get: but when i just change accelerator to cpu instead of mps i get: so i Don't think the problem is solved |
Oh, I see. This is probably the effect of the mps accelerator. I will continue to debug to see where disparities are originating from. Somehow, the computation graph is being broken. |
thanks so much for working on this! |
any luck? |
sorry for the late reply I'm still digging into this. But for now, the temporary solution would be disabling import torch
def disable_torch_mps_is_available():
return False
torch._C._mps_is_available = disable_torch_mps_is_available
def main():
print(torch.backends.mps.is_available())
main() |
Questions:
|
I found the issue I think. there seams to be an issue with line 48 in pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py:
when x is on mps it returns nan my temp fix is:
|
Nice Catch I tried it locally and started working. |
Our tests run entirely on CPU; due to the limited resources on MacOs runners. This is only an issue for users who want to use |
I found just setting align_corners=False also works without needing to move to cpu. This way is faster. |
@fkiraly would you take a look at the changes before merging. |
This PR partially resolves an issue with training when we enable MPS fallback where tensors are being detached from the computation graph, this fix works for all optimizers supported in
pytorch-forecasting/pytorch_forecasting/models/base_model.py
Line 442 in d57b0bb
ranger
part of #1721