torch.autograd.graph.increment_version: accept List[Tensor], use in AOTDispatcher (#132652)

The regression from https://github.com/pytorch/pytorch/issues/132281 pinpoints e4ace1a396 as the cause. The main delta that commit introduces is that we now manually check `is_inference()` and call `increment_version()` (a pybind call) on every mutated input tensor to the graph.

This PR attempts to reduce overhead a bit by bundling up all of those checks into a single pybind call, by:

(1) updating `torch.autograd.graph.increment_version()` to accept a `Union[Tensor, List[Tensor]]`

(2) updating its semantics to no-op if you pass in a tensor with no version counter, instead of erroring

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132652
Approved by: https://github.com/albanD
This commit is contained in:
Brian Hirsh
2024-08-05 14:20:10 -07:00
committed by PyTorch MergeBot
parent af67b8df6d
commit e3394e5548
5 changed files with 33 additions and 13 deletions

View File

@ -3982,10 +3982,12 @@ class TestAutograd(TestCase):
with torch.inference_mode():
a = torch.rand(5, requires_grad=True)
msg = "update to inference tensor outside InferenceMode"
with self.assertRaisesRegex(RuntimeError, msg):
# does not error
torch.autograd.graph.increment_version(a)
# does not error
torch.autograd.graph.increment_version(a)
def test_no_grad_input(self):
class MyFunction(Function):
@staticmethod

View File

@ -1331,7 +1331,7 @@ def autocast_increment_nesting() -> _int: ...
def autocast_decrement_nesting() -> _int: ...
def is_autocast_cache_enabled() -> _bool: ...
def set_autocast_cache_enabled(enabled: _bool) -> None: ...
def _increment_version(tensor: Tensor) -> None: ...
def _increment_version(tensors: Iterable[Tensor]) -> None: ...
def set_anomaly_enabled(enabled: _bool, check_nan: _bool = True) -> None: ...
def is_anomaly_enabled() -> _bool: ...
def is_anomaly_check_nan_enabled() -> _bool: ...

View File

@ -295,10 +295,11 @@ def _create_runtime_wrapper(
orig_inputs = {i: args[i] for i in epilogue_args_idx}
if keep_input_mutations:
for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd:
arg = args[i]
if not arg.is_inference(): # inference tensors have no VC
torch.autograd.graph.increment_version(arg)
mutated_args = (
args[i]
for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd
)
torch.autograd.graph.increment_version(mutated_args)
if trace_joint:
args_ = list(args)

View File

@ -215,7 +215,7 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
return GradientEdge(grad_fn, tensor.output_nr)
def increment_version(tensor: torch.Tensor) -> None:
def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None:
"""Update autograd metadata tracking whether the given Tensor was modified in place.
This is to enable more accurate error checking within the autograd engine.
@ -223,11 +223,16 @@ def increment_version(tensor: torch.Tensor) -> None:
when mark_dirty() is called appropriately so you only need to call this explicitly
if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
know about. For example a custom kernel that reads the Tensor data_ptr and modifies
the memory inplace based on this pointer.
the memory inplace based on this pointer. Can accept either a tensor, or a list of tensors.
Note that incrementing the version counter multiple times for a single inplace operation
is not problematic.
Note that if you pass in tensor constructed under torch.inference_mode(),
we will not bump its version counter (because your tensor does not have one).
"""
if isinstance(tensor, torch.Tensor):
tensor = (tensor,)
torch._C._increment_version(tensor)

View File

@ -1254,11 +1254,23 @@ static PyObject* len_torch_dispatch_stack(PyObject* _unused, PyObject* args) {
END_HANDLE_TH_ERRORS
}
PyObject* THPModule_increment_version(PyObject* _unused, PyObject* tensor) {
PyObject* THPModule_increment_version(
PyObject* _unused,
PyObject* tensor_list) {
HANDLE_TH_ERRORS
TORCH_CHECK(
THPVariable_Check(tensor), "increment_version expect a Tensor as input");
torch::autograd::increment_version((THPVariable_Unpack(tensor)));
auto iterator = THPObjectPtr(PyObject_GetIter(tensor_list));
TORCH_CHECK(iterator, "increment_version expect a Iterable[Tensor] as input");
auto item = THPObjectPtr(PyIter_Next(iterator));
while (item) {
TORCH_CHECK(
THPVariable_Check(item),
"increment_version expects each element of the iterable to be a tensor");
auto t = THPVariable_Unpack(item);
if (!t.is_inference()) {
torch::autograd::increment_version(t);
}
item = THPObjectPtr(PyIter_Next(iterator));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}