Compare commits

...

1 Commits

Author SHA1 Message Date
4d2b8e830d [DeviceMesh] Use _flatten_rank_map to replace _flatten_mesh_list so that we don't need to compare root mesh (#166003) (#166264)
Summary:

Since we are already share a flattened tensor `_rank_map` across all meshes from a same root mesh, we can just use a flattened list of it to replace the comparison of root_mesh and flattened_mesh_list (because with same _rank_map and layout, the mesh tensor is guaranteed to be the same). This way we can also give back the CPU overhead added in https://github.com/pytorch/pytorch/pull/164510 and further simply the code.

We do have a more ambitious universe-based change here: https://github.com/pytorch/pytorch/pull/165680 but it needs more discussions and would lead to BC breaking. We might eventually merge that PR but probably not now and this is a change which is not BC breaking and will help concatenate and 2D integration with concatenate.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

imported-using-ghimport

Test Plan: Imported from OSS

Differential Revision: D85526705

Pulled By: fduwjj

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166264
Approved by: https://github.com/XilunWu
2025-10-27 10:51:39 -07:00
4 changed files with 319 additions and 24 deletions

View File

@ -22,6 +22,7 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
RowwiseParallel,
)
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -82,7 +83,46 @@ class SimpleModelAnnotated(torch.nn.Module):
return self.mlp_1(x)
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
class FlexAttentionModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.proj_q = torch.nn.Linear(16, 128, device=device)
self.proj_k = torch.nn.Linear(16, 128, device=device)
self.proj_v = torch.nn.Linear(16, 128, device=device)
self.proj_out = torch.nn.Linear(128, 16, device=device)
self.num_heads = 8
self.head_dim = 16
def forward(self, x, *, block_mask=None):
batch_size, seq_len, embed_dim = x.shape
# Project to Q, K, V
q = self.proj_q(x)
k = self.proj_k(x)
v = self.proj_v(x)
# After colwise parallel, q/k/v are sharded on the last dimension
# Get the actual size after sharding
hidden_size = q.shape[-1]
num_heads_local = hidden_size // self.head_dim
# Reshape to (batch, num_heads, seq_len, head_dim)
q = q.view(batch_size, seq_len, num_heads_local, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, num_heads_local, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, num_heads_local, self.head_dim).transpose(1, 2)
# Apply flex_attention
attn_output_raw = flex_attention(q, k, v, block_mask=block_mask)
# Reshape back to (batch, seq_len, hidden_size)
attn_output = (
attn_output_raw.transpose(1, 2)
.contiguous()
.view(batch_size, seq_len, hidden_size)
)
# Output projection
output = self.proj_out(attn_output)
return output
def strict_export_and_aot_export_joint_with_descriptors(model, args, kwargs=None):
if kwargs is None:
kwargs = {}
# needed for stric export
torch.utils._pytree.register_constant(DTensorSpec)
@ -91,36 +131,43 @@ def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
install_free_tensors=True, inline_inbuilt_nn_modules=True
):
with torch._export.utils._disable_aten_to_metadata_assertions():
ep = torch.export.export(model, (inputs,), strict=True)
ep = torch.export.export(model, args, kwargs, strict=True)
# joint_gm produced here is missing the backward region, due to incompatiblility
# between ep.module() and aot_export_joint_with_descriptors.
# Keeping this here to show the issue.
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
return aot_export_joint_with_descriptors_alone(ep.module(), args, kwargs)
def graph_capture_and_aot_export_joint_with_descriptors_v2(model, inputs):
gm = dynamo_graph_capture_for_export(model)(inputs)
def graph_capture_and_aot_export_joint_with_descriptors_v2(model, args, kwargs=None):
if kwargs is None:
kwargs = {}
gm = dynamo_graph_capture_for_export(model)(*args, **kwargs)
fake_mode = gm.meta.get("fake_mode", None)
with tracing(TracingContext(fake_mode)):
return aot_export_joint_with_descriptors_alone(gm, inputs)
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
def graph_capture_and_aot_export_joint_with_descriptors(model, args, kwargs=None):
if kwargs is None:
kwargs = {}
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(inputs)
gm = _dynamo_graph_capture_for_export(model)(*args, **kwargs)
fake_mode = gm.meta.get("fake_mode", None)
with tracing(TracingContext(fake_mode)):
return aot_export_joint_with_descriptors_alone(gm, inputs)
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
def aot_export_joint_with_descriptors_alone(model, inputs):
def aot_export_joint_with_descriptors_alone(model, args, kwargs=None):
if kwargs is None:
kwargs = {}
with contextlib.ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
model,
(inputs,),
args,
kwargs,
)
return joint_with_descriptors.graph_module
@ -168,8 +215,8 @@ class DTensorExportTest(TestCase):
}
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
inputs = torch.rand(20, 10, device=self.device_type)
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
inp = torch.rand(20, 10, device=self.device_type)
inputs = (distribute_tensor(inp, mesh_2d["tp"], placements=[Replicate()]),)
joint_gm = export_fn(tp_model, inputs)
fw_gm, bw_gm = min_cut_rematerialization_partition(
@ -352,9 +399,10 @@ class DTensorExportTest(TestCase):
}
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
inputs = torch.rand(20, 10, device=self.device_type)
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100)
inp = torch.rand(20, 10, device=self.device_type)
inp_dtensor = distribute_tensor(inp, mesh_2d["tp"], placements=[Replicate()])
torch._dynamo.mark_dynamic(inp_dtensor, 0, min=5, max=100)
inputs = (inp_dtensor,)
joint_gm = export_fn(tp_model, inputs)
@ -390,15 +438,67 @@ class DTensorExportTest(TestCase):
z = torch.randn(16, 16)
y_dtensor = distribute_tensor(y, device_mesh, placements=[Replicate()])
z_dtensor = DTensor.from_local(z, device_mesh, placements=[Partial()])
inputs = (x_dtensor, y_dtensor, z_dtensor)
# Run model to verify it works
output = model(x_dtensor, y_dtensor, z_dtensor)
output = model(*inputs)
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
gm = export_fn(model)(x_dtensor, y_dtensor, z_dtensor)
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
gm = export_fn(model)(*inputs)
output_gm = gm(*inputs)
self.assertEqual(output, output_gm)
def test_flex_attention_dtensor_export(self):
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
model = FlexAttentionModel(self.device_type)
# Parallelize the model: shard on head dimension
# proj_q, proj_k, proj_v are colwise parallel (output is sharded on head dimension)
# proj_out is rowwise parallel (input is sharded, output needs reduction)
parallelize_plan = {
"proj_q": ColwiseParallel(),
"proj_k": ColwiseParallel(),
"proj_v": ColwiseParallel(),
"proj_out": RowwiseParallel(),
}
tp_model = parallelize_module(model, device_mesh, parallelize_plan)
batch_size = 4
seq_len = 64
embed_dim = 16
num_heads = 8
# Input tensor replicated across all devices
inp = torch.randn(batch_size, seq_len, embed_dim, device=self.device_type)
inputs = (distribute_tensor(inp, device_mesh, placements=[Replicate()]),)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(
causal_mask,
batch_size,
num_heads,
seq_len,
seq_len,
device=self.device_type,
)
flex_kwargs = {"block_mask": block_mask}
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(
tp_model, inputs, flex_kwargs
)
self.assertTrue(
_count_op(joint_gm, torch.ops.higher_order.flex_attention),
1,
)
self.assertTrue(
_count_op(joint_gm, torch.ops.higher_order.flex_attention_backward),
2,
)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -5734,6 +5734,141 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
self.assertEqual(flex_output, sdpa_output, atol=1e-3, rtol=1e-3)
@supported_platform
def test_pytree_flatten_unflatten(self, device):
"""Test that BlockMask can be correctly flattened and unflattened using class methods."""
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
# Create a BlockMask with various attributes set
block_mask = create_block_mask(
causal_mask, B=2, H=4, Q_LEN=512, KV_LEN=512, device=device
)
# Flatten and unflatten using class methods
tensors, context = block_mask._flatten()
reconstructed_mask = BlockMask._unflatten(tensors, context)
# Verify the reconstructed mask has the same attributes
self.assertEqual(reconstructed_mask.shape, block_mask.shape)
self.assertEqual(reconstructed_mask.sparsity(), block_mask.sparsity())
# Verify all tensor attributes are equal (using _TENSOR_ATTRS)
for attr_name in BlockMask._TENSOR_ATTRS:
original_value = getattr(block_mask, attr_name)
reconstructed_value = getattr(reconstructed_mask, attr_name)
if original_value is None:
self.assertIsNone(
reconstructed_value,
f"Tensor attribute {attr_name} should be None but got {reconstructed_value}",
)
else:
self.assertIsInstance(
original_value,
torch.Tensor,
f"Expected {attr_name} to be a Tensor",
)
self.assertTrue(
torch.equal(original_value, reconstructed_value),
f"Tensor attribute {attr_name} not equal after reconstruction",
)
# Verify all context attributes are equal (using _CONTEXT_ATTRS)
for attr_name in BlockMask._CONTEXT_ATTRS:
original_value = getattr(block_mask, attr_name)
reconstructed_value = getattr(reconstructed_mask, attr_name)
self.assertEqual(
original_value,
reconstructed_value,
f"Context attribute {attr_name} not equal after reconstruction",
)
@supported_platform
def test_pytree_flatten_with_keys(self, device):
"""Test that BlockMask._flatten_with_keys works correctly for tracing."""
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(
causal_mask, B=2, H=4, Q_LEN=512, KV_LEN=512, device=device
)
tensors_with_keys, context_with_keys = block_mask._flatten_with_keys()
self.assertEqual(len(tensors_with_keys), len(BlockMask._TENSOR_ATTRS))
self.assertEqual(len(context_with_keys), len(BlockMask._CONTEXT_ATTRS))
from torch.utils._pytree import GetAttrKey
for key, tensor in tensors_with_keys:
self.assertIsInstance(key, GetAttrKey)
self.assertIsNotNone(key)
for key, value in context_with_keys:
self.assertIsInstance(key, GetAttrKey)
self.assertIsNotNone(key)
@supported_platform
def test_pytree_preserves_new_attributes(self, device):
"""
Test that BlockMask._TENSOR_ATTRS and _CONTEXT_ATTRS are correctly defined
and that flatten/unflatten preserves all attributes in these lists.
"""
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(
causal_mask, B=2, H=4, Q_LEN=512, KV_LEN=512, device=device
)
# Flatten and unflatten using class methods
tensors, context = block_mask._flatten()
reconstructed_mask = BlockMask._unflatten(tensors, context)
# Verify the number of tensors and context values matches the attribute lists
self.assertEqual(
len(tensors),
len(BlockMask._TENSOR_ATTRS),
"Number of tensors should match _TENSOR_ATTRS length",
)
self.assertEqual(
len(context),
len(BlockMask._CONTEXT_ATTRS),
"Number of context values should match _CONTEXT_ATTRS length",
)
# Verify all attributes from the lists exist and are equal after reconstruction
for attr_name in BlockMask._TENSOR_ATTRS + BlockMask._CONTEXT_ATTRS:
self.assertTrue(
hasattr(reconstructed_mask, attr_name),
f"Reconstructed mask missing attribute: {attr_name}",
)
original_value = getattr(block_mask, attr_name)
reconstructed_value = getattr(reconstructed_mask, attr_name)
if isinstance(original_value, torch.Tensor):
self.assertTrue(
torch.equal(original_value, reconstructed_value),
f"Tensor attribute {attr_name} not equal after reconstruction",
)
elif original_value is None:
self.assertIsNone(
reconstructed_value,
f"Attribute {attr_name} should be None but got {reconstructed_value}",
)
else:
self.assertEqual(
original_value,
reconstructed_value,
f"Attribute {attr_name} not equal after reconstruction",
)
@large_tensor_test_class("2GB", device=test_device[0])
class TestPagedAttention(InductorTestCase):

View File

@ -249,8 +249,12 @@ def _shard_dict_of_args(
)
assert args_chunk_spec is not None # Should have been set by caller
values, tree_spec = tree_flatten(args_dict)
chunk_specs, _ = tree_flatten(args_chunk_spec)
values, tree_spec = tree_flatten(
args_dict, is_leaf=lambda x: isinstance(x, BlockMask)
)
chunk_specs, _ = tree_flatten(
args_chunk_spec, is_leaf=lambda x: isinstance(x, BlockMask)
)
# First check and find the actual number of chunks
split_sizes = []
@ -369,10 +373,14 @@ def split_args_kwargs_into_chunks(
return _Replicate()
if args_chunk_spec is None:
args_chunk_spec = tree_map(default_spec, args)
args_chunk_spec = tree_map(
default_spec, args, is_leaf=lambda v: isinstance(v, BlockMask)
)
if kwargs_chunk_spec is None:
kwargs_chunk_spec = tree_map(default_spec, kwargs)
kwargs_chunk_spec = tree_map(
default_spec, kwargs, is_leaf=lambda v: isinstance(v, BlockMask)
)
args_split_dict = _shard_dict_of_args(
dict(enumerate(args)),

View File

@ -34,7 +34,7 @@ from torch.fx.experimental.proxy_tensor import (
_temp_remove_pre_dispatch_torch_function_mode,
)
from torch.nn.attention._utils import _validate_sdpa_input
from torch.utils._pytree import tree_map_only
from torch.utils._pytree import GetAttrKey, register_pytree_node, tree_map_only
# Private debug flag to disable internal compilation wrapping for debugging purposes.
@ -519,6 +519,24 @@ class BlockMask:
BLOCK_SIZE: tuple[int, int]
mask_mod: _mask_mod_signature
# Attribute lists for pytree flatten/unflatten
_TENSOR_ATTRS = [
"kv_num_blocks",
"kv_indices",
"full_kv_num_blocks",
"full_kv_indices",
"q_num_blocks",
"q_indices",
"full_q_num_blocks",
"full_q_indices",
]
_CONTEXT_ATTRS = [
"seq_lengths",
"BLOCK_SIZE",
"mask_mod",
]
def __init__(
self,
seq_lengths: tuple[int, int],
@ -913,6 +931,31 @@ class BlockMask:
)
return BlockMask(*mapped_attributes)
def _flatten(self):
"""Flatten BlockMask into a list of tensors and context."""
tensors = tuple(getattr(self, attr) for attr in self._TENSOR_ATTRS)
context = tuple(getattr(self, attr) for attr in self._CONTEXT_ATTRS)
return tensors, context
@classmethod
def _unflatten(cls, tensors, context):
"""Unflatten tensors and context back into a BlockMask."""
kwargs = {
**dict(zip(cls._CONTEXT_ATTRS, context)),
**dict(zip(cls._TENSOR_ATTRS, tensors)),
}
return cls(**kwargs)
def _flatten_with_keys(self):
"""Flatten BlockMask with keys for better tracing."""
tensors = tuple(
(GetAttrKey(attr), getattr(self, attr)) for attr in self._TENSOR_ATTRS
)
context = tuple(
(GetAttrKey(attr), getattr(self, attr)) for attr in self._CONTEXT_ATTRS
)
return tensors, context
def _broadcast_to_dim(x, dim):
while x.dim() < dim:
@ -1605,3 +1648,12 @@ def flex_attention(
return _finalize_outputs(
out, lse, max_scores, return_aux=return_aux, return_lse=return_lse
)
register_pytree_node(
BlockMask,
BlockMask._flatten,
BlockMask._unflatten,
flatten_with_keys_fn=BlockMask._flatten_with_keys,
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
)