Skip to content

Commit

Permalink
Improve bfloat16 serialization (backward compatible) (#553)
Browse files Browse the repository at this point in the history
-    added bfloat16 serialization that sends 2 bytes per value (previously, we sent 4);
-    changed de-serialization code so it supports both modes of serialization.
-    the new mode can be enabled via export USE_LEGACY_BFLOAT16=0
-    tested in pytorch 1.12 and 1.13

---------

Co-authored-by: Aleksandr Borzunov <[email protected]>
  • Loading branch information
justheuristic and borzunov authored Feb 9, 2023
1 parent e6b6219 commit 7d1bb7d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
27 changes: 21 additions & 6 deletions hivemind/compression/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import os
import warnings
from abc import ABC, abstractmethod
from enum import Enum, auto
Expand All @@ -13,6 +14,7 @@
# While converting read-only NumPy arrays into PyTorch tensors, we don't make extra copies for efficiency
warnings.filterwarnings("ignore", message="The given NumPy array is not writable", category=UserWarning)

USE_LEGACY_BFLOAT16 = bool(int(os.environ.get("USE_LEGACY_BFLOAT16", 1)))

Key = Any

Expand Down Expand Up @@ -81,26 +83,39 @@ class NoCompression(CompressionBase):

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
tensor = tensor.detach()
shape = tensor.shape
dtype_name = str(tensor.dtype).lstrip("torch.")
raw_data = tensor
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float32)
if USE_LEGACY_BFLOAT16:
raw_data = tensor.to(torch.float32)
else:
typed_storage = tensor.storage()
storage = typed_storage.untyped() if hasattr(typed_storage, "untyped") else typed_storage._untyped()
raw_data = torch.tensor(storage, dtype=torch.int8)

return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=tensor.numpy().tobytes(),
size=tensor.shape,
buffer=raw_data.numpy().tobytes(),
size=shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
shape = torch.Size(serialized_tensor.size)
if serialized_tensor.dtype == "bfloat16":
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
tensor = torch.as_tensor(array, dtype=torch.bfloat16)
if len(serialized_tensor.buffer) // shape.numel() == 4: # legacy mode: convert to fp32
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
tensor = torch.as_tensor(array, dtype=torch.bfloat16)
else: # efficient mode: send bfloat16 data directly
storage_type = torch.TypedStorage if hasattr(torch, "TypedStorage") else torch._TypedStorage
storage = storage_type.from_buffer(serialized_tensor.buffer, byte_order="little", dtype=torch.bfloat16)
tensor = torch.as_tensor(storage, dtype=torch.bfloat16)
else:
array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
tensor = torch.as_tensor(array)
return tensor.reshape(tuple(serialized_tensor.size))
return tensor.reshape(shape)

def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return 1.0
4 changes: 3 additions & 1 deletion tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def test_serialize_tensor():
_check(torch.tensor(1.0), CompressionType.FLOAT16)


@pytest.mark.parametrize("use_legacy_bfloat16", [True, False])
@pytest.mark.forked
def test_serialize_bfloat16():
def test_serialize_bfloat16(use_legacy_bfloat16: bool):
hivemind.compression.base.USE_LEGACY_BFLOAT16 = use_legacy_bfloat16
tensor = torch.randn(4096, 16, dtype=torch.bfloat16)
_check(tensor, CompressionType.NONE)
_check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)
Expand Down

0 comments on commit 7d1bb7d

Please sign in to comment.