Minor fix in jit tests to pass TorchDynamo (#79903)

@jansel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79903
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain
2022-06-21 00:42:02 +00:00
committed by PyTorch MergeBot
parent 0656e9e595
commit 24f550b0cb

View File

@ -173,9 +173,6 @@ class TestJit(JitCommonTestCase):
@_alias_ops((op for op in op_db if op.aliases))
def test_jit_alias_remapping(self, device, dtype, op):
# Required to avoid undefined value: tensor error in JIT compilation of the function template
tensor = torch.tensor
# NOTE: only tests on first sample
samples = op.sample_inputs(device, dtype, requires_grad=True)
sample = first_sample(self, samples)
@ -240,6 +237,11 @@ class TestJit(JitCommonTestCase):
args=", ".join(args),
args_kw=", ".join(args_kw),
)
# Required to avoid undefined value: tensor error in JIT
# compilation of the function template
script = script.replace("tensor(", "torch.tensor(")
scripted = torch.jit.CompilationUnit(script)._fn
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):