Ignore eager profiling code in training IR (#140826)

Differential Revision: [D66010452](https://our.internmc.facebook.com/intern/diff/D66010452/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140826
Approved by: https://github.com/zhxchen17
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2024-11-15 09:17:08 -08:00
committed by PyTorch MergeBot
parent bf8709b08a
commit b86b5349cb
2 changed files with 27 additions and 0 deletions

View File

@ -7356,6 +7356,17 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
mod(torch.randn(10, 10))
export(mod, (torch.randn(10, 10),), strict=False)
def test_profiling_code(self):
class Foo(torch.nn.Module):
def forward(self, x):
with torch.profiler.record_function("foo"):
return x.sin()
ep = export(Foo(), (torch.randn(5, 5),))
FileCheck().check_count(
"torch.ops.profiler._record_function_enter_new.default", 0, exactly=True
).run(ep.graph_module.code)
def test_predispatch_cond(self):
class Model(torch.nn.Module):
def __init__(self) -> None:

View File

@ -1505,6 +1505,22 @@ def _export_to_aten_ir_make_fx(
hook.remove() # type: ignore[possibly-undefined]
# In export, we ignore any op that is related to
# eager mode profiling call. The expectation is
# that either runtimes provide their own profiling
# OR user wrap the compiled region on a profiling in
# later stage.
def _is_impure(node):
if node.op == "call_function" and node.target in (
torch.ops.profiler._record_function_enter.default,
torch.ops.profiler._record_function_enter_new.default,
torch.ops.profiler._record_function_exit.default,
):
return False
return True
gm.graph.eliminate_dead_code(_is_impure)
# create graph signature
input_names = _graph_input_names(gm)
output_names = _graph_output_names(gm)