diff --git a/test/bench_mps_ops.py b/test/bench_mps_ops.py index 05a0587b8398..305b4c2a9851 100644 --- a/test/bench_mps_ops.py +++ b/test/bench_mps_ops.py @@ -90,7 +90,7 @@ def bench_reduction( return reduction_func(t, dim=0) f.__name__ = reduction_func.__name__ - f_c = torch.compile(f, dynamic=False) + f_c = torch.compile(f, dynamic=False, fullgraph=True) for size in (512, 1024, 2048, 4096): x = torch.testing.make_tensor(size, size, device=device, dtype=dtype) @@ -116,7 +116,7 @@ def bench_scan( def f(t): return scan_func(t, dim=dim) - f_c = torch.compile(f, dynamic=False) + f_c = torch.compile(f, dynamic=False, fullgraph=True) for size in (32, 128, 512, 1024): f.__name__ = f"{scan_func.__name__}-dim{dim}-{size}x{size}" @@ -135,7 +135,7 @@ def bench_scan( def f_1d(t): return scan_func(t, dim=0) - f_1d_c = torch.compile(f_1d, dynamic=False) + f_1d_c = torch.compile(f_1d, dynamic=False, fullgraph=True) for size in (100, 10000, 1000000): f_1d.__name__ = f"{scan_func.__name__}-1d-{size}" @@ -204,4 +204,5 @@ def main() -> None: if __name__ == "__main__": + torch._dynamo.config.cache_size_limit = 2**16 main()