mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor. The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph. We then use that size when doing later operations. Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere). At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled. After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164717 Approved by: https://github.com/bobrenjc93, https://github.com/ezyang ghstack dependencies: #165266
67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from torch.distributed.tensor import distribute_tensor, DTensor
|
|
from torch.distributed.tensor.placement_types import Replicate
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
)
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE
|
|
from torch.testing._internal.triton_utils import requires_gpu
|
|
|
|
|
|
class TestDynamic(DTensorTestBase):
|
|
@requires_gpu
|
|
@with_comms
|
|
# FIXME: Currently broken for fake tensor cache
|
|
@parametrize("fake_tensor_cache_enabled", [False])
|
|
def test_embedding(self, fake_tensor_cache_enabled):
|
|
with patch.object(
|
|
torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled
|
|
):
|
|
device_mesh = self.build_device_mesh()
|
|
|
|
placements = (Replicate(),)
|
|
|
|
num_embeddings = 202048
|
|
embedding_dim = 256
|
|
weight = distribute_tensor(
|
|
torch.rand(
|
|
[num_embeddings, embedding_dim],
|
|
dtype=torch.float32,
|
|
device=GPU_TYPE,
|
|
requires_grad=True,
|
|
),
|
|
device_mesh,
|
|
placements, # [Replicate()],
|
|
)
|
|
|
|
def forward(input_batch_inputs_):
|
|
to = weight.to(torch.float32)
|
|
emb = torch.nn.functional.embedding(input_batch_inputs_, to)
|
|
return emb
|
|
|
|
arg0 = torch.randint(
|
|
low=0, high=100, size=(2, 512), dtype=torch.int64, device=GPU_TYPE
|
|
)
|
|
arg0 = DTensor.from_local(arg0, device_mesh, placements)
|
|
|
|
compiled_forward = torch.compile(forward, fullgraph=True, dynamic=True)
|
|
_out = compiled_forward(arg0)
|
|
|
|
|
|
instantiate_parametrized_tests(TestDynamic)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|