[dynamo] fix add_push_null callsites with CALL_FUNCTION_EX (#132329)

Also fix a bug in `PyCodegen.add_push_null` where in Python <= 3.12, we may accidentally duplicate a NULL instead of the object on the stack before it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132329
Approved by: https://github.com/anijain2305
This commit is contained in:
William Wen
2024-07-31 14:39:33 -07:00
committed by PyTorch MergeBot
parent 0016be8051
commit 625af2d27c
3 changed files with 22 additions and 12 deletions

View File

@ -11,6 +11,7 @@ import torch.nn
from . import utils
from .bytecode_transformation import (
add_push_null,
add_push_null_call_function_ex,
create_call_function,
create_call_method,
create_dup_top,
@ -83,7 +84,7 @@ class PyCodegen:
res = value.reconstruct(self)
assert res is None, f"reconstruct!=None {value}"
def add_push_null(self, gen_fn):
def add_push_null(self, gen_fn, call_function_ex=False):
"""
`gen_fn` generates instructions via PyCodegen methods
that push a single callable to the stack.
@ -95,11 +96,19 @@ class PyCodegen:
with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR).
"""
old_len = len(self._output)
if sys.version_info < (3, 13):
# gen_fn may DUP_TOP instead if TOS is not cleared.
# Will cause problems since NULL will be pushed right
# before the generated instructions in <= 3.12
self.clear_tos()
gen_fn()
# inplace modify self._output
added_insts = self._output[old_len:]
del self._output[old_len:]
self._output.extend(add_push_null(added_insts))
if call_function_ex:
self._output.extend(add_push_null_call_function_ex(added_insts))
else:
self._output.extend(add_push_null(added_insts))
if sys.version_info >= (3, 13):
# NULL will be at top of stack
self.clear_tos()