Files
pytorch/test/custom_operator/test_custom_ops.py
Xuehai Pan 548c460bf1 [BE][Easy][7/19] enforce style for empty lines in import segments in test/[a-c]*/ and test/[q-z]*/ (#129758)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129758
Approved by: https://github.com/ezyang
2024-07-31 10:54:03 +00:00

160 lines
5.5 KiB
Python

# Owner(s): ["module: unknown"]
import os.path
import sys
import tempfile
import unittest
from model import get_custom_op_library_path, Model
import torch
import torch._library.utils as utils
from torch import ops
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
torch.ops.import_module("pointwise")
class TestCustomOperators(TestCase):
def setUp(self):
self.library_path = get_custom_op_library_path()
ops.load_library(self.library_path)
def test_custom_library_is_loaded(self):
self.assertIn(self.library_path, ops.loaded_libraries)
def test_op_with_no_abstract_impl_pystub(self):
x = torch.randn(3, device="meta")
if utils.requires_set_python_module():
with self.assertRaisesRegex(RuntimeError, "pointwise"):
torch.ops.custom.tan(x)
else:
# Smoketest
torch.ops.custom.tan(x)
def test_op_with_incorrect_abstract_impl_pystub(self):
x = torch.randn(3, device="meta")
with self.assertRaisesRegex(RuntimeError, "pointwise"):
torch.ops.custom.cos(x)
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
def test_dynamo_pystub_suggestion(self):
x = torch.randn(3)
@torch.compile(backend="eager", fullgraph=True)
def f(x):
return torch.ops.custom.asin(x)
with self.assertRaisesRegex(
RuntimeError,
r"unsupported operator: .* you may need to `import nonexistent`",
):
f(x)
def test_abstract_impl_pystub_faketensor(self):
from functorch import make_fx
x = torch.randn(3, device="cpu")
self.assertNotIn("my_custom_ops", sys.modules.keys())
with self.assertRaises(
torch._subclasses.fake_tensor.UnsupportedOperatorException
):
gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)
torch.ops.import_module("my_custom_ops")
gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)
self.assertExpectedInline(
"""\
def forward(self, arg0_1):
nonzero = torch.ops.custom.nonzero.default(arg0_1); arg0_1 = None
return nonzero
""".strip(),
gm.code.strip(),
)
def test_abstract_impl_pystub_meta(self):
x = torch.randn(3, device="meta")
self.assertNotIn("my_custom_ops2", sys.modules.keys())
with self.assertRaisesRegex(NotImplementedError, r"'my_custom_ops2'"):
y = torch.ops.custom.sin.default(x)
torch.ops.import_module("my_custom_ops2")
y = torch.ops.custom.sin.default(x)
def test_calling_custom_op_string(self):
output = ops.custom.op2("abc", "def")
self.assertLess(output, 0)
output = ops.custom.op2("abc", "abc")
self.assertEqual(output, 0)
def test_calling_custom_op(self):
output = ops.custom.op(torch.ones(5), 2.0, 3)
self.assertEqual(type(output), list)
self.assertEqual(len(output), 3)
for tensor in output:
self.assertTrue(tensor.allclose(torch.ones(5) * 2))
output = ops.custom.op_with_defaults(torch.ones(5))
self.assertEqual(type(output), list)
self.assertEqual(len(output), 1)
self.assertTrue(output[0].allclose(torch.ones(5)))
def test_calling_custom_op_with_autograd(self):
x = torch.randn((5, 5), requires_grad=True)
y = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y)
self.assertTrue(output.allclose(x + 2 * y + x * y))
go = torch.ones((), requires_grad=True)
output.sum().backward(go, False, True)
grad = torch.ones(5, 5)
self.assertEqual(x.grad, y + grad)
self.assertEqual(y.grad, x + grad * 2)
# Test with optional arg.
x.grad.zero_()
y.grad.zero_()
z = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y, z)
self.assertTrue(output.allclose(x + 2 * y + x * y + z))
go = torch.ones((), requires_grad=True)
output.sum().backward(go, False, True)
self.assertEqual(x.grad, y + grad)
self.assertEqual(y.grad, x + grad * 2)
self.assertEqual(z.grad, grad)
def test_calling_custom_op_with_autograd_in_nograd_mode(self):
with torch.no_grad():
x = torch.randn((5, 5), requires_grad=True)
y = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y)
self.assertTrue(output.allclose(x + 2 * y + x * y))
def test_calling_custom_op_inside_script_module(self):
model = Model()
output = model.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
def test_saving_and_loading_script_module_with_custom_op(self):
model = Model()
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually.
file = tempfile.NamedTemporaryFile(delete=False)
try:
file.close()
model.save(file.name)
loaded = torch.jit.load(file.name)
finally:
os.unlink(file.name)
output = loaded.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
if __name__ == "__main__":
run_tests()