mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch.autograd.graph.increment_version: accept List[Tensor], use in AOTDispatcher (#132652)
The regression from https://github.com/pytorch/pytorch/issues/132281 pinpoints e4ace1a396
as the cause. The main delta that commit introduces is that we now manually check `is_inference()` and call `increment_version()` (a pybind call) on every mutated input tensor to the graph.
This PR attempts to reduce overhead a bit by bundling up all of those checks into a single pybind call, by:
(1) updating `torch.autograd.graph.increment_version()` to accept a `Union[Tensor, List[Tensor]]`
(2) updating its semantics to no-op if you pass in a tensor with no version counter, instead of erroring
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132652
Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
af67b8df6d
commit
e3394e5548
@ -3982,10 +3982,12 @@ class TestAutograd(TestCase):
|
||||
|
||||
with torch.inference_mode():
|
||||
a = torch.rand(5, requires_grad=True)
|
||||
msg = "update to inference tensor outside InferenceMode"
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
# does not error
|
||||
torch.autograd.graph.increment_version(a)
|
||||
|
||||
# does not error
|
||||
torch.autograd.graph.increment_version(a)
|
||||
|
||||
def test_no_grad_input(self):
|
||||
class MyFunction(Function):
|
||||
@staticmethod
|
||||
|
@ -1331,7 +1331,7 @@ def autocast_increment_nesting() -> _int: ...
|
||||
def autocast_decrement_nesting() -> _int: ...
|
||||
def is_autocast_cache_enabled() -> _bool: ...
|
||||
def set_autocast_cache_enabled(enabled: _bool) -> None: ...
|
||||
def _increment_version(tensor: Tensor) -> None: ...
|
||||
def _increment_version(tensors: Iterable[Tensor]) -> None: ...
|
||||
def set_anomaly_enabled(enabled: _bool, check_nan: _bool = True) -> None: ...
|
||||
def is_anomaly_enabled() -> _bool: ...
|
||||
def is_anomaly_check_nan_enabled() -> _bool: ...
|
||||
|
@ -295,10 +295,11 @@ def _create_runtime_wrapper(
|
||||
orig_inputs = {i: args[i] for i in epilogue_args_idx}
|
||||
|
||||
if keep_input_mutations:
|
||||
for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd:
|
||||
arg = args[i]
|
||||
if not arg.is_inference(): # inference tensors have no VC
|
||||
torch.autograd.graph.increment_version(arg)
|
||||
mutated_args = (
|
||||
args[i]
|
||||
for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd
|
||||
)
|
||||
torch.autograd.graph.increment_version(mutated_args)
|
||||
|
||||
if trace_joint:
|
||||
args_ = list(args)
|
||||
|
@ -215,7 +215,7 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
|
||||
return GradientEdge(grad_fn, tensor.output_nr)
|
||||
|
||||
|
||||
def increment_version(tensor: torch.Tensor) -> None:
|
||||
def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None:
|
||||
"""Update autograd metadata tracking whether the given Tensor was modified in place.
|
||||
|
||||
This is to enable more accurate error checking within the autograd engine.
|
||||
@ -223,11 +223,16 @@ def increment_version(tensor: torch.Tensor) -> None:
|
||||
when mark_dirty() is called appropriately so you only need to call this explicitly
|
||||
if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
|
||||
know about. For example a custom kernel that reads the Tensor data_ptr and modifies
|
||||
the memory inplace based on this pointer.
|
||||
the memory inplace based on this pointer. Can accept either a tensor, or a list of tensors.
|
||||
|
||||
Note that incrementing the version counter multiple times for a single inplace operation
|
||||
is not problematic.
|
||||
|
||||
Note that if you pass in tensor constructed under torch.inference_mode(),
|
||||
we will not bump its version counter (because your tensor does not have one).
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = (tensor,)
|
||||
torch._C._increment_version(tensor)
|
||||
|
||||
|
||||
|
@ -1254,11 +1254,23 @@ static PyObject* len_torch_dispatch_stack(PyObject* _unused, PyObject* args) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPModule_increment_version(PyObject* _unused, PyObject* tensor) {
|
||||
PyObject* THPModule_increment_version(
|
||||
PyObject* _unused,
|
||||
PyObject* tensor_list) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
THPVariable_Check(tensor), "increment_version expect a Tensor as input");
|
||||
torch::autograd::increment_version((THPVariable_Unpack(tensor)));
|
||||
auto iterator = THPObjectPtr(PyObject_GetIter(tensor_list));
|
||||
TORCH_CHECK(iterator, "increment_version expect a Iterable[Tensor] as input");
|
||||
auto item = THPObjectPtr(PyIter_Next(iterator));
|
||||
while (item) {
|
||||
TORCH_CHECK(
|
||||
THPVariable_Check(item),
|
||||
"increment_version expects each element of the iterable to be a tensor");
|
||||
auto t = THPVariable_Unpack(item);
|
||||
if (!t.is_inference()) {
|
||||
torch::autograd::increment_version(t);
|
||||
}
|
||||
item = THPObjectPtr(PyIter_Next(iterator));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
Reference in New Issue
Block a user