mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] utility to generate bytecode from template function (#127359)
This will be helpful in reducing some of the hardcoded and python-version-dependent bytecode generation in various places in dynamo - e.g. resume function generation and object reconstruction. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127359 Approved by: https://github.com/jansel ghstack dependencies: #127329
This commit is contained in:
committed by
PyTorch MergeBot
parent
5d316c81be
commit
d44ab8ba6d
@ -8,7 +8,7 @@ import unittest
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch._dynamo import bytecode_analysis, bytecode_transformation
|
||||
from torch._dynamo.testing import skipIfNotPy311
|
||||
from torch._dynamo.testing import skipIfNotPy311, skipIfNotPy312
|
||||
|
||||
|
||||
class BytecodeTests(torch._dynamo.test_case.TestCase):
|
||||
@ -414,6 +414,119 @@ def fn():
|
||||
self.assertEqual(tab[0].end, 4)
|
||||
self.assertEqual(tab[0].target, 6)
|
||||
|
||||
def test_bytecode_from_template(self):
|
||||
def fn(d1):
|
||||
for k, v in d1.items():
|
||||
d2[k] = v
|
||||
|
||||
varname_map = {"d1": "var1", "d2": "var2", "k": "var3", "v": "var4"}
|
||||
insts = bytecode_transformation.bytecode_from_template(fn, varname_map)
|
||||
for inst in insts:
|
||||
self.assertIsNone(inst.starts_line)
|
||||
if inst.opname.startswith("LOAD"):
|
||||
self.assertNotIn(inst.argval, varname_map)
|
||||
if inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR"):
|
||||
self.assertIsNone(inst.arg)
|
||||
self.assertFalse(inst.opname.startswith("RETURN"))
|
||||
|
||||
@skipIfNotPy311
|
||||
def test_bytecode_from_template_noprefix(self):
|
||||
# Test that 3.11+ prefix instructions are removed
|
||||
def gen_fn():
|
||||
cl = None
|
||||
|
||||
def fn():
|
||||
return cl
|
||||
|
||||
return fn
|
||||
|
||||
fn = gen_fn()
|
||||
|
||||
dis_insts = list(dis.get_instructions(fn))
|
||||
names = {inst.opname for inst in dis_insts}
|
||||
self.assertIn("RESUME", names)
|
||||
self.assertIn("COPY_FREE_VARS", names)
|
||||
|
||||
insts = bytecode_transformation.bytecode_from_template(fn)
|
||||
names = {inst.opname for inst in insts}
|
||||
self.assertNotIn("RESUME", names)
|
||||
self.assertNotIn("COPY_FREE_VARS", names)
|
||||
|
||||
def test_bytecode_from_template_noreturn1(self):
|
||||
# Test that functions with multiple returns will have their
|
||||
# returns replaced with jumps to the end
|
||||
def fn():
|
||||
if x:
|
||||
return y
|
||||
z = 3
|
||||
return z
|
||||
|
||||
dis_insts = list(dis.get_instructions(fn))
|
||||
dis_returns = list(filter(lambda x: x.opname.startswith("RETURN"), dis_insts))
|
||||
self.assertGreater(len(dis_returns), 1)
|
||||
self.assertTrue(dis_insts[-1].opname.startswith("RETURN"))
|
||||
|
||||
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
|
||||
self.assertEqual(insts[-1].opname, "NOP")
|
||||
self.assertEqual(len(dis_insts), len(insts))
|
||||
for i0, i1 in zip(dis_insts, insts):
|
||||
if i0.opname.startswith("RETURN"):
|
||||
if i1 is insts[-1]:
|
||||
continue
|
||||
self.assertIn("JUMP", i1.opname)
|
||||
self.assertIs(i1.target, insts[-1])
|
||||
|
||||
# Should work with 3.10, but testing with 3.11+ is sufficient.
|
||||
# In 3.8, `fn` ends with a RETURN_VALUE.
|
||||
@skipIfNotPy311
|
||||
def test_bytecode_from_template_noreturn2(self):
|
||||
# Test function that doesn't end with RETURN_VALUE
|
||||
def fn():
|
||||
if x:
|
||||
return x
|
||||
if x:
|
||||
return x
|
||||
raise RuntimeError
|
||||
|
||||
dis_insts = list(dis.get_instructions(fn))
|
||||
self.assertFalse(dis_insts[-1].opname.startswith("RETURN"))
|
||||
|
||||
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
|
||||
self.assertEqual(insts[-1].opname, "NOP")
|
||||
self.assertEqual(insts[-2].opname, dis_insts[-1].opname)
|
||||
self.assertEqual(len(dis_insts) + 1, len(insts))
|
||||
for i0, i1 in zip(dis_insts, insts):
|
||||
if i0.opname.startswith("RETURN"):
|
||||
self.assertIn("JUMP", i1.opname)
|
||||
self.assertIs(i1.target, insts[-1])
|
||||
|
||||
@skipIfNotPy312
|
||||
def test_bytecode_from_template_noreturn_const(self):
|
||||
# Test 3.12+ RETURN_CONST
|
||||
def fn():
|
||||
if x:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
dis_insts = list(dis.get_instructions(fn))
|
||||
dis_return_consts = list(
|
||||
filter(lambda x: x.opname == "RETURN_CONST", dis_insts)
|
||||
)
|
||||
self.assertGreater(len(dis_return_consts), 1)
|
||||
self.assertTrue(dis_insts[-1].opname == "RETURN_CONST")
|
||||
|
||||
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
|
||||
self.assertEqual(insts[-1].opname, "NOP")
|
||||
insts_i = 0
|
||||
for i, inst in enumerate(dis_insts):
|
||||
if inst.opname == "RETURN_CONST":
|
||||
self.assertEqual(insts[insts_i].opname, "LOAD_CONST")
|
||||
insts_i += 1
|
||||
if insts_i != len(insts) - 1:
|
||||
self.assertIn("JUMP", insts[insts_i].opname)
|
||||
self.assertIs(insts[insts_i].target, insts[-1])
|
||||
insts_i += 1
|
||||
|
||||
|
||||
class BytecodeHookTests(torch._dynamo.test_case.TestCase):
|
||||
def test_bytecode_hook(self):
|
||||
|
@ -1117,6 +1117,23 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N
|
||||
instructions[i].arg = idx
|
||||
|
||||
|
||||
def clear_instruction_args(instructions):
|
||||
# Clear the instruction arg for instructions that have argvals.
|
||||
# Useful for using dis'd bytecode within generated bytecode.
|
||||
for inst in instructions:
|
||||
if (
|
||||
inst.argval is not _NotProvided
|
||||
and (
|
||||
inst.opcode in HAS_LOCAL
|
||||
or inst.opcode in HAS_NAME
|
||||
or inst.opcode in HAS_FREE
|
||||
or inst.opcode in HAS_CONST
|
||||
)
|
||||
and inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR")
|
||||
):
|
||||
inst.arg = None
|
||||
|
||||
|
||||
def get_code_keys() -> List[str]:
|
||||
# Python 3.11 changes to code keys are not fully documented.
|
||||
# See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24
|
||||
@ -1247,3 +1264,100 @@ def unique_id(name) -> str:
|
||||
def is_generator(code: types.CodeType) -> bool:
|
||||
co_generator = 0x20
|
||||
return (code.co_flags & co_generator) > 0
|
||||
|
||||
|
||||
def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True):
|
||||
"""Generates bytecode from a template function `fn` for use in
|
||||
dynamo bytecode generation.
|
||||
|
||||
For example, we can generate Python-version-independent bytecode
|
||||
for looping through a dictionary and copying the values to a new dictionary.
|
||||
|
||||
def template(d1, d2):
|
||||
for k, v in d1.items():
|
||||
d2[k] = v
|
||||
|
||||
|
||||
or a try block:
|
||||
|
||||
def template():
|
||||
try:
|
||||
dummy1
|
||||
except:
|
||||
dummy2
|
||||
raise
|
||||
dummy3
|
||||
|
||||
Args:
|
||||
fn: a function template to generate bytecode from
|
||||
varname_map: a mapping of `fn`'s varnames to new names. This
|
||||
map will be applied to the generated bytecode's varnames.
|
||||
For example, local variables in `fn` can be replaced with
|
||||
new names that are generated by `OutputGraph.new_var`.
|
||||
noreturn: remove all RETURN_* bytecodes and replace them with a jump
|
||||
to the end of the bytecode.
|
||||
noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive).
|
||||
"""
|
||||
insts = cleaned_instructions(fn.__code__)
|
||||
clear_instruction_args(insts)
|
||||
|
||||
if noprefix:
|
||||
for i, inst in enumerate(insts):
|
||||
if inst.opname == "RESUME":
|
||||
insts = insts[i + 1 :]
|
||||
break
|
||||
|
||||
for inst in insts:
|
||||
# If we don't reset starts_line, then the generated
|
||||
# bytecode's line number will be based on fn's.
|
||||
inst.starts_line = None
|
||||
if varname_map and inst.argval in varname_map:
|
||||
inst.argval = varname_map[inst.argval]
|
||||
|
||||
if noreturn:
|
||||
if sys.version_info >= (3, 12):
|
||||
# replace RETURN_CONST with LOAD_CONST RETURN_VALUE
|
||||
new_insts = []
|
||||
for inst in insts:
|
||||
if inst.opname == "RETURN_CONST":
|
||||
inst.opcode = dis.opmap["LOAD_CONST"]
|
||||
inst.opname = "LOAD_CONST"
|
||||
new_insts.append(inst)
|
||||
# no need to propagate target/exn table
|
||||
new_insts.append(create_instruction("RETURN_VALUE"))
|
||||
else:
|
||||
new_insts.append(inst)
|
||||
insts = new_insts
|
||||
|
||||
returns = []
|
||||
for inst in insts:
|
||||
if inst.opname == "RETURN_VALUE":
|
||||
returns.append(inst)
|
||||
|
||||
if len(returns) == 1 and returns[0] is insts[-1]:
|
||||
# only 1 return at the end - just pop it
|
||||
insts.pop(-1)
|
||||
elif len(returns) > 0:
|
||||
# create jump target - if the last inst is a return,
|
||||
# we can replace it with a NOP and make that the jump target.
|
||||
if insts[-1] is returns[-1]:
|
||||
insts[-1].opname = "NOP"
|
||||
insts[-1].opcode = dis.opmap["NOP"]
|
||||
insts[-1].arg = None
|
||||
insts[-1].argval = _NotProvided
|
||||
returns.pop(-1)
|
||||
else:
|
||||
insts.append(create_instruction("NOP"))
|
||||
|
||||
# replace returns with jumps
|
||||
for inst in returns:
|
||||
# don't replace inst with new instruction
|
||||
# due to targetting/exn table/etc.
|
||||
jump_inst = create_jump_absolute(insts[-1])
|
||||
inst.opname = jump_inst.opname
|
||||
inst.opcode = jump_inst.opcode
|
||||
inst.arg = jump_inst.arg
|
||||
inst.argval = jump_inst.argval
|
||||
inst.target = jump_inst.target
|
||||
|
||||
return insts
|
||||
|
@ -343,6 +343,12 @@ def skipIfNotPy311(fn):
|
||||
return unittest.skip(fn)
|
||||
|
||||
|
||||
def skipIfNotPy312(fn):
|
||||
if sys.version_info >= (3, 12):
|
||||
return fn
|
||||
return unittest.skip(fn)
|
||||
|
||||
|
||||
def xfailIfPy312(fn):
|
||||
if sys.version_info >= (3, 12):
|
||||
return unittest.expectedFailure(fn)
|
||||
|
Reference in New Issue
Block a user