Add support for the ONNX Runtime Eager Mode backend (#58248)

Summary:
This PR implements the necessary hooks/stubs/enums/etc for complete ONNX Runtime (ORT) Eager Mode integration. The actual extension will live out of tree at https://github.com/pytorch/ort.

We have been [working on this at Microsoft](https://github.com/microsoft/onnxruntime-pytorch/tree/eager-ort/torch_onnxruntime) for the last few months, and are finally ready to contribute the PyTorch core changes upstream (nothing major or exciting, just the usual boilerplate for adding new backends).

The ORT backend will allow us to ferry [almost] all torch ops into granular ONNX kernels that ORT will eagerly execute against any devices it supports (therefore, we only need a single ORT backend from a PyTorch perspective).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58248

Reviewed By: astaff

Differential Revision: D30344992

Pulled By: albanD

fbshipit-source-id: 69082b32121246340d686e16653626114b7714b2
This commit is contained in:
Aaron Bockover
2021-08-20 11:11:47 -07:00
committed by Facebook GitHub Bot
parent b95ce1591d
commit c78ab28441
38 changed files with 236 additions and 120 deletions

View File

@ -19,11 +19,11 @@ except ImportError as e:
try:
if HAS_PYTEST:
cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
msnpu_extension = pytest.importorskip("torch_test_cpp_extension.msnpu")
ort_extension = pytest.importorskip("torch_test_cpp_extension.ort")
rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
else:
import torch_test_cpp_extension.cpp as cpp_extension
import torch_test_cpp_extension.msnpu as msnpu_extension
import torch_test_cpp_extension.ort as ort_extension
import torch_test_cpp_extension.rng as rng_extension
except ImportError as e:
raise RuntimeError(
@ -100,45 +100,45 @@ class TestCppExtensionAOT(common.TestCase):
self.assertFalse(has_value)
class TestMSNPUTensor(common.TestCase):
class TestORTTensor(common.TestCase):
def test_unregistered(self):
a = torch.arange(0, 10, device='cpu')
with self.assertRaisesRegex(RuntimeError, "Could not run"):
b = torch.arange(0, 10, device='msnpu')
b = torch.arange(0, 10, device='ort')
def test_zeros(self):
a = torch.empty(5, 5, device='cpu')
self.assertEqual(a.device, torch.device('cpu'))
b = torch.empty(5, 5, device='msnpu')
self.assertEqual(b.device, torch.device('msnpu', 0))
self.assertEqual(msnpu_extension.get_test_int(), 0)
b = torch.empty(5, 5, device='ort')
self.assertEqual(b.device, torch.device('ort', 0))
self.assertEqual(ort_extension.get_test_int(), 0)
self.assertEqual(torch.get_default_dtype(), b.dtype)
c = torch.empty((5, 5), dtype=torch.int64, device='msnpu')
self.assertEqual(msnpu_extension.get_test_int(), 0)
c = torch.empty((5, 5), dtype=torch.int64, device='ort')
self.assertEqual(ort_extension.get_test_int(), 0)
self.assertEqual(torch.int64, c.dtype)
def test_add(self):
a = torch.empty(5, 5, device='msnpu', requires_grad=True)
self.assertEqual(msnpu_extension.get_test_int(), 0)
a = torch.empty(5, 5, device='ort', requires_grad=True)
self.assertEqual(ort_extension.get_test_int(), 0)
b = torch.empty(5, 5, device='msnpu')
self.assertEqual(msnpu_extension.get_test_int(), 0)
b = torch.empty(5, 5, device='ort')
self.assertEqual(ort_extension.get_test_int(), 0)
c = a + b
self.assertEqual(msnpu_extension.get_test_int(), 1)
self.assertEqual(ort_extension.get_test_int(), 1)
def test_conv_backend_override(self):
# To simplify tests, we use 4d input here to avoid doing view4d( which
# needs more overrides) in _convolution.
input = torch.empty(2, 4, 10, 2, device='msnpu', requires_grad=True)
weight = torch.empty(6, 4, 2, 2, device='msnpu', requires_grad=True)
bias = torch.empty(6, device='msnpu')
input = torch.empty(2, 4, 10, 2, device='ort', requires_grad=True)
weight = torch.empty(6, 4, 2, 2, device='ort', requires_grad=True)
bias = torch.empty(6, device='ort')
# Make sure forward is overriden
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
self.assertEqual(msnpu_extension.get_test_int(), 2)
self.assertEqual(ort_extension.get_test_int(), 2)
self.assertEqual(out.shape[0], input.shape[0])
self.assertEqual(out.shape[1], weight.shape[0])
@ -146,7 +146,7 @@ class TestMSNPUTensor(common.TestCase):
# Double backward is dispatched to _convolution_double_backward.
# It is not tested here as it involves more computation/overrides.
grad = torch.autograd.grad(out, input, out, create_graph=True)
self.assertEqual(msnpu_extension.get_test_int(), 3)
self.assertEqual(ort_extension.get_test_int(), 3)
self.assertEqual(grad[0].shape, input.shape)