Teach ProxyTorchDispatchMode how to decompose sympy.Expr into known inputs (#164717)

In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor.

The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to  use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph.  We then use that size when doing later operations.

Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere).

At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled.

After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164717
Approved by: https://github.com/bobrenjc93, https://github.com/ezyang
ghstack dependencies: #165266
This commit is contained in:
Aaron Orenstein
2025-10-12 08:50:28 -07:00
committed by PyTorch MergeBot
parent e86942f422
commit 5f21cc786a
2 changed files with 204 additions and 15 deletions

View File

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from unittest.mock import patch
import torch
from torch.distributed.tensor import distribute_tensor, DTensor
from torch.distributed.tensor.placement_types import Replicate
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.triton_utils import requires_gpu
class TestDynamic(DTensorTestBase):
@requires_gpu
@with_comms
# FIXME: Currently broken for fake tensor cache
@parametrize("fake_tensor_cache_enabled", [False])
def test_embedding(self, fake_tensor_cache_enabled):
with patch.object(
torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled
):
device_mesh = self.build_device_mesh()
placements = (Replicate(),)
num_embeddings = 202048
embedding_dim = 256
weight = distribute_tensor(
torch.rand(
[num_embeddings, embedding_dim],
dtype=torch.float32,
device=GPU_TYPE,
requires_grad=True,
),
device_mesh,
placements, # [Replicate()],
)
def forward(input_batch_inputs_):
to = weight.to(torch.float32)
emb = torch.nn.functional.embedding(input_batch_inputs_, to)
return emb
arg0 = torch.randint(
low=0, high=100, size=(2, 512), dtype=torch.int64, device=GPU_TYPE
)
arg0 = DTensor.from_local(arg0, device_mesh, placements)
compiled_forward = torch.compile(forward, fullgraph=True, dynamic=True)
_out = compiled_forward(arg0)
instantiate_parametrized_tests(TestDynamic)
if __name__ == "__main__":
run_tests()

View File

@ -304,7 +304,9 @@ def set_proxy_slot( # type: ignore[no-redef]
import sympy
if isinstance(obj.node.expr, sympy.Symbol):
tracer.sympy_expr_tracker[obj.node.expr] = _SympyExprTrackerValue(proxy)
tracer.sympy_expr_tracker[obj.node.expr] = _SympyExprTrackerValue(
proxy, obj
)
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
@ -406,24 +408,144 @@ def get_proxy_slot(
assert isinstance(obj, py_sym_types), type(obj)
tracker = tracer.symnode_tracker
# pyrefly: ignore # unsupported-operation
if obj not in tracker:
# Last ditch
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
value = tracer.sympy_expr_tracker[obj.node.expr].proxy
else:
if isinstance(default, _NoDefault):
raise RuntimeError(
f"{obj} ({id(obj)})is not tracked with proxy for {tracer}"
)
return default
else:
# pyrefly: ignore # index-error
value = tracker[obj]
# pyrefly: ignore # index-error
value = tracker.get(obj)
if value is None and isinstance(obj, py_sym_types):
if obj.node.is_symbolic():
# Last ditch - we found a SymInt (SymBool, etc) we don't know
# about.
if (tmp := tracer.sympy_expr_tracker.get(obj.node.expr)) is not None:
value = tmp.proxy
else:
# Attempt to build it from first principles.
_build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
value = tracker.get(obj)
if value is None:
# We don't know this value - return the default.
if isinstance(default, _NoDefault):
raise RuntimeError(
f"{obj} ({type(obj)}, {id(obj)})is not tracked with proxy for {tracer}"
)
return default
res = transform(value)
return res
@functools.cache
def _sympy_handlers() -> dict[type[sympy.Expr], Callable[..., Any]]:
"""
Returns a dict converting sympy functions to python operators
(i.e. `sympy.Mul` -> `operator.mul`)
"""
import torch.utils._sympy.interp
handlers = {}
for k, v in torch.utils._sympy.interp.handlers().items():
op = getattr(operator, v, None)
if op is not None:
handlers[k] = op
return handlers
def _build_proxy_for_sym_expr(
tracer: _ProxyTracer, expr: sympy.Expr, out: PySymType | None = None
) -> PySymType | None:
"""
Decompose `expr` and look for the pieces as inputs. If `out` is provided
then that will be the resulting SymNode (and `out.expr` must be the same as
`expr`).
This function is used when the ProxyTorchDispatchMode sees a SymNode
that it hasn't seen before to try to associate it with traced inputs.
How can this happen?
First thing to remember is that although sympy.Exprs are interned (so
`sympy.Expr("s3*s4")` will always have the same `id` and will always compare
equal) SymNode does not (so doing `SymNode("s3")*SymNode("s4")` twice in a
row will give two unique SymNodes).
- On way for this to happen is if we turn off tracing to compute an
intermediate value and then USE that value with tracing turned on - for
example if we turn off tracing to do some FakeTensor propagation to
compute a size (dtensor does this) but then turn tracing back on and use
that computed size.
- Another way is if we compute a size in one graph and stash it somewhere
hidden (such as in some meta-data) and later use it in a different graph
(dtensor does this too). Since the size was computed in the first graph
and it's not an official input to the second graph it's not tracked
properly. This is often going to show up as it usually works in fullgraph
but a graph break causes a failure.
To handle this we decompose the sympy.Expr and look for the pieces as
inputs. But there are problems with this approach:
- We lose operation provanance: We end up figuring out where to get the
inputs - but those may not actually be correct. If we have "s1" coming in
from both tensor1 and tensor2 and we pick the wrong one we could end up
keeping a tensor alive longer than intended.
- There's no guarantee that those values are inputs to the graph: If we have
"s1*s2" computed in a graph #1 and used in graph #2 there's no guarantee
that the input that holds "s1" is actually an input on graph #2.
- The decomposition isn't guaranteed to be the same: Sympy can "simplify"
expressions so it's possible that our inputs are "s1*s2" and "s3" but we
decompose it into "s1" and "s2*s3" - which wouldn't be found.
Other ways we could handle this:
- Don't: Just require that all inputs are tracked properly. This is the
"correct" solution but harder because you need to track down each
potential problem one by one and fix them. And when it fails it's a lot of
work to figure out both why it's failing and the right way to fix it. This
is complicated by the fact that a stashed value could be incorrect but
work fine until we happen to get an graph break in the wrong place - so it
may be a while before the bug is found. (Maybe we need a "dynamo abuse
mode" where we run tests with as many graph breaks inserted as possible?)
- Track SymNode ops separately from proxy tracing: Right now SymNode
operations are tracked as part of the proxy tracing - so when we disable
proxy tracing we also disable SymNode tracing. But we don't have to do
that - we could instead always have SymNodes track where they came from
and just use that when needed. This solves the problem of tracing being
temporarily turned off but doesn't help if an input isn't present after a
graph break.
- Better decomposition: Right now the decomposition is pretty simple. We do
have a sat-solver available to us so we could theoretically do a better
job figuring out a "correct" decomposition. But that still relies on
having the inputs available at all - which isn't a guarantee.
"""
if (value := tracer.sympy_expr_tracker.get(expr)) is not None:
assert not out
return value.value
args = []
for arg in expr.args:
if (arg_value := _build_proxy_for_sym_expr(tracer, arg)) is None:
return None
args.append(arg_value)
args = tuple(args)
func: OpOverload | None = _sympy_handlers().get(expr.func) # type: ignore[assignment]
if not func:
# Handler not found
return None
if out is None:
out = func(*args)
else:
_sym_register(tracer, func, args, out)
return out
def snapshot_fake(val: Tensor, include_real: bool = False) -> Optional[Tensor]:
# val.detach() will also eventually call fast_detach(),
# but this saves us a full trip into __torch_dispatch__
@ -1112,6 +1234,7 @@ class _SymNodeDict:
@dataclass
class _SympyExprTrackerValue:
proxy: _PySymProxyType
value: PySymType
class PythonKeyTracer(Tracer):