From 792c60640bc21c2ac98a98167691ec9f1a428043 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Tue, 9 Dec 2025 13:01:01 -0800 Subject: [PATCH] adding Int8DynamicActivationInt8WeightConfig and Int8WeightOnlyConfig to safetensors --- .../prototype/safetensors/test_safetensors_support.py | 6 ++++++ torchao/prototype/safetensors/safetensors_utils.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index 6892a0ca22..1c32ca63cc 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -20,7 +20,9 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, Int8DynamicActivationIntxWeightConfig, + Int8WeightOnlyConfig, IntxWeightOnlyConfig, ) from torchao.utils import is_sm_at_least_89 @@ -50,6 +52,8 @@ class TestSafeTensors(TestCase): (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), False), (IntxWeightOnlyConfig(), False), (Int8DynamicActivationIntxWeightConfig(), False), + (Int8WeightOnlyConfig(version=2), False), + (Int8DynamicActivationInt8WeightConfig(version=2), False), ], ) def test_safetensors(self, config, act_pre_scale=False): @@ -95,6 +99,8 @@ def test_safetensors(self, config, act_pre_scale=False): (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), False), (IntxWeightOnlyConfig(), False), (Int8DynamicActivationIntxWeightConfig(), False), + (Int8WeightOnlyConfig(version=2), False), + (Int8DynamicActivationInt8WeightConfig(version=2), False), ], ) def test_safetensors_sharded(self, config, act_pre_scale=False): diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index 9630515039..71cbe2d58e 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -10,21 +10,29 @@ Float8Tensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8Tensor, IntxUnpackedToInt8Tensor, + MappingType, ) from torchao.quantization.quantize_.common import KernelPreference -from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs +from torchao.quantization.quantize_.workflows import ( + QuantizeTensorToFloat8Kwargs, + QuantizeTensorToInt8Kwargs, +) ALLOWED_CLASSES = { "Float8Tensor": Float8Tensor, "Int4Tensor": Int4Tensor, "Int4TilePackedTo4dTensor": Int4TilePackedTo4dTensor, "IntxUnpackedToInt8Tensor": IntxUnpackedToInt8Tensor, + "Int8Tensor": Int8Tensor, "Float8MMConfig": torchao.float8.inference.Float8MMConfig, "QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs, + "QuantizeTensorToInt8Kwargs": QuantizeTensorToInt8Kwargs, "PerRow": torchao.quantization.PerRow, "PerTensor": torchao.quantization.PerTensor, "KernelPreference": KernelPreference, + "MappingType": MappingType, } ALLOWED_TENSORS_SUBCLASSES = [ @@ -32,6 +40,7 @@ "Int4Tensor", "Int4TilePackedTo4dTensor", "IntxUnpackedToInt8Tensor", + "Int8Tensor", ] __all__ = [