mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Issue: #93684 In previous PRs #95849 #99560 we redirect `numpy.*`, `<tensor>.numpy()` calls to `torch_np.*` methods and attributes, by creating `NumpyNdarrayVariable` for those calls. We need to handle `NumpyNdarrayVariable` when graph break happens. This PR did 2 things: 1. In `codegen.py` we made sure we can reconstruct the value wrapped by `NumpyNdarrayVariable`, to be `torch_np.ndarray` in the stack whenerver we recompiles the subgraph. 2. In `builder.py` we can wrap the value to be `NumpyNdarrayVariable` and save it as graph input. ----- Starting from commit 6: ## A new design for supporting numpy in dynamo In short the core concept doesn't change: we still convert `numpy` API calls to `torch_np` API calls. However, instead of wrapping a `torch_np.ndarray` in `NumpyNdarrayVariable`, the new design wraps a `torch.Tensor`. The reason for doing this change is because we need to keep `torch.Tensor` everywhere in the captured graph, so that it works well with the backend of dynamo. See discussions in https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/142 for details. ### Flow This is an example showing how do we think about dynamo working on a simple function: ```python def f(x: torch.Tensor, y: torch.Tensor): a, b = x.numpy(), y.numpy() c = np.add(x, y) return torch.from_numpy(c) ``` ``` +------------+ +------------+ torch.Tensor | |numpy.ndarray| | -------------- .numpy() --------------| | | | | | +------------------+ +------------+ | numpy.add |numpy.ndarray| |torch.Tensor +------------+ | --------------| torch.from_numpy -------------- torch.Tensor | |numpy.ndarray| | | | -------------- .numpy() --------------| | +------------------+ | | | | +------------+ +------------+ +------------+ +----------------+ torch.Tensor | |torch.Tensor | | -------------- .detach() --------------| | | | | | +----------------+ +------------+ +------------+ | |torch_np.ndarray| |torch.Tensor| |torch.Tensor | torch_np.add -----------------| util.to_tensor -------------| .detach() -------------- +------------+ | | | | | | torch.Tensor | |torch.Tensor | | +----------------+ +------------+ -------------- .detach() --------------| | | | | | +------------+ | +----------------+ | | wrapper on torch_np.add | +--------------------------------------------------------+ ``` ### Approach `torch_np` APIs can take both `torch_np.ndarray` as well as `torch.Tensor`. What we need to do is to have a wrapper for these APIs to convert the return value back to `torch.Tensor`. This way only the wrapper is showing up in the captured graph, with `torch.Tensor`s as input and `torch.Tensor` as output. If we have a graph break or we've traced to the end of the program, we need to inspect all the `NumpyNdarrayVariable` in the stack and convert them back to `numpy.ndarray`, to make sure the compiled version is still behaving the same as the eager version. ### Examples Here's an example of the graph generated: ```python def fn(x: np.ndarray, y: np.ndarray): a = x.real b = y.real torch._dynamo.graph_break() return np.add(a, 1), np.add(b, 1) ``` Graph generated: ``` [2023-05-16 10:31:48,737] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs ------------- -------------- ---------------------------------------------------------- ---------------------- -------- placeholder l_x_ L_x_ () {} placeholder l_y_ L_y_ () {} call_function from_numpy <built-in method from_numpy of type object at 0x12b1fdc80> (l_x_,) {} call_function from_numpy_1 <built-in method from_numpy of type object at 0x12b1fdc80> (l_y_,) {} call_function attr_wrapper <function attr_wrapper at 0x12e8693a0> (from_numpy, 'real') {} call_function attr_wrapper_1 <function attr_wrapper at 0x12e8693a0> (from_numpy_1, 'real') {} output output output ((),) {} [2023-05-16 10:31:48,908] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH __compiled_fn_2 <eval_with_key>.1 opcode name target args kwargs ------------- ------------- ---------------------------------------------------------- ------------------------------- -------- placeholder l_a_ L_a_ () {} placeholder l_b_ L_b_ () {} call_function from_numpy <built-in method from_numpy of type object at 0x12b1fdc80> (l_a_,) {} call_function from_numpy_1 <built-in method from_numpy of type object at 0x12b1fdc80> (l_b_,) {} call_function wrapped_add <Wrapped function <original add>> (from_numpy, 1) {} call_function wrapped_add_1 <Wrapped function <original add>> (from_numpy_1, 1) {} output output output ((wrapped_add, wrapped_add_1),) {} ``` ### Changes * `codegen.py`: reconstruct `numpy.ndarray` from `NumpyNdarrayVariable` by adding bytecode to call `utils.to_numpy_helper()`. * `output_graph.py`: getting rid of legacy code that does exactly what `codegen.py` does, which only handling return case but not graph break case. * `utils.py`: added helpers to convert `numpy.ndarray` to `torch.Tensor` and vice versa. Also adding a wrapper class that takes in a function. In `__call__` it calls the function and converts its out to `torch.Tensor` (or a list of it). * `builder.py`: add method to wrap `numpy.ndarray` graph inputs into `NumpyNdarrayVariable`, by calling `torch.numpy` in the proxy. * `misc.py`: `numpy` API calls goes into `NumpyVariable` and we find the function with the same name in `torch_np` module, then wrap it with the wrapper defined in `utils.py`. * `tensor.py`, `torch.py`: proxy `tensor.numpy()` to be `torch.detach()` but wrap it with `NumpyNdarrayVariable`. Similarly, `torch.from_numpy()` -> `torch.detach()` but wrap it with `TensorVariable`. In `NumpyNdarrayVariable`, do the similar `torch_np.ndarray` to `torch.Tensor` wrapping for attributes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100839 Approved by: https://github.com/ezyang
347 lines
12 KiB
Python
347 lines
12 KiB
Python
import collections
|
|
import dataclasses
|
|
import re
|
|
import sys
|
|
import types
|
|
from typing import List
|
|
|
|
import torch.nn
|
|
from . import utils
|
|
|
|
from .bytecode_transformation import (
|
|
create_call_function,
|
|
create_dup_top,
|
|
create_instruction,
|
|
create_load_global,
|
|
create_rot_n,
|
|
Instruction,
|
|
)
|
|
from .exc import unimplemented
|
|
from .source import AttrSource, GeneratorStateSource, Source
|
|
from .utils import is_safe_constant, rot_n_helper
|
|
from .variables.base import VariableTracker
|
|
from .variables.nn_module import NNModuleVariable
|
|
from .variables.tensor import (
|
|
NumpyNdarrayVariable,
|
|
SymNodeVariable,
|
|
TensorVariable,
|
|
TensorWithTFOverrideVariable,
|
|
UnspecializedPythonVariable,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GraphOutputEntry:
|
|
index: int
|
|
variable: VariableTracker
|
|
|
|
def merge(self, other: VariableTracker):
|
|
# merge in any extra guards
|
|
self.variable = self.variable.add_options(other)
|
|
|
|
|
|
class PyCodegen:
|
|
"""
|
|
Helper class uses for constructing Python bytecode
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tx=None,
|
|
root: torch.nn.Module = None,
|
|
graph_output_var: str = None,
|
|
tempvars=None,
|
|
):
|
|
self.root = root
|
|
self.top_of_stack = None
|
|
self.uses = collections.Counter()
|
|
self.graph_outputs = collections.OrderedDict()
|
|
self._output: List[Instruction] = []
|
|
self.tempvars = tempvars or {}
|
|
self.tx = tx
|
|
self.graph_output_var = graph_output_var
|
|
self.code_options = self.tx.output.code_options
|
|
self.cell_and_freevars = self.tx.cell_and_freevars
|
|
self.new_var = self.tx.output.new_var
|
|
|
|
def graph_output_vars(self):
|
|
return [x.variable for x in self.graph_outputs.values()]
|
|
|
|
def __call__(self, value, allow_cache=True):
|
|
"""Generate code such that top-of-stack (TOS) is set to value"""
|
|
if isinstance(value, Source):
|
|
self._output.extend(value.reconstruct(self))
|
|
self.clear_tos()
|
|
return
|
|
|
|
self.tx.output.guards.update(value.guards)
|
|
|
|
assert isinstance(value, VariableTracker)
|
|
output = self._output
|
|
graph_outputs = self.graph_outputs
|
|
|
|
if self.top_of_stack is value:
|
|
output.append(create_dup_top())
|
|
return
|
|
|
|
if allow_cache:
|
|
if value.mutable_local and value.mutable_local in self.tempvars:
|
|
output.append(self.create_load(self.tempvars[value.mutable_local]))
|
|
self.top_of_stack = value
|
|
return
|
|
if self.tempvars.get(value) is not None:
|
|
output.append(self.create_load(self.tempvars[value]))
|
|
self.top_of_stack = value
|
|
return
|
|
|
|
if (
|
|
value.source is not None
|
|
and allow_cache
|
|
and not isinstance(value.source, GeneratorStateSource)
|
|
):
|
|
output.extend(value.source.reconstruct(self))
|
|
elif value.is_python_constant() and is_safe_constant(
|
|
value.as_python_constant()
|
|
):
|
|
output.append(self.create_load_const(value.as_python_constant()))
|
|
elif isinstance(
|
|
value,
|
|
(
|
|
TensorVariable,
|
|
SymNodeVariable,
|
|
TensorWithTFOverrideVariable,
|
|
UnspecializedPythonVariable,
|
|
NumpyNdarrayVariable,
|
|
),
|
|
):
|
|
if isinstance(value, TensorWithTFOverrideVariable):
|
|
# unwrap back to tensor
|
|
value = value.tensor_variable
|
|
graph_outputs_key = id(value.proxy)
|
|
if graph_outputs_key not in graph_outputs:
|
|
graph_outputs[graph_outputs_key] = GraphOutputEntry(
|
|
len(graph_outputs), value
|
|
)
|
|
else:
|
|
graph_outputs[graph_outputs_key].merge(value)
|
|
if isinstance(value, NumpyNdarrayVariable):
|
|
self.load_import_from(utils.__name__, "to_numpy_helper")
|
|
output.append(self.create_load(self.graph_output_var))
|
|
output.append(
|
|
self._create_load_const(graph_outputs[graph_outputs_key].index)
|
|
)
|
|
output.append(create_instruction("BINARY_SUBSCR"))
|
|
if isinstance(value, NumpyNdarrayVariable):
|
|
output.extend(create_call_function(1, False))
|
|
elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
|
|
output.extend(
|
|
[self.create_load_attr("item")] + create_call_function(0, True)
|
|
)
|
|
elif isinstance(value, NNModuleVariable):
|
|
parts = value.module_key.split(".")
|
|
if parts[0] in self.code_options["co_varnames"]:
|
|
output.append(self.create_load(parts[0]))
|
|
parts = parts[1:]
|
|
else:
|
|
assert self.root is not None
|
|
output.append(self.create_load_output(self.root))
|
|
for part in parts:
|
|
output.append(self.create_load_attr(part))
|
|
else:
|
|
self.uses[value] += 1
|
|
try:
|
|
output.extend(value.reconstruct(self))
|
|
except NotImplementedError:
|
|
unimplemented(f"reconstruct: {value}")
|
|
if allow_cache and value in self.tempvars:
|
|
self._output.append(create_dup_top())
|
|
self.add_cache(value)
|
|
|
|
self.top_of_stack = value
|
|
|
|
def add_cache(self, value):
|
|
var = self.new_var()
|
|
self.tempvars[value] = var
|
|
if value.mutable_local:
|
|
self.tempvars[value.mutable_local] = var
|
|
self._output.append(self.create_store(var))
|
|
|
|
def foreach(self, items):
|
|
for i in items:
|
|
self(i)
|
|
|
|
def setup_globally_cached(self, name, value, push_null):
|
|
"""Store value in a new global"""
|
|
name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
|
|
f_globals = self.tx.f_globals
|
|
if name in f_globals:
|
|
assert id(f_globals[name]) == id(value)
|
|
else:
|
|
f_globals[name] = value
|
|
return [self.create_load_global(name, push_null, add=True)]
|
|
|
|
def clear_tos(self):
|
|
self.top_of_stack = None
|
|
|
|
def append_output(self, inst):
|
|
assert isinstance(inst, Instruction)
|
|
self._output.append(inst)
|
|
self.clear_tos()
|
|
|
|
def extend_output(self, insts):
|
|
assert all(isinstance(x, Instruction) for x in insts)
|
|
self._output.extend(insts)
|
|
self.clear_tos()
|
|
|
|
def get_instructions(self):
|
|
return self._output
|
|
|
|
def create_load(self, name):
|
|
if name in self.cell_and_freevars():
|
|
return create_instruction("LOAD_DEREF", argval=name)
|
|
assert name in self.code_options["co_varnames"], f"{name} missing"
|
|
return create_instruction("LOAD_FAST", argval=name)
|
|
|
|
def create_load_closure(self, name):
|
|
assert name in self.cell_and_freevars()
|
|
return create_instruction("LOAD_CLOSURE", argval=name)
|
|
|
|
def create_store(self, name):
|
|
if name in self.cell_and_freevars():
|
|
return create_instruction("STORE_DEREF", argval=name)
|
|
assert name in self.code_options["co_varnames"]
|
|
return create_instruction("STORE_FAST", argval=name)
|
|
|
|
def create_load_global(self, name, push_null, add=False):
|
|
if add:
|
|
self.tx.output.update_co_names(name)
|
|
assert name in self.code_options["co_names"], f"{name} not in co_names"
|
|
return create_load_global(name, push_null)
|
|
|
|
def create_load_const(self, value):
|
|
assert is_safe_constant(value), f"unsafe constant {value}"
|
|
return self._create_load_const(value)
|
|
|
|
def _create_load_const(self, value):
|
|
return create_instruction("LOAD_CONST", argval=value)
|
|
|
|
create_load_output = _create_load_const
|
|
|
|
def create_load_attr(self, name):
|
|
if name not in self.code_options["co_names"]:
|
|
self.code_options["co_names"] += (name,)
|
|
return create_instruction("LOAD_ATTR", argval=name)
|
|
|
|
def create_load_attrs(self, names):
|
|
return [self.create_load_attr(name) for name in names.split(".")]
|
|
|
|
def load_function_name(self, fn_name, push_null, num_on_stack=0):
|
|
"""Load the global fn_name on the stack num_on_stack down"""
|
|
output = []
|
|
if push_null and sys.version_info >= (3, 11):
|
|
output.extend(
|
|
[create_instruction("PUSH_NULL"), *self.rot_n(num_on_stack + 1)]
|
|
)
|
|
output.extend(
|
|
[
|
|
self.create_load_global(fn_name, False, add=True),
|
|
*self.rot_n(num_on_stack + 1),
|
|
]
|
|
)
|
|
return output
|
|
|
|
def rot_n(self, n):
|
|
try:
|
|
return create_rot_n(n)
|
|
except AttributeError:
|
|
# desired rotate bytecode doesn't exist, generate equivalent bytecode
|
|
return [
|
|
create_instruction("BUILD_TUPLE", arg=n),
|
|
self._create_load_const(rot_n_helper(n)),
|
|
*create_rot_n(2),
|
|
create_instruction("CALL_FUNCTION_EX", arg=0),
|
|
create_instruction("UNPACK_SEQUENCE", arg=n),
|
|
]
|
|
|
|
def pop_null(self):
|
|
# POP_TOP doesn't work for null, so we pop nulls by pushing in a
|
|
# nop function, calling it (which consumes the null), and popping the result.
|
|
assert sys.version_info >= (3, 11)
|
|
return [
|
|
self._create_load_const(lambda: None),
|
|
*create_call_function(0, False),
|
|
create_instruction("POP_TOP"),
|
|
]
|
|
|
|
def make_function_with_closure(
|
|
self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0
|
|
):
|
|
freevars = code.co_freevars
|
|
assert freevars
|
|
output = self._output
|
|
if sys.version_info >= (3, 11) and push_null:
|
|
output.append(create_instruction("PUSH_NULL"))
|
|
output.extend(self.rot_n(num_on_stack + 1))
|
|
for var in freevars:
|
|
assert var in self.cell_and_freevars()
|
|
output.append(create_instruction("LOAD_CLOSURE", argval=var))
|
|
output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
|
|
output.append(self.create_load_const(code))
|
|
if sys.version_info < (3, 11):
|
|
output.append(self.create_load_const(fn_name))
|
|
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
|
|
output.extend(self.rot_n(num_on_stack + 1))
|
|
self.clear_tos()
|
|
|
|
def create_load_python_module(self, mod, push_null):
|
|
"""
|
|
Generate a LOAD_GLOBAL instruction to fetch a given python module.
|
|
"""
|
|
root_globals = self.tx.output.root_globals
|
|
name = re.sub(r"^.*[.]", "", mod.__name__)
|
|
if root_globals.get(name, None) is mod:
|
|
return self.create_load_global(name, push_null, add=True)
|
|
mangled_name = f"___module_{name}_{id(mod)}"
|
|
if mangled_name not in root_globals:
|
|
self.tx.output.install_global(mangled_name, mod)
|
|
return self.create_load_global(mangled_name, push_null, add=True)
|
|
|
|
def make_call_generated_code(self, fn_name: str) -> List[Instruction]:
|
|
"""Call the generated code function stored in fn_name"""
|
|
self.extend_output(self.load_function_name(fn_name, True))
|
|
|
|
graphargs = self.tx.output.graphargs
|
|
for arg in graphargs:
|
|
if arg.is_unspecialized:
|
|
self.extend_output(
|
|
[
|
|
self.create_load_python_module(torch, True),
|
|
self.create_load_attr("as_tensor"),
|
|
]
|
|
)
|
|
self.extend_output(arg.load(self))
|
|
self.extend_output(create_call_function(1, False))
|
|
else:
|
|
self.extend_output(arg.load(self))
|
|
|
|
self.extend_output(create_call_function(len(graphargs), False))
|
|
|
|
def load_import_from(self, module_name, object_name):
|
|
self.extend_output(
|
|
AttrSource(self.tx.import_source(module_name), object_name).reconstruct(
|
|
self
|
|
)
|
|
)
|
|
|
|
def create_call_function_kw(self, nargs, kw_names, push_null):
|
|
if sys.version_info >= (3, 11):
|
|
output = create_call_function(nargs, push_null)
|
|
assert output[-2].opname == "PRECALL"
|
|
kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
|
|
output.insert(-2, kw_names_inst)
|
|
return output
|
|
return [
|
|
self.create_load_const(kw_names),
|
|
create_instruction("CALL_FUNCTION_KW", arg=nargs),
|
|
]
|