Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix issue when training TFT model on mac M1 mps device. element 0 of tensors does not require grad and does not have a grad_fn #1725

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

fnhirwa
Copy link
Member

@fnhirwa fnhirwa commented Dec 9, 2024

This PR partially resolves an issue with training when we enable MPS fallback where tensors are being detached from the computation graph, this fix works for all optimizers supported in

optimizer (str): Optimizer, "ranger", "sgd", "adam", "adamw" or class name of optimizer in ``torch.optim``
except for ranger

part of #1721

@fnhirwa
Copy link
Member Author

fnhirwa commented Dec 10, 2024

pinging @zy636

This seems to be a problem with Ranger21 optimizer on the MPS accelerator the proposed solution works for other optimizers.

When PyTorch detects an unsupported operation on MPS, it silently switches to CPU for that operation. However:

  1. Tensor detachment: The switch between devices breaks the link in the computation graph.
  2. Non-differentiable parameters: Parameters involved in unsupported operations may no longer be tracked by PyTorch's autograd.
  3. Gradient issues: Ranger21 depends on having proper gradients for updates. If any parameter is non-differentiable or detached, it results in failures like ZeroParameterSizeError.

For now, I would advise using CPU for training, and if you want to use other optimizers once this PR is merged they'll work, I tested with adam, adamw and sgd.

Thanks

@fnhirwa fnhirwa marked this pull request as ready for review December 10, 2024 18:25
Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting from a general best practice point - when fixing a bug and all tests passed before the fix, we should add a test tha would have failed before the fix, but passes afterwards ("regression" testing).

@fnhirwa
Copy link
Member Author

fnhirwa commented Dec 10, 2024

Commenting from a general best practice point - when fixing a bug and all tests passed before the fix, we should add a test tha would have failed before the fix, but passes afterwards ("regression" testing).

this was failing on macOS when using mps as an accelerator to run the stallion example notebook, a unit test can't be provided for this case because we aren't using MPS on CI as discussed in #1648, however I'll comment with working example on this PR.

@fnhirwa
Copy link
Member Author

fnhirwa commented Dec 11, 2024

The following example works well on MPS accelerator:

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import copy
from pathlib import Path
import warnings

from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import pandas as pd
import torch

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from lightning.pytorch.tuner import Tuner


warnings.filterwarnings("ignore")

import lightning.pytorch as pl

pl.seed_everything(42)

from pytorch_forecasting.data.examples import get_stallion_data

data = get_stallion_data()

# add time index
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()

# add additional features
data["month"] = data.date.dt.month.astype(str).astype("category")  # categories have be strings
data["log_volume"] = np.log(data.volume + 1e-8)
data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")
data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean")

# we want to encode special days as one variable and thus need to first reverse one-hot encoding
special_days = [
    "easter_day",
    "good_friday",
    "new_year",
    "christmas",
    "labor_day",
    "independence_day",
    "revolution_day_memorial",
    "regional_games",
    "fifa_u_17_world_cup",
    "football_gold_cup",
    "beer_capital",
    "music_fest",
]
data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category")

max_prediction_length = 6
max_encoder_length = 24
training_cutoff = data["time_idx"].max() - max_prediction_length

training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="volume",
    group_ids=["agency", "sku"],
    min_encoder_length=max_encoder_length // 2,  # keep encoder length long (as it is in the validation set)
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["agency", "sku"],
    static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
    time_varying_known_categoricals=["special_days", "month"],
    variable_groups={"special_days": special_days},  # group of categorical variables can be treated as one variable
    time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        "volume",
        "log_volume",
        "industry_volume",
        "soda_volume",
        "avg_max_temp",
        "avg_volume_by_agency",
        "avg_volume_by_sku",
    ],
    target_normalizer=GroupNormalizer(
        groups=["agency", "sku"], transformation="softplus"
    ),  # use softplus and normalize by group
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

# create validation set (predict=True) which means to predict the last max_prediction_length points in time
# for each series
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)

# create dataloaders for model
batch_size = 128  # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)



# configure network and trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor()  # log the learning rate
logger = TensorBoardLogger("lightning_logs")  # logging results to a tensorboard


trainer = pl.Trainer(
    max_epochs=1,
    accelerator="mps",
    enable_model_summary=True,
    gradient_clip_val=0.1,
    limit_train_batches=50,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback],
    logger=logger,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=2,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=QuantileLoss(),
    log_interval=10,  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    optimizer="sgd",
    reduce_on_plateau_patience=4,
)


trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

@fnhirwa fnhirwa requested a review from fkiraly December 12, 2024 11:07
@moshesimon
Copy link

moshesimon commented Dec 19, 2024

when i try your fix on the above code i get:
train_loss_epoch=1.42e+6
and pytorch_forecasting/metrics/base_metrics.py:817: UserWarning: Loss is not finite. Resetting it to 1e9
warnings.warn("Loss is not finite. Resetting it to 1e9")

but when i just change accelerator to cpu instead of mps i get:
train_loss_epoch=301.0

so i Don't think the problem is solved

@fnhirwa
Copy link
Member Author

fnhirwa commented Dec 19, 2024

when i try your fix on the above code i get: train_loss_epoch=1.42e+6 and pytorch_forecasting/metrics/base_metrics.py:817: UserWarning: Loss is not finite. Resetting it to 1e9 warnings.warn("Loss is not finite. Resetting it to 1e9")

but when i just change accelerator to cpu instead of mps i get: train_loss_epoch=301.0

so i Don't think the problem is solved

Oh, I see. This is probably the effect of the mps accelerator. I will continue to debug to see where disparities are originating from. Somehow, the computation graph is being broken.

@fnhirwa fnhirwa marked this pull request as draft December 19, 2024 16:02
@moshesimon
Copy link

thanks so much for working on this!
I really need it fixed xD

@moshesimon
Copy link

any luck?

@fnhirwa
Copy link
Member Author

fnhirwa commented Jan 1, 2025

any luck?

sorry for the late reply I'm still digging into this.

But for now, the temporary solution would be disabling mps during training so that any code that uses accelerator="auto" don't pick mps when available. This can be achieved by definind a monkeypatch to the torch._C._mps_is_available functions which makes torch.backends.mps.is_available() to return False hence it is disabled.

import torch

def disable_torch_mps_is_available():
    return False

torch._C._mps_is_available = disable_torch_mps_is_available

def main():
    print(torch.backends.mps.is_available())

main()

@fkiraly
Copy link
Collaborator

fkiraly commented Jan 5, 2025

@moshesimon
Copy link

moshesimon commented Jan 6, 2025

I found the issue I think.

there seams to be an issue with line 48 in pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py:

upsampled = F.interpolate(x.unsqueeze(1), self.output_size, mode="linear", align_corners=True).squeeze(1)

when x is on mps it returns nan
when x is on cpu it works.

my temp fix is:

def interpolate(self, x):
        if x.device.type == 'mps':
            x = x.to('cpu')
            upsampled = F.interpolate(x.unsqueeze(1), self.output_size, mode="linear", align_corners=True).squeeze(1)
            upsampled = upsampled.to('mps')
        else:
            upsampled = F.interpolate(x.unsqueeze(1), self.output_size, mode="linear", align_corners=True).squeeze(1)
        if self.trainable:
            upsampled = upsampled * self.gate(self.mask.unsqueeze(0)) * 2.0
        return upsampled

@fnhirwa
Copy link
Member Author

fnhirwa commented Jan 6, 2025

I found the issue I think.

there seams to be an issue with line 48 in pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py:

upsampled = F.interpolate(x.unsqueeze(1), self.output_size, mode="linear", align_corners=True).squeeze(1)

when x is on mps it returns nan when x is on cpu it works.

my temp fix is:

def interpolate(self, x):
        if x.device.type == 'mps':
            x = x.to('cpu')
            upsampled = F.interpolate(x.unsqueeze(1), self.output_size, mode="linear", align_corners=True).squeeze(1)
            upsampled = upsampled.to('mps')
        else:
            upsampled = F.interpolate(x.unsqueeze(1), self.output_size, mode="linear", align_corners=True).squeeze(1)
        if self.trainable:
            upsampled = upsampled * self.gate(self.mask.unsqueeze(0)) * 2.0
        return upsampled

Nice Catch I tried it locally and started working.

@fnhirwa fnhirwa marked this pull request as ready for review January 6, 2025 17:06
@fnhirwa
Copy link
Member Author

fnhirwa commented Jan 6, 2025

Questions:

Our tests run entirely on CPU; due to the limited resources on MacOs runners. This is only an issue for users who want to use mps accelerator.

@moshesimon
Copy link

moshesimon commented Jan 7, 2025

I found just setting align_corners=False also works without needing to move to cpu. This way is faster.

@fnhirwa
Copy link
Member Author

fnhirwa commented Jan 14, 2025

@fkiraly would you take a look at the changes before merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants