[test][scan] refactor inductor test and prepare for adding bw tests

[ghstack-poisoned]
This commit is contained in:
Yidi Wu
2025-08-26 15:19:20 -07:00
parent e06d1d6610
commit 46abb2dae4

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: inductor"]
import contextlib
import itertools
import unittest
@ -1755,27 +1755,56 @@ class ScanTests(TestCase):
inputs,
device,
dynamic,
requires_grad=False,
autograd=False,
):
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)(
model
)
import copy
inputs = [inp.requires_grad_(autograd) for inp in inputs]
inputs = [inp.to(device=device) for inp in inputs]
model = model.to(device=device)
cloned_inputs = [inp.clone() for inp in inputs]
grad_ctx = contextlib.nullcontext() if requires_grad else torch.no_grad()
with grad_ctx:
result = model(scan, *cloned_inputs)
result_exp = model(_fake_scan, *cloned_inputs)
for p in model.parameters():
p.requires_grad_(autograd)
result_compiled = compiled_model(scan, *cloned_inputs)
result_compiled_exp = compiled_model(_fake_scan, *cloned_inputs)
model1 = copy.deepcopy(model)
model2 = copy.deepcopy(model)
model3 = copy.deepcopy(model)
model4 = copy.deepcopy(model)
torch.compile(fullgraph=True, dynamic=dynamic)(model)
self.assertEqual(result, result_exp)
def _run_model(model, inputs):
cloned_inputs = [
inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
]
fw_result = model(*cloned_inputs)
loss = loss_fn(fw_result)
if autograd:
loss.backward()
return (
fw_result,
loss,
[
inp.grad
for inp in cloned_inputs
if isinstance(inp, torch.Tensor)
],
{n: p.grad for n, p in model.named_parameters()},
)
else:
return fw_result, loss
result_exp = _run_model(model1, [_fake_scan] + inputs)
result_eager = _run_model(model2, [scan] + inputs)
result_compiled = _run_model(
torch.compile(fullgraph=True, dynamic=dynamic)(model3), [scan] + inputs
)
result_compiled_exp = _run_model(
torch.compile(fullgraph=True, dynamic=dynamic)(model4),
[_fake_scan] + inputs,
)
self.assertEqual(result_exp, result_eager)
self.assertEqual(result_exp, result_compiled)
self.assertEqual(result_compiled, result_compiled_exp)
self.assertEqual(result_exp, result_compiled_exp)
def _compare_result(
self,