From 447e8893f7c0197f1dbe1159e46b43dbab434fec Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 25 Jun 2023 00:07:06 +0800 Subject: [PATCH 1/5] feat(ScalingTensor, ScalingMeta): pickle and unpickle --- msamp/common/tensor/meta.py | 10 ++++++++++ msamp/common/tensor/tensor.py | 11 +++++++++++ tests/common/tensor/test_meta.py | 17 +++++++++++++++++ tests/common/tensor/test_tensor.py | 18 ++++++++++++++++++ 4 files changed, 56 insertions(+) diff --git a/msamp/common/tensor/meta.py b/msamp/common/tensor/meta.py index a6f5a036..465725c2 100644 --- a/msamp/common/tensor/meta.py +++ b/msamp/common/tensor/meta.py @@ -34,6 +34,16 @@ def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1, # lock flag to avoid the reference of the meta changed. self.locked = False + def __getstate__(self): + """Get state.""" + state = {k: v for k, v in self.__dict__.items()} + state.pop('group') + return state + + def __setstate__(self, state): + """Set state.""" + self.__dict__.update(state) + @staticmethod @torch.jit.script def compute_scaling_factor(amax, scale, fp_max: float, margin: int): diff --git a/msamp/common/tensor/tensor.py b/msamp/common/tensor/tensor.py index d90396ac..46c8a67f 100644 --- a/msamp/common/tensor/tensor.py +++ b/msamp/common/tensor/tensor.py @@ -86,6 +86,17 @@ def __init__(self, value, meta): raise TypeError(f'Type mismatch, value.type is {value.type}, meta.type is {meta.type}') self._requires_grad = False + def __getstate__(self): + """Get state.""" + state = {k: v for k, v in self.__dict__.items()} + state.pop('_backward_post_hooks') + state.pop('_grad') + return state + + def __setstate__(self, state): + """Set state.""" + self.__dict__.update(state) + @property def grad(self): """Decoration function to access _grad.""" diff --git a/tests/common/tensor/test_meta.py b/tests/common/tensor/test_meta.py index 09e3cd8b..334aef0e 100644 --- a/tests/common/tensor/test_meta.py +++ b/tests/common/tensor/test_meta.py @@ -4,6 +4,7 @@ """Tests for ScalingMeta.""" import torch +import pickle import unittest from msamp.common.dtype import Dtypes @@ -59,3 +60,19 @@ def test_disable_in_time_scaling(self): meta = ScalingMeta(Dtypes.kfloat8_e4m3) self.assertFalse(meta.is_in_time_scaling()) ScalingMeta.in_time_scaling = bak + + @decorator.cuda_test + def test_meta_pickle(self): + """Test pickle and unpickle of ScalingMeta.""" + meta = ScalingMeta(Dtypes.kfloat8_e4m3) + value = torch.randn((3, 4), device='cuda') + fp8_value = value.cast(meta.qtype, meta=meta) + + meta2 = pickle.loads(pickle.dumps(meta)) + + self.assertEqual(meta.qtype, meta2.qtype) + self.assertEqual(meta.amax_counter, meta2.amax_counter) + self.assertEqual(meta.window_size, meta2.window_size) + self.assertTrue(torch.equal(meta.scale, meta2.scale)) + self.assertTrue(torch.equal(meta.scale_inv, meta2.scale_inv)) + self.assertTrue(torch.equal(meta.amax, meta2.amax)) diff --git a/tests/common/tensor/test_tensor.py b/tests/common/tensor/test_tensor.py index 51b878cc..6384814f 100644 --- a/tests/common/tensor/test_tensor.py +++ b/tests/common/tensor/test_tensor.py @@ -4,6 +4,7 @@ """Tests for ScalingTensor.""" import unittest +import pickle import torch import numpy as np @@ -326,3 +327,20 @@ def test_grad_check_unscale_cuda(self): for dtype in dtypes: for qtype in qtypes: self._helper_test_grad_check_unscale('cuda', dtype=dtype, qtype=qtype) + + @decorator.cuda_test + def test_tensor_pickle(self): + """Test pickle and unpickle of ScalingTensor.""" + meta = ScalingMeta(Dtypes.kfloat8_e4m3) + value = torch.randn((3, 4), device='cuda') + fp8_value = value.cast(meta.qtype, meta=meta) + fp8_value.grad = torch.randn((3, 4), device='cuda') + + fp8_value2 = pickle.loads(pickle.dumps(fp8_value)) + + self.assertTrue(torch.equal(fp8_value.value, fp8_value2.value)) + self.assertTrue(torch.equal(fp8_value.scale_inv, fp8_value2.scale_inv)) + self.assertTrue(torch.equal(fp8_value.float(), fp8_value2.float())) + + # pickle state does not save grad + self.assertTrue(fp8_value2.grad is None) From b35d7bbbcaa7ef980657664a5fe850af0a82ffc3 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 25 Jun 2023 00:11:19 +0800 Subject: [PATCH 2/5] fix --- tests/common/tensor/test_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/common/tensor/test_tensor.py b/tests/common/tensor/test_tensor.py index 6384814f..770cdb54 100644 --- a/tests/common/tensor/test_tensor.py +++ b/tests/common/tensor/test_tensor.py @@ -339,7 +339,7 @@ def test_tensor_pickle(self): fp8_value2 = pickle.loads(pickle.dumps(fp8_value)) self.assertTrue(torch.equal(fp8_value.value, fp8_value2.value)) - self.assertTrue(torch.equal(fp8_value.scale_inv, fp8_value2.scale_inv)) + self.assertTrue(torch.equal(fp8_value.meta.scale_inv, fp8_value2.meta.scale_inv)) self.assertTrue(torch.equal(fp8_value.float(), fp8_value2.float())) # pickle state does not save grad From b6b4c0c7197db99fb6bfd4ff830af6cf14e0f200 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 25 Jun 2023 00:13:06 +0800 Subject: [PATCH 3/5] unpickle of ScalingTensor --- msamp/common/tensor/tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/msamp/common/tensor/tensor.py b/msamp/common/tensor/tensor.py index 46c8a67f..28e958c2 100644 --- a/msamp/common/tensor/tensor.py +++ b/msamp/common/tensor/tensor.py @@ -96,6 +96,8 @@ def __getstate__(self): def __setstate__(self, state): """Set state.""" self.__dict__.update(state) + self._backward_post_hooks = HookManager() + self._grad = None @property def grad(self): From 9eec4ff0c6025ac55972f74c299ff1eeed3ba51b Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 25 Jun 2023 00:15:17 +0800 Subject: [PATCH 4/5] unpickle of ScalingMeta --- msamp/common/tensor/meta.py | 1 + 1 file changed, 1 insertion(+) diff --git a/msamp/common/tensor/meta.py b/msamp/common/tensor/meta.py index 465725c2..da366961 100644 --- a/msamp/common/tensor/meta.py +++ b/msamp/common/tensor/meta.py @@ -43,6 +43,7 @@ def __getstate__(self): def __setstate__(self, state): """Set state.""" self.__dict__.update(state) + self.group = None @staticmethod @torch.jit.script From a750b629c2b870b09a329a4b3577c19860418bbb Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 25 Jun 2023 23:58:38 +0800 Subject: [PATCH 5/5] lint for meta unit test --- tests/common/tensor/test_meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/common/tensor/test_meta.py b/tests/common/tensor/test_meta.py index 334aef0e..24173f52 100644 --- a/tests/common/tensor/test_meta.py +++ b/tests/common/tensor/test_meta.py @@ -66,7 +66,7 @@ def test_meta_pickle(self): """Test pickle and unpickle of ScalingMeta.""" meta = ScalingMeta(Dtypes.kfloat8_e4m3) value = torch.randn((3, 4), device='cuda') - fp8_value = value.cast(meta.qtype, meta=meta) + value.cast(meta.qtype, meta=meta) meta2 = pickle.loads(pickle.dumps(meta))