Move tensor implicit conversions to test_builtins.py (#55532)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55532

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D27729682

Pulled By: nikithamalgifb

fbshipit-source-id: d2517ee68b83e59cde87b8fb7d5bf7203f02cbc6
This commit is contained in:
nikithamalgi
2021-04-13 07:11:55 -07:00
committed by Facebook GitHub Bot
parent 5dba4ff786
commit d7d7556f17
2 changed files with 71 additions and 70 deletions

View File

@ -5,6 +5,7 @@ import unittest
from typing import Dict, List
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -219,3 +220,73 @@ class TestTensorBuiltins(JitTestCase):
return a
self.checkScript(fn6, (torch.zeros(2, dtype=torch.float32, device="cuda"),))
def test_tensor_item(self):
def test_scalar_cast(x):
scalar = x.item()
return int(scalar), float(scalar)
graph = torch.jit.script(test_scalar_cast).graph
FileCheck().check("(int, float) = prim::TupleConstruct").run(graph)
self.checkScript(test_scalar_cast, (torch.tensor(1.0),))
self.checkScript(test_scalar_cast, (torch.tensor(1),))
def test_method_on_number(self):
def func():
c = 1
return c.add(1)
with self.assertRaisesRegex(RuntimeError, 'nonexistent attribute or method'):
torch.jit.script(func)
# testing implicit conversion of tensors to scalars to match function arguments
def test_scalar_to_num_conversions(self):
@torch.jit.script
def multiple_defs(x):
c = 1
x = x + c
return x
self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph))
@torch.jit.script
def tensor_to_int_script(x, tensor):
return x.unsqueeze(tensor)
# location present in error message
with self.assertRaisesRegex(RuntimeError, "x.unsqueeze"):
tensor_to_int_script(torch.tensor([2]), torch.tensor([2, 2]))
def tensor_to_int(x, tensor):
return x.unsqueeze(tensor)
@torch.jit.script
def tensor_to_float_script(x, tensor):
return x.addcmul(tensor, tensor, value=tensor)
def tensor_to_float(x, tensor):
return x.addcmul(tensor, tensor, value=tensor)
x = torch.zeros(10)
# float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
tensors = [torch.tensor(1.1),
torch.tensor(1.1, requires_grad=True),
torch.tensor(0),
torch.tensor([2])]
script_funs = [tensor_to_int_script, tensor_to_float_script]
funs = [tensor_to_int, tensor_to_float]
# return the result, or whether exception was thrown
def test_func(func, x, tensor):
try:
result = func(x, tensor)
except RuntimeError as e:
result = True
except TypeError as e:
result = True
return result
# assert result or exception equal for each (function, inputs)
for tensor in tensors:
for i in range(len(script_funs)):
self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))

View File

@ -4622,76 +4622,6 @@ a")
test(backward=True)
test(backward=True)
def test_tensor_item(self):
def test_scalar_cast(x):
scalar = x.item()
return int(scalar), float(scalar)
graph = torch.jit.script(test_scalar_cast).graph
FileCheck().check("(int, float) = prim::TupleConstruct").run(graph)
self.checkScript(test_scalar_cast, (torch.tensor(1.0),))
self.checkScript(test_scalar_cast, (torch.tensor(1),))
def test_method_on_number(self):
def func():
c = 1
return c.add(1)
with self.assertRaisesRegex(RuntimeError, 'nonexistent attribute or method'):
torch.jit.script(func)
# testing implicit conversion of tensors to scalars to match function arguments
def test_scalar_to_num_conversions(self):
@torch.jit.script
def multiple_defs(x):
c = 1
x = x + c
return x
self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph))
@torch.jit.script
def tensor_to_int_script(x, tensor):
return x.unsqueeze(tensor)
# location present in error message
with self.assertRaisesRegex(RuntimeError, "x.unsqueeze"):
tensor_to_int_script(torch.tensor([2]), torch.tensor([2, 2]))
def tensor_to_int(x, tensor):
return x.unsqueeze(tensor)
@torch.jit.script
def tensor_to_float_script(x, tensor):
return x.addcmul(tensor, tensor, value=tensor)
def tensor_to_float(x, tensor):
return x.addcmul(tensor, tensor, value=tensor)
x = torch.zeros(10)
# float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
tensors = [torch.tensor(1.1),
torch.tensor(1.1, requires_grad=True),
torch.tensor(0),
torch.tensor([2])]
script_funs = [tensor_to_int_script, tensor_to_float_script]
funs = [tensor_to_int, tensor_to_float]
# return the result, or whether exception was thrown
def test_func(func, x, tensor):
try:
result = func(x, tensor)
except RuntimeError as e:
result = True
except TypeError as e:
result = True
return result
# assert result or exception equal for each (function, inputs)
for tensor in tensors:
for i in range(len(script_funs)):
self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
def test_module_copy_with_attributes(self):
class Vocabulary(torch.jit.ScriptModule):
def __init__(self, vocab_list):