[DataPipe] Enable profiler record context in __next__ branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79757

Approved by: https://github.com/ejguan
This commit is contained in:
Kevin Tse
2022-06-16 18:53:15 -04:00
committed by PyTorch MergeBot
parent 25ca006707
commit e8ed16f3c0

View File

@ -171,7 +171,8 @@ def hook_iterator(namespace, profile_name):
@functools.wraps(next_func) @functools.wraps(next_func)
def wrap_next(*args, **kwargs): def wrap_next(*args, **kwargs):
if torch.autograd._profiler_enabled(): if torch.autograd._profiler_enabled():
return next_func(*args, **kwargs) with profiler_record_fn_context():
return next_func(*args, **kwargs)
else: else:
return next_func(*args, **kwargs) return next_func(*args, **kwargs)