mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix torchrec multiprocess tests (#158159)
Summary: The new version of `get_device_tflops` imported something from testing, which imported common_utils.py, which disabled global flags. Test Plan: Fixing existing tests Rollback Plan: Differential Revision: D78192700 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158159 Approved by: https://github.com/nipung90, https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
058fb1790f
commit
9cd521de4d
@ -2178,7 +2178,10 @@ def get_device_tflops(dtype: torch.dtype) -> float:
|
||||
|
||||
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
|
||||
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (
|
||||
8,
|
||||
0,
|
||||
)
|
||||
|
||||
assert dtype in (torch.float16, torch.bfloat16, torch.float32)
|
||||
|
||||
|
Reference in New Issue
Block a user