mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
xpu: support custom ops with torch.library on xpu backend (#152879)
Fixes: https://github.com/intel/torch-xpu-ops/issues/1626
This PR started enabling of tests for `torch.library`, but more work is needed. Tests are using `torch._custom_ops` deprecated API planned for removal at pytorch 2.6 (not done). I think cleanup of pytorch would be nice before enabling more tests for xpu.
a2ccda3c60/torch/_custom_op/impl.py (L47)
CC: @EikanWang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152879
Approved by: https://github.com/EikanWang, https://github.com/malfet, https://github.com/guangyey, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
442aca44d6
commit
cd80f9a4c3
@ -167,6 +167,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
lib.impl("foo", Foo.apply, "Autograd")
|
||||
lib.impl("foo", foo_impl, "CPU")
|
||||
lib.impl("foo", foo_impl, "CUDA")
|
||||
lib.impl("foo", foo_impl, "XPU")
|
||||
|
||||
x = torch.tensor(3.14159 / 3, requires_grad=True, device=device)
|
||||
with self.assertRaisesRegex(
|
||||
@ -271,6 +272,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
lib.impl("foo", Foo.apply, "Autograd")
|
||||
lib.impl("foo", foo_impl, "CPU")
|
||||
lib.impl("foo", foo_impl, "CUDA")
|
||||
lib.impl("foo", foo_impl, "XPU")
|
||||
|
||||
x = torch.tensor([0, 1.0], requires_grad=True)
|
||||
with self.assertRaisesRegex(
|
||||
@ -312,6 +314,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
lib.impl("foo", Foo.apply, "Autograd")
|
||||
lib.impl("foo", foo_impl, "CPU")
|
||||
lib.impl("foo", foo_impl, "CUDA")
|
||||
lib.impl("foo", foo_impl, "XPU")
|
||||
lib.impl("foo", foo_meta, "Meta")
|
||||
|
||||
x = torch.tensor([0, 1.0], requires_grad=True)
|
||||
@ -343,6 +346,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
lib.impl("foo", Foo.apply, "Autograd")
|
||||
lib.impl("foo", foo_impl, "CPU")
|
||||
lib.impl("foo", foo_impl, "CUDA")
|
||||
lib.impl("foo", foo_impl, "XPU")
|
||||
lib.impl("foo", foo_meta, "Meta")
|
||||
|
||||
x = torch.tensor([0, 1.0])
|
||||
@ -369,6 +373,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
|
||||
lib.impl("foo", Foo.apply, "CPU")
|
||||
lib.impl("foo", Foo.apply, "CUDA")
|
||||
lib.impl("foo", Foo.apply, "XPU")
|
||||
lib.impl("foo", lambda x: x.clone(), "Meta")
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
@ -462,6 +467,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
lib.impl("foo", Foo.apply, "Autograd")
|
||||
lib.impl("foo", foo_impl, "CPU")
|
||||
lib.impl("foo", foo_impl, "CUDA")
|
||||
lib.impl("foo", foo_impl, "XPU")
|
||||
|
||||
x = torch.randn(3, requires_grad=True, device=device)
|
||||
# Should not raise
|
||||
@ -511,6 +517,7 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
|
||||
lib.impl("foo", Foo.apply, "CPU")
|
||||
lib.impl("foo", Foo.apply, "CUDA")
|
||||
lib.impl("foo", Foo.apply, "XPU")
|
||||
|
||||
x = torch.randn(3, requires_grad=True, device=device)
|
||||
with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
|
||||
@ -4677,8 +4684,10 @@ class TestOpProfiles(TestCase):
|
||||
loaded = read_profiles_from_yaml(yaml_str)
|
||||
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
|
||||
only_for = ("cpu", "cuda", "xpu")
|
||||
instantiate_device_type_tests(
|
||||
TestCustomOpTesting, globals(), only_for=only_for, allow_xpu=True
|
||||
)
|
||||
instantiate_parametrized_tests(TestCustomOp)
|
||||
instantiate_parametrized_tests(TestCustomOpAPI)
|
||||
|
||||
|
Reference in New Issue
Block a user