From 9cd521de4dad5fc6bca94e253a9334b9a521acb0 Mon Sep 17 00:00:00 2001 From: Gabriel Ferns Date: Tue, 15 Jul 2025 05:44:33 +0000 Subject: [PATCH] Fix torchrec multiprocess tests (#158159) Summary: The new version of `get_device_tflops` imported something from testing, which imported common_utils.py, which disabled global flags. Test Plan: Fixing existing tests Rollback Plan: Differential Revision: D78192700 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158159 Approved by: https://github.com/nipung90, https://github.com/huydhn --- torch/_inductor/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 10701d0d8b2d..d22d67cecff2 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2178,7 +2178,10 @@ def get_device_tflops(dtype: torch.dtype) -> float: from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops - from torch.testing._internal.common_cuda import SM80OrLater + SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + 8, + 0, + ) assert dtype in (torch.float16, torch.bfloat16, torch.float32)