mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This makes OpHandler just a normal class using inheritance, and removes typing workarounds needed because it wasn't Pull Request resolved: https://github.com/pytorch/pytorch/pull/146257 Approved by: https://github.com/shunting314 ghstack dependencies: #146252, #146254, #146255
47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import unittest
|
|
|
|
from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides
|
|
from torch._inductor.codegen.halide import HalideOverrides
|
|
from torch._inductor.codegen.mps import MetalOverrides
|
|
from torch._inductor.codegen.triton import TritonKernelOverrides
|
|
from torch._inductor.ops_handler import list_ops, OP_NAMES, OpsHandler
|
|
from torch._inductor.test_case import TestCase
|
|
|
|
|
|
class TestOpCompleteness(TestCase):
|
|
def verify_ops_handler_completeness(self, handler):
|
|
for op in OP_NAMES:
|
|
self.assertIsNot(
|
|
getattr(handler, op),
|
|
getattr(OpsHandler, op),
|
|
msg=f"{handler} must implement {op}",
|
|
)
|
|
extra_ops = list_ops(handler) - OP_NAMES
|
|
if extra_ops:
|
|
raise AssertionError(
|
|
f"{handler} has an extra ops: {extra_ops}, add them to OpHandler class or prefix with `_`"
|
|
)
|
|
|
|
def test_triton_overrides(self):
|
|
self.verify_ops_handler_completeness(TritonKernelOverrides)
|
|
|
|
def test_cpp_overrides(self):
|
|
self.verify_ops_handler_completeness(CppOverrides)
|
|
|
|
def test_cpp_vec_overrides(self):
|
|
self.verify_ops_handler_completeness(CppVecOverrides)
|
|
|
|
def test_halide_overrides(self):
|
|
self.verify_ops_handler_completeness(HalideOverrides)
|
|
|
|
@unittest.skip("MPS backend not yet finished")
|
|
def test_metal_overrides(self):
|
|
self.verify_ops_handler_completeness(MetalOverrides)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
run_tests()
|