Files
pytorch/test/distributed/tensor/test_dynamic.py
Aaron Orenstein 5f21cc786a Teach ProxyTorchDispatchMode how to decompose sympy.Expr into known inputs (#164717)
In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor.

The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to  use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph.  We then use that size when doing later operations.

Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere).

At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled.

After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164717
Approved by: https://github.com/bobrenjc93, https://github.com/ezyang
ghstack dependencies: #165266
2025-10-16 20:57:06 +00:00

67 lines
2.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from unittest.mock import patch
import torch
from torch.distributed.tensor import distribute_tensor, DTensor
from torch.distributed.tensor.placement_types import Replicate
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.triton_utils import requires_gpu
class TestDynamic(DTensorTestBase):
@requires_gpu
@with_comms
# FIXME: Currently broken for fake tensor cache
@parametrize("fake_tensor_cache_enabled", [False])
def test_embedding(self, fake_tensor_cache_enabled):
with patch.object(
torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled
):
device_mesh = self.build_device_mesh()
placements = (Replicate(),)
num_embeddings = 202048
embedding_dim = 256
weight = distribute_tensor(
torch.rand(
[num_embeddings, embedding_dim],
dtype=torch.float32,
device=GPU_TYPE,
requires_grad=True,
),
device_mesh,
placements, # [Replicate()],
)
def forward(input_batch_inputs_):
to = weight.to(torch.float32)
emb = torch.nn.functional.embedding(input_batch_inputs_, to)
return emb
arg0 = torch.randint(
low=0, high=100, size=(2, 512), dtype=torch.int64, device=GPU_TYPE
)
arg0 = DTensor.from_local(arg0, device_mesh, placements)
compiled_forward = torch.compile(forward, fullgraph=True, dynamic=True)
_out = compiled_forward(arg0)
instantiate_parametrized_tests(TestDynamic)
if __name__ == "__main__":
run_tests()