Highlights
We are excited to announce the 0.7.0 release of torchao! This release moves QAT out of prototype with improved LoRA support and more flexible APIs, and adds support for new experimental kernels such as Marlin QQQ (for CUDA), int8_dynamic_activation_intx_weight
(for ARM CPU), and more!
QAT moved out of prototype, LoRA integration, new flexible APIs (#1020, #1085, #1152, #1037, #1152)
QAT has been moved out of prototype to torchao/quantization/qat
to provide better API stability guarantees moving forward. In addition to the existing *QATQuantizer
classes, we now also support the more flexible FakeQuantizedLinear
and FakeQuantizedEmbedding
modules for users to configure the exact quantization settings they wish to use during QAT.
from torchao.quantization.qat.api import FakeQuantizeConfig
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
# Specify quantization schemes to use during QAT
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=8)
# Replace nn.Linear and nn.Embedding with these in your model
fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config)
fq_embedding = FakeQuantizedEmbedding(16, 32, weight_config=weight_config)
We also leveraged the new flexible APIs to build a new QAT + LoRA fine-tuning flow in torchtune. Try it out today!
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora
Marlin QQQ for CUDA (#1113)
Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. For more details about Marlin QQQ, please refer to paper.
from torchao.dtypes import MarlinQQQLayout
quantize_(
model,
int8_dynamic_activation_int4_weight(
group_size=128,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
Benchmarking results can be found in https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#marlin-qqq.
This is a prototype feature - feel free to try out!
int8_dynamic_activation_intx_weight Quantization for ARM CPU (#995, #1027, #1254, #1353)
We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon).
from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight
assert precision == torch.float32, "int8_dynamic_activation_intx_weight requires fp32 precision"
# Build kernels in temp location, and load them in torch
# This requires an ARM CPU
from torchao.experimental.temp_build import temp_build_and_load_torchao_ops
temp_build_and_load_torchao_ops(cmake_lists_path=os.path.dirname(os.path.realpath(__file__)) + "/../../experimental")
# Quantize model
nbit = 4
assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8"
group_size = 128
has_weight_zeros = False
quantize_(
model,
int8_dynamic_activation_intx_weight(
group_size=group_size,
nbit=nbit,
has_weight_zeros=has_weight_zeros,
),
)
Benchmarking results can be found in https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#int8_dynamic_activation_intx_weight-quantization
We are still trying to figure out how to ship the ARM CPU kernels, so the exact API is subject to change.
BC Breaking
Rename AQT#2 LayoutType -> Layout (#1049)
Before:
from torchao.dtypes import (
BlockSparseLayoutType,
Int4CPULayoutType,
MarlinQQQLayoutType,
MarlinSparseLayoutType,
SemiSparseLayoutType,
TensorCoreTiledLayoutType,
UintxLayoutType,
Float8LayoutType,
LayoutType,
PlainLayoutType,
)
After:
from torchao.dtypes import (
BlockSparseLayout,
Int4CPULayout,
MarlinQQQLayout,
MarlinSparseLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
UintxLayout,
Float8Layout,
Layout,
PlainLayout,
)
QAT imports after move out of prototype (#1091)
Before:
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
ComposableQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int4WeightOnlyEmbeddingQATQuantizer
Int8DynActInt4WeightQATQuantizer,
Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.prototype.qat.api import (
FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
FakeQuantizer,
)
After:
from torchao.quantization.qat import (
ComposableQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int4WeightOnlyEmbeddingQATQuantizer
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.qat.api import (
FakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
)
New Features
- Add BF16 stochastic rounding option for optimizers (#1124)
- Add quantize_() API support for NF4 (#1216)
- Support W4A8 Marlin kernel (#1113)
Improvements
quantize_
- Add default filtering to remove mis-alinged weights (#1194)
- Add tensor parallelism support for int4_weight_only quantization (#1120)
- Add support for asymmetric act quant for int8 dynamic quant (#1131)
- Add support for groupwise quantization for int8 weight only quantization (#1121)
- Add AQT tensor parallel for float8_dynamic_quant (#1078)
- Int8wo Embedding Quant (#1167)
- Making sure int4 weight only supports cpu as well (#1203)
- BF16 support for Quant-LLM kernel (#1147)
- Add hardware check to fp8 quant (#1314)
- Add support for quantize_() with Float8Linear module (#1344)
autoquant
- Added support for Per Tensor Scaling for Float8 Dynamic Autoquant (#1175)
- Add floating point options for autoquant and add accuracy measurement (#1355)
benchmarks
- Adding batchsize support for torchao llama benchmarks (#1182)
- Add capability of benchmarking arbitrary binary (#1107)
experimental
- Add embedding ops aten (#1129)
- Add embedding ops executorch (#1137)
- Add quantized embedding kernels to torchao (#1018)
- Allow deprecated declarations what using Parallel ExecuTorch (#1031)
- Introduce lowbit quantized linear MPS kernels (#954)
- Enable 6-bit kernel (#1027)
- Kleidi 4b blockwise gemv prototype (#997)
- Experimental 6-bit quantization for Llama in torchchat (#1094)
- Introduce 7-bit quantization for Llama in torchchat. (#1139)
- Executorch Subclass API (#966) (#995)
- 8-bit packing support (#1248)
- Experimental Enable 8-bit (#1254)
- Experimental Benchmarking (#1353)
optimizer
- [low-bit optim] Upcast everything to FP32 for internal calculations (#1068)
- [Low-bit optim] Support for dcp.save() and dcp.load() (#1217)
- Enable CPU Offload for Intel GPU (#1324)
SAM2
- SAM2.1 copy (#1172)
- SAM2 AMG server side request batching (#1197)
- More SAM2-fast server improvements (#1285)
- SAM2 Fast AMG: memory profiling and more compile (#1296)
- SAM2 AMG cli and other QoL improvements (#1336)
- SAM2 AMG cli.py on modal (#1349)
- Reduce SAM2 AMG cli startup by using deploy (#1350)
- Reduce startup time for SAM2 AMG by using torch.export (#1358)
- More batching and improved furious accuracy/performance (#1253)
- SAM2.1 and example README (#1048)
- SAM2 AMG example mIoU, perf numbers and more SAM2 model annotations (#1196)
other
- Add SpinQuant to generate.py (#1069)
- SpinQuant (#983)
- SmoothQuant using tensor subclassing (#1030)
- Expose FakeQuantizeConfigs in QAT quantizers (#1214)
- Add module-swap UX for INT8 mixed-precision training (#1179)
- Float8 training: move module attribute setting to sync function (#1341)
Bug Fixes
- Header bug fix (#1079)
- Temporary fix for QAT quantizer when linear layer bias is True (#1087)
- Fix out-of-bounds memory access in Galore dequant kernel (#1125)
- Fixed weights_only=True load for float8_dynamic_activation_float8_weight in quant_api (#1122)
- Fix int8_weight_only group_size (#1165)
- Is_linear fix for MHA (#1141)
- Fixing eval.py to use GPTQ_MT for gptq (#1176)
- [CPU offload optim] Fix when there are non-trainable params (#1210)
- Fix for weights-only load (#1228)
- Pin nightlies to deal with std::badalloc (#1256)
- Fix 2.5.1 failing sparsity test (#1261)
- Call narrow only for TensorCoreTiledLayout (#1207)
- Fix an autoquant bug in flatten/unflatten (#1288)
- Float8 with delayed scaling: fix autocast handling (#1306)
- Fix bug with float8 training + FSDP2 + TP (#1327)
- Float8 training: fix bug with AC + compile (#1329)
- Fix torchtitan + float8 + delayed + compile (#1334)
- [low-bit optim] Fix edge cases for FSDP2 integration (#1269)
- [NF4] .to() fixes (#1312)
- Check scale.ndim before applying t/transpose (#1339)
Performance
- Swap in faster uint6 bitpacking function (#1098)
- Implement more efficient pack and unpack uint5 (#1138)
- Fix 20x slowdown of FP6 kernel due to device properties query (#1092)
Documentation
- Add a developer guide for exporting to executorch (#1219)
- Enable AWQ example on CPU (#1043)
- Add readme doc for experiemental (#1130)
- Move float8 out of prototype in quantization README (#1166)
- Update torchao api reference and add contributor guide (#1255)
- Fix pickle.dump missing file argument typo in README (#1316)
- Update README.md (#1319)
- Update README.md: Fix bibtex and sglang links (#1361)
- Add bibtex (#1177)
- Clarify torchao.float8 PyTorch version support (#1191)
Developers
- [Tp Test] Fix the placement of the device tensor (#1054)
- Skip test_fpx_weight_only in fbcode (#1056)
- Pin pt nightly CPU version (#1061)
- Unpin CUDA Nightly (#1064)
- Update smoke test (#1111)
- Update regression_test.yml (#1163)
- Add PyTorch 2.5 to regression test (#1168)
- Fix Bias APIs, re-enable kleidi tests for arm64 (#1162)
- Create CITATION.cff (#1178)
- Unpin nightlies (#1183)
- [experimental] Kleidi - add operator level tests (#1173)
- Ruff format and lint (#1226)
- Update pre-commit to match CI/CD (#1227)
- Fixing pytest skip for only test_floatx.py (#1251)
- Fixed invalid url in citation section (#1348)
- Add to safe globals (#1171)
- Aqt rename#1 Layout -> TensorImpl (#1046)
- Move and rename GranularityType -> Granularity (#1038)
- Change torchao quantization types from int to size_t and preface vars with "preferred_" (#1041)
- Shrink hadamard matrices (#1051)
- Use ExecuTorch prebuilt library in pip package to build custom kernels (#1059)
- Update base.h unit to unsigned int (#962)
- Create header for packed weight ops (#1072)
- Update cmake files (#1070)
- Create build_wheels_aarch64_linux.yml (#1083)
- ROCM binary upload (#1099)
- Create build_wheels_windows.yml (#1101)
- Use fewer instructions when unpacking uint6s. (#1109)
- [CI] XPU binary build enable (#1105)
- Move common ET/Aten op stuff to ops/library.h (#1116)
- Move bias from kernel to packed_weights (#1119)
- Update gpu_sparsity kernel benchmarking script (#1143)
- [ROCm] use dataclass for fnuz type setting (#1142)
- Move files to prototype/sparsity (#1145)
- C10::nullopt -> std::nullopt (#1032) (#1151)
- [reland][ROCm] use dataclass for fnuz type setting (#1150)
- Move float8_aten_api to float8_ops (#1155)
- Initialize model with meta device for generation benchmarking (#1144)
- Replace torch.empty with torch.zeros (#1157)
- Update utils.py (#1186)
- Remove int_scaled_mm's dependency on triton for cpu (#128)
- at::optional -> std::optional (#1170) (#1212)
- fast_flush kwarg of do_bench is removed (#1222)
- Remove calibration args from generate.py (#1258)
- Skip marlin QQQ ops test in fbcode (#1289)
- Fix Marlin QQQ ops test with unittest (#1294)
- Fix Failing CI - Update bitsandbytes import (#1343)
- Remove lm_eval warning (#1347)
- Refactor Affine Quantized Tensor (#1234)
- Move files from quantization/prototype -> prototype/quantization (#1187)
- Add TTFT benchmarks + update sparsity benchmarks (#1140)
- Add "_gemm_input_role" to dunder slots (#984)
- Add an option to use fp8-all-gather only without fp8 computation. (#1093)
- Bump version to 0.7 (#1045)
New Contributors
- @Jack-Khuu made their first contribution in #1031
- @keyan made their first contribution in #1041
- @digantdesai made their first contribution in #997
- @EnragedAntelope made their first contribution in #962
- @c4lcut3c made their first contribution in #1094
- @elfisworking made their first contribution in #1087
- @chuanqi129 made their first contribution in #1105
- @p4arth made their first contribution in #1122
- @xuzijian629 made their first contribution in #1138
- @jeffdaily made their first contribution in #1142
- @r-barnes made their first contribution in #1151
- @helunwencser made their first contribution in #1157
- @bertmaher made their first contribution in #1222
- @tibidoh made their first contribution in #1248
- @mandroid6 made their first contribution in #1250
- @HandH1998 made their first contribution in #1113
- @readleyj made their first contribution in #1316
- @22dimensions made their first contribution in #1318
- @galqiwi made their first contribution in #1348
- @dbyoung18 made their first contribution in #1324
- @sunjiweiswift made their first contribution in #1259
- @merrymercy made their first contribution in #1361
Full Changelog: v0.6.1...v0.7.0-rc1