mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Teach ProxyTorchDispatchMode how to decompose sympy.Expr into known inputs (#164717)
In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor. The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph. We then use that size when doing later operations. Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere). At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled. After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164717 Approved by: https://github.com/bobrenjc93, https://github.com/ezyang ghstack dependencies: #165266
This commit is contained in:
committed by
PyTorch MergeBot
parent
e86942f422
commit
5f21cc786a
66
test/distributed/tensor/test_dynamic.py
Normal file
66
test/distributed/tensor/test_dynamic.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor
|
||||
from torch.distributed.tensor.placement_types import Replicate
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
|
||||
|
||||
class TestDynamic(DTensorTestBase):
|
||||
@requires_gpu
|
||||
@with_comms
|
||||
# FIXME: Currently broken for fake tensor cache
|
||||
@parametrize("fake_tensor_cache_enabled", [False])
|
||||
def test_embedding(self, fake_tensor_cache_enabled):
|
||||
with patch.object(
|
||||
torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled
|
||||
):
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
placements = (Replicate(),)
|
||||
|
||||
num_embeddings = 202048
|
||||
embedding_dim = 256
|
||||
weight = distribute_tensor(
|
||||
torch.rand(
|
||||
[num_embeddings, embedding_dim],
|
||||
dtype=torch.float32,
|
||||
device=GPU_TYPE,
|
||||
requires_grad=True,
|
||||
),
|
||||
device_mesh,
|
||||
placements, # [Replicate()],
|
||||
)
|
||||
|
||||
def forward(input_batch_inputs_):
|
||||
to = weight.to(torch.float32)
|
||||
emb = torch.nn.functional.embedding(input_batch_inputs_, to)
|
||||
return emb
|
||||
|
||||
arg0 = torch.randint(
|
||||
low=0, high=100, size=(2, 512), dtype=torch.int64, device=GPU_TYPE
|
||||
)
|
||||
arg0 = DTensor.from_local(arg0, device_mesh, placements)
|
||||
|
||||
compiled_forward = torch.compile(forward, fullgraph=True, dynamic=True)
|
||||
_out = compiled_forward(arg0)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDynamic)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -304,7 +304,9 @@ def set_proxy_slot( # type: ignore[no-redef]
|
||||
import sympy
|
||||
|
||||
if isinstance(obj.node.expr, sympy.Symbol):
|
||||
tracer.sympy_expr_tracker[obj.node.expr] = _SympyExprTrackerValue(proxy)
|
||||
tracer.sympy_expr_tracker[obj.node.expr] = _SympyExprTrackerValue(
|
||||
proxy, obj
|
||||
)
|
||||
|
||||
|
||||
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
|
||||
@ -406,24 +408,144 @@ def get_proxy_slot(
|
||||
assert isinstance(obj, py_sym_types), type(obj)
|
||||
tracker = tracer.symnode_tracker
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if obj not in tracker:
|
||||
# Last ditch
|
||||
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
|
||||
value = tracer.sympy_expr_tracker[obj.node.expr].proxy
|
||||
else:
|
||||
if isinstance(default, _NoDefault):
|
||||
raise RuntimeError(
|
||||
f"{obj} ({id(obj)})is not tracked with proxy for {tracer}"
|
||||
)
|
||||
return default
|
||||
else:
|
||||
# pyrefly: ignore # index-error
|
||||
value = tracker[obj]
|
||||
# pyrefly: ignore # index-error
|
||||
value = tracker.get(obj)
|
||||
|
||||
if value is None and isinstance(obj, py_sym_types):
|
||||
if obj.node.is_symbolic():
|
||||
# Last ditch - we found a SymInt (SymBool, etc) we don't know
|
||||
# about.
|
||||
if (tmp := tracer.sympy_expr_tracker.get(obj.node.expr)) is not None:
|
||||
value = tmp.proxy
|
||||
|
||||
else:
|
||||
# Attempt to build it from first principles.
|
||||
_build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
|
||||
value = tracker.get(obj)
|
||||
|
||||
if value is None:
|
||||
# We don't know this value - return the default.
|
||||
if isinstance(default, _NoDefault):
|
||||
raise RuntimeError(
|
||||
f"{obj} ({type(obj)}, {id(obj)})is not tracked with proxy for {tracer}"
|
||||
)
|
||||
return default
|
||||
|
||||
res = transform(value)
|
||||
return res
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _sympy_handlers() -> dict[type[sympy.Expr], Callable[..., Any]]:
|
||||
"""
|
||||
Returns a dict converting sympy functions to python operators
|
||||
(i.e. `sympy.Mul` -> `operator.mul`)
|
||||
"""
|
||||
import torch.utils._sympy.interp
|
||||
|
||||
handlers = {}
|
||||
for k, v in torch.utils._sympy.interp.handlers().items():
|
||||
op = getattr(operator, v, None)
|
||||
if op is not None:
|
||||
handlers[k] = op
|
||||
return handlers
|
||||
|
||||
|
||||
def _build_proxy_for_sym_expr(
|
||||
tracer: _ProxyTracer, expr: sympy.Expr, out: PySymType | None = None
|
||||
) -> PySymType | None:
|
||||
"""
|
||||
Decompose `expr` and look for the pieces as inputs. If `out` is provided
|
||||
then that will be the resulting SymNode (and `out.expr` must be the same as
|
||||
`expr`).
|
||||
|
||||
This function is used when the ProxyTorchDispatchMode sees a SymNode
|
||||
that it hasn't seen before to try to associate it with traced inputs.
|
||||
|
||||
How can this happen?
|
||||
|
||||
First thing to remember is that although sympy.Exprs are interned (so
|
||||
`sympy.Expr("s3*s4")` will always have the same `id` and will always compare
|
||||
equal) SymNode does not (so doing `SymNode("s3")*SymNode("s4")` twice in a
|
||||
row will give two unique SymNodes).
|
||||
|
||||
- On way for this to happen is if we turn off tracing to compute an
|
||||
intermediate value and then USE that value with tracing turned on - for
|
||||
example if we turn off tracing to do some FakeTensor propagation to
|
||||
compute a size (dtensor does this) but then turn tracing back on and use
|
||||
that computed size.
|
||||
|
||||
- Another way is if we compute a size in one graph and stash it somewhere
|
||||
hidden (such as in some meta-data) and later use it in a different graph
|
||||
(dtensor does this too). Since the size was computed in the first graph
|
||||
and it's not an official input to the second graph it's not tracked
|
||||
properly. This is often going to show up as it usually works in fullgraph
|
||||
but a graph break causes a failure.
|
||||
|
||||
To handle this we decompose the sympy.Expr and look for the pieces as
|
||||
inputs. But there are problems with this approach:
|
||||
|
||||
- We lose operation provanance: We end up figuring out where to get the
|
||||
inputs - but those may not actually be correct. If we have "s1" coming in
|
||||
from both tensor1 and tensor2 and we pick the wrong one we could end up
|
||||
keeping a tensor alive longer than intended.
|
||||
|
||||
- There's no guarantee that those values are inputs to the graph: If we have
|
||||
"s1*s2" computed in a graph #1 and used in graph #2 there's no guarantee
|
||||
that the input that holds "s1" is actually an input on graph #2.
|
||||
|
||||
- The decomposition isn't guaranteed to be the same: Sympy can "simplify"
|
||||
expressions so it's possible that our inputs are "s1*s2" and "s3" but we
|
||||
decompose it into "s1" and "s2*s3" - which wouldn't be found.
|
||||
|
||||
Other ways we could handle this:
|
||||
|
||||
- Don't: Just require that all inputs are tracked properly. This is the
|
||||
"correct" solution but harder because you need to track down each
|
||||
potential problem one by one and fix them. And when it fails it's a lot of
|
||||
work to figure out both why it's failing and the right way to fix it. This
|
||||
is complicated by the fact that a stashed value could be incorrect but
|
||||
work fine until we happen to get an graph break in the wrong place - so it
|
||||
may be a while before the bug is found. (Maybe we need a "dynamo abuse
|
||||
mode" where we run tests with as many graph breaks inserted as possible?)
|
||||
|
||||
- Track SymNode ops separately from proxy tracing: Right now SymNode
|
||||
operations are tracked as part of the proxy tracing - so when we disable
|
||||
proxy tracing we also disable SymNode tracing. But we don't have to do
|
||||
that - we could instead always have SymNodes track where they came from
|
||||
and just use that when needed. This solves the problem of tracing being
|
||||
temporarily turned off but doesn't help if an input isn't present after a
|
||||
graph break.
|
||||
|
||||
- Better decomposition: Right now the decomposition is pretty simple. We do
|
||||
have a sat-solver available to us so we could theoretically do a better
|
||||
job figuring out a "correct" decomposition. But that still relies on
|
||||
having the inputs available at all - which isn't a guarantee.
|
||||
"""
|
||||
|
||||
if (value := tracer.sympy_expr_tracker.get(expr)) is not None:
|
||||
assert not out
|
||||
return value.value
|
||||
|
||||
args = []
|
||||
for arg in expr.args:
|
||||
if (arg_value := _build_proxy_for_sym_expr(tracer, arg)) is None:
|
||||
return None
|
||||
args.append(arg_value)
|
||||
args = tuple(args)
|
||||
|
||||
func: OpOverload | None = _sympy_handlers().get(expr.func) # type: ignore[assignment]
|
||||
if not func:
|
||||
# Handler not found
|
||||
return None
|
||||
|
||||
if out is None:
|
||||
out = func(*args)
|
||||
else:
|
||||
_sym_register(tracer, func, args, out)
|
||||
return out
|
||||
|
||||
|
||||
def snapshot_fake(val: Tensor, include_real: bool = False) -> Optional[Tensor]:
|
||||
# val.detach() will also eventually call fast_detach(),
|
||||
# but this saves us a full trip into __torch_dispatch__
|
||||
@ -1112,6 +1234,7 @@ class _SymNodeDict:
|
||||
@dataclass
|
||||
class _SympyExprTrackerValue:
|
||||
proxy: _PySymProxyType
|
||||
value: PySymType
|
||||
|
||||
|
||||
class PythonKeyTracer(Tracer):
|
||||
|
Reference in New Issue
Block a user