Introduce automatic wrapper to run DTensor tests under local tensor mode (#165383)

The wrapper enable to share test body implementation while eliminating need test class by hand. As an example, this change converts the whole DTensorTest to use local tensor mode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165383
Approved by: https://github.com/ezyang
This commit is contained in:
Dzmitry Huba
2025-10-14 06:08:00 +00:00
committed by PyTorch MergeBot
parent a856a17799
commit 5fbf93b774
3 changed files with 118 additions and 119 deletions

View File

@ -2,8 +2,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import functools
import itertools
import sys
import types
from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass
from functools import partial, wraps
@ -13,7 +16,12 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._local_tensor import LocalTensor
from torch.distributed._local_tensor import (
LocalIntNode,
LocalTensor,
LocalTensorMode,
maybe_run_for_local_tensor,
)
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
@ -687,3 +695,90 @@ class DTensorConverter:
return t
else:
raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}")
class LocalDTensorTestBase(DTensorTestBase):
def _get_local_tensor_mode(self):
return LocalTensorMode(frozenset(range(0, self.world_size)))
def setUp(self) -> None:
super().setUp()
torch.autograd._enable_record_function(False)
def tearDown(self) -> None:
super().tearDown()
torch.autograd._enable_record_function(True)
@property
def rank(self):
return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
@rank.setter
def rank(self, rank):
pass
def join_or_run(self, fn):
@wraps(fn)
def wrapper(self):
fn()
return types.MethodType(wrapper, self)
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
dist.init_process_group("fake", rank=0, world_size=self.world_size)
self._pg = dist.distributed_c10d._get_default_group()
def destroy_pg(self, device_id: Optional[int] = None) -> None:
dist.destroy_process_group(self._pg)
self._pg = None
def _spawn_processes(self) -> None:
pass
def make_wrapped(fn, ctxs):
@functools.wraps(fn)
def wrapped(self):
torch._dynamo.reset()
stack = contextlib.ExitStack()
for ctx in ctxs:
if callable(ctx):
stack.enter_context(ctx(self))
else:
stack.enter_context(ctx)
out = fn(self)
stack.close()
return out
return wrapped
def create_local_tensor_test_class(orig_cls, skipped_tests=None):
if skipped_tests is None:
skipped_tests = []
dct = orig_cls.__dict__.copy()
for name in list(dct.keys()):
fn = dct[name]
if not callable(fn):
continue
elif name in skipped_tests:
dct[name] = lambda self: self.skipTest("Skipped test")
elif name.startswith("test_"):
ctxs = [
lambda test: test._get_local_tensor_mode(),
]
dct[name] = make_wrapped(fn, ctxs)
cls = type(
orig_cls.__name__ + "WithLocalTensor",
(LocalDTensorTestBase,) + orig_cls.__bases__,
dct,
)
cls.__file__ = __file__
return cls
@maybe_run_for_local_tensor
def map_local_tensor_for_rank(tensor, rank, func):
return func(tensor, rank)