mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][CPP] Fix the test case of test_linear_reuse_kernels (#163723)
Fixes #163491. Add tolerances to make `test_linear_reuse_kernels` more stable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163723 Approved by: https://github.com/leslie-fang-intel
This commit is contained in:
@ -2936,8 +2936,18 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||||||
|
|
||||||
x = torch.randn(batch_size, in_features).to(dtype=dtype)
|
x = torch.randn(batch_size, in_features).to(dtype=dtype)
|
||||||
mod = M().to(dtype=dtype).eval()
|
mod = M().to(dtype=dtype).eval()
|
||||||
self.common(mod, (x))
|
with verify(dtype) as (atol, rtol):
|
||||||
_, code = run_and_get_cpp_code(mod, x)
|
ref_res = mod(x)
|
||||||
|
m = torch.compile(mod)
|
||||||
|
res, code = run_and_get_cpp_code(m, x)
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
ref_res,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
equal_nan=True,
|
||||||
|
exact_dtype=True,
|
||||||
|
)
|
||||||
# Check that only 2 kernels are in the generated code
|
# Check that only 2 kernels are in the generated code
|
||||||
assert code.count("AMXState amx_state") == 2
|
assert code.count("AMXState amx_state") == 2
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user