Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ScalingTensor, ScalingMeta): pickle and unpickle #84

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions msamp/common/tensor/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1,
# lock flag to avoid the reference of the meta changed.
self.locked = False

def __getstate__(self):
"""Get state."""
state = {k: v for k, v in self.__dict__.items()}
state.pop('group')
return state

def __setstate__(self, state):
"""Set state."""
self.__dict__.update(state)
self.group = None

@staticmethod
@torch.jit.script
def compute_scaling_factor(amax, scale, fp_max: float, margin: int):
Expand Down
13 changes: 13 additions & 0 deletions msamp/common/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ def __init__(self, value, meta):
)
self._requires_grad = False

def __getstate__(self):
"""Get state."""
state = {k: v for k, v in self.__dict__.items()}
state.pop('_backward_post_hooks')
state.pop('_grad')
return state

def __setstate__(self, state):
"""Set state."""
self.__dict__.update(state)
self._backward_post_hooks = HookManager()
self._grad = None

@property
def grad(self):
"""Decoration function to access _grad."""
Expand Down
17 changes: 17 additions & 0 deletions tests/common/tensor/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Tests for ScalingMeta."""

import torch
import pickle
import unittest

from msamp.common.dtype import Dtypes
Expand Down Expand Up @@ -59,3 +60,19 @@ def test_disable_in_time_scaling(self):
meta = ScalingMeta(Dtypes.kfloat8_e4m3)
self.assertFalse(meta.is_in_time_scaling())
ScalingMeta.in_time_scaling = bak

@decorator.cuda_test
def test_meta_pickle(self):
"""Test pickle and unpickle of ScalingMeta."""
meta = ScalingMeta(Dtypes.kfloat8_e4m3)
value = torch.randn((3, 4), device='cuda')
value.cast(meta.qtype, meta=meta)

meta2 = pickle.loads(pickle.dumps(meta))

self.assertEqual(meta.qtype, meta2.qtype)
self.assertEqual(meta.amax_counter, meta2.amax_counter)
self.assertEqual(meta.window_size, meta2.window_size)
self.assertTrue(torch.equal(meta.scale, meta2.scale))
self.assertTrue(torch.equal(meta.scale_inv, meta2.scale_inv))
self.assertTrue(torch.equal(meta.amax, meta2.amax))
18 changes: 18 additions & 0 deletions tests/common/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Tests for ScalingTensor."""

import unittest
import pickle
import torch
import numpy as np

Expand Down Expand Up @@ -340,3 +341,20 @@ def test_grad_check_unscale_cuda(self):
for dtype in dtypes:
for qtype in qtypes:
self._helper_test_grad_check_unscale('cuda', dtype=dtype, qtype=qtype)

@decorator.cuda_test
def test_tensor_pickle(self):
"""Test pickle and unpickle of ScalingTensor."""
meta = ScalingMeta(Dtypes.kfloat8_e4m3)
value = torch.randn((3, 4), device='cuda')
fp8_value = value.cast(meta.qtype, meta=meta)
fp8_value.grad = torch.randn((3, 4), device='cuda')

fp8_value2 = pickle.loads(pickle.dumps(fp8_value))

self.assertTrue(torch.equal(fp8_value.value, fp8_value2.value))
self.assertTrue(torch.equal(fp8_value.meta.scale_inv, fp8_value2.meta.scale_inv))
self.assertTrue(torch.equal(fp8_value.float(), fp8_value2.float()))

# pickle state does not save grad
self.assertTrue(fp8_value2.grad is None)
Loading