mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129768 Approved by: https://github.com/jansel
28 lines
816 B
Python
28 lines
816 B
Python
# mypy: allow-untyped-defs
|
|
import torch.library
|
|
from torch import Tensor
|
|
from torch.autograd import Function
|
|
|
|
|
|
if not torch._running_with_deploy():
|
|
_test_lib_def = torch.library.Library("_inductor_test", "DEF")
|
|
_test_lib_def.define(
|
|
"realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag
|
|
)
|
|
|
|
_test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
|
|
for dispatch_key in ("CPU", "CUDA", "Meta"):
|
|
_test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
|
|
|
|
class Realize(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return torch.ops._inductor_test.realize(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
def realize(x: Tensor) -> Tensor:
|
|
return Realize.apply(x)
|