diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index ed50a02d7b27..b61f35515b4b 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -242,7 +242,7 @@ class DistributedPatternTests(TestCase): x = x.sin() v = w._version w.copy_(x + 1) - torch._C._autograd._unsafe_set_version_counter(w, v) + torch._C._autograd._unsafe_set_version_counter((w,), (v,)) return w, v for v in (3, 0, 1): @@ -266,7 +266,7 @@ class DistributedPatternTests(TestCase): with torch.no_grad(): v = w._version w.copy_(x) - torch._C._autograd._unsafe_set_version_counter(w, v) + torch._C._autograd._unsafe_set_version_counter((w,), (v,)) return r w1 = torch.randn(1, requires_grad=True) diff --git a/test/test_autograd.py b/test/test_autograd.py index 6fd359fbc943..7756425e7092 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4799,10 +4799,18 @@ SinBackward0, MulBackward0, torch::autograd::AccumulateGrad # version counter doesn't change inside of the context manager self.assertEqual(2, x._version) - torch._C._autograd._unsafe_set_version_counter(x, 0) + torch._C._autograd._unsafe_set_version_counter((x,), (0,)) self.assertEqual(0, x._version) with self.assertRaisesRegex(RuntimeError, "Cannot set"): - torch._C._autograd._unsafe_set_version_counter(x, -1) + torch._C._autograd._unsafe_set_version_counter((x,), (-1,)) + + y = torch.ones(2, requires_grad=True).clone() + with torch.autograd._unsafe_preserve_version_counter((x, y)): + x.mul_(2) + y.mul_(3) + # version counter doesn't change inside of the context manager + self.assertEqual(0, x._version) + self.assertEqual(0, y._version) def test_current_node(self): pr = [] diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index f756828ed6c9..457929cb72ae 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -115,7 +115,9 @@ def _push_saved_tensors_default_hooks( unpack_hook: Callable[[Any], torch.Tensor], ) -> None: ... def _pop_saved_tensors_default_hooks() -> None: ... -def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ... +def _unsafe_set_version_counter( + t: tuple[torch.Tensor, ...], prev_version: tuple[int, ...] +) -> None: ... def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ... def _profiler_type() -> ActiveProfilerType: ... diff --git a/torch/_dynamo/tensor_version_op.py b/torch/_dynamo/tensor_version_op.py index 889b2450409f..63fa3b439009 100644 --- a/torch/_dynamo/tensor_version_op.py +++ b/torch/_dynamo/tensor_version_op.py @@ -25,7 +25,7 @@ def _tensor_version_fake(fake_mode, self_tensor): _unsafe_set_version_counter = _make_prim( - schema="_unsafe_set_version_counter(Tensor self, SymInt version) -> ()", + schema="_unsafe_set_version_counter(Tensor[] tensors, SymInt[] versions) -> ()", return_type=RETURN_TYPE.NEW, meta=lambda self, version: None, impl_aten=torch._C._autograd._unsafe_set_version_counter, @@ -55,5 +55,5 @@ def _tensor_version_functional(mode, self): @_unsafe_set_version_counter.py_impl(FunctionalTensorMode) -def _unsafe_set_version_counter_functional(ctx, self, version): - torch._C._autograd._unsafe_set_version_counter(self, version) +def _unsafe_set_version_counter_functional(ctx, tensors, versions): + torch._C._autograd._unsafe_set_version_counter(tensors, versions) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 79e7666916a7..3131d44c8822 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1872,7 +1872,7 @@ class BuiltinVariable(VariableTracker): version = x._version if version > 0: version = version - 1 - torch._C._autograd._unsafe_set_version_counter(x, version) + torch._C._autograd._unsafe_set_version_counter((x,), (version,)) return x tx.output.create_proxy( diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index cbe9939014aa..c251433390f7 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -975,20 +975,39 @@ class PreserveVersionContextVariable(ContextWrappingVariable): Wraps torch.autograd._unsafe_preserve_version_counter """ + @staticmethod + def _create_lambda_from_tensors(tx, tensors): + if isinstance(tensors, variables.TensorVariable): + versions = variables.TupleVariable( + [x.var_getattr(tx, "_version") for x in [tensors]] + ) + tensors = variables.TupleVariable([tensors]) + else: + versions = variables.TupleVariable( + [x.var_getattr(tx, "_version") for x in tensors.items] + ) + return PreserveVersionContextVariable(tensors, versions) + @staticmethod def constructor(tx): return variables.LambdaVariable( - lambda tensor: PreserveVersionContextVariable( - tensor, - tensor.var_getattr(tx, "_version"), + lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors( + tx, tensors ) ) - def __init__(self, tensor, prev_version, **kwargs) -> None: + def __init__(self, tensors, prev_versions, **kwargs) -> None: kwargs.setdefault("target_values", None) super().__init__(**kwargs) - self.tensor = tensor - self.prev_version = prev_version + self.tensors = tensors + self.prev_versions = prev_versions + # The context manager accepts Union[Tensor, Tuple[Tensor]] + if isinstance(self.tensors, variables.TensorVariable): + self.tensors = variables.TupleVariable([self.tensors]) + if isinstance( + self.prev_versions, (variables.ConstantVariable, variables.SymNodeVariable) + ): + self.prev_versions = variables.TupleVariable([self.prev_versions]) def enter(self, tx): pass @@ -998,7 +1017,7 @@ class PreserveVersionContextVariable(ContextWrappingVariable): return variables.TorchInGraphFunctionVariable( _unsafe_set_version_counter - ).call_function(tx, [self.tensor, self.prev_version], {}) + ).call_function(tx, [self.tensors, self.prev_versions], {}) def reconstruct(self, codegen): unimplemented( diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 3cc7593cf6a5..6aa932c01136 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any +from typing import Any, Tuple, Union import torch from torch.utils._contextlib import ( @@ -386,12 +386,13 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager): """ - def __init__(self, tensor: torch.Tensor) -> None: - self.tensor = tensor - self.prev_version = tensor._version + def __init__(self, tensors: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> None: + self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors + assert isinstance(self.tensors, tuple) + self.prev_versions = tuple(t._version for t in self.tensors) def __enter__(self) -> None: pass def __exit__(self, *args) -> None: - torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version) + torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions) diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 9a8375480374..e293cc4be21d 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -388,10 +388,23 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { return activities; }); - m.def("_unsafe_set_version_counter", [](const at::Tensor& t, int64_t i) { - auto vc = torch::autograd::impl::version_counter(t); - vc.set_version(i); - }); + m.def( + "_unsafe_set_version_counter", + [](const std::vector& tensors, + const std::vector& versions) { + auto tensors_len = tensors.size(); + auto versions_len = versions.size(); + TORCH_CHECK( + tensors_len == versions_len, + "tensors_len=", + tensors_len, + ", versions_len=", + versions_len); + for (const auto i : c10::irange(tensors_len)) { + auto vc = torch::autograd::impl::version_counter(tensors[i]); + vc.set_version(versions[i]); + } + }); m.def("_enable_profiler_legacy", enableProfilerLegacy); py::class_(m, "_ProfilerDisableOptions") diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index 1dbe91eb0622..210945af2241 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -290,35 +290,37 @@ def foreach_all_gather_copy_out( out = [t.view(world_size, -1).view(torch.uint8) for t in split_with_sizes_out] else: out = [t.view(world_size, -1) for t in split_with_sizes_out] - torch.ops.fsdp.split_with_sizes_copy( - all_gather_output, all_gather_input_split_sizes, dim=1, out=out - ) + with torch.autograd._unsafe_preserve_version_counter(tuple(out)): + torch.ops.fsdp.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=1, out=out + ) for fsdp_param, param_all_gather_outputs in shard_i_copy_infos: # Chunk-cat from the temporary to the final all-gather output tensors shard_dim = fsdp_param.fsdp_placement.dim - for param_all_gather_output, target_all_gather_output in zip( - param_all_gather_outputs, fsdp_param.all_gather_outputs + + with torch.autograd._unsafe_preserve_version_counter( + tuple(fsdp_param.all_gather_outputs) ): - padded_sharded_size = ( - fsdp_param.padded_sharded_param_size - if fsdp_param.sharded_state == ShardedState.SHARDED - else cast( - torch.Tensor, fsdp_param._sharded_post_forward_param_data - ).size() - ) - pre_param_size = list(padded_sharded_size) - pre_param_size[0] *= world_size - chunks = torch.chunk( - param_all_gather_output.view(pre_param_size), world_size, dim=0 - ) - post_param_size = list(padded_sharded_size) - post_param_size[shard_dim] *= world_size - cat_out = target_all_gather_output.view(post_param_size) - torch.cat(chunks, dim=shard_dim, out=cat_out) - torch._C._autograd._unsafe_set_version_counter( - target_all_gather_output, target_all_gather_output._version - 1 - ) + for param_all_gather_output, target_all_gather_output in zip( + param_all_gather_outputs, fsdp_param.all_gather_outputs + ): + padded_sharded_size = ( + fsdp_param.padded_sharded_param_size + if fsdp_param.sharded_state == ShardedState.SHARDED + else cast( + torch.Tensor, fsdp_param._sharded_post_forward_param_data + ).size() + ) + pre_param_size = list(padded_sharded_size) + pre_param_size[0] *= world_size + chunks = torch.chunk( + param_all_gather_output.view(pre_param_size), world_size, dim=0 + ) + post_param_size = list(padded_sharded_size) + post_param_size[shard_dim] *= world_size + cat_out = target_all_gather_output.view(post_param_size) + torch.cat(chunks, dim=shard_dim, out=cat_out) @torch.no_grad()