Compare commits

...

2 Commits

Author SHA1 Message Date
a2fbe26c80 Update on "basic unbacked matmuls"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-17 16:43:33 -08:00
6a7ec0e344 basic unbacked matmuls
[ghstack-poisoned]
2025-11-17 16:09:28 -08:00
6 changed files with 158 additions and 58 deletions

View File

@ -464,6 +464,44 @@ def forward(self, b_parametrizations_buffer_original0, x):
run(g, 64, 8)
self.assertEqual(cnt.frame_count, 2)
def test_dtensor_unbacked_matmuls(self):
from torch.distributed.tensor import randn as d_randn
# use 2x2 mesh for testing
dist.destroy_process_group()
dist.init_process_group("fake", store=FakeStore(), rank=0, world_size=4)
device_mesh = init_device_mesh(self.device_type, (2, 2))
for test_index, (px, py) in enumerate(
[
# these should pass as no redistribute strategy is available
[[Replicate(), Replicate()], [Replicate(), Replicate()]],
[[Replicate(), Shard(0)], [Replicate(), Replicate()]],
[[Replicate(), Shard(1)], [Replicate(), Shard(0)]],
]
):
# create DTensors with unbacked outer/inner sizes
x_dt = d_randn(64, 64, device_mesh=device_mesh, placements=px)
y_dt = d_randn(64, 64, device_mesh=device_mesh, placements=py)
for i in range(2):
torch._dynamo.decorators.mark_unbacked(x_dt, i)
torch._dynamo.decorators.mark_unbacked(y_dt, i)
# full-graph capture
torch._dynamo.reset()
cnt = torch._dynamo.testing.CompileCounterWithBackend("eager")
fn = torch.compile(torch.mm, backend=cnt, fullgraph=True)
fn(x_dt, y_dt)
# test sharded matmuls with zero-size shards
if test_index >= 1:
dx = d_randn(3, 1, device_mesh=device_mesh, placements=px)
dy = d_randn(1, 1, device_mesh=device_mesh, placements=py)
out, eager_out = fn(dx, dy), torch.mm(dx, dy)
self.assertEqual(tuple(out.shape), (3, 1))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(out.shape, eager_out.shape)
def test_dtensor_requires_grad_recompile(self):
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -552,6 +552,32 @@ graph():
return (getitem,)""", # noqa: B950
)
def test_dtensor_mark_unbacked(self):
device_mesh = init_device_mesh(
self.device_type, mesh_shape=(self.world_size // 2, 2)
)
class Foo(torch.nn.Module):
def forward(self, x, y):
return x @ y
x_dt = distribute_tensor(
torch.randn(64, 64), device_mesh, placements=[Replicate(), Replicate()]
)
y_dt = x_dt.clone()
for i in range(2):
torch._dynamo.decorators.mark_unbacked(x_dt, i)
torch._dynamo.decorators.mark_unbacked(y_dt, i)
gm = dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
n = 0
for node in gm.graph.nodes:
if bindings := node.meta.get("unbacked_bindings", {}):
# 2 outer sizes, 2 inner sizes
self.assertEqual(len(bindings), 4)
n += 1
self.assertEqual(n, 2) # 2 nodes with bindings (x, y)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -575,34 +575,46 @@ def mark_unbacked(
specialize_on (Optional[list[Any]], default=None): A list of specialization criteria (e.g., lambdas) for this dimension.
If provided, Dynamo will generate specialized compiled regions for each criterion in addition to a generic trace.
"""
# You could have copied the mark_dynamic behavior but I'm not convinced
# it's what you want
assert not is_traceable_wrapper_subclass(t), "not implemented yet"
if isinstance(t, torch.distributed.tensor.DTensor):
# apply on inner tensor sizes/strides
_apply_func_to_inner_tensors_of_same_dim(mark_unbacked, t, index)
elif is_traceable_wrapper_subclass(t):
# You could have copied the mark_dynamic behavior but I'm not convinced
# it's what you want
assert not is_traceable_wrapper_subclass(t), "not implemented yet"
if isinstance(index, int):
if strict:
if not hasattr(t, "_dynamo_strict_unbacked_indices"):
# pyrefly: ignore [missing-attribute]
t._dynamo_strict_unbacked_indices = set()
# pyrefly: ignore [missing-attribute]
t._dynamo_strict_unbacked_indices.add(index)
return
if not hasattr(t, "_specialized_on"):
# pyrefly: ignore [missing-attribute]
t._specialize_on = {}
if not hasattr(t, "_dynamo_unbacked_indices"):
# pyrefly: ignore [missing-attribute]
t._dynamo_unbacked_indices = set()
if not hasattr(t, "_dynamo_hint_overrides"):
# pyrefly: ignore [missing-attribute]
t._dynamo_hint_overrides = {}
if hint_override:
# pyrefly: ignore [missing-attribute]
t._dynamo_hint_overrides[index] = hint_override
# FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies:
# TypeError: 'Attribute' object does not support item assignment
if isinstance(t._specialize_on, dict):
# pyrefly: ignore [missing-attribute]
t._specialize_on[index] = specialize_on if specialize_on is not None else []
# pyrefly: ignore [missing-attribute]
t._dynamo_unbacked_indices.add(index)
return

View File

@ -160,6 +160,8 @@ def prod(xs: Iterable[int]) -> int:
def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
"""Check if the shape is shardable according to the spec."""
from torch.fx.experimental.symbolic_shapes import guard_or_false
# number of shards in each tensor dimension
shards_map = [1] * len(shape)
for i, placement in enumerate(spec.placements):
@ -172,7 +174,7 @@ def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
for i, dim_size in enumerate(shape):
# TODO: maybe we should determine is_shardable based on
# whether it's evenly sharded or not
if shards_map[i] > 1 and dim_size < shards_map[i]:
if shards_map[i] > 1 and guard_or_false(dim_size < shards_map[i]):
return False
return True

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
import logging
import threading
from collections.abc import Callable, Sequence
from functools import lru_cache
@ -31,6 +32,8 @@ from torch.distributed.tensor._utils import (
aten = torch.ops.aten
log = logging.getLogger(__name__)
def _length(obj) -> int:
if obj is None:
@ -589,12 +592,16 @@ class ShardingPropagator:
def _select_strategy(
self, strategy: OpStrategy, op_schema: Optional[OpSchema] = None
) -> OpSpec:
from torch.fx.experimental.symbolic_shapes import guard_or_false
if len(strategy.strategies) == 1:
# short cut with only one possible OpSpec
return strategy.strategies[0]
op_spec_costs: list[float] = []
op_spec_costs: list[torch.types.FloatLikeType] = []
no_redistribute_strategy_index: int = -1
negative_cost_index: int = -1
zero_cost_index: int = -1
for strategy_idx, op_spec in enumerate(strategy.strategies):
assert op_spec.redistribute_cost is not None, (
"must set redistribute cost each OpSpec!"
@ -602,37 +609,45 @@ class ShardingPropagator:
redistribute_cost = sum(chain.from_iterable(op_spec.redistribute_cost))
op_spec_costs.append(redistribute_cost)
# If there's no redistribute cost, we record the index of the strategy
# which doesn't need redistribute.
# If there are strategies with negative/zero/no redistribute cost,
# we record those indices.
# TODO: Currently this only applies to OpStrategy selection. Requires extra
# logic to make it work for TupleStrategy, if needed.
if op_schema is not None and redistribute_cost == 0:
needs_redistribute = False
for spec_idx, input_spec in enumerate(op_schema.args_spec):
desired_spec = (
op_spec.output_spec
if op_spec.input_specs is None
else op_spec.input_specs[spec_idx]
)
if input_spec.placements != desired_spec.placements:
needs_redistribute = True
break
if op_schema is not None:
if guard_or_false(redistribute_cost < 0):
if (
negative_cost_index == -1
or redistribute_cost < op_spec_costs[negative_cost_index]
# assume negative costs are coefficients, so we don't need guard_or_false here
):
negative_cost_index = strategy_idx
elif guard_or_false(redistribute_cost == 0):
needs_redistribute = False
for spec_idx, input_spec in enumerate(op_schema.args_spec):
desired_spec = (
op_spec.output_spec
if op_spec.input_specs is None
else op_spec.input_specs[spec_idx]
)
if input_spec.placements != desired_spec.placements:
needs_redistribute = True
break
if not needs_redistribute:
no_redistribute_strategy_index = strategy_idx
if not needs_redistribute:
no_redistribute_strategy_index = strategy_idx
elif zero_cost_index == -1:
zero_cost_index = strategy_idx
# for eager execution, we just select the one with the minimal redistribute cost
min_cost = min(op_spec_costs)
if min_cost < 0:
# If there's negative cost, we select the one with the minimal cost,
# even if this means we need to redistribute, e.g. via local chunking.
# E.g. this can happen for ops in self.op_to_shape_and_stride_idx
# when the inputs / outputs are sharded.
selected_strategy_index = op_spec_costs.index(min_cost)
elif min_cost == 0 and no_redistribute_strategy_index != -1:
# If there's no redistribute cost, we select the one with no redistribute.
# prioritize negative/zero/no redistribute cost strategies
if negative_cost_index != -1:
selected_strategy_index = negative_cost_index
elif no_redistribute_strategy_index != -1:
selected_strategy_index = no_redistribute_strategy_index
elif zero_cost_index != -1:
selected_strategy_index = zero_cost_index
else:
# default to choosing minimal redistribute cost
min_cost = min(op_spec_costs)
selected_strategy_index = op_spec_costs.index(min_cost)
return strategy.strategies[selected_strategy_index]

View File

@ -1192,35 +1192,13 @@ def _free_unbacked_symbols_with_path(
if pending is None:
pending = set()
r = {}
if isinstance(a, (tuple, list)):
# NB: real is apparently not always a tuple/list here
# python test/inductor/test_torchinductor.py CpuTests.test_index_propagation_nested_indirect_indexing_cpu
for i in range(len(a)):
r.update(
go(
a[i],
path + (pytree.SequenceKey(i),),
real=real[i] if real is not None else None, # type: ignore[index]
)
)
elif is_traceable_wrapper_subclass(a):
# TODO: Determine if this is correct
attrs, _ = a.__tensor_flatten__()
for attr in attrs:
sub = getattr(a, attr)
r.update(go(sub, path + (InnerTensorKey(attr),)))
elif isinstance(a, torch.Tensor) and is_batchedtensor(a):
unwrapped_tensor = get_unwrapped(a)
r.update(go(unwrapped_tensor, path))
elif isinstance(a, torch.Tensor) and not is_batchedtensor(a):
from torch._subclasses.fake_tensor import FakeTensor
assert isinstance(a, FakeTensor)
def match_tensor(a: torch.Tensor, real_tensor: Optional[torch.Tensor] = None):
r.update(
go(
a.size(),
path + (CallMethodKey("size"),),
real=a.real_tensor.size() if a.real_tensor is not None else None,
real=real_tensor.size() if real_tensor is not None else None,
)
)
if a.layout not in [
@ -1233,7 +1211,7 @@ def _free_unbacked_symbols_with_path(
go(
a.stride(),
path + (CallMethodKey("stride"),),
real=a.real_tensor.stride() if a.real_tensor is not None else None,
real=real_tensor.stride() if real_tensor is not None else None,
)
)
r.update(
@ -1241,13 +1219,42 @@ def _free_unbacked_symbols_with_path(
a.storage_offset(),
path + (CallMethodKey("storage_offset"),),
real=(
a.real_tensor.storage_offset()
if a.real_tensor is not None
else None
real_tensor.storage_offset() if real_tensor is not None else None
),
)
)
if isinstance(a, (tuple, list)):
# NB: real is apparently not always a tuple/list here
# python test/inductor/test_torchinductor.py CpuTests.test_index_propagation_nested_indirect_indexing_cpu
for i in range(len(a)):
r.update(
go(
a[i],
path + (pytree.SequenceKey(i),),
real=real[i] if real is not None else None, # type: ignore[index]
)
)
elif is_traceable_wrapper_subclass(a):
from torch.distributed.tensor import DTensor
# TODO: Determine if this is correct
attrs, _ = a.__tensor_flatten__()
for attr in attrs:
sub = getattr(a, attr)
r.update(go(sub, path + (InnerTensorKey(attr),)))
# match DTensor outer shapes
if isinstance(a, DTensor):
match_tensor(a)
elif isinstance(a, torch.Tensor) and is_batchedtensor(a):
unwrapped_tensor = get_unwrapped(a)
r.update(go(unwrapped_tensor, path))
elif isinstance(a, torch.Tensor) and not is_batchedtensor(a):
from torch._subclasses.fake_tensor import FakeTensor
assert isinstance(a, FakeTensor)
match_tensor(a, a.real_tensor)
elif (
isinstance(a, (torch.SymInt, torch.SymFloat))
and isinstance(s := expr(a), sympy.Symbol)