update _unsafe_set_version_counter to accept lists of tensors (#137921)

See the comment [here](https://github.com/pytorch/pytorch/issues/132014#issuecomment-2379547400) (cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @XilunWu @rec) - this PR updates `_unsafe_set_version_counter` to accept a list of tensors, for overhead-sensitive users (e.g. distributed) who need to hide VC bumps from autograd on a large list of tensors without wanting to suffer the overhead of going from python->C++ separately for every tensor in the list.

I left the binding in pybind, and used a `std::vector`. if we **really** need to optimize overhead even further, we could write a manual cpython binding.

I use this updated API in the next PR to fix FSDP2, so that it properly hides the VC of all `all_gather_buffer` tensors in its call to `split_with_sizes_copy.out(all_gather_buffers)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137921
Approved by: https://github.com/awgu, https://github.com/albanD
This commit is contained in:
Brian Hirsh
2025-02-03 16:16:29 -08:00
committed by PyTorch MergeBot
parent 425aca40a4
commit e68f5087d8
9 changed files with 94 additions and 49 deletions

View File

@ -242,7 +242,7 @@ class DistributedPatternTests(TestCase):
x = x.sin()
v = w._version
w.copy_(x + 1)
torch._C._autograd._unsafe_set_version_counter(w, v)
torch._C._autograd._unsafe_set_version_counter((w,), (v,))
return w, v
for v in (3, 0, 1):
@ -266,7 +266,7 @@ class DistributedPatternTests(TestCase):
with torch.no_grad():
v = w._version
w.copy_(x)
torch._C._autograd._unsafe_set_version_counter(w, v)
torch._C._autograd._unsafe_set_version_counter((w,), (v,))
return r
w1 = torch.randn(1, requires_grad=True)

View File

@ -4799,10 +4799,18 @@ SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
# version counter doesn't change inside of the context manager
self.assertEqual(2, x._version)
torch._C._autograd._unsafe_set_version_counter(x, 0)
torch._C._autograd._unsafe_set_version_counter((x,), (0,))
self.assertEqual(0, x._version)
with self.assertRaisesRegex(RuntimeError, "Cannot set"):
torch._C._autograd._unsafe_set_version_counter(x, -1)
torch._C._autograd._unsafe_set_version_counter((x,), (-1,))
y = torch.ones(2, requires_grad=True).clone()
with torch.autograd._unsafe_preserve_version_counter((x, y)):
x.mul_(2)
y.mul_(3)
# version counter doesn't change inside of the context manager
self.assertEqual(0, x._version)
self.assertEqual(0, y._version)
def test_current_node(self):
pr = []

View File

@ -115,7 +115,9 @@ def _push_saved_tensors_default_hooks(
unpack_hook: Callable[[Any], torch.Tensor],
) -> None: ...
def _pop_saved_tensors_default_hooks() -> None: ...
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
def _unsafe_set_version_counter(
t: tuple[torch.Tensor, ...], prev_version: tuple[int, ...]
) -> None: ...
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
def _profiler_type() -> ActiveProfilerType: ...

View File

@ -25,7 +25,7 @@ def _tensor_version_fake(fake_mode, self_tensor):
_unsafe_set_version_counter = _make_prim(
schema="_unsafe_set_version_counter(Tensor self, SymInt version) -> ()",
schema="_unsafe_set_version_counter(Tensor[] tensors, SymInt[] versions) -> ()",
return_type=RETURN_TYPE.NEW,
meta=lambda self, version: None,
impl_aten=torch._C._autograd._unsafe_set_version_counter,
@ -55,5 +55,5 @@ def _tensor_version_functional(mode, self):
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode)
def _unsafe_set_version_counter_functional(ctx, self, version):
torch._C._autograd._unsafe_set_version_counter(self, version)
def _unsafe_set_version_counter_functional(ctx, tensors, versions):
torch._C._autograd._unsafe_set_version_counter(tensors, versions)

View File

@ -1872,7 +1872,7 @@ class BuiltinVariable(VariableTracker):
version = x._version
if version > 0:
version = version - 1
torch._C._autograd._unsafe_set_version_counter(x, version)
torch._C._autograd._unsafe_set_version_counter((x,), (version,))
return x
tx.output.create_proxy(

View File

@ -975,20 +975,39 @@ class PreserveVersionContextVariable(ContextWrappingVariable):
Wraps torch.autograd._unsafe_preserve_version_counter
"""
@staticmethod
def _create_lambda_from_tensors(tx, tensors):
if isinstance(tensors, variables.TensorVariable):
versions = variables.TupleVariable(
[x.var_getattr(tx, "_version") for x in [tensors]]
)
tensors = variables.TupleVariable([tensors])
else:
versions = variables.TupleVariable(
[x.var_getattr(tx, "_version") for x in tensors.items]
)
return PreserveVersionContextVariable(tensors, versions)
@staticmethod
def constructor(tx):
return variables.LambdaVariable(
lambda tensor: PreserveVersionContextVariable(
tensor,
tensor.var_getattr(tx, "_version"),
lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors(
tx, tensors
)
)
def __init__(self, tensor, prev_version, **kwargs) -> None:
def __init__(self, tensors, prev_versions, **kwargs) -> None:
kwargs.setdefault("target_values", None)
super().__init__(**kwargs)
self.tensor = tensor
self.prev_version = prev_version
self.tensors = tensors
self.prev_versions = prev_versions
# The context manager accepts Union[Tensor, Tuple[Tensor]]
if isinstance(self.tensors, variables.TensorVariable):
self.tensors = variables.TupleVariable([self.tensors])
if isinstance(
self.prev_versions, (variables.ConstantVariable, variables.SymNodeVariable)
):
self.prev_versions = variables.TupleVariable([self.prev_versions])
def enter(self, tx):
pass
@ -998,7 +1017,7 @@ class PreserveVersionContextVariable(ContextWrappingVariable):
return variables.TorchInGraphFunctionVariable(
_unsafe_set_version_counter
).call_function(tx, [self.tensor, self.prev_version], {})
).call_function(tx, [self.tensors, self.prev_versions], {})
def reconstruct(self, codegen):
unimplemented(

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any
from typing import Any, Tuple, Union
import torch
from torch.utils._contextlib import (
@ -386,12 +386,13 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
"""
def __init__(self, tensor: torch.Tensor) -> None:
self.tensor = tensor
self.prev_version = tensor._version
def __init__(self, tensors: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> None:
self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
assert isinstance(self.tensors, tuple)
self.prev_versions = tuple(t._version for t in self.tensors)
def __enter__(self) -> None:
pass
def __exit__(self, *args) -> None:
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)
torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions)

View File

@ -388,10 +388,23 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
return activities;
});
m.def("_unsafe_set_version_counter", [](const at::Tensor& t, int64_t i) {
auto vc = torch::autograd::impl::version_counter(t);
vc.set_version(i);
});
m.def(
"_unsafe_set_version_counter",
[](const std::vector<at::Tensor>& tensors,
const std::vector<int64_t>& versions) {
auto tensors_len = tensors.size();
auto versions_len = versions.size();
TORCH_CHECK(
tensors_len == versions_len,
"tensors_len=",
tensors_len,
", versions_len=",
versions_len);
for (const auto i : c10::irange(tensors_len)) {
auto vc = torch::autograd::impl::version_counter(tensors[i]);
vc.set_version(versions[i]);
}
});
m.def("_enable_profiler_legacy", enableProfilerLegacy);
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")

View File

@ -290,35 +290,37 @@ def foreach_all_gather_copy_out(
out = [t.view(world_size, -1).view(torch.uint8) for t in split_with_sizes_out]
else:
out = [t.view(world_size, -1) for t in split_with_sizes_out]
torch.ops.fsdp.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
)
with torch.autograd._unsafe_preserve_version_counter(tuple(out)):
torch.ops.fsdp.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
)
for fsdp_param, param_all_gather_outputs in shard_i_copy_infos:
# Chunk-cat from the temporary to the final all-gather output tensors
shard_dim = fsdp_param.fsdp_placement.dim
for param_all_gather_output, target_all_gather_output in zip(
param_all_gather_outputs, fsdp_param.all_gather_outputs
with torch.autograd._unsafe_preserve_version_counter(
tuple(fsdp_param.all_gather_outputs)
):
padded_sharded_size = (
fsdp_param.padded_sharded_param_size
if fsdp_param.sharded_state == ShardedState.SHARDED
else cast(
torch.Tensor, fsdp_param._sharded_post_forward_param_data
).size()
)
pre_param_size = list(padded_sharded_size)
pre_param_size[0] *= world_size
chunks = torch.chunk(
param_all_gather_output.view(pre_param_size), world_size, dim=0
)
post_param_size = list(padded_sharded_size)
post_param_size[shard_dim] *= world_size
cat_out = target_all_gather_output.view(post_param_size)
torch.cat(chunks, dim=shard_dim, out=cat_out)
torch._C._autograd._unsafe_set_version_counter(
target_all_gather_output, target_all_gather_output._version - 1
)
for param_all_gather_output, target_all_gather_output in zip(
param_all_gather_outputs, fsdp_param.all_gather_outputs
):
padded_sharded_size = (
fsdp_param.padded_sharded_param_size
if fsdp_param.sharded_state == ShardedState.SHARDED
else cast(
torch.Tensor, fsdp_param._sharded_post_forward_param_data
).size()
)
pre_param_size = list(padded_sharded_size)
pre_param_size[0] *= world_size
chunks = torch.chunk(
param_all_gather_output.view(pre_param_size), world_size, dim=0
)
post_param_size = list(padded_sharded_size)
post_param_size[shard_dim] *= world_size
cat_out = target_all_gather_output.view(post_param_size)
torch.cat(chunks, dim=shard_dim, out=cat_out)
@torch.no_grad()