mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153551 Approved by: https://github.com/mlazos
41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
from typing import Any
|
|
|
|
import torch
|
|
|
|
from ..decorators import substitute_in_graph
|
|
|
|
|
|
@substitute_in_graph( # type: ignore[arg-type]
|
|
torch.Tensor._make_subclass
|
|
)
|
|
def make_subclass(
|
|
cls: type[Any], data: torch.Tensor, requires_grad: bool = False, **kwargs: Any
|
|
) -> Any:
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
# This is a rough approximation of `THPVariable_make_subclass`. It should
|
|
# suffice for most of Dynamo tracing purposes.
|
|
# https://github.com/pytorch/pytorch/blob/ccfde4dadfa3c342076a1ee387017f84dd4ad2f7/torch/csrc/autograd/python_variable.cpp#L597-L650
|
|
assert len(kwargs) == 0, (
|
|
"_make_subclass only supports requires_grad as keyword arg"
|
|
)
|
|
data = data.detach()
|
|
|
|
# Avoid unnecessary `requires_grad` mutation, which isn't supported in Dynamo.
|
|
if data.requires_grad != requires_grad:
|
|
data.requires_grad = requires_grad
|
|
|
|
# Dynamo can't yet handle upcasting to base tensor type via `as_subclass`.
|
|
if cls is torch.Tensor:
|
|
return torch.Tensor(data)
|
|
|
|
# Calling `as_subclass` because
|
|
# 1. Dynamo knows how to handle it
|
|
# 2. the C impls match at this point -- both `THPVariable_make_subclass` and
|
|
# `THPVariable_as_subclass` calls `THPVariable_NewWithVar`.
|
|
return data.as_subclass(cls)
|
|
|
|
|
|
__all__ = [
|
|
"make_subclass",
|
|
]
|