rename ort to maia (#123265)

Fixes #123264

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123265
Approved by: https://github.com/albanD
This commit is contained in:
Ashwin Hari
2024-04-23 00:33:20 +00:00
committed by PyTorch MergeBot
parent bffecb5aff
commit 5f5778476a
39 changed files with 183 additions and 189 deletions

View File

@ -26,11 +26,11 @@ except ImportError as e:
try:
if HAS_PYTEST:
cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
ort_extension = pytest.importorskip("torch_test_cpp_extension.ort")
maia_extension = pytest.importorskip("torch_test_cpp_extension.maia")
rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
else:
import torch_test_cpp_extension.cpp as cpp_extension
import torch_test_cpp_extension.ort as ort_extension
import torch_test_cpp_extension.maia as maia_extension
import torch_test_cpp_extension.rng as rng_extension
except ImportError as e:
raise RuntimeError(
@ -255,46 +255,46 @@ class TestPybindTypeCasters(common.TestCase):
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestORTTensor(common.TestCase):
class TestMAIATensor(common.TestCase):
def test_unregistered(self):
a = torch.arange(0, 10, device="cpu")
with self.assertRaisesRegex(RuntimeError, "Could not run"):
b = torch.arange(0, 10, device="ort")
b = torch.arange(0, 10, device="maia")
@skipIfTorchDynamo("dynamo cannot model ort device")
@skipIfTorchDynamo("dynamo cannot model maia device")
def test_zeros(self):
a = torch.empty(5, 5, device="cpu")
self.assertEqual(a.device, torch.device("cpu"))
b = torch.empty(5, 5, device="ort")
self.assertEqual(b.device, torch.device("ort", 0))
self.assertEqual(ort_extension.get_test_int(), 0)
b = torch.empty(5, 5, device="maia")
self.assertEqual(b.device, torch.device("maia", 0))
self.assertEqual(maia_extension.get_test_int(), 0)
self.assertEqual(torch.get_default_dtype(), b.dtype)
c = torch.empty((5, 5), dtype=torch.int64, device="ort")
self.assertEqual(ort_extension.get_test_int(), 0)
c = torch.empty((5, 5), dtype=torch.int64, device="maia")
self.assertEqual(maia_extension.get_test_int(), 0)
self.assertEqual(torch.int64, c.dtype)
def test_add(self):
a = torch.empty(5, 5, device="ort", requires_grad=True)
self.assertEqual(ort_extension.get_test_int(), 0)
a = torch.empty(5, 5, device="maia", requires_grad=True)
self.assertEqual(maia_extension.get_test_int(), 0)
b = torch.empty(5, 5, device="ort")
self.assertEqual(ort_extension.get_test_int(), 0)
b = torch.empty(5, 5, device="maia")
self.assertEqual(maia_extension.get_test_int(), 0)
c = a + b
self.assertEqual(ort_extension.get_test_int(), 1)
self.assertEqual(maia_extension.get_test_int(), 1)
def test_conv_backend_override(self):
# To simplify tests, we use 4d input here to avoid doing view4d( which
# needs more overrides) in _convolution.
input = torch.empty(2, 4, 10, 2, device="ort", requires_grad=True)
weight = torch.empty(6, 4, 2, 2, device="ort", requires_grad=True)
bias = torch.empty(6, device="ort")
input = torch.empty(2, 4, 10, 2, device="maia", requires_grad=True)
weight = torch.empty(6, 4, 2, 2, device="maia", requires_grad=True)
bias = torch.empty(6, device="maia")
# Make sure forward is overriden
out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1)
self.assertEqual(ort_extension.get_test_int(), 2)
self.assertEqual(maia_extension.get_test_int(), 2)
self.assertEqual(out.shape[0], input.shape[0])
self.assertEqual(out.shape[1], weight.shape[0])
@ -302,7 +302,7 @@ class TestORTTensor(common.TestCase):
# Double backward is dispatched to _convolution_double_backward.
# It is not tested here as it involves more computation/overrides.
grad = torch.autograd.grad(out, input, out, create_graph=True)
self.assertEqual(ort_extension.get_test_int(), 3)
self.assertEqual(maia_extension.get_test_int(), 3)
self.assertEqual(grad[0].shape, input.shape)