mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dynamo] Add support for nn.Parameter constructor (part 2) (#120965)
This handles the case where the tensor isn't an input. The changes to dynamo tests are cases where we would previously fall back to eager. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120965 Approved by: https://github.com/yanboliang ghstack dependencies: #121735
This commit is contained in:
committed by
PyTorch MergeBot
parent
040b925753
commit
0b7d9711d4
0
test/dynamo_skips/TestNN.test_padding_list
Normal file
0
test/dynamo_skips/TestNN.test_padding_list
Normal file
0
test/dynamo_skips/TestNN.test_vector_to_parameters
Normal file
0
test/dynamo_skips/TestNN.test_vector_to_parameters
Normal file
@ -301,6 +301,36 @@ class DistributedPatternTests(TestCase):
|
||||
self._assert_same_grad(r1, r2)
|
||||
self._assert_same_grad(p1, p2)
|
||||
|
||||
def test_nn_param_return3(self):
|
||||
def fn(x):
|
||||
p = torch.nn.Parameter(x + 123)
|
||||
return p, p.sin()
|
||||
|
||||
opt = torch.compile(fn, fullgraph=True)
|
||||
x1 = torch.randn(16)
|
||||
x2 = x1.clone()
|
||||
|
||||
p1, r1 = fn(x1)
|
||||
r1.sum().backward()
|
||||
p2, r2 = opt(x2)
|
||||
r2.sum().backward()
|
||||
self._assert_same_grad(r1, r2)
|
||||
self._assert_same_grad(p1, p2)
|
||||
|
||||
def test_nn_param_return4(self):
|
||||
def fn(x):
|
||||
p = torch.nn.Parameter(x + 123, requires_grad=False)
|
||||
return p, x + 1
|
||||
|
||||
opt = torch.compile(fn, fullgraph=True)
|
||||
x1 = torch.randn(16)
|
||||
x2 = x1.clone()
|
||||
|
||||
p1, r1 = fn(x1)
|
||||
p2, r2 = opt(x2)
|
||||
self._assert_same_grad(r1, r2)
|
||||
self._assert_same_grad(p1, p2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_CPU and not IS_MACOS:
|
||||
|
50
torch/_dynamo/create_parameter_op.py
Normal file
50
torch/_dynamo/create_parameter_op.py
Normal file
@ -0,0 +1,50 @@
|
||||
import torch
|
||||
from torch._prims import _make_prim, RETURN_TYPE
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
|
||||
doc = """
|
||||
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
|
||||
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
|
||||
becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
|
||||
actually should be created we mutate this sacrificial placeholder into it. This allows gradients
|
||||
to flow into the parameter as if it were an input to the graph (which is the only thing we are
|
||||
allowed to compute gradients on).
|
||||
""".strip()
|
||||
|
||||
_bind_nn_parameter = _make_prim(
|
||||
schema="_bind_nn_parameter(Tensor self, Tensor placeholder) -> Tensor",
|
||||
return_type=RETURN_TYPE.NEW,
|
||||
meta=lambda self, placeholder: torch.nn.Parameter(
|
||||
clone_preserve_strides(self), placeholder.requires_grad
|
||||
),
|
||||
impl_aten=lambda self, placeholder: placeholder.set_(self),
|
||||
doc=doc,
|
||||
)
|
||||
torch.fx.node.has_side_effect(_bind_nn_parameter)
|
||||
|
||||
|
||||
class TracableCreateParameter(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, placeholder):
|
||||
assert not tensor.requires_grad
|
||||
return _bind_nn_parameter(tensor, placeholder)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return None, grad # grad flows to placeholder
|
||||
|
||||
|
||||
def tracable_create_parameter(tensor, placeholder):
|
||||
with torch.set_grad_enabled(placeholder.requires_grad):
|
||||
return TracableCreateParameter.apply(tensor, placeholder)
|
||||
|
||||
|
||||
def new_parameter_placeholder(size, dtype, device, requires_grad):
|
||||
"""Create a placeholder to be passed to the above functions"""
|
||||
result = torch.nn.Parameter(
|
||||
torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
|
||||
)
|
||||
# TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
|
||||
# Allocating a zero tensor would causes assert failures in autograd.
|
||||
result.untyped_storage().resize_(0)
|
||||
return result
|
@ -69,6 +69,7 @@ from .source import (
|
||||
LocalSource,
|
||||
ParamBufferSource,
|
||||
ShapeEnvSource,
|
||||
SyntheticLocalSource,
|
||||
TensorProperty,
|
||||
TensorPropertySource,
|
||||
)
|
||||
@ -472,6 +473,28 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
|
||||
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH))
|
||||
|
||||
def synthetic_graph_input(self, fn, args):
|
||||
"""
|
||||
call fn(*args) before the graph runs and turn the result into a fake input.
|
||||
"""
|
||||
example_value = fn(*args)
|
||||
varname = self.new_var()
|
||||
cg = PyCodegen(self.root_tx)
|
||||
cg.load_import_from(
|
||||
fn.__module__,
|
||||
fn.__name__,
|
||||
)
|
||||
cg.foreach(map(variables.ConstantVariable.create, args))
|
||||
cg.call_function(len(args), True)
|
||||
cg.store(varname)
|
||||
self.pregraph_bytecode.extend(cg.get_instructions())
|
||||
source = SyntheticLocalSource(varname)
|
||||
result = VariableBuilder(self.root_tx, source)(example_value)
|
||||
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
|
||||
source
|
||||
)
|
||||
return result
|
||||
|
||||
def add_cleanup_hook(self, fn: Callable[[], Any]):
|
||||
self.cleanup_hooks.append(fn)
|
||||
|
||||
|
@ -17,6 +17,7 @@ from torch._streambase import _StreamBase
|
||||
from ..._guards import TracingContext
|
||||
from .. import config, polyfill, variables
|
||||
from ..codegen import PyCodegen
|
||||
from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter
|
||||
from ..device_interface import get_registered_device_interfaces
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
@ -840,7 +841,35 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
if data.source:
|
||||
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
|
||||
|
||||
unimplemented("Parameter() on non-input")
|
||||
try:
|
||||
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
|
||||
dtype = data.var_getattr(tx, "dtype").as_python_constant()
|
||||
device = data.var_getattr(tx, "device").as_python_constant()
|
||||
except NotImplementedError as e:
|
||||
unimplemented(f"Parameter not python_constant: {e}")
|
||||
|
||||
placeholder = tx.output.synthetic_graph_input(
|
||||
new_parameter_placeholder, [shape, dtype, device, requires_grad]
|
||||
)
|
||||
if data.requires_grad:
|
||||
data = data.call_method(tx, "detach", [], {})
|
||||
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
result = wrap_fx_proxy(
|
||||
tx,
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
tracable_create_parameter,
|
||||
(data.as_proxy(), placeholder.as_proxy()),
|
||||
{},
|
||||
),
|
||||
)
|
||||
assert isinstance(result, variables.TensorVariable)
|
||||
result.class_type = torch.nn.Parameter
|
||||
# In reconstruct() should use the original parameter. The one returned by the graph will be an alias.
|
||||
result.source = placeholder.source
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _nn_param_via_prefix_insert(tx, data, requires_grad):
|
||||
|
@ -181,6 +181,26 @@ def _output_node(gm: torch.fx.GraphModule) -> torch.fx.Node:
|
||||
return next(n for n in reversed(gm.graph.nodes) if n.op == "output")
|
||||
|
||||
|
||||
def _input_node(gm: torch.fx.GraphModule, i: int) -> torch.fx.Node:
|
||||
"""Fetch the i-th placeholder in the graph"""
|
||||
seen = 0
|
||||
for n in gm.graph.nodes:
|
||||
if n.op == "placeholder":
|
||||
if seen == i:
|
||||
return n
|
||||
seen += 1
|
||||
raise IndexError(f"input {i} does not exist, only {seen} inputs in graph")
|
||||
|
||||
|
||||
def _can_detach(node: torch.fx.Node):
|
||||
"""
|
||||
Avoid calling .detach() on inputs passed to _bind_nn_parameter()
|
||||
"""
|
||||
from torch._dynamo.create_parameter_op import _bind_nn_parameter
|
||||
|
||||
return all(n.target is not _bind_nn_parameter for n in node.users)
|
||||
|
||||
|
||||
def aot_dispatch_autograd(
|
||||
flat_fn,
|
||||
flat_args: List[Any],
|
||||
@ -317,7 +337,7 @@ def aot_dispatch_autograd(
|
||||
== len(fw_metadata.input_info) + inner_meta.num_outputs_rng_offset
|
||||
)
|
||||
for i, (bw_out) in enumerate(bw_outs):
|
||||
if bw_out is None:
|
||||
if bw_out is None and _can_detach(_input_node(fx_g, i)):
|
||||
_indices_of_inps_to_detach.append(i)
|
||||
|
||||
if aot_config.enable_log:
|
||||
|
@ -4496,6 +4496,29 @@ class ResizeStorageBytes(MutatingFirstArgExternKernel):
|
||||
mark_node_as_mutating(self, variable)
|
||||
|
||||
|
||||
class BindNNParameter(ExternKernelAlloc):
|
||||
def __init__(self, variable, placeholder):
|
||||
variable.freeze_layout()
|
||||
super().__init__(
|
||||
variable.get_layout(),
|
||||
[variable, placeholder],
|
||||
python_kernel_name="torch.ops.prims._bind_nn_parameter",
|
||||
)
|
||||
V.graph.never_reuse_buffers.add(variable.data.get_name())
|
||||
V.graph.never_reuse_buffers.add(placeholder.get_name())
|
||||
V.graph.never_reuse_buffers.add(self.get_name())
|
||||
mark_node_as_mutating(self, variable, placeholder)
|
||||
|
||||
def get_alias_names(self):
|
||||
return [self.inputs[0].get_name(), self.inputs[1].get_name()]
|
||||
|
||||
def get_mutation_names(self):
|
||||
return [self.inputs[1].get_name()]
|
||||
|
||||
def has_side_effects(self):
|
||||
return True
|
||||
|
||||
|
||||
class ScatterFallback(ExternKernel):
|
||||
"""
|
||||
This needs to be a custom class to handle mutation properly.
|
||||
|
@ -13,6 +13,7 @@ import torch
|
||||
import torch.ao.quantization.fx._decomposed
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.create_parameter_op import _bind_nn_parameter
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
triton_kernel_wrapper_functional,
|
||||
triton_kernel_wrapper_mutation,
|
||||
@ -5924,6 +5925,12 @@ def resize_storage_bytes_(variable, new_size):
|
||||
return variable
|
||||
|
||||
|
||||
@register_lowering(_bind_nn_parameter)
|
||||
def create_nn_parameter(self, placeholder):
|
||||
self.realize()
|
||||
return TensorBox.create(ir.BindNNParameter(self, placeholder))
|
||||
|
||||
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
make_fallback(auto_functionalized)
|
||||
|
Reference in New Issue
Block a user