mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Minor fix in jit tests to pass TorchDynamo (#79903)
@jansel Pull Request resolved: https://github.com/pytorch/pytorch/pull/79903 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
0656e9e595
commit
24f550b0cb
@ -173,9 +173,6 @@ class TestJit(JitCommonTestCase):
|
||||
|
||||
@_alias_ops((op for op in op_db if op.aliases))
|
||||
def test_jit_alias_remapping(self, device, dtype, op):
|
||||
# Required to avoid undefined value: tensor error in JIT compilation of the function template
|
||||
tensor = torch.tensor
|
||||
|
||||
# NOTE: only tests on first sample
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
sample = first_sample(self, samples)
|
||||
@ -240,6 +237,11 @@ class TestJit(JitCommonTestCase):
|
||||
args=", ".join(args),
|
||||
args_kw=", ".join(args_kw),
|
||||
)
|
||||
|
||||
# Required to avoid undefined value: tensor error in JIT
|
||||
# compilation of the function template
|
||||
script = script.replace("tensor(", "torch.tensor(")
|
||||
|
||||
scripted = torch.jit.CompilationUnit(script)._fn
|
||||
|
||||
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
|
||||
|
Reference in New Issue
Block a user