Files
pytorch/torch/_dynamo/polyfills/tensor.py

41 lines
1.4 KiB
Python

from typing import Any
import torch
from ..decorators import substitute_in_graph
@substitute_in_graph( # type: ignore[arg-type]
torch.Tensor._make_subclass
)
def make_subclass(
cls: type[Any], data: torch.Tensor, requires_grad: bool = False, **kwargs: Any
) -> Any:
with torch._C.DisableTorchFunctionSubclass():
# This is a rough approximation of `THPVariable_make_subclass`. It should
# suffice for most of Dynamo tracing purposes.
# https://github.com/pytorch/pytorch/blob/ccfde4dadfa3c342076a1ee387017f84dd4ad2f7/torch/csrc/autograd/python_variable.cpp#L597-L650
assert len(kwargs) == 0, (
"_make_subclass only supports requires_grad as keyword arg"
)
data = data.detach()
# Avoid unnecessary `requires_grad` mutation, which isn't supported in Dynamo.
if data.requires_grad != requires_grad:
data.requires_grad = requires_grad
# Dynamo can't yet handle upcasting to base tensor type via `as_subclass`.
if cls is torch.Tensor:
return torch.Tensor(data)
# Calling `as_subclass` because
# 1. Dynamo knows how to handle it
# 2. the C impls match at this point -- both `THPVariable_make_subclass` and
# `THPVariable_as_subclass` calls `THPVariable_NewWithVar`.
return data.as_subclass(cls)
__all__ = [
"make_subclass",
]