mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for the ONNX Runtime Eager Mode backend (#58248)
Summary: This PR implements the necessary hooks/stubs/enums/etc for complete ONNX Runtime (ORT) Eager Mode integration. The actual extension will live out of tree at https://github.com/pytorch/ort. We have been [working on this at Microsoft](https://github.com/microsoft/onnxruntime-pytorch/tree/eager-ort/torch_onnxruntime) for the last few months, and are finally ready to contribute the PyTorch core changes upstream (nothing major or exciting, just the usual boilerplate for adding new backends). The ORT backend will allow us to ferry [almost] all torch ops into granular ONNX kernels that ORT will eagerly execute against any devices it supports (therefore, we only need a single ORT backend from a PyTorch perspective). Pull Request resolved: https://github.com/pytorch/pytorch/pull/58248 Reviewed By: astaff Differential Revision: D30344992 Pulled By: albanD fbshipit-source-id: 69082b32121246340d686e16653626114b7714b2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b95ce1591d
commit
c78ab28441
@ -19,11 +19,11 @@ except ImportError as e:
|
||||
try:
|
||||
if HAS_PYTEST:
|
||||
cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
|
||||
msnpu_extension = pytest.importorskip("torch_test_cpp_extension.msnpu")
|
||||
ort_extension = pytest.importorskip("torch_test_cpp_extension.ort")
|
||||
rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
|
||||
else:
|
||||
import torch_test_cpp_extension.cpp as cpp_extension
|
||||
import torch_test_cpp_extension.msnpu as msnpu_extension
|
||||
import torch_test_cpp_extension.ort as ort_extension
|
||||
import torch_test_cpp_extension.rng as rng_extension
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
@ -100,45 +100,45 @@ class TestCppExtensionAOT(common.TestCase):
|
||||
self.assertFalse(has_value)
|
||||
|
||||
|
||||
class TestMSNPUTensor(common.TestCase):
|
||||
class TestORTTensor(common.TestCase):
|
||||
def test_unregistered(self):
|
||||
a = torch.arange(0, 10, device='cpu')
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not run"):
|
||||
b = torch.arange(0, 10, device='msnpu')
|
||||
b = torch.arange(0, 10, device='ort')
|
||||
|
||||
def test_zeros(self):
|
||||
a = torch.empty(5, 5, device='cpu')
|
||||
self.assertEqual(a.device, torch.device('cpu'))
|
||||
|
||||
b = torch.empty(5, 5, device='msnpu')
|
||||
self.assertEqual(b.device, torch.device('msnpu', 0))
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
b = torch.empty(5, 5, device='ort')
|
||||
self.assertEqual(b.device, torch.device('ort', 0))
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
self.assertEqual(torch.get_default_dtype(), b.dtype)
|
||||
|
||||
c = torch.empty((5, 5), dtype=torch.int64, device='msnpu')
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
c = torch.empty((5, 5), dtype=torch.int64, device='ort')
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
self.assertEqual(torch.int64, c.dtype)
|
||||
|
||||
def test_add(self):
|
||||
a = torch.empty(5, 5, device='msnpu', requires_grad=True)
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
a = torch.empty(5, 5, device='ort', requires_grad=True)
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
|
||||
b = torch.empty(5, 5, device='msnpu')
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
b = torch.empty(5, 5, device='ort')
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
|
||||
c = a + b
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 1)
|
||||
self.assertEqual(ort_extension.get_test_int(), 1)
|
||||
|
||||
def test_conv_backend_override(self):
|
||||
# To simplify tests, we use 4d input here to avoid doing view4d( which
|
||||
# needs more overrides) in _convolution.
|
||||
input = torch.empty(2, 4, 10, 2, device='msnpu', requires_grad=True)
|
||||
weight = torch.empty(6, 4, 2, 2, device='msnpu', requires_grad=True)
|
||||
bias = torch.empty(6, device='msnpu')
|
||||
input = torch.empty(2, 4, 10, 2, device='ort', requires_grad=True)
|
||||
weight = torch.empty(6, 4, 2, 2, device='ort', requires_grad=True)
|
||||
bias = torch.empty(6, device='ort')
|
||||
|
||||
# Make sure forward is overriden
|
||||
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 2)
|
||||
self.assertEqual(ort_extension.get_test_int(), 2)
|
||||
self.assertEqual(out.shape[0], input.shape[0])
|
||||
self.assertEqual(out.shape[1], weight.shape[0])
|
||||
|
||||
@ -146,7 +146,7 @@ class TestMSNPUTensor(common.TestCase):
|
||||
# Double backward is dispatched to _convolution_double_backward.
|
||||
# It is not tested here as it involves more computation/overrides.
|
||||
grad = torch.autograd.grad(out, input, out, create_graph=True)
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 3)
|
||||
self.assertEqual(ort_extension.get_test_int(), 3)
|
||||
self.assertEqual(grad[0].shape, input.shape)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user