mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move tensor implicit conversions to test_builtins.py (#55532)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55532 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D27729682 Pulled By: nikithamalgifb fbshipit-source-id: d2517ee68b83e59cde87b8fb7d5bf7203f02cbc6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5dba4ff786
commit
d7d7556f17
@ -5,6 +5,7 @@ import unittest
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
@ -219,3 +220,73 @@ class TestTensorBuiltins(JitTestCase):
|
||||
return a
|
||||
|
||||
self.checkScript(fn6, (torch.zeros(2, dtype=torch.float32, device="cuda"),))
|
||||
|
||||
def test_tensor_item(self):
|
||||
def test_scalar_cast(x):
|
||||
scalar = x.item()
|
||||
return int(scalar), float(scalar)
|
||||
|
||||
graph = torch.jit.script(test_scalar_cast).graph
|
||||
FileCheck().check("(int, float) = prim::TupleConstruct").run(graph)
|
||||
self.checkScript(test_scalar_cast, (torch.tensor(1.0),))
|
||||
self.checkScript(test_scalar_cast, (torch.tensor(1),))
|
||||
|
||||
def test_method_on_number(self):
|
||||
def func():
|
||||
c = 1
|
||||
return c.add(1)
|
||||
with self.assertRaisesRegex(RuntimeError, 'nonexistent attribute or method'):
|
||||
torch.jit.script(func)
|
||||
|
||||
# testing implicit conversion of tensors to scalars to match function arguments
|
||||
def test_scalar_to_num_conversions(self):
|
||||
@torch.jit.script
|
||||
def multiple_defs(x):
|
||||
c = 1
|
||||
x = x + c
|
||||
return x
|
||||
|
||||
self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph))
|
||||
|
||||
@torch.jit.script
|
||||
def tensor_to_int_script(x, tensor):
|
||||
return x.unsqueeze(tensor)
|
||||
|
||||
# location present in error message
|
||||
with self.assertRaisesRegex(RuntimeError, "x.unsqueeze"):
|
||||
tensor_to_int_script(torch.tensor([2]), torch.tensor([2, 2]))
|
||||
|
||||
def tensor_to_int(x, tensor):
|
||||
return x.unsqueeze(tensor)
|
||||
|
||||
@torch.jit.script
|
||||
def tensor_to_float_script(x, tensor):
|
||||
return x.addcmul(tensor, tensor, value=tensor)
|
||||
|
||||
def tensor_to_float(x, tensor):
|
||||
return x.addcmul(tensor, tensor, value=tensor)
|
||||
|
||||
x = torch.zeros(10)
|
||||
# float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
|
||||
tensors = [torch.tensor(1.1),
|
||||
torch.tensor(1.1, requires_grad=True),
|
||||
torch.tensor(0),
|
||||
torch.tensor([2])]
|
||||
|
||||
script_funs = [tensor_to_int_script, tensor_to_float_script]
|
||||
funs = [tensor_to_int, tensor_to_float]
|
||||
|
||||
# return the result, or whether exception was thrown
|
||||
def test_func(func, x, tensor):
|
||||
try:
|
||||
result = func(x, tensor)
|
||||
except RuntimeError as e:
|
||||
result = True
|
||||
except TypeError as e:
|
||||
result = True
|
||||
return result
|
||||
|
||||
# assert result or exception equal for each (function, inputs)
|
||||
for tensor in tensors:
|
||||
for i in range(len(script_funs)):
|
||||
self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
|
||||
|
@ -4622,76 +4622,6 @@ a")
|
||||
test(backward=True)
|
||||
test(backward=True)
|
||||
|
||||
def test_tensor_item(self):
|
||||
def test_scalar_cast(x):
|
||||
scalar = x.item()
|
||||
return int(scalar), float(scalar)
|
||||
|
||||
graph = torch.jit.script(test_scalar_cast).graph
|
||||
FileCheck().check("(int, float) = prim::TupleConstruct").run(graph)
|
||||
self.checkScript(test_scalar_cast, (torch.tensor(1.0),))
|
||||
self.checkScript(test_scalar_cast, (torch.tensor(1),))
|
||||
|
||||
def test_method_on_number(self):
|
||||
def func():
|
||||
c = 1
|
||||
return c.add(1)
|
||||
with self.assertRaisesRegex(RuntimeError, 'nonexistent attribute or method'):
|
||||
torch.jit.script(func)
|
||||
|
||||
# testing implicit conversion of tensors to scalars to match function arguments
|
||||
def test_scalar_to_num_conversions(self):
|
||||
@torch.jit.script
|
||||
def multiple_defs(x):
|
||||
c = 1
|
||||
x = x + c
|
||||
return x
|
||||
|
||||
self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph))
|
||||
|
||||
@torch.jit.script
|
||||
def tensor_to_int_script(x, tensor):
|
||||
return x.unsqueeze(tensor)
|
||||
|
||||
# location present in error message
|
||||
with self.assertRaisesRegex(RuntimeError, "x.unsqueeze"):
|
||||
tensor_to_int_script(torch.tensor([2]), torch.tensor([2, 2]))
|
||||
|
||||
def tensor_to_int(x, tensor):
|
||||
return x.unsqueeze(tensor)
|
||||
|
||||
@torch.jit.script
|
||||
def tensor_to_float_script(x, tensor):
|
||||
return x.addcmul(tensor, tensor, value=tensor)
|
||||
|
||||
def tensor_to_float(x, tensor):
|
||||
return x.addcmul(tensor, tensor, value=tensor)
|
||||
|
||||
x = torch.zeros(10)
|
||||
# float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
|
||||
tensors = [torch.tensor(1.1),
|
||||
torch.tensor(1.1, requires_grad=True),
|
||||
torch.tensor(0),
|
||||
torch.tensor([2])]
|
||||
|
||||
script_funs = [tensor_to_int_script, tensor_to_float_script]
|
||||
funs = [tensor_to_int, tensor_to_float]
|
||||
|
||||
# return the result, or whether exception was thrown
|
||||
def test_func(func, x, tensor):
|
||||
try:
|
||||
result = func(x, tensor)
|
||||
except RuntimeError as e:
|
||||
result = True
|
||||
except TypeError as e:
|
||||
result = True
|
||||
return result
|
||||
|
||||
# assert result or exception equal for each (function, inputs)
|
||||
for tensor in tensors:
|
||||
for i in range(len(script_funs)):
|
||||
self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
|
||||
|
||||
def test_module_copy_with_attributes(self):
|
||||
class Vocabulary(torch.jit.ScriptModule):
|
||||
def __init__(self, vocab_list):
|
||||
|
Reference in New Issue
Block a user