Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve MRSTFT loss on MPS #453

Open
sdatkinson opened this issue Aug 13, 2024 · 1 comment
Open

Improve MRSTFT loss on MPS #453

sdatkinson opened this issue Aug 13, 2024 · 1 comment
Labels
enhancement New feature or request priority:high Prioritize upstream: pytorch Resolving this depends on resolving an issue with PyTorch

Comments

@sdatkinson
Copy link
Owner

Prerequisite to #436. Try to make MRSTFT at least run partially on MPS to minimize the speed hit.

@sdatkinson sdatkinson added enhancement New feature or request priority:high Prioritize labels Aug 13, 2024
@sdatkinson
Copy link
Owner Author

sdatkinson commented Aug 13, 2024

It looks like the fix is to get this implemented in PyTorch. The exception raised is:

The operator 'aten::angle' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

I tried this, and it worked but there was no time savings--having this one op on CPU seems to fully account for the overhead associated with the MRSTFT loss. If it's not addressed, then no other work on this is worth it.

Here are the comments to +1!

@sdatkinson sdatkinson added the upstream: pytorch Resolving this depends on resolving an issue with PyTorch label Sep 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request priority:high Prioritize upstream: pytorch Resolving this depends on resolving an issue with PyTorch
Projects
None yet
Development

No branches or pull requests

1 participant