[dynamo] utility to generate bytecode from template function (#127359)

This will be helpful in reducing some of the hardcoded and python-version-dependent bytecode generation in various places in dynamo - e.g. resume function generation and object reconstruction.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127359
Approved by: https://github.com/jansel
ghstack dependencies: #127329
This commit is contained in:
William Wen
2024-05-29 10:56:23 -07:00
committed by PyTorch MergeBot
parent 5d316c81be
commit d44ab8ba6d
3 changed files with 234 additions and 1 deletions

View File

@ -8,7 +8,7 @@ import unittest
import torch
import torch._dynamo.test_case
from torch._dynamo import bytecode_analysis, bytecode_transformation
from torch._dynamo.testing import skipIfNotPy311
from torch._dynamo.testing import skipIfNotPy311, skipIfNotPy312
class BytecodeTests(torch._dynamo.test_case.TestCase):
@ -414,6 +414,119 @@ def fn():
self.assertEqual(tab[0].end, 4)
self.assertEqual(tab[0].target, 6)
def test_bytecode_from_template(self):
def fn(d1):
for k, v in d1.items():
d2[k] = v
varname_map = {"d1": "var1", "d2": "var2", "k": "var3", "v": "var4"}
insts = bytecode_transformation.bytecode_from_template(fn, varname_map)
for inst in insts:
self.assertIsNone(inst.starts_line)
if inst.opname.startswith("LOAD"):
self.assertNotIn(inst.argval, varname_map)
if inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR"):
self.assertIsNone(inst.arg)
self.assertFalse(inst.opname.startswith("RETURN"))
@skipIfNotPy311
def test_bytecode_from_template_noprefix(self):
# Test that 3.11+ prefix instructions are removed
def gen_fn():
cl = None
def fn():
return cl
return fn
fn = gen_fn()
dis_insts = list(dis.get_instructions(fn))
names = {inst.opname for inst in dis_insts}
self.assertIn("RESUME", names)
self.assertIn("COPY_FREE_VARS", names)
insts = bytecode_transformation.bytecode_from_template(fn)
names = {inst.opname for inst in insts}
self.assertNotIn("RESUME", names)
self.assertNotIn("COPY_FREE_VARS", names)
def test_bytecode_from_template_noreturn1(self):
# Test that functions with multiple returns will have their
# returns replaced with jumps to the end
def fn():
if x:
return y
z = 3
return z
dis_insts = list(dis.get_instructions(fn))
dis_returns = list(filter(lambda x: x.opname.startswith("RETURN"), dis_insts))
self.assertGreater(len(dis_returns), 1)
self.assertTrue(dis_insts[-1].opname.startswith("RETURN"))
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
self.assertEqual(insts[-1].opname, "NOP")
self.assertEqual(len(dis_insts), len(insts))
for i0, i1 in zip(dis_insts, insts):
if i0.opname.startswith("RETURN"):
if i1 is insts[-1]:
continue
self.assertIn("JUMP", i1.opname)
self.assertIs(i1.target, insts[-1])
# Should work with 3.10, but testing with 3.11+ is sufficient.
# In 3.8, `fn` ends with a RETURN_VALUE.
@skipIfNotPy311
def test_bytecode_from_template_noreturn2(self):
# Test function that doesn't end with RETURN_VALUE
def fn():
if x:
return x
if x:
return x
raise RuntimeError
dis_insts = list(dis.get_instructions(fn))
self.assertFalse(dis_insts[-1].opname.startswith("RETURN"))
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
self.assertEqual(insts[-1].opname, "NOP")
self.assertEqual(insts[-2].opname, dis_insts[-1].opname)
self.assertEqual(len(dis_insts) + 1, len(insts))
for i0, i1 in zip(dis_insts, insts):
if i0.opname.startswith("RETURN"):
self.assertIn("JUMP", i1.opname)
self.assertIs(i1.target, insts[-1])
@skipIfNotPy312
def test_bytecode_from_template_noreturn_const(self):
# Test 3.12+ RETURN_CONST
def fn():
if x:
return 1
return 0
dis_insts = list(dis.get_instructions(fn))
dis_return_consts = list(
filter(lambda x: x.opname == "RETURN_CONST", dis_insts)
)
self.assertGreater(len(dis_return_consts), 1)
self.assertTrue(dis_insts[-1].opname == "RETURN_CONST")
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
self.assertEqual(insts[-1].opname, "NOP")
insts_i = 0
for i, inst in enumerate(dis_insts):
if inst.opname == "RETURN_CONST":
self.assertEqual(insts[insts_i].opname, "LOAD_CONST")
insts_i += 1
if insts_i != len(insts) - 1:
self.assertIn("JUMP", insts[insts_i].opname)
self.assertIs(insts[insts_i].target, insts[-1])
insts_i += 1
class BytecodeHookTests(torch._dynamo.test_case.TestCase):
def test_bytecode_hook(self):

View File

@ -1117,6 +1117,23 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N
instructions[i].arg = idx
def clear_instruction_args(instructions):
# Clear the instruction arg for instructions that have argvals.
# Useful for using dis'd bytecode within generated bytecode.
for inst in instructions:
if (
inst.argval is not _NotProvided
and (
inst.opcode in HAS_LOCAL
or inst.opcode in HAS_NAME
or inst.opcode in HAS_FREE
or inst.opcode in HAS_CONST
)
and inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR")
):
inst.arg = None
def get_code_keys() -> List[str]:
# Python 3.11 changes to code keys are not fully documented.
# See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24
@ -1247,3 +1264,100 @@ def unique_id(name) -> str:
def is_generator(code: types.CodeType) -> bool:
co_generator = 0x20
return (code.co_flags & co_generator) > 0
def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True):
"""Generates bytecode from a template function `fn` for use in
dynamo bytecode generation.
For example, we can generate Python-version-independent bytecode
for looping through a dictionary and copying the values to a new dictionary.
def template(d1, d2):
for k, v in d1.items():
d2[k] = v
or a try block:
def template():
try:
dummy1
except:
dummy2
raise
dummy3
Args:
fn: a function template to generate bytecode from
varname_map: a mapping of `fn`'s varnames to new names. This
map will be applied to the generated bytecode's varnames.
For example, local variables in `fn` can be replaced with
new names that are generated by `OutputGraph.new_var`.
noreturn: remove all RETURN_* bytecodes and replace them with a jump
to the end of the bytecode.
noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive).
"""
insts = cleaned_instructions(fn.__code__)
clear_instruction_args(insts)
if noprefix:
for i, inst in enumerate(insts):
if inst.opname == "RESUME":
insts = insts[i + 1 :]
break
for inst in insts:
# If we don't reset starts_line, then the generated
# bytecode's line number will be based on fn's.
inst.starts_line = None
if varname_map and inst.argval in varname_map:
inst.argval = varname_map[inst.argval]
if noreturn:
if sys.version_info >= (3, 12):
# replace RETURN_CONST with LOAD_CONST RETURN_VALUE
new_insts = []
for inst in insts:
if inst.opname == "RETURN_CONST":
inst.opcode = dis.opmap["LOAD_CONST"]
inst.opname = "LOAD_CONST"
new_insts.append(inst)
# no need to propagate target/exn table
new_insts.append(create_instruction("RETURN_VALUE"))
else:
new_insts.append(inst)
insts = new_insts
returns = []
for inst in insts:
if inst.opname == "RETURN_VALUE":
returns.append(inst)
if len(returns) == 1 and returns[0] is insts[-1]:
# only 1 return at the end - just pop it
insts.pop(-1)
elif len(returns) > 0:
# create jump target - if the last inst is a return,
# we can replace it with a NOP and make that the jump target.
if insts[-1] is returns[-1]:
insts[-1].opname = "NOP"
insts[-1].opcode = dis.opmap["NOP"]
insts[-1].arg = None
insts[-1].argval = _NotProvided
returns.pop(-1)
else:
insts.append(create_instruction("NOP"))
# replace returns with jumps
for inst in returns:
# don't replace inst with new instruction
# due to targetting/exn table/etc.
jump_inst = create_jump_absolute(insts[-1])
inst.opname = jump_inst.opname
inst.opcode = jump_inst.opcode
inst.arg = jump_inst.arg
inst.argval = jump_inst.argval
inst.target = jump_inst.target
return insts

View File

@ -343,6 +343,12 @@ def skipIfNotPy311(fn):
return unittest.skip(fn)
def skipIfNotPy312(fn):
if sys.version_info >= (3, 12):
return fn
return unittest.skip(fn)
def xfailIfPy312(fn):
if sys.version_info >= (3, 12):
return unittest.expectedFailure(fn)