mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 07:24:58 +08:00
Summary: When we call torch.inference_mode, we seem to skip Autograd key causing the custom op export uses to be not decomposed properly before subclass dispatching starts. We fix this by force desugaring this op at Python key Test Plan: test Differential Revision: D71599541 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149698 Approved by: https://github.com/bdhirsh
27 lines
960 B
Python
27 lines
960 B
Python
import torch
|
|
|
|
|
|
lib = torch.library.Library("export", "FRAGMENT") # noqa: TOR901
|
|
|
|
lib.define(
|
|
"access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor"
|
|
)
|
|
|
|
|
|
@torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd")
|
|
# When running under torch.inference_mode(), we seem to skip AUtograd key
|
|
# so we should desugar this op as soon as we start tracing to post-dispatch.
|
|
@torch.library.impl(lib, "access_subclass_inner_tensor", "Python")
|
|
def _access_subclass_inner_tensor(
|
|
src_subclass_tensor: torch.Tensor, attr: str
|
|
) -> torch.Tensor:
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
|
|
assert is_traceable_wrapper_subclass(src_subclass_tensor)
|
|
val = getattr(src_subclass_tensor, attr, None)
|
|
if val is None or not isinstance(val, torch.Tensor):
|
|
raise RuntimeError(
|
|
f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}"
|
|
)
|
|
return val
|