mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[DCP][HuggingFace] Add Support for dequantization of SafeTensors checkpoints (#160682)
This PR introduces the QuantizedHuggingFaceReader component which enables the reading and dequantization of the quantized tensors in the SafeTensors checkpoint. Following capabilities are inrtoduced: - Configuration the target DType and the block size. - Multi threaded dequantization for efficiency Test Plan: buck test //caffe2/test/distributed/checkpoint\:test_quantized_hf_storage ``` Time elapsed: 2:34.1s Tests finished: Pass 31. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D80174674 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160682 Approved by: https://github.com/ankitageorge
This commit is contained in:
committed by
PyTorch MergeBot
parent
9458d1ac3b
commit
1281470155
@ -2514,6 +2514,8 @@ coverage_ignore_classes = [
|
||||
# torch.distributed.checkpoint.hf_storage
|
||||
"HuggingFaceStorageReader",
|
||||
"HuggingFaceStorageWriter",
|
||||
# torch.distributed.checkpoint.quantized_hf_storage
|
||||
"QuantizedHuggingFaceStorageReader",
|
||||
# torch.distributed.checkpoint.metadata
|
||||
"BytesStorageMetadata",
|
||||
"ChunkStorageMetadata",
|
||||
|
||||
@ -173,6 +173,9 @@ We also provide other storage layers, including ones to interact with HuggingFac
|
||||
.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageWriter
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.checkpoint.QuantizedHuggingFaceStorageReader
|
||||
:members:
|
||||
|
||||
We provide default implementations of `LoadPlanner` and `SavePlanner` that
|
||||
can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor.
|
||||
|
||||
|
||||
@ -1139,6 +1139,10 @@ If you are running single node training, it may be convenient to interactively b
|
||||
.. py:module:: torch.distributed.checkpoint.hf_storage
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.distributed.checkpoint.quantized_hf_storage
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.distributed.checkpoint.metadata
|
||||
```
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.checkpoint.quantized_hf_storage import (
|
||||
QuantizedHuggingFaceStorageReader,
|
||||
)
|
||||
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard, zeros
|
||||
@ -157,6 +161,118 @@ class TestSingleRankSaveLoad(TestCase):
|
||||
torch.equal(state_dict_to_save[key], state_dict_to_load[key])
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_quantized_checkpoint_loading(self) -> None:
|
||||
"""Test end-to-end saving a quantizaed checkpoint and loading it."""
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
return
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
# Create original (unquantized) tensors to validate against
|
||||
original_tensors = {
|
||||
"linear1.weight": torch.randn(256, 128, dtype=torch.float32) * 2.0,
|
||||
"linear2.weight": torch.randn(128, 64, dtype=torch.float32) * 1.5,
|
||||
"embedding.weight": torch.randn(512, 256, dtype=torch.float32) * 3.0,
|
||||
}
|
||||
|
||||
# Create quantized tensors and scale tensors
|
||||
quantized_checkpoint = {}
|
||||
block_size = 128
|
||||
|
||||
for tensor_name, original_tensor in original_tensors.items():
|
||||
# Simulate quantization: scale down the tensor for quantization
|
||||
# This is a simplified quantization - in real scenarios it would be more complex
|
||||
rows, cols = original_tensor.shape
|
||||
|
||||
# Create scale tensor for block-wise dequantization
|
||||
block_rows = (rows + block_size - 1) // block_size
|
||||
block_cols = (cols + block_size - 1) // block_size
|
||||
|
||||
# Create scale inverse tensor (used for dequantization)
|
||||
scale_inv = torch.ones(block_rows, block_cols, dtype=torch.float32) * 2.0
|
||||
|
||||
# Create quantized version (divide by scale for quantization)
|
||||
quantized_tensor = original_tensor / 2.0 # Simplified quantization
|
||||
|
||||
# Store quantized tensor and its scale
|
||||
quantized_checkpoint[tensor_name] = quantized_tensor
|
||||
quantized_checkpoint[f"{tensor_name}_scale_inv"] = scale_inv
|
||||
|
||||
# Save quantized checkpoint to safetensors file
|
||||
safetensors_file = os.path.join(CHECKPOINT_DIR, "model.safetensors")
|
||||
save_file(quantized_checkpoint, safetensors_file)
|
||||
|
||||
# Create model.safetensors.index.json with weight mapping
|
||||
weight_map = {}
|
||||
for key in quantized_checkpoint.keys():
|
||||
weight_map[key] = "model.safetensors"
|
||||
|
||||
index_data = {
|
||||
"metadata": {
|
||||
"total_size": sum(
|
||||
t.numel() * t.element_size() for t in quantized_checkpoint.values()
|
||||
)
|
||||
},
|
||||
"weight_map": weight_map,
|
||||
}
|
||||
|
||||
index_file = os.path.join(CHECKPOINT_DIR, "model.safetensors.index.json")
|
||||
with open(index_file, "w") as f:
|
||||
json.dump(index_data, f, indent=2)
|
||||
|
||||
# Prepare state dict to load into
|
||||
state_dict_to_load = {}
|
||||
for tensor_name, original_tensor in original_tensors.items():
|
||||
state_dict_to_load[tensor_name] = torch.zeros_like(original_tensor)
|
||||
|
||||
# Load using QuantizedHuggingFaceStorageReader
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=QuantizedHuggingFaceStorageReader(
|
||||
path=CHECKPOINT_DIR,
|
||||
target_dtype=torch.float32,
|
||||
block_size=block_size,
|
||||
thread_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate that loaded tensors match original tensors
|
||||
self.assertEqual(
|
||||
sorted(original_tensors.keys()), sorted(state_dict_to_load.keys())
|
||||
)
|
||||
|
||||
for tensor_name in original_tensors.keys():
|
||||
original = original_tensors[tensor_name]
|
||||
loaded = state_dict_to_load[tensor_name]
|
||||
|
||||
# Verify shapes match
|
||||
self.assertEqual(
|
||||
original.shape,
|
||||
loaded.shape,
|
||||
f"Shape mismatch for {tensor_name}: {original.shape} vs {loaded.shape}",
|
||||
)
|
||||
|
||||
# Verify dtypes match
|
||||
self.assertEqual(
|
||||
original.dtype,
|
||||
loaded.dtype,
|
||||
f"Dtype mismatch for {tensor_name}: {original.dtype} vs {loaded.dtype}",
|
||||
)
|
||||
|
||||
# Verify dequantized values match original values
|
||||
# We expect exact match since we used simple 2x scaling
|
||||
torch.testing.assert_close(
|
||||
loaded,
|
||||
original,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Value mismatch for tensor {tensor_name}",
|
||||
)
|
||||
|
||||
|
||||
class TestDistributedHFSafetensorsConsolidation(DTensorTestBase):
|
||||
@with_comms
|
||||
|
||||
84
test/distributed/checkpoint/test_quantized_hf_storage.py
Normal file
84
test/distributed/checkpoint/test_quantized_hf_storage.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint.metadata import MetadataIndex
|
||||
from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
|
||||
from torch.distributed.checkpoint.quantized_hf_storage import (
|
||||
QuantizedHuggingFaceStorageReader,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestQuantizedHfStorage(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up common test fixtures."""
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.path = self.temp_dir.name
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_dequantization(self):
|
||||
"""Test that quantized tensors are properly dequantized during read operations."""
|
||||
reader = QuantizedHuggingFaceStorageReader(self.path, thread_count=1)
|
||||
|
||||
# Test data
|
||||
quantized_tensor = torch.ones(4, 4, dtype=torch.float32)
|
||||
scale_inv = torch.tensor([[2.0]], dtype=torch.float32)
|
||||
|
||||
# Mock the safetensors file for reading data
|
||||
mock_file = MagicMock()
|
||||
|
||||
# Mock get_slice to return a tensor that can be sliced
|
||||
def mock_get_slice(tensor_name):
|
||||
mock_tensor = MagicMock()
|
||||
mock_tensor.__getitem__ = lambda self, slices: quantized_tensor
|
||||
return mock_tensor
|
||||
|
||||
mock_file.get_slice = mock_get_slice
|
||||
mock_file.get_tensor.return_value = scale_inv
|
||||
|
||||
reader._weight_scale_mapping = {
|
||||
"model.layers.0.self_attn.kv_b_proj.weight": "model.layers.0.self_attn.kv_b_proj.weight_scale_inv",
|
||||
}
|
||||
|
||||
# Create a read request for quantized tensor
|
||||
read_item = ReadItem(
|
||||
type=LoadItemType.TENSOR,
|
||||
storage_index=MetadataIndex(
|
||||
fqn="model.layers.0.self_attn.kv_b_proj.weight",
|
||||
offset=torch.Size([0, 0]),
|
||||
),
|
||||
dest_index=MetadataIndex(
|
||||
fqn="model.layers.0.self_attn.kv_b_proj.weight",
|
||||
offset=torch.Size([0, 0]),
|
||||
),
|
||||
storage_offsets=[0, 0],
|
||||
dest_offsets=[0, 0],
|
||||
lengths=[4, 4],
|
||||
)
|
||||
|
||||
# Mock planner
|
||||
target_tensor = torch.zeros(4, 4, dtype=torch.float32)
|
||||
mock_planner = MagicMock()
|
||||
mock_planner.resolve_tensor.return_value = target_tensor
|
||||
|
||||
# Test the _process_read_request method
|
||||
reader._process_read_request(mock_file, read_item, mock_planner)
|
||||
|
||||
# Verify the tensor was dequantized (ones * 2.0 = twos)
|
||||
expected_result = torch.ones(4, 4, dtype=torch.float32) * 2.0
|
||||
mock_planner.commit_tensor.assert_called_once()
|
||||
|
||||
# Check that target_tensor was updated correctly
|
||||
args, _ = mock_planner.commit_tensor.call_args
|
||||
committed_tensor = args[1] # second argument is the tensor
|
||||
torch.testing.assert_close(committed_tensor, expected_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -11,6 +11,7 @@ from .metadata import (
|
||||
)
|
||||
from .optimizer import load_sharded_optimizer_state_dict
|
||||
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
|
||||
from .quantized_hf_storage import QuantizedHuggingFaceStorageReader
|
||||
from .state_dict_loader import load, load_state_dict
|
||||
from .state_dict_saver import async_save, save, save_state_dict
|
||||
from .storage import StorageReader, StorageWriter
|
||||
|
||||
223
torch/distributed/checkpoint/quantized_hf_storage.py
Normal file
223
torch/distributed/checkpoint/quantized_hf_storage.py
Normal file
@ -0,0 +1,223 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint._hf_utils import _metadata_fn
|
||||
from torch.distributed.checkpoint.planner import LoadPlanner, ReadItem
|
||||
|
||||
from .hf_storage import HuggingFaceStorageReader
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["QuantizedHuggingFaceStorageReader"]
|
||||
|
||||
|
||||
class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
|
||||
"""
|
||||
Extension of HuggingFaceStorageReader that handles quantized tensors.
|
||||
Checkpoint should have the full tensor in a SafeTensor file. The quantized
|
||||
tensor should not be sharded across multiple files.
|
||||
|
||||
This reader handles the dequantization of tensors during the read process,
|
||||
converting them from quantized blocks to full dequantized tensors before
|
||||
copying to the target tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
thread_count: int = 1,
|
||||
target_dtype: torch.dtype = torch.float32,
|
||||
block_size: int = 128,
|
||||
):
|
||||
"""
|
||||
Initialize the HuggingFace storage reader to load quantized checkpoints
|
||||
|
||||
Args:
|
||||
path: directory where the checkpoint will be read from.
|
||||
thread_count: Number of threads to use to read distributed checkpoint. Defaults to 1.
|
||||
target_dtype: Target dtype for dequantized tensor. Defaults to torch.float32.
|
||||
block_size: Fixed block size for dequantization. Defaults to 128.
|
||||
"""
|
||||
super().__init__(path=path, thread_count=thread_count)
|
||||
|
||||
self.target_dtype: torch.dtype = target_dtype
|
||||
self.block_size: int = block_size
|
||||
self._weight_scale_mapping: dict[str, str] = {}
|
||||
self._scale_tensor_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
def read_metadata(self) -> Any:
|
||||
self._load_quantization_metadata()
|
||||
return super().read_metadata()
|
||||
|
||||
def _load_quantization_metadata(self):
|
||||
"""Load quantization metadata from the checkpoint."""
|
||||
checkpoint_path = Path(self.path)
|
||||
# Load weight mapping from index file
|
||||
index_file = checkpoint_path / _metadata_fn
|
||||
|
||||
with open(index_file) as f:
|
||||
index_data = json.load(f)
|
||||
weight_map = index_data.get("weight_map", {})
|
||||
self._build_weight_scale_mapping(weight_map)
|
||||
|
||||
def _build_weight_scale_mapping(self, weight_map: dict[str, str]):
|
||||
"""Analyze and build weight-scale tensor pairs from weight mapping."""
|
||||
for tensor_name in weight_map.keys():
|
||||
if tensor_name.endswith(".weight_scale_inv"):
|
||||
weight_name = tensor_name.replace(".weight_scale_inv", ".weight")
|
||||
if weight_name in weight_map:
|
||||
self._weight_scale_mapping[weight_name] = tensor_name
|
||||
|
||||
def _process_read_request(
|
||||
self, f: Any, req: ReadItem, planner: LoadPlanner
|
||||
) -> None:
|
||||
"""Override the Helper function that processes a single read request."""
|
||||
tensor_fqn = req.storage_index.fqn
|
||||
|
||||
# Check if this is a quantized tensor that needs dequantization
|
||||
if self._is_tensor_quantized(tensor_fqn):
|
||||
tensor = self._read_quantized_tensor_with_block_alignment(req, f)
|
||||
else:
|
||||
# Standard tensor reading
|
||||
slices = tuple(
|
||||
slice(offset, offset + length)
|
||||
for offset, length in zip(req.storage_offsets, req.lengths)
|
||||
)
|
||||
tensor = f.get_slice(tensor_fqn)[slices]
|
||||
|
||||
target_tensor = planner.resolve_tensor(req).detach()
|
||||
|
||||
assert target_tensor.size() == tensor.size(), (
|
||||
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
|
||||
target_tensor.copy_(tensor)
|
||||
planner.commit_tensor(req, target_tensor)
|
||||
|
||||
def _calculate_scale_shape(
|
||||
self, weight: torch.Tensor, block_size: int
|
||||
) -> tuple[int, int]:
|
||||
"""Calculate expected scale tensor shape based on weight tensor and block size."""
|
||||
rows, cols = weight.shape
|
||||
block_rows = (rows + block_size - 1) // block_size # Ceiling division
|
||||
block_cols = (cols + block_size - 1) // block_size # Ceiling division
|
||||
return (block_rows, block_cols)
|
||||
|
||||
def _dequantize_tensor(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
scale_inv: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dequantize tensor using block-wise scaling.
|
||||
|
||||
Args:
|
||||
weight: Quantized weight tensor
|
||||
scale_inv: Scale inverse tensor for dequantization
|
||||
|
||||
Returns:
|
||||
Dequantized tensor
|
||||
"""
|
||||
# Get original dimensions
|
||||
orig_shape = weight.shape
|
||||
|
||||
# Calculate block dimensions for the local shard
|
||||
expected_scale_shape = self._calculate_scale_shape(weight, self.block_size)
|
||||
block_rows, block_cols = expected_scale_shape
|
||||
|
||||
# Create output tensor in target dtype
|
||||
dequantized = weight.detach().clone().to(dtype=self.target_dtype)
|
||||
|
||||
# Apply scaling factors to each block
|
||||
for i in range(block_rows):
|
||||
row_start = i * self.block_size
|
||||
row_end = min(row_start + self.block_size, orig_shape[0])
|
||||
|
||||
for j in range(block_cols):
|
||||
col_start = j * self.block_size
|
||||
col_end = min(col_start + self.block_size, orig_shape[1])
|
||||
|
||||
# Get the block
|
||||
block = weight[row_start:row_end, col_start:col_end]
|
||||
|
||||
scale = scale_inv[i, j]
|
||||
block = block * scale
|
||||
|
||||
# Explicitly convert block to target dtype
|
||||
block_converted = block.to(dtype=self.target_dtype)
|
||||
# Store the dequantized block
|
||||
dequantized[row_start:row_end, col_start:col_end] = block_converted
|
||||
|
||||
return dequantized
|
||||
|
||||
def _is_tensor_quantized(self, tensor_fqn: str) -> bool:
|
||||
"""
|
||||
Check if a tensor is a quantized.
|
||||
|
||||
Args:
|
||||
tensor_fqn: Fully qualified name of the tensor
|
||||
|
||||
Returns:
|
||||
True if tensor is quantized and has a corresponding scale tensor,
|
||||
False otherwise
|
||||
"""
|
||||
# Skip scale tensors themselves
|
||||
if tensor_fqn.endswith(".weight_scale_inv"):
|
||||
return False
|
||||
|
||||
# Check if this weight tensor has a corresponding scale tensor
|
||||
if tensor_fqn not in self._weight_scale_mapping:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _read_quantized_tensor_with_block_alignment(
|
||||
self, req: ReadItem, safetensor_file: Any
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Read a quantized tensor with block alignment.
|
||||
|
||||
Args:
|
||||
req: Read request containing tensor info and required slices
|
||||
safetensor_file: Open safetensors file handle
|
||||
|
||||
Returns:
|
||||
Dequantized tensor ready for use
|
||||
"""
|
||||
tensor_fqn = req.storage_index.fqn
|
||||
scale_fqn = self._weight_scale_mapping[tensor_fqn]
|
||||
|
||||
try:
|
||||
# Load the quantized weight
|
||||
weight_slices = tuple(
|
||||
slice(offset, offset + length)
|
||||
for offset, length in zip(req.storage_offsets, req.lengths)
|
||||
)
|
||||
quantized_tensor = safetensor_file.get_slice(tensor_fqn)[weight_slices]
|
||||
|
||||
# Load the corresponding scale inverse tensor
|
||||
# For scale tensors, we typically need the full tensor for proper block alignment
|
||||
if scale_fqn not in self._scale_tensor_cache:
|
||||
scale_inv = safetensor_file.get_tensor(
|
||||
scale_fqn
|
||||
) # Load full scale tensor
|
||||
self._scale_tensor_cache[scale_fqn] = scale_inv
|
||||
else:
|
||||
scale_inv = self._scale_tensor_cache[scale_fqn]
|
||||
|
||||
# Perform dequantization
|
||||
dequantized_tensor = self._dequantize_tensor(
|
||||
weight=quantized_tensor,
|
||||
scale_inv=scale_inv,
|
||||
)
|
||||
|
||||
return dequantized_tensor
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to read the quantized tensor!!")
|
||||
raise e
|
||||
Reference in New Issue
Block a user