mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Ignore eager profiling code in training IR (#140826)
Differential Revision: [D66010452](https://our.internmc.facebook.com/intern/diff/D66010452/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140826 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf8709b08a
commit
b86b5349cb
@ -7356,6 +7356,17 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
||||
mod(torch.randn(10, 10))
|
||||
export(mod, (torch.randn(10, 10),), strict=False)
|
||||
|
||||
def test_profiling_code(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
with torch.profiler.record_function("foo"):
|
||||
return x.sin()
|
||||
|
||||
ep = export(Foo(), (torch.randn(5, 5),))
|
||||
FileCheck().check_count(
|
||||
"torch.ops.profiler._record_function_enter_new.default", 0, exactly=True
|
||||
).run(ep.graph_module.code)
|
||||
|
||||
def test_predispatch_cond(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -1505,6 +1505,22 @@ def _export_to_aten_ir_make_fx(
|
||||
|
||||
hook.remove() # type: ignore[possibly-undefined]
|
||||
|
||||
# In export, we ignore any op that is related to
|
||||
# eager mode profiling call. The expectation is
|
||||
# that either runtimes provide their own profiling
|
||||
# OR user wrap the compiled region on a profiling in
|
||||
# later stage.
|
||||
def _is_impure(node):
|
||||
if node.op == "call_function" and node.target in (
|
||||
torch.ops.profiler._record_function_enter.default,
|
||||
torch.ops.profiler._record_function_enter_new.default,
|
||||
torch.ops.profiler._record_function_exit.default,
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
gm.graph.eliminate_dead_code(_is_impure)
|
||||
|
||||
# create graph signature
|
||||
input_names = _graph_input_names(gm)
|
||||
output_names = _graph_output_names(gm)
|
||||
|
Reference in New Issue
Block a user