You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
LightningCLI helps avoid boilerplate code for command line tools. The current implementation does not seem to support external accelerators, and it only accepts the accelerators present in lightning source.
Pitch
Extend support for external accelerators in LightningCLI.
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.cli import LightningCLI
from lightning_habana import HPUAccelerator
class BMAccelerator(BoringModel):
def on_fit_start(self):
assert isinstance(self.trainer.accelerator, HPUAccelerator), self.trainer.accelerator
model = BMAccelerator
accelerator = HPUAccelerator()
if __name__ == "__main__":
# Method 1, Passing supported accelerator class instance from an external library
cli = LightningCLI(model, trainer_defaults={'accelerator': accelerator}
# Method 2, passing accelerator as string
cli = LightningCLI(model, trainer_defaults={'accelerator': 'hpu'}
Gives the following tracebacks:
Method 1, passing supported accelerator class instance from an external library
Traceback (most recent call last):
File "temp.py", line 34, in <module>
cli = LightningCLI(model, trainer_defaults={'accelerator': HPUAccelerator()})
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 353, in __init__
self._run_subcommand(self.subcommand)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
fn(**fn_kwargs)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
call._call_and_handle_interrupt(
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 893, in _run
self.strategy.setup_environment()
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 127, in setup_environment
self.accelerator.setup_device(self.root_device)
File "/home/agola/lightning-habana-fork/src/lightning_habana/pytorch/accelerator.py", line 50, in setup_device
raise MisconfigurationException(f"Device should be HPU, got {device} instead.")
lightning.fabric.utilities.exceptions.MisconfigurationException: Device should be HPU, got cpu instead.
Method 2, passing accelerator as string
Traceback (most recent call last):
File "temp.py", line 33, in <module>
cli = LightningCLI(model, trainer_defaults={'accelerator': "hpu"})
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 353, in __init__
self._run_subcommand(self.subcommand)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
fn(**fn_kwargs)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
call._call_and_handle_interrupt(
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 916, in _run
call._call_lightning_module_hook(self, "on_fit_start")
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 142, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "temp.py", line 15, in on_fit_start
assert isinstance(self.trainer.accelerator,
AssertionError: <lightning.pytorch.accelerators.hpu.HPUAccelerator object at 0x7f37f62917c0>
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
🚀 Feature
LightningCLI support for external accelerators
Motivation
LightningCLI helps avoid boilerplate code for command line tools. The current implementation does not seem to support external accelerators, and it only accepts the accelerators present in lightning source.
Pitch
Extend support for external accelerators in LightningCLI.
Alternatives
Additional context
First mentioned in #54
To reproduce:
Gives the following tracebacks:
Method 1, passing supported accelerator class instance from an external library
Method 2, passing accelerator as string
Env
The text was updated successfully, but these errors were encountered: