Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for loading fp8 checkpoint #68

Open
wenscarl opened this issue May 21, 2024 · 12 comments
Open

Support for loading fp8 checkpoint #68

wenscarl opened this issue May 21, 2024 · 12 comments

Comments

@wenscarl
Copy link
Contributor

wenscarl commented May 21, 2024

There is a use_fp flag for the offline_quantize tool in saxml/tool to quantize the weight in fp8 but still has to be stored in int8(

# This is needed since fp8 cannot be saved.
). If that is always the case, is there any example showcasing how to load a checkpoint in int8 but interpret as fp8? @jianlijianli @zhangqiaorjc

@jianlijianli
Copy link
Contributor

Hi wenscarl, could you please take a look here? Since we did a bitcast during save, we need to use a bitcast during serving.

@wenscarl
Copy link
Contributor Author

Thanks. If the checkpoint is generated by use_fp=True, but at weights setup, the dtype is still int8. Is there something I am missing? I only see one example here.
For fp8 inference, we only set 4 layers in transformer, ffn1, ffn2, combined_qkv and post to be fp8 and they do not go through quantized_einsum. Could you provide some design advice? Thanks!

@jianlijianli
Copy link
Contributor

Hi wenscarl, we can enable running the fp8 model (stored as int8) by setting dtype=jnp.float8_e4m3fn in the for_transformer() decorator.

There are several architecture variations but for a basic transformer with MHA, ffn1, ffn2, combined_qkv and post are the weights that for_transformer() quantizes (i.e. that API would have trouble if you want to quantize ffn1, ffn2 and post, but not combined_qkv). So setting the dtype=jnp.float8_e4m3fn might just work for your use case.

Could you please elaborate on "they do not go through quantized_einsum"?

@wenscarl
Copy link
Contributor Author

wenscarl commented May 23, 2024

Could you please elaborate on "they do not go through quantized_einsum"?

For fp8 training/inference, those layers are replaced by Fp8EinsumOp. See the USE_FP8 option in PAXML here.
I tried dtype=jnp.float8_e4m3fn with the decorator, it complains about the following at loading train_state:

ValueError: Cannot intersect index domain {  } with index domain { [0, 512) }: Ranks do not match [source locations='tensorstore/index_space/index_transform.cc:507']
E0523 20:23:00.429406 140603804873152 server.py:146] Exception during loading: <_InactiveRpcError of RPC that terminated with:
        status = StatusCode.INVALID_ARGUMENT
        details = "Cannot intersect index domain {  } with index domain { [0, 512) }: Ranks do not match [source locations='tensorstore/index_space/index_transform.cc:507']"
        debug_error_string = "UNKNOWN:Error received from peer ipv4:127.0.0.1:10001 {grpc_message:"Cannot intersect index domain {  } with index domain { [0, 512) }: Ranks do not match [source locations=\'tensorstore/index_space/index_transform.cc:507\']", grpc_status:3, created_time:"2024-05-23T20:23:00.428698069+00:00"}"

from the shape [512], it looks like some bias, since the model_dim is set to 512(see the following model config).
The way I used offline_quantize.py is

python saxml/tools/offline_quantize.py --input_dir /tmp/ckpts/fp8_train//checkpoints/checkpoint_00000000/state/ --output_dir /output/patch/checkoutpoint_0000000/state --quantization_configs "gptj" --use_fp True
where the input is a f32 checkpoint.
The gptj config inside `quantization_config.py` is modified as, 
``` factor = 1.0
  configs = {
      'ff_layer.ffn_layer1.linear.w': ([0, 1], factor, 0, -1),
      'ff_layer.ffn_layer2.linear.w': ([0, 1], factor, 0, -1),
      'self_attention.combined_qkv.w': ([0, 1, 2, 3], factor, 0, -1),
      'self_attention.post.w': ([0,1, 2], factor, 0, -1),
  }

to in line with the shape of weights in each layer. The full model config is:

@servable_model_registry.register
@template.make_servable()
@quantization.for_transformer(quantize_on_the_fly=False, dtype=jnp.float8_e4m3fn, linear_only=True)
class LmCloudSpmd2B(lm_cloud.LmCloudSpmd2B):
  # pylint: disable=line-too-long
  """Servable config on 1x1x4.

  Checkpoint:
  gs://sax-data/lm_cloud_2b_mesh_3/1/checkpoints/checkpoint_00000000
  """
  # pylint: enable=line-too-long

  SPM_MODEL = os.path.join(os.path.dirname(__file__), 'test_model.model')
  ICI_MESH_SHAPE = [1, 1, 1]
  NUM_LAYERS = 5
  MODEL_DIMS = 512
  HIDDEN_DIMS = MODEL_DIMS * 4
  FPROP_FOR_PREFIX = True
  BATCH_SIZE = 1
  TRAINING_OPTIMIZED_SHARDING = False
  USE_REPEATED_LAYER = False

  @property
  def test_mode(self) -> bool:
    return False

  def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
    task_p = super().task()
    task_p = template.set_decoding_sharding_hparams(
        task_p,
        mesh_shape=self.ICI_MESH_SHAPE,
    )
    return task_p

@wenscarl
Copy link
Contributor Author

wenscarl commented May 23, 2024

The shapes for model weights are:

mdl_vars.params.lm.transformer.x_layers_1.ff_layer.ffn_layer1.linear.w/
[512, 2048]

mdl_vars.params.lm.transformer.x_layers_1.ff_layer.ffn_layer1.linear.w/
[2048, 512]

mdl_vars.params.lm.transformer.x_layers_1.self_attention.combined_qkv.w
[3,512,4,128],

mdl_vars.params.lm.transformer.x_layers_1.self_attention.post.w
[512, 4, 128]

If the gptj config is set to

  configs = {
      'ff_layer.ffn_layer1.linear.w': ([0], factor, 0, -1),
      'ff_layer.ffn_layer2.linear.w': ([0], factor, 0, -1),
      'self_attention.combined_qkv.w': ([0, 1, 2], factor, 0, -1),
      'self_attention.post.w': ([0, 1], factor, 0, -1),
  }

the checkpoint loading is able to make through. But will hit error,

pybind11_abseil.status.StatusNotOk: OUT_OF_RANGE: Invalid id: 11011
         [[{{function_node map_1_while_body_459}}{{node SentenceTokenizer/SentencepieceDetokenizeOp}}]]
E0523 21:46:05.865277 140622177190336 server.py:146] Exception during loading: <_InactiveRpcError of RPC that terminated with:
        status = StatusCode.INTERNAL
        details = "Loading error: OUT_OF_RANGE: Invalid id: 11011
         [[{{function_node map_1_while_body_459}}{{node SentenceTokenizer/SentencepieceDetokenizeOp}}]]"
        debug_error_string = "UNKNOWN:Error received from peer ipv4:127.0.0.1:10001 {created_time:"2024-05-23T21:46:05.864696671+00:00", grpc_status:13, grpc_message:"Loading error: OUT_OF_RANGE: Invalid id: 11011\n\t [[{{function_node map_1_while_body_459}}{{node SentenceTokenizer/SentencepieceDetokenizeOp}}]]"}"
>

Is there any example showcasing how to run offline_quantize tool to generated a quantized checkpoint and then loaded by saxml for inference?

@jianlijianli
Copy link
Contributor

Hi wenscarl, thanks for all the details. I can confirm fp8 works 100% when the infra was first added. There are no public examples around fp8 since it was experimental.

I think your offline_quantize script is correct. The decorator should be

@quantization.for_transformer(quantize_on_the_fly=False, dtype=jnp.float8_e4m3fn)

The linear_only will load quantized tensors for ff1, ff2 and float tensors for combined_qkv, and post.

The sentence piece error is unrelated. If you are able to run float model with that config (exclude quantization stuff), the sentence piece should work for quantized checkpoint/config as well. Did you run into any issues with float?

@wenscarl
Copy link
Contributor Author

Did you run into any issues with float?
I can run with float without any issue by: 1. not using for_transformer decorator, 2. directly loading in a float checkpoint rather than pre-quantized checkpoint.

@wenscarl
Copy link
Contributor Author

wenscarl commented May 28, 2024

Updates:
Running inference with weight in fp8 is experimentally feasible with change except that the scale is dummy.
Only feed_forward layers are quantized with linear_only=True.
For fp8 matmul, only per-tensor-scale is supported now by cublasLt. The gptj config is set as

  configs = {
      'ff_layer.ffn_layer1.linear.w': ([0, 1], factor, 0, -1),
      'ff_layer.ffn_layer2.linear.w': ([0,1], factor, 0, -1),
  }

such that the resulting w_quantized_scale is a scalar of shape []. But here specifies that the scale_shape being a vector. But even forcing scale_shape to be a scalar, there will be error:

I0528 21:12:05.819829 140044237002304 checkpointer.py:164] Restoring item from /tmp/ckpts/mybucket/fp8/checkpoint_00000000.
E0528 21:12:05.862007 140044237002304 model_service_base.py:1557] Invalid load request. model_key: /sax/test/lm2b, model_path: saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2B, error: Cannot intersect index domain {  } with index domain { [0, 1) }: Ranks do not match [source locations='tensorstore/index_space/index_transform.cc:507']
Traceback (most recent call last):
  File "/model-server-bin/server.runfiles/__main__/saxml/server/model_service_base.py", line 1547, in _run_primary_worker_loop
    self._load_model(
  File "/model-server-bin/server.runfiles/__main__/saxml/server/model_service_base.py", line 1368, in _load_model
    self._loaded_models.load(
  File "/model-server-bin/server.runfiles/__main__/saxml/server/model_service_base.py", line 511, in load
    loaded = params.load(key, ckpt_path, self._primary_process_id, prng_key)
  File "/model-server-bin/server.runfiles/__main__/saxml/server/pax/servable_model_params.py", line 154, in load
    model.load(checkpoint_path, jax.random.PRNGKey(prng_key))
  File "/model-server-bin/server.runfiles/__main__/saxml/server/pax/servable_model.py", line 410, in load
    model, model_state = self.load_state(checkpoint_path, init_key, precompile)
  File "/model-server-bin/server.runfiles/__main__/saxml/server/pax/servable_model.py", line 506, in load_state
    partitioned_train_state = CKPT_MODULE.restore_checkpoint(
  File "/opt/paxml/paxml/checkpoints.py", line 246, in restore_checkpoint
    output = checkpoint_manager.restore(
  File "/opt/paxml/paxml/checkpoint_managers.py", line 605, in restore
    restored = self._manager.restore(
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_manager.py", line 867, in restore
    restored = self._checkpointer.restore(restore_directory, args=args)
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpointer.py", line 166, in restore
    restored = self._handler.restore(directory, args=ckpt_args)
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/composite_checkpoint_handler.py", line 459, in restore
    restored[item_name] = handler.restore(
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/composite_checkpoint_handler.py", line 137, in restore
    return self._handler.restore(directory, *args.args, **args.kwargs)
  File "/opt/paxml/paxml/checkpoints.py", line 587, in restore
    restored_train_state = super().restore(
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 1073, in restore
    restored_item = asyncio.run(
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/opt/paxml/paxml/checkpoints.py", line 511, in _maybe_deserialize
    return await super()._maybe_deserialize(
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 903, in _maybe_deserialize
    deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py", line 1531, in deserialize
    ret = await asyncio.gather(*deserialize_ops)
  File "/opt/jax/jax/experimental/array_serialization/serialization.py", line 311, in async_deserialize
    return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
  File "/opt/jax/jax/experimental/array_serialization/serialization.py", line 76, in create_async_array_from_callback
    dbs = await asyncio.gather(*future_arrays)
  File "/opt/jax/jax/experimental/array_serialization/serialization.py", line 281, in cb
    restricted_domain = t.domain.intersect(requested_domain)
ValueError: Cannot intersect index domain {  } with index domain { [0, 1) }: Ranks do not match [source locations='tensorstore/index_space/index_transform.cc:507']
E0528 21:12:05.864998 140067501011392 server.py:146] Exception during loading: <_InactiveRpcError of RPC that terminated with:

Is there an example showcasing how to do per-tensor-scaling properly?

@wenscarl
Copy link
Contributor Author

@jianlijianli for viz.

@jianlijianli
Copy link
Contributor

Hi wenscarl, sorry for the delay. Pax defaults everything to per-channel quantization so there is no API for per-tensor quantization, but it should be easy to hack a bit locally to run per-tensor. I think all we need is to set scale dims to [1] here.

The runtime should be able to handle both per-tensor and per-channel scale, thanks to the broadcast behavior of multiply.

@wenscarl
Copy link
Contributor Author

wenscarl commented Jun 5, 2024

Hi @jianlijianli, some updates.
I am able to run fp8 inference with the following experimental praxis and saxml PRs.
google/saxml#27
#72
Due to the cublasLt fp8 matmul limitations, the quantization has to be weight-and-activation, per-tensor, symmetric.

  1. The static activation quantization is not supported yet, here. How should one properly provide support for it in the future?
  2. The comment here is a bit confusing. Doesn't it support the activation quantization in the attention layer?

@jianlijianli
Copy link
Contributor

Hi wenscarl, apologies for the delay and really glad you had fp8 working now.

The static activation quantization is less useful so we didn't support it yet. But it should not be too hard to support it. How do you plan to do static activation? The static activation scale are collected from QAT training or calibration?

The comment in https://github.com/google/praxis/blob/main/praxis/layers/quantization/quantize.py#L286 is a bit confusing. The entire quantize.py is rewriting checkpoints so from that point of view it's always weight-only, for both ffw and attention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants