Compare commits

...

1 Commits

Author SHA1 Message Date
724034df5e proxy track DTensor outer sizes/strides
[ghstack-poisoned]
2025-11-17 14:28:26 -08:00
5 changed files with 83 additions and 53 deletions

View File

@ -211,8 +211,8 @@ def forward(self, b_parametrizations_buffer_original0, x):
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
add_1 = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
view_1 = torch.ops.aten.view.default(add_1, [4, 4]); add_1 = None
return (view_1,)""", # noqa: B950
)
@ -352,9 +352,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
self.assertEqual(res, ref)
@skipIfHpu
@unittest.skip(
"DTensor + dynamic fails - s77 + 8 is not tracked with proxy .. proxy_tensor.PythonKeyTracer"
)
def test_dtensor_dynamic_slice(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@ -396,9 +393,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
res = opt_fn(x)
self.assertEqual(res, ref)
@unittest.skip(
"DTensor + dynamic fails - s77 + 8 is not tracked with proxy .. proxy_tensor.PythonKeyTracer"
)
def test_dtensor_dynamic_cat(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -552,6 +552,35 @@ graph():
return (getitem,)""", # noqa: B950
)
def test_dtensor_unbacked_matmuls(self):
device_mesh = init_device_mesh(
self.device_type, mesh_shape=(self.world_size // 2, 2)
)
class Foo(torch.nn.Module):
def forward(self, x, y):
for i in range(2):
torch._check(x.size(i) >= 8)
torch._check(y.size(i) >= 8)
return x @ y
x = torch.randn(64, 64)
y = torch.randn(64, 64)
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate(), Replicate()])
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate(), Replicate()])
for i in range(2):
torch._dynamo.decorators.mark_unbacked(x_dt, i)
torch._dynamo.decorators.mark_unbacked(y_dt, i)
gm = dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
n = 0
for node in gm.graph.nodes:
if bindings := node.meta.get("unbacked_bindings", {}):
# 2 outer sizes, 2 inner sizes
self.assertEqual(len(bindings), 4)
n += 1
self.assertEqual(n, 2) # 2 nodes with bindings
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -575,34 +575,46 @@ def mark_unbacked(
specialize_on (Optional[list[Any]], default=None): A list of specialization criteria (e.g., lambdas) for this dimension.
If provided, Dynamo will generate specialized compiled regions for each criterion in addition to a generic trace.
"""
# You could have copied the mark_dynamic behavior but I'm not convinced
# it's what you want
assert not is_traceable_wrapper_subclass(t), "not implemented yet"
if isinstance(t, torch.distributed.tensor.DTensor):
# apply on inner tensor sizes/strides
_apply_func_to_inner_tensors_of_same_dim(mark_unbacked, t, index)
elif is_traceable_wrapper_subclass(t):
# You could have copied the mark_dynamic behavior but I'm not convinced
# it's what you want
assert not is_traceable_wrapper_subclass(t), "not implemented yet"
if isinstance(index, int):
if strict:
if not hasattr(t, "_dynamo_strict_unbacked_indices"):
# pyrefly: ignore [missing-attribute]
t._dynamo_strict_unbacked_indices = set()
# pyrefly: ignore [missing-attribute]
t._dynamo_strict_unbacked_indices.add(index)
return
if not hasattr(t, "_specialized_on"):
# pyrefly: ignore [missing-attribute]
t._specialize_on = {}
if not hasattr(t, "_dynamo_unbacked_indices"):
# pyrefly: ignore [missing-attribute]
t._dynamo_unbacked_indices = set()
if not hasattr(t, "_dynamo_hint_overrides"):
# pyrefly: ignore [missing-attribute]
t._dynamo_hint_overrides = {}
if hint_override:
# pyrefly: ignore [missing-attribute]
t._dynamo_hint_overrides[index] = hint_override
# FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies:
# TypeError: 'Attribute' object does not support item assignment
if isinstance(t._specialize_on, dict):
# pyrefly: ignore [missing-attribute]
t._specialize_on[index] = specialize_on if specialize_on is not None else []
# pyrefly: ignore [missing-attribute]
t._dynamo_unbacked_indices.add(index)
return

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
import contextlib
import threading
from collections.abc import Callable, Sequence
from functools import lru_cache
@ -167,20 +166,10 @@ class ShardingPropagator:
return None
# NOTE: We must call the tracing in fake tensor mode so that it avoids
# materializing memory. Also disable the proxy mode tracing to prevent
# these operators to be inserted in the fx graph.
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
# DTensor.dispatch runs fake tensor prop twice, once here, and once for the actual
# local tensor result. The result here is never surfaced to tracing, and so if
# the op is data-dependent, can result in PendingUnbackedSymbolNotFound errors.
# materializing memory. This call should also be responsible for
# outer size/stride proxy tracking.
fake_mode = detect_fake_mode() or FakeTensorMode()
suppress_fresh_symbols_ctx = (
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
if fake_mode.shape_env
else contextlib.nullcontext()
)
with fake_mode, disable_proxy_modes_tracing(), suppress_fresh_symbols_ctx:
with fake_mode:
fake_args = op_schema.gen_fake_args()
fake_kwargs = op_schema.gen_fake_kwargs()
fake_out = op_schema.op(*fake_args, **fake_kwargs)

View File

@ -1192,35 +1192,13 @@ def _free_unbacked_symbols_with_path(
if pending is None:
pending = set()
r = {}
if isinstance(a, (tuple, list)):
# NB: real is apparently not always a tuple/list here
# python test/inductor/test_torchinductor.py CpuTests.test_index_propagation_nested_indirect_indexing_cpu
for i in range(len(a)):
r.update(
go(
a[i],
path + (pytree.SequenceKey(i),),
real=real[i] if real is not None else None, # type: ignore[index]
)
)
elif is_traceable_wrapper_subclass(a):
# TODO: Determine if this is correct
attrs, _ = a.__tensor_flatten__()
for attr in attrs:
sub = getattr(a, attr)
r.update(go(sub, path + (InnerTensorKey(attr),)))
elif isinstance(a, torch.Tensor) and is_batchedtensor(a):
unwrapped_tensor = get_unwrapped(a)
r.update(go(unwrapped_tensor, path))
elif isinstance(a, torch.Tensor) and not is_batchedtensor(a):
from torch._subclasses.fake_tensor import FakeTensor
assert isinstance(a, FakeTensor)
def match_tensor(a: torch.Tensor, real_tensor: Optional[torch.Tensor] = None):
r.update(
go(
a.size(),
path + (CallMethodKey("size"),),
real=a.real_tensor.size() if a.real_tensor is not None else None,
real=real_tensor.size() if real_tensor is not None else None,
)
)
if a.layout not in [
@ -1233,7 +1211,7 @@ def _free_unbacked_symbols_with_path(
go(
a.stride(),
path + (CallMethodKey("stride"),),
real=a.real_tensor.stride() if a.real_tensor is not None else None,
real=real_tensor.stride() if real_tensor is not None else None,
)
)
r.update(
@ -1241,13 +1219,41 @@ def _free_unbacked_symbols_with_path(
a.storage_offset(),
path + (CallMethodKey("storage_offset"),),
real=(
a.real_tensor.storage_offset()
if a.real_tensor is not None
else None
real_tensor.storage_offset() if real_tensor is not None else None
),
)
)
if isinstance(a, (tuple, list)):
# NB: real is apparently not always a tuple/list here
# python test/inductor/test_torchinductor.py CpuTests.test_index_propagation_nested_indirect_indexing_cpu
for i in range(len(a)):
r.update(
go(
a[i],
path + (pytree.SequenceKey(i),),
real=real[i] if real is not None else None, # type: ignore[index]
)
)
elif is_traceable_wrapper_subclass(a):
from torch.distributed.tensor import DTensor
# TODO: Determine if this is correct
attrs, _ = a.__tensor_flatten__()
for attr in attrs:
sub = getattr(a, attr)
r.update(go(sub, path + (InnerTensorKey(attr),)))
if isinstance(a, DTensor):
match_tensor(a)
elif isinstance(a, torch.Tensor) and is_batchedtensor(a):
unwrapped_tensor = get_unwrapped(a)
r.update(go(unwrapped_tensor, path))
elif isinstance(a, torch.Tensor) and not is_batchedtensor(a):
from torch._subclasses.fake_tensor import FakeTensor
assert isinstance(a, FakeTensor)
match_tensor(a, a.real_tensor)
elif (
isinstance(a, (torch.SymInt, torch.SymFloat))
and isinstance(s := expr(a), sympy.Symbol)