Fix torchrec multiprocess tests (#158159)

Summary: The new version of `get_device_tflops` imported something from testing, which imported common_utils.py, which disabled global flags.

Test Plan:
Fixing existing tests

Rollback Plan:

Differential Revision: D78192700

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158159
Approved by: https://github.com/nipung90, https://github.com/huydhn
This commit is contained in:
Gabriel Ferns
2025-07-15 05:44:33 +00:00
committed by PyTorch MergeBot
parent 058fb1790f
commit 9cd521de4d

View File

@ -2178,7 +2178,10 @@ def get_device_tflops(dtype: torch.dtype) -> float:
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
from torch.testing._internal.common_cuda import SM80OrLater
SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (
8,
0,
)
assert dtype in (torch.float16, torch.bfloat16, torch.float32)