[DCP][HuggingFace] Add Support for dequantization of SafeTensors checkpoints (#160682)

This PR introduces the QuantizedHuggingFaceReader component which enables the reading and dequantization of the quantized tensors in the SafeTensors checkpoint. Following capabilities are inrtoduced:
- Configuration the target DType and the block size.
- Multi threaded dequantization for efficiency

Test Plan:
buck test //caffe2/test/distributed/checkpoint\:test_quantized_hf_storage
```
Time elapsed: 2:34.1s
Tests finished: Pass 31. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D80174674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160682
Approved by: https://github.com/ankitageorge
This commit is contained in:
Saurabh Mishra
2025-09-04 01:09:53 +00:00
committed by PyTorch MergeBot
parent 9458d1ac3b
commit 1281470155
7 changed files with 433 additions and 0 deletions

View File

@ -2514,6 +2514,8 @@ coverage_ignore_classes = [
# torch.distributed.checkpoint.hf_storage
"HuggingFaceStorageReader",
"HuggingFaceStorageWriter",
# torch.distributed.checkpoint.quantized_hf_storage
"QuantizedHuggingFaceStorageReader",
# torch.distributed.checkpoint.metadata
"BytesStorageMetadata",
"ChunkStorageMetadata",

View File

@ -173,6 +173,9 @@ We also provide other storage layers, including ones to interact with HuggingFac
.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageWriter
:members:
.. autoclass:: torch.distributed.checkpoint.QuantizedHuggingFaceStorageReader
:members:
We provide default implementations of `LoadPlanner` and `SavePlanner` that
can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor.

View File

@ -1139,6 +1139,10 @@ If you are running single node training, it may be convenient to interactively b
.. py:module:: torch.distributed.checkpoint.hf_storage
```
```{eval-rst}
.. py:module:: torch.distributed.checkpoint.quantized_hf_storage
```
```{eval-rst}
.. py:module:: torch.distributed.checkpoint.metadata
```

View File

@ -1,11 +1,15 @@
# Owner(s): ["oncall: distributed checkpointing"]
import importlib
import json
import os
import torch
import torch.distributed.checkpoint as dist_cp
from torch import distributed as dist
from torch.distributed.checkpoint.quantized_hf_storage import (
QuantizedHuggingFaceStorageReader,
)
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard, zeros
@ -157,6 +161,118 @@ class TestSingleRankSaveLoad(TestCase):
torch.equal(state_dict_to_save[key], state_dict_to_load[key])
)
@with_temp_dir
def test_quantized_checkpoint_loading(self) -> None:
"""Test end-to-end saving a quantizaed checkpoint and loading it."""
try:
from safetensors.torch import save_file
except ImportError:
print("safetensors not installed")
return
CHECKPOINT_DIR = self.temp_dir
# Create original (unquantized) tensors to validate against
original_tensors = {
"linear1.weight": torch.randn(256, 128, dtype=torch.float32) * 2.0,
"linear2.weight": torch.randn(128, 64, dtype=torch.float32) * 1.5,
"embedding.weight": torch.randn(512, 256, dtype=torch.float32) * 3.0,
}
# Create quantized tensors and scale tensors
quantized_checkpoint = {}
block_size = 128
for tensor_name, original_tensor in original_tensors.items():
# Simulate quantization: scale down the tensor for quantization
# This is a simplified quantization - in real scenarios it would be more complex
rows, cols = original_tensor.shape
# Create scale tensor for block-wise dequantization
block_rows = (rows + block_size - 1) // block_size
block_cols = (cols + block_size - 1) // block_size
# Create scale inverse tensor (used for dequantization)
scale_inv = torch.ones(block_rows, block_cols, dtype=torch.float32) * 2.0
# Create quantized version (divide by scale for quantization)
quantized_tensor = original_tensor / 2.0 # Simplified quantization
# Store quantized tensor and its scale
quantized_checkpoint[tensor_name] = quantized_tensor
quantized_checkpoint[f"{tensor_name}_scale_inv"] = scale_inv
# Save quantized checkpoint to safetensors file
safetensors_file = os.path.join(CHECKPOINT_DIR, "model.safetensors")
save_file(quantized_checkpoint, safetensors_file)
# Create model.safetensors.index.json with weight mapping
weight_map = {}
for key in quantized_checkpoint.keys():
weight_map[key] = "model.safetensors"
index_data = {
"metadata": {
"total_size": sum(
t.numel() * t.element_size() for t in quantized_checkpoint.values()
)
},
"weight_map": weight_map,
}
index_file = os.path.join(CHECKPOINT_DIR, "model.safetensors.index.json")
with open(index_file, "w") as f:
json.dump(index_data, f, indent=2)
# Prepare state dict to load into
state_dict_to_load = {}
for tensor_name, original_tensor in original_tensors.items():
state_dict_to_load[tensor_name] = torch.zeros_like(original_tensor)
# Load using QuantizedHuggingFaceStorageReader
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=QuantizedHuggingFaceStorageReader(
path=CHECKPOINT_DIR,
target_dtype=torch.float32,
block_size=block_size,
thread_count=2,
),
)
# Validate that loaded tensors match original tensors
self.assertEqual(
sorted(original_tensors.keys()), sorted(state_dict_to_load.keys())
)
for tensor_name in original_tensors.keys():
original = original_tensors[tensor_name]
loaded = state_dict_to_load[tensor_name]
# Verify shapes match
self.assertEqual(
original.shape,
loaded.shape,
f"Shape mismatch for {tensor_name}: {original.shape} vs {loaded.shape}",
)
# Verify dtypes match
self.assertEqual(
original.dtype,
loaded.dtype,
f"Dtype mismatch for {tensor_name}: {original.dtype} vs {loaded.dtype}",
)
# Verify dequantized values match original values
# We expect exact match since we used simple 2x scaling
torch.testing.assert_close(
loaded,
original,
rtol=1e-5,
atol=1e-5,
msg=f"Value mismatch for tensor {tensor_name}",
)
class TestDistributedHFSafetensorsConsolidation(DTensorTestBase):
@with_comms

View File

@ -0,0 +1,84 @@
# Owner(s): ["oncall: distributed checkpointing"]
import tempfile
from unittest.mock import MagicMock
import torch
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
from torch.distributed.checkpoint.quantized_hf_storage import (
QuantizedHuggingFaceStorageReader,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class TestQuantizedHfStorage(TestCase):
def setUp(self):
"""Set up common test fixtures."""
self.temp_dir = tempfile.TemporaryDirectory()
self.path = self.temp_dir.name
def tearDown(self):
"""Clean up test fixtures."""
self.temp_dir.cleanup()
def test_dequantization(self):
"""Test that quantized tensors are properly dequantized during read operations."""
reader = QuantizedHuggingFaceStorageReader(self.path, thread_count=1)
# Test data
quantized_tensor = torch.ones(4, 4, dtype=torch.float32)
scale_inv = torch.tensor([[2.0]], dtype=torch.float32)
# Mock the safetensors file for reading data
mock_file = MagicMock()
# Mock get_slice to return a tensor that can be sliced
def mock_get_slice(tensor_name):
mock_tensor = MagicMock()
mock_tensor.__getitem__ = lambda self, slices: quantized_tensor
return mock_tensor
mock_file.get_slice = mock_get_slice
mock_file.get_tensor.return_value = scale_inv
reader._weight_scale_mapping = {
"model.layers.0.self_attn.kv_b_proj.weight": "model.layers.0.self_attn.kv_b_proj.weight_scale_inv",
}
# Create a read request for quantized tensor
read_item = ReadItem(
type=LoadItemType.TENSOR,
storage_index=MetadataIndex(
fqn="model.layers.0.self_attn.kv_b_proj.weight",
offset=torch.Size([0, 0]),
),
dest_index=MetadataIndex(
fqn="model.layers.0.self_attn.kv_b_proj.weight",
offset=torch.Size([0, 0]),
),
storage_offsets=[0, 0],
dest_offsets=[0, 0],
lengths=[4, 4],
)
# Mock planner
target_tensor = torch.zeros(4, 4, dtype=torch.float32)
mock_planner = MagicMock()
mock_planner.resolve_tensor.return_value = target_tensor
# Test the _process_read_request method
reader._process_read_request(mock_file, read_item, mock_planner)
# Verify the tensor was dequantized (ones * 2.0 = twos)
expected_result = torch.ones(4, 4, dtype=torch.float32) * 2.0
mock_planner.commit_tensor.assert_called_once()
# Check that target_tensor was updated correctly
args, _ = mock_planner.commit_tensor.call_args
committed_tensor = args[1] # second argument is the tensor
torch.testing.assert_close(committed_tensor, expected_result)
if __name__ == "__main__":
run_tests()

View File

@ -11,6 +11,7 @@ from .metadata import (
)
from .optimizer import load_sharded_optimizer_state_dict
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
from .quantized_hf_storage import QuantizedHuggingFaceStorageReader
from .state_dict_loader import load, load_state_dict
from .state_dict_saver import async_save, save, save_state_dict
from .storage import StorageReader, StorageWriter

View File

@ -0,0 +1,223 @@
# mypy: allow-untyped-defs
import json
import logging
from pathlib import Path
from typing import Any
import torch
from torch.distributed.checkpoint._hf_utils import _metadata_fn
from torch.distributed.checkpoint.planner import LoadPlanner, ReadItem
from .hf_storage import HuggingFaceStorageReader
logger: logging.Logger = logging.getLogger(__name__)
__all__ = ["QuantizedHuggingFaceStorageReader"]
class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
"""
Extension of HuggingFaceStorageReader that handles quantized tensors.
Checkpoint should have the full tensor in a SafeTensor file. The quantized
tensor should not be sharded across multiple files.
This reader handles the dequantization of tensors during the read process,
converting them from quantized blocks to full dequantized tensors before
copying to the target tensor.
"""
def __init__(
self,
path: str,
thread_count: int = 1,
target_dtype: torch.dtype = torch.float32,
block_size: int = 128,
):
"""
Initialize the HuggingFace storage reader to load quantized checkpoints
Args:
path: directory where the checkpoint will be read from.
thread_count: Number of threads to use to read distributed checkpoint. Defaults to 1.
target_dtype: Target dtype for dequantized tensor. Defaults to torch.float32.
block_size: Fixed block size for dequantization. Defaults to 128.
"""
super().__init__(path=path, thread_count=thread_count)
self.target_dtype: torch.dtype = target_dtype
self.block_size: int = block_size
self._weight_scale_mapping: dict[str, str] = {}
self._scale_tensor_cache: dict[str, torch.Tensor] = {}
def read_metadata(self) -> Any:
self._load_quantization_metadata()
return super().read_metadata()
def _load_quantization_metadata(self):
"""Load quantization metadata from the checkpoint."""
checkpoint_path = Path(self.path)
# Load weight mapping from index file
index_file = checkpoint_path / _metadata_fn
with open(index_file) as f:
index_data = json.load(f)
weight_map = index_data.get("weight_map", {})
self._build_weight_scale_mapping(weight_map)
def _build_weight_scale_mapping(self, weight_map: dict[str, str]):
"""Analyze and build weight-scale tensor pairs from weight mapping."""
for tensor_name in weight_map.keys():
if tensor_name.endswith(".weight_scale_inv"):
weight_name = tensor_name.replace(".weight_scale_inv", ".weight")
if weight_name in weight_map:
self._weight_scale_mapping[weight_name] = tensor_name
def _process_read_request(
self, f: Any, req: ReadItem, planner: LoadPlanner
) -> None:
"""Override the Helper function that processes a single read request."""
tensor_fqn = req.storage_index.fqn
# Check if this is a quantized tensor that needs dequantization
if self._is_tensor_quantized(tensor_fqn):
tensor = self._read_quantized_tensor_with_block_alignment(req, f)
else:
# Standard tensor reading
slices = tuple(
slice(offset, offset + length)
for offset, length in zip(req.storage_offsets, req.lengths)
)
tensor = f.get_slice(tensor_fqn)[slices]
target_tensor = planner.resolve_tensor(req).detach()
assert target_tensor.size() == tensor.size(), (
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
)
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)
def _calculate_scale_shape(
self, weight: torch.Tensor, block_size: int
) -> tuple[int, int]:
"""Calculate expected scale tensor shape based on weight tensor and block size."""
rows, cols = weight.shape
block_rows = (rows + block_size - 1) // block_size # Ceiling division
block_cols = (cols + block_size - 1) // block_size # Ceiling division
return (block_rows, block_cols)
def _dequantize_tensor(
self,
weight: torch.Tensor,
scale_inv: torch.Tensor,
) -> torch.Tensor:
"""
Dequantize tensor using block-wise scaling.
Args:
weight: Quantized weight tensor
scale_inv: Scale inverse tensor for dequantization
Returns:
Dequantized tensor
"""
# Get original dimensions
orig_shape = weight.shape
# Calculate block dimensions for the local shard
expected_scale_shape = self._calculate_scale_shape(weight, self.block_size)
block_rows, block_cols = expected_scale_shape
# Create output tensor in target dtype
dequantized = weight.detach().clone().to(dtype=self.target_dtype)
# Apply scaling factors to each block
for i in range(block_rows):
row_start = i * self.block_size
row_end = min(row_start + self.block_size, orig_shape[0])
for j in range(block_cols):
col_start = j * self.block_size
col_end = min(col_start + self.block_size, orig_shape[1])
# Get the block
block = weight[row_start:row_end, col_start:col_end]
scale = scale_inv[i, j]
block = block * scale
# Explicitly convert block to target dtype
block_converted = block.to(dtype=self.target_dtype)
# Store the dequantized block
dequantized[row_start:row_end, col_start:col_end] = block_converted
return dequantized
def _is_tensor_quantized(self, tensor_fqn: str) -> bool:
"""
Check if a tensor is a quantized.
Args:
tensor_fqn: Fully qualified name of the tensor
Returns:
True if tensor is quantized and has a corresponding scale tensor,
False otherwise
"""
# Skip scale tensors themselves
if tensor_fqn.endswith(".weight_scale_inv"):
return False
# Check if this weight tensor has a corresponding scale tensor
if tensor_fqn not in self._weight_scale_mapping:
return False
return True
def _read_quantized_tensor_with_block_alignment(
self, req: ReadItem, safetensor_file: Any
) -> torch.Tensor:
"""
Read a quantized tensor with block alignment.
Args:
req: Read request containing tensor info and required slices
safetensor_file: Open safetensors file handle
Returns:
Dequantized tensor ready for use
"""
tensor_fqn = req.storage_index.fqn
scale_fqn = self._weight_scale_mapping[tensor_fqn]
try:
# Load the quantized weight
weight_slices = tuple(
slice(offset, offset + length)
for offset, length in zip(req.storage_offsets, req.lengths)
)
quantized_tensor = safetensor_file.get_slice(tensor_fqn)[weight_slices]
# Load the corresponding scale inverse tensor
# For scale tensors, we typically need the full tensor for proper block alignment
if scale_fqn not in self._scale_tensor_cache:
scale_inv = safetensor_file.get_tensor(
scale_fqn
) # Load full scale tensor
self._scale_tensor_cache[scale_fqn] = scale_inv
else:
scale_inv = self._scale_tensor_cache[scale_fqn]
# Perform dequantization
dequantized_tensor = self._dequantize_tensor(
weight=quantized_tensor,
scale_inv=scale_inv,
)
return dequantized_tensor
except Exception as e:
logger.error("Failed to read the quantized tensor!!")
raise e