mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Example: ``` graph(): %arg0 : [#users=3] = placeholder[target=arg0] %arg_guard_equality_check : [#users=1] = call_function[target=torch._tensor_equal](args = (%arg0, (1, 1, 2), (2, 2, 1), torch.float32), kwargs = {}) %_assert_true : [#users=0] = call_function[target=torch._assert_true](args = (%arg_guard_equality_check, Guard evaluation failed equality check for arg0), kwargs = {}) %add : [#users=1] = call_function[target=operator.add](args = (%arg0, 1), kwargs = {}) return ([arg0, arg0], (add, add)) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/84617 Approved by: https://github.com/jansel
33 lines
989 B
Python
33 lines
989 B
Python
#!/usr/bin/env python3
|
|
# Owner(s): ["module: internals"]
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
class TestComparisonUtils(TestCase):
|
|
def test_all_equal_no_assert(self):
|
|
t = torch.tensor([0.5])
|
|
torch._assert_tensor_metadata(t, [1], [1], torch.float)
|
|
|
|
def test_all_equal_no_assert_nones(self):
|
|
t = torch.tensor([0.5])
|
|
torch._assert_tensor_metadata(t, None, None, None)
|
|
|
|
def test_assert_dtype(self):
|
|
t = torch.tensor([0.5])
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
torch._assert_tensor_metadata(t, None, None, torch.int32)
|
|
|
|
def test_assert_strides(self):
|
|
t = torch.tensor([0.5])
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
torch._assert_tensor_metadata(t, None, [3], torch.float)
|
|
|
|
def test_assert_sizes(self):
|
|
t = torch.tensor([0.5])
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
torch._assert_tensor_metadata(t, [3], [1], torch.float)
|