mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157640 Approved by: https://github.com/yewentao256, https://github.com/malfet
1115 lines
40 KiB
Python
1115 lines
40 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import contextlib
|
|
import copy
|
|
import functools
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.testing
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch._C import FileCheck
|
|
from torch._inductor.utils import run_and_get_triton_code
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
checkpoint_wrapper,
|
|
CheckpointImpl,
|
|
)
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
|
from torch.distributed.tensor.parallel import (
|
|
ColwiseParallel,
|
|
parallelize_module,
|
|
PrepareModuleInput,
|
|
PrepareModuleOutput,
|
|
RowwiseParallel,
|
|
)
|
|
from torch.distributed.tensor.placement_types import _StridedShard
|
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
|
from torch.testing._internal.common_fsdp import get_devtype
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfHpu,
|
|
skipIfTorchDynamo,
|
|
TEST_CUDA,
|
|
TEST_HPU,
|
|
)
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
MLPModule,
|
|
with_comms,
|
|
)
|
|
from torch.testing._internal.distributed.fake_pg import FakeStore
|
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
dev_type = torch.device(get_devtype())
|
|
|
|
|
|
class SimpleModel(nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.mlp_0 = MLPModule(device)
|
|
self.mlp_1 = MLPModule(device)
|
|
|
|
def forward(self, input):
|
|
return self.mlp_1(self.mlp_0(input))
|
|
|
|
|
|
def extract_graph(fx_g, _, graph_cell):
|
|
graph_cell[0] = fx_g.code
|
|
return fx_g
|
|
|
|
|
|
# Make a custom compiler that runs aot autograd but extracts the fw graph
|
|
fw_graph_cell = [None]
|
|
bw_graph_cell = [None]
|
|
fw_compiler = functools.partial(extract_graph, graph_cell=fw_graph_cell)
|
|
bw_compiler = functools.partial(extract_graph, graph_cell=bw_graph_cell)
|
|
|
|
from functorch.compile import min_cut_rematerialization_partition
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
|
|
|
|
aot_eager_graph = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
|
|
|
|
class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|
def setUp(self):
|
|
super(
|
|
type(self), self
|
|
).setUp() # use explicit params for compiled autograd test wrapping
|
|
fake_store = FakeStore()
|
|
dist.init_process_group(
|
|
"fake", store=fake_store, rank=0, world_size=self.world_size
|
|
)
|
|
|
|
def tearDown(self):
|
|
super(
|
|
type(self), self
|
|
).tearDown() # use explicit params for compiled autograd test wrapping
|
|
dist.destroy_process_group()
|
|
|
|
@property
|
|
def device_type(self) -> str:
|
|
return "cuda" if TEST_CUDA else "hpu" if TEST_HPU else "cpu"
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 2
|
|
|
|
def test_dtensor_basic(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
@torch.compile(backend="aot_eager", fullgraph=True)
|
|
def fn(x):
|
|
return x * x + 2
|
|
|
|
param = torch.randn(4, 4, requires_grad=True)
|
|
x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False)
|
|
|
|
res = fn(x)
|
|
res.to_local().sum().backward()
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
def test_dtensor_basic_export(self):
|
|
mesh = DeviceMesh("cuda", torch.arange(self.world_size))
|
|
|
|
param = torch.randn(4, 4)
|
|
param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.buffer = torch.nn.Buffer(param_x)
|
|
|
|
def forward(self, x):
|
|
inter = self.buffer + DTensor.from_local(
|
|
x, mesh, [Shard(0)], run_check=False
|
|
)
|
|
return inter.to_local()
|
|
|
|
torch.utils._pytree.register_constant(
|
|
torch.distributed.tensor._dtensor_spec.DTensorSpec
|
|
)
|
|
torch.utils._pytree.register_constant(DeviceMesh)
|
|
|
|
ep = torch.export.export_for_training(
|
|
Foo(), (torch.randn(4, 4, dtype=torch.float64),), strict=False
|
|
)
|
|
self.assertExpectedInline(
|
|
str(ep.graph_module.code).strip(),
|
|
"""\
|
|
def forward(self, b_buffer, x):
|
|
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
|
|
to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None
|
|
view_as = torch.ops.aten.view_as.default(to, to); to = None
|
|
dtensor___init__0 = self.dtensor___init__0
|
|
dtensor_const_func_spec0 = self.dtensor_const_func_spec0
|
|
flat_apply = torch.ops.higher_order.flat_apply(dtensor_const_func_spec0, dtensor___init__0, view_as, False); dtensor_const_func_spec0 = dtensor___init__0 = view_as = None
|
|
add = torch.ops.aten.add.Tensor(b_buffer, flat_apply); b_buffer = flat_apply = None
|
|
access_subclass_inner_tensor_default_4 = torch.ops.export.access_subclass_inner_tensor.default(add, '_local_tensor'); add = None
|
|
view_as_1 = torch.ops.aten.view_as.default(access_subclass_inner_tensor_default_4, access_subclass_inner_tensor_default_4); access_subclass_inner_tensor_default_4 = None
|
|
return (view_as_1,)""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
str(ep.run_decompositions({}).graph_module.code).strip(),
|
|
"""\
|
|
def forward(self, b_parametrizations_buffer_original0, x):
|
|
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
|
|
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None
|
|
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
|
|
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
|
|
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
|
|
return (view_1,)""", # noqa: B950
|
|
)
|
|
|
|
def test_placement_compile(self):
|
|
def fn(x):
|
|
a = 0
|
|
if x.is_replicate():
|
|
a += 1
|
|
if x.is_shard():
|
|
a += 2
|
|
if x.dim < 0:
|
|
raise RuntimeError("dim < 0")
|
|
if x.is_shard(0):
|
|
a += 2
|
|
if x.is_shard(dim=0):
|
|
a += 2
|
|
if x.is_shard(dim=None):
|
|
a += 2
|
|
if x.is_partial():
|
|
a += 3
|
|
return a
|
|
|
|
compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn)
|
|
split_factors = [2, 3, 4]
|
|
for x in [Shard(0), Replicate(), Partial()] + [
|
|
_StridedShard(0, split_factor=s) for s in split_factors
|
|
]:
|
|
opt_fn = fn(x)
|
|
compiled_out = compiled_fn(x)
|
|
self.assertEqual(opt_fn, compiled_out)
|
|
|
|
def test_device_mesh_compile(self):
|
|
def fn(x: DeviceMesh):
|
|
# test size()
|
|
a = x.size()
|
|
b = x.size(0)
|
|
c = x.size(mesh_dim=0)
|
|
size = a + b + c
|
|
# test get_coordinate()
|
|
coord = x.get_coordinate()
|
|
# test get_group()
|
|
group0 = x.get_group(0)
|
|
group1 = x.get_group(mesh_dim=1)
|
|
return size, coord, group0, group1
|
|
|
|
# Can't be fullgraph=True because ProcessGroup is not reconstructible in dynamo
|
|
compiled_fn = torch.compile(backend="aot_eager")(fn)
|
|
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).unsqueeze(1))
|
|
opt_fn = fn(mesh)
|
|
compiled_out = compiled_fn(mesh)
|
|
self.assertEqual(opt_fn, compiled_out)
|
|
|
|
def test_fakify_dtensor(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
# pass in DTensor as inputs/outputs to the function
|
|
def fn(x):
|
|
return x
|
|
|
|
x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False)
|
|
ref = fn(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(res, ref)
|
|
|
|
def test_dynamo_dtensor(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
# test passing in DTensor as inputs/outputs and run some tensor computation
|
|
def fn(x):
|
|
return x * x + 2
|
|
|
|
x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False)
|
|
ref = fn(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(res, ref)
|
|
|
|
@skipIfHpu
|
|
def test_dtensor_dynamic(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
# test passing in DTensor as inputs/outputs and run some tensor computation
|
|
def fn(x):
|
|
return (
|
|
torch.mul(x, x)
|
|
.redistribute(device_mesh=x.device_mesh, placements=[Replicate()])
|
|
.to_local()[0]
|
|
)
|
|
|
|
x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False)
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
ref = fn(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(res, ref)
|
|
|
|
def test_dtensor_attribute_access_on_intermediate(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
def fn(x):
|
|
tmp = x * 2
|
|
if tmp.placements[0].is_shard():
|
|
return tmp._local_tensor + 2
|
|
else:
|
|
return tmp._local_tensor + 3
|
|
|
|
x = DTensor.from_local(torch.ones(4), mesh, [Shard(0)], run_check=False)
|
|
ref = fn(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(res, ref)
|
|
|
|
def test_dtensor_constructor_w_graph_break(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
x = torch.randn(64, 32, requires_grad=True)
|
|
spec = DTensorSpec(
|
|
mesh,
|
|
(Replicate(), Shard(0)),
|
|
tensor_meta=TensorMeta(
|
|
shape=torch.Size([128, 32]), stride=(32, 1), dtype=x.dtype
|
|
),
|
|
)
|
|
|
|
# test passing in DTensor as inputs/outputs and run some tensor computation
|
|
def fn(x):
|
|
print("graph break!")
|
|
return DTensor(
|
|
x,
|
|
spec,
|
|
requires_grad=x.requires_grad,
|
|
)
|
|
|
|
fn(x)
|
|
torch.compile(fn, backend="eager")(x)
|
|
|
|
def test_dtensor_constructor_w_dynamo_disable(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
x = torch.randn(32, requires_grad=True)
|
|
spec = DTensorSpec(
|
|
mesh,
|
|
(Replicate(),),
|
|
tensor_meta=TensorMeta(shape=torch.Size([32]), stride=(1,), dtype=x.dtype),
|
|
)
|
|
|
|
@torch._dynamo.disable(recursive=False)
|
|
def fn(x):
|
|
print("foo")
|
|
return DTensor(
|
|
x,
|
|
spec,
|
|
requires_grad=x.requires_grad,
|
|
)
|
|
|
|
out = fn(x)
|
|
out2 = torch.compile(fn, backend="eager")(x)
|
|
self.assertEqual(out, out2)
|
|
|
|
def test_dtensor_noncontiguous_output(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
# test passing in DTensor as inputs/outputs and run some tensor computation
|
|
def fn(x, y, z):
|
|
x_transposed = x.permute(0, 2, 1).contiguous()
|
|
tmp = torch._C._nn.linear(x_transposed, y, z)
|
|
return tmp.permute(0, 2, 1)
|
|
|
|
x_inner = torch.randn(4, 16, 4, requires_grad=True)
|
|
y_inner = torch.randn(4, 16, requires_grad=True)
|
|
z_inner = torch.randn(4, requires_grad=True)
|
|
x = DTensor.from_local(x_inner, mesh, [Shard(1)], run_check=False)
|
|
y = DTensor.from_local(y_inner, mesh, [Shard(1)], run_check=False)
|
|
z = DTensor.from_local(z_inner, mesh, [Replicate()], run_check=False)
|
|
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(x, y, z)
|
|
out.contiguous().sum().backward()
|
|
|
|
def test_dynamo_dtensor_from_local(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
# create DTensor inside fn and run some compute
|
|
def fn(x):
|
|
dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False)
|
|
return dt.to_local() + 2
|
|
|
|
# below is the op approach for reference
|
|
# from torch.distributed._tensor.api import _FromTorchTensor
|
|
# def from_local_tensor(x):
|
|
# return _FromTorchTensor.apply(x, mesh, [Replicate()], False)
|
|
|
|
# _dt_lib_def = torch.library.Library("dtensor", "DEF")
|
|
# _dt_lib_def.define("from_local(Tensor self) -> Tensor")
|
|
|
|
# _dt_lib_impl = torch.library.Library("dtensor", "IMPL")
|
|
# _dt_lib_impl.impl("from_local", from_local_tensor, "Autograd")
|
|
|
|
x = torch.ones(1, requires_grad=True)
|
|
ref = fn(x)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
|
res = opt_fn(x)
|
|
# backward should work as well
|
|
res.sum().backward()
|
|
|
|
self.assertEqual(res, ref)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
# test if user calls from_local with mesh/placements as kwargs and that should still work
|
|
def from_local_kwargs_fn(x):
|
|
dt = DTensor.from_local(
|
|
x, device_mesh=mesh, placements=[Replicate()], run_check=False
|
|
)
|
|
return dt.to_local() + 2
|
|
|
|
ref = from_local_kwargs_fn(x)
|
|
opt_kwargs_fn = torch.compile(from_local_kwargs_fn, backend=cnt, fullgraph=True)
|
|
res = opt_kwargs_fn(x)
|
|
self.assertEqual(res, ref)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_dynamo_dtensor_from_local_dynamic_shapes(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
# Case 1: all dims dynamic
|
|
def fn(x):
|
|
dt = DTensor.from_local(
|
|
x,
|
|
mesh,
|
|
[Replicate()],
|
|
run_check=False,
|
|
shape=x.shape,
|
|
stride=x.stride(),
|
|
)
|
|
return dt.to_local() + 2
|
|
|
|
inp = torch.randn(4, 6, requires_grad=True)
|
|
ref = fn(inp)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
res = torch.compile(fn, backend=cnt, fullgraph=True, dynamic=True)(inp)
|
|
res.sum().backward()
|
|
|
|
self.assertEqual(res, ref)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
# Case 2: only sizes are dynamic, strides are static
|
|
def fn(x):
|
|
dt = DTensor.from_local(
|
|
x, mesh, [Replicate()], run_check=False, shape=x.shape, stride=(1,)
|
|
)
|
|
return dt.to_local() + 2
|
|
|
|
inp = torch.randn(4, requires_grad=True)
|
|
torch._dynamo.mark_dynamic(inp, 0)
|
|
ref = fn(inp)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
res = torch.compile(fn, backend=cnt, fullgraph=True)(inp)
|
|
res.sum().backward()
|
|
|
|
self.assertEqual(res, ref)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
# Case 3: both sizes and strides have a mix of dynamic and static dims
|
|
def fn(x):
|
|
dt = DTensor.from_local(
|
|
x,
|
|
mesh,
|
|
[Replicate()],
|
|
run_check=False,
|
|
shape=(x.shape[0], x.shape[1], 2),
|
|
stride=(x.stride()[0], 2, 1),
|
|
)
|
|
return dt.to_local() + 2
|
|
|
|
inp = torch.randn(4, 6, 2, requires_grad=True)
|
|
torch._dynamo.mark_dynamic(inp, 0)
|
|
torch._dynamo.mark_dynamic(inp, 1)
|
|
ref = fn(inp)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
res = torch.compile(fn, backend=cnt, fullgraph=True)(inp)
|
|
res.sum().backward()
|
|
|
|
self.assertEqual(res, ref)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_dynamo_dtensor_recompile(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
# test passing in DTensor as inputs/outputs and run some tensor computation
|
|
def fn(x):
|
|
return torch.mul(x, x)
|
|
|
|
x = DTensor.from_local(torch.rand(2, 2), mesh, [Shard(0)], run_check=False)
|
|
x2 = DTensor.from_local(torch.rand(2, 2), mesh, [Shard(0)], run_check=False)
|
|
x3 = DTensor.from_local(torch.rand(2, 2), mesh, [Shard(1)], run_check=False)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True, dynamic=False)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(fn(x2), opt_fn(x2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(fn(x3), opt_fn(x3))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
@skipIfHpu
|
|
def test_dtensor_partial_placement_redistribute_unbalanced_correct_strides(self):
|
|
# Partial -> Shard on an unbalanced tensor results in:
|
|
# - A contiguous DTensor
|
|
# - where the inner _local_tensor is noncontiguous
|
|
placement = Shard(1)
|
|
|
|
def fn(x):
|
|
out = x.redistribute(mesh, [placement])
|
|
return out
|
|
|
|
# Temporarily ignore setUp(), and use rank3 graphs during tracing
|
|
dist.destroy_process_group()
|
|
fake_store = FakeStore()
|
|
dist.init_process_group("fake", store=fake_store, rank=3, world_size=2)
|
|
mesh = DeviceMesh(self.device_type, [1, 3])
|
|
|
|
x = torch.randn(10, 257, 160, requires_grad=True)
|
|
x_dt = DTensor.from_local(
|
|
x,
|
|
mesh,
|
|
[Partial()],
|
|
run_check=False,
|
|
shape=(10, 257, 160),
|
|
stride=(41120, 160, 1),
|
|
)
|
|
|
|
# tmp_dt has an inner, non-contiguous tensor, and is an autograd non-leaf
|
|
tmp_dt = fn(x_dt)
|
|
fake_mode = torch._subclasses.FakeTensorMode()
|
|
tmp_dt_fake = fake_mode.from_tensor(tmp_dt)
|
|
self.assertEqual(tmp_dt.shape, tmp_dt_fake.shape)
|
|
self.assertEqual(tmp_dt.stride(), tmp_dt_fake.stride())
|
|
self.assertEqual(tmp_dt._local_tensor.shape, tmp_dt_fake._local_tensor.shape)
|
|
self.assertEqual(
|
|
tmp_dt._local_tensor.stride(), tmp_dt_fake._local_tensor.stride()
|
|
)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent(self):
|
|
# Partial -> Shard on an unbalanced tensor results in:
|
|
# - A contiguous DTensor
|
|
# - where the inner _local_tensor is noncontiguous
|
|
# When this tensor is a fwd graph output,
|
|
# AOTAutograd needs to make sure we trace the backward
|
|
# with a contiguous tangent
|
|
placement = Shard(1)
|
|
|
|
def fn(x):
|
|
out = x.redistribute(mesh, [placement])
|
|
return out
|
|
|
|
# Temporarily ignore setUp(), and use rank3 graphs during tracing
|
|
dist.destroy_process_group()
|
|
fake_store = FakeStore()
|
|
dist.init_process_group("fake", store=fake_store, rank=3, world_size=2)
|
|
mesh = DeviceMesh(self.device_type, [1, 3])
|
|
|
|
x = torch.randn(10, 257, 160, requires_grad=True)
|
|
x_dt = DTensor.from_local(
|
|
x,
|
|
mesh,
|
|
[Partial()],
|
|
run_check=False,
|
|
shape=(10, 257, 160),
|
|
stride=(41120, 160, 1),
|
|
)
|
|
|
|
out_dt = torch.compile(fn)(x_dt)
|
|
# If we don't properly contiguify our traced tangents,
|
|
# this fails with an inductor stride assert
|
|
out_dt.to_local().sum().backward()
|
|
|
|
def test_dynamo_to_local_kwargs(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
def fn(x):
|
|
return dt.to_local(grad_placements=[Shard(0)]) + 2
|
|
|
|
fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
|
x = torch.ones(4)
|
|
dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False)
|
|
|
|
out_ref = fn(dt)
|
|
out_test = fn_opt(dt)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
def test_dynamo_to_local_kwargs_forward_hook(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
def fw_hook(module, inp, out):
|
|
tmp = out.to_local(grad_placements=out.placements) + 2
|
|
return DTensor.from_local(tmp, mesh, out.placements, run_check=False)
|
|
|
|
mod = torch.nn.Linear(4, 4)
|
|
mod.register_forward_hook(fw_hook)
|
|
|
|
mod = torch.nn.Linear(4, 4)
|
|
mod.register_forward_hook(fw_hook)
|
|
mod.weight = torch.nn.Parameter(
|
|
DTensor.from_local(mod.weight, mesh, [Replicate()], run_check=False)
|
|
)
|
|
mod.bias = torch.nn.Parameter(
|
|
DTensor.from_local(mod.bias, mesh, [Replicate()], run_check=False)
|
|
)
|
|
opt_mod = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
|
|
|
x = torch.ones(4, 4)
|
|
dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False)
|
|
|
|
out_ref = mod(dt)
|
|
out_test = opt_mod(dt)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_dtensor_different_gradient_placement(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
def fn(x, y, z):
|
|
permute = x.permute(0, 2, 1)
|
|
permute2 = permute.contiguous()
|
|
layer_norm = torch.nn.functional.layer_norm(permute2, (4,), y, z, 1e-05)
|
|
out = layer_norm.permute(0, 2, 1)
|
|
return out
|
|
|
|
x = torch.randn(4, 2, 4, requires_grad=True, device="cuda")
|
|
x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False)
|
|
|
|
y = torch.randn(4, requires_grad=True, device="cuda")
|
|
y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False)
|
|
|
|
z = torch.randn(4, requires_grad=True, device="cuda")
|
|
z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False)
|
|
|
|
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
|
tmp_dt = opt_fn(x_dt, y_dt, z_dt)
|
|
out_dt = torch.matmul(tmp_dt, x_dt).permute(0, 2, 1)
|
|
out_dt.sum().backward()
|
|
|
|
def test_dynamo_dtensor_from_local_redistribute(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
# pass in tensor as inputs/outputs, create DTensor and run redistribute
|
|
# (allgather collective) inside the fn
|
|
def fn(x):
|
|
dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
|
return dt.redistribute(mesh, [Replicate()]).to_local() + 2
|
|
|
|
x = torch.ones(1)
|
|
ref = fn(x)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(res, ref)
|
|
|
|
def redistribute_kwargs_fn(x):
|
|
dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
|
return (
|
|
dt.redistribute(device_mesh=mesh, placements=[Replicate()]).to_local()
|
|
+ 2
|
|
)
|
|
|
|
x = torch.ones(1)
|
|
ref = redistribute_kwargs_fn(x)
|
|
opt_kwargs_fn = torch.compile(
|
|
redistribute_kwargs_fn, backend=cnt, fullgraph=True
|
|
)
|
|
res = opt_kwargs_fn(x)
|
|
self.assertEqual(res, ref)
|
|
|
|
@skipIfHpu
|
|
def test_dynamo_dtensor_from_local_redistribute_async(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
|
|
|
# pass in tensor as inputs/outputs, create DTensor and run redistribute
|
|
# (allgather collective) inside the fn
|
|
def fn(x):
|
|
dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
|
out = dt.redistribute(mesh, [Replicate()], async_op=True).to_local()
|
|
if isinstance(out, AsyncCollectiveTensor):
|
|
return out.wait()
|
|
else:
|
|
return out
|
|
|
|
x = torch.ones(1)
|
|
ref = fn(x)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(res, ref)
|
|
|
|
def test_dtensor_dont_recompile_on_same_placement_devicemesh(self):
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
|
|
|
@torch.compile(backend=cnt)
|
|
def fn(x):
|
|
DTensor.from_local(x, mesh, [placement], run_check=False)
|
|
|
|
x = torch.ones(4, 4, requires_grad=True)
|
|
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
placement = Shard(1)
|
|
fn(x)
|
|
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
placement = Shard(1)
|
|
# no recompile, placement is unchanged
|
|
fn(x)
|
|
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
placement = Partial()
|
|
# recompile since placement is different
|
|
fn(x)
|
|
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
placement = Partial()
|
|
# no recompile, placement is unchanged
|
|
fn(x)
|
|
|
|
# 2 total frames (one for Partial(), one for Shard())
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_dtensor_dynamo_device_mesh_attrs(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
# pass in tensor as inputs/outputs, create DTensor and run redistribute
|
|
# (allgather collective) inside the fn
|
|
def fn(x_dt):
|
|
if x_dt.device_mesh.device_type == "cuda":
|
|
return x_dt + 1
|
|
else:
|
|
return x_dt + 2
|
|
|
|
x = torch.ones(4, 4)
|
|
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
|
ref = fn(x_dt)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x_dt)
|
|
self.assertEqual(ref, res)
|
|
|
|
@skipIfHpu
|
|
def test_graph_input_is_async(self):
|
|
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
|
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
def fn(x):
|
|
return x.sin().sin()
|
|
|
|
opt_fn = torch.compile(fn, backend=aot_eager_graph, fullgraph=True)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
|
x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True)
|
|
x2 = x2.to_local()
|
|
self.assertTrue(isinstance(x2, AsyncCollectiveTensor))
|
|
opt_fn(x2)
|
|
# The important part: we get a wait_tensor() in the graph.
|
|
# At runtime, the input to the graph is an AsyncCollectiveTensor,
|
|
# and inside the graph we need to issue a wait() to synchronize.
|
|
self.assertExpectedInline(
|
|
str(fw_graph_cell[0]).strip(),
|
|
"""\
|
|
def forward(self, primals_1):
|
|
wait_tensor = torch.ops._c10d_functional.wait_tensor.default(primals_1)
|
|
sin = torch.ops.aten.sin.default(wait_tensor)
|
|
sin_1 = torch.ops.aten.sin.default(sin); sin = None
|
|
return (sin_1, primals_1, wait_tensor)""",
|
|
)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_unwrap_async_collective_tensor_tangent(self):
|
|
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
|
|
|
def fn(x):
|
|
return x.clone()
|
|
|
|
ref_x = TwoTensor(
|
|
torch.randn(2, 3, requires_grad=True), torch.randn(2, 3, requires_grad=True)
|
|
)
|
|
ref_y = fn(ref_x)
|
|
|
|
ref_y.backward(gradient=TwoTensor(torch.randn(2, 3), torch.randn(2, 3)))
|
|
|
|
fn_comp = torch.compile(fn, fullgraph=True)
|
|
|
|
x = TwoTensor(
|
|
torch.randn(2, 3, requires_grad=True), torch.randn(2, 3, requires_grad=True)
|
|
)
|
|
y = fn_comp(x)
|
|
y.backward(gradient=TwoTensor(torch.randn(2, 3), torch.randn(2, 3)))
|
|
|
|
x2 = TwoTensor(
|
|
torch.randn(2, 3, requires_grad=True), torch.randn(2, 3, requires_grad=True)
|
|
)
|
|
y2 = fn_comp(x2)
|
|
y2.backward(
|
|
gradient=TwoTensor(
|
|
AsyncCollectiveTensor(torch.randn(2, 3)),
|
|
AsyncCollectiveTensor(torch.randn(2, 3)),
|
|
)
|
|
)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_dtensor_partial_placement_graph_output(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
def fn(x):
|
|
return x + x
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
x_dt = DTensor.from_local(x, mesh, [Partial()], run_check=False)
|
|
|
|
y = torch.randn(4, 4, requires_grad=True)
|
|
y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False)
|
|
|
|
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
|
tmp_dt = opt_fn(x_dt)
|
|
out_dt = torch.matmul(tmp_dt, y_dt)
|
|
out_dt.sum().backward()
|
|
|
|
def _test_tp_compile_comm_reordering(self):
|
|
class FakeAttention(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.wq = nn.Linear(16, 16)
|
|
self.wk = nn.Linear(16, 16)
|
|
self.wv = nn.Linear(16, 16)
|
|
self.wo = nn.Linear(16, 16)
|
|
|
|
def forward(self, x):
|
|
xq = self.wq(x)
|
|
xk = self.wk(x)
|
|
xv = self.wv(x)
|
|
# fake attention:
|
|
xo = xq + xk + xv
|
|
return self.wo(xo)
|
|
|
|
class FakeTransformerBlock(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attn = FakeAttention()
|
|
|
|
def forward(self, x):
|
|
return self.attn(x)
|
|
|
|
class FakeTransformer(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.block = FakeTransformerBlock()
|
|
|
|
def forward(self, input):
|
|
return self.block(input)
|
|
|
|
model = FakeTransformer().to(self.device_type)
|
|
|
|
tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",))
|
|
|
|
# apply sequence parallel
|
|
parallel_plan = {
|
|
"attn": PrepareModuleInput(
|
|
input_layouts=Shard(0), desired_input_layouts=Replicate()
|
|
),
|
|
"attn.wq": ColwiseParallel(),
|
|
"attn.wk": ColwiseParallel(),
|
|
"attn.wv": ColwiseParallel(),
|
|
"attn.wo": RowwiseParallel(output_layouts=Shard(0)),
|
|
}
|
|
|
|
parallelize_module(
|
|
module=model.block,
|
|
device_mesh=tp_mesh,
|
|
parallelize_plan=parallel_plan,
|
|
)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
|
compiled_model = torch.compile(model, backend=cnt, fullgraph=True)
|
|
inp = torch.rand(20, 16).to(self.device_type)
|
|
out = compiled_model(inp)
|
|
out.sum().backward()
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
code = run_and_get_triton_code(compiled_model, inp)
|
|
FileCheck().check(
|
|
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal"
|
|
).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check(
|
|
"extern_kernels.mm(buf0,"
|
|
).run(code)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@skip_if_lt_x_gpu(1)
|
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
|
@patch.object(torch._inductor.config, "compile_threads", 1)
|
|
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
|
|
def test_tp_compile_comm_reordering(self):
|
|
self._test_tp_compile_comm_reordering()
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@skip_if_lt_x_gpu(1)
|
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
|
@patch.object(torch._inductor.config, "compile_threads", 1)
|
|
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
|
|
@torch._inductor.config.patch("graph_partition", True)
|
|
def test_tp_compile_comm_reordering_graph_partition(self):
|
|
self._test_tp_compile_comm_reordering()
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestDTensorCompileE2E(DTensorTestBase):
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
# multiprocess relies on pickling the source code
|
|
# so compiled autograd tests can't dynamically wrap this class
|
|
def _bwd_ctx(self, use_ca):
|
|
if not use_ca:
|
|
return contextlib.nullcontext()
|
|
return torch._dynamo.compiled_autograd._enable(torch.compile)
|
|
|
|
@with_comms
|
|
@parametrize("is_seq_parallel", [True, False])
|
|
@parametrize("use_ca", [True, False])
|
|
def test_tp_compile_fullgraph(self, is_seq_parallel, use_ca):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
|
|
model = SimpleModel(self.device_type)
|
|
|
|
colwise_style = (
|
|
ColwiseParallel(input_layouts=Shard(0))
|
|
if is_seq_parallel
|
|
else ColwiseParallel()
|
|
)
|
|
rowwise_style = (
|
|
RowwiseParallel(output_layouts=Shard(0))
|
|
if is_seq_parallel
|
|
else RowwiseParallel()
|
|
)
|
|
|
|
if is_seq_parallel:
|
|
# use input preparation to test out the compile of it
|
|
prepare_module_input = PrepareModuleInput(
|
|
input_layouts=Shard(0),
|
|
desired_input_layouts=Replicate(),
|
|
)
|
|
prepare_module_out = PrepareModuleOutput(
|
|
output_layouts=Replicate(),
|
|
desired_output_layouts=Shard(0),
|
|
)
|
|
plan = {
|
|
"mlp_0": prepare_module_input,
|
|
"mlp_0.net1": ColwiseParallel(),
|
|
"mlp_0.net2": rowwise_style,
|
|
"mlp_1.net1": colwise_style,
|
|
"mlp_1.net2": RowwiseParallel(),
|
|
"mlp_1": prepare_module_out,
|
|
}
|
|
else:
|
|
plan = {
|
|
"mlp_0.net1": colwise_style,
|
|
"mlp_0.net2": rowwise_style,
|
|
"mlp_1.net1": colwise_style,
|
|
"mlp_1.net2": rowwise_style,
|
|
}
|
|
|
|
model = parallelize_module(
|
|
model,
|
|
mesh,
|
|
parallelize_plan=plan,
|
|
)
|
|
rng_seed = self.rank if is_seq_parallel else 0
|
|
torch.manual_seed(rng_seed)
|
|
inp = torch.rand(20, 10, device=self.device_type)
|
|
out = model(inp)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
compiled_mod = torch.compile(model, backend=cnt, fullgraph=True)
|
|
compiled_out = compiled_mod(inp)
|
|
with self._bwd_ctx(use_ca):
|
|
compiled_out.sum().backward()
|
|
self.assertEqual(compiled_out, out)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
@with_comms
|
|
@skip_if_lt_x_gpu(4)
|
|
@parametrize("use_ca", [True, False])
|
|
def test_2d_fsdp_tp_compile(self, use_ca):
|
|
data_parallel_size = 2
|
|
model = SimpleModel(self.device_type)
|
|
model_copy = copy.deepcopy(model)
|
|
|
|
# 2-D mesh is [dp, tp]
|
|
twod_mesh = init_device_mesh(
|
|
self.device_type,
|
|
(data_parallel_size, self.world_size // data_parallel_size),
|
|
mesh_dim_names=["dp", "tp"],
|
|
)
|
|
|
|
inp = torch.rand(20, 10, device=self.device_type)
|
|
parallelize_plan = {
|
|
"mlp_0.net1": ColwiseParallel(),
|
|
"mlp_0.net2": RowwiseParallel(),
|
|
"mlp_1.net1": ColwiseParallel(),
|
|
"mlp_1.net2": RowwiseParallel(),
|
|
}
|
|
tp_model = parallelize_module(model, twod_mesh["tp"], parallelize_plan)
|
|
eager_2d = FSDP(
|
|
tp_model,
|
|
device_id=dev_type.type,
|
|
use_orig_params=True,
|
|
device_mesh=twod_mesh["dp"],
|
|
)
|
|
out = eager_2d(inp)
|
|
tp_model2 = parallelize_module(
|
|
model_copy,
|
|
twod_mesh["tp"],
|
|
parallelize_plan,
|
|
)
|
|
fsdp_2d = FSDP(
|
|
tp_model2,
|
|
device_id=dev_type.type,
|
|
use_orig_params=True,
|
|
device_mesh=twod_mesh["dp"],
|
|
)
|
|
|
|
# TODO: once aot autograd support is ready we can just use default backend
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
compiled_2d = torch.compile(fsdp_2d, backend=cnt)
|
|
compiled_output = compiled_2d(inp)
|
|
with self._bwd_ctx(use_ca):
|
|
compiled_output.sum().backward()
|
|
|
|
self.assertEqual(out, compiled_output)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
@with_comms
|
|
@skip_if_lt_x_gpu(4)
|
|
@parametrize("use_ca", [True, False])
|
|
def test_2d_fsdp_tp_ac_compile(self, use_ca):
|
|
dp_degree = 2
|
|
tp_degree = self.world_size // dp_degree
|
|
model = SimpleModel(self.device_type)
|
|
model_copy = copy.deepcopy(model)
|
|
|
|
# 2-D mesh is [dp, tp]
|
|
mesh_2d = init_device_mesh(
|
|
self.device_type,
|
|
mesh_shape=(dp_degree, tp_degree),
|
|
mesh_dim_names=("dp", "tp"),
|
|
)
|
|
|
|
inp = torch.rand(20, 10, device=self.device_type)
|
|
parallelize_plan = {
|
|
"mlp_0.net1": ColwiseParallel(),
|
|
"mlp_0.net2": RowwiseParallel(),
|
|
"mlp_1.net1": ColwiseParallel(),
|
|
"mlp_1.net2": RowwiseParallel(),
|
|
}
|
|
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
|
|
tp_model = checkpoint_wrapper(
|
|
tp_model,
|
|
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
|
checkpoint_fn=checkpoint,
|
|
use_reentrant=False,
|
|
)
|
|
eager_2d = FSDP(tp_model, device_mesh=mesh_2d["dp"], use_orig_params=True)
|
|
|
|
tp_model2 = parallelize_module(model_copy, mesh_2d["tp"], parallelize_plan)
|
|
fsdp_2d = FSDP(
|
|
tp_model2,
|
|
device_mesh=mesh_2d["dp"],
|
|
use_orig_params=True,
|
|
)
|
|
# TODO: once aot autograd support is ready we can just use default backend
|
|
compiled_2d = torch.compile(fsdp_2d, backend="aot_eager")
|
|
|
|
# forward pass
|
|
out = eager_2d(inp)
|
|
compiled_output = compiled_2d(inp)
|
|
self.assertEqual(out, compiled_output)
|
|
|
|
# backward pass
|
|
out.sum().backward()
|
|
with self._bwd_ctx(use_ca):
|
|
compiled_output.sum().backward()
|
|
|
|
# compare the gradients:
|
|
for n, p in zip(fsdp_2d.parameters(), compiled_2d.parameters()):
|
|
self.assertEqual(n.grad, p.grad)
|
|
|
|
@with_comms
|
|
@skip_if_lt_x_gpu(4)
|
|
@parametrize("use_ca", [True, False])
|
|
def test_compile_dtensor_redistribute_backward(self, use_ca):
|
|
mesh = DeviceMesh(
|
|
device_type=self.device_type, mesh=torch.arange(self.world_size)
|
|
)
|
|
|
|
def fn(x, y):
|
|
dt = DTensor.from_local(x.reshape(2, 4), mesh, [Shard(0)], run_check=False)
|
|
dt2 = DTensor.from_local(y.reshape(4, 2), mesh, [Shard(1)], run_check=False)
|
|
dt_out = torch.matmul(dt, dt2)
|
|
dt_out_redistribute = dt_out.redistribute(mesh, [Replicate()])
|
|
return dt_out_redistribute.to_local()
|
|
|
|
opt_fn = torch.compile(fn, backend=aot_eager_graph, fullgraph=True)
|
|
|
|
x_ref = torch.arange(8, requires_grad=True, dtype=torch.float32)
|
|
y_ref = torch.arange(8, requires_grad=True, dtype=torch.float32)
|
|
ref = fn(x_ref, y_ref)
|
|
|
|
x = torch.arange(8, requires_grad=True, dtype=torch.float32)
|
|
y = torch.arange(8, requires_grad=True, dtype=torch.float32)
|
|
res = opt_fn(x, y)
|
|
|
|
self.assertEqual(res, ref)
|
|
|
|
# Now run and assert the backward + gradients
|
|
ref.sum().backward()
|
|
with self._bwd_ctx(use_ca):
|
|
res.sum().backward()
|
|
|
|
self.assertEqual(x_ref.grad, x.grad)
|
|
self.assertEqual(y_ref.grad, y.grad)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|