[ONNX] Fix symbolic values and numpy implementation (#135786)

1. Remove `__eq__` to make `SymbolicTensor` hashable and test for that
2. Update the `__array__` method so that it works for tensor on GPU

Fixes https://github.com/pytorch/pytorch/issues/135700
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135786
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu
2024-09-12 14:24:43 +00:00
committed by PyTorch MergeBot
parent dddaadac6c
commit d67cc58181
3 changed files with 26 additions and 6 deletions

View File

@ -0,0 +1,22 @@
# Owner(s): ["module: onnx"]
"""Unit tests for the _tensors module."""
from __future__ import annotations
import onnxscript
from torch.onnx._internal.exporter import _tensors
from torch.testing._internal import common_utils
class SymbolicTensorTest(common_utils.TestCase):
def test_it_is_hashable(self):
tensor = _tensors.SymbolicTensor(
opset=onnxscript.values.Opset(domain="test", version=1)
)
self.assertEqual(hash(tensor), hash(tensor))
self.assertIn(tensor, {tensor})
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -99,8 +99,9 @@ class TorchTensor(ir.Tensor):
def __array__(self, dtype: Any = None) -> np.ndarray:
# numpy() calls __array__ in ir.Tensor
self.raw: torch.Tensor
if self.dtype == ir.DataType.BFLOAT16:
return self.raw.view(torch.uint16).__array__(dtype)
return self.raw.view(torch.uint16).numpy(force=True).__array__(dtype)
if self.dtype in {
ir.DataType.FLOAT8E4M3FN,
ir.DataType.FLOAT8E4M3FNUZ,
@ -108,8 +109,8 @@ class TorchTensor(ir.Tensor):
ir.DataType.FLOAT8E5M2FNUZ,
}:
# TODO: Use ml_dtypes
return self.raw.view(torch.uint8).__array__(dtype)
return self.raw.__array__(dtype)
return self.raw.view(torch.uint8).numpy(force=True).__array__(dtype)
return self.raw.numpy(force=True).__array__(dtype)
def tobytes(self) -> bytes:
# Implement tobytes to support native PyTorch types so we can use types like bloat16

View File

@ -88,9 +88,6 @@ class SymbolicTensor(ir.Value):
def __le__(self, other):
return self._opset.LessOrEqual(self, other)
def __eq__(self, other):
return self._opset.Equal(self, other)
def __ge__(self, other):
return self._opset.GreaterOrEqual(self, other)