Files
pytorch/torch/_dynamo/tensor_version_op.py
Brian Hirsh e68f5087d8 update _unsafe_set_version_counter to accept lists of tensors (#137921)
See the comment [here](https://github.com/pytorch/pytorch/issues/132014#issuecomment-2379547400) (cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @XilunWu @rec) - this PR updates `_unsafe_set_version_counter` to accept a list of tensors, for overhead-sensitive users (e.g. distributed) who need to hide VC bumps from autograd on a large list of tensors without wanting to suffer the overhead of going from python->C++ separately for every tensor in the list.

I left the binding in pybind, and used a `std::vector`. if we **really** need to optimize overhead even further, we could write a manual cpython binding.

I use this updated API in the next PR to fix FSDP2, so that it properly hides the VC of all `all_gather_buffer` tensors in its call to `split_with_sizes_copy.out(all_gather_buffers)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137921
Approved by: https://github.com/awgu, https://github.com/albanD
2025-02-04 04:51:11 +00:00

60 lines
2.1 KiB
Python

# mypy: allow-untyped-defs
import torch
from torch._prims import _make_prim, RETURN_TYPE
from torch._subclasses import FakeTensorMode
from torch._subclasses.functional_tensor import FunctionalTensorMode
_tensor_version = _make_prim(
schema="_tensor_version(Tensor self) -> SymInt",
return_type=RETURN_TYPE.NEW,
meta=torch.ops.aten._version.default,
impl_aten=torch.ops.aten._version.default,
doc="Tracable unbacked SymInt version of torch.Tensor._version",
)
@_tensor_version.py_impl(FakeTensorMode)
def _tensor_version_fake(fake_mode, self_tensor):
"""
The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the
`._version` into an unbacked SymInt so that we don't need to specialize on the `._version`
of input tensors to the graph.
"""
return fake_mode.shape_env.create_unbacked_symint()
_unsafe_set_version_counter = _make_prim(
schema="_unsafe_set_version_counter(Tensor[] tensors, SymInt[] versions) -> ()",
return_type=RETURN_TYPE.NEW,
meta=lambda self, version: None,
impl_aten=torch._C._autograd._unsafe_set_version_counter,
doc="Tracable+SymInt version of torch._C._autograd._unsafe_set_version_counter",
)
torch.fx.node.has_side_effect(_unsafe_set_version_counter)
"""
When we functionalize _tensor_version + _unsafe_set_version_counter,
the ops disappear from the traced graph. We run them eagerly on the
fake tensors used for tracing, in order to get past asserts that would
fail in autograd.
Why is this ok?
1) Versions on functional tensors don't make any sense since you can't mutate a functional tensor.
2) The whole point of version munging is to trick autograd into doing what we want, and after
AotAtuograd there is no longer any need for these ops.
Note this is similar to how no_grad is handled.
"""
@_tensor_version.py_impl(FunctionalTensorMode)
def _tensor_version_functional(mode, self):
return self._version
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode)
def _unsafe_set_version_counter_functional(ctx, tensors, versions):
torch._C._autograd._unsafe_set_version_counter(tensors, versions)