mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
update _unsafe_set_version_counter to accept lists of tensors (#137921)
See the comment [here](https://github.com/pytorch/pytorch/issues/132014#issuecomment-2379547400) (cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @XilunWu @rec) - this PR updates `_unsafe_set_version_counter` to accept a list of tensors, for overhead-sensitive users (e.g. distributed) who need to hide VC bumps from autograd on a large list of tensors without wanting to suffer the overhead of going from python->C++ separately for every tensor in the list. I left the binding in pybind, and used a `std::vector`. if we **really** need to optimize overhead even further, we could write a manual cpython binding. I use this updated API in the next PR to fix FSDP2, so that it properly hides the VC of all `all_gather_buffer` tensors in its call to `split_with_sizes_copy.out(all_gather_buffers)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137921 Approved by: https://github.com/awgu, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
425aca40a4
commit
e68f5087d8
@ -242,7 +242,7 @@ class DistributedPatternTests(TestCase):
|
||||
x = x.sin()
|
||||
v = w._version
|
||||
w.copy_(x + 1)
|
||||
torch._C._autograd._unsafe_set_version_counter(w, v)
|
||||
torch._C._autograd._unsafe_set_version_counter((w,), (v,))
|
||||
return w, v
|
||||
|
||||
for v in (3, 0, 1):
|
||||
@ -266,7 +266,7 @@ class DistributedPatternTests(TestCase):
|
||||
with torch.no_grad():
|
||||
v = w._version
|
||||
w.copy_(x)
|
||||
torch._C._autograd._unsafe_set_version_counter(w, v)
|
||||
torch._C._autograd._unsafe_set_version_counter((w,), (v,))
|
||||
return r
|
||||
|
||||
w1 = torch.randn(1, requires_grad=True)
|
||||
|
@ -4799,10 +4799,18 @@ SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
|
||||
# version counter doesn't change inside of the context manager
|
||||
self.assertEqual(2, x._version)
|
||||
|
||||
torch._C._autograd._unsafe_set_version_counter(x, 0)
|
||||
torch._C._autograd._unsafe_set_version_counter((x,), (0,))
|
||||
self.assertEqual(0, x._version)
|
||||
with self.assertRaisesRegex(RuntimeError, "Cannot set"):
|
||||
torch._C._autograd._unsafe_set_version_counter(x, -1)
|
||||
torch._C._autograd._unsafe_set_version_counter((x,), (-1,))
|
||||
|
||||
y = torch.ones(2, requires_grad=True).clone()
|
||||
with torch.autograd._unsafe_preserve_version_counter((x, y)):
|
||||
x.mul_(2)
|
||||
y.mul_(3)
|
||||
# version counter doesn't change inside of the context manager
|
||||
self.assertEqual(0, x._version)
|
||||
self.assertEqual(0, y._version)
|
||||
|
||||
def test_current_node(self):
|
||||
pr = []
|
||||
|
@ -115,7 +115,9 @@ def _push_saved_tensors_default_hooks(
|
||||
unpack_hook: Callable[[Any], torch.Tensor],
|
||||
) -> None: ...
|
||||
def _pop_saved_tensors_default_hooks() -> None: ...
|
||||
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
|
||||
def _unsafe_set_version_counter(
|
||||
t: tuple[torch.Tensor, ...], prev_version: tuple[int, ...]
|
||||
) -> None: ...
|
||||
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
|
||||
def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
|
||||
def _profiler_type() -> ActiveProfilerType: ...
|
||||
|
@ -25,7 +25,7 @@ def _tensor_version_fake(fake_mode, self_tensor):
|
||||
|
||||
|
||||
_unsafe_set_version_counter = _make_prim(
|
||||
schema="_unsafe_set_version_counter(Tensor self, SymInt version) -> ()",
|
||||
schema="_unsafe_set_version_counter(Tensor[] tensors, SymInt[] versions) -> ()",
|
||||
return_type=RETURN_TYPE.NEW,
|
||||
meta=lambda self, version: None,
|
||||
impl_aten=torch._C._autograd._unsafe_set_version_counter,
|
||||
@ -55,5 +55,5 @@ def _tensor_version_functional(mode, self):
|
||||
|
||||
|
||||
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode)
|
||||
def _unsafe_set_version_counter_functional(ctx, self, version):
|
||||
torch._C._autograd._unsafe_set_version_counter(self, version)
|
||||
def _unsafe_set_version_counter_functional(ctx, tensors, versions):
|
||||
torch._C._autograd._unsafe_set_version_counter(tensors, versions)
|
||||
|
@ -1872,7 +1872,7 @@ class BuiltinVariable(VariableTracker):
|
||||
version = x._version
|
||||
if version > 0:
|
||||
version = version - 1
|
||||
torch._C._autograd._unsafe_set_version_counter(x, version)
|
||||
torch._C._autograd._unsafe_set_version_counter((x,), (version,))
|
||||
return x
|
||||
|
||||
tx.output.create_proxy(
|
||||
|
@ -975,20 +975,39 @@ class PreserveVersionContextVariable(ContextWrappingVariable):
|
||||
Wraps torch.autograd._unsafe_preserve_version_counter
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _create_lambda_from_tensors(tx, tensors):
|
||||
if isinstance(tensors, variables.TensorVariable):
|
||||
versions = variables.TupleVariable(
|
||||
[x.var_getattr(tx, "_version") for x in [tensors]]
|
||||
)
|
||||
tensors = variables.TupleVariable([tensors])
|
||||
else:
|
||||
versions = variables.TupleVariable(
|
||||
[x.var_getattr(tx, "_version") for x in tensors.items]
|
||||
)
|
||||
return PreserveVersionContextVariable(tensors, versions)
|
||||
|
||||
@staticmethod
|
||||
def constructor(tx):
|
||||
return variables.LambdaVariable(
|
||||
lambda tensor: PreserveVersionContextVariable(
|
||||
tensor,
|
||||
tensor.var_getattr(tx, "_version"),
|
||||
lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors(
|
||||
tx, tensors
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, tensor, prev_version, **kwargs) -> None:
|
||||
def __init__(self, tensors, prev_versions, **kwargs) -> None:
|
||||
kwargs.setdefault("target_values", None)
|
||||
super().__init__(**kwargs)
|
||||
self.tensor = tensor
|
||||
self.prev_version = prev_version
|
||||
self.tensors = tensors
|
||||
self.prev_versions = prev_versions
|
||||
# The context manager accepts Union[Tensor, Tuple[Tensor]]
|
||||
if isinstance(self.tensors, variables.TensorVariable):
|
||||
self.tensors = variables.TupleVariable([self.tensors])
|
||||
if isinstance(
|
||||
self.prev_versions, (variables.ConstantVariable, variables.SymNodeVariable)
|
||||
):
|
||||
self.prev_versions = variables.TupleVariable([self.prev_versions])
|
||||
|
||||
def enter(self, tx):
|
||||
pass
|
||||
@ -998,7 +1017,7 @@ class PreserveVersionContextVariable(ContextWrappingVariable):
|
||||
|
||||
return variables.TorchInGraphFunctionVariable(
|
||||
_unsafe_set_version_counter
|
||||
).call_function(tx, [self.tensor, self.prev_version], {})
|
||||
).call_function(tx, [self.tensors, self.prev_versions], {})
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
unimplemented(
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any
|
||||
from typing import Any, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.utils._contextlib import (
|
||||
@ -386,12 +386,13 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor) -> None:
|
||||
self.tensor = tensor
|
||||
self.prev_version = tensor._version
|
||||
def __init__(self, tensors: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> None:
|
||||
self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
|
||||
assert isinstance(self.tensors, tuple)
|
||||
self.prev_versions = tuple(t._version for t in self.tensors)
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)
|
||||
torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions)
|
||||
|
@ -388,10 +388,23 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
return activities;
|
||||
});
|
||||
|
||||
m.def("_unsafe_set_version_counter", [](const at::Tensor& t, int64_t i) {
|
||||
auto vc = torch::autograd::impl::version_counter(t);
|
||||
vc.set_version(i);
|
||||
});
|
||||
m.def(
|
||||
"_unsafe_set_version_counter",
|
||||
[](const std::vector<at::Tensor>& tensors,
|
||||
const std::vector<int64_t>& versions) {
|
||||
auto tensors_len = tensors.size();
|
||||
auto versions_len = versions.size();
|
||||
TORCH_CHECK(
|
||||
tensors_len == versions_len,
|
||||
"tensors_len=",
|
||||
tensors_len,
|
||||
", versions_len=",
|
||||
versions_len);
|
||||
for (const auto i : c10::irange(tensors_len)) {
|
||||
auto vc = torch::autograd::impl::version_counter(tensors[i]);
|
||||
vc.set_version(versions[i]);
|
||||
}
|
||||
});
|
||||
|
||||
m.def("_enable_profiler_legacy", enableProfilerLegacy);
|
||||
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
|
||||
|
@ -290,35 +290,37 @@ def foreach_all_gather_copy_out(
|
||||
out = [t.view(world_size, -1).view(torch.uint8) for t in split_with_sizes_out]
|
||||
else:
|
||||
out = [t.view(world_size, -1) for t in split_with_sizes_out]
|
||||
torch.ops.fsdp.split_with_sizes_copy(
|
||||
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
|
||||
)
|
||||
with torch.autograd._unsafe_preserve_version_counter(tuple(out)):
|
||||
torch.ops.fsdp.split_with_sizes_copy(
|
||||
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
|
||||
)
|
||||
|
||||
for fsdp_param, param_all_gather_outputs in shard_i_copy_infos:
|
||||
# Chunk-cat from the temporary to the final all-gather output tensors
|
||||
shard_dim = fsdp_param.fsdp_placement.dim
|
||||
for param_all_gather_output, target_all_gather_output in zip(
|
||||
param_all_gather_outputs, fsdp_param.all_gather_outputs
|
||||
|
||||
with torch.autograd._unsafe_preserve_version_counter(
|
||||
tuple(fsdp_param.all_gather_outputs)
|
||||
):
|
||||
padded_sharded_size = (
|
||||
fsdp_param.padded_sharded_param_size
|
||||
if fsdp_param.sharded_state == ShardedState.SHARDED
|
||||
else cast(
|
||||
torch.Tensor, fsdp_param._sharded_post_forward_param_data
|
||||
).size()
|
||||
)
|
||||
pre_param_size = list(padded_sharded_size)
|
||||
pre_param_size[0] *= world_size
|
||||
chunks = torch.chunk(
|
||||
param_all_gather_output.view(pre_param_size), world_size, dim=0
|
||||
)
|
||||
post_param_size = list(padded_sharded_size)
|
||||
post_param_size[shard_dim] *= world_size
|
||||
cat_out = target_all_gather_output.view(post_param_size)
|
||||
torch.cat(chunks, dim=shard_dim, out=cat_out)
|
||||
torch._C._autograd._unsafe_set_version_counter(
|
||||
target_all_gather_output, target_all_gather_output._version - 1
|
||||
)
|
||||
for param_all_gather_output, target_all_gather_output in zip(
|
||||
param_all_gather_outputs, fsdp_param.all_gather_outputs
|
||||
):
|
||||
padded_sharded_size = (
|
||||
fsdp_param.padded_sharded_param_size
|
||||
if fsdp_param.sharded_state == ShardedState.SHARDED
|
||||
else cast(
|
||||
torch.Tensor, fsdp_param._sharded_post_forward_param_data
|
||||
).size()
|
||||
)
|
||||
pre_param_size = list(padded_sharded_size)
|
||||
pre_param_size[0] *= world_size
|
||||
chunks = torch.chunk(
|
||||
param_all_gather_output.view(pre_param_size), world_size, dim=0
|
||||
)
|
||||
post_param_size = list(padded_sharded_size)
|
||||
post_param_size[shard_dim] *= world_size
|
||||
cat_out = target_all_gather_output.view(post_param_size)
|
||||
torch.cat(chunks, dim=shard_dim, out=cat_out)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
Reference in New Issue
Block a user