mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944 Approved by: https://github.com/Skylion007
778 lines
28 KiB
Python
778 lines
28 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import NamedTuple
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._C._functorch import (
|
|
_unwrap_for_grad,
|
|
_wrap_for_grad,
|
|
current_level,
|
|
TransformType,
|
|
)
|
|
from torch._functorch.apis import vmap
|
|
from torch._functorch.utils import enable_single_level_autograd_function
|
|
from torch._functorch.vmap import (
|
|
_add_batch_dim,
|
|
_broadcast_to_and_flatten,
|
|
restore_vmap,
|
|
unwrap_batched,
|
|
wrap_batched,
|
|
)
|
|
from torch._ops import HigherOrderOperator
|
|
from torch.autograd.forward_ad import _set_fwd_grad_enabled
|
|
|
|
|
|
# autograd.Function technically runs before the regular PyTorch dispatcher.
|
|
# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
|
|
# work with it. One day we might decide to change this, but until then,
|
|
# we need to give the illusion that autograd.Function runs before those things.
|
|
#
|
|
# We do this by using creating a custom HigherOrderOperator that only functorch
|
|
# dispatches specially.
|
|
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
|
|
def __init__(self) -> None:
|
|
super().__init__("custom_function_call")
|
|
|
|
def __call__(self, autograd_function, *args, **kwargs):
|
|
# When custom_function_call is done dispatching through functorch,
|
|
# it should just invoke the autograd.Function. This is consistent
|
|
# with the autograd.Function behavior of being invoked before the
|
|
# PyTorch dispatcher.
|
|
#
|
|
# This will lead us into trouble later down the line, but this is
|
|
# pre-existing. There is an invariant that a function traced by
|
|
# make_fx should have the same behavior when provided the same
|
|
# Tensor. However, make_fx sees autograd.Function as a composite
|
|
# (because autograd.Function happens before the Python dispatch key)
|
|
# and only traces the forward pass.
|
|
if torch._C._are_functorch_transforms_active():
|
|
return super().__call__(autograd_function, *args, **kwargs)
|
|
return autograd_function.apply(*args, **kwargs)
|
|
|
|
|
|
# "custom_function_call"
|
|
# This is the mechanism for an autograd.Function that works with functorch transforms.
|
|
# It wraps an autograd.Function; interactions with functorch transforms are defined
|
|
# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch
|
|
# dispatcher.
|
|
custom_function_call = CustomFunctionHigherOrderOperator()
|
|
|
|
|
|
# The grad rule for custom_function_call is to construct a new _SingleLevelFunction
|
|
# (autograd.Function that only works with a single layer (level) of functorch) that:
|
|
# - unwraps the inputs
|
|
# - redispatches to custom_function_call
|
|
# - wraps the outputs
|
|
# and whose backward pass calls the original autograd.Function's backward.
|
|
#
|
|
# Why do we need to redispatch to custom_function_call?
|
|
# -----------------------------------------------------
|
|
# This is consistent with how ATen operators work with functorch's grad transform:
|
|
# they always redispatch to the original operator.
|
|
# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
|
|
#
|
|
# grad1 will:
|
|
# - set up the autograd graph
|
|
# - unwrap the inputs
|
|
# - redispatch to at::sin (*)
|
|
# - rewrap the outputs on the return
|
|
#
|
|
# On the redispatch in (*), grad0 will:
|
|
# - set up the autograd graph
|
|
# - unwrap the inputs
|
|
# - redispatch to at::sin
|
|
# - rewrap the outputs on the return
|
|
#
|
|
# To "set up the autograd graph", we generate a _SingleLevelFunction
|
|
# and apply it.
|
|
@custom_function_call.py_impl(TransformType.Grad)
|
|
@custom_function_call.py_impl(TransformType.Jvp)
|
|
def custom_function_call_grad(interpreter, autograd_function, *operands):
|
|
Generated = generate_single_level_function(interpreter, autograd_function)
|
|
with enable_single_level_autograd_function():
|
|
flat_out = Generated.apply(*operands)
|
|
return flat_out
|
|
|
|
|
|
def generate_single_level_function(interpreter, autograd_function):
|
|
level = interpreter.level()
|
|
|
|
def forward(*operands):
|
|
unwrapped_operands = pytree.tree_map_only(
|
|
torch.Tensor, lambda x: _unwrap_for_grad(x, level), operands
|
|
)
|
|
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
|
|
# the transform. _SingleLevelFunction will turn off both fwd and bwd
|
|
# gradient computation and we need to turn it back on here.
|
|
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
|
|
unwrapped_output = custom_function_call(
|
|
autograd_function, *unwrapped_operands
|
|
)
|
|
|
|
# See NOTE [mark_dirty object identity check]
|
|
def wrap_fn(output):
|
|
return _wrap_for_grad(output, level)
|
|
|
|
return wrap_outputs_maintaining_identity(
|
|
unwrapped_output, unwrapped_operands, operands, wrap_fn
|
|
)
|
|
|
|
def setup_context(ctx, inputs, output):
|
|
return autograd_function.setup_context(ctx, inputs, output)
|
|
|
|
# backward is only used if the transform is TransformType.Grad
|
|
def backward(ctx, *grads):
|
|
result = autograd_function.backward(ctx, *grads)
|
|
return result
|
|
|
|
# jvp is only used if the transform is TransformType.Jvp
|
|
def jvp(ctx, *tangents):
|
|
result = autograd_function.jvp(ctx, *tangents)
|
|
return result
|
|
|
|
# This is the sequence of magic words to dynamically generate a Subclass with
|
|
# a given name. A Tensor's .grad_fn field has a class name that is the original
|
|
# autograd.Function's name + Backward, so we do this to generate some
|
|
# meaningful name.
|
|
name = f"{autograd_function.__name__}Generated"
|
|
Generated = type(
|
|
name,
|
|
(torch.autograd.function._SingleLevelFunction,),
|
|
{
|
|
"forward": staticmethod(forward),
|
|
"backward": staticmethod(backward),
|
|
"jvp": staticmethod(jvp),
|
|
"setup_context": staticmethod(setup_context),
|
|
},
|
|
)
|
|
return Generated
|
|
|
|
|
|
# wrap_outputs_maintaining_identity handles outputs from the vmap,
|
|
# backward (vjp), and jvp staticmethod. The way it distinguishes
|
|
# between the vmap case and the {backward, jvp} case is if the out_dims
|
|
# are specified or not.
|
|
#
|
|
# NB: we cannot use out_dims=None as the deciding factor. This because
|
|
# out_dims=None can still happen in the vmap staticmethod! What the
|
|
# user is saying in that case is that their output does not have a
|
|
# dimension that is being vmapped over, which is valid.
|
|
NO_OUT_DIMS = "not specified"
|
|
|
|
|
|
# NOTE [mark_dirty object identity check]
|
|
# autograd.Function's ctx.mark_dirty expect a returned input
|
|
# to have the same object identity as the input.
|
|
# Mode-only functorch will greatly simplify this logic.
|
|
def wrap_outputs_maintaining_identity(
|
|
outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS
|
|
):
|
|
flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs)
|
|
flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs)
|
|
|
|
unwrapped_input_to_orig_input = {
|
|
id(unwrapped): orig
|
|
for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
|
|
}
|
|
|
|
flat_outputs, spec = pytree.tree_flatten(outputs)
|
|
result = []
|
|
|
|
out_dims_specified = out_dims != NO_OUT_DIMS
|
|
|
|
if out_dims_specified:
|
|
flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
|
|
# _broadcast_to_and_flatten returns None if it is unable to broadcast.
|
|
# TODO: update following link from master to stable once that's out
|
|
if flat_out_dims is None:
|
|
raise RuntimeError(
|
|
f"The autograd.Function's vmap staticmethod returned an "
|
|
f"incompatible (output, out_dims) tuple. "
|
|
f"Expected out_dims={out_dims} "
|
|
f"to be compatible with the structure of `output`. "
|
|
f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "
|
|
f"but output has structure {spec}. "
|
|
f"For more details, please see "
|
|
f"https://pytorch.org/docs/main/notes/extending.func.html"
|
|
)
|
|
|
|
for i, output in enumerate(flat_outputs):
|
|
if not isinstance(output, torch.Tensor):
|
|
result.append(output)
|
|
continue
|
|
if id(output) in unwrapped_input_to_orig_input:
|
|
result.append(unwrapped_input_to_orig_input[id(output)])
|
|
continue
|
|
if out_dims_specified:
|
|
result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index]
|
|
else:
|
|
result.append(wrap_fn(output))
|
|
|
|
return pytree.tree_unflatten(result, spec)
|
|
|
|
|
|
# NOTE: [functorch vjp and autograd interaction]
|
|
# There's an edge case with the functorch vjp and autograd interaction
|
|
# that will eventually be fixed by mode-only functorch.
|
|
# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
|
|
# so we (the framework) need to do it manually. Regular PyTorch operators
|
|
# automatically do so this is consistent.
|
|
#
|
|
# class MyExp(torch.autograd.Function):
|
|
# @staticmethod
|
|
# def forward(x):
|
|
# return x.exp()
|
|
#
|
|
# @staticmethod
|
|
# def setup_context(ctx, inputs, output):
|
|
# y = output
|
|
# ctx.save_for_backward(y)
|
|
#
|
|
# @staticmethod
|
|
# def backward(gy):
|
|
# y, = ctx.saved_tensors()
|
|
# return MyMul.apply(gy, y)
|
|
#
|
|
# x = torch.randn([], requires_grad=True)
|
|
# gy = torch.randn([], requires_grad=True)
|
|
# _, vjp_fn = vjp(MySin.apply, x)
|
|
# result = vjp_fn(gy)
|
|
#
|
|
# MyMul is an autograd.Function that is not shown here.
|
|
# It saves a `y` for backward (since gy requires grad).
|
|
#
|
|
# in vjp_fn(gy), we get:
|
|
# > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
|
|
# Because the y that is saved for backward by MyExp is a GradTensorWrapper
|
|
# but is now dead since we are outside the vjp context.
|
|
#
|
|
# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
|
|
# will automatically unwrap the GradTensorWrapper when applied.
|
|
# But since autograd.Function technically sits above the regular PyTorch
|
|
# dispatcher, it doesn't get this treatment. So we manually do
|
|
# the unwrapping to be consistent with regular PyTorch dispatcher operations.
|
|
|
|
|
|
class VmapInfo(NamedTuple):
|
|
batch_size: int
|
|
randomness: str
|
|
|
|
|
|
def has_overridden_vmap_rule(autograd_function):
|
|
return autograd_function.vmap is not torch.autograd.Function.vmap
|
|
|
|
|
|
def validate_vmap_returns_tuple_of_two_elements(result):
|
|
base_error_msg = (
|
|
"Expected the vmap staticmethod to have two returns, an output "
|
|
"and out_dims with pytree structure compatible with the output. "
|
|
)
|
|
if not isinstance(result, tuple):
|
|
raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")
|
|
if not len(result) == 2:
|
|
raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")
|
|
|
|
|
|
@custom_function_call.py_impl(TransformType.Vmap)
|
|
def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs):
|
|
if any(
|
|
isinstance(val, torch.Tensor)
|
|
for val in torch.utils._pytree.tree_flatten(kwargs)[0]
|
|
):
|
|
raise NotImplementedError(
|
|
f"Run vmap on autograd.Function with kwarg-only Tensor args. "
|
|
f"Please do not pass kwarg-only Tensors to autograd.Function. "
|
|
f"Got: {kwargs}"
|
|
)
|
|
|
|
if autograd_function.generate_vmap_rule:
|
|
if has_overridden_vmap_rule(autograd_function):
|
|
# TODO: Update link to stable once that's out
|
|
# https://github.com/pytorch/pytorch/issues/92029
|
|
raise RuntimeError(
|
|
f"You tried to vmap over {autograd_function.__name__}, but "
|
|
f"it has both generate_vmap_rule=True and an overridden vmap "
|
|
f"staticmethod. Please set generate_vmap_rule=False or delete "
|
|
f"the overridden vmap staticmethod to avoid ambiguity. "
|
|
f"For more details, please see "
|
|
f"https://pytorch.org/docs/main/notes/extending.func.html"
|
|
)
|
|
return custom_function_call_vmap_generate_rule(
|
|
interpreter, autograd_function, *operands
|
|
)
|
|
|
|
if not has_overridden_vmap_rule(autograd_function):
|
|
# TODO: Update link to stable once that's out
|
|
# https://github.com/pytorch/pytorch/issues/92029
|
|
raise RuntimeError(
|
|
f"You tried to vmap over {autograd_function.__name__}, but "
|
|
f"it does not have vmap support. Please override and implement the "
|
|
f"vmap staticmethod or set generate_vmap_rule=True. "
|
|
f"For more details, please see "
|
|
f"https://pytorch.org/docs/main/notes/extending.func.html"
|
|
)
|
|
|
|
return custom_function_call_vmap_helper(
|
|
interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs
|
|
)
|
|
|
|
|
|
def custom_function_call_vmap_helper(
|
|
interpreter, vmap_function, op, *operands, **kwargs
|
|
):
|
|
current_level = interpreter.level()
|
|
info = VmapInfo(
|
|
batch_size=interpreter.batch_size(),
|
|
randomness=interpreter.randomness(),
|
|
)
|
|
# We're either in the autograd.Function case (vmap staticmethod)
|
|
# or the torch.library.register_vmap case.
|
|
autograd_function_case = isinstance(op, torch.autograd.function.FunctionMeta)
|
|
|
|
def lower_to_next():
|
|
if autograd_function_case:
|
|
return interpreter.lower()
|
|
else:
|
|
return torch._C._ExcludeDispatchKeyGuard(
|
|
torch._C.DispatchKeySet(torch._C.DispatchKey.FuncTorchBatched)
|
|
)
|
|
|
|
unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
|
|
# If none of the tensors are batched at the current level, then we skip the
|
|
# current level. This saves the user from needing to handle this case in
|
|
# their vmap staticmethod (and is consistent with our C++ batching rule API)
|
|
if pytree.tree_all(lambda dim: dim is None, in_dims):
|
|
with lower_to_next():
|
|
if autograd_function_case:
|
|
return custom_function_call(op, *operands)
|
|
else:
|
|
return op(*operands, **kwargs)
|
|
|
|
with lower_to_next():
|
|
result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs)
|
|
validate_vmap_returns_tuple_of_two_elements(result)
|
|
unwrapped_output, out_dims = result
|
|
|
|
# See NOTE [mark_dirty object identity check]
|
|
def wrap_fn(output, out_dim):
|
|
return (
|
|
output
|
|
if out_dim is None
|
|
else _add_batch_dim(output, out_dim, current_level)
|
|
)
|
|
|
|
return wrap_outputs_maintaining_identity(
|
|
unwrapped_output, unwrapped_operands, operands, wrap_fn, out_dims=out_dims
|
|
)
|
|
|
|
|
|
def unpack_outputs(outputs):
|
|
out_dims = outputs[-1]
|
|
if isinstance(out_dims, tuple):
|
|
outputs = outputs[:-1]
|
|
else:
|
|
outputs = outputs[0]
|
|
return outputs, out_dims
|
|
|
|
|
|
def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):
|
|
unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())
|
|
vmapped_function = vmapify_autograd_function(
|
|
autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness()
|
|
)
|
|
with interpreter.lower():
|
|
outputs = custom_function_call(vmapped_function, *unwrapped_operands)
|
|
|
|
assert isinstance(outputs, tuple)
|
|
outputs, out_dims = unpack_outputs(outputs)
|
|
return wrap_batched(outputs, out_dims, interpreter.level())
|
|
|
|
|
|
@custom_function_call.py_impl(TransformType.Functionalize)
|
|
def custom_function_call_functionalize(
|
|
interpreter, autograd_function, generate_vmap_rule, *operands
|
|
):
|
|
raise RuntimeError("NYI: Functionalize rule for custom_function_call")
|
|
|
|
|
|
def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):
|
|
def forward(*operands):
|
|
outputs, out_dims = restore_vmap(
|
|
autograd_function.forward, in_dims, batch_size, randomness
|
|
)(*operands)
|
|
if isinstance(outputs, torch.Tensor):
|
|
return outputs, out_dims
|
|
else:
|
|
return *outputs, out_dims
|
|
|
|
def setup_context(ctx, inputs, outputs):
|
|
outputs, out_dims = unpack_outputs(outputs)
|
|
key = id(Generated)
|
|
|
|
def inner(inputs, outputs):
|
|
# wrapped_ctx.save_for_backward will:
|
|
# - unwrap batchedtensors into (tensor, bdim)
|
|
# - save_for_backward(*unwrapped_tensors)
|
|
# - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims
|
|
wrapped_ctx = CtxCustomSave(ctx, current_level())
|
|
autograd_function.setup_context(wrapped_ctx, inputs, outputs)
|
|
|
|
# input_shapes are used for reductify later to reduce expanded gradients
|
|
# to the correct shape.
|
|
# See NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
|
|
# for more details
|
|
input_shapes = tuple(
|
|
inp.shape if isinstance(inp, torch.Tensor) else None for inp in inputs
|
|
)
|
|
if not hasattr(ctx, "_pt_input_shapes"):
|
|
ctx._pt_input_shapes = {}
|
|
ctx._pt_input_shapes.update({key: input_shapes})
|
|
|
|
if not hasattr(ctx, "_pt_saved_tensors_bdims_stack"):
|
|
ctx._pt_saved_tensors_bdims_stack = {}
|
|
ctx._pt_saved_tensors_bdims_stack.update(
|
|
{key: (wrapped_ctx._pt_saved_tensors_bdims)}
|
|
)
|
|
|
|
# See NOTE: [Why do we need to run setup_context under a vmap?]
|
|
restore_vmap(
|
|
inner,
|
|
(in_dims, out_dims),
|
|
batch_size,
|
|
randomness,
|
|
)(inputs, outputs)
|
|
|
|
if not hasattr(ctx, "_pt_out_dims"):
|
|
ctx._pt_out_dims = {}
|
|
ctx._pt_out_dims.update({key: out_dims})
|
|
|
|
def jvp(ctx, *tangents):
|
|
key = id(Generated)
|
|
|
|
def jvp_no_context(saved_tensors, tangents):
|
|
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
|
|
return autograd_function.jvp(wrapped_ctx, *tangents)
|
|
|
|
tangent_in_dims = get_tangents_in_dims(in_dims, tangents)
|
|
out_tangents, out_tangents_dims = restore_vmap(
|
|
jvp_no_context,
|
|
(ctx._pt_saved_tensors_bdims_stack[key], tangent_in_dims),
|
|
batch_size,
|
|
randomness,
|
|
)(ctx.saved_tensors, tangents)
|
|
|
|
result = reductify(
|
|
out_tangents, out_tangents_dims, ctx._pt_out_dims[key], batch_size
|
|
)
|
|
if isinstance(result, torch.Tensor):
|
|
return result, None
|
|
else:
|
|
return *result, None
|
|
|
|
def backward(ctx, *grad_outputs):
|
|
key = id(Generated)
|
|
grad_outputs_ = grad_outputs[:-1]
|
|
grad_outputs_in_dims = ctx._pt_out_dims[key]
|
|
|
|
if not isinstance(grad_outputs_in_dims, tuple):
|
|
grad_outputs_in_dims = (grad_outputs_in_dims,)
|
|
|
|
grad_outputs_in_dims = tuple(
|
|
in_dim if grad_output is not None else None
|
|
for grad_output, in_dim in zip(grad_outputs_, grad_outputs_in_dims)
|
|
)
|
|
|
|
def backward_no_context(inputs):
|
|
saved_tensors, grad_outputs = inputs
|
|
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
|
|
return autograd_function.backward(wrapped_ctx, *grad_outputs)
|
|
|
|
grad_ins, grad_ins_dims = restore_vmap(
|
|
backward_no_context,
|
|
((ctx._pt_saved_tensors_bdims_stack[key], grad_outputs_in_dims),),
|
|
batch_size,
|
|
randomness,
|
|
)((ctx.saved_tensors, grad_outputs_))
|
|
result = reductify(
|
|
grad_ins, grad_ins_dims, in_dims, batch_size, ctx._pt_input_shapes[key]
|
|
)
|
|
return result
|
|
|
|
name = f"Vmapped{autograd_function.__name__}"
|
|
Generated = type(
|
|
name,
|
|
(torch.autograd.Function,),
|
|
{
|
|
"forward": staticmethod(forward),
|
|
"backward": staticmethod(backward),
|
|
"jvp": staticmethod(jvp),
|
|
"setup_context": staticmethod(setup_context),
|
|
"generate_vmap_rule": True,
|
|
},
|
|
)
|
|
|
|
return Generated
|
|
|
|
|
|
# tangents might be None, so we need to replace
|
|
# the corresponding in_dims with None.
|
|
def get_tangents_in_dims(input_dims, tangents):
|
|
flat_in_dims, spec = pytree.tree_flatten(input_dims)
|
|
flat_tangents = pytree.arg_tree_leaves(*tangents)
|
|
result = [
|
|
None if tangent is None else in_dim
|
|
for in_dim, tangent in zip(flat_in_dims, flat_tangents)
|
|
]
|
|
return pytree.tree_unflatten(result, spec)
|
|
|
|
|
|
# NOTE: [Why do we need to run setup_context under a vmap?]
|
|
# Consider the following autograd.Function
|
|
#
|
|
# class Sum(torch.autograd.Function):
|
|
# @staticmethod
|
|
# def forward(x):
|
|
# return x.sum()
|
|
# @staticmethod
|
|
# def setup_context(ctx, inputs, outputs):
|
|
# ctx.x_shape = inputs[0]
|
|
# @staticmethod
|
|
# def backward(ctx, gy):
|
|
# return gy.expand(ctx.x_shape)
|
|
#
|
|
# x = torch.randn(B, 4)
|
|
# in_dims = 0
|
|
# vmap(Sum.apply, in_dims)(x)
|
|
#
|
|
# Let's assume for a moment that we didn't vmap setup_context in VmappedSum:
|
|
#
|
|
# class VmappedSum(torch.autograd.Function):
|
|
# @staticmethod
|
|
# def forward(x):
|
|
# return vmap(Sum.forward, in_dims)(x)
|
|
#
|
|
# @staticmethod
|
|
# def setup_context(ctx, inputs, outputs):
|
|
# Sum.setup_context(ctx, inputs, outputs)
|
|
#
|
|
# @staticmethod
|
|
# def backward(ctx, gy):
|
|
# def backward_no_context(gy):
|
|
# return gy.expand(ctx.x_shape)
|
|
#
|
|
# dims = (0,)
|
|
# gx = vmap(backward_no_context, dims)(gy)
|
|
# return gx
|
|
#
|
|
# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B],
|
|
# and we're doing:
|
|
#
|
|
# def backward_no_context(gy):
|
|
# return gy.expand([B, 4])
|
|
#
|
|
# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]")
|
|
#
|
|
# This gives us the wrong result (gx has shape [B, B, 4], but it should
|
|
# have shape [4]). Performing vmap over setup_context means the shape
|
|
# saved has shape [4] and leads to a correct result shape for gx.
|
|
|
|
|
|
# Wraps a ctx object. Forwards all attr accesses to the underlying object
|
|
# except for the attrs in _pt_attrs
|
|
class WrappedCtx:
|
|
_pt_reserved_attrs: tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx")
|
|
|
|
def __init__(self, ctx):
|
|
if not isinstance(ctx, WrappedCtx):
|
|
reserved_attrs = type(self)._pt_reserved_attrs
|
|
for name in reserved_attrs:
|
|
if not hasattr(ctx, name):
|
|
continue
|
|
raise RuntimeError(
|
|
f"PyTorch reserves the {reserved_attrs} field on ctx. "
|
|
"Please name your fields on ctx something else to avoid name "
|
|
"collision."
|
|
)
|
|
self._pt_inner_ctx = ctx
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._pt_inner_ctx, name)
|
|
|
|
def __setattr__(self, name, value):
|
|
if name in type(self)._pt_reserved_attrs:
|
|
self.__dict__[name] = value
|
|
return
|
|
return setattr(self._pt_inner_ctx, name, value)
|
|
|
|
|
|
# Wraps ctx to create a new ctx object that overrides saved_tensors.
|
|
class CtxWithSavedTensors(WrappedCtx):
|
|
_pt_reserved_attrs = ("_pt_new_saved_tensors", *WrappedCtx._pt_reserved_attrs)
|
|
|
|
def __init__(self, ctx, new_saved_tensors):
|
|
super().__init__(ctx)
|
|
self._pt_new_saved_tensors = new_saved_tensors
|
|
|
|
@property
|
|
def saved_tensors(self):
|
|
return self._pt_new_saved_tensors
|
|
|
|
|
|
class CtxCustomSave(WrappedCtx):
|
|
_pt_reserved_attrs = (
|
|
"_pt_saved_tensors_bdims",
|
|
"_pt_current_level",
|
|
*WrappedCtx._pt_reserved_attrs,
|
|
)
|
|
|
|
def __init__(self, ctx, current_level):
|
|
super().__init__(ctx)
|
|
self._pt_saved_tensors_bdims = ()
|
|
self._pt_current_level = current_level
|
|
|
|
def save_for_backward(self, *tensors):
|
|
unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
|
|
self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)
|
|
self._pt_saved_tensors_bdims = bdims
|
|
|
|
def save_for_forward(self, *tensors):
|
|
unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
|
|
self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)
|
|
self._pt_saved_tensors_bdims = bdims
|
|
|
|
|
|
def reductify(
|
|
grad_input,
|
|
grad_input_bdim,
|
|
input_bdim,
|
|
batch_size,
|
|
target_shape_without_bdim_to_reduce_to=None,
|
|
):
|
|
if not isinstance(grad_input, tuple):
|
|
grad_input = (grad_input,)
|
|
if not isinstance(grad_input_bdim, tuple):
|
|
grad_input_bdim = (grad_input_bdim,)
|
|
if not isinstance(input_bdim, tuple):
|
|
input_bdim = (input_bdim,)
|
|
|
|
if target_shape_without_bdim_to_reduce_to is None:
|
|
target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)
|
|
result = tuple(
|
|
reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)
|
|
for gi, gi_bdim, i_bdim, maybe_ishape in zip(
|
|
grad_input,
|
|
grad_input_bdim,
|
|
input_bdim,
|
|
target_shape_without_bdim_to_reduce_to,
|
|
)
|
|
)
|
|
return result
|
|
|
|
|
|
def reductify_leaf(
|
|
grad_input,
|
|
grad_input_bdim,
|
|
input_bdim,
|
|
batch_size,
|
|
target_shape_without_bdim_to_reduce_to=None,
|
|
):
|
|
if grad_input is None:
|
|
return None
|
|
|
|
if grad_input_bdim is None and input_bdim is None:
|
|
return grad_input
|
|
|
|
if grad_input_bdim is not None and input_bdim is None:
|
|
return grad_input.sum(grad_input_bdim)
|
|
|
|
# NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
|
|
# For reverse-mode AD,
|
|
# given a grad_input and input, it is valid for the user to return a
|
|
# grad_input that has a broadcasted shape when compared to the input.
|
|
# In this situation, autograd automatically reduces the grad_input to
|
|
# the shape of the input.
|
|
#
|
|
# However, when input_bdim is not None, we have problems.
|
|
#
|
|
# [example 1]
|
|
# grad_input: Tensor[3, 4], input: Tensor[B, 4]
|
|
# We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable
|
|
# from [B, 4].
|
|
#
|
|
# [example 2]
|
|
# grad_input: Tensor[3, B, 4], input: Tensor[B, 4]
|
|
# We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable
|
|
# from [B, 4].
|
|
#
|
|
# This means that we need to also reduce the grad_input to the shape of the
|
|
# input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag;
|
|
# if not-None then we do the reducing manually, otherwise, we do not do a reduction.
|
|
assert input_bdim is not None
|
|
|
|
if grad_input_bdim is None:
|
|
grad_input = grad_input.unsqueeze(input_bdim)
|
|
new_shape = list(grad_input.shape)
|
|
new_shape[input_bdim] = batch_size
|
|
grad_input = grad_input.expand(new_shape)
|
|
grad_input_bdim = input_bdim
|
|
|
|
if target_shape_without_bdim_to_reduce_to is not None:
|
|
return vmap(
|
|
torch.Tensor.sum_to_size,
|
|
in_dims=(grad_input_bdim, None),
|
|
out_dims=input_bdim,
|
|
)(grad_input, target_shape_without_bdim_to_reduce_to)
|
|
|
|
if input_bdim != grad_input_bdim:
|
|
grad_input = grad_input.movedim(grad_input_bdim, input_bdim)
|
|
return grad_input
|
|
|
|
|
|
def autograd_function_forward_rewritten(original_forward, original_setup_context):
|
|
def new_forward(ctx, *args, **kwargs):
|
|
output = original_forward(*args, **kwargs)
|
|
original_setup_context(ctx, args, output)
|
|
return output
|
|
|
|
return new_forward
|
|
|
|
|
|
class AutogradFunctionApply(HigherOrderOperator):
|
|
def __init__(self) -> None:
|
|
super().__init__("autograd_function_apply")
|
|
|
|
def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
|
|
saved_values = None
|
|
args_tensor_mask = fwd_kwargs["args_tensor_mask"]
|
|
non_differentiable_idx = fwd_kwargs["non_differentiable_idx"]
|
|
length_of_tensor_args = sum(args_tensor_mask)
|
|
# Filter out the original tensor args from fwd_args,
|
|
# lifted freevars should not be args of ApplyTemplate.apply
|
|
# since we don't need to calculate the gradients of them.
|
|
new_fwd_args = fwd_args[:length_of_tensor_args]
|
|
|
|
class ApplyTemplate(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, *args):
|
|
nonlocal saved_values
|
|
output, saved_values = fwd(None, *fwd_args)
|
|
|
|
# If users call ctx.mark_non_differentiable() in the original fwd function.
|
|
if len(non_differentiable_idx) > 0:
|
|
non_differentiable_output = []
|
|
for i, x in enumerate(output):
|
|
if i in non_differentiable_idx:
|
|
non_differentiable_output.append(x)
|
|
ctx.mark_non_differentiable(*non_differentiable_output)
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grad):
|
|
return bwd(None, *grad, *saved_values)
|
|
|
|
return ApplyTemplate.apply(*new_fwd_args)
|
|
|
|
|
|
autograd_function_apply = AutogradFunctionApply()
|