dynamo tracing perf: cache cleaned_instructions: 33.7 -> 30.0 (#143070)

See #143056 for overall docs.

This PR: Cache the interesting/expensive bits of `cleaned_instructions()`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143070
Approved by: https://github.com/jansel
This commit is contained in:
Aaron Orenstein
2024-12-23 08:32:43 -08:00
committed by PyTorch MergeBot
parent 51a7ecde80
commit 3df12d38cf
2 changed files with 63 additions and 5 deletions

View File

@ -2,10 +2,22 @@
import copy
import dataclasses
import dis
import functools
import itertools
import sys
import types
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)
from ..utils._backport_slots import dataclass_slots
from .bytecode_analysis import (
@ -1453,6 +1465,52 @@ def populate_kw_names_argval(instructions, consts):
# If safe=True, we do not make any bytecode modifications.
# Mainly used for debugging bytecode_transformation (see debug_checks)
def cleaned_instructions(code, safe=False) -> List[Instruction]:
instructions = _cached_cleaned_instructions(code, safe)
# We have a lot of code that implicitly mutates the instruction array. We
# could do better here by making the copies explicit when necessary.
return _clone_instructions(instructions)
# Copy an instructions array, making sure to remap the individual instruction targets.
def _clone_instructions(instructions):
# This is super hot and this is the fastest way to do this (tried copy.copy
# and dataclasses.replace).
copied = [
Instruction(
i.opcode,
i.opname,
i.arg,
i.argval,
i.offset,
i.starts_line,
i.is_jump_target,
i.positions,
i.target,
i.exn_tab_entry,
i.argrepr,
)
for i in instructions
]
remap = dict(zip(instructions, copied))
# Handle `None` in the remapper so we don't need an extra `if`.
remap[None] = None
for i in copied:
i.target = remap[i.target]
if entry := i.exn_tab_entry:
i.exn_tab_entry = InstructionExnTabEntry(
remap[entry.start],
remap[entry.end],
remap[entry.target],
entry.depth,
entry.lasti,
)
return copied
@functools.lru_cache
def _cached_cleaned_instructions(code, safe=False) -> Sequence[Instruction]:
instructions = list(map(convert_instruction, dis.get_instructions(code)))
check_offsets(instructions)
if sys.version_info >= (3, 11):