Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fp8 direct quantization #69

Merged
merged 1 commit into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 122 additions & 26 deletions praxis/layers/injection/fp8_nvidia_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Fp8EinsumOp(base_layer.BaseLayer):
"""Wrapper around jnp.einsum used in standard Pax layers."""

amax_history_length: int = 1024
use_direct_quant: bool = True

def setup(self) -> None:
scale_args, amax_history_args = _get_fp8_args(
Expand Down Expand Up @@ -128,9 +129,7 @@ def quantized_einsum(
return y, x_qdq
return y

def __call__(
self, equation: str, *args: JTensor
) -> Union[JTensor, tuple[JTensor, JTensor]]:
def __call__(self, equation: str, *args: JTensor) -> JTensor:
assert len(args) == 2
x = args[0]
k = args[1]
Expand All @@ -141,10 +140,58 @@ def __call__(
), f'k dtype has to be {comp_dtype}, but got {k.dtype}'
x = jnp.asarray(x, comp_dtype)

y = self.quantized_einsum(equation, x, k, return_quantized_x=False)
if self.use_direct_quant:
def _quantized_dot_general(
lhs, rhs, dimension_numbers, precision=None,
preferred_element_type=None
):
theta = self.theta
return fp8_ops.q_dot_dq(
lhs,
rhs,
lhs_scale=theta.input_scale,
rhs_scale=theta.kernel_scale,
out_grad_scale=theta.output_grad_scale,
lhs_amax_history=theta.input_amax_history,
rhs_amax_history=theta.kernel_amax_history,
out_grad_amax_history=theta.output_grad_amax_history,
compute_dtype=comp_dtype,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general)
else:
y = self.quantized_einsum(equation, x, k, return_quantized_x=False)

return y

# This decorator wraps a function to perform quantized dot product.
# It prepares the arguments for quantized_dot, including the pre-quantized input,
# scales, and amax histories. This allows for efficient FP8 matrix multiplication while
# managing quantization parameters.
def quantized_dot_config(
compute_dtype, q_lhs, lhs_scale, q_rhs, rhs_scale, out_grad_scale,
out_grad_amax_history
):
def decorator(func):
def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
return fp8_ops.quantized_dot(
lhs=lhs,
q_lhs=q_lhs,
lhs_scale=lhs_scale,
rhs=rhs,
q_rhs=q_rhs,
rhs_scale=rhs_scale,
out_grad_scale=out_grad_scale,
out_grad_amax_history=out_grad_amax_history,
compute_dtype=compute_dtype,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type
)
return wrapper
return decorator

class Fp8EinsumGatedOp(Fp8EinsumOp):
"""Wrapper around two jnp.einsum for gated FFN."""
Expand Down Expand Up @@ -181,29 +228,78 @@ def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]:
), f'k dtype has to be {comp_dtype}, but got {k.dtype} and {k_gated.dtype}'
x = jnp.asarray(x, comp_dtype)

y, x_qdq = self.quantized_einsum(equation, x, k, return_quantized_x=True)

theta = self.theta

k_gated_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
k_gated,
theta.kernel_scale_gated,
theta.kernel_amax_history_gated,
)
y_gated_qdq = jnp.einsum(
equation,
x_qdq,
k_gated_qdq,
_dot_general=fp8_ops.dot_general_with_precision,
)
y_gated = fp8_ops.out_qdq(
comp_dtype,
jnp.float8_e5m2,
y_gated_qdq,
theta.output_grad_scale_gated,
theta.output_grad_amax_history_gated,
)
if self.use_direct_quant:
q_x, new_input_scale = fp8_ops.in_q(
comp_dtype, jnp.float8_e4m3fn, x, theta.input_scale, theta.input_amax_history
)
q_k, new_kernel_scale = fp8_ops.in_q(
comp_dtype, jnp.float8_e4m3fn, k, theta.kernel_scale, theta.kernel_amax_history
)
q_k_gated, new_kernel_scale_gated = fp8_ops.in_q(
comp_dtype, jnp.float8_e4m3fn, k_gated, theta.kernel_scale_gated, theta.kernel_amax_history_gated
)
common_args = (comp_dtype, q_x, new_input_scale)
main_fp8_metas = (
q_k,
new_kernel_scale,
theta.output_grad_scale,
theta.output_grad_amax_history
)
gated_fp8_metas = (
q_k_gated,
new_kernel_scale_gated,
theta.output_grad_scale_gated,
theta.output_grad_amax_history_gated
)

@quantized_dot_config(*common_args, *main_fp8_metas)
def _quantized_dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
pass

@quantized_dot_config(*common_args, *gated_fp8_metas)
def _quantized_dot_general_gated(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
pass

y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general)
y_gated = jnp.einsum(equation, x, k_gated, _dot_general=_quantized_dot_general_gated)

y = fp8_ops.out_dq(
dq_type=x.dtype,
lhs_scale=new_input_scale,
rhs_scale=new_kernel_scale,
out=y
)
y_gated = fp8_ops.out_dq(
dq_type=x.dtype,
lhs_scale=new_input_scale,
rhs_scale=new_kernel_scale_gated,
out=y
)
else:
y, x_qdq = self.quantized_einsum(
equation, x, k, return_quantized_x=True
)
k_gated_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
k_gated,
theta.kernel_scale_gated,
theta.kernel_amax_history_gated,
)
y_gated_qdq = jnp.einsum(
equation,
x_qdq,
k_gated_qdq,
_dot_general=fp8_ops.dot_general_with_precision,
)
y_gated = fp8_ops.out_qdq(
comp_dtype,
jnp.float8_e5m2,
y_gated_qdq,
theta.output_grad_scale_gated,
theta.output_grad_amax_history_gated,
)

return y, y_gated
11 changes: 7 additions & 4 deletions praxis/layers/injection/fp8_nvidia_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from functools import partial

from absl.testing import absltest
from absl.testing import absltest, parameterized
from flax.linen.fp8_ops import qdq
import jax
from jax import numpy as jnp
Expand All @@ -30,9 +30,10 @@

PARAMS = base_layer.PARAMS

class Fp8LinearsTest(test_utils.TestCase):
class Fp8LinearsTest(test_utils.TestCase, parameterized.TestCase):

def test_fp8_einsum_injection(self):
@parameterized.parameters([True, False])
def test_fp8_einsum_injection(self, use_direct_quant):
# Used to cast the inputs to be representable in FP8, so that the difference
# of the results from the original gemm and fp8 gemm is small.
cast_to_representable = partial(
Expand Down Expand Up @@ -100,7 +101,9 @@ def _train(variables, x):
}

output1a, output1b = run(None, expected_shapes_original)
einsum_tpl = pax_fiddle.Config(fp8_ops.Fp8EinsumOp)
einsum_tpl = pax_fiddle.Config(
fp8_ops.Fp8EinsumOp, use_direct_quant=use_direct_quant
)
output2a, output2b = run(einsum_tpl, expected_shapes_new)
dw1, dw2 = output1b[0][PARAMS]['w'], output2b[0][PARAMS]['w']
dx1, dx2 = output1b[1], output2b[1]
Expand Down