Files
pytorch/torch/_dynamo/tensor_version_op.py
Lucas Kabela 656885b614 [Dynamo][Better Engineering] Type devices, resume_execution and testing utils (#158593)
As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to a set of utilities in dynamo, `device_interface.py`, `resume_execution.py`, `tensor_version_ops.py`, `test_case.py`, and `test_minifier_common.py`

Running
```
mypy torch/_dynamo/device_interface.py torch/_dynamo/resume_execution.py torch/_dynamo/tensor_version_op.py torch/_dynamo/test_case.py torch/_dynamo/test_minifier_common.py  --linecount-report /tmp/coverage_log
```

| -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  976 | 1672 | 58.37% | 76 | 112 | 67.86% |
| This PR | 1719 | 1719 | 100.00% | 112 | 112 | 100.00% |
| Delta    | +743 | +47 | +41.63% | +36 | 0 | +32.14% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158593
Approved by: https://github.com/mlazos
2025-07-18 18:22:06 +00:00

71 lines
2.7 KiB
Python

"""This module implements tensor version operations for Dynamo tracing.
It provides primitives for handling tensor versioning during tracing, particularly in the
context of functionalization where version operations are handled eagerly on fake tensors.
When we functionalize _tensor_version + _unsafe_set_version_counter, the ops disappear from
the traced graph. We run them eagerly on the fake tensors used for tracing, in order to get
past asserts that would fail in autograd.
Why is this ok?
1) Versions on functional tensors do not make any sense since you cannot mutate a functional
tensor.
2) The whole point of version munging is to trick autograd into doing what we want, and after
AotAutograd there is no longer any need for these ops.
Note this is similar to how no_grad is handled.
"""
from contextlib import AbstractContextManager
from typing import Any
import torch
from torch import SymInt
from torch._prims import _make_prim, RETURN_TYPE
from torch._subclasses import FakeTensorMode
from torch._subclasses.functional_tensor import FunctionalTensorMode
_tensor_version = _make_prim(
schema="_tensor_version(Tensor self) -> SymInt",
return_type=RETURN_TYPE.NEW,
meta=torch.ops.aten._version.default,
impl_aten=torch.ops.aten._version.default,
doc="Tracable unbacked SymInt version of torch.Tensor._version",
)
@_tensor_version.py_impl(FakeTensorMode) # type: ignore[misc]
def _tensor_version_fake(fake_mode: FakeTensorMode, self_tensor: Any) -> SymInt:
"""
The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the
`._version` into an unbacked SymInt so that we don't need to specialize on the `._version`
of input tensors to the graph.
"""
assert fake_mode.shape_env is not None
return fake_mode.shape_env.create_unbacked_symint()
_unsafe_set_version_counter = _make_prim(
schema="_unsafe_set_version_counter(Tensor[] tensors, SymInt[] versions) -> ()",
return_type=RETURN_TYPE.NEW,
meta=lambda self, version: None,
impl_aten=torch._C._autograd._unsafe_set_version_counter,
doc="Tracable+SymInt version of torch._C._autograd._unsafe_set_version_counter",
)
torch.fx.node.has_side_effect(_unsafe_set_version_counter)
@_tensor_version.py_impl(FunctionalTensorMode) # type: ignore[misc]
def _tensor_version_functional(mode: FunctionalTensorMode, self: Any) -> int:
return self._version
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) # type: ignore[misc]
def _unsafe_set_version_counter_functional(
ctx: AbstractContextManager[Any],
tensors: tuple[torch.Tensor, ...],
versions: tuple[int, ...],
) -> None:
torch._C._autograd._unsafe_set_version_counter(tensors, versions)