mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a set of utilities in dynamo, `device_interface.py`, `resume_execution.py`, `tensor_version_ops.py`, `test_case.py`, and `test_minifier_common.py` Running ``` mypy torch/_dynamo/device_interface.py torch/_dynamo/resume_execution.py torch/_dynamo/tensor_version_op.py torch/_dynamo/test_case.py torch/_dynamo/test_minifier_common.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 976 | 1672 | 58.37% | 76 | 112 | 67.86% | | This PR | 1719 | 1719 | 100.00% | 112 | 112 | 100.00% | | Delta | +743 | +47 | +41.63% | +36 | 0 | +32.14% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158593 Approved by: https://github.com/mlazos
71 lines
2.7 KiB
Python
71 lines
2.7 KiB
Python
"""This module implements tensor version operations for Dynamo tracing.
|
|
|
|
It provides primitives for handling tensor versioning during tracing, particularly in the
|
|
context of functionalization where version operations are handled eagerly on fake tensors.
|
|
|
|
When we functionalize _tensor_version + _unsafe_set_version_counter, the ops disappear from
|
|
the traced graph. We run them eagerly on the fake tensors used for tracing, in order to get
|
|
past asserts that would fail in autograd.
|
|
|
|
Why is this ok?
|
|
1) Versions on functional tensors do not make any sense since you cannot mutate a functional
|
|
tensor.
|
|
2) The whole point of version munging is to trick autograd into doing what we want, and after
|
|
AotAutograd there is no longer any need for these ops.
|
|
|
|
Note this is similar to how no_grad is handled.
|
|
"""
|
|
|
|
from contextlib import AbstractContextManager
|
|
from typing import Any
|
|
|
|
import torch
|
|
from torch import SymInt
|
|
from torch._prims import _make_prim, RETURN_TYPE
|
|
from torch._subclasses import FakeTensorMode
|
|
from torch._subclasses.functional_tensor import FunctionalTensorMode
|
|
|
|
|
|
_tensor_version = _make_prim(
|
|
schema="_tensor_version(Tensor self) -> SymInt",
|
|
return_type=RETURN_TYPE.NEW,
|
|
meta=torch.ops.aten._version.default,
|
|
impl_aten=torch.ops.aten._version.default,
|
|
doc="Tracable unbacked SymInt version of torch.Tensor._version",
|
|
)
|
|
|
|
|
|
@_tensor_version.py_impl(FakeTensorMode) # type: ignore[misc]
|
|
def _tensor_version_fake(fake_mode: FakeTensorMode, self_tensor: Any) -> SymInt:
|
|
"""
|
|
The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the
|
|
`._version` into an unbacked SymInt so that we don't need to specialize on the `._version`
|
|
of input tensors to the graph.
|
|
"""
|
|
assert fake_mode.shape_env is not None
|
|
return fake_mode.shape_env.create_unbacked_symint()
|
|
|
|
|
|
_unsafe_set_version_counter = _make_prim(
|
|
schema="_unsafe_set_version_counter(Tensor[] tensors, SymInt[] versions) -> ()",
|
|
return_type=RETURN_TYPE.NEW,
|
|
meta=lambda self, version: None,
|
|
impl_aten=torch._C._autograd._unsafe_set_version_counter,
|
|
doc="Tracable+SymInt version of torch._C._autograd._unsafe_set_version_counter",
|
|
)
|
|
torch.fx.node.has_side_effect(_unsafe_set_version_counter)
|
|
|
|
|
|
@_tensor_version.py_impl(FunctionalTensorMode) # type: ignore[misc]
|
|
def _tensor_version_functional(mode: FunctionalTensorMode, self: Any) -> int:
|
|
return self._version
|
|
|
|
|
|
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) # type: ignore[misc]
|
|
def _unsafe_set_version_counter_functional(
|
|
ctx: AbstractContextManager[Any],
|
|
tensors: tuple[torch.Tensor, ...],
|
|
versions: tuple[int, ...],
|
|
) -> None:
|
|
torch._C._autograd._unsafe_set_version_counter(tensors, versions)
|