[torchfuzz] Support EagerVsFullGraphDynamicCompileWithNumericsCheck (#164432)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164432
Approved by: https://github.com/pianpwk
This commit is contained in:
bobrenjc93
2025-10-02 10:36:20 -07:00
committed by PyTorch MergeBot
parent 2a760dc51e
commit 7617b113ad
2 changed files with 30 additions and 2 deletions

View File

@ -23,3 +23,31 @@ class EagerVsFullGraphDynamicCompileCheck(Check):
"result_compiled = compiled_program(*args)",
"print('✅ compile success')",
]
class EagerVsFullGraphDynamicCompileWithNumericsCheck(Check):
"""Check that runs eager and compiled, compares forward numerics."""
def codegen(self, args_tuple: str) -> list[str]:
return [
f"args = {args_tuple}",
"out_eager = fuzzed_program(*args)",
"out_eager.sum().backward()",
"print('Eager Success! ✅')",
"compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)",
"out_compiled = compiled_program(*args)",
"out_compiled.sum().backward()",
"print('Compile Success! ✅')",
"out_eager_sum = out_eager.sum()",
"out_compiled_sum = out_compiled.sum()",
"diff = (out_eager_sum - out_compiled_sum).abs().item()",
"rel_diff = diff / (out_eager_sum.abs().item() + 1e-12) * 100",
"print(f'Relative diff (sum): {rel_diff:.6f}%')",
"if rel_diff > 5:",
" print(f'❌ Forward output sums differ significantly (relative)!')",
" print('out_eager_sum:', out_eager_sum.item())",
" print('out_compiled_sum:', out_compiled_sum.item())",
" print('Absolute diff:', diff)",
" print('Relative diff (%):', rel_diff)",
" import sys; sys.exit(1)",
]

View File

@ -107,7 +107,7 @@ class FuzzTemplate:
class DefaultFuzzTemplate(FuzzTemplate):
def __init__(self):
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithNumericsCheck
super().__init__(
supported_ops=[
@ -125,7 +125,7 @@ class DefaultFuzzTemplate(FuzzTemplate):
"torch.bmm",
"torch.matmul",
],
check=EagerVsFullGraphDynamicCompileCheck(),
check=EagerVsFullGraphDynamicCompileWithNumericsCheck(),
)
def spec_distribution(self):