mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Submodule] Remove deprecated USE_TBB option and TBB submodule (#127051)"
This reverts commit 699db7988d84d163ebb6919f78885e4630182a7a. Reverted https://github.com/pytorch/pytorch/pull/127051 on behalf of https://github.com/PaliC due to This PR needs to be synced using the import button as there is a bug in our diff train ([comment](https://github.com/pytorch/pytorch/pull/127051#issuecomment-2138496995))
This commit is contained in:
@ -1497,6 +1497,8 @@ def disable_translation_validation_if_dynamic_shapes(fn):
|
||||
# See: https://github.com/pytorch/pytorch/pull/59402#issuecomment-858811135
|
||||
TestEnvironment.def_flag("TEST_CUDA_MEM_LEAK_CHECK", env_var="PYTORCH_TEST_CUDA_MEM_LEAK_CHECK")
|
||||
|
||||
# True if CI is running TBB-enabled Pytorch
|
||||
IS_TBB = "tbb" in os.getenv("BUILD_ENVIRONMENT", "")
|
||||
|
||||
# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
|
||||
numpy_to_torch_dtype_dict = {
|
||||
@ -1873,6 +1875,19 @@ def skipIfNoSciPy(fn):
|
||||
fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def skipIfTBB(message="This test makes TBB sad"):
|
||||
def dec_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if IS_TBB:
|
||||
raise unittest.SkipTest(message)
|
||||
else:
|
||||
fn(*args, **kwargs)
|
||||
return wrapper
|
||||
return dec_fn
|
||||
|
||||
|
||||
def skip_if_pytest(fn):
|
||||
@wraps(fn)
|
||||
def wrapped(*args, **kwargs):
|
||||
@ -4708,6 +4723,24 @@ dtype_abbrs = {
|
||||
}
|
||||
|
||||
|
||||
def set_single_threaded_if_parallel_tbb(fn):
|
||||
"""Set test to be single threaded for parallel tbb.
|
||||
|
||||
See https://github.com/pytorch/pytorch/issues/64571#issuecomment-914691883
|
||||
"""
|
||||
if not IS_TBB:
|
||||
return fn
|
||||
|
||||
@wraps(fn)
|
||||
def wrap_fn(*args, **kwargs):
|
||||
num_threads = torch.get_num_threads()
|
||||
torch.set_num_threads(1)
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
torch.set_num_threads(num_threads)
|
||||
return wrap_fn
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_cycles_per_ms() -> float:
|
||||
|
||||
Reference in New Issue
Block a user