[dynamo][compile-time] Cache the cleaned insturctions while inlining (#153333)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153333
Approved by: https://github.com/StrongerXi, https://github.com/jansel, https://github.com/williamwen42
This commit is contained in:
Animesh Jain
2025-05-13 23:00:28 -07:00
committed by PyTorch MergeBot
parent 0139ce9303
commit 864a5f4434
4 changed files with 29 additions and 14 deletions

View File

@ -6,7 +6,7 @@ add_loop_eager_dynamic,compile_time_instruction_count,5928000000,0.025
add_loop_inductor,compile_time_instruction_count,29570000000,0.015
add_loop_inductor,compile_time_instruction_count,29400000000,0.015
@ -16,7 +16,9 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44480000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,974800000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,954500000,0.015
@ -52,11 +54,11 @@ symint_sum_loop,compile_time_instruction_count,4262000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2079000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5940000000,0.015
@ -72,4 +74,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10350000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10270000000,0.015

1 add_loop_eager compile_time_instruction_count 3051000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 974800000 954500000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18240000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16340000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10370000000 0.2
10 update_hint_regression compile_time_instruction_count 1715000000 0.02
11 float_args compile_time_instruction_count 444500000 0.015
12 sum_floordiv_regression compile_time_instruction_count 1009000000 0.015
16 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5981000000 5940000000 0.015
17 aotdispatcher_partitioner_cpu compile_time_instruction_count 8630000000 0.015
18 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1900000000 0.015
19 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3818000000 0.015
20 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10270000000 0.015
21
22
23
24
54
55
56
57
58
59
60
61
62
63
64
74
75
76
77

View File

@ -612,7 +612,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
fn_opt = torch.compile(f, backend="eager")
fn_opt(torch.randn(3, 3))
self.assertEqual(len(records), 4)
self.assertEqual(len(records), 3)
messages = [
"\n".join(record.getMessage().split("\n")[-2:]) for record in records
]
@ -636,12 +636,6 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
# return g(g(x))
# ~^^^^^^""",
# )
self.assertExpectedInline(
messages[3],
"""\
return x * 2
~~^~~""",
)
@skipIfNotPy311
@make_logging_test(trace_call=True)

View File

@ -3985,8 +3985,25 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
f_builtins = f_globals["__builtins__"]
if not isinstance(f_builtins, dict):
f_builtins = f_builtins.__dict__
instructions = cleaned_instructions(code)
propagate_line_nums(instructions)
# Get the cached instructions. These instructions are safe to cache
# because we dont mutate them in transform_code_object (those
# instructions are for the top most Instruction translator). Also, we
# have to be careful about not using _cached_cleaned_instructions here
# because that function is global, while we want the the cache to be
# alive only during a compmilation.
tracing_ctx = parent.output.tracing_context
instructions = None
if tracing_ctx:
if tracing_ctx.previously_cleaned_instructions.get(code):
instructions = tracing_ctx.previously_cleaned_instructions[code]
if instructions is None:
instructions = cleaned_instructions(code)
propagate_line_nums(instructions)
if tracing_ctx:
tracing_ctx.previously_cleaned_instructions[code] = instructions
super().__init__(
output=parent.output,
f_locals={},

View File

@ -826,6 +826,7 @@ class TracingContext:
self.module_context = ModuleContext()
self.global_context = GlobalContext()
self.previously_inlined_functions = dict()
self.previously_cleaned_instructions = dict()
self.fake_mode = fake_mode
self.frame_summary_stack = []
# This is morally part of frame_summary_stack, but it is kept separate
@ -872,6 +873,7 @@ class TracingContext:
# for the context on clearing global context.
self.global_context.global_state = {}
self.previously_inlined_functions.clear()
self.previously_cleaned_instructions.clear()
@staticmethod
@contextmanager