mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ONNX] Fix symbolic values and numpy implementation (#135786)
1. Remove `__eq__` to make `SymbolicTensor` hashable and test for that 2. Update the `__array__` method so that it works for tensor on GPU Fixes https://github.com/pytorch/pytorch/issues/135700 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135786 Approved by: https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
dddaadac6c
commit
d67cc58181
22
test/onnx/exporter/test_tensors.py
Normal file
22
test/onnx/exporter/test_tensors.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
"""Unit tests for the _tensors module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import onnxscript
|
||||
|
||||
from torch.onnx._internal.exporter import _tensors
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class SymbolicTensorTest(common_utils.TestCase):
|
||||
def test_it_is_hashable(self):
|
||||
tensor = _tensors.SymbolicTensor(
|
||||
opset=onnxscript.values.Opset(domain="test", version=1)
|
||||
)
|
||||
self.assertEqual(hash(tensor), hash(tensor))
|
||||
self.assertIn(tensor, {tensor})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
@ -99,8 +99,9 @@ class TorchTensor(ir.Tensor):
|
||||
|
||||
def __array__(self, dtype: Any = None) -> np.ndarray:
|
||||
# numpy() calls __array__ in ir.Tensor
|
||||
self.raw: torch.Tensor
|
||||
if self.dtype == ir.DataType.BFLOAT16:
|
||||
return self.raw.view(torch.uint16).__array__(dtype)
|
||||
return self.raw.view(torch.uint16).numpy(force=True).__array__(dtype)
|
||||
if self.dtype in {
|
||||
ir.DataType.FLOAT8E4M3FN,
|
||||
ir.DataType.FLOAT8E4M3FNUZ,
|
||||
@ -108,8 +109,8 @@ class TorchTensor(ir.Tensor):
|
||||
ir.DataType.FLOAT8E5M2FNUZ,
|
||||
}:
|
||||
# TODO: Use ml_dtypes
|
||||
return self.raw.view(torch.uint8).__array__(dtype)
|
||||
return self.raw.__array__(dtype)
|
||||
return self.raw.view(torch.uint8).numpy(force=True).__array__(dtype)
|
||||
return self.raw.numpy(force=True).__array__(dtype)
|
||||
|
||||
def tobytes(self) -> bytes:
|
||||
# Implement tobytes to support native PyTorch types so we can use types like bloat16
|
||||
|
@ -88,9 +88,6 @@ class SymbolicTensor(ir.Value):
|
||||
def __le__(self, other):
|
||||
return self._opset.LessOrEqual(self, other)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._opset.Equal(self, other)
|
||||
|
||||
def __ge__(self, other):
|
||||
return self._opset.GreaterOrEqual(self, other)
|
||||
|
||||
|
Reference in New Issue
Block a user