diff --git a/msamp/common/tensor/meta.py b/msamp/common/tensor/meta.py index d466c421..4f32673e 100644 --- a/msamp/common/tensor/meta.py +++ b/msamp/common/tensor/meta.py @@ -34,6 +34,17 @@ def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1, # lock flag to avoid the reference of the meta changed. self.locked = False + def __getstate__(self): + """Get state.""" + state = {k: v for k, v in self.__dict__.items()} + state.pop('group') + return state + + def __setstate__(self, state): + """Set state.""" + self.__dict__.update(state) + self.group = None + @staticmethod @torch.jit.script def compute_scaling_factor(amax, scale, fp_max: float, margin: int): diff --git a/msamp/common/tensor/tensor.py b/msamp/common/tensor/tensor.py index d355f0db..1eead719 100644 --- a/msamp/common/tensor/tensor.py +++ b/msamp/common/tensor/tensor.py @@ -90,6 +90,19 @@ def __init__(self, value, meta): ) self._requires_grad = False + def __getstate__(self): + """Get state.""" + state = {k: v for k, v in self.__dict__.items()} + state.pop('_backward_post_hooks') + state.pop('_grad') + return state + + def __setstate__(self, state): + """Set state.""" + self.__dict__.update(state) + self._backward_post_hooks = HookManager() + self._grad = None + @property def grad(self): """Decoration function to access _grad.""" diff --git a/tests/common/tensor/test_meta.py b/tests/common/tensor/test_meta.py index 09e3cd8b..24173f52 100644 --- a/tests/common/tensor/test_meta.py +++ b/tests/common/tensor/test_meta.py @@ -4,6 +4,7 @@ """Tests for ScalingMeta.""" import torch +import pickle import unittest from msamp.common.dtype import Dtypes @@ -59,3 +60,19 @@ def test_disable_in_time_scaling(self): meta = ScalingMeta(Dtypes.kfloat8_e4m3) self.assertFalse(meta.is_in_time_scaling()) ScalingMeta.in_time_scaling = bak + + @decorator.cuda_test + def test_meta_pickle(self): + """Test pickle and unpickle of ScalingMeta.""" + meta = ScalingMeta(Dtypes.kfloat8_e4m3) + value = torch.randn((3, 4), device='cuda') + value.cast(meta.qtype, meta=meta) + + meta2 = pickle.loads(pickle.dumps(meta)) + + self.assertEqual(meta.qtype, meta2.qtype) + self.assertEqual(meta.amax_counter, meta2.amax_counter) + self.assertEqual(meta.window_size, meta2.window_size) + self.assertTrue(torch.equal(meta.scale, meta2.scale)) + self.assertTrue(torch.equal(meta.scale_inv, meta2.scale_inv)) + self.assertTrue(torch.equal(meta.amax, meta2.amax)) diff --git a/tests/common/tensor/test_tensor.py b/tests/common/tensor/test_tensor.py index cfc5f8f9..2beee14c 100644 --- a/tests/common/tensor/test_tensor.py +++ b/tests/common/tensor/test_tensor.py @@ -4,6 +4,7 @@ """Tests for ScalingTensor.""" import unittest +import pickle import torch import numpy as np @@ -340,3 +341,20 @@ def test_grad_check_unscale_cuda(self): for dtype in dtypes: for qtype in qtypes: self._helper_test_grad_check_unscale('cuda', dtype=dtype, qtype=qtype) + + @decorator.cuda_test + def test_tensor_pickle(self): + """Test pickle and unpickle of ScalingTensor.""" + meta = ScalingMeta(Dtypes.kfloat8_e4m3) + value = torch.randn((3, 4), device='cuda') + fp8_value = value.cast(meta.qtype, meta=meta) + fp8_value.grad = torch.randn((3, 4), device='cuda') + + fp8_value2 = pickle.loads(pickle.dumps(fp8_value)) + + self.assertTrue(torch.equal(fp8_value.value, fp8_value2.value)) + self.assertTrue(torch.equal(fp8_value.meta.scale_inv, fp8_value2.meta.scale_inv)) + self.assertTrue(torch.equal(fp8_value.float(), fp8_value2.float())) + + # pickle state does not save grad + self.assertTrue(fp8_value2.grad is None)