Files
pytorch/torch/fx/experimental/_constant_symnode.py
Dzmitry Huba 5e58420dff LocalTensor (#164537)
A LocalTensor is a tensor subclass which simulates a tensor that is
distributed across SPMD ranks.  A LocalTensor might be size N, but in fact
there are world_size shards/replicas of it stored internally.  When you do a
plain PyTorch operation on it, we apply the operation to each shard; when you
do a collective, we do the mathematically equivalent operation on the local
shards.  A LocalTensor is associated with a list of ranks which specify
which ranks it holds local tensors for.

NB, this is NOT a DataParallel like abstraction where you can run operations
on multiple different GPUs. It is intended purely for *debugging* purposes,
the overhead is almost certainly too high to keep eight GPUs (even the C++
autograd needs multithreading to keep up!)  (It might potentially be possible
to trace through this with torch.compile and then compile it with CUDA graphs
but this is currently a non-goal.)

In order to handle MPMD, we provide a helper decorator that allows you to
run a function with no side effects for each LocalTensor shard and combine
results back into LocalTensor or LocalIntNode.

Note: This PR convert all DTensor ops and some DTensor tests to illustrate
intended usage and ensure conrrectness. In subsequent PR more tests will be
converted. DUring test conversion we aim to share as much as possible of
test logic between multi-process / multi-threaded and local tensor tests.
We would like to developers to be able to run both flavors of the tests.

Note: This work is based on the original proposal
by @ezyang (WIP PR https://github.com/pytorch/pytorch/pull/162753).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164537
Approved by: https://github.com/ezyang
2025-10-12 20:06:41 +00:00

73 lines
1.6 KiB
Python

from typing import * # noqa: F403
# Python version of c10/core/ConstantSymNodeImpl.cpp
# This needs to exist because the Python version of nested int is not compatible
# with the C++ version of constant symnode.
class ConstantIntNode:
def __init__(self, val: int):
self.val = val
def is_constant(self) -> bool:
return True
def maybe_as_int(self) -> int:
return self.val
def is_int(self) -> bool:
return True
def is_float(self) -> bool:
return False
def is_bool(self) -> bool:
return False
def is_nested_int(self) -> bool:
return False
def clone(self) -> "ConstantIntNode":
return self
def _str(self) -> str:
return str(self.val)
def __str__(self) -> str:
return self._str()
def __repr__(self) -> str:
return self._str()
def _graph_repr(self) -> str:
return self._str()
def add(self, other: Any) -> Any:
return other.add(self)
def mul(self, other: Any) -> Any:
return other.mul(self)
def eq(self, other: Any) -> Any:
return other.eq(self)
def ne(self, other: Any) -> Any:
return other.ne(self)
def gt(self, other: Any) -> Any:
return other.lt(self)
def lt(self, other: Any) -> Any:
return other.gt(self)
def le(self, other: Any) -> Any:
return other.ge(self)
def ge(self, other: Any) -> Any:
return other.le(self)
def is_symbolic(self) -> bool:
return False
def constant_int(self) -> int:
return self.val