mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
A LocalTensor is a tensor subclass which simulates a tensor that is distributed across SPMD ranks. A LocalTensor might be size N, but in fact there are world_size shards/replicas of it stored internally. When you do a plain PyTorch operation on it, we apply the operation to each shard; when you do a collective, we do the mathematically equivalent operation on the local shards. A LocalTensor is associated with a list of ranks which specify which ranks it holds local tensors for. NB, this is NOT a DataParallel like abstraction where you can run operations on multiple different GPUs. It is intended purely for *debugging* purposes, the overhead is almost certainly too high to keep eight GPUs (even the C++ autograd needs multithreading to keep up!) (It might potentially be possible to trace through this with torch.compile and then compile it with CUDA graphs but this is currently a non-goal.) In order to handle MPMD, we provide a helper decorator that allows you to run a function with no side effects for each LocalTensor shard and combine results back into LocalTensor or LocalIntNode. Note: This PR convert all DTensor ops and some DTensor tests to illustrate intended usage and ensure conrrectness. In subsequent PR more tests will be converted. DUring test conversion we aim to share as much as possible of test logic between multi-process / multi-threaded and local tensor tests. We would like to developers to be able to run both flavors of the tests. Note: This work is based on the original proposal by @ezyang (WIP PR https://github.com/pytorch/pytorch/pull/162753). Pull Request resolved: https://github.com/pytorch/pytorch/pull/164537 Approved by: https://github.com/ezyang
73 lines
1.6 KiB
Python
73 lines
1.6 KiB
Python
from typing import * # noqa: F403
|
|
|
|
|
|
# Python version of c10/core/ConstantSymNodeImpl.cpp
|
|
# This needs to exist because the Python version of nested int is not compatible
|
|
# with the C++ version of constant symnode.
|
|
class ConstantIntNode:
|
|
def __init__(self, val: int):
|
|
self.val = val
|
|
|
|
def is_constant(self) -> bool:
|
|
return True
|
|
|
|
def maybe_as_int(self) -> int:
|
|
return self.val
|
|
|
|
def is_int(self) -> bool:
|
|
return True
|
|
|
|
def is_float(self) -> bool:
|
|
return False
|
|
|
|
def is_bool(self) -> bool:
|
|
return False
|
|
|
|
def is_nested_int(self) -> bool:
|
|
return False
|
|
|
|
def clone(self) -> "ConstantIntNode":
|
|
return self
|
|
|
|
def _str(self) -> str:
|
|
return str(self.val)
|
|
|
|
def __str__(self) -> str:
|
|
return self._str()
|
|
|
|
def __repr__(self) -> str:
|
|
return self._str()
|
|
|
|
def _graph_repr(self) -> str:
|
|
return self._str()
|
|
|
|
def add(self, other: Any) -> Any:
|
|
return other.add(self)
|
|
|
|
def mul(self, other: Any) -> Any:
|
|
return other.mul(self)
|
|
|
|
def eq(self, other: Any) -> Any:
|
|
return other.eq(self)
|
|
|
|
def ne(self, other: Any) -> Any:
|
|
return other.ne(self)
|
|
|
|
def gt(self, other: Any) -> Any:
|
|
return other.lt(self)
|
|
|
|
def lt(self, other: Any) -> Any:
|
|
return other.gt(self)
|
|
|
|
def le(self, other: Any) -> Any:
|
|
return other.ge(self)
|
|
|
|
def ge(self, other: Any) -> Any:
|
|
return other.le(self)
|
|
|
|
def is_symbolic(self) -> bool:
|
|
return False
|
|
|
|
def constant_int(self) -> int:
|
|
return self.val
|