Skip to content

Commit

Permalink
[Feature] Enable bfloat16 convert functions in Python API (dmlc#5760)
Browse files Browse the repository at this point in the history
  • Loading branch information
itaraban authored Jul 31, 2023
1 parent b6f5ba9 commit 8c213ef
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 13 deletions.
104 changes: 94 additions & 10 deletions docs/source/guide/mixed_precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ Chapter 8: Mixed Precision Training
===================================
DGL is compatible with the `PyTorch Automatic Mixed Precision (AMP) package
<https://pytorch.org/docs/stable/amp.html>`_
for mixed precision training, thus saving both training time and GPU memory
consumption. This feature requires DGL 0.9+.
for mixed precision training, thus saving both training time and GPU/CPU memory
consumption. This feature requires DGL 0.9+ and 1.1+ for CPU bloat16.

Message-Passing with Half Precision
-----------------------------------
Expand Down Expand Up @@ -58,18 +58,19 @@ DGL relies on PyTorch's AMP package for mixed precision training,
and the user experience is exactly
the same as `PyTorch's <https://pytorch.org/docs/stable/notes/amp_examples.html>`_.

By wrapping the forward pass with ``torch.cuda.amp.autocast()``, PyTorch automatically
By wrapping the forward pass with ``torch.amp.autocast()``, PyTorch automatically
selects the appropriate datatype for each op and tensor. Half precision tensors are memory
efficient, most operators on half precision tensors are faster as they leverage GPU tensorcores.
efficient, most operators on half precision tensors are faster as they leverage GPU tensorcores
and CPU special instructon set.

.. code::
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.amp import autocast
def forward(g, feat, label, mask, model, amp_dtype):
def forward(device_type, g, feat, label, mask, model, amp_dtype):
amp_enabled = amp_dtype in (torch.float16, torch.bfloat16)
with autocast(enabled=amp_enabled, dtype=amp_dtype):
with autocast(device_type, enabled=amp_enabled, dtype=amp_dtype):
logit = model(g, feat)
loss = F.cross_entropy(logit[mask], label[mask])
return loss
Expand Down Expand Up @@ -104,7 +105,7 @@ Pay attention to the differences in the code when AMP is activated or not.
from dgl.nn import GATConv
from dgl.transforms import AddSelfLoop
amp_dtype = torch.float16 # or torch.bfloat16
amp_dtype = torch.bfloat16 # or torch.float16
class GAT(nn.Module):
def __init__(self,
Expand All @@ -130,7 +131,8 @@ Pay attention to the differences in the code when AMP is activated or not.
# Data loading
transform = AddSelfLoop()
data = RedditDataset(transform)
dev = torch.device('cuda')
device_type = 'cuda' # or 'cpu'
dev = torch.device(device_type)
g = data[0]
g = g.int().to(dev)
Expand All @@ -151,7 +153,7 @@ Pay attention to the differences in the code when AMP is activated or not.
for epoch in range(100):
optimizer.zero_grad()
loss = forward(g, feat, label, train_mask, model, amp_dtype)
loss = forward(device_type, g, feat, label, train_mask, model, amp_dtype)
if amp_dtype == torch.float16:
# Backprop w/ gradient scaling
Expand All @@ -169,5 +171,87 @@ If we change the number of heads to ``[2, 2, 2]``, training without fp16
triggers GPU OOM(out-of-memory) issue while training with fp16 consumes
15.7G GPU memory.

BFloat16 CPU example
-----------------------------------
DGL supports running training in the bfloat16 data type on the CPU.
This data type doesn't require any CPU feature and can improve the performance of a memory-bound model.
Starting with Intel Xeon 4th Generation, which has `AMX
<https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html>`_ instructon set, bfloat16 should significantly improve training and inference performance without huge code changes.
Here is an example of simple GCN bfloat16 training:

.. code::
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data import CiteseerGraphDataset
from dgl.nn import GraphConv
from dgl.transforms import AddSelfLoop
class GCN(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# two-layer GCN
self.layers.append(
GraphConv(in_size, hid_size, activation=F.relu)
)
self.layers.append(GraphConv(hid_size, out_size))
self.dropout = nn.Dropout(0.5)
def forward(self, g, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(g, h)
return h
# Data loading
transform = AddSelfLoop()
data = CiteseerGraphDataset(transform=transform)
g = data[0]
g = g.int()
train_mask = g.ndata['train_mask']
feat = g.ndata['feat']
label = g.ndata['label']
in_size = feat.shape[1]
hid_size = 16
out_size = data.num_classes
model = GCN(in_size, hid_size, out_size)
# Convert model and graph to bfloat16
g = dgl.to_bfloat16(g)
feat = feat.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
model.train()
# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
loss_fcn = nn.CrossEntropyLoss()
for epoch in range(100):
logits = model(g, feat)
loss = loss_fcn(logits[train_mask], label[train_mask])
loss.backward()
optimizer.step()
print('Epoch {} | Loss {}'.format(epoch, loss.item()))
The only difference with common training is model and graph conversion before training/inference.

.. code::
g = dgl.to_bfloat16(g)
feat = feat.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
DGL is still improving its half-precision support and the compute kernel's
performance is far from optimal, please stay tuned to our future updates.
13 changes: 13 additions & 0 deletions examples/pytorch/gat/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse

import dgl
import dgl.nn as dglnn

import torch
Expand Down Expand Up @@ -88,6 +89,12 @@ def train(g, features, labels, masks, model):
default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed').",
)
parser.add_argument(
"--dt",
type=str,
default="float",
help="data type(float, bfloat16)",
)
args = parser.parse_args()
print(f"Training with DGL built-in GATConv module.")

Expand Down Expand Up @@ -115,6 +122,12 @@ def train(g, features, labels, masks, model):
out_size = data.num_classes
model = GAT(in_size, 8, out_size, heads=[8, 1]).to(device)

# convert model and graph to bfloat16 if needed
if args.dt == "bfloat16":
g = dgl.to_bfloat16(g)
features = features.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)

# model training
print("Training...")
train(g, features, labels, masks, model)
Expand Down
12 changes: 12 additions & 0 deletions examples/pytorch/gcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def train(g, features, labels, masks, model):
default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed').",
)
parser.add_argument(
"--dt",
type=str,
default="float",
help="data type(float, bfloat16)",
)
args = parser.parse_args()
print(f"Training with DGL built-in GraphConv module.")

Expand Down Expand Up @@ -99,6 +105,12 @@ def train(g, features, labels, masks, model):
out_size = data.num_classes
model = GCN(in_size, 16, out_size).to(device)

# convert model and graph to bfloat16 if needed
if args.dt == "bfloat16":
g = dgl.to_bfloat16(g)
features = features.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)

# model training
print("Training...")
train(g, features, labels, masks, model)
Expand Down
12 changes: 12 additions & 0 deletions examples/pytorch/graphsage/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def inference(self, g, device, batch_size):
y = torch.empty(
g.num_nodes(),
self.hid_size if l != len(self.layers) - 1 else self.out_size,
dtype=feat.dtype,
device=buffer_device,
pin_memory=pin_memory,
)
Expand Down Expand Up @@ -171,6 +172,12 @@ def train(args, device, g, dataset, model, num_classes):
help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training.",
)
parser.add_argument(
"--dt",
type=str,
default="float",
help="data type(float, bfloat16)",
)
args = parser.parse_args()
if not torch.cuda.is_available():
args.mode = "cpu"
Expand All @@ -189,6 +196,11 @@ def train(args, device, g, dataset, model, num_classes):
out_size = dataset.num_classes
model = SAGE(in_size, 256, out_size).to(device)

# convert model and graph to bfloat16 if needed
if args.dt == "bfloat16":
g = dgl.to_bfloat16(g)
model = model.to(dtype=torch.bfloat16)

# model training
print("Training...")
train(args, device, g, dataset, model, num_classes)
Expand Down
13 changes: 13 additions & 0 deletions examples/pytorch/graphsage/train_full.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse

import dgl
import dgl.nn as dglnn

import torch
Expand Down Expand Up @@ -69,6 +70,12 @@ def train(g, features, labels, masks, model):
default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed')",
)
parser.add_argument(
"--dt",
type=str,
default="float",
help="data type(float, bfloat16)",
)
args = parser.parse_args()
print(f"Training with DGL built-in GraphSage module")

Expand Down Expand Up @@ -96,6 +103,12 @@ def train(g, features, labels, masks, model):
out_size = data.num_classes
model = SAGE(in_size, 16, out_size).to(device)

# convert model and graph to bfloat16 if needed
if args.dt == "bfloat16":
g = dgl.to_bfloat16(g)
features = features.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)

# model training
print("Training...")
train(g, features, labels, masks, model)
Expand Down
1 change: 1 addition & 0 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def data_type_dict():
"""Returns a dictionary from data type string to the data type.
The dictionary should include at least:
bfloat16
float16
float32
float64
Expand Down
1 change: 1 addition & 0 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

def data_type_dict():
return {
"bfloat16": th.bfloat16,
"float16": th.float16,
"float32": th.float32,
"float64": th.float64,
Expand Down
1 change: 1 addition & 0 deletions python/dgl/backend/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def zerocopy_from_dlpack(dlpack_tensor):

def data_type_dict():
return {
"bfloat16": tf.bfloat16,
"float16": tf.float16,
"float32": tf.float32,
"float64": tf.float64,
Expand Down
13 changes: 12 additions & 1 deletion python/dgl/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,18 +990,29 @@ def _astype_float(self, new_type):
F.float64,
F.float32,
F.float16,
F.bfloat16,
], "'new_type' must be floating-point type: %s" % str(new_type)
newframe = self.clone()
new_columns = {}
for name, column in self._columns.items():
dtype = column.dtype
if dtype != new_type and dtype in [F.float64, F.float32, F.float16]:
if dtype != new_type and dtype in [
F.float64,
F.float32,
F.float16,
F.bfloat16,
]:
new_columns[name] = column.astype(new_type)
else:
new_columns[name] = column
newframe._columns = new_columns
return newframe

def bfloat16(self):
"""Return a new frame with all floating-point columns converted
to bfloat16"""
return self._astype_float(F.bfloat16)

def half(self):
"""Return a new frame with all floating-point columns converted
to half-precision (float16)"""
Expand Down
19 changes: 19 additions & 0 deletions python/dgl/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"random_walk_pe",
"laplacian_pe",
"lap_pe",
"to_bfloat16",
"to_half",
"to_float",
"to_double",
Expand Down Expand Up @@ -3711,6 +3712,24 @@ def laplacian_pe(g, k, padding=False, return_eigval=False):
return lap_pe(g, k, padding, return_eigval)


def to_bfloat16(g):
r"""Cast this graph to use bfloat16 for any
floating-point edge and node feature data.
A shallow copy is returned so that the original graph is not modified.
Feature tensors that are not floating-point will not be modified.
Returns
-------
DGLGraph
Clone of graph with the feature data converted to float16.
"""
ret = copy.copy(g)
ret._edge_frames = [frame.bfloat16() for frame in ret._edge_frames]
ret._node_frames = [frame.bfloat16() for frame in ret._node_frames]
return ret


def to_half(g):
r"""Cast this graph to use float16 (half-precision) for any
floating-point edge and node feature data.
Expand Down
7 changes: 5 additions & 2 deletions tests/python/common/test_heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,7 +2443,7 @@ def test_dtype_cast(idtype):


def test_float_cast():
for t in [F.float16, F.float32, F.float64]:
for t in [F.bfloat16, F.float16, F.float32, F.float64]:
idtype = F.int32
g = dgl.heterograph(
{
Expand All @@ -2469,6 +2469,7 @@ def test_float_cast():
("c", F.float64),
("d", F.int32),
("e", F.int64),
("f", F.bfloat16),
]
for name, type in dataNamesTypes:
g.nodes["user"].data[name] = F.copy_to(
Expand All @@ -2487,6 +2488,8 @@ def test_float_cast():
F.tensor(pvalues, dtype=type), ctx=F.ctx()
)

if t == F.bfloat16:
g = dgl.transforms.functional.to_bfloat16(g)
if t == F.float16:
g = dgl.transforms.functional.to_half(g)
if t == F.float32:
Expand All @@ -2498,7 +2501,7 @@ def test_float_cast():
# integer tensors shouldn't be converted
reqType = (
t
if (origType in [F.float16, F.float32, F.float64])
if (origType in [F.bfloat16, F.float16, F.float32, F.float64])
else origType
)

Expand Down

0 comments on commit 8c213ef

Please sign in to comment.