mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Ignore eager profiling code in training IR (#140826)
Differential Revision: [D66010452](https://our.internmc.facebook.com/intern/diff/D66010452/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140826 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf8709b08a
commit
b86b5349cb
@ -7356,6 +7356,17 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||||||
mod(torch.randn(10, 10))
|
mod(torch.randn(10, 10))
|
||||||
export(mod, (torch.randn(10, 10),), strict=False)
|
export(mod, (torch.randn(10, 10),), strict=False)
|
||||||
|
|
||||||
|
def test_profiling_code(self):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
with torch.profiler.record_function("foo"):
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
ep = export(Foo(), (torch.randn(5, 5),))
|
||||||
|
FileCheck().check_count(
|
||||||
|
"torch.ops.profiler._record_function_enter_new.default", 0, exactly=True
|
||||||
|
).run(ep.graph_module.code)
|
||||||
|
|
||||||
def test_predispatch_cond(self):
|
def test_predispatch_cond(self):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -1505,6 +1505,22 @@ def _export_to_aten_ir_make_fx(
|
|||||||
|
|
||||||
hook.remove() # type: ignore[possibly-undefined]
|
hook.remove() # type: ignore[possibly-undefined]
|
||||||
|
|
||||||
|
# In export, we ignore any op that is related to
|
||||||
|
# eager mode profiling call. The expectation is
|
||||||
|
# that either runtimes provide their own profiling
|
||||||
|
# OR user wrap the compiled region on a profiling in
|
||||||
|
# later stage.
|
||||||
|
def _is_impure(node):
|
||||||
|
if node.op == "call_function" and node.target in (
|
||||||
|
torch.ops.profiler._record_function_enter.default,
|
||||||
|
torch.ops.profiler._record_function_enter_new.default,
|
||||||
|
torch.ops.profiler._record_function_exit.default,
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
gm.graph.eliminate_dead_code(_is_impure)
|
||||||
|
|
||||||
# create graph signature
|
# create graph signature
|
||||||
input_names = _graph_input_names(gm)
|
input_names = _graph_input_names(gm)
|
||||||
output_names = _graph_output_names(gm)
|
output_names = _graph_output_names(gm)
|
||||||
|
Reference in New Issue
Block a user