mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Split this directory into two PRs to keep them from being too large. Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165062 Approved by: https://github.com/oulgen, https://github.com/mlazos
30 lines
862 B
Python
30 lines
862 B
Python
from typing import Any
|
|
|
|
import torch.library
|
|
from torch import Tensor
|
|
from torch.autograd import Function
|
|
|
|
|
|
_test_lib_def = torch.library.Library("_inductor_test", "DEF")
|
|
_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag)
|
|
|
|
_test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
|
|
for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"):
|
|
_test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
|
|
|
|
|
|
class Realize(Function):
|
|
@staticmethod
|
|
# pyrefly: ignore # bad-override
|
|
def forward(ctx: object, x: Tensor) -> Tensor:
|
|
return torch.ops._inductor_test.realize(x)
|
|
|
|
@staticmethod
|
|
# types need to stay consistent with _SingleLevelFunction
|
|
def backward(ctx: Any, *grad_output: Any) -> Any:
|
|
return grad_output[0]
|
|
|
|
|
|
def realize(x: Tensor) -> Tensor:
|
|
return Realize.apply(x)
|