mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
rename ort to maia (#123265)
Fixes #123264 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123265 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
bffecb5aff
commit
5f5778476a
@ -26,11 +26,11 @@ except ImportError as e:
|
||||
try:
|
||||
if HAS_PYTEST:
|
||||
cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
|
||||
ort_extension = pytest.importorskip("torch_test_cpp_extension.ort")
|
||||
maia_extension = pytest.importorskip("torch_test_cpp_extension.maia")
|
||||
rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
|
||||
else:
|
||||
import torch_test_cpp_extension.cpp as cpp_extension
|
||||
import torch_test_cpp_extension.ort as ort_extension
|
||||
import torch_test_cpp_extension.maia as maia_extension
|
||||
import torch_test_cpp_extension.rng as rng_extension
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
@ -255,46 +255,46 @@ class TestPybindTypeCasters(common.TestCase):
|
||||
|
||||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestORTTensor(common.TestCase):
|
||||
class TestMAIATensor(common.TestCase):
|
||||
def test_unregistered(self):
|
||||
a = torch.arange(0, 10, device="cpu")
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not run"):
|
||||
b = torch.arange(0, 10, device="ort")
|
||||
b = torch.arange(0, 10, device="maia")
|
||||
|
||||
@skipIfTorchDynamo("dynamo cannot model ort device")
|
||||
@skipIfTorchDynamo("dynamo cannot model maia device")
|
||||
def test_zeros(self):
|
||||
a = torch.empty(5, 5, device="cpu")
|
||||
self.assertEqual(a.device, torch.device("cpu"))
|
||||
|
||||
b = torch.empty(5, 5, device="ort")
|
||||
self.assertEqual(b.device, torch.device("ort", 0))
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
b = torch.empty(5, 5, device="maia")
|
||||
self.assertEqual(b.device, torch.device("maia", 0))
|
||||
self.assertEqual(maia_extension.get_test_int(), 0)
|
||||
self.assertEqual(torch.get_default_dtype(), b.dtype)
|
||||
|
||||
c = torch.empty((5, 5), dtype=torch.int64, device="ort")
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
c = torch.empty((5, 5), dtype=torch.int64, device="maia")
|
||||
self.assertEqual(maia_extension.get_test_int(), 0)
|
||||
self.assertEqual(torch.int64, c.dtype)
|
||||
|
||||
def test_add(self):
|
||||
a = torch.empty(5, 5, device="ort", requires_grad=True)
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
a = torch.empty(5, 5, device="maia", requires_grad=True)
|
||||
self.assertEqual(maia_extension.get_test_int(), 0)
|
||||
|
||||
b = torch.empty(5, 5, device="ort")
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
b = torch.empty(5, 5, device="maia")
|
||||
self.assertEqual(maia_extension.get_test_int(), 0)
|
||||
|
||||
c = a + b
|
||||
self.assertEqual(ort_extension.get_test_int(), 1)
|
||||
self.assertEqual(maia_extension.get_test_int(), 1)
|
||||
|
||||
def test_conv_backend_override(self):
|
||||
# To simplify tests, we use 4d input here to avoid doing view4d( which
|
||||
# needs more overrides) in _convolution.
|
||||
input = torch.empty(2, 4, 10, 2, device="ort", requires_grad=True)
|
||||
weight = torch.empty(6, 4, 2, 2, device="ort", requires_grad=True)
|
||||
bias = torch.empty(6, device="ort")
|
||||
input = torch.empty(2, 4, 10, 2, device="maia", requires_grad=True)
|
||||
weight = torch.empty(6, 4, 2, 2, device="maia", requires_grad=True)
|
||||
bias = torch.empty(6, device="maia")
|
||||
|
||||
# Make sure forward is overriden
|
||||
out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1)
|
||||
self.assertEqual(ort_extension.get_test_int(), 2)
|
||||
self.assertEqual(maia_extension.get_test_int(), 2)
|
||||
self.assertEqual(out.shape[0], input.shape[0])
|
||||
self.assertEqual(out.shape[1], weight.shape[0])
|
||||
|
||||
@ -302,7 +302,7 @@ class TestORTTensor(common.TestCase):
|
||||
# Double backward is dispatched to _convolution_double_backward.
|
||||
# It is not tested here as it involves more computation/overrides.
|
||||
grad = torch.autograd.grad(out, input, out, create_graph=True)
|
||||
self.assertEqual(ort_extension.get_test_int(), 3)
|
||||
self.assertEqual(maia_extension.get_test_int(), 3)
|
||||
self.assertEqual(grad[0].shape, input.shape)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user