mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Introduce automatic wrapper to run DTensor tests under local tensor mode (#165383)
The wrapper enable to share test body implementation while eliminating need test class by hand. As an example, this change converts the whole DTensorTest to use local tensor mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165383 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a856a17799
commit
5fbf93b774
@ -2,8 +2,11 @@
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import partial, wraps
|
||||
@ -13,7 +16,12 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._local_tensor import LocalTensor
|
||||
from torch.distributed._local_tensor import (
|
||||
LocalIntNode,
|
||||
LocalTensor,
|
||||
LocalTensorMode,
|
||||
maybe_run_for_local_tensor,
|
||||
)
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
@ -687,3 +695,90 @@ class DTensorConverter:
|
||||
return t
|
||||
else:
|
||||
raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}")
|
||||
|
||||
|
||||
class LocalDTensorTestBase(DTensorTestBase):
|
||||
def _get_local_tensor_mode(self):
|
||||
return LocalTensorMode(frozenset(range(0, self.world_size)))
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.autograd._enable_record_function(False)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
super().tearDown()
|
||||
torch.autograd._enable_record_function(True)
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
|
||||
|
||||
@rank.setter
|
||||
def rank(self, rank):
|
||||
pass
|
||||
|
||||
def join_or_run(self, fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self):
|
||||
fn()
|
||||
|
||||
return types.MethodType(wrapper, self)
|
||||
|
||||
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
|
||||
dist.init_process_group("fake", rank=0, world_size=self.world_size)
|
||||
self._pg = dist.distributed_c10d._get_default_group()
|
||||
|
||||
def destroy_pg(self, device_id: Optional[int] = None) -> None:
|
||||
dist.destroy_process_group(self._pg)
|
||||
self._pg = None
|
||||
|
||||
def _spawn_processes(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def make_wrapped(fn, ctxs):
|
||||
@functools.wraps(fn)
|
||||
def wrapped(self):
|
||||
torch._dynamo.reset()
|
||||
stack = contextlib.ExitStack()
|
||||
for ctx in ctxs:
|
||||
if callable(ctx):
|
||||
stack.enter_context(ctx(self))
|
||||
else:
|
||||
stack.enter_context(ctx)
|
||||
out = fn(self)
|
||||
stack.close()
|
||||
return out
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def create_local_tensor_test_class(orig_cls, skipped_tests=None):
|
||||
if skipped_tests is None:
|
||||
skipped_tests = []
|
||||
|
||||
dct = orig_cls.__dict__.copy()
|
||||
for name in list(dct.keys()):
|
||||
fn = dct[name]
|
||||
if not callable(fn):
|
||||
continue
|
||||
elif name in skipped_tests:
|
||||
dct[name] = lambda self: self.skipTest("Skipped test")
|
||||
elif name.startswith("test_"):
|
||||
ctxs = [
|
||||
lambda test: test._get_local_tensor_mode(),
|
||||
]
|
||||
dct[name] = make_wrapped(fn, ctxs)
|
||||
|
||||
cls = type(
|
||||
orig_cls.__name__ + "WithLocalTensor",
|
||||
(LocalDTensorTestBase,) + orig_cls.__bases__,
|
||||
dct,
|
||||
)
|
||||
cls.__file__ = __file__
|
||||
return cls
|
||||
|
||||
|
||||
@maybe_run_for_local_tensor
|
||||
def map_local_tensor_for_rank(tensor, rank, func):
|
||||
return func(tensor, rank)
|
||||
|
Reference in New Issue
Block a user