Files
pytorch/test/functorch/test_eager_transforms.py
ekamiti 9e473fd868 Make adding Buffers more like adding Parameters (#125971)
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new Buffer class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the register_buffer method has not been changed. The persistent parameter in the Buffer type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new Buffer type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the Buffer type can be used as a drop in replacement for register_buffer as it just leads to register_buffer being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.

Fixes #35735

Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125971
Approved by: https://github.com/albanD, https://github.com/anijain2305, https://github.com/mlazos
2024-07-31 10:32:40 +00:00

5246 lines
172 KiB
Python

# Owner(s): ["module: functorch"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import math
import os
import subprocess
import sys
import unittest
import warnings
from functools import partial, wraps
# NB: numpy is a testing dependency!
import numpy as np
from common_utils import expectedFailureIf
import functorch
import torch
import torch.autograd.forward_ad as fwAD
import torch.nn as nn
import torch.nn.functional as F
from functorch import (
combine_state_for_ensemble,
grad,
grad_and_value,
hessian,
jacfwd,
jacrev,
jvp,
make_functional,
make_functional_with_buffers,
make_fx,
vjp,
vmap,
)
from functorch.experimental import functionalize, replace_all_batch_norm_modules_
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
from torch._dynamo import allow_in_graph
from torch._functorch.eager_transforms import _slice_argnums
from torch._functorch.make_functional import (
functional_init,
functional_init_with_buffers,
)
from torch._functorch.utils import enable_single_level_autograd_function
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.func import functional_call, linearize, stack_module_state
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import (
SM70OrLater,
TEST_CUDA,
tf32_on_and_off,
with_tf32_off,
)
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
)
from torch.testing._internal.common_dtype import get_all_fp_dtypes
from torch.testing._internal.common_utils import (
freeze_rng_state,
instantiate_parametrized_tests,
IS_FBCODE,
IS_WINDOWS,
markDynamoStrictTest,
parametrize,
run_tests,
skipIfRocm,
skipIfTorchDynamo,
subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
xfailIfTorchDynamo,
)
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
USE_TORCHVISION = False
try:
import torchvision # noqa: F401
USE_TORCHVISION = True
except ImportError:
warnings.warn(
"Couldn't import torchvision. Some of our tests use it, try "
"to install it with commands from pytorch.org, post-fixed with "
"`--no-deps` to avoid overwriting the pytorch installation",
UserWarning,
)
# TestCase for _slice_argnums, an important helper function
class VmapTearDownMixin:
def tearDown(self):
# Ensure that in the case of a test failure, the next test won't fail
# because of a previous call to _vmap_increment_nesting that wasn't undone
# i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1
# and the call to increment nesting is not undone
if not TEST_WITH_TORCHDYNAMO:
return
warn = False
while ci := torch._C._functorch.peek_interpreter_stack():
if ci.key() == torch._C._functorch.TransformType.Vmap:
warn = True
torch._C._functorch._vmap_decrement_nesting()
else:
break
if warn:
msg = (
"Interpreter stack is not empty. Test should have called "
"'torch._C._functorch._vmap_decrement_nesting()'"
)
warnings.warn(msg)
@markDynamoStrictTest
class TestSliceArgnums(TestCase):
def test_invalid_argnum_type(self):
x = torch.randn(3)
args = (x,)
with self.assertRaisesRegex(RuntimeError, "int or Tuple"):
_slice_argnums(args, 0.0)
with self.assertRaisesRegex(RuntimeError, "int or Tuple"):
_slice_argnums(args, [0])
with self.assertRaisesRegex(RuntimeError, "must be int"):
_slice_argnums(args, (0.0,))
args = (0.1, 1.1, 2.1, 3.1, 4.1)
with self.assertRaisesRegex(RuntimeError, "must be int"):
_slice_argnums(args, ((0, 1), 2))
def test_out_of_bounds_argnum_values(self):
x = torch.randn(3)
args = (x,)
with self.assertRaisesRegex(RuntimeError, "positional inputs"):
_slice_argnums(args, 1)
with self.assertRaisesRegex(RuntimeError, "positional inputs"):
_slice_argnums(args, -2)
with self.assertRaisesRegex(RuntimeError, "positional inputs"):
_slice_argnums(args, (-2,))
def test_not_enough_argnums(self):
x = torch.randn(3)
args = (x,)
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
_slice_argnums(args, ())
def test_duplicate_argnums(self):
x = torch.randn(3)
args = (x, x)
with self.assertRaisesRegex(RuntimeError, "must be unique"):
_slice_argnums(args, (0, 0))
with self.assertRaisesRegex(RuntimeError, "must be unique"):
_slice_argnums(args, (0, -2))
def test_flat_args_with_positive_int_argnum(self):
args = (0.1, 1.1, 2.1, 3.1, 4.1)
res = _slice_argnums(args, 0)
self.assertEqual(res, (0.1,))
res = _slice_argnums(args, 4)
self.assertEqual(res, (4.1,))
def test_flat_args_with_negative_int_argnum(self):
args = (0.1, 1.1, 2.1, 3.1, 4.1)
res = _slice_argnums(args, -1)
self.assertEqual(res, (4.1,))
res = _slice_argnums(args, -5)
self.assertEqual(res, (0.1,))
def test_flat_args_with_tuple_argnum(self):
args = (0.1, 1.1, 2.1, 3.1, 4.1)
res = _slice_argnums(args, (0, 1, 2, 3, 4))
self.assertEqual(res, args)
res = _slice_argnums(args, (0, -3))
self.assertEqual(res, (0.1, 2.1))
def test_pytree_args(self):
args = ((0.1, 1.1), 2.0, [3.1])
res = _slice_argnums(args, 0)
self.assertEqual(res, args[0:1])
res = _slice_argnums(args, (0,))
self.assertEqual(res, args[0:1])
res = _slice_argnums(args, -1)
self.assertEqual(res, args[-1:])
res = _slice_argnums(args, (0, -2))
self.assertEqual(res, args[0:2])
def test_argnums_reorders(self):
args = ((0.1, 1.1, 2.1), 3.1, 4.1)
res = _slice_argnums(args, (1, 0))
self.assertEqual(res, (args[1], args[0]))
def _get_weights_and_functional_call(net, mechanism):
if mechanism == "make_functional":
return make_functional(net)
else:
assert mechanism == "functional_call"
# this makes it so the function from make_functional and this call have the same signature
def net_func(weights, data):
return functional_call(net, weights, (data,))
return net_func, dict(net.named_parameters())
def _get_weights_and_functional_call_with_buffers(net, mechanism):
if mechanism == "make_functional":
return make_functional_with_buffers(net)
else:
assert mechanism == "functional_call"
# this makes it so the function from make_functional and this call have the same signature
def net_func(weights, buffers, data):
return functional_call(net, (weights, buffers), (data,))
return net_func, dict(net.named_parameters()), dict(net.named_buffers())
@markDynamoStrictTest
class TestGradTransform(TestCase):
def test_primitive(self, device):
x = torch.randn([], device=device)
result = grad(torch.sin)(x)
self.assertEqual(result, torch.cos(x))
def test_composite_simple(self, device):
x = torch.randn(2, 3, 4, device=device)
result = grad(lambda x: torch.flatten(x).sum())(x)
self.assertEqual(result, torch.ones_like(x))
def test_fn_with_kwargs(self, device):
def foo(x, y):
return (x * y).sum()
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
expected = grad(foo)(x, y)
result = grad(foo)(x, y=y)
self.assertEqual(result, expected)
def test_composite_complicated(self, device):
x = torch.randn(3, device=device)
y = torch.randn(3, 5, device=device)
def foo(x, y):
result = x @ y
return result.sum()
result = grad(foo)(x, y)
x.requires_grad_()
out = foo(x, y)
(expected,) = torch.autograd.grad(out, x)
self.assertEqual(result, expected)
def test_composite_two_ops(self, device):
N, C = 2, 5
y = torch.randn(N, C, device=device)
targets = torch.randint(0, C, (N,), device=device)
def foo(y, targets):
return F.cross_entropy(y, targets)
result = grad(foo)(y, targets)
y.requires_grad_()
(expected,) = torch.autograd.grad(foo(y, targets), y)
self.assertEqual(result, expected)
def _test_attributes(self, get_attr_lambda, device):
x = torch.randn(2, 3, 5, dtype=torch.double, device=device)
expected = get_attr_lambda(x)
def foo(x):
self.assertEqual(get_attr_lambda(x), expected)
return x.sum()
grad(foo)(x)
def test_shape(self, device):
self._test_attributes(lambda x: x.shape, device)
def test_dtype(self, device):
self._test_attributes(lambda x: x.dtype, device)
def test_is_cuda(self, device):
self._test_attributes(lambda x: x.is_cuda, device)
def test_numel(self, device):
self._test_attributes(lambda x: x.numel(), device)
def test_inplace(self, device):
x = torch.randn([], device=device)
def foo(x):
return x.clone().sin_()
result = grad(foo)(x)
self.assertEqual(result, x.cos())
def test_inplace_on_view(self, device):
x = torch.randn(3, device=device)
def foo(x):
y = x.clone()
y0 = y[0]
y0.sin_()
return y.sum()
result = grad(foo)(x)
x.requires_grad_()
out = foo(x)
(expected,) = torch.autograd.grad(out, x)
self.assertEqual(result, expected)
def test_inplace_on_view_base(self, device):
x = torch.randn(3, device=device)
def foo(x):
y = x.clone()
y0 = y[0]
y.sin_()
return y0
result = grad(foo)(x)
x.requires_grad_()
out = foo(x)
(expected,) = torch.autograd.grad(out, x)
self.assertEqual(result, expected)
def test_inplace_on_captures(self, device):
x = torch.tensor([1.0, 2.0, 3.0], device=device)
captured = torch.randn(3, device=device)
def foo(x):
captured.copy_(x)
return (x * captured).sum()
with self.assertRaisesRegex(RuntimeError, "mutate a captured Tensor"):
grad(foo)(x)
def test_nesting_simple(self, device):
x = torch.randn([], device=device)
result = grad(grad(torch.sin))(x)
self.assertEqual(result, -torch.sin(x))
@skipIfTorchDynamo("Ref: https://github.com/pytorch/pytorch/issues/103613")
def test_escaped_wrappers_are_marked_as_dead(self, device):
x = torch.randn([], device=device)
escaped = []
def foo(x):
y = x.sin()
escaped.append(y)
return y
grad(foo)(x)
self.assertEqual(torch._C._functorch.dlevel(escaped[0]), -1)
@skipIfTorchDynamo("Ref: https://github.com/pytorch/pytorch/issues/103613")
def test_escaped_wrappers_are_ignored(self, device):
x = torch.randn([], device=device)
escaped = []
def foo(x):
y = x.sin()
escaped.append(y)
return y
grad(foo)(x)
something = escaped[0].sum()
self.assertEqual(torch._C._functorch.dlevel(something), 0)
self.assertEqual(something, x.sin().sum())
def test_manual_seed_inside_grad(self, device):
x = torch.randn([], device=device)
def f(x):
torch.manual_seed(0)
return x * torch.randn_like(x)
with freeze_rng_state():
result = grad(f)(x)
x.requires_grad_()
(expected,) = torch.autograd.grad(f(x), x)
self.assertEqual(result, expected)
def test_vjp(self, device):
x = torch.randn([], device=device)
out, vjp_fn = vjp(torch.sin, x)
self.assertEqual(out, x.sin())
v = torch.randn([], device=device)
(result,) = vjp_fn(v)
self.assertEqual(result, v * x.cos())
def test_vjp_two_outputs(self, device):
def f(x):
return x, x
result, vjp_fn = vjp(f, torch.tensor(1.0))
vjp_fn(result)
def test_conj_bit(self):
x = torch.tensor(1 + 1j)
def foo(x):
assert not x.is_conj()
y = x.conj()
assert y.is_conj()
return y.abs()
res = grad(foo)(x)
with torch.no_grad():
self.assertEqual(res, torch.ones_like(res) * torch.sgn(x))
def test_composed_with_autograd(self, device):
x = torch.randn([], requires_grad=True, device=device)
y = grad(torch.sin)(x)
(result,) = torch.autograd.grad(y, x)
self.assertEqual(result, -x.sin())
def test_grad_of_vjp_composition(self, device):
x = torch.randn([], device=device)
y = torch.randn([], device=device)
def foo(x, y):
out, vjp_fn = vjp(torch.sin, x)
return grad(lambda y: vjp_fn(y)[0])(y)
result = foo(x, y)
expected = x.cos()
self.assertEqual(result, expected)
def test_vjp_of_grad_composition(self, device):
x = torch.randn([], device=device)
y = torch.randn([], device=device)
def foo(x, y):
out, vjp_fn = vjp(grad(torch.sin), x)
return vjp_fn(y)[0]
result = foo(x, y)
expected = -y * x.sin()
self.assertEqual(result, expected)
def test_grad_of_vjp_of_grad_composition(self, device):
x = torch.randn([], device=device)
y = torch.randn([], device=device)
def foo(x, y):
df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x)
return grad(lambda y: vjp_fn(y)[0])(y)
result = foo(x, y)
expected = x.cos()
self.assertEqual(result, expected)
def test_views(self, device):
x = torch.randn([], requires_grad=True, device=device)
y = torch.randn([], requires_grad=True, device=device)
def silly_sin(x):
x = x.view([])
x = x.sin()
return x
def foo(x, y):
z1 = grad(silly_sin)(x)
z2 = torch.cos(y)
return z1 + z2
result = foo(x, y)
grads = torch.autograd.grad(result, [x, y])
self.assertEqual(grads[0], -x.sin())
self.assertEqual(grads[1], -y.sin())
def test_view_inplace_simple(self, device):
def foo(x):
x = x.clone()
x.view([]).sin_()
return x
x = torch.randn([], requires_grad=True, device=device)
result = grad(foo)(x)
self.assertEqual(result, x.cos())
def test_invalid_argnums(self, device):
x = torch.randn([])
y = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "but only"):
grad(torch.mul, argnums=-3)(x, y)
with self.assertRaisesRegex(RuntimeError, "but only"):
grad(torch.mul, argnums=2)(x, y)
with self.assertRaisesRegex(RuntimeError, "int or Tuple"):
grad(torch.mul, argnums=[0])(x, y)
with self.assertRaisesRegex(RuntimeError, "must be int"):
grad(torch.mul, argnums=("0",))(x, y)
with self.assertRaisesRegex(RuntimeError, "must be unique"):
grad(torch.mul, argnums=(0, 0))(x, y)
with self.assertRaisesRegex(RuntimeError, "must be unique"):
grad(torch.mul, argnums=(0, -2))(x, y)
def test_argnums(self, device):
x = torch.randn([])
y = torch.randn([])
gx = grad(torch.mul, argnums=0)(x, y)
self.assertEqual(gx, y)
gy = grad(torch.mul, argnums=1)(x, y)
self.assertEqual(gy, x)
(gx,) = grad(torch.mul, argnums=(0,))(x, y)
self.assertEqual(gx, y)
gx, gy = grad(torch.mul, argnums=(0, 1))(x, y)
self.assertEqual(gx, y)
self.assertEqual(gy, x)
def test_out_of_order_argnums(self, device):
x = torch.randn([])
y = torch.randn([])
gy, gx = grad(torch.mul, argnums=(1, 0))(x, y)
self.assertEqual(gx, y)
self.assertEqual(gy, x)
def test_negative_argnums(self, device):
x = torch.randn([])
y = torch.randn([])
gx = grad(torch.mul, argnums=-2)(x, y)
self.assertEqual(gx, y)
gy = grad(torch.mul, argnums=-1)(x, y)
self.assertEqual(gy, x)
(gx,) = grad(torch.mul, argnums=(-2,))(x, y)
self.assertEqual(gx, y)
gx, gy = grad(torch.mul, argnums=(-2, -1))(x, y)
self.assertEqual(gx, y)
self.assertEqual(gy, x)
def test_grad_pytree_inputs(self, device):
x = torch.randn([], device=device)
def f(a, b):
x, y = a
return 1 * x + 2 * y + 3 * b["foo"]
args = ((x, x), {"foo": x})
gx, gy = grad(f)(*args)
self.assertEqual(gx, torch.tensor(1.0, device=device))
self.assertEqual(gy, torch.tensor(2.0, device=device))
((gx, gy),) = grad(f, argnums=(0,))(*args)
self.assertEqual(gx, torch.tensor(1.0, device=device))
self.assertEqual(gy, torch.tensor(2.0, device=device))
(gx, gy), gz = grad(f, argnums=(0, 1))(*args)
self.assertEqual(gx, torch.tensor(1.0, device=device))
self.assertEqual(gy, torch.tensor(2.0, device=device))
self.assertEqual(gz["foo"], torch.tensor(3.0, device=device))
def test_grad_aux_tensor(self, device):
x = torch.randn(3, device=device)
with self.assertRaisesRegex(
RuntimeError,
r"grad_and_value\(f\)\(\*args\): output of function f should be a tuple",
):
grad(lambda t: [t, t], has_aux=True)(x)
with self.assertRaisesRegex(
RuntimeError,
r"grad_and_value\(f\)\(\*args\): output of function f should be a tuple",
):
grad(lambda t: (t, t + 2, t + 3), has_aux=True)(x)
def f(t):
y = t.sin()
return y.sum(), t.cos()
out, aux = grad(f, has_aux=True)(x)
self.assertEqual(aux, x.cos())
self.assertEqual(out, x.cos())
def test_grad_aux_pytree(self, device):
def f(x):
y = x.sin()
return y.sum(), {"a": x.cos(), "b": [x.tan()]}
x = torch.randn(3, device=device)
out, aux = grad(f, has_aux=True)(x)
_, expected_aux = f(x)
self.assertEqual(aux, expected_aux)
self.assertEqual(out, x.cos())
for aux in [1, 1.0, "abc"]:
with self.assertRaisesRegex(
RuntimeError, r"Expected tensors, got unsupported type"
):
_ = grad(lambda x: (x.sum(), aux), has_aux=True)(x)
with self.assertRaisesRegex(
RuntimeError, r"Expected tensors, got unsupported type"
):
_ = grad(lambda x: (x.sum(), [x, aux]), has_aux=True)(x)
def test_zero_grad(self, device):
def f(x):
return (x["a"] ** 2.0).sum()
inps = {
"a": torch.randn(10, device=device) + 3,
"b": torch.randn(10, device=device),
}
grads = grad(f)(inps)
self.assertNotEqual(grads["a"].sum(), 0.0)
self.assertEqual(grads["b"].sum(), 0.0)
def test_unrelated_grad(self, device):
x = torch.tensor(1.0, device=device)
y = torch.tensor(2.0, device=device)
def unrelated(x):
return y
result = grad(unrelated)(x)
self.assertEqual(result, torch.zeros_like(x))
def test_unrelated_vjp(self, device):
x = torch.tensor(1.0, device=device)
y = torch.tensor(2.0, device=device)
v = torch.tensor(1.0, device=device)
def unrelated(x):
return y
out, vjp_fn = vjp(unrelated, x)
result = vjp_fn(v)
expected = (torch.zeros_like(x),)
self.assertEqual(result, expected)
def test_unrelated_vjp_multiple_inputs_outputs(self, device):
w = torch.tensor(3.0, device=device)
x = torch.tensor(4.0, device=device)
y = torch.tensor(2.0, device=device)
v = torch.tensor(1.0, device=device)
def unrelated(w, x):
return y, y, x
out, vjp_fn = vjp(unrelated, w, x)
result = vjp_fn((v, v, v))
expected = (torch.zeros_like(x), torch.ones_like(x))
self.assertEqual(result, expected)
# TODO: https://github.com/zou3519/functorch/issues/12
@onlyCPU
def test_unrelated_hessian(self, device):
N = 5
M = 3
W = torch.randn(N, M, device=device)
def f(x):
return W @ x
x = torch.randn(M)
result = jacrev(jacrev(f))(x)
expected = torch.zeros(N, M, M, device=device)
self.assertEqual(result, expected)
def test_vjp_pytree_input(self, device):
def f(x):
return x[0] * x[1][0]
x = torch.randn([], device=device)
v = torch.randn([], device=device)
out, vjp_fn = vjp(f, (x, (x, x)))
self.assertEqual(out, x * x)
result = vjp_fn(v)
self.assertEqual(result, ((x * v, (x * v, 0.0)),))
def test_vjp_pytree_output(self, device):
def f(x):
return x, (x, x)
x = torch.randn([], device=device)
v1 = torch.randn([], device=device)
v2 = torch.randn([], device=device)
v3 = torch.randn([], device=device)
_, vjp_fn = vjp(f, x)
(result,) = vjp_fn((v1, (v2, v3)))
self.assertEqual(result, v1 + v2 + v3)
def test_vjp_outputs_can_any_pytree(self, device):
x = torch.randn(2, 3, device=device)
t = torch.randn(2, 3, device=device)
for output in [None, ()]:
with self.assertRaisesRegex(
RuntimeError,
r"vjp\(f, \*primals\): Expected f to be a function that has non-empty output",
):
_, vjp_fn = vjp(lambda _: output, x)
vjp_fn(t)
for output in [1, True, 12.2, "abc"]:
with self.assertRaisesRegex(
RuntimeError,
r"vjp\(f, \*primals\): expected f\(\*primals\) to return only tensors",
):
_, vjp_fn = vjp(lambda _: output, x)
vjp_fn(t)
# Check list output
output, vjp_fn = vjp(lambda x: [x, x.sum()], x)
(vjp_out,) = vjp_fn([t, t.sum()])
assert isinstance(output, list) and len(output) == 2
assert isinstance(vjp_out, torch.Tensor)
# Check dict output
output, vjp_fn = vjp(lambda x: {"x": x, "xsum": x.sum()}, x)
(vjp_out,) = vjp_fn({"x": t, "xsum": t.sum()})
assert isinstance(output, dict) and len(output) == 2 and "xsum" in output
assert isinstance(vjp_out, torch.Tensor)
def composite_output(x):
out = x.sum()
return [
(out, {"a": x, "out": [x, out]}),
]
output, vjp_fn = vjp(composite_output, x)
(vjp_out,) = vjp_fn(
[
(t.sum(), {"a": t, "out": [t, t.sum()]}),
]
)
assert isinstance(output, list)
assert isinstance(output[0], tuple) and isinstance(output[0][1], dict)
assert isinstance(vjp_out, torch.Tensor)
def test_vjp_pytree_error(self, device):
def f(x):
return x, (x, x)
x = torch.randn([], device=device)
v1 = torch.randn([], device=device)
v2 = torch.randn([], device=device)
v3 = torch.randn([], device=device)
_, vjp_fn = vjp(f, x)
with self.assertRaisesRegex(RuntimeError, "Expected pytree structure"):
(result,) = vjp_fn(((v1, (v2, v3)),))
def test_vjp_aux_tensor(self, device):
x = torch.randn(3, device=device)
with self.assertRaisesRegex(
RuntimeError, r"vjp\(f, \*primals\): output of function f should be a tuple"
):
vjp(lambda t: [t, t], x, has_aux=True)
with self.assertRaisesRegex(
RuntimeError, r"vjp\(f, \*primals\): output of function f should be a tuple"
):
vjp(lambda t: (t, t + 2, t + 3), x, has_aux=True)
def f(t):
y = t.sin()
return y, t.cos()
out, vjp_fn, aux = vjp(f, x, has_aux=True)
self.assertEqual(aux, x.cos())
self.assertEqual(out, x.sin())
v = torch.randn(3, device=device)
(grad_x,) = vjp_fn(v)
self.assertEqual(grad_x, v * x.cos())
def test_vjp_aux_pytree(self, device):
def f(x):
y = x.sin()
return y, {"a": x.cos(), "b": [x.tan()]}
x = torch.randn(3, device=device)
out, vjp_fn, aux = vjp(f, x, has_aux=True)
expected_out, expected_aux = f(x)
self.assertEqual(out, expected_out)
self.assertEqual(aux, expected_aux)
v = torch.randn(3, device=device)
(grad_x,) = vjp_fn(v)
self.assertEqual(grad_x, v * x.cos())
for aux in [1, 1.0, "abc"]:
with self.assertRaisesRegex(
RuntimeError, r"Expected tensors, got unsupported type"
):
_ = vjp(lambda x: (x, aux), x, has_aux=True)
with self.assertRaisesRegex(
RuntimeError, r"Expected tensors, got unsupported type"
):
_ = vjp(lambda x: (x, [x, aux]), x, has_aux=True)
def test_functional_init(self, device):
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes
self.fc1 = nn.Linear(2, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x
B = 10
weights, fn, _ = functional_init(MLPClassifier, (B,), device=device)(32, 2)
inputs = torch.randn(B, 7, 2, device=device)
vmap(fn)(weights, (inputs,))
def test_functional_init_with_buffers(self, device):
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes
self.fc1 = nn.Linear(2, self.hidden_dim)
self.bn = nn.BatchNorm1d(self.hidden_dim, affine=True)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.bn(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x
B = 10
weights, buffers, fn, _, _ = functional_init_with_buffers(
MLPClassifier, [B], device=device
)(32, 2)
inputs = torch.randn(B, 7, 2, device=device)
vmap(fn)(weights, buffers, (inputs,))
def test_advanced_indexing(self, device):
def f(value):
log_prob = torch.ones((), device=device)
val = torch.zeros(()) > 0
log_prob[val] = 0
return value
result = grad(f)(torch.randn((), device=device))
self.assertEqual(result, torch.ones_like(result))
def f2(value):
value = value.clone()
value[value > 0] = 0
return value.sum()
x = torch.randn(100, device=device)
result = grad(f2)(x)
self.assertEqual(result, (x <= 0).type_as(x))
def test_tensor_ctor_inside_grad(self, device):
def foo(x):
return x * torch.tensor(2.0, device=device)
x = torch.tensor(3.14, device=device)
functorch.grad(foo)(x)
@parametrize(
"op_list_data",
[
subtest(
(
[
vmap,
],
[(4, 2), (64, 3, 32, 32)],
),
name="vmap",
),
subtest(([vmap, vmap], [(4, 3, 2), (64, 3, 32, 32)]), name="vmap_vmap"),
subtest(
(
[
grad,
],
[(0,), [], (4, 2), (64, 3, 32, 32)],
),
name="grad",
),
subtest(
(
[grad, grad],
[
[],
],
),
name="grad_grad",
),
subtest(([vmap, grad], [(4, 2)]), name="vmap_grad"),
],
)
def test_tensor_print(self, device, op_list_data):
op_list, shapes = op_list_data
for dt in get_all_fp_dtypes():
data = [torch.randn(s, dtype=dt, device=device) for s in shapes]
for x in data:
buf = None
def foo(t):
nonlocal buf
buf = repr(t)
return t.mean()
fn = foo
bdim = 0
for op in reversed(op_list):
if op == vmap:
fn = op(fn, in_dims=bdim)
bdim += 1
else:
fn = op(fn)
expected = f"{repr(x)}"
level = 0
for op in op_list:
level += 1
if op == grad:
expected = f"GradTrackingTensor(lvl={level}, value={expected})"
elif op == vmap:
bdim -= 1
expected = (
f"BatchedTensor(lvl={level}, bdim={bdim}, value={expected})"
)
fn(x)
buf = buf.replace("\n", "").replace(" ", "")
expected = expected.replace("\n", "").replace(" ", "")
self.assertEqual(expected, buf)
def test_print_captured_tensor_inside_transform(self, device):
x = torch.tensor([1.0, 2.0, 3.0], device=device)
out = None
def f(y):
nonlocal out
out = repr(x)
return y
vjp(f, torch.randn(4, device=device))
self.assertEqual(out, repr(x))
def test_no_grad_outside(self, device):
x = torch.randn([], device=device, requires_grad=True)
with torch.no_grad():
y = grad(torch.sin)(x)
self.assertEqual(y, x.cos())
self.assertFalse(y.requires_grad)
def test_no_grad_inside(self, device):
def f(x):
with torch.no_grad():
shift = x**2
return x**2 - shift
x = torch.randn([], device=device)
y = grad(f)(x)
self.assertEqual(y, 2 * x)
y = grad(grad(f))(x)
self.assertEqual(y, 2)
x = torch.randn([], device=device, requires_grad=True)
y = grad(f)(x)
(z,) = torch.autograd.grad(y, x)
self.assertEqual(z, 2)
def test_no_grad_mixed(self, device):
def f(x):
with torch.no_grad():
shift = x**2
return x**2 - shift
x = torch.randn([], device=device, requires_grad=True)
with torch.no_grad():
y = grad(f)(x)
self.assertEqual(y, 2 * x)
self.assertFalse(y.requires_grad)
def test_no_grad_nested_simple(self, device):
def h(x):
with torch.no_grad():
shift = grad(lambda x: 0.25 * x**4)(x)
return x**3 - shift
x = torch.tensor(1.5, device=device, requires_grad=True)
y = grad(h)(x)
self.assertEqual(y, 3 * x**2)
(z,) = torch.autograd.grad(y, x)
self.assertEqual(z, 6 * x)
def test_no_grad_nested_complicated(self, device):
def f(x):
with torch.no_grad():
shift = x**3
return x**3 - shift
def g(x):
r1 = grad(f)(x)
with torch.no_grad():
shift = grad(f)(x)
return r1 - shift
x = torch.randn([], requires_grad=True, device=device)
y = grad(g)(x)
# The only differential part of g is x ** 3
self.assertEqual(y, 6 * x)
(z,) = torch.autograd.grad(y, x)
self.assertEqual(z, 6)
def test_no_grad_value(self, device):
def h(x):
with torch.no_grad():
gvalue, value = grad_and_value(lambda x: x**3)(x)
return x**3 - value
x = torch.tensor(1.6, device=device, requires_grad=True)
y = grad(h)(x)
self.assertEqual(y, 3 * x**2)
(z,) = torch.autograd.grad(y, x)
self.assertEqual(z, 6 * x)
def test_no_grad_outside_vjp(self, device):
def h(x):
return x**2
x = torch.tensor(2.0, requires_grad=True, device=device)
with torch.no_grad():
out, vjp_fn = vjp(h, x)
(y,) = vjp_fn(torch.tensor(1.0, device=device))
self.assertEqual(y, 2 * x)
self.assertFalse(y.requires_grad)
self.assertFalse(out.requires_grad)
def test_no_grad_outside_vjp_fn(self, device):
def h(x):
return x**2
x = torch.tensor(3.14, requires_grad=True, device=device)
out, vjp_fn = vjp(h, x)
with torch.no_grad():
(y,) = vjp_fn(torch.tensor(1.0, device=device))
self.assertEqual(y, 2 * x)
self.assertFalse(y.requires_grad)
self.assertTrue(out.requires_grad)
(z,) = torch.autograd.grad(out, x)
self.assertEqual(z, 2 * x)
def test_no_grad_outside_vjp_only(self, device):
def h(x):
return x**2
x = torch.tensor(3.14, requires_grad=True, device=device)
with torch.no_grad():
out, vjp_fn = vjp(h, x)
(y,) = vjp_fn(torch.tensor(1.0, device=device))
self.assertEqual(y, 2 * x)
self.assertFalse(out.requires_grad)
# This one is a little weird...
self.assertTrue(y.requires_grad)
(z,) = torch.autograd.grad(y, x)
self.assertEqual(z, 2)
@markDynamoStrictTest
class TestAutogradFunction(TestCase):
def test_set_materialize_grads(self, device):
class A(torch.autograd.Function):
@staticmethod
def forward(x, y):
return x, y
@staticmethod
def setup_context(ctx, inputs, output):
ctx.set_materialize_grads(False)
@staticmethod
def backward(ctx, gx, gy):
self.assertIsNotNone(gx)
self.assertIsNone(gy)
return gx, gy
def f(y, x):
x, y = A.apply(x, y)
return x**2
x = torch.tensor(2.0, device=device)
y = torch.tensor(3.0, device=device)
# grad differentiates w.r.t. arg 0 by default
grad(f)(y, x)
grad(grad(f))(y, x)
@parametrize("inner_requires_grad", [True, False])
@parametrize("save_for", ["jvp", "vjp"])
@parametrize("save_tensors", ["input", "output", "neither"])
@parametrize("mark_dirty", [True, False])
def test_function_returns_input(
self, device, inner_requires_grad, save_for, save_tensors, mark_dirty
):
class A(torch.autograd.Function):
@staticmethod
def forward(x):
return x
@staticmethod
def setup_context(ctx, inputs, output):
if save_for == "jvp":
save_fn = ctx.save_for_forward
else:
save_fn = ctx.save_for_backward
if mark_dirty:
ctx.mark_dirty(inputs[0])
if save_tensors == "input":
save_fn(inputs[0])
elif save_tensors == "output":
save_fn(output)
elif save_tensors == "neither":
pass
@staticmethod
def backward(ctx, grad_output):
return grad_output
@staticmethod
def jvp(ctx, x_t):
# NB: the logic to check ctx.save_for_forward happens
# before we reach this!
if mark_dirty:
ret = x_t.add_(0)
else:
ret = x_t.view_as(x_t)
return ret
def fn(x):
return A.apply(x.clone())
err_msg = "A input that has been returned as-is"
a = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad)
a_t = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad)
if save_tensors in ("input", "output") and not mark_dirty:
with self.assertRaisesRegex(RuntimeError, err_msg):
grad(fn)(a)
with self.assertRaisesRegex(RuntimeError, err_msg):
jvp(fn, (a,), (a_t,))
else:
grad(fn)(a)
jvp(fn, (a,), (a_t,))
a = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad).clone()
a_t = torch.tensor(
2.0, device=device, requires_grad=inner_requires_grad
).clone()
if save_tensors in ("input", "output") and not mark_dirty:
with self.assertRaisesRegex(RuntimeError, err_msg):
A.apply(a)
with self.assertRaisesRegex(RuntimeError, err_msg):
with fwAD.dual_level():
A.apply(fwAD.make_dual(a, a_t))
else:
b = A.apply(a)
if mark_dirty:
self.assertTrue(a is b)
if not (
mark_dirty and save_for == "vjp" and save_tensors in ("input", "output")
):
# TODO(soulitzer): https://github.com/pytorch/pytorch/issues/97827
with fwAD.dual_level():
a_dual = fwAD.make_dual(a, a_t)
b_dual = A.apply(a_dual)
if mark_dirty:
self.assertTrue(a_dual is b_dual)
def test_needs_input_grads(self, device):
class A(torch.autograd.Function):
@staticmethod
def forward(x, y):
return x * y
@staticmethod
def setup_context(ctx, inputs, output):
return
@staticmethod
def backward(ctx, grad_output):
self.assertTrue(ctx.needs_input_grad[0])
self.assertFalse(ctx.needs_input_grad[1])
return None, None
x = torch.tensor(2.0, device=device)
y = torch.tensor(3.0, device=device)
# grad differentiates w.r.t. arg 0 by default
grad(A.apply)(x, y)
grad(grad(A.apply))(x, y)
def _get_NumpyCubeNotComposable(self):
class NumpyCubeNotComposable(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = input.cpu().numpy()
return torch.tensor(input_np**3, device=input.device), input_np
@staticmethod
def setup_context(ctx, inputs, output):
ctx.input_np = output[1]
ctx.device = inputs[0].device
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output, grad_saved):
result_np = 3 * (ctx.input_np**2)
return torch.tensor(result_np, device=ctx.device)
return NumpyCubeNotComposable
def test_once_differentiable_autograd_vjp(self, device):
NumpyCubeNotComposable = self._get_NumpyCubeNotComposable()
def f(x):
y, _ = NumpyCubeNotComposable.apply(x)
return y
# regular autograd x vjp
x = torch.randn([], requires_grad=True, device=device)
grad_y = torch.randn_like(x, requires_grad=True)
_, vjp_fn = vjp(f, x)
(gx,) = vjp_fn(grad_y)
with self.assertRaisesRegex(RuntimeError, "marked with @once_differentiable"):
gx.backward()
# TODO: support torch.autograd.function.once_differentiable
# (or, if impossible, figure out how to raise a nice error)
# https://github.com/pytorch/pytorch/issues/90224
@unittest.expectedFailure
def test_once_differentiable_grad_vjp(self, device):
NumpyCubeNotComposable = self._get_NumpyCubeNotComposable()
# grad x vjp
x = torch.randn([], device=device)
grad_y = torch.randn_like(x)
def h(x, grad_y):
_, vjp_fn = vjp(f, x) # noqa: F821
(gx,) = vjp_fn(grad_y)
return gx
grad(h, argnums=(0, 1))(x, grad_y)
def test_grad_fn_name(self, device):
names = []
class FooBar(torch.autograd.Function):
@staticmethod
def forward(x):
return x.clone()
@staticmethod
def setup_context(ctx, inputs, output):
return
@staticmethod
def backward(ctx, grad_output):
return grad_output
def f(x):
y = FooBar.apply(x)
names.append(type(y.grad_fn).__name__)
return y
x = torch.tensor(1.0)
grad(f)(x)
self.assertEqual(names, ["FooBarGeneratedBackward"])
@markDynamoStrictTest
class TestAutogradFunctionVmapAPI(TestCase):
def test_no_vmap_staticmethod_and_no_generate_vmap_rule(self, device):
class NumpyCube(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = to_numpy(input) # noqa: F821
dinput = torch.tensor(3 * input_np**2, device=input.device)
return torch.tensor(input_np**3, device=input.device), dinput
@staticmethod
def setup_context(ctx, inputs, output):
ctx.save_for_backward(inputs, output[1])
@staticmethod
def backward(ctx, grad_output, grad_saved):
raise RuntimeError("foobar")
x = torch.randn(3, device=device)
with self.assertRaisesRegex(RuntimeError, "does not have vmap support"):
vmap(NumpyCube.apply)(x)
def test_has_vmap_staticmethod_and_has_generate_vmap_rule(self, device):
class NumpyCube(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(input):
input_np = to_numpy(input) # noqa: F821
dinput = torch.tensor(3 * input_np**2, device=input.device)
return torch.tensor(input_np**3, device=input.device), dinput
@staticmethod
def setup_context(ctx, outputs, input):
ctx.save_for_backward(input, outputs[1])
@staticmethod
def backward(ctx, grad_output, grad_saved):
raise RuntimeError("foobar")
@staticmethod
def vmap(infos, in_dims, x):
raise RuntimeError("foobar")
x = torch.randn(3, device=device)
with self.assertRaisesRegex(RuntimeError, "generate_vmap_rule=True and"):
vmap(NumpyCube.apply)(x)
def test_info_object(self, device):
batch_size = 10
class Id(torch.autograd.Function):
@staticmethod
def forward(input):
pass
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def backward(ctx, grad_output, grad_saved):
pass
@staticmethod
def vmap(info, in_dims, input):
self.assertEqual(info.batch_size, batch_size)
self.assertEqual(info.randomness, randomness)
return input, in_dims[0]
x = torch.randn(batch_size, 3, device=device)
for randomness in ("error", "different", "same"):
vmap(Id.apply, randomness=randomness)(x)
def test_in_dims_single_input(self, device):
class Id(torch.autograd.Function):
@staticmethod
def forward(input):
pass
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def backward(ctx, grad_output, grad_saved):
pass
@staticmethod
def vmap(info, in_dims, input):
self.assertEqual(in_dims, (1,))
return input, in_dims[0]
B = 10
x = torch.randn(3, B, device=device)
vmap(Id.apply, in_dims=1)(x)
vmap(Id.apply, in_dims=(1,))(x)
def test_in_dims_multiple_inputs(self, device):
class Id(torch.autograd.Function):
@staticmethod
def forward(x, y):
pass
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def backward(ctx, grad_output, grad_saved):
pass
@staticmethod
def vmap(info, in_dims, x, y):
self.assertEqual(in_dims, (0, [0, 0]))
self.assertTrue(isinstance(in_dims, tuple))
self.assertTrue(isinstance(in_dims[1], list))
return (x, y), in_dims
x = torch.randn(2, device=device)
vmap(Id.apply)(x, [x, x])
def test_skips_empty_layer(self, device):
class Id(torch.autograd.Function):
@staticmethod
def forward(input):
return input
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def backward(ctx, grad_output, grad_saved):
pass
@staticmethod
def vmap(info, in_dims, input):
raise RuntimeError("expected to not be called")
def f(x):
y = torch.tensor(1.0)
y = Id.apply(y)
return x * 1
x = torch.randn(2, 3)
vmap(f)(x)
def test_none_returns(self, device):
class Zeros(torch.autograd.Function):
@staticmethod
def forward(input):
return torch.zeros(input.shape, device=input.device)
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def vmap(info, in_dims, input):
assert in_dims == (0,)
return torch.zeros(input.shape[1:], device=input.device), None
B = 2
x = torch.randn(B, 3)
y = vmap(Zeros.apply)(x)
self.assertEqual(y, torch.zeros_like(x))
class TwoZeros(torch.autograd.Function):
@staticmethod
def forward(input):
r = torch.zeros(input.shape, device=input.device)
return r, r
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def vmap(info, in_dims, input):
assert in_dims == (0,)
r = torch.zeros(input.shape[1:], device=input.device)
return (r, r), None
B = 2
x = torch.randn(B, 3)
result = vmap(TwoZeros.apply)(x)
self.assertTrue(isinstance(result, tuple))
y, z = result
self.assertEqual(y, torch.zeros_like(x))
self.assertEqual(z, torch.zeros_like(x))
def test_should_have_two_returns(self, device):
class Zeros(torch.autograd.Function):
@staticmethod
def forward(input):
r = torch.zeros(input.shape, device=input.device)
return r
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def vmap(info, in_dims, input):
r = torch.zeros(input.shape[1:], device=input.device)
return r
B = 2
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "to have two returns"):
result = vmap(Zeros.apply)(x)
class TwoZeros(torch.autograd.Function):
@staticmethod
def forward(input):
r = torch.zeros(input.shape, device=input.device)
return r, r
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def vmap(info, in_dims, input):
r = torch.zeros(input.shape[1:], device=input.device)
return r, r, 0, 0
B = 2
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "to have two returns"):
result = vmap(Zeros.apply)(x)
def test_incompatible_out_dims_error_msg(self, device):
class Zeros(torch.autograd.Function):
@staticmethod
def forward(input):
r = torch.zeros(input.shape, device=input.device)
return r
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def vmap(info, in_dims, input):
r = torch.zeros(input.shape[1:], device=input.device)
return r, (None,)
B = 2
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "returned an incompatible"):
result = vmap(Zeros.apply)(x)
class Zeros(torch.autograd.Function):
@staticmethod
def forward(input):
r = torch.zeros(input.shape, device=input.device)
return [r]
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def vmap(info, in_dims, input):
r = torch.zeros(input.shape[1:], device=input.device)
return [r], (None,)
B = 2
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "returned an incompatible"):
result = vmap(Zeros.apply)(x)
def test_kwarg_only_tensors(self, device):
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
class MyClass(torch.autograd.Function):
@staticmethod
def forward(x, *, y):
return x + y
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def vmap(info, in_dims, x, *, y):
assert in_dims == (0,)
return x + y, 0
x = torch.randn(3)
y = torch.randn(3)
vmap(MyClass.apply)(x, y=y)
@markDynamoStrictTest
class TestVmapOfGrad(TestCase):
def test_per_sample_grads_inplace_view(self, device):
def compute_loss(weight, x, t):
x = x.mm(weight)
y = x.squeeze_(0)
return (y - t).sum()
weight = torch.randn(16, 2, device=device)
x = torch.randn(64, 1, 16, device=device)
t = torch.randn(64, 2, device=device)
result = vmap(partial(grad(compute_loss), weight))(x, t)
expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
expected = torch.stack(expected)
# TODO: Check if the rtol is a problem
self.assertEqual(result, expected, atol=0, rtol=5e-4)
def test_new_zeros_materializes_tensor(self, device):
N = 3
C = 5
def foo(y, x):
result = x.new_zeros((C,))
result.copy_(y)
return result.sum()
x = torch.randn(N, device=device)
y = torch.randn(N, C, device=device)
result = vmap(grad(foo))(y, x)
self.assertEqual(result, torch.ones_like(y))
def test_new_empty_materializes_tensor(self, device):
N = 3
C = 5
def foo(y, x):
result = x.new_empty((C,))
result.copy_(y)
return result.sum()
x = torch.randn(N, device=device)
y = torch.randn(N, C, device=device)
result = vmap(grad(foo))(y, x)
self.assertEqual(result, torch.ones_like(y))
def test_per_sample_grads_simple(self, device):
def compute_loss(weight, x, t):
y = x @ weight
return ((y - t) ** 2).sum()
weight = torch.randn(16, 2, device=device)
x = torch.randn(64, 16, device=device)
t = torch.randn(64, 2, device=device)
result = vmap(partial(grad(compute_loss), weight))(x, t)
expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
expected = torch.stack(expected)
# TODO: Check if the rtol is a problem
self.assertEqual(result, expected, atol=0, rtol=5e-4)
def _compare_expected_and_result(self, expected, result, mechanism):
if mechanism == "make_functional":
expected = zip(*expected)
expected = tuple(torch.stack(shards) for shards in expected)
for r, e in zip(result, expected):
self.assertEqual(r, e, atol=0, rtol=1.5e-3)
else:
assert mechanism == "functional_call"
expected = {
k: tuple(d[k] for d in expected) for k, v in expected[0].items()
}
expected = {k: torch.stack(shards) for k, shards in expected.items()}
for key in result:
self.assertEqual(result[key], expected[key], atol=0, rtol=1.5e-3)
@tf32_on_and_off(0.005)
@parametrize("mechanism", ["make_functional", "functional_call"])
def test_per_sample_grads_embeddingnet(self, device, mechanism):
class SampleNet(nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
self.emb = nn.Embedding(vocab_size, 16)
self.fc1 = nn.Linear(16, 16)
self.fc2 = nn.Linear(16, 2)
def forward(self, x):
x = self.emb(x)
x = torch.transpose(x, -1, -2)
x = torch.mean(x, -1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
def name(self):
return "SampleNet"
# Create our inputs...
vocab_size = 1000
batch_shape = [64]
words_per_sentence = 5
data = torch.randint(
0, vocab_size, (*batch_shape, words_per_sentence), device=device
)
targets = torch.randint(0, 1, (*batch_shape,), device=device)
# Construct our module
net = SampleNet(vocab_size).to(device=device)
criterion = nn.CrossEntropyLoss()
net_func, weights = _get_weights_and_functional_call(net, mechanism)
def compute_loss(weights, data, target):
output = net_func(weights, data)
result = criterion(output, target)
return result
expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)]
result = vmap(partial(grad(compute_loss), weights))(data, targets)
self._compare_expected_and_result(expected, result, mechanism)
def test_log_softmax(self, device):
x = torch.randn(3, 5, device=device)
v = torch.randn(5, device=device)
def foo(x, v):
_, vjp_fn = vjp(partial(torch.log_softmax, dim=-1), x)
return vjp_fn(v)[0]
result = vmap(foo, (0, None))(x, v)
v = v.expand_as(x)
x.requires_grad_()
output = torch.log_softmax(x, dim=-1)
output.backward(v)
self.assertEqual(result, x.grad)
jacrev_and_jacfwd = parametrize(
"jacapi", [subtest(jacrev, name="jacrev"), subtest(jacfwd, name="jacfwd")]
)
FIXME_jacrev_only = parametrize("jacapi", [subtest(jacrev, name="jacrev")])
@markDynamoStrictTest
class TestJac(VmapTearDownMixin, TestCase):
@jacrev_and_jacfwd
def test_simple(self, device, jacapi):
x = torch.randn(3, device=device)
y = jacapi(torch.sin)(x)
expected = torch.diagflat(x.cos())
assert torch.allclose(y, expected)
@jacrev_and_jacfwd
def test_simple_not_flat(self, device, jacapi):
x = torch.randn(2, 3, device=device)
y = jacapi(torch.sin)(x)
expected = torch.diagflat(x.view(-1).cos())
expected = expected.view(2, 3, 2, 3)
assert torch.allclose(y, expected)
@jacrev_and_jacfwd
def test_take(self, device, jacapi):
x = torch.rand(5)
def func(x):
y = torch.ones(3, dtype=torch.long)
z = torch.take(x, y)
return z
self.assertEqual(jacrev(func)(x), torch.autograd.functional.jacobian(func, x))
@jacrev_and_jacfwd
def test_diff_numel(self, device, jacapi):
x = torch.randn(2, 4, device=device)
# Tensor[2, 4] -> Tensor[3, 1]
def f(x):
return x[0, 1:].unsqueeze(-1)
y = jacapi(f)(x)
self.assertEqual(y.shape, (3, 1, 2, 4))
expected = x.new_zeros(3, 1, 2, 4)
expected[0, 0, 0, 1] = 1
expected[1, 0, 0, 2] = 1
expected[2, 0, 0, 3] = 1
self.assertEqual(y, expected)
@jacrev_and_jacfwd
def test_vmap_on_jac_simple(self, device, jacapi):
x = torch.randn(2, 3, device=device)
y = vmap(jacapi(torch.sin))(x)
expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
assert torch.allclose(y, expected)
@jacrev_and_jacfwd
def test_nested_jac_simple(self, device, jacapi):
def foo(x):
return x.sin().sum()
x = torch.randn(3, device=device)
y = jacapi(jacapi(foo))(x)
expected = torch.diagflat(-x.sin())
assert torch.allclose(y, expected)
@jacrev_and_jacfwd
def test_multiple_args(self, device, jacapi):
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
z = jacapi(torch.multiply, argnums=1)(x, y)
expected = torch.diagflat(x)
assert torch.allclose(z, expected)
@jacrev_and_jacfwd
def test_multiple_outputs_multiple_argnums(self, device, jacapi):
def f(x, y):
return 2 * x + 3 * y, 4 * x + 5 * y
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
z = jacapi(f, argnums=(0, 1))(x, y)
expected_out0_x = torch.diagflat(torch.full_like(x, 2))
expected_out0_y = torch.diagflat(torch.full_like(y, 3))
expected_out1_x = torch.diagflat(torch.full_like(x, 4))
expected_out1_y = torch.diagflat(torch.full_like(y, 5))
self.assertEqual(len(z), 2)
self.assertTrue(isinstance(z, tuple))
self.assertEqual(len(z[0]), 2)
self.assertTrue(isinstance(z[0], tuple))
self.assertEqual(z[0][0], expected_out0_x)
self.assertEqual(z[0][1], expected_out0_y)
self.assertEqual(z[1][0], expected_out1_x)
self.assertEqual(z[1][1], expected_out1_y)
@jacrev_and_jacfwd
def test_multiple_outputs_single_argnums(self, device, jacapi):
def f(x, y):
return 2 * x + 3 * y, 4 * x + 5 * y
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
expected_out0_x = torch.diagflat(torch.full_like(x, 2))
expected_out1_x = torch.diagflat(torch.full_like(x, 4))
z = jacapi(f, argnums=0)(x, y)
self.assertEqual(len(z), 2)
self.assertTrue(isinstance(z, tuple))
self.assertEqual(z, (expected_out0_x, expected_out1_x))
z = jacapi(f, argnums=(0,))(x, y)
self.assertEqual(len(z), 2)
self.assertTrue(isinstance(z, tuple))
self.assertTrue(isinstance(z[0], tuple))
self.assertEqual(z, ((expected_out0_x,), (expected_out1_x,)))
@jacrev_and_jacfwd
def test_multiple_outputs_pytree(self, device, jacapi):
def f(x, y):
return {"left": 2 * x + 3 * y, "right": 4 * x + 5 * y}
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
z = jacapi(f, argnums=(0, 1))(x, y)
expected_left_x = torch.diagflat(torch.full_like(x, 2))
expected_left_y = torch.diagflat(torch.full_like(y, 3))
expected_right_x = torch.diagflat(torch.full_like(x, 4))
expected_right_y = torch.diagflat(torch.full_like(y, 5))
expected = {
"left": (expected_left_x, expected_left_y),
"right": (expected_right_x, expected_right_y),
}
self.assertTrue(isinstance(z, dict))
self.assertTrue(isinstance(z["left"], tuple))
self.assertTrue(isinstance(z["right"], tuple))
self.assertEqual(z, expected)
@jacrev_and_jacfwd
def test_multiple_inputs_pytree(self, device, jacapi):
def f(a, b, c):
a0, a1 = a
return a0 + a1 * 2 + b * 3 + c * 4
x = torch.randn([], device=device)
args = ((x, x), x, x)
result = jacapi(f, argnums=(0, 1, 2))(*args)
expected = (
(torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),
torch.tensor(3.0, device=device),
torch.tensor(4.0, device=device),
)
self.assertEqual(result, expected)
result = jacapi(f, argnums=(0,))(*args)
expected = (
(torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),
)
self.assertEqual(result, expected)
result = jacapi(f)(*args)
expected = (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device))
self.assertEqual(result, expected)
@jacrev_and_jacfwd
def test_dimensionality(self, device, jacapi):
def f(x):
return x
x = torch.randn([], device=device)
result = jacapi(f)(x)
self.assertEqual(result.dim(), 0)
self.assertEqual(result, torch.ones_like(x))
x = torch.randn([1], device=device)
result = jacapi(f)(x)
self.assertEqual(result.dim(), 2)
self.assertEqual(result, x.new_ones(1, 1))
@jacrev_and_jacfwd
def test_aux_tensor(self, device, jacapi):
def f(x):
y = x.clone()
return y, y.cos()
x = torch.randn(3, device=device)
result, aux = jacapi(f, has_aux=True)(x)
self.assertEqual(result, torch.eye(3, 3, device=device))
self.assertEqual(aux, x.cos())
@jacrev_and_jacfwd
def test_aux_pytree(self, device, jacapi):
def f(x):
y = x.clone()
return y, {"a": y.cos(), "b": [y.tan()]}
x = torch.randn(3, device=device)
result, aux = jacapi(f, has_aux=True)(x)
self.assertEqual(result, torch.eye(3, 3, device=device))
_, expected_aux = f(x)
self.assertEqual(aux, expected_aux)
for aux in [1, 1.0, "abc"]:
with self.assertRaisesRegex(
RuntimeError, r"Expected tensors, got unsupported type"
):
_ = jacapi(lambda x: (x, aux), has_aux=True)(x)
with self.assertRaisesRegex(
RuntimeError, r"Expected tensors, got unsupported type"
):
_ = jacapi(lambda x: (x, [x, aux]), has_aux=True)(x)
@jacrev_and_jacfwd
def test_outputs_can_any_pytree(self, device, jacapi):
x = torch.randn(2, 3, device=device)
for output in [None, ()]:
with self.assertRaisesRegex(
RuntimeError,
r"(vjp|jvp).+: Expected f to be a function that has non-empty output",
):
jacapi(lambda _: output)(x)
for output in [1, True, 12.2, "abc"]:
with self.assertRaisesRegex(
RuntimeError,
r"(vjp|jvp).+: expected f\(\*primals\) to return only tensors",
):
jacapi(lambda _: output)(x)
# Check list output
out = jacapi(lambda x: [x, x.sum()])(x)
assert isinstance(out, list) and len(out) == 2
# Check dict output
out = jacapi(lambda x: {"x": x, "xsum": x.sum()})(x)
assert isinstance(out, dict) and len(out) == 2 and "xsum" in out
def composite_output(x):
out = x.sum()
return [
(out, {"a": x, "out": [x, out]}),
]
out = jacapi(composite_output)(x)
assert isinstance(out, list)
assert isinstance(out[0], tuple) and isinstance(out[0][1], dict)
@jacrev_and_jacfwd
def test_multiple_inputs_outputs_pytree(self, device, jacapi):
def f(a, b, c):
a0, a1 = a
return a0 + a1 * 2, {"foo": b * 3 + c * 4}
x = torch.randn([], device=device)
zero = torch.zeros([], device=device)
args = ((x, x), x, x)
result = jacapi(f)(*args)
expected = (
(torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),
{"foo": (zero, zero)},
)
self.assertEqual(result, expected)
result = jacapi(f, argnums=(0,))(*args)
expected = (
((torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),),
{"foo": ((zero, zero),)},
)
self.assertEqual(result, expected)
result = jacapi(f, argnums=(0, 1))(*args)
expected = (
(
(torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),
zero,
),
{"foo": ((zero, zero), torch.tensor(3.0, device=device))},
)
self.assertEqual(result, expected)
@jacrev_and_jacfwd
def test_multiple_inputs_outputs_pytree_multidim(self, device, jacapi):
def f(dct):
a = dct["a"]
b = dct["b"]
return {"c": a.sin(), "d": b.cos()}
x = torch.randn(3, device=device)
args = ({"a": x, "b": x},)
result = jacapi(f)(*args)
expected = {
"c": {"a": x.cos().diagflat(), "b": x.new_zeros(3, 3)},
"d": {"a": x.new_zeros(3, 3), "b": -x.sin().diagflat()},
}
self.assertEqual(result, expected)
@jacrev_and_jacfwd
def test_unrelated_input(self, device, jacapi):
def f(x, y):
return x
x = torch.randn(2, 3, device=device)
y = torch.randn(2, 3, device=device)
result = jacapi(f, argnums=(0, 1))(x, y)
expected0 = torch.eye(6, 6, device=device).view(2, 3, 2, 3)
expected1 = y.new_zeros(2, 3, 2, 3)
expected = (expected0, expected1)
self.assertTrue(isinstance(result, tuple))
self.assertEqual(result, expected)
@jacrev_and_jacfwd
def test_unrelated_output(self, device, jacapi):
y = torch.randn(2, 3, device=device)
def f(x):
return y
x = torch.randn(2, 3, device=device)
result = jacapi(f)(x)
expected = x.new_zeros(2, 3, 2, 3)
self.assertEqual(result, expected)
@jacrev_and_jacfwd
def test_empty_output(self, device, jacapi):
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
def f(x, y):
return ()
with self.assertRaisesRegex(RuntimeError, "xpected"):
jacapi(f)(x, y)
@jacrev_and_jacfwd
def test_argnums_tuple(self, device, jacapi):
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
z = jacapi(torch.multiply, argnums=(0, 1))(x, y)
expected0 = torch.diagflat(y)
expected1 = torch.diagflat(x)
assert len(z) == 2
assert torch.allclose(z[0], expected0)
assert torch.allclose(z[1], expected1)
@jacrev_and_jacfwd
def test_argnums_effect_on_return(self, device, jacapi):
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
z = jacapi(torch.multiply, argnums=(0,))(x, y)
expected0 = torch.diagflat(y)
assert isinstance(z, tuple)
assert len(z) == 1
assert torch.allclose(z[0], expected0)
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
z = jacapi(torch.multiply, argnums=0)(x, y)
expected0 = torch.diagflat(y)
assert isinstance(z, torch.Tensor)
assert torch.allclose(z, expected0)
@jacrev_and_jacfwd
def test_argnums_defaults_to_zero(self, device, jacapi):
def f(x, y):
return x * 2 + y * 3
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
z = jacapi(f)(x, y)
expected = torch.diagflat(torch.full_like(x, 2))
self.assertEqual(z, expected)
@jacrev_and_jacfwd
def test_empty_argnums(self, device, jacapi):
x = torch.randn(3, device=device)
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
jacapi(torch.sin, argnums=())(x)
@jacrev_and_jacfwd
def test_out_of_bounds_argnums(self, device, jacapi):
x = torch.randn(3, device=device)
with self.assertRaisesRegex(RuntimeError, "only 1 positional inputs"):
jacapi(torch.sin, argnums=2)(x)
@jacrev_and_jacfwd
def test_negative_argnums(self, device, jacapi):
x = torch.randn(3, device=device)
with self.assertRaisesRegex(RuntimeError, "only 1 positional inputs"):
jacapi(torch.sin, argnums=-2)(x)
@jacrev_and_jacfwd
def test_repeated_argnums(self, device, jacapi):
x = torch.randn(3, device=device)
with self.assertRaisesRegex(RuntimeError, "must be unique"):
jacapi(torch.sin, argnums=(0, 0))(x)
@jacrev_and_jacfwd
def test_float_argnums(self, device, jacapi):
x = torch.randn(3, device=device)
with self.assertRaisesRegex(RuntimeError, "must be int or Tuple"):
jacapi(torch.sin, argnums=0.0)(x)
with self.assertRaisesRegex(RuntimeError, "must be int"):
jacapi(torch.multiply, argnums=(1, 0.0))(x, x)
def test_hessian_simple(self, device):
def f(x):
return x.sin()
x = torch.randn(3, device=device)
hessian(f)(x)
def _test_against_reference(self, f, inputs, jacapi):
def foo(inputs):
return f(*inputs)
expected = torch.autograd.functional.jacobian(f, inputs)
result = jacapi(foo)(inputs)
self.assertEqual(result, expected)
@jacrev_and_jacfwd
def test_against_reference_simple(self, device, jacapi):
def f(x):
return 3 * x**2
x = torch.randn(2, 3, 5, device=device)
self._test_against_reference(f, (x,), jacapi)
@jacrev_and_jacfwd
def test_against_reference_multi_input(self, device, jacapi):
def f(x, y):
return (x.cos() * x) @ y.sin()
x = torch.randn(2, 3, device=device)
y = torch.randn(3, 5, device=device)
self._test_against_reference(f, (x, y), jacapi)
@jacrev_and_jacfwd
def test_against_reference_multi_input_multi_output(self, device, jacapi):
def f(x, y):
return (x * x) @ y, x @ (x.sum(1) * y), y.sum()
x = torch.randn(5, 3, device=device)
y = torch.randn(3, 5, device=device)
self._test_against_reference(f, (x, y), jacapi)
@jacrev_and_jacfwd
def test_against_reference_unrelated_outputs(self, device, jacapi):
def f(x, y):
return x, y, x, y
x = torch.randn(2, device=device)
y = torch.randn(3, device=device)
self._test_against_reference(f, (x, y), jacapi)
@jacrev_and_jacfwd
def test_against_reference_zero_dim(self, device, jacapi):
# zero-dim output
def f(x, y):
return x.sum(), y.sum(), x * y
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
self._test_against_reference(f, (x, y), jacapi)
# zero-dim input
def g(x):
return torch.stack([x, x, x])
x = torch.randn([], device=device)
self._test_against_reference(g, (x,), jacapi)
# Mixed zero-dim input / zero-dim output
def h(x, y):
return y.sum(), x * y
x = torch.randn([], device=device)
y = torch.randn(1, device=device)
self._test_against_reference(h, (x, y), jacapi)
@jacrev_and_jacfwd
def test_against_reference_correctness_different_devices(self, device, jacapi):
def f(x, y):
return x * y, (x * y).to(device=device)
x = torch.randn(3)
y = torch.randn(3)
self._test_against_reference(f, (x, y), jacapi)
@jacrev_and_jacfwd
def test_against_reference_default_arg(self, device, jacapi):
def f(x, y, z=3.0):
return x * y * z
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
self._test_against_reference(f, (x, y), jacapi)
@jacrev_and_jacfwd
def test_inplace(self, device, jacapi):
def f(x, y):
y.copy_(x)
return y
out = jacapi(f, argnums=0) # x is differentiable
x, y = torch.randn(2, device=device), torch.randn(2, device=device)
self.assertEqual(out(x, y), torch.eye(y.shape[0]))
# testing tuple of argnums with the example that raised this issue originally
def g(x, y, z):
x[:2] = y
return torch.vstack([(x**2).sum(), (z**3).sum()])
out = jacapi(g, argnums=(1, 2))
x, y, z = (
torch.randn(3, device=device),
torch.randn(2, device=device),
torch.randn(2, device=device),
)
expected_out = (
torch.zeros(2, 1, 2, device=device),
torch.zeros(2, 1, 2, device=device),
)
expected_out[0][0][0] = 2 * y # top left corner
expected_out[1][1][0] = 3 * (z**2) # bottom right corner
out_val = out(x, y, z)
self.assertEqual(out_val, expected_out)
@parametrize("_preallocate_and_copy", (True, False))
def test_chunk_jacrev(self, device, _preallocate_and_copy):
x = torch.randn(10, 2, device=device)
y = torch.randn(1, 2, device=device)
def f(x, y):
return (x.sin(), x + y), (x + 2, x.sum())
for chunk_size in (1, 2, 3, 4, 7, 10, 1000):
expected = jacrev(f, argnums=(0, 1))(x, y)
actual = jacrev(
f,
argnums=(0, 1),
chunk_size=chunk_size,
_preallocate_and_copy=_preallocate_and_copy,
)(x, y)
self.assertEqual(actual, expected)
err_msg = "jacrev: `chunk_size` should be greater than 0."
with self.assertRaisesRegex(ValueError, err_msg):
jacrev(f, argnums=(0,), chunk_size=0)(x, y)
with self.assertRaisesRegex(ValueError, err_msg):
jacrev(f, argnums=(0,), chunk_size=-2)(x, y)
@parametrize("_preallocate_and_copy", (True, False))
def test_chunk_jacrev_composition(self, device, _preallocate_and_copy):
x = torch.randn(10, 2, device=device)
chunk_size = 3
def f(x):
return (x.sin(), x), (x + 2, x.sum())
expected = vmap(jacrev(jacrev(f)))(x)
actual = vmap(
jacrev(
jacrev(
f,
chunk_size=chunk_size,
_preallocate_and_copy=_preallocate_and_copy,
),
chunk_size=chunk_size,
)
)(x)
self.assertEqual(actual, expected)
# https://github.com/pytorch/pytorch/issues/127036
@xfailIfTorchDynamo
@parametrize("_preallocate_and_copy", (True, False))
def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy):
# With chunk_size=1, we shouldn't `vmap` and hence not be limited
# by it's constraints.
x = torch.randn(3, 3, device=device)
# Function with Dynamic Op in Backward.
# This should cause jacrev/vmap(vjp) to fail.
class IdentityWithDynamicBackwardOp(torch.autograd.Function):
@staticmethod
def forward(input):
return input
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def backward(ctx, grad_output):
# dynamic op in backward pass.
grad_output.nonzero()
return grad_output
def f(x):
return IdentityWithDynamicBackwardOp.apply(x)
# With `chunk_size=1`, we don't use vmap. So the following should work.
jacfn = jacrev(f, chunk_size=1, _preallocate_and_copy=_preallocate_and_copy)
actual = jacfn(x)
expected = torch.autograd.functional.jacobian(f, x, vectorize=False)
self.assertEqual(actual, expected)
# Should fail with `chunk_size=2`.
msg = (
r"vmap: We do not support batching operators that can output dynamic shape."
)
with self.assertRaisesRegex(RuntimeError, msg):
jacrev(f, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x)
def test_complex_error(self, device):
# Verify complex input raises error
# C -> C
def fn(x):
return x.conj()
x = torch.randn(1, device=device, dtype=torch.cfloat)
with self.assertRaisesRegex(RuntimeError, "jacrev: Expected all inputs"):
jacrev(fn)(x)
with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all inputs"):
jacfwd(fn)(x)
# Verify complex output raises error
# R -> C
def fn(x):
return torch.conj(x * 0.5j)
x = torch.randn(1, device=device, dtype=torch.float)
with self.assertRaisesRegex(RuntimeError, "jacrev: Expected all outputs"):
jacrev(fn)(x)
with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all outputs"):
jacfwd(fn)(x)
@jacrev_and_jacfwd
def test_jac_with_non_tensor_args(self, device, jacapi):
def f(t, int_x):
return t + int_x
t = torch.randn(3, 3, device=device)
actual = jacapi(f)(t, 3)
expected = torch.autograd.functional.jacobian(partial(f, int_x=3), t)
self.assertEqual(actual, expected)
@markDynamoStrictTest
class TestHessian(TestCase):
def _test_against_reference(self, f, inputs):
def foo(inputs):
return f(*inputs)
expected = torch.autograd.functional.hessian(f, inputs)
result = hessian(foo)(inputs)
self.assertEqual(result, expected)
def test_hessian_vectorize_correctness_simple(self, device):
def f(x):
return (3 * x**2).sum()
x = torch.randn(2, 3, 5, device=device)
self._test_against_reference(f, (x,))
def test_hessian_vectorize_correctness_multi_input(self, device):
def f(x, y, z):
return ((x.relu() * x) @ y.sin() @ z).sum()
x = torch.randn(2, 3, device=device)
y = torch.randn(3, 5, device=device)
z = torch.randn(5, 5, device=device)
self._test_against_reference(f, (x, y, z))
def test_hessian_vectorize_correctness_unrelated_outputs(self, device):
# output unrelated to one input
def f(x, y):
return (x**2).sum()
x = torch.randn(2, device=device)
y = torch.randn(3, device=device)
self._test_against_reference(f, (x, y))
# output unrelated to all inputs
def f(x, y):
return torch.ones([])
x = torch.randn(2, device=device)
y = torch.randn(3, device=device)
self._test_against_reference(f, (x, y))
def test_jacfwd_different_levels(self, device):
# Test case from:
# https://github.com/pytorch/functorch/issues/597
b = 8
n = 100
d = 2
x1 = torch.randn(b, n, d, device=device)
x2 = x1
A = 0.1 * torch.randn(b, d, d, device=device)
def loss(A, x1, x2):
x2_hat = (A @ (x1.T)).T
res = x2 - x2_hat
res_sqr = res**2
return res_sqr.sum()
hess1 = vmap(jacrev(jacrev(loss)))(A, x1, x2)
hess2 = vmap(hessian(loss))(A, x1, x2)
self.assertEqual(hess2, hess1)
@markDynamoStrictTest
class TestJvp(TestCase):
def test_inplace_on_captures(self, device):
x = torch.tensor([1.0, 2.0, 3.0], device=device)
captured = torch.randn(3, device=device)
def foo(x):
captured.copy_(x)
return (x * captured).sum()
with self.assertRaisesRegex(RuntimeError, "mutate a captured Tensor"):
grad(foo)(x)
def test_simple(self, device):
x = torch.randn(2, 3, device=device)
t = torch.randn(2, 3, device=device)
result = jvp(torch.sin, (x,), (t,))
expected = (x.sin(), x.cos() * t)
self.assertTrue(isinstance(result, tuple))
self.assertEqual(result, expected)
def test_multiple_inputs(self, device):
x = torch.randn(2, 3, device=device)
y = torch.randn(2, 3, device=device)
tx = torch.randn(2, 3, device=device)
ty = torch.randn(2, 3, device=device)
def f(x, y):
return x * y
result = jvp(f, (x, y), (tx, ty))
expected = (x * y, y * tx + x * ty)
self.assertTrue(isinstance(result, tuple))
self.assertEqual(result, expected)
def test_pytree_inputs(self, device):
def f(x, y, z):
a, b = x
return a + 2 * b + 3 * y + 4 * z
one = torch.tensor(1.0, device=device)
primal_outs, tangent_outs = jvp(
f, ((one, one), one, one), ((one, one), one, one)
)
self.assertEqual(primal_outs, one * 10)
self.assertEqual(tangent_outs, one * 10)
def test_pytree_inputs_error_cases(self, device):
def f(x):
return x
one = torch.tensor(1.0, device=device)
with self.assertRaisesRegex(RuntimeError, "Expected primals to be a tuple"):
jvp(f, one, one)
with self.assertRaisesRegex(RuntimeError, "same python structure"):
jvp(f, ((one, one), one), (one, one))
with self.assertRaisesRegex(RuntimeError, "only contain Tensors"):
jvp(f, ((one, one), 1), ((one, one), one))
with self.assertRaisesRegex(RuntimeError, "only contain Tensors"):
jvp(f, ((one, one), 1), ((1, one), one))
with self.assertRaisesRegex(RuntimeError, "at least one Tensor"):
jvp(f, ((),), ((),))
def test_unrelated_input(self, device):
def f(x, y):
return x
x = torch.randn(2, 3, device=device)
y = torch.randn(2, 3, device=device)
tx = torch.randn(2, 3, device=device)
ty = torch.randn(2, 3, device=device)
result = jvp(f, (x, y), (tx, ty))
expected = (x, tx)
self.assertTrue(isinstance(result, tuple))
self.assertEqual(result, expected)
def test_unrelated_output(self, device):
y = torch.randn(2, 3, device=device)
def f(x):
return y
x = torch.randn(2, 3, device=device)
tx = torch.randn(2, 3, device=device)
result = jvp(f, (x,), (tx,))
expected = (y, torch.zeros_like(y))
self.assertTrue(isinstance(result, tuple))
self.assertEqual(result, expected)
def test_strict_mode(self, device):
y = torch.randn(2, 3, device=device)
def f(x):
return x, y
x = torch.randn(2, 3, device=device)
tx = torch.randn(2, 3, device=device)
with self.assertRaisesRegex(RuntimeError, "strict"):
jvp(f, (x,), (tx,), strict=True)
def test_multiple_outputs(self, device):
x = torch.randn(2, 3, device=device)
t = torch.randn(2, 3, device=device)
def f(x):
return torch.sin(x), torch.cos(x)
result = jvp(f, (x,), (t,))
expected = (f(x), (x.cos() * t, -x.sin() * t))
self.assertTrue(isinstance(result, tuple))
self.assertEqual(result, expected)
def test_multiple_inputs_outputs(self, device):
x = torch.randn(2, 3, device=device)
y = torch.randn(2, 3, device=device)
tx = torch.randn(2, 3, device=device)
ty = torch.randn(2, 3, device=device)
def f(x, y):
return 2 * x + 3 * y, 4 * x + 5 * y
result = jvp(f, (x, y), (tx, ty))
expected = (f(x, y), f(tx, ty))
self.assertTrue(isinstance(result, tuple))
self.assertEqual(result, expected)
def test_primals_tangents_length_mismatch(self, device):
x = torch.randn(2, 3, device=device)
t = torch.randn(2, 3, device=device)
msg = "same python structure"
with self.assertRaisesRegex(RuntimeError, msg):
jvp(torch.sin, (x,), (t, t))
with self.assertRaisesRegex(RuntimeError, msg):
jvp(torch.sin, (x, x), (t, t, t))
def test_nonempty_primals_and_tangents(self, device):
with self.assertRaisesRegex(RuntimeError, "at least one Tensor"):
jvp(torch.sin, (), ())
def test_inputs_are_tuples_of_tensors(self, device):
x = torch.randn(2, 3, device=device)
t = torch.randn(2, 3, device=device)
with self.assertRaisesRegex(RuntimeError, "be a tuple"):
jvp(torch.sin, x, (t,))
with self.assertRaisesRegex(RuntimeError, "same python structure"):
jvp(torch.sin, (x,), t)
with self.assertRaisesRegex(RuntimeError, "same python structure"):
jvp(torch.sin, (x,), [t])
with self.assertRaisesRegex(RuntimeError, "only contain Tensors"):
jvp(torch.sin, (1.0,), (t,))
with self.assertRaisesRegex(RuntimeError, "only contain Tensors"):
jvp(torch.sin, (x,), (1.0,))
def test_outputs_can_any_pytree(self, device):
x = torch.randn(2, 3, device=device)
t = torch.randn(2, 3, device=device)
for output in [None, ()]:
with self.assertRaisesRegex(
RuntimeError,
r"jvp\(f, primals, tangents\): Expected f to be a function that has non-empty output",
):
jvp(lambda _: output, (x,), (t,))
for output in [1, True, 12.2, "abc"]:
with self.assertRaisesRegex(
RuntimeError,
r"jvp\(f, primals, tangents\): expected f\(\*primals\) to return only tensors",
):
jvp(lambda _: output, (x,), (t,))
# Check list output
out = jvp(lambda x: [x, x.sum()], (x,), (t,))
for i in range(2):
assert isinstance(out[i], list) and len(out[i]) == 2
# Check dict output
out = jvp(lambda x: {"x": x, "xsum": x.sum()}, (x,), (t,))
for i in range(2):
assert isinstance(out[i], dict) and len(out[i]) == 2 and "xsum" in out[i]
def composite_output(x):
out = x.sum()
return [
(out, {"a": x, "out": [x, out]}),
]
out = jvp(composite_output, (x,), (t,))
for i in range(2):
assert isinstance(out[i], list)
assert isinstance(out[i][0], tuple) and isinstance(out[i][0][1], dict)
def test_aux_tensor(self, device):
x = torch.randn(3, device=device)
t = torch.randn(3, device=device)
with self.assertRaisesRegex(
RuntimeError,
r"jvp\(f, primals, tangents\): output of function f should be a tuple",
):
jvp(lambda t: [t, t], (x,), (t,), has_aux=True)
with self.assertRaisesRegex(
RuntimeError,
r"jvp\(f, primals, tangents\): output of function f should be a tuple",
):
jvp(lambda t: (t, t + 2, t + 3), (x,), (t,), has_aux=True)
def f(z):
y = z.sin()
return y, z.cos()
out, jvp_out, aux = jvp(f, (x,), (t,), has_aux=True)
self.assertEqual(aux, x.cos())
self.assertEqual(out, x.sin())
self.assertEqual(jvp_out, t * x.cos())
def test_aux_pytree(self, device):
def f(x):
y = x.sin()
return y, {"a": x.cos(), "b": [x.tan()]}
x = torch.randn(3, device=device)
t = torch.randn(3, device=device)
out, jvp_out, aux = jvp(f, (x,), (t,), has_aux=True)
expected_out, expected_aux = f(x)
self.assertEqual(out, expected_out)
self.assertEqual(aux, expected_aux)
self.assertEqual(jvp_out, t * x.cos())
for aux in [1, 1.0, "abc"]:
with self.assertRaisesRegex(
RuntimeError, r"Expected tensors, got unsupported type"
):
_ = jvp(lambda x: (x, aux), (x,), (t,), has_aux=True)
with self.assertRaisesRegex(
RuntimeError, r"Expected tensors, got unsupported type"
):
_ = jvp(lambda x: (x, [x, aux]), (x,), (t,), has_aux=True)
def test_autograd_function_disables_fwd_grad(self, device):
# Sanity check. We don't really assume this anywhere so
# it's fine if this breaks one day.
class MySquare(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
enabled = fwAD._is_fwd_grad_enabled()
self.assertFalse(enabled)
return x * x
@staticmethod
def backward(ctx, gx):
return gx
x = torch.randn(3, requires_grad=True)
MySquare.apply(x)
def test_disable_fwd_grad_outside(self, device):
x = torch.randn([], device=device)
t = torch.ones_like(x)
with fwAD._set_fwd_grad_enabled(False):
_, y = jvp(torch.sin, (x,), (t,))
self.assertEqual(y, x.cos())
def test_disable_fwd_grad_inside(self, device):
def f(x):
with fwAD._set_fwd_grad_enabled(False):
shift = x**2
return x**2 - shift
x = torch.randn([], device=device)
t = torch.ones_like(x)
_, y = jvp(f, (x,), (t,))
self.assertEqual(y, 2 * x)
_, y = jvp(lambda x: jvp(f, (x,), (t,))[1], (x,), (t,))
self.assertEqual(y, 2)
def test_disable_fwd_grad_mixed(self, device):
def f(x):
with fwAD._set_fwd_grad_enabled(False):
shift = x**2
return x**2 - shift
x = torch.randn([], device=device)
t = torch.ones_like(x)
with fwAD._set_fwd_grad_enabled(True):
_, y = jvp(f, (x,), (t,))
self.assertEqual(y, 2 * x)
def test_jvp_inside_autograd_function(self, device):
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
t = torch.ones_like(x)
_, neg_sin_x = jvp(torch.cos, (x,), (t,))
ctx.save_for_backward(x)
return -neg_sin_x
@staticmethod
def backward(ctx, gx):
(x,) = ctx.saved_tensors
t = torch.ones_like(x)
_, cos_x = jvp(torch.sin, (x,), (t,))
return gx * cos_x
x = torch.randn([], device=device, requires_grad=True)
y = MySin.apply(x)
self.assertEqual(y, x.sin())
(gx,) = torch.autograd.grad(y, x)
self.assertEqual(gx, x.cos())
def test_zerotensor_vmapjvp_interaction(self, device):
dummy = torch.ones(4, 1)
x = torch.randn(4, 2)
x_tangent = torch.randn(2)
def push_jvp(dummy, x):
result = jvp(torch.cov, (x,), (x_tangent,))
return result
# Should not error
vmap(vmap(push_jvp, (0, None)))(dummy, x)
@markDynamoStrictTest
class TestLinearize(TestCase):
@dtypes(torch.float)
def test_linearize_basic(self, device, dtype):
x_p = make_tensor((3, 1), device=device, dtype=dtype)
x_t = make_tensor((3, 1), device=device, dtype=dtype)
def fn(x):
return x.cos()
actual_output, jvp_fn = linearize(fn, x_p)
actual_jvp = jvp_fn(x_t)
expected_output, expected_jvp = jvp(fn, (x_p,), (x_t,))
self.assertEqual(actual_output, expected_output)
self.assertEqual(actual_jvp, expected_jvp)
@dtypes(torch.float)
def test_linearize_return(self, device, dtype):
x_p = make_tensor((3, 1), device=device, dtype=dtype)
x_t = make_tensor((3, 1), device=device, dtype=dtype)
def fn(x):
return (x.cos(), x.sum())
actual_output, jvp_fn = linearize(fn, x_p)
actual_jvp = jvp_fn(x_t)
expected_output, expected_jvp = jvp(fn, (x_p,), (x_t,))
self.assertEqual(actual_output, expected_output)
self.assertEqual(actual_jvp, expected_jvp)
@dtypes(torch.float)
def test_linearize_composition(self, device, dtype):
x_p = make_tensor((3, 1), device=device, dtype=dtype)
x_t = make_tensor((3, 3, 1), device=device, dtype=dtype)
def fn(x):
return (x.cos(), x.sum())
_, jvp_fn = linearize(fn, x_p)
actual_batched_jvp = vmap(jvp_fn)(x_t)
def jvp_fn(x_t):
return jvp(fn, (x_p,), (x_t,))[1]
expected_batched_jvp = vmap(jvp_fn)(x_t)
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
@dtypes(torch.float)
def test_linearize_nested_input_nested_output(self, device, dtype):
x_p = make_tensor((3, 1), device=device, dtype=dtype)
x_t = make_tensor((3, 1), device=device, dtype=dtype)
y_p = make_tensor((3, 1), device=device, dtype=dtype)
y_t = make_tensor((3, 1), device=device, dtype=dtype)
z_p = make_tensor((3, 1), device=device, dtype=dtype)
z_t = make_tensor((3, 1), device=device, dtype=dtype)
def fn(arg):
x = arg["x"]
y = arg["yz"][0]
z = arg["yz"][1]
return {"a": x.sum(), "b": {"c": y + z, "d": (x * z, y.exp())}}
inp_p = {"x": x_p, "yz": (y_p, z_p)}
inp_t = {"x": x_t, "yz": (y_t, z_t)}
actual_output, jvp_fn = linearize(fn, inp_p)
actual_jvp = jvp_fn(inp_t)
expected_output, expected_jvp = jvp(fn, (inp_p,), (inp_t,))
self.assertEqual(actual_output, expected_output)
self.assertEqual(actual_jvp, expected_jvp)
@onlyCUDA
def test_linearize_errors(self):
dtype = torch.float
device = torch.device("cpu")
x_p = make_tensor((3, 1), device=device, dtype=dtype)
x_t = make_tensor((3, 1), device=device, dtype=dtype)
def fn(x):
return x.sin()
_, jvp_fn = linearize(fn, x_p)
with self.assertRaisesRegex(
RuntimeError, "to have the same argspec as the primals"
):
jvp_fn((x_t, x_t))
with self.assertRaisesRegex(
RuntimeError, "in flattened pytree doesn't match the shape"
):
jvp_fn(x_t.unsqueeze(0))
with self.assertRaisesRegex(
RuntimeError, "in flattened pytree doesn't match the dtype"
):
jvp_fn(x_t.to(torch.double))
with self.assertRaisesRegex(
RuntimeError, "in flattened pytree doesn't match the device"
):
jvp_fn(x_t.to(torch.device("cuda")))
# The tests here follow the cases in [Forward Grad View/inplace]
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/autograd_meta.cpp#L18-L43
@markDynamoStrictTest
class TestVmapJvpInplaceView(TestCase):
# Case 1 in [Forward Grad View/inplace]
def test_all_dual_no_view(self, device):
B = 2
def push_jvp(f):
def inner(x, xt, y, yt):
return jvp(f, (x, y), (xt, yt))
return inner
def f(x, y):
x.copy_(y)
return x
x = torch.randn(3, B, device=device)
xt = torch.randn(3, B, device=device)
y = torch.randn(3, B, device=device)
yt = torch.randn(3, B, device=device)
out, out_tangent = vmap(push_jvp(f), in_dims=1)(x, xt, y, yt)
self.assertEqual(out, x.movedim(1, 0))
self.assertEqual(out_tangent, yt.movedim(1, 0))
x = torch.randn(3, B, device=device)
xt = torch.randn(3, B, device=device)
y = torch.randn(3, 3, device=device)[:, 1]
yt = torch.randn(6, device=device)[::2]
out, out_tangent = vmap(push_jvp(f), in_dims=(1, 1, None, None))(x, xt, y, yt)
self.assertEqual(out, x.movedim(1, 0))
self.assertEqual(out_tangent, yt.expand(B, 3))
# Case 2 in [Forward Grad View/inplace]
def test_all_dual_base_view_inplace(self, device):
B = 2
def push_jvp(f):
def inner(x, xt, y, yt):
return jvp(f, (x, y), (xt, yt))
return inner
# with view, propagate from view to base
def f(x, y):
view = x[:, ::2]
view.copy_(y)
return view, x
orig_x = torch.randn(2, 6, B, device=device)
orig_xt = torch.randn(2, 6, B, device=device)
x = orig_x.clone()
xt = orig_xt.clone()
y = torch.randn(2, B, 3, device=device)
yt = torch.randn(2, B, 3, device=device)
out, out_tangent = vmap(push_jvp(f), in_dims=(2, 2, 1, 1))(x, xt, y, yt)
expected_out = vmap(f, in_dims=(2, 1))(orig_x.clone(), y)
self.assertEqual(out[0], expected_out[0])
self.assertEqual(out[1], expected_out[1])
self.assertEqual(out_tangent[0], yt.movedim(1, 0))
expected_x_tangent = orig_xt.movedim(-1, 0).clone()
expected_x_tangent[:, :, ::2].copy_(yt.movedim(1, 0))
self.assertEqual(out_tangent[1], expected_x_tangent)
expected = orig_x.movedim(2, 0).clone()
expected[:, :, ::2] = y.movedim(1, 0)
self.assertEqual(x.movedim(2, 0), expected)
# Case 3 in [Forward Grad View/inplace]
def test_all_dual_base_inplace(self, device):
B = 2
def push_jvp(f):
def inner(x, xt, y, yt):
return jvp(f, (x, y), (xt, yt))
return inner
# Case 3: with view, propagate from base to view
def f(x, y):
view = x[0, ::2]
x.copy_(y)
return x, view
x = torch.randn(2, B, 6, device=device)
xt = torch.randn(2, 6, B, device=device)
y = torch.randn(2, B, 6, device=device)
yt = torch.randn(2, B, 6, device=device)
out, out_tangent = vmap(push_jvp(f), in_dims=(1, 2, 1, 1))(x.clone(), xt, y, yt)
expected_out = vmap(f, in_dims=(1, 1))(x.clone(), y)
self.assertEqual(out[0], expected_out[0])
self.assertEqual(out[1], expected_out[1])
self.assertEqual(out_tangent[0], yt.movedim(1, 0))
self.assertEqual(out_tangent[1], yt.movedim(1, 0)[:, 0, ::2])
# Case 4 in [Forward Grad View/inplace]
def test_right_dual_view_prop(self, device):
B = 2
# Changes on the view must propagate to its base. Also:
# - x is a regular Tensor
# - y is a dual tensor
def f(x, y):
x = x.clone()
view = x[0]
view.copy_(y)
return view, x
def push_jvp(x, y, yt):
return jvp(partial(f, x), (y,), (yt,))
x = torch.randn(2, B, 6, device=device)
y = torch.randn(6, B, device=device)
yt = torch.randn(6, B, device=device)
outs, tangents = vmap(push_jvp, in_dims=(1, 1, 1))(x, y, yt)
expected_out = vmap(f, in_dims=(1, 1))(x.clone(), y)
self.assertEqual(outs[0], expected_out[0])
self.assertEqual(outs[1], expected_out[1])
self.assertEqual(tangents[0], yt.movedim(1, 0))
expected_tangent_1 = torch.zeros_like(x).movedim(1, 0)
expected_tangent_1[:, 0].copy_(yt.movedim(1, 0))
self.assertEqual(tangents[1], expected_tangent_1)
# Case 5 in [Forward Grad View/inplace]
def test_right_dual_base_prop(self, device):
B = 2
# Changes on the base must propagate on all its views. Also:
# - x is a regular Tensor
# - y is a dual tensor
def f(x, y):
x = x.clone()
view = x[0]
x.copy_(y)
return view, x
def push_jvp(x, y, yt):
return jvp(partial(f, x), (y,), (yt,))
x = torch.randn(2, B, 6)
y = torch.randn(2, 6, B)
yt = torch.randn(2, 6, B)
outs, tangents = vmap(push_jvp, in_dims=(1, 2, 2))(x, y, yt)
expected_out = vmap(f, in_dims=(1, 2))(x, y)
self.assertEqual(outs[0], expected_out[0])
self.assertEqual(outs[1], expected_out[1])
self.assertEqual(tangents[0], yt.movedim(2, 0)[:, 0])
self.assertEqual(tangents[1], yt.movedim(2, 0))
# Use for testing miscellaneous helper functions
@markDynamoStrictTest
class TestHelpers(TestCase):
def test_CtxWithSavedTensors_error_if_name_collision(self, device):
x = torch.randn([], device=device, requires_grad=True)
y = torch.randn([], device=device, requires_grad=True)
class A(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx._pt_inner_ctx = 1
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gy):
wrapped = torch._functorch.autograd_function.CtxWithSavedTensors(
ctx, (y,)
)
return gy
class B(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx._pt_new_saved_tensors = 1
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gy):
wrapped = torch._functorch.autograd_function.CtxWithSavedTensors(
ctx, (y,)
)
return gy
out = A.apply(x)
with self.assertRaisesRegex(RuntimeError, "name collision"):
out.backward()
out = B.apply(x)
with self.assertRaisesRegex(RuntimeError, "name collision"):
out.backward()
def test_CtxWithSavedTensors_nesting(self, device):
CtxWithSavedTensors = torch._functorch.autograd_function.CtxWithSavedTensors
x = torch.randn([], device=device, requires_grad=True)
y = torch.randn([], device=device)
z = torch.randn([], device=device)
class A(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gy):
ctx_y = CtxWithSavedTensors(ctx, (y,))
# Can't use self.assertEqual because that relies on TLS
# that is not available in multithread autograd
assert len(ctx_y.saved_tensors) == 1
assert torch.allclose(ctx_y.saved_tensors[0], y)
wrapped = CtxWithSavedTensors(ctx_y, (z,))
assert len(wrapped.saved_tensors) == 1
assert torch.allclose(wrapped.saved_tensors[0], z)
assert len(ctx_y.saved_tensors) == 1
assert torch.allclose(ctx_y.saved_tensors[0], y)
return gy * wrapped.saved_tensors[0]
out = A.apply(x)
out.backward()
self.assertEqual(x.grad, z)
def test_CtxWithSavedTensors_overrides_saved_tensors(self, device):
x = torch.randn([], device=device, requires_grad=True)
class A(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gy):
# The override can be literally anything
override = (1, 2, 3)
wrapped = torch._functorch.autograd_function.CtxWithSavedTensors(
ctx, override
)
assert wrapped.saved_tensors == override
return gy
out = A.apply(x)
out.backward()
def test_CtxWithSavedTensors_passthrough(self, device):
x = torch.randn([], device=device, requires_grad=True)
y = torch.randn([], device=device)
class A(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x * y
@staticmethod
def backward(ctx, gz):
# The override can be literally anything
override = (1, 2, 3)
wrapped = torch._functorch.autograd_function.CtxWithSavedTensors(
ctx, override
)
assert wrapped.needs_input_grad[0] == ctx.needs_input_grad[0]
assert wrapped.needs_input_grad[1] == ctx.needs_input_grad[1]
wrapped.foo = "bar"
assert wrapped.foo == "bar"
assert ctx.foo == "bar"
return gz, gz
out = A.apply(x, y)
out.backward()
def test_reductify_leaf(self, device):
reductify_leaf = torch._functorch.autograd_function.reductify_leaf
B = 2
# grad_input None case
output = reductify_leaf(None, None, 0, B)
self.assertIsNone(output)
output = reductify_leaf(None, None, None, B)
self.assertIsNone(output)
# grad_input has bdim, input does not have bdim
grad_input = torch.randn([B, 3, 4], device=device)
output = reductify_leaf(grad_input, 0, None, B)
self.assertEqual(output, grad_input.sum(0))
grad_input = torch.randn([3, B, 4], device=device)
output = reductify_leaf(grad_input, 1, None, B, (3,))
self.assertEqual(output, grad_input.sum(1))
# grad_input does not have bdim, input has bdim
# This can happen if the user returns a fresh Tensor from the backward pass
# that is unrelated to the input
grad_input = torch.randn([3, 4], device=device)
output = reductify_leaf(grad_input, None, 1, B)
self.assertEqual(output, grad_input.view(3, 1, 4).expand(3, B, 4))
grad_input = torch.randn([3, 4], device=device)
output = reductify_leaf(grad_input, None, 1, B, (4,))
self.assertEqual(output, grad_input.view(3, 4, 1).expand(3, 4, B).sum(0))
# grad_input has bdim, input has bdim
grad_input = torch.randn([B, 3, 4], device=device)
output = reductify_leaf(grad_input, 0, 1, B)
self.assertEqual(output, grad_input.movedim(0, 1))
grad_input = torch.randn([3, 4, 5, B], device=device)
output = reductify_leaf(grad_input, 3, 0, B, (5,))
self.assertEqual(output, grad_input.movedim(-1, 2).sum(0).sum(0))
@markDynamoStrictTest
class TestComposability(TestCase):
def test_deprecation_vmap(self, device):
x = torch.randn(3, device=device)
# functorch version of the API is deprecated
with self.assertWarnsRegex(FutureWarning, "Please use `torch.vmap`"):
vmap(torch.sin)
# the non-functorch version is not deprecated
with warnings.catch_warnings():
warnings.simplefilter("error")
torch.vmap(torch.sin)
# Some of these pass, some of these don't
@parametrize(
"transform",
["grad", "jacrev", "jacfwd", "grad_and_value", "hessian", "functionalize"],
)
def test_deprecation_transforms(self, device, transform):
api = getattr(functorch, transform)
new_api = getattr(torch.func, transform)
# functorch version of the API is deprecated
with self.assertWarnsRegex(
FutureWarning, f"Please use `torch.func.{transform}`"
):
api(torch.sin)
# the non-functorch version is not deprecated
with warnings.catch_warnings():
warnings.simplefilter("error")
new_api(torch.sin)
def test_grad_grad(self, device):
x = torch.randn([], device=device)
y = grad(grad(torch.sin))(x)
self.assertEqual(y, -x.sin())
def test_grad_vmap(self, device):
def foo(x):
y = vmap(torch.sin)(x)
return y.sum()
x = torch.randn(3, device=device)
y = grad(foo)(x)
self.assertEqual(y, x.cos())
def test_grad_vjp(self, device):
x = torch.randn(3, device=device)
def foo(x):
_, vjp_fn = vjp(torch.sin, x)
return vjp_fn(x)[0].sum()
y = grad(foo)(x)
expected = grad(lambda x: (x * x.cos()).sum())(x)
self.assertEqual(y, expected)
def test_vmap_grad(self, device):
x = torch.randn(3, device=device)
y = vmap(grad(torch.sin))(x)
self.assertEqual(y, x.cos())
def test_vmap_vmap(self, device):
x = torch.randn(2, 3, device=device)
y = vmap(vmap(torch.sin))(x)
self.assertEqual(y, x.sin())
def test_vmap_vjp(self, device):
x = torch.randn(3, device=device)
_, vjp_fn = vjp(torch.sin, x)
def foo(x):
_, vjp_fn = vjp(torch.sin, x)
return vjp_fn(x)
y = vmap(foo)(x)
self.assertEqual(y, vjp_fn(x))
# TODO: there's a very interesting error message when the following
# is on CPU
xs = torch.randn(5, 3, device=device)
expected = torch.stack([vjp_fn(x)[0] for x in xs])
result = vmap(lambda x: vjp_fn(x)[0])(xs)
self.assertEqual(result, expected)
def test_vjp_grad(self, device):
x = torch.randn([], device=device)
y, vjp_fn = vjp(grad(torch.sin), x)
self.assertEqual(y, x.cos())
v = torch.randn([])
self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
def test_vjp_vmap(self, device):
x = torch.randn(3, device=device)
y, vjp_fn = vjp(vmap(torch.sin), x)
self.assertEqual(y, x.sin())
v = torch.randn(3, device=device)
self.assertEqual(vjp_fn(v)[0], x.cos() * v)
def test_vjp_vjp(self, device):
x = torch.randn(3, device=device)
y, vjp_fn = vjp(torch.sin, x)
self.assertEqual(y, x.sin())
y, vjp_fn = vjp(lambda x: vjp_fn(x)[0], x)
self.assertEqual(y, x * x.cos())
y = vjp_fn(x)[0]
# Honestly IDK what the result here is... but at least it runs
def test_make_fx_vmap(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(5, 3)
f = vmap(f)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(5, 3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_jacrev(self, device):
def f(x):
return x.sin().sum()
inp = torch.randn(3)
f = jacrev(jacrev(f))
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_vjp(self, device):
def f(x):
return torch.sin(x).sum()
primals = torch.randn(3)
_, vjp_fn = vjp(f, primals)
cotangent = torch.randn(())
fx_f = make_fx(vjp_fn)(cotangent, True, True)
new_cotangent = torch.randn(())
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
# FIXME: test fails in Windows
@unittest.skipIf(IS_WINDOWS, "fails in Windows; needs investigation")
@unittest.skipIf(IS_FBCODE, "can't subprocess in fbcode")
# it is redundant to run this test twice on a machine that has GPUs
@onlyCPU
def test_no_warning_on_import_functorch(self, device):
out = subprocess.check_output(
[sys.executable, "-W", "always", "-c", "import functorch"],
stderr=subprocess.STDOUT,
cwd=os.path.dirname(os.path.realpath(__file__)),
).decode("utf-8")
self.assertEqual(out, "")
def test_requires_grad_inside_transform(self, device):
def f(x):
x.requires_grad_()
return x.sin().sum()
x = torch.randn(3)
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
vmap(f)(x)
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
grad(f)(x)
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
vmap(grad(f))(x)
x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
grad(grad(f))(x)
def test_retain_grad_inside_transform(self, device):
def f(x):
y = x.sin()
y.retain_grad()
return y.sum()
x = torch.randn(3)
with self.assertRaisesRegex(RuntimeError, "Tensor.retain_grad()"):
grad(f)(x)
def test_autograd_functional_jacrev_inside_transform(self, device):
def f(x):
y = torch.autograd.functional.jacobian(lambda x: x.sin().sum(), x)
return y
B = 5
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
vmap(f)(x)
x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
grad(f)(x)
def test_autograd_functional_vjp_inside_transform(self, device):
def f(x):
y = torch.autograd.functional.vjp(lambda x: x.sin().sum(), x)
return y
B = 5
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
vmap(f)(x)
x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
grad(f)(x)
def test_autograd_functional_jvp_inside_transform(self, device):
def f(x):
t = torch.ones_like(x)
y = torch.autograd.functional.jvp(lambda x: x.sin().sum(), (x,), (t,))
return y
B = 5
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
vmap(f)(x)
x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
grad(f)(x)
def test_autograd_functional_jacfwd_inside_transform(self, device):
def f(x):
y = torch.autograd.functional.jacobian(
lambda x: x.sin().sum(), x, strategy="forward-mode", vectorize=True
)
return y
B = 5
x = torch.randn(B, 3)
with self.assertRaisesRegex(
RuntimeError, "Batching rule not implemented for aten::_make_dual"
):
vmap(f)(x)
@parametrize(
"transform",
[
"vmap",
"grad",
"jacrev",
"jacfwd",
"grad_and_value",
"hessian",
"functionalize",
],
)
def test_autograd_function_no_setup_context(self, device, transform):
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.sin()
@staticmethod
def backward(ctx, gy):
(x,) = ctx.saved_tensors
return gy * x.cos()
x = torch.randn(3, device=device)
transform = getattr(functorch, transform)
with self.assertRaisesRegex(RuntimeError, "must override the setup_context"):
transform(MySin.apply)(x)
# Some of these pass, some of these don't
@parametrize(
"transform",
[
"grad",
"jacrev",
"grad_and_value",
"hessian",
],
)
def test_transforms_dont_support_saved_tensor_hooks(self, device, transform):
def f(x):
return torch.sin(x).sum()
def g(x):
with torch.autograd.graph.save_on_cpu():
return f(x)
x = torch.randn(3, device=device)
if transform == "functionalize":
transform = functorch.experimental.functionalize
else:
transform = getattr(functorch, transform)
with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"):
with torch.autograd.graph.save_on_cpu():
transform(f)(x)
with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"):
transform(g)(x)
def test_vjp_doesnt_support_saved_tensor_hooks(self, device):
def f(x):
return torch.sin(x).sum()
def g(x):
with torch.autograd.graph.save_on_cpu():
return f(x)
x = torch.randn(3, device=device)
with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"):
with torch.autograd.graph.save_on_cpu():
vjp(f, x)
with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"):
vjp(g, x)
def test_jvp_supports_saved_tensor_hooks(self, device):
def f(x):
return torch.sin(x).sum()
def g(x):
with torch.autograd.graph.save_on_cpu():
return f(x)
x = torch.randn(3, device=device)
t = torch.randn(3, device=device)
# smoke tests
with torch.autograd.graph.save_on_cpu():
jvp(f, (x,), (t,))
# smoke tests
jvp(g, (x,), (t,))
def test_can_use_functionalize_when_key_is_excluded(self, device):
def f(x):
y = x.clone()
y.sin_()
return y
x = torch.randn([], device=device)
expected = f(x)
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
gm = make_fx(functorch.functionalize(f))(x)
self.assertTrue("sin_" not in gm.code)
self.assertEqual(gm(x), expected)
local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
self.assertTrue(local_exclude_set.has(DispatchKey.Functionalize))
def test_can_use_vmap_when_key_is_excluded(self, device):
def f(x):
return x.sum(0)
x = torch.randn(3, device=device)
expected = vmap(f)(x)
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.FuncTorchBatched)):
result = vmap(f)(x)
self.assertEqual(result, expected)
local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
self.assertTrue(local_exclude_set.has(DispatchKey.FuncTorchBatched))
def test_can_use_grad_when_key_is_excluded(self, device):
def f(x):
return x.sin()
x = torch.randn([], device=device)
expected = grad(f)(x)
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Autograd)):
result = grad(f)(x)
self.assertEqual(result, expected)
local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
self.assertTrue(local_exclude_set.has(DispatchKey.Autograd))
@markDynamoStrictTest
class TestMakeFunctional(TestCase):
@parametrize("disable_autograd_tracking", [True, False])
def test_disable_autograd_tracking(self, disable_autograd_tracking):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
def forward(self, x):
x = self.linear(x)
return x
mod = Foo()
_, params = make_functional(
mod, disable_autograd_tracking=disable_autograd_tracking
)
self.assertEqual(len(params), 2)
for param in params:
self.assertEqual(param.requires_grad, not disable_autograd_tracking)
def test_parameter_tying(self):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.randn(3))
self.linear = nn.Linear(3, 3)
self.linear.bias = self.bias
self.linear_tied = self.linear
def forward(self, x):
x = self.linear(x)
x = self.linear_tied(x)
x = x + self.bias
return x
torch.manual_seed(1)
mod = Foo()
func, _ = make_functional(mod)
torch.manual_seed(0)
mod = Foo()
_, params = make_functional(mod)
self.assertEqual(len(params), 2)
x = torch.randn(2, 3)
result = func(params, x)
expected = mod(x)
self.assertEqual(result, expected)
def test_buffer_tying(self):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.randn(3))
self.linear = nn.Linear(3, 3)
self.buffer = nn.Buffer(torch.randn(3))
self.buffer_tied = self.buffer
def forward(self, x):
x = self.linear(x)
x = x + self.bias
x = x + self.buffer
x = x + self.buffer_tied
return x
torch.manual_seed(1)
mod = Foo()
func, _, _ = make_functional_with_buffers(mod)
torch.manual_seed(0)
mod = Foo()
_, params, buffers = make_functional_with_buffers(mod)
self.assertEqual(len(params), 3)
self.assertEqual(len(buffers), 1)
x = torch.randn(2, 3)
result = func(params, buffers, x)
expected = mod(x)
self.assertEqual(result, expected)
@parametrize("disable_autograd_tracking", [True, False])
def test_with_buffers_disable_autograd_tracking(self, disable_autograd_tracking):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.buffer = nn.Buffer(torch.randn(3))
def forward(self, x):
x = self.linear(x)
x = x + self.buffer
return x
mod = Foo()
_, params, buffers = make_functional_with_buffers(
mod, disable_autograd_tracking=disable_autograd_tracking
)
self.assertEqual(len(params), 2)
self.assertEqual(len(buffers), 1)
for param in params:
self.assertEqual(param.requires_grad, not disable_autograd_tracking)
@parametrize("detach_params", [True, False])
def test_using_detach_functional_call(self, detach_params):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.buffer = nn.Buffer(torch.randn(3))
def forward(self, x):
x = self.linear(x)
x = x + self.buffer
return x
def params_dict(mod):
named_params = mod.named_parameters()
return (
{k: v.detach() for k, v in named_params}
if detach_params
else dict(named_params)
)
mod = Foo()
x = torch.randn(3, 3)
d = (params_dict(mod), dict(mod.named_buffers()))
out = functional_call(mod, d, x)
self.assertEqual(out.grad_fn is None, detach_params)
def test_parameter_tying_grad(self):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.weight = self.linear.weight
self.bias = self.linear.bias
def forward(self, x):
x = self.linear(x)
x = F.linear(x, self.weight, self.bias)
return x
x = torch.randn(2, 3)
torch.manual_seed(0)
mod = Foo()
loss = mod(x).sum()
expected = torch.autograd.grad(loss, mod.parameters())
mod = Foo()
fmod, _, _ = make_functional_with_buffers(mod)
torch.manual_seed(0)
mod = Foo()
_, params, buffers = make_functional_with_buffers(mod)
def compute_loss(params, buffers, x):
return fmod(params, buffers, x).sum()
result = grad(compute_loss)(params, buffers, x)
self.assertEqual(result, expected)
def test_parameter_tying_ensemble(self):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.weight = self.linear.weight
self.bias = self.linear.bias
self.buffer = nn.Buffer(torch.randn(3))
self.buffer_tied = self.buffer
def forward(self, x):
x = self.linear(x)
x = F.linear(x, self.weight, self.bias)
x = x + self.buffer
x = x + self.buffer_tied
return x
num_models = 2
xs = torch.randn(num_models, 64, 3)
models = [Foo() for _ in range(num_models)]
fmodel, _, _ = combine_state_for_ensemble(models)
torch.manual_seed(0)
models = [Foo() for _ in range(num_models)]
_, params, buffers = combine_state_for_ensemble(models)
result = vmap(fmodel)(params, buffers, xs)
torch.manual_seed(0)
models = [Foo() for _ in range(num_models)]
expected = torch.stack([model(x) for model, x in zip(models, xs)])
self.assertEqual(result, expected)
@parametrize("mechanism", ["make_functional", "functional_call"])
def test_correctness_mnist(self, mechanism):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
x = torch.randn(64, 1, 32, 32)
torch.manual_seed(301)
fnet, _ = _get_weights_and_functional_call(Net(), mechanism)
torch.manual_seed(0)
_, params = _get_weights_and_functional_call(Net(), mechanism)
result = fnet(params, x)
torch.manual_seed(0)
net = Net()
expected = net(x)
self.assertEqual(result, expected)
def test_combine_state_for_ensemble_error(self):
in_features = 2
out_features = 2
models = []
with self.assertRaisesRegex(RuntimeError, "Expected at least one model"):
_ = combine_state_for_ensemble(models)
num_models = 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
models[1].eval()
with self.assertRaisesRegex(RuntimeError, "same training/eval mode"):
_ = combine_state_for_ensemble(models)
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
models[1] = torch.nn.Conv2d(3, 3, (3, 3))
with self.assertRaisesRegex(RuntimeError, "models to be of the same class"):
_ = combine_state_for_ensemble(models)
def test_combine_state_for_ensemble_smoke(self):
in_features = 2
out_features = 2
num_models = 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
_ = combine_state_for_ensemble(models)
def test_stack_module_state_smoke(self):
in_features = 2
out_features = 2
num_models = 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
_ = stack_module_state(models)
def test_stack_module_state_leaf(self):
in_features = 2
out_features = 2
num_models = 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
params, buffers = stack_module_state(models)
for param in params.values():
self.assertTrue(param.requires_grad)
self.assertTrue(param.is_leaf)
def test_stack_module_state_mismatch_error(self):
in_features = 2
out_features = 2
num_models = 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
models[0].weight.requires_grad_(False)
with self.assertRaisesRegex(RuntimeError, "same .requires_grad"):
params, buffers = stack_module_state(models)
def test_stack_module_state_error(self):
in_features = 2
out_features = 2
models = []
with self.assertRaisesRegex(
RuntimeError, "stack_module_state:.* Expected at least one model"
):
_ = stack_module_state(models)
num_models = 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
models[1].eval()
with self.assertRaisesRegex(
RuntimeError, "stack_module_state:.* same training/eval mode."
):
_ = stack_module_state(models)
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
models[1] = torch.nn.Conv2d(3, 3, (3, 3))
with self.assertRaisesRegex(
RuntimeError, "stack_module_state:.* models to be of the same class"
):
_ = stack_module_state(models)
@parametrize("mechanism", ["make_functional", "functional_call"])
def test_make_functional_state_correctly_returned_after_forward(self, mechanism):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
def forward(self, x):
x = self.linear(x)
return x
def get_module_info(mod):
if mechanism == "make_functional":
return make_functional(mod)
else:
assert mechanism == "functional_call"
return mod, dict(mod.named_parameters())
mod = Net()
func_mod, params = get_module_info(mod)
# state in func.names_map
mod = func_mod.stateless_model if mechanism == "make_functional" else func_mod
old_state_linear_weight = mod.linear.weight
old_state_linear_bias = mod.linear.bias
self.assertIsNotNone(old_state_linear_weight)
self.assertIsNotNone(old_state_linear_bias)
x = torch.randn(4, 3)
if mechanism == "make_functional":
func_mod(params, x)
else:
assert mechanism == "functional_call"
functional_call(func_mod, params, x)
mod = func_mod.stateless_model if mechanism == "make_functional" else func_mod
new_state_linear_weight = mod.linear.weight
new_state_linear_bias = mod.linear.bias
self.assertIsNotNone(new_state_linear_weight)
self.assertIsNotNone(new_state_linear_bias)
self.assertEqual(old_state_linear_weight, new_state_linear_weight)
self.assertEqual(old_state_linear_bias, new_state_linear_bias)
@markDynamoStrictTest
class TestExamplesCorrectness(TestCase):
def _update_params(self, params, grads, alpha, mechanism):
if mechanism == "make_functional":
return [(params[i] - alpha * grads[i]) for i in range(len(params))]
else:
assert mechanism == "functional_call"
return {k: params[k] - alpha * grads[k] for k in params}
@parametrize("mechanism", ["make_functional", "functional_call"])
def test_maml_regression(self, device, mechanism):
class ThreeLayerNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(1, 40)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(40, 40)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(40, 1)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
return x
# TODO: should replace with F.mse_loss
def mse_loss(x, y):
return torch.mean((x - y) ** 2)
net, params = _get_weights_and_functional_call(
ThreeLayerNet().to(device), mechanism
)
K = 20
num_tasks = 4
alpha = 0.1
def sample_tasks(outer_batch_size, inner_batch_size):
# Select amplitude and phase for the task
As = []
phases = []
for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=0.5))
phases.append(np.random.uniform(low=0.0, high=np.pi))
def get_batch():
xs, ys = [], []
for A, phase in zip(As, phases):
x = np.random.uniform(
low=-5.0, high=5.0, size=(inner_batch_size, 1)
)
y = A * np.sin(x + phase)
xs.append(x)
ys.append(y)
return torch.tensor(xs, dtype=torch.float, device=device), torch.tensor(
ys, dtype=torch.float, device=device
)
x1, y1 = get_batch()
x2, y2 = get_batch()
return x1, y1, x2, y2
def get_loss_for_task(use_transform, x1, y1, x2, y2):
def inner_loss(params, x1, y1):
f = net(params, x1)
loss = mse_loss(f, y1)
return loss
if use_transform:
grads = grad(inner_loss)(params, x1, y1)
else:
loss = inner_loss(params, x1, y1)
grad_params, spec = tree_flatten(params)
grads = torch.autograd.grad(loss, grad_params, create_graph=True)
grads = tree_unflatten(grads, spec)
new_params = self._update_params(params, grads, alpha, mechanism)
v_f = net(new_params, x2)
return mse_loss(v_f, y2)
task = sample_tasks(num_tasks, K)
list_params = (
params if mechanism == "make_functional" else list(params.values())
)
# Compute with vmap+grad
inner_losses = vmap(partial(get_loss_for_task, True))(
task[0], task[1], task[2], task[3]
)
loss2 = sum(inner_losses) / len(inner_losses)
result_grads = torch.autograd.grad(loss2, list_params)
# Compute without vmap+grad
inner_losses = [
get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i])
for i in range(num_tasks)
]
loss2 = sum(inner_losses) / len(inner_losses)
expected_grads = torch.autograd.grad(loss2, list_params)
self.assertEqual(result_grads, expected_grads)
@parametrize("mechanism", ["make_functional", "functional_call"])
def test_maml_omniglot(self, device, mechanism):
# TODO: there appears to be precision issues for float32
dtype = torch.double
# TODO: We don't support inplace relu?
inplace_relu = False
n_way = 5
n_inner_iter = 2
num_tasks = 2
# real example uses batch norm but it's numerically unstable in the first
# iteration, when near 0, and won't produce same gradients. Uses group norm instead
net = (
nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.GroupNorm(64, 64, affine=True),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.GroupNorm(64, 64, affine=True),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.GroupNorm(64, 64, affine=True),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(64, n_way),
)
.to(device)
.to(dtype)
)
fnet, params, buffers = _get_weights_and_functional_call_with_buffers(
net, mechanism
)
net = (params, buffers, fnet)
def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry):
params, buffers, fnet = net
querysz = x_qry.size(0)
def compute_loss(new_params, buffers, x, y):
logits = fnet(new_params, buffers, x)
loss = F.cross_entropy(logits, y)
return loss
new_params = params
for _ in range(n_inner_iter):
if use_transform:
grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
else:
res = compute_loss(new_params, buffers, x_spt, y_spt)
grad_params, spec = tree_flatten(new_params)
grads = torch.autograd.grad(res, grad_params, create_graph=True)
grads = tree_unflatten(grads, spec)
new_params = self._update_params(new_params, grads, 1e-1, mechanism)
qry_logits = fnet(new_params, buffers, x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry)
qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz
return qry_loss, qry_acc
# Get some sample inputs...
x_spt = torch.randn(num_tasks, 25, 1, 28, 28, dtype=dtype, device=device)
y_spt = torch.randint(0, 5, (num_tasks, 25), device=device)
x_qry = torch.randn(num_tasks, 75, 1, 28, 28, dtype=dtype, device=device)
y_qry = torch.randint(0, 5, (num_tasks, 75), device=device)
# compute with vmap + grad
compute_loss = partial(loss_for_task, net, n_inner_iter, True)
qry_losses, _ = vmap(compute_loss)(x_spt, y_spt, x_qry, y_qry)
list_params = (
params if mechanism == "make_functional" else list(params.values())
)
result_grads = torch.autograd.grad(qry_losses.sum(), list_params)
# compute without vmap + grad
compute_loss = partial(loss_for_task, net, n_inner_iter, False)
losses = [
compute_loss(x_spt[i], y_spt[i], x_qry[i], y_qry[i])[0]
for i in range(num_tasks)
]
expected_grads = torch.autograd.grad(sum(losses), list_params)
self.assertEqual(result_grads, expected_grads)
@parametrize("mechanism", ["make_functional", "functional_call"])
@parametrize("originally_track_running_stats", [True, False])
def test_update_batch_norm(self, device, originally_track_running_stats, mechanism):
dtype = torch.double
inplace_relu = False
classes = 5
num_batches = 2
net = (
nn.Sequential(
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(
64, affine=True, track_running_stats=originally_track_running_stats
),
nn.ReLU(inplace=inplace_relu),
nn.Flatten(),
nn.Linear(43264, classes),
)
.to(device)
.to(dtype)
)
replace_all_batch_norm_modules_(net)
transformed_net = net
fnet, params, buffers = _get_weights_and_functional_call_with_buffers(
transformed_net, mechanism
)
criterion = nn.CrossEntropyLoss()
def compute_loss(x, y, params, buffers):
return criterion(fnet(params, buffers, x), y)
# Get some sample inputs...
x = torch.randn(num_batches, 1, 64, 28, 28, device=device, dtype=dtype)
y = torch.randint(0, classes, (num_batches, 1), device=device)
# compute some per sample grads with vmap + grad
result_grads = vmap(grad(compute_loss, argnums=2), in_dims=(0, 0, None, None))(
x, y, params, buffers
)
# compute some per sample grads without vmap + grad
fnet, params, buffers = _get_weights_and_functional_call_with_buffers(
transformed_net, mechanism
)
flat_params, spec = tree_flatten(params)
expected_grads = [
torch.autograd.grad(compute_loss(x[i], y[i], params, buffers), flat_params)
for i in range(num_batches)
]
expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)]
expected_grads = tree_unflatten(expected_grads, spec)
self.assertEqual(result_grads, expected_grads)
@parametrize("jac", ["jacfwd", "jacrev"])
def test_lennard_jones_batched_jac(self, device, jac):
sigma = 0.5
epsilon = 4.0
jac = getattr(functorch, jac)
def lennard_jones(r):
return epsilon * ((sigma / r) ** 12 - (sigma / r) ** 6)
def lennard_jones_force(r):
"""Get magnitude of LJ force"""
return -epsilon * (
(-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)
)
r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device)
drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device))
norms = torch.norm(drs, dim=1).reshape(-1, 1)
training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1)
training_forces = torch.stack(
[force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)]
)
model = nn.Sequential(
nn.Linear(1, 16),
nn.Tanh(),
nn.Linear(16, 16),
nn.Tanh(),
nn.Linear(16, 16),
nn.Tanh(),
nn.Linear(16, 16),
nn.Tanh(),
nn.Linear(16, 1),
).to(device)
def make_prediction(model, drs, use_functorch):
norms = torch.norm(drs, dim=1).reshape(-1, 1)
energies = model(norms)
if use_functorch:
network_derivs = vmap(jac(model))(norms).squeeze(-1)
forces = -network_derivs * drs / norms
else:
forces = []
for r, dr in zip(norms, drs):
network_deriv = torch.autograd.functional.jacobian(
model, r, create_graph=True
)
force = -network_deriv * dr / r
forces.append(force)
forces = torch.cat(forces)
return energies, forces
def loss_fn(energies, forces, predicted_energies, predicted_forces):
return (
F.mse_loss(energies, predicted_energies)
+ 0.01 * F.mse_loss(forces, predicted_forces) / 3
)
energies, forces = make_prediction(model, drs, use_functorch=True)
loss = loss_fn(training_energies, training_forces, energies, forces)
result = torch.autograd.grad(loss, model.parameters())
energies, forces = make_prediction(model, drs, use_functorch=False)
loss = loss_fn(training_energies, training_forces, energies, forces)
expected = torch.autograd.grad(loss, model.parameters())
self.assertEqual(result, expected)
@parametrize("mechanism", ["make_functional", "functional_call"])
def test_ensemble_regression(self, device, mechanism):
def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
ts = torch.linspace(0, 1, n_samples)
rs = ts**0.5
thetas = rs * rotations * 2 * math.pi
signs = torch.randint(0, 2, (n_samples,)) * 2 - 1
labels = (signs > 0).to(torch.long)
xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples) * noise_std
ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples) * noise_std
points = torch.stack([xs, ys], dim=1)
return points.to(device), labels.to(device)
points, labels = make_spirals(100, noise_std=0.05)
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes
self.fc1 = nn.Linear(2, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x
loss_fn = nn.NLLLoss()
func_model, weights = _get_weights_and_functional_call(
MLPClassifier().to(device), mechanism
)
def train_step_fn(use_transform, weights, batch, targets, lr=0.2):
def compute_loss(weights, batch, targets):
output = func_model(weights, batch)
loss = loss_fn(output, targets)
return loss
if use_transform:
grad_weights, loss = grad_and_value(compute_loss)(
weights, batch, targets
)
else:
loss = compute_loss(weights, batch, targets)
flat_weights, spec = tree_flatten(weights)
flat_grad_weights = torch.autograd.grad(loss, flat_weights)
grad_weights = tree_unflatten(flat_grad_weights, spec)
new_weights = self._update_params(weights, grad_weights, lr, mechanism)
return (loss, new_weights)
def unpack(train_result):
return train_result[0], train_result[1]
def init_fn(num_models):
models = tuple(MLPClassifier().to(device) for _ in range(num_models))
if mechanism == "make_functional":
return combine_state_for_ensemble(models)[1]
else:
return stack_module_state(models)[0]
def slice_weights(batched_weights, index):
return tree_map(
lambda weight: weight[index].detach().requires_grad_(), batched_weights
)
batched_weights = init_fn(num_models=2)
parallel_train_step_fn = vmap(
partial(train_step_fn, True), in_dims=(0, None, None)
)
result_loss, result_weights = unpack(
parallel_train_step_fn(batched_weights, points, labels)
)
loss0, weights0 = unpack(
train_step_fn(False, slice_weights(batched_weights, 0), points, labels)
)
loss1, weights1 = unpack(
train_step_fn(False, slice_weights(batched_weights, 1), points, labels)
)
expected_loss = torch.stack([loss0, loss1])
weights0, spec0 = tree_flatten(weights0)
weights1, spec1 = tree_flatten(weights1)
assert spec0 == spec1
expected_weights = tuple(
torch.stack([w0, w1]) for w0, w1 in zip(weights0, weights1)
)
expected_weights = tree_unflatten(expected_weights, spec0)
self.assertEqual(result_loss, expected_loss)
self.assertEqual(result_weights, expected_weights)
@parametrize(
"dropout_layer",
[
subtest(nn.Dropout, "Dropout"),
subtest(nn.AlphaDropout, "AlphaDropout"),
subtest(nn.FeatureAlphaDropout, "FeatureAlphaDropout"),
],
)
@parametrize("mechanism", ["make_functional", "functional_call"])
def test_find_learning_rate_ensembling(self, device, dropout_layer, mechanism):
# This example mimics what a user might do when trying to find the optimal learning rate. They would
# want to run a bunch of models with the same behavior (including the same dropout!) and have them
# each run with different learning rates. Specifically, this is an example of using same randomness with vmap
points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint(
0, 2, (100,), device=device
)
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes
self.dropout = dropout_layer()
self.fc1 = nn.Linear(16, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
def forward(self, x):
x = self.dropout(x)
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x
loss_fn = nn.NLLLoss()
func_model, weights = _get_weights_and_functional_call(
MLPClassifier().to(device), mechanism
)
def train_step_fn(weights, batch, targets, lr):
def compute_loss(weights, batch, targets):
output = func_model(weights, batch)
loss = loss_fn(output, targets)
return loss
grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
new_weights = self._update_params(weights, grad_weights, lr, mechanism)
if mechanism != "make_functional":
new_weights = list(new_weights.values())
# NB: return looks weird because torch.vmap must return Tensors
return (loss, *new_weights)
def unpack(train_result):
return train_result[0], train_result[1:]
def init_fn(num_models):
og_model = MLPClassifier().to(device)
models = tuple(
copy.deepcopy(og_model) for _ in range(num_models)
) # have same initialization
if mechanism == "make_functional":
return combine_state_for_ensemble(models)[1]
else:
return stack_module_state(models)[0]
batched_weights = init_fn(num_models=2)
parallel_train_step_fn = vmap(
train_step_fn, in_dims=(0, None, None, 0), randomness="same"
)
lrs = torch.tensor([0.2, 0.4], device=device)
result_loss, result_weights = unpack(
parallel_train_step_fn(batched_weights, points, labels, lrs)
)
self.assertEqual(result_loss[0], result_loss[1])
self.assertNotEqual(
tuple(weight[0] for weight in result_weights),
tuple(weight[1] for weight in result_weights),
)
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
@parametrize("mechanism", ["make_functional", "functional_call"])
def test_resnet18_per_sample_grads(self, device, mechanism):
import torchvision.models as models
model = models.__dict__["resnet18"](
pretrained=False, norm_layer=(lambda c: nn.GroupNorm(min(32, c), c))
).to(device)
criterion = nn.CrossEntropyLoss(
reduction="sum"
) # avoid cross batch reductions for for loop comparison
func_model, weights = _get_weights_and_functional_call(model, mechanism)
def compute_loss(weights, image, target):
image = image.unsqueeze(0)
target = target.unsqueeze(0)
output = func_model(weights, image)
loss = criterion(output, target)
return loss
batch_size = 3
images = torch.randn(batch_size, 3, 32, 32, device=device)
targets = torch.randint(0, 10, (batch_size,), device=device)
result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))(
weights, images, targets
)
flat_weights, spec = tree_flatten(weights)
expected_grads = [
torch.autograd.grad(
compute_loss(weights, images[i], targets[i]), flat_weights
)
for i in range(batch_size)
]
expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)]
expected_grads = tree_unflatten(expected_grads, spec)
self.assertEqual(result_grads, expected_grads, atol=1e-3, rtol=1.0)
def normalize_devices(fx_g):
for node in fx_g.graph.nodes:
args = list(node.args)
for idx, arg in enumerate(args):
if isinstance(arg, torch.device):
args[idx] = "cpu"
node.args = tuple(args)
new_kwargs = {}
for k, v in node.kwargs.items():
if isinstance(v, torch.device):
v = "cpu"
new_kwargs[k] = v
node.kwargs = new_kwargs
fx_g.recompile()
return fx_g
@markDynamoStrictTest
class TestFunctionalize(TestCase):
def _check_functionalize_correctness(self, f, inpt, *, skip_vmap=False):
inpt1 = inpt.clone()
inpt2 = inpt.clone()
inpt3 = inpt.clone()
expected_outputs = f(inpt1)
if skip_vmap:
actual_outputs = functionalize(f)(inpt2)
else:
actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze()
# Right now the flavor of functionalize that also removes view ops
# isn't being used with vmap
# That's because {view}_copy ops don't have batching rules yet
# (although we should probably fix that)
actual_outputs_view_copy = functionalize(f, remove="mutations_and_views")(inpt3)
# Check that outputs are the same
self.assertEqual(actual_outputs, expected_outputs)
self.assertEqual(actual_outputs_view_copy, expected_outputs)
# Inputs might have been mutated by f: check that they were mutated properly
self.assertEqual(inpt1, inpt2)
self.assertEqual(inpt1, inpt3)
def test_simple_view(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
tmp = torch.ones(2, device=device)
y = x.view(4, 2)
y.add_(tmp)
return x
self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
def test_multioutput_view(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
tmp = torch.ones(2, device=device)
y1, y2 = x.split(2)
y1_view = y1.diagonal()
y1_view.add_(tmp)
return x
self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
def test_inplace_view(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
tmp = torch.ones(4, device=device)
y = x + x
y2 = y.transpose(1, 0)
z = y2[0]
z.add_(tmp)
return y
self._check_functionalize_correctness(
f, torch.zeros(4, 2, device=device), skip_vmap=True
)
# See https://github.com/pytorch/functorch/issues/780
def test_linear(self, device):
def f(x, y, z) -> torch.Tensor:
return torch._C._nn.linear(x, y, z)
x = torch.randn(14, 1, 384, device=device)
y = torch.randn(96, 384, device=device)
z = torch.randn(96, device=device)
out_expected = f(x, y, z)
out_actual = functionalize(f)(x, y, z)
self.assertEqual(out_expected, out_actual)
def test_multioutput_inplace_slice_view(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
tmp = torch.ones(2, 2, device=device)
y = x.view(8)
z0 = y.reshape(2, 4)
z1 = z0.transpose(1, 0)
z1.unsqueeze_(0)
z1.squeeze_()
z2, z3 = z1.split(2)
z2.add_(tmp)
return x
# See Note [Fix vmap slice_scatter]
self._check_functionalize_correctness(
f, torch.zeros(4, 2, device=device), skip_vmap=True
)
# Ensure functionalize works with List[Optional[Tensor]] arguments.
# See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085
def test_functionalize_opt_tensor_list(self, device):
def f(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return x[indices]
inpta = torch.ones(4, device=device)
inptb = torch.arange(2, device=device)
out1 = f(inpta, inptb)
out2 = functionalize(f)(inpta, inptb)
self.assertEqual(out1, out2)
out = make_fx(functionalize(f))(inpta, inptb)
self.assertExpectedInline(
(out.code),
"""\
def forward(self, x_1, indices_1) -> torch.Tensor:
index = torch.ops.aten.index.Tensor(x_1, [indices_1]); x_1 = indices_1 = None
return index
""",
)
# Ensure grad(functionalize(f)) works
def test_functionalize_grad(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
tmp = torch.ones(2, device=device)
y = x + x
z = y.view(4, 2)
y.add_(tmp)
return z.sum()
inpt1 = torch.ones(4, 2, device=device)
inpt2 = torch.ones(4, 2, device=device)
out1 = grad(f)(inpt1)
out2 = grad(functionalize(f))(inpt2)
self.assertEqual(out1, out2)
self.assertEqual(inpt1, inpt2)
@unittest.skipIf(IS_FBCODE, "fails in fbcode")
def test_vmap_functionalize_jvp(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
y = x + x
z = y.view(-1)
y.add_(1)
return z
def jvp_wrapper(x, t):
return jvp(
f,
(x,),
(t,),
)
x = torch.randn(2, 3, device=device)
t = torch.randn(2, 3, device=device)
out1 = vmap(jvp_wrapper)(x, t)
out2 = vmap(functionalize(jvp_wrapper))(x, t)
self.assertEqual(out1, out2)
# TODO: move this test into test_fake_tensor.py
# once functionalize() can be used in core tests.
def test_functionalize_fake_tensors(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
y = x.detach()
return y + y
with FakeTensorMode() as mode:
x = torch.ones(2, device=device, requires_grad=True)
out = functionalize(f)(x)
self.assertEqual(x.size(), (2,))
def test_functionalize_fx_simple(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
tmp = torch.ones(2, device=device)
y = x.view(4, 2)
y.add_(tmp)
return x
# There's a copy_ in the graph, because the input (x) was mutated.
# To preserve semantics, functionalize() needs to propagate the mutation.
fn = make_fx(functionalize(f, remove="mutations_and_views"))
out = fn(torch.zeros(4, 2, device=device))
out = normalize_devices(out)
self.assertExpectedInline(
(out.code),
"""\
def forward(self, x_1) -> torch.Tensor:
ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False)
view_copy = torch.ops.aten.view_copy.default(x_1, [4, 2])
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
copy_ = torch.ops.aten.copy_.default(x_1, view_copy_1); x_1 = None
return view_copy_1
""",
)
def test_functionalize_fx_transpose_simple(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
return x.transpose(1, 0)
fn = make_fx(functionalize(f, remove="mutations_and_views"))
out = fn(torch.zeros(4, 2, device=device))
out = normalize_devices(out)
self.assertExpectedInline(
out.code,
"""\
def forward(self, x_1) -> torch.Tensor:
transpose_copy = torch.ops.aten.transpose_copy.int(x_1, 1, 0); x_1 = None
return transpose_copy
""",
)
def test_functionalize_fx_out_op(self, device):
def f(inpt: torch.Tensor) -> torch.Tensor:
out = torch.empty((), dtype=torch.float32)
torch.add(inpt, inpt, out=out)
out_view = out.view(4)
out_view.add_(1)
return out
fn = make_fx(functionalize(f, remove="mutations_and_views"))
out = fn(torch.arange(4, device=device, dtype=torch.float32))
out = normalize_devices(out)
self.assertExpectedInline(
out.code,
"""\
def forward(self, inpt_1) -> torch.Tensor:
empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = 'cpu', pin_memory = False)
add = torch.ops.aten.add.Tensor(inpt_1, inpt_1); inpt_1 = None
view_copy = torch.ops.aten.view_copy.default(add, [4])
view_copy_1 = torch.ops.aten.view_copy.default(add, [4]); add = None
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None
view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]); add_1 = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4])
return view_copy_2
""",
)
def test_functionalize_fx_multi_out_op(self, device):
def f(inpt: torch.Tensor) -> torch.Tensor:
mins = torch.empty(4, dtype=torch.float32)
maxs = torch.empty(2, 2, dtype=torch.float32)
maxs_view = maxs.view(4)
inpt_view = inpt.view(2, 4)
torch.aminmax(inpt_view, dim=0, out=(mins, maxs_view))
return (maxs, mins)
fn = make_fx(functionalize(f, remove="mutations_and_views"))
out = fn(torch.arange(8, device=device, dtype=torch.float32))
out = normalize_devices(out)
self.assertExpectedInline(
out.code,
"""\
def forward(self, inpt_1) -> torch.Tensor:
empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = 'cpu', pin_memory = False)
empty_1 = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = 'cpu', pin_memory = False)
view_copy = torch.ops.aten.view_copy.default(empty_1, [4]); empty_1 = None
view_copy_1 = torch.ops.aten.view_copy.default(inpt_1, [2, 4]); inpt_1 = None
aminmax = torch.ops.aten.aminmax.default(view_copy_1, dim = 0); view_copy_1 = None
getitem = aminmax[0]
getitem_1 = aminmax[1]; aminmax = None
view_copy_2 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]); getitem_1 = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4])
return (view_copy_2, getitem)
""",
)
def test_functionalize_fx_reapply_views_simple(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
tmp = torch.ones(2, device=device)
y = x.view(4, 2)
y.add_(tmp)
return x
out = make_fx(functionalize(f))(torch.zeros(4, 2, device=device))
out = normalize_devices(out)
self.assertExpectedInline(
out.code,
"""\
def forward(self, x_1) -> torch.Tensor:
ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False)
view = torch.ops.aten.view.default(x_1, [4, 2])
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
view_2 = torch.ops.aten.view.default(view_1, [4, 2])
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = None
return view_1
""",
)
def test_functionalize_nonfunctional_output(self, device):
global_out = torch.ones(2, device=device)
def f() -> torch.Tensor:
return global_out
out = make_fx(functionalize(f))()
out = normalize_devices(out)
self.assertExpectedInline(
out.code,
"""\
def forward(self) -> torch.Tensor:
_tensor_constant0 = self._tensor_constant0
return _tensor_constant0
""",
)
def test_functionalize_optional_tensorlist1(self, device):
def f(a, b) -> torch.Tensor:
# at::index has OptionalTensorList arguments,
# test that here
return a[b]
a = torch.arange(4).reshape(2, 2)
b = torch.ones(2, dtype=torch.long)
out = make_fx(functionalize(f))(a, b)
out = normalize_devices(out)
self.assertExpectedInline(
out.code,
"""\
def forward(self, a_1, b_1) -> torch.Tensor:
index = torch.ops.aten.index.Tensor(a_1, [b_1]); a_1 = b_1 = None
return index
""",
)
@unittest.skipIf(IS_FBCODE, "fails in fbcode")
def test_functionalize_optional_tensorlist2(self, device):
def f(a, b) -> torch.Tensor:
# See https://github.com/pytorch/pytorch/pull/77846
return torch.ops.aten.index(a, b)
a = torch.arange(4).reshape(2, 2)
b = torch.ones(2, dtype=torch.long)
out = make_fx(functionalize(f))(a, b)
self.assertExpectedInline(
out.code,
"""\
def forward(self, a_1, b_1) -> torch.Tensor:
unbind = torch.ops.aten.unbind.int(b_1); b_1 = None
getitem = unbind[0]
getitem_1 = unbind[1]; unbind = None
index = torch.ops.aten.index.Tensor(a_1, [getitem, getitem_1]); a_1 = getitem = getitem_1 = None
return index
""",
)
def test_resize_program_inputs(self, device):
def f(x):
x.resize_(10)
x.fill_(2)
fn = make_fx(functionalize(f))
out = fn(torch.zeros(0, device=device))
out = normalize_devices(out)
self.assertExpectedInline(
(out.code),
"""\
def forward(self, x_1):
resize = torch.ops.aten.resize.default(x_1, [10])
fill = torch.ops.aten.fill.Scalar(resize, 2); resize = None
resize_ = torch.ops.aten.resize_.default(x_1, [10]); x_1 = None
copy_ = torch.ops.aten.copy_.default(resize_, fill); resize_ = fill = None
return None
""",
)
def construct_sum_pyop():
mysum = HigherOrderOperator("mysum")
@mysum.py_impl(torch._C._functorch.TransformType.Vmap)
def mysum_batch_rule(interpreter, x, dim):
if not torch._C._functorch.is_batchedtensor(x):
with interpreter.lower():
x = x.view_as(x) # unnecessary, just here to test the dispatch
return mysum(x, dim)
bdim = torch._C._functorch.maybe_get_bdim(x)
value = torch._C._functorch.get_unwrapped(x)
with interpreter.lower():
value = value.movedim(bdim, 0)
result = mysum(value, dim + 1)
return torch._C._functorch._add_batch_dim(result, 0, interpreter.level())
@mysum.py_impl(torch._C._functorch.TransformType.Grad)
def mysum_grad_rule(interpreter, x, dim):
level = interpreter.level()
class MySum(torch.autograd.function._SingleLevelFunction):
@staticmethod
def forward(ctx, x, dim):
ctx.x_shape = x.shape
ctx.dim = dim
x = torch._C._functorch._unwrap_for_grad(x, level)
with torch.enable_grad(), interpreter.lower():
x = x.view_as(x) # unnecessary, just here to test the dispatch
y = mysum(x, dim)
y = torch._C._functorch._wrap_for_grad(y, level)
return y
@staticmethod
def backward(ctx, gy):
return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None
with enable_single_level_autograd_function():
return MySum.apply(x, dim)
@mysum.py_impl(torch._C.DispatchKey.AutogradCPU)
def mysum_autograd_cpu(x, dim):
return torch.sum(x, dim)
@mysum.py_impl(torch._C.DispatchKey.AutogradCUDA)
def mysum_autograd_cuda(x, dim):
return torch.sum(x, dim)
return mysum
sum_pyop = construct_sum_pyop()
@markDynamoStrictTest
class TestHigherOrderOperatorInteraction(TestCase):
def test_basic_sum(self, device):
x = torch.randn(2, 3, 4, device=device)
result = sum_pyop(x, 1)
self.assertEqual(result, torch.sum(x, 1))
def test_vmap_sum(self, device):
x = torch.randn(2, 3, 4, device=device)
result = vmap(sum_pyop, (0, None))(x, 0)
self.assertEqual(result, torch.sum(x, 1))
result = vmap(vmap(sum_pyop, (0, None)), (0, None))(x, 0)
self.assertEqual(result, torch.sum(x, 2))
def test_grad_sum(self, device):
x = torch.randn(3, device=device)
gx = grad(sum_pyop)(x, 0)
self.assertEqual(gx, torch.ones_like(x))
def test_grad_grad_sum(self, device):
x = torch.randn(3, requires_grad=True, device=device)
def f(x):
# higher order grad. Requires a non-linearity
return sum_pyop(x.sin(), 0)
def grad_f_sum(x):
return grad(f)(x).sum()
ggx = grad(grad_f_sum)(x)
self.assertEqual(ggx, -x.sin())
def test_vmap_grad_sum(self, device):
x = torch.randn(2, 3, device=device)
gx = vmap(grad(sum_pyop), (0, None))(x, 0)
self.assertEqual(gx, torch.ones_like(x))
def test_no_grad_outside_grad(self, device):
x = torch.randn(3, device=device, requires_grad=True)
with torch.no_grad():
y = grad(sum_pyop)(x, 0)
self.assertEqual(y, torch.ones_like(x))
self.assertFalse(y.requires_grad)
def test_no_grad_inside_grad(self, device):
def f(x):
with torch.no_grad():
shift = sum_pyop(x**2, 0)
return sum_pyop(x**2, 0) - shift
x = torch.randn(3, device=device)
y = grad(f)(x)
self.assertEqual(y, 2 * x)
y = grad(lambda x: grad(f)(x).sum())(x)
self.assertEqual(y, torch.full_like(x, 2))
x = torch.randn(3, device=device, requires_grad=True)
y = grad(f)(x)
(z,) = torch.autograd.grad(y.sum(), x)
self.assertEqual(z, torch.full_like(x, 2))
def test_grad_name_wrapping(self, device):
def my_fn(x):
return x.sum()
grad_fn = grad(my_fn)
self.assertEqual(grad_fn.__name__, "my_fn")
def test_functional_call_multiple_dicts(self):
mod = nn.Linear(1, 1)
x = torch.randn((1, 1))
params = ({"weight": torch.zeros(1, 1)}, {"bias": torch.ones(1)})
functional_call(mod, params, x)
def traceable(f):
f = allow_in_graph(f)
@wraps(f)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper
@markDynamoStrictTest
class TestCompileTransforms(TestCase):
@skipIfRocm(msg="test leaks memory on ROCm")
# torch.compile is not supported on Windows CUDA.
# Triton only supports GPU with SM70 or later.
@expectedFailureIf((IS_WINDOWS and TEST_CUDA) or (TEST_CUDA and not SM70OrLater))
def test_compile_vmap_hessian(self, device):
# The model and inputs are a smaller version
# of code at benchmark repo:
# https://github.com/pytorch/benchmark/blob/main/userbenchmark/functorch/vmap_hessian_fc.py
D = 2
B = 4
x = torch.randn(B, D, device=device)
model = nn.Sequential(nn.Linear(D, D), nn.ReLU()).to(device)
params_and_buffers = (
dict(model.named_parameters()),
dict(model.named_buffers()),
)
def predict(params_and_buffers, x):
out = torch.func.functional_call(model, params_and_buffers, x)
return out, out
fn = vmap(
jacfwd(jacrev(predict, argnums=1, has_aux=True), argnums=1, has_aux=True),
in_dims=(None, 0),
)
expected = fn(params_and_buffers, x)
opt_fn = torch.compile(traceable(fn))
actual = opt_fn(params_and_buffers, x)
self.assertEqual(actual, expected)
# torch.compile is not supported on Windows
@expectedFailureIf(IS_WINDOWS)
@torch._dynamo.config.patch(suppress_errors=False)
def test_grad_deprecated_api(self, device):
x = torch.randn((), device=device)
y = torch.randn((), device=device)
def wrapper_fn(x, y):
return functorch.grad(torch.mul)(x, y)
actual = wrapper_fn(x, y)
expected = torch.compile(wrapper_fn, backend="eager", fullgraph=True)(x, y)
fn = torch.compile(wrapper_fn, backend="eager", fullgraph=True)
self.assertEqual(actual, expected)
def wrapper_fn(x, y):
return functorch.grad(torch.mul, argnums=(0, 1))(x, y)
actual = wrapper_fn(x, y)
expected = torch.compile(wrapper_fn, backend="eager", fullgraph=True)(x, y)
self.assertEqual(actual, expected)
only_for = ("cpu", "cuda")
instantiate_device_type_tests(
TestGradTransform,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestVmapOfGrad,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestJac,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestJvp,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestLinearize,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestVmapJvpInplaceView,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestHessian,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestComposability,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestExamplesCorrectness,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestHigherOrderOperatorInteraction,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestFunctionalize,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestAutogradFunction,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestAutogradFunctionVmapAPI,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestHelpers,
globals(),
only_for=only_for,
)
instantiate_parametrized_tests(
TestMakeFunctional,
)
instantiate_device_type_tests(
TestCompileTransforms,
globals(),
only_for=only_for,
)
if __name__ == "__main__":
run_tests()