[dynamo] Add support for nn.Parameter constructor (part 2) (#120965)

This handles the case where the tensor isn't an input.

The changes to dynamo tests are cases where we would previously fall back to eager.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120965
Approved by: https://github.com/yanboliang
ghstack dependencies: #121735
This commit is contained in:
Jason Ansel
2024-03-14 18:55:31 -07:00
committed by PyTorch MergeBot
parent 040b925753
commit 0b7d9711d4
99 changed files with 184 additions and 2 deletions

View File

@ -301,6 +301,36 @@ class DistributedPatternTests(TestCase):
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
def test_nn_param_return3(self):
def fn(x):
p = torch.nn.Parameter(x + 123)
return p, p.sin()
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
r1.sum().backward()
p2, r2 = opt(x2)
r2.sum().backward()
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
def test_nn_param_return4(self):
def fn(x):
p = torch.nn.Parameter(x + 123, requires_grad=False)
return p, x + 1
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
p2, r2 = opt(x2)
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
if __name__ == "__main__":
if HAS_CPU and not IS_MACOS:

View File

@ -0,0 +1,50 @@
import torch
from torch._prims import _make_prim, RETURN_TYPE
from torch._prims_common import clone_preserve_strides
doc = """
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
actually should be created we mutate this sacrificial placeholder into it. This allows gradients
to flow into the parameter as if it were an input to the graph (which is the only thing we are
allowed to compute gradients on).
""".strip()
_bind_nn_parameter = _make_prim(
schema="_bind_nn_parameter(Tensor self, Tensor placeholder) -> Tensor",
return_type=RETURN_TYPE.NEW,
meta=lambda self, placeholder: torch.nn.Parameter(
clone_preserve_strides(self), placeholder.requires_grad
),
impl_aten=lambda self, placeholder: placeholder.set_(self),
doc=doc,
)
torch.fx.node.has_side_effect(_bind_nn_parameter)
class TracableCreateParameter(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, placeholder):
assert not tensor.requires_grad
return _bind_nn_parameter(tensor, placeholder)
@staticmethod
def backward(ctx, grad):
return None, grad # grad flows to placeholder
def tracable_create_parameter(tensor, placeholder):
with torch.set_grad_enabled(placeholder.requires_grad):
return TracableCreateParameter.apply(tensor, placeholder)
def new_parameter_placeholder(size, dtype, device, requires_grad):
"""Create a placeholder to be passed to the above functions"""
result = torch.nn.Parameter(
torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
)
# TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
# Allocating a zero tensor would causes assert failures in autograd.
result.untyped_storage().resize_(0)
return result

View File

@ -69,6 +69,7 @@ from .source import (
LocalSource,
ParamBufferSource,
ShapeEnvSource,
SyntheticLocalSource,
TensorProperty,
TensorPropertySource,
)
@ -472,6 +473,28 @@ class OutputGraph(Checkpointable[OutputGraphState]):
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH))
def synthetic_graph_input(self, fn, args):
"""
call fn(*args) before the graph runs and turn the result into a fake input.
"""
example_value = fn(*args)
varname = self.new_var()
cg = PyCodegen(self.root_tx)
cg.load_import_from(
fn.__module__,
fn.__name__,
)
cg.foreach(map(variables.ConstantVariable.create, args))
cg.call_function(len(args), True)
cg.store(varname)
self.pregraph_bytecode.extend(cg.get_instructions())
source = SyntheticLocalSource(varname)
result = VariableBuilder(self.root_tx, source)(example_value)
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
source
)
return result
def add_cleanup_hook(self, fn: Callable[[], Any]):
self.cleanup_hooks.append(fn)

View File

@ -17,6 +17,7 @@ from torch._streambase import _StreamBase
from ..._guards import TracingContext
from .. import config, polyfill, variables
from ..codegen import PyCodegen
from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter
from ..device_interface import get_registered_device_interfaces
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
@ -840,7 +841,35 @@ Either create the tensor outside the compiled region, or do not set the tensor t
if data.source:
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
unimplemented("Parameter() on non-input")
try:
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
dtype = data.var_getattr(tx, "dtype").as_python_constant()
device = data.var_getattr(tx, "device").as_python_constant()
except NotImplementedError as e:
unimplemented(f"Parameter not python_constant: {e}")
placeholder = tx.output.synthetic_graph_input(
new_parameter_placeholder, [shape, dtype, device, requires_grad]
)
if data.requires_grad:
data = data.call_method(tx, "detach", [], {})
from .builder import wrap_fx_proxy
result = wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_function",
tracable_create_parameter,
(data.as_proxy(), placeholder.as_proxy()),
{},
),
)
assert isinstance(result, variables.TensorVariable)
result.class_type = torch.nn.Parameter
# In reconstruct() should use the original parameter. The one returned by the graph will be an alias.
result.source = placeholder.source
return result
@staticmethod
def _nn_param_via_prefix_insert(tx, data, requires_grad):

View File

@ -181,6 +181,26 @@ def _output_node(gm: torch.fx.GraphModule) -> torch.fx.Node:
return next(n for n in reversed(gm.graph.nodes) if n.op == "output")
def _input_node(gm: torch.fx.GraphModule, i: int) -> torch.fx.Node:
"""Fetch the i-th placeholder in the graph"""
seen = 0
for n in gm.graph.nodes:
if n.op == "placeholder":
if seen == i:
return n
seen += 1
raise IndexError(f"input {i} does not exist, only {seen} inputs in graph")
def _can_detach(node: torch.fx.Node):
"""
Avoid calling .detach() on inputs passed to _bind_nn_parameter()
"""
from torch._dynamo.create_parameter_op import _bind_nn_parameter
return all(n.target is not _bind_nn_parameter for n in node.users)
def aot_dispatch_autograd(
flat_fn,
flat_args: List[Any],
@ -317,7 +337,7 @@ def aot_dispatch_autograd(
== len(fw_metadata.input_info) + inner_meta.num_outputs_rng_offset
)
for i, (bw_out) in enumerate(bw_outs):
if bw_out is None:
if bw_out is None and _can_detach(_input_node(fx_g, i)):
_indices_of_inps_to_detach.append(i)
if aot_config.enable_log:

View File

@ -4496,6 +4496,29 @@ class ResizeStorageBytes(MutatingFirstArgExternKernel):
mark_node_as_mutating(self, variable)
class BindNNParameter(ExternKernelAlloc):
def __init__(self, variable, placeholder):
variable.freeze_layout()
super().__init__(
variable.get_layout(),
[variable, placeholder],
python_kernel_name="torch.ops.prims._bind_nn_parameter",
)
V.graph.never_reuse_buffers.add(variable.data.get_name())
V.graph.never_reuse_buffers.add(placeholder.get_name())
V.graph.never_reuse_buffers.add(self.get_name())
mark_node_as_mutating(self, variable, placeholder)
def get_alias_names(self):
return [self.inputs[0].get_name(), self.inputs[1].get_name()]
def get_mutation_names(self):
return [self.inputs[1].get_name()]
def has_side_effects(self):
return True
class ScatterFallback(ExternKernel):
"""
This needs to be a custom class to handle mutation properly.

View File

@ -13,6 +13,7 @@ import torch
import torch.ao.quantization.fx._decomposed
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.create_parameter_op import _bind_nn_parameter
from torch._higher_order_ops.triton_kernel_wrap import (
triton_kernel_wrapper_functional,
triton_kernel_wrapper_mutation,
@ -5924,6 +5925,12 @@ def resize_storage_bytes_(variable, new_size):
return variable
@register_lowering(_bind_nn_parameter)
def create_nn_parameter(self, placeholder):
self.realize()
return TensorBox.create(ir.BindNNParameter(self, placeholder))
from torch._higher_order_ops.auto_functionalize import auto_functionalized
make_fallback(auto_functionalized)