Revert "[Submodule] Remove deprecated USE_TBB option and TBB submodule (#127051)"

This reverts commit 699db7988d84d163ebb6919f78885e4630182a7a.

Reverted https://github.com/pytorch/pytorch/pull/127051 on behalf of https://github.com/PaliC due to This PR needs to be synced using the import button as there is a bug in our diff train ([comment](https://github.com/pytorch/pytorch/pull/127051#issuecomment-2138496995))
This commit is contained in:
PyTorch MergeBot
2024-05-30 01:16:57 +00:00
parent 1abcac9dab
commit 67739d8c6f
34 changed files with 863 additions and 19 deletions

View File

@ -1497,6 +1497,8 @@ def disable_translation_validation_if_dynamic_shapes(fn):
# See: https://github.com/pytorch/pytorch/pull/59402#issuecomment-858811135
TestEnvironment.def_flag("TEST_CUDA_MEM_LEAK_CHECK", env_var="PYTORCH_TEST_CUDA_MEM_LEAK_CHECK")
# True if CI is running TBB-enabled Pytorch
IS_TBB = "tbb" in os.getenv("BUILD_ENVIRONMENT", "")
# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
numpy_to_torch_dtype_dict = {
@ -1873,6 +1875,19 @@ def skipIfNoSciPy(fn):
fn(*args, **kwargs)
return wrapper
def skipIfTBB(message="This test makes TBB sad"):
def dec_fn(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if IS_TBB:
raise unittest.SkipTest(message)
else:
fn(*args, **kwargs)
return wrapper
return dec_fn
def skip_if_pytest(fn):
@wraps(fn)
def wrapped(*args, **kwargs):
@ -4708,6 +4723,24 @@ dtype_abbrs = {
}
def set_single_threaded_if_parallel_tbb(fn):
"""Set test to be single threaded for parallel tbb.
See https://github.com/pytorch/pytorch/issues/64571#issuecomment-914691883
"""
if not IS_TBB:
return fn
@wraps(fn)
def wrap_fn(*args, **kwargs):
num_threads = torch.get_num_threads()
torch.set_num_threads(1)
try:
return fn(*args, **kwargs)
finally:
torch.set_num_threads(num_threads)
return wrap_fn
@functools.lru_cache
def get_cycles_per_ms() -> float: