Disable RPC profiling for kineto profilers

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76234

RPC profiling is only enabled when the profiler is of legacy type.

Differential Revision: [D35484579](https://our.internmc.facebook.com/intern/diff/D35484579/)

Approved by: https://github.com/H-Huang
This commit is contained in:
Rohan Varma
2022-04-26 08:13:12 -07:00
committed by PyTorch MergeBot
parent 5cd880f4c0
commit ec62901a2c
3 changed files with 42 additions and 5 deletions

View File

@ -604,7 +604,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
torch._C._log_api_usage_once("torch.distributed.rpc_remote")
qualified_name = torch.jit._builtins._find_builtin(func)
dst_worker_info = _to_worker_info(to)
should_profile = torch.autograd._profiler_enabled()
should_profile = _get_should_profile()
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info)
@ -657,7 +657,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
qualified_name = torch.jit._builtins._find_builtin(func)
dst_worker_info = _to_worker_info(to)
should_profile = torch.autograd._profiler_enabled()
should_profile = _get_should_profile()
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)
@ -879,6 +879,14 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
_thread_local_var.future_list.append(fut)
return fut
def _get_should_profile():
# Legacy profiler should be enabled. RPC profiling is not supported with
# Kineto profiler.
ActiveProfilerType = torch._C._autograd.ActiveProfilerType
return (
torch.autograd._profiler_enabled() and
torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
)
def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
ctx_manager = contextlib.suppress()