mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[torchfuzz] Support EagerVsFullGraphDynamicCompileWithNumericsCheck (#164432)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164432 Approved by: https://github.com/pianpwk
This commit is contained in:
committed by
PyTorch MergeBot
parent
2a760dc51e
commit
7617b113ad
@ -23,3 +23,31 @@ class EagerVsFullGraphDynamicCompileCheck(Check):
|
||||
"result_compiled = compiled_program(*args)",
|
||||
"print('✅ compile success')",
|
||||
]
|
||||
|
||||
|
||||
class EagerVsFullGraphDynamicCompileWithNumericsCheck(Check):
|
||||
"""Check that runs eager and compiled, compares forward numerics."""
|
||||
|
||||
def codegen(self, args_tuple: str) -> list[str]:
|
||||
return [
|
||||
f"args = {args_tuple}",
|
||||
"out_eager = fuzzed_program(*args)",
|
||||
"out_eager.sum().backward()",
|
||||
"print('Eager Success! ✅')",
|
||||
"compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)",
|
||||
"out_compiled = compiled_program(*args)",
|
||||
"out_compiled.sum().backward()",
|
||||
"print('Compile Success! ✅')",
|
||||
"out_eager_sum = out_eager.sum()",
|
||||
"out_compiled_sum = out_compiled.sum()",
|
||||
"diff = (out_eager_sum - out_compiled_sum).abs().item()",
|
||||
"rel_diff = diff / (out_eager_sum.abs().item() + 1e-12) * 100",
|
||||
"print(f'Relative diff (sum): {rel_diff:.6f}%')",
|
||||
"if rel_diff > 5:",
|
||||
" print(f'❌ Forward output sums differ significantly (relative)!')",
|
||||
" print('out_eager_sum:', out_eager_sum.item())",
|
||||
" print('out_compiled_sum:', out_compiled_sum.item())",
|
||||
" print('Absolute diff:', diff)",
|
||||
" print('Relative diff (%):', rel_diff)",
|
||||
" import sys; sys.exit(1)",
|
||||
]
|
||||
|
@ -107,7 +107,7 @@ class FuzzTemplate:
|
||||
|
||||
class DefaultFuzzTemplate(FuzzTemplate):
|
||||
def __init__(self):
|
||||
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
|
||||
from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithNumericsCheck
|
||||
|
||||
super().__init__(
|
||||
supported_ops=[
|
||||
@ -125,7 +125,7 @@ class DefaultFuzzTemplate(FuzzTemplate):
|
||||
"torch.bmm",
|
||||
"torch.matmul",
|
||||
],
|
||||
check=EagerVsFullGraphDynamicCompileCheck(),
|
||||
check=EagerVsFullGraphDynamicCompileWithNumericsCheck(),
|
||||
)
|
||||
|
||||
def spec_distribution(self):
|
||||
|
Reference in New Issue
Block a user