mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163646 Approved by: https://github.com/jansel ghstack dependencies: #163626, #163627, #163629, #163643, #163644, #163645
111 lines
3.3 KiB
Python
111 lines
3.3 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
from typing import List, Tuple
|
|
|
|
import torch
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
class TestHash(JitTestCase):
|
|
def test_hash_tuple(self):
|
|
def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool:
|
|
return hash(t1) == hash(t2)
|
|
|
|
self.checkScript(fn, ((1, 2), (1, 2)))
|
|
self.checkScript(fn, ((1, 2), (3, 4)))
|
|
self.checkScript(fn, ((1, 2), (2, 1)))
|
|
|
|
def test_hash_tuple_nested_unhashable_type(self):
|
|
# Tuples may contain unhashable types like `list`, check that we error
|
|
# properly in that case.
|
|
@torch.jit.script
|
|
def fn_unhashable(t1: Tuple[int, List[int]]):
|
|
return hash(t1)
|
|
|
|
with self.assertRaisesRegexWithHighlight(RuntimeError, "unhashable", "hash"):
|
|
fn_unhashable((1, [1]))
|
|
|
|
def test_hash_tensor(self):
|
|
"""Tensors should hash by identity"""
|
|
|
|
def fn(t1, t2):
|
|
return hash(t1) == hash(t2)
|
|
|
|
tensor1 = torch.tensor(1)
|
|
tensor1_clone = torch.tensor(1)
|
|
tensor2 = torch.tensor(2)
|
|
|
|
self.checkScript(fn, (tensor1, tensor1))
|
|
self.checkScript(fn, (tensor1, tensor1_clone))
|
|
self.checkScript(fn, (tensor1, tensor2))
|
|
|
|
def test_hash_none(self):
|
|
def fn():
|
|
n1 = None
|
|
n2 = None
|
|
return hash(n1) == hash(n2)
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def test_hash_bool(self):
|
|
def fn(b1: bool, b2: bool):
|
|
return hash(b1) == hash(b2)
|
|
|
|
self.checkScript(fn, (True, False))
|
|
self.checkScript(fn, (True, True))
|
|
self.checkScript(fn, (False, True))
|
|
self.checkScript(fn, (False, False))
|
|
|
|
def test_hash_float(self):
|
|
def fn(f1: float, f2: float):
|
|
return hash(f1) == hash(f2)
|
|
|
|
self.checkScript(fn, (1.2345, 1.2345))
|
|
self.checkScript(fn, (1.2345, 6.789))
|
|
self.checkScript(fn, (1.2345, float("inf")))
|
|
self.checkScript(fn, (float("inf"), float("inf")))
|
|
self.checkScript(fn, (1.2345, float("nan")))
|
|
self.checkScript(fn, (float("nan"), float("inf")))
|
|
|
|
def test_hash_int(self):
|
|
def fn(i1: int, i2: int):
|
|
return hash(i1) == hash(i2)
|
|
|
|
self.checkScript(fn, (123, 456))
|
|
self.checkScript(fn, (123, 123))
|
|
self.checkScript(fn, (123, -123))
|
|
self.checkScript(fn, (-123, -123))
|
|
self.checkScript(fn, (123, 0))
|
|
|
|
def test_hash_string(self):
|
|
def fn(s1: str, s2: str):
|
|
return hash(s1) == hash(s2)
|
|
|
|
self.checkScript(fn, ("foo", "foo"))
|
|
self.checkScript(fn, ("foo", "bar"))
|
|
self.checkScript(fn, ("foo", ""))
|
|
|
|
def test_hash_device(self):
|
|
def fn(d1: torch.device, d2: torch.device):
|
|
return hash(d1) == hash(d2)
|
|
|
|
gpu0 = torch.device("cuda:0")
|
|
gpu1 = torch.device("cuda:1")
|
|
cpu = torch.device("cpu")
|
|
self.checkScript(fn, (gpu0, gpu0))
|
|
self.checkScript(fn, (gpu0, gpu1))
|
|
self.checkScript(fn, (gpu0, cpu))
|
|
self.checkScript(fn, (cpu, cpu))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_jit.py")
|