mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] utility to generate bytecode from template function (#127359)
This will be helpful in reducing some of the hardcoded and python-version-dependent bytecode generation in various places in dynamo - e.g. resume function generation and object reconstruction. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127359 Approved by: https://github.com/jansel ghstack dependencies: #127329
This commit is contained in:
committed by
PyTorch MergeBot
parent
5d316c81be
commit
d44ab8ba6d
@ -8,7 +8,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
from torch._dynamo import bytecode_analysis, bytecode_transformation
|
from torch._dynamo import bytecode_analysis, bytecode_transformation
|
||||||
from torch._dynamo.testing import skipIfNotPy311
|
from torch._dynamo.testing import skipIfNotPy311, skipIfNotPy312
|
||||||
|
|
||||||
|
|
||||||
class BytecodeTests(torch._dynamo.test_case.TestCase):
|
class BytecodeTests(torch._dynamo.test_case.TestCase):
|
||||||
@ -414,6 +414,119 @@ def fn():
|
|||||||
self.assertEqual(tab[0].end, 4)
|
self.assertEqual(tab[0].end, 4)
|
||||||
self.assertEqual(tab[0].target, 6)
|
self.assertEqual(tab[0].target, 6)
|
||||||
|
|
||||||
|
def test_bytecode_from_template(self):
|
||||||
|
def fn(d1):
|
||||||
|
for k, v in d1.items():
|
||||||
|
d2[k] = v
|
||||||
|
|
||||||
|
varname_map = {"d1": "var1", "d2": "var2", "k": "var3", "v": "var4"}
|
||||||
|
insts = bytecode_transformation.bytecode_from_template(fn, varname_map)
|
||||||
|
for inst in insts:
|
||||||
|
self.assertIsNone(inst.starts_line)
|
||||||
|
if inst.opname.startswith("LOAD"):
|
||||||
|
self.assertNotIn(inst.argval, varname_map)
|
||||||
|
if inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR"):
|
||||||
|
self.assertIsNone(inst.arg)
|
||||||
|
self.assertFalse(inst.opname.startswith("RETURN"))
|
||||||
|
|
||||||
|
@skipIfNotPy311
|
||||||
|
def test_bytecode_from_template_noprefix(self):
|
||||||
|
# Test that 3.11+ prefix instructions are removed
|
||||||
|
def gen_fn():
|
||||||
|
cl = None
|
||||||
|
|
||||||
|
def fn():
|
||||||
|
return cl
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
fn = gen_fn()
|
||||||
|
|
||||||
|
dis_insts = list(dis.get_instructions(fn))
|
||||||
|
names = {inst.opname for inst in dis_insts}
|
||||||
|
self.assertIn("RESUME", names)
|
||||||
|
self.assertIn("COPY_FREE_VARS", names)
|
||||||
|
|
||||||
|
insts = bytecode_transformation.bytecode_from_template(fn)
|
||||||
|
names = {inst.opname for inst in insts}
|
||||||
|
self.assertNotIn("RESUME", names)
|
||||||
|
self.assertNotIn("COPY_FREE_VARS", names)
|
||||||
|
|
||||||
|
def test_bytecode_from_template_noreturn1(self):
|
||||||
|
# Test that functions with multiple returns will have their
|
||||||
|
# returns replaced with jumps to the end
|
||||||
|
def fn():
|
||||||
|
if x:
|
||||||
|
return y
|
||||||
|
z = 3
|
||||||
|
return z
|
||||||
|
|
||||||
|
dis_insts = list(dis.get_instructions(fn))
|
||||||
|
dis_returns = list(filter(lambda x: x.opname.startswith("RETURN"), dis_insts))
|
||||||
|
self.assertGreater(len(dis_returns), 1)
|
||||||
|
self.assertTrue(dis_insts[-1].opname.startswith("RETURN"))
|
||||||
|
|
||||||
|
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
|
||||||
|
self.assertEqual(insts[-1].opname, "NOP")
|
||||||
|
self.assertEqual(len(dis_insts), len(insts))
|
||||||
|
for i0, i1 in zip(dis_insts, insts):
|
||||||
|
if i0.opname.startswith("RETURN"):
|
||||||
|
if i1 is insts[-1]:
|
||||||
|
continue
|
||||||
|
self.assertIn("JUMP", i1.opname)
|
||||||
|
self.assertIs(i1.target, insts[-1])
|
||||||
|
|
||||||
|
# Should work with 3.10, but testing with 3.11+ is sufficient.
|
||||||
|
# In 3.8, `fn` ends with a RETURN_VALUE.
|
||||||
|
@skipIfNotPy311
|
||||||
|
def test_bytecode_from_template_noreturn2(self):
|
||||||
|
# Test function that doesn't end with RETURN_VALUE
|
||||||
|
def fn():
|
||||||
|
if x:
|
||||||
|
return x
|
||||||
|
if x:
|
||||||
|
return x
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
dis_insts = list(dis.get_instructions(fn))
|
||||||
|
self.assertFalse(dis_insts[-1].opname.startswith("RETURN"))
|
||||||
|
|
||||||
|
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
|
||||||
|
self.assertEqual(insts[-1].opname, "NOP")
|
||||||
|
self.assertEqual(insts[-2].opname, dis_insts[-1].opname)
|
||||||
|
self.assertEqual(len(dis_insts) + 1, len(insts))
|
||||||
|
for i0, i1 in zip(dis_insts, insts):
|
||||||
|
if i0.opname.startswith("RETURN"):
|
||||||
|
self.assertIn("JUMP", i1.opname)
|
||||||
|
self.assertIs(i1.target, insts[-1])
|
||||||
|
|
||||||
|
@skipIfNotPy312
|
||||||
|
def test_bytecode_from_template_noreturn_const(self):
|
||||||
|
# Test 3.12+ RETURN_CONST
|
||||||
|
def fn():
|
||||||
|
if x:
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
dis_insts = list(dis.get_instructions(fn))
|
||||||
|
dis_return_consts = list(
|
||||||
|
filter(lambda x: x.opname == "RETURN_CONST", dis_insts)
|
||||||
|
)
|
||||||
|
self.assertGreater(len(dis_return_consts), 1)
|
||||||
|
self.assertTrue(dis_insts[-1].opname == "RETURN_CONST")
|
||||||
|
|
||||||
|
insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
|
||||||
|
self.assertEqual(insts[-1].opname, "NOP")
|
||||||
|
insts_i = 0
|
||||||
|
for i, inst in enumerate(dis_insts):
|
||||||
|
if inst.opname == "RETURN_CONST":
|
||||||
|
self.assertEqual(insts[insts_i].opname, "LOAD_CONST")
|
||||||
|
insts_i += 1
|
||||||
|
if insts_i != len(insts) - 1:
|
||||||
|
self.assertIn("JUMP", insts[insts_i].opname)
|
||||||
|
self.assertIs(insts[insts_i].target, insts[-1])
|
||||||
|
insts_i += 1
|
||||||
|
|
||||||
|
|
||||||
class BytecodeHookTests(torch._dynamo.test_case.TestCase):
|
class BytecodeHookTests(torch._dynamo.test_case.TestCase):
|
||||||
def test_bytecode_hook(self):
|
def test_bytecode_hook(self):
|
||||||
|
@ -1117,6 +1117,23 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N
|
|||||||
instructions[i].arg = idx
|
instructions[i].arg = idx
|
||||||
|
|
||||||
|
|
||||||
|
def clear_instruction_args(instructions):
|
||||||
|
# Clear the instruction arg for instructions that have argvals.
|
||||||
|
# Useful for using dis'd bytecode within generated bytecode.
|
||||||
|
for inst in instructions:
|
||||||
|
if (
|
||||||
|
inst.argval is not _NotProvided
|
||||||
|
and (
|
||||||
|
inst.opcode in HAS_LOCAL
|
||||||
|
or inst.opcode in HAS_NAME
|
||||||
|
or inst.opcode in HAS_FREE
|
||||||
|
or inst.opcode in HAS_CONST
|
||||||
|
)
|
||||||
|
and inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR")
|
||||||
|
):
|
||||||
|
inst.arg = None
|
||||||
|
|
||||||
|
|
||||||
def get_code_keys() -> List[str]:
|
def get_code_keys() -> List[str]:
|
||||||
# Python 3.11 changes to code keys are not fully documented.
|
# Python 3.11 changes to code keys are not fully documented.
|
||||||
# See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24
|
# See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24
|
||||||
@ -1247,3 +1264,100 @@ def unique_id(name) -> str:
|
|||||||
def is_generator(code: types.CodeType) -> bool:
|
def is_generator(code: types.CodeType) -> bool:
|
||||||
co_generator = 0x20
|
co_generator = 0x20
|
||||||
return (code.co_flags & co_generator) > 0
|
return (code.co_flags & co_generator) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True):
|
||||||
|
"""Generates bytecode from a template function `fn` for use in
|
||||||
|
dynamo bytecode generation.
|
||||||
|
|
||||||
|
For example, we can generate Python-version-independent bytecode
|
||||||
|
for looping through a dictionary and copying the values to a new dictionary.
|
||||||
|
|
||||||
|
def template(d1, d2):
|
||||||
|
for k, v in d1.items():
|
||||||
|
d2[k] = v
|
||||||
|
|
||||||
|
|
||||||
|
or a try block:
|
||||||
|
|
||||||
|
def template():
|
||||||
|
try:
|
||||||
|
dummy1
|
||||||
|
except:
|
||||||
|
dummy2
|
||||||
|
raise
|
||||||
|
dummy3
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: a function template to generate bytecode from
|
||||||
|
varname_map: a mapping of `fn`'s varnames to new names. This
|
||||||
|
map will be applied to the generated bytecode's varnames.
|
||||||
|
For example, local variables in `fn` can be replaced with
|
||||||
|
new names that are generated by `OutputGraph.new_var`.
|
||||||
|
noreturn: remove all RETURN_* bytecodes and replace them with a jump
|
||||||
|
to the end of the bytecode.
|
||||||
|
noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive).
|
||||||
|
"""
|
||||||
|
insts = cleaned_instructions(fn.__code__)
|
||||||
|
clear_instruction_args(insts)
|
||||||
|
|
||||||
|
if noprefix:
|
||||||
|
for i, inst in enumerate(insts):
|
||||||
|
if inst.opname == "RESUME":
|
||||||
|
insts = insts[i + 1 :]
|
||||||
|
break
|
||||||
|
|
||||||
|
for inst in insts:
|
||||||
|
# If we don't reset starts_line, then the generated
|
||||||
|
# bytecode's line number will be based on fn's.
|
||||||
|
inst.starts_line = None
|
||||||
|
if varname_map and inst.argval in varname_map:
|
||||||
|
inst.argval = varname_map[inst.argval]
|
||||||
|
|
||||||
|
if noreturn:
|
||||||
|
if sys.version_info >= (3, 12):
|
||||||
|
# replace RETURN_CONST with LOAD_CONST RETURN_VALUE
|
||||||
|
new_insts = []
|
||||||
|
for inst in insts:
|
||||||
|
if inst.opname == "RETURN_CONST":
|
||||||
|
inst.opcode = dis.opmap["LOAD_CONST"]
|
||||||
|
inst.opname = "LOAD_CONST"
|
||||||
|
new_insts.append(inst)
|
||||||
|
# no need to propagate target/exn table
|
||||||
|
new_insts.append(create_instruction("RETURN_VALUE"))
|
||||||
|
else:
|
||||||
|
new_insts.append(inst)
|
||||||
|
insts = new_insts
|
||||||
|
|
||||||
|
returns = []
|
||||||
|
for inst in insts:
|
||||||
|
if inst.opname == "RETURN_VALUE":
|
||||||
|
returns.append(inst)
|
||||||
|
|
||||||
|
if len(returns) == 1 and returns[0] is insts[-1]:
|
||||||
|
# only 1 return at the end - just pop it
|
||||||
|
insts.pop(-1)
|
||||||
|
elif len(returns) > 0:
|
||||||
|
# create jump target - if the last inst is a return,
|
||||||
|
# we can replace it with a NOP and make that the jump target.
|
||||||
|
if insts[-1] is returns[-1]:
|
||||||
|
insts[-1].opname = "NOP"
|
||||||
|
insts[-1].opcode = dis.opmap["NOP"]
|
||||||
|
insts[-1].arg = None
|
||||||
|
insts[-1].argval = _NotProvided
|
||||||
|
returns.pop(-1)
|
||||||
|
else:
|
||||||
|
insts.append(create_instruction("NOP"))
|
||||||
|
|
||||||
|
# replace returns with jumps
|
||||||
|
for inst in returns:
|
||||||
|
# don't replace inst with new instruction
|
||||||
|
# due to targetting/exn table/etc.
|
||||||
|
jump_inst = create_jump_absolute(insts[-1])
|
||||||
|
inst.opname = jump_inst.opname
|
||||||
|
inst.opcode = jump_inst.opcode
|
||||||
|
inst.arg = jump_inst.arg
|
||||||
|
inst.argval = jump_inst.argval
|
||||||
|
inst.target = jump_inst.target
|
||||||
|
|
||||||
|
return insts
|
||||||
|
@ -343,6 +343,12 @@ def skipIfNotPy311(fn):
|
|||||||
return unittest.skip(fn)
|
return unittest.skip(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def skipIfNotPy312(fn):
|
||||||
|
if sys.version_info >= (3, 12):
|
||||||
|
return fn
|
||||||
|
return unittest.skip(fn)
|
||||||
|
|
||||||
|
|
||||||
def xfailIfPy312(fn):
|
def xfailIfPy312(fn):
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
return unittest.expectedFailure(fn)
|
return unittest.expectedFailure(fn)
|
||||||
|
Reference in New Issue
Block a user