Files
pytorch/test/custom_operator/test_custom_ops.py
Pritam Damania cdc56d0b6c Support c10::optional<Tensor> in custom C++ autograd function. (#37700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37700

Certain autograd functions can have optional Tensor arguments. For
this purpose it would be nice to support c10::optional<Tensor> as an argument
for C++ autograd functions.

I've added the appropriate overload to ExtractVariables to ensure this works.
For an example, you can look at D21272807 in terms of how this is used.
ghstack-source-id: 103541789

Test Plan: waitforbuildbot

Differential Revision: D21363491

fbshipit-source-id: 0c8665e9bfe279e6b9ab84a889524fea11fa971c
2020-05-06 01:59:51 -07:00

94 lines
3.3 KiB
Python

import os.path
import tempfile
import unittest
import torch
from torch import ops
from model import Model, get_custom_op_library_path
class TestCustomOperators(unittest.TestCase):
def setUp(self):
self.library_path = get_custom_op_library_path()
ops.load_library(self.library_path)
def test_custom_library_is_loaded(self):
self.assertIn(self.library_path, ops.loaded_libraries)
def test_calling_custom_op_string(self):
output = ops.custom.op2("abc", "def")
self.assertLess(output, 0)
output = ops.custom.op2("abc", "abc")
self.assertEqual(output, 0)
def test_calling_custom_op(self):
output = ops.custom.op(torch.ones(5), 2.0, 3)
self.assertEqual(type(output), list)
self.assertEqual(len(output), 3)
for tensor in output:
self.assertTrue(tensor.allclose(torch.ones(5) * 2))
output = ops.custom.op_with_defaults(torch.ones(5))
self.assertEqual(type(output), list)
self.assertEqual(len(output), 1)
self.assertTrue(output[0].allclose(torch.ones(5)))
def test_calling_custom_op_with_autograd(self):
x = torch.randn((5, 5), requires_grad=True)
y = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y)
self.assertTrue(output.allclose(x + 2 * y + x * y))
go = torch.ones((), requires_grad=True)
output.sum().backward(go, False, True)
grad = torch.ones(5, 5)
self.assertTrue(torch.allclose(x.grad, y + grad))
self.assertTrue(torch.allclose(y.grad, x + grad * 2))
# Test with optional arg.
x.grad.zero_()
y.grad.zero_()
z = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y, z)
self.assertTrue(output.allclose(x + 2 * y + x * y + z))
go = torch.ones((), requires_grad=True)
output.sum().backward(go, False, True)
self.assertTrue(torch.allclose(x.grad, y + grad))
self.assertTrue(torch.allclose(y.grad, x + grad * 2))
self.assertTrue(torch.allclose(z.grad, grad))
def test_calling_custom_op_with_autograd_in_nograd_mode(self):
with torch.no_grad():
x = torch.randn((5, 5), requires_grad=True)
y = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y)
self.assertTrue(output.allclose(x + 2 * y + x * y))
def test_calling_custom_op_inside_script_module(self):
model = Model()
output = model.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
def test_saving_and_loading_script_module_with_custom_op(self):
model = Model()
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually.
file = tempfile.NamedTemporaryFile(delete=False)
try:
file.close()
model.save(file.name)
loaded = torch.jit.load(file.name)
finally:
os.unlink(file.name)
output = loaded.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
if __name__ == "__main__":
unittest.main()