[functorch] beef testing

This commit is contained in:
Richard Zou
2021-04-21 10:44:44 -07:00
committed by Jon Janzen
parent 93888a3779
commit a39cb7a80f
9 changed files with 191 additions and 274 deletions

View File

@ -4,3 +4,24 @@ from . import _C
from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
from ._src.make_functional import make_functional, make_functional_with_buffers
# Monkeypatching lol
_old_cross_entropy = torch.nn.functional.cross_entropy
def cross_entropy(input, target, weight=None, size_average=None,
ignore_index=-100, reduce=None, reduction='mean'):
if input.dim() == 1 and target.dim() == 0:
input = input.unsqueeze(0)
target = target.unsqueeze(0)
result = _old_cross_entropy(
input, target, weight, size_average,
ignore_index, reduce, reduction)
if reduction == 'none':
return result.squeeze(0)
return result
torch.nn.functional.cross_entropy = cross_entropy

View File

@ -3,10 +3,10 @@ from functools import partial
import collections
import torch.nn as nn
import torch.nn.functional as F
from torch.make_functional import make_functional, make_functional_with_buffers
import gc
from .vmap import vmap
from .make_functional import make_functional, make_functional_with_buffers
from functorch._C import (
_wrap_for_grad,
@ -15,13 +15,6 @@ from functorch._C import (
_grad_decrement_nesting,
)
# x = torch.ones(2, 3)
# y = torch.ones(2, 3)
# # result = vmap(torch.add)(x, y)
# result = vmap(vmap(torch.add))(x, y)
# assert torch.allclose(result, x + y)
# TODO: replace all of these with pytrees
def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
@ -96,46 +89,6 @@ def jacrev(f):
return result
return wrapper_fn
#
#
# def jacrev(f, diff_argnums=(0,)):
# def wrapper(*args):
# torch._C._grad_increment_nesting()
# output = None
# grad_outputs = None
# try:
# args = [_create_differentiable(arg) if i in diff_argnums else arg
# for i, arg in enumerate(args)]
# output = f(*args)
# # Only support single tensor output for now
# assert isinstance(output, torch.Tensor)
# output_numel = output.numel()
# if output_numel != 0:
# grad_output = torch.eye(output_numel).view(output_numel, *output.shape)
#
# diff_args = [args[i] for i in diff_argnums]
# single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1
# # TODO: quick hack...
# if len(diff_args) == 1 and isinstance(diff_args[0], tuple):
# diff_args = diff_args[0]
# # NB: need create_graph so that backward pass isn't run in no_grad mode
#
# def compute_vjp(v):
# return torch.autograd.grad(output, diff_args, v, create_graph=True)
#
# if output_numel == 0:
# grad_input = compute_vjp(grad_output)
# else:
# grad_input = vmap(compute_vjp)(grad_output)
#
# if single_diff_arg:
# grad_input = grad_input[0]
# finally:
# _undo_create_differentiable(args)
# torch._C._grad_decrement_nesting()
# return grad_input, output
# return wrapper
def grad_with_value(f, diff_argnums=(0,), has_aux=False):
def wrapper(*args):
level = _grad_increment_nesting()

View File

@ -8,8 +8,8 @@ import warnings
from functorch._C import (
_add_batch_dim,
_remove_batch_dim,
_vmapmode_decrement_nesting,
_vmapmode_increment_nesting,
_vmap_decrement_nesting,
_vmap_increment_nesting,
)
in_dims_t = Union[int, Tuple]
@ -244,14 +244,6 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
vmap does not provide general autobatching or handle variable-length
sequences out of the box.
"""
warnings.warn(
'functorch.vmap is an experimental prototype that is subject to '
'change and/or deletion. Please use at your own risk. There may be '
'unexpected performance cliffs due to certain operators not being '
'implemented. To see detailed performance warnings please use '
'`torch._C._debug_only_display_vmap_fallback_warnings(True) '
'before the call to `vmap`.',
stacklevel=2)
return _vmap(func, in_dims, out_dims)
# A version of vmap but without the initial "experimental prototype" warning
@ -259,12 +251,14 @@ def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> C
@functools.wraps(func)
def wrapped(*args):
_check_out_dims_is_int_or_int_tuple(out_dims, func)
vmap_level = _vmapmode_increment_nesting()
vmap_level = _vmap_increment_nesting()
torch._C._vmapmode_increment_nesting()
try:
batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func)
batched_outputs = func(*batched_inputs)
_validate_outputs(batched_outputs, func)
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
finally:
_vmapmode_decrement_nesting()
torch._C._vmapmode_decrement_nesting()
_vmap_decrement_nesting()
return wrapped

View File

@ -1487,8 +1487,8 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
// Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution
// here is to just directly call the underlying implementation.
m.impl("size.int", static_cast<int64_t (*)(const Tensor&, int64_t)>(native::size));
m.impl("_add_batch_dim", native::_add_batch_dim);
m.impl("_remove_batch_dim", native::_remove_batch_dim);
// m.impl("_add_batch_dim", native::_add_batch_dim);
// m.impl("_remove_batch_dim", native::_remove_batch_dim);
m.impl("max_pool2d", at::native::max_pool2d); // composite
m.impl("max_pool2d_with_indices", max_pool2d_with_indices_batching_rule);

View File

@ -31,15 +31,23 @@ class DynamicLayerStackHolder : public c10::DebugInfoBase {
thread_local std::shared_ptr<DynamicLayerStackHolder> kDynamicLayerStack;
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
if (kDynamicLayerStack == nullptr) {
kDynamicLayerStack = std::make_shared<DynamicLayerStackHolder>();
c10::ThreadLocalDebugInfo::_push(
// TODO: this isn't a PRODUCER_INFO, but there's nothing else we can use
c10::DebugInfoKind::PRODUCER_INFO,
kDynamicLayerStack);
if (kDynamicLayerStack != nullptr) {
// TODO: can figure out how to memoize this. std::call_once with thread_local?
return kDynamicLayerStack->dynamicLayerStack;
}
if (ThreadLocalDebugInfo::current() != nullptr) {
// TODO: this is going to break if someone else uses PRODUCER_INFO
kDynamicLayerStack = std::static_pointer_cast<DynamicLayerStackHolder>(
ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PRODUCER_INFO));
TORCH_INTERNAL_ASSERT(kDynamicLayerStack != nullptr);
return kDynamicLayerStack->dynamicLayerStack;
}
kDynamicLayerStack = std::make_shared<DynamicLayerStackHolder>();
c10::ThreadLocalDebugInfo::_push(
// TODO: this isn't a PRODUCER_INFO, but there's nothing else we can use
c10::DebugInfoKind::PRODUCER_INFO,
kDynamicLayerStack);
TORCH_INTERNAL_ASSERT(kDynamicLayerStack != nullptr);
// TODO: can figure out how to memoize this. std::call_once with thread_local?
return kDynamicLayerStack->dynamicLayerStack;
}
@ -111,17 +119,17 @@ DynamicLayer popDynamicLayerAndDeleteMetadata() {
auto level = result.layerId();
// TODO: is this lock safe? No one else should be writing to the same bucket
if (c10::show_dispatch_trace_enabled()) {
std::cout << "deleting metadata" << std::endl;
}
// if (c10::show_dispatch_trace_enabled()) {
// std::cout << "deleting metadata" << std::endl;
// }
auto& data = getGlobalDynmetaData();
auto it = data.find(level);
if (it == data.end()) {
return result;
}
if (c10::show_dispatch_trace_enabled()) {
std::cout << "deleted metadata for level " << level << std::endl;
}
// if (c10::show_dispatch_trace_enabled()) {
// std::cout << "deleted metadata for level " << level << std::endl;
// }
// invalidate the thing
*(it->second) = false;
data.erase(level);
@ -215,9 +223,9 @@ static void sanityCheckStack(torch::jit::Stack* stack) {
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
if (c10::show_dispatch_trace_enabled()) {
std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
}
// if (c10::show_dispatch_trace_enabled()) {
// std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
// }
if (dynamicLayerStack.size() == 0) {
sanityCheckStack(stack);
c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset);
@ -276,12 +284,11 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
if (!maybe_tensor_wrapper) {
return tensor;
}
if (maybe_tensor_wrapper->level().value() == cur_level) {
auto tensor_wrapper_level = maybe_tensor_wrapper->level().value();
TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= cur_level);
if (tensor_wrapper_level == cur_level) {
return maybe_tensor_wrapper->value();
}
if (c10::show_dispatch_trace_enabled()) {
std::cout << "unwrap " << cur_level << std::endl;
}
return tensor;
};
auto wrap = [&](const Tensor& tensor) {
@ -291,9 +298,9 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
if (cur_level == 1) {
return tensor;
}
if (c10::show_dispatch_trace_enabled()) {
std::cout << "wrap " << cur_level << std::endl;
}
// if (c10::show_dispatch_trace_enabled()) {
// std::cout << "wrap " << cur_level << std::endl;
// }
return makeTensorWrapper(tensor, cur_level);
};

View File

@ -1,60 +0,0 @@
#include <functorch/csrc/DynamicLayer.h>
#include <functorch/csrc/VmapMode.h>
#include <functorch/csrc/Constants.h>
namespace at {
namespace functorch {
namespace impl {
/// thread_local is a feature that is not enabled by Caffe2 mobile
/// build (e.g. iOS). Therefore, we only provide `at::VmapMode`
/// when we are not in mobile build or when FEATURE_TORCH_MOBILE
/// is on.
#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)
thread_local int64_t VmapMode_current_vmap_level = 0;
int64_t VmapMode::current_vmap_level() {
return VmapMode_current_vmap_level;
}
int64_t VmapMode::increment_nesting() {
VmapMode_current_vmap_level++;
auto level = initAndPushDynamicLayer(kBatchedKey);
if (VmapMode_current_vmap_level == 1) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, true);
}
return level;
}
int64_t VmapMode::decrement_nesting() {
VmapMode_current_vmap_level--;
auto layer = popDynamicLayerAndDeleteMetadata();
TORCH_INTERNAL_ASSERT(layer.key() == kBatchedKey);
if (VmapMode_current_vmap_level == 0) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, false);
}
// TODO: this return value should never be used
return VmapMode_current_vmap_level;
}
#else
int64_t VmapMode::current_nesting_level() {
TORCH_CHECK(false, "VmapMode is not supported on mobile");
}
int64_t VmapMode::increment_nesting() {
TORCH_CHECK(false, "VmapMode is not supported on mobile");
}
int64_t VmapMode::decrement_nesting() {
TORCH_CHECK(false, "VmapMode is not supported on mobile");
}
#endif
} // namespace impl
}
} // namespace at

View File

@ -1,30 +0,0 @@
#pragma once
#include <c10/core/impl/LocalDispatchKeySet.h>
namespace at {
namespace functorch {
namespace impl {
// VmapMode contains a thread local count of how many nested vmaps
// we are currently inside. That number is known as the `vmap level`.
// VmapMode is used in the implementation of the Python `torch.vmap` API.
//
// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
struct TORCH_API VmapMode {
// Returns the vmap level, aka the count of how many nested vmaps we're in.
static int64_t current_vmap_level();
// Increment the count of nested vmaps. If this causes the vmap level to be
// greater than 0, then it enables DispatchKey::VmapMode on all tensors.
static int64_t increment_nesting();
// Decrements the count of nested vmaps. If this causes the vmap level to be
// equal to 0, then it disables DispatchKey::VmapMode on all tensors.
static int64_t decrement_nesting();
};
} // namespace impl
}
} // namespace at

View File

@ -5,7 +5,6 @@
#include <functorch/csrc/DynamicLayer.h>
#include <functorch/csrc/BatchedTensorImpl.h>
#include <functorch/csrc/VmapTransforms.h>
#include <functorch/csrc/VmapMode.h>
namespace at {
namespace functorch {
@ -155,7 +154,19 @@ int64_t _grad_increment_nesting() {
}
int64_t _grad_decrement_nesting() {
return popDynamicLayerAndDeleteMetadata().layerId();
auto layer = popDynamicLayerAndDeleteMetadata();
TORCH_INTERNAL_ASSERT(layer.key() == DispatchKey::Autograd);
return layer.layerId();
}
int64_t _vmap_increment_nesting() {
return initAndPushDynamicLayer(kBatchedKey);
}
int64_t _vmap_decrement_nesting() {
auto layer = popDynamicLayerAndDeleteMetadata();
TORCH_INTERNAL_ASSERT(layer.key() == kBatchedKey);
return layer.layerId();
}
@ -165,8 +176,8 @@ int64_t _grad_decrement_nesting() {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim");
m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim");
m.def("_vmapmode_increment_nesting", &at::functorch::impl::VmapMode::increment_nesting, "add batch dim");
m.def("_vmapmode_decrement_nesting", &at::functorch::impl::VmapMode::decrement_nesting, "remove batch dim");
m.def("_vmap_increment_nesting", &at::functorch::_vmap_increment_nesting, "remove batch dim");
m.def("_vmap_decrement_nesting", &at::functorch::_vmap_decrement_nesting, "remove batch dim");
m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim");
m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim");
m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim");

View File

@ -7,7 +7,7 @@ import functools
import itertools
import warnings
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma
skipCUDAIfNoMagma, onlyOnCPUAndCUDA
import types
from functools import partial
@ -16,19 +16,19 @@ from functorch import grad, vjp, vmap, make_functional, jacrev
class TestGradTransform(TestCase):
def test_primitive(self):
x = torch.randn([])
def test_primitive(self, device):
x = torch.randn([], device=device)
result = grad(torch.sin)(x)
self.assertEqual(result, torch.cos(x))
def test_composite_simple(self):
x = torch.randn(2, 3, 4)
def test_composite_simple(self, device):
x = torch.randn(2, 3, 4, device=device)
result = grad(lambda x: torch.flatten(x).sum())(x)
self.assertEqual(result, torch.ones_like(x))
def test_composite_complicated(self):
x = torch.randn(3)
y = torch.randn(3, 5)
def test_composite_complicated(self, device):
x = torch.randn(3, device=device)
y = torch.randn(3, 5, device=device)
def foo(x, y):
result = x @ y
@ -42,10 +42,10 @@ class TestGradTransform(TestCase):
self.assertEqual(result, expected)
def test_composite_two_ops(self):
def test_composite_two_ops(self, device):
N, C = 2, 5
y = torch.randn(N, C)
targets = torch.randint(0, C, (N,))
y = torch.randn(N, C, device=device)
targets = torch.randint(0, C, (N,), device=device)
def foo(y, targets):
return F.cross_entropy(y, targets)
@ -57,8 +57,8 @@ class TestGradTransform(TestCase):
self.assertEqual(result, expected)
def _test_attributes(self, get_attr_lambda):
x = torch.randn(2, 3, 5, dtype=torch.double)
def _test_attributes(self, get_attr_lambda, device):
x = torch.randn(2, 3, 5, dtype=torch.double, device=device)
expected = get_attr_lambda(x)
def foo(x):
@ -67,20 +67,20 @@ class TestGradTransform(TestCase):
grad(foo)(x)
def test_shape(self):
self._test_attributes(lambda x: x.shape)
def test_shape(self, device):
self._test_attributes(lambda x: x.shape, device)
def test_dtype(self):
self._test_attributes(lambda x: x.dtype)
def test_dtype(self, device):
self._test_attributes(lambda x: x.dtype, device)
def test_is_cuda(self):
self._test_attributes(lambda x: x.is_cuda)
def test_is_cuda(self, device):
self._test_attributes(lambda x: x.is_cuda, device)
def test_numel(self):
self._test_attributes(lambda x: x.numel())
def test_numel(self, device):
self._test_attributes(lambda x: x.numel(), device)
def test_inplace(self):
x = torch.randn([])
def test_inplace(self, device):
x = torch.randn([], device=device)
def foo(x):
return x.clone().sin_()
@ -88,8 +88,8 @@ class TestGradTransform(TestCase):
result = grad(foo)(x)
self.assertEqual(result, x.cos())
def test_inplace_on_view(self):
x = torch.randn(3)
def test_inplace_on_view(self, device):
x = torch.randn(3, device=device)
def foo(x):
y = x.clone()
@ -105,8 +105,8 @@ class TestGradTransform(TestCase):
self.assertEqual(result, expected)
def test_inplace_on_view_base(self):
x = torch.randn(3)
def test_inplace_on_view_base(self, device):
x = torch.randn(3, device=device)
def foo(x):
y = x.clone()
@ -122,13 +122,13 @@ class TestGradTransform(TestCase):
self.assertEqual(result, expected)
def test_nesting_simple(self):
x = torch.randn([])
def test_nesting_simple(self, device):
x = torch.randn([], device=device)
result = grad(grad(torch.sin))(x)
self.assertEqual(result, -torch.sin(x))
def test_escaped_wrappers_are_marked_as_dead(self):
x = torch.randn([])
def test_escaped_wrappers_are_marked_as_dead(self, device):
x = torch.randn([], device=device)
escaped = []
def foo(x):
y = x.sin()
@ -136,10 +136,10 @@ class TestGradTransform(TestCase):
return y
result = grad(foo)(x)
self.assertEqual(escaped[0].dlevel(), -1)
self.assertEqual(functorch._C.dlevel(escaped[0]), -1)
def test_escaped_wrappers_are_ignored(self):
x = torch.randn([])
def test_escaped_wrappers_are_ignored(self, device):
x = torch.randn([], device=device)
escaped = []
def foo(x):
y = x.sin()
@ -149,28 +149,28 @@ class TestGradTransform(TestCase):
result = grad(foo)(x)
something = escaped[0].sum()
self.assertEqual(something.dlevel(), 0)
self.assertEqual(functorch._C.dlevel(something), 0)
self.assertEqual(something, x.sin().sum())
def test_vjp(self):
x = torch.randn([])
def test_vjp(self, device):
x = torch.randn([], device=device)
out, vjp_fn = vjp(torch.sin, x)
self.assertEqual(out, x.sin())
v = torch.randn([])
v = torch.randn([], device=device)
result, = vjp_fn(v)
self.assertEqual(result, v * x.cos())
def test_composed_with_autograd(self):
x = torch.randn([], requires_grad=True)
def test_composed_with_autograd(self, device):
x = torch.randn([], requires_grad=True, device=device)
y = grad(torch.sin)(x)
result, = torch.autograd.grad(y, x)
self.assertEqual(result, -x.sin())
def test_grad_of_vjp_composition(self):
x = torch.randn([])
y = torch.randn([])
def test_grad_of_vjp_composition(self, device):
x = torch.randn([], device=device)
y = torch.randn([], device=device)
def foo(x, y):
out, vjp_fn = vjp(torch.sin, x)
@ -180,9 +180,9 @@ class TestGradTransform(TestCase):
expected = x.cos()
self.assertEqual(result, expected)
def test_vjp_of_grad_composition(self):
x = torch.randn([])
y = torch.randn([])
def test_vjp_of_grad_composition(self, device):
x = torch.randn([], device=device)
y = torch.randn([], device=device)
def foo(x, y):
out, vjp_fn = vjp(grad(torch.sin), x)
@ -192,9 +192,9 @@ class TestGradTransform(TestCase):
expected = -y * x.sin()
self.assertEqual(result, expected)
def test_grad_of_vjp_of_grad_composition(self):
x = torch.randn([])
y = torch.randn([])
def test_grad_of_vjp_of_grad_composition(self, device):
x = torch.randn([], device=device)
y = torch.randn([], device=device)
def foo(x, y):
df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x)
@ -204,9 +204,9 @@ class TestGradTransform(TestCase):
expected = x.cos()
self.assertEqual(result, expected)
def test_views(self):
x = torch.randn([], requires_grad=True)
y = torch.randn([], requires_grad=True)
def test_views(self, device):
x = torch.randn([], requires_grad=True, device=device)
y = torch.randn([], requires_grad=True, device=device)
def silly_sin(x):
x = x.view([])
@ -223,34 +223,34 @@ class TestGradTransform(TestCase):
self.assertEqual(grads[0], -x.sin())
self.assertEqual(grads[1], -y.sin())
def test_view_inplace_simple(self):
def test_view_inplace_simple(self, device):
def foo(x):
x = x.clone()
x.view([]).sin_()
return x
x = torch.randn([], requires_grad=True)
x = torch.randn([], requires_grad=True, device=device)
result = grad(foo)(x)
self.assertEqual(result, x.cos())
class TestVmapOfGrad(TestCase):
def test_per_sample_grads_inplace_view(self):
def test_per_sample_grads_inplace_view(self, device):
def compute_loss(weight, x, t):
x = x.mm(weight)
y = x.squeeze_(0)
return (y - t).sum()
weight = torch.randn(16, 2)
x = torch.randn(64, 1, 16)
t = torch.randn(64, 2)
weight = torch.randn(16, 2, device=device)
x = torch.randn(64, 1, 16, device=device)
t = torch.randn(64, 2, device=device)
result = vmap(partial(grad(compute_loss), weight))(x, t)
expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
expected = torch.stack(expected)
# TODO: Check if the rtol is a problem
self.assertEqual(result, expected, atol=0, rtol=5e-4)
def test_new_zeros_materializes_tensor(self):
def test_new_zeros_materializes_tensor(self, device):
N = 3
C = 5
@ -259,25 +259,25 @@ class TestVmapOfGrad(TestCase):
result.copy_(y)
return result.sum()
x = torch.randn(N)
y = torch.randn(N, C)
x = torch.randn(N, device=device)
y = torch.randn(N, C, device=device)
result = vmap(grad(foo))(x, y)
def test_per_sample_grads_simple(self):
def test_per_sample_grads_simple(self, device):
def compute_loss(weight, x, t):
y = x @ weight
return ((y - t) ** 2).sum()
weight = torch.randn(16, 2)
x = torch.randn(64, 16)
t = torch.randn(64, 2)
weight = torch.randn(16, 2, device=device)
x = torch.randn(64, 16, device=device)
t = torch.randn(64, 2, device=device)
result = vmap(partial(grad(compute_loss), weight))(x, t)
expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
expected = torch.stack(expected)
# TODO: Check if the rtol is a problem
self.assertEqual(result, expected, atol=0, rtol=5e-4)
def test_per_sample_grads_embeddingnet(self):
def test_per_sample_grads_embeddingnet(self, device):
class SampleNet(nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
@ -301,11 +301,11 @@ class TestVmapOfGrad(TestCase):
vocab_size = 1000
batch_shape = [64]
words_per_sentence = 5
data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence))
targets = torch.randint(0, 1, (*batch_shape,))
data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence), device=device)
targets = torch.randint(0, 1, (*batch_shape,), device=device)
# Construct our module
net = SampleNet(vocab_size)
net = SampleNet(vocab_size).to(device=device)
criterion = nn.CrossEntropyLoss()
params = dict(net.named_parameters())
@ -326,42 +326,42 @@ class TestVmapOfGrad(TestCase):
self.assertEqual(r, e, atol=0, rtol=1e-4)
class TestJacrev(TestCase):
def test_simple(self):
x = torch.randn(3)
def test_simple(self, device):
x = torch.randn(3, device=device)
y = jacrev(torch.sin)(x)
expected = torch.diagflat(x.cos())
assert torch.allclose(y, expected)
def test_simple_not_flat(self):
x = torch.randn(2, 3)
def test_simple_not_flat(self, device):
x = torch.randn(2, 3, device=device)
y = jacrev(torch.sin)(x)
expected = torch.diagflat(x.view(-1).cos())
expected = expected.view(2, 3, 2, 3)
assert torch.allclose(y, expected)
def test_vmap_on_jacrev_simple(self):
x = torch.randn(2, 3)
def test_vmap_on_jacrev_simple(self, device):
x = torch.randn(2, 3, device=device)
y = vmap(jacrev(torch.sin))(x)
expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
assert torch.allclose(y, expected)
def test_hessian_simple(self):
def test_hessian_simple(self, device):
def foo(x):
return x.sin().sum()
x = torch.randn(3)
x = torch.randn(3, device=device)
y = jacrev(jacrev(foo))(x)
expected = torch.diagflat(-x.sin())
assert torch.allclose(y, expected)
class TestComposability(TestCase):
def test_grad_grad(self):
x = torch.randn([])
def test_grad_grad(self, device):
x = torch.randn([], device=device)
y = grad(grad(torch.sin))(x)
self.assertEqual(y, -x.sin())
def test_grad_vmap(self):
def test_grad_vmap(self, device):
def foo(x):
y = vmap(torch.sin)(x)
return y.sum()
@ -370,8 +370,8 @@ class TestComposability(TestCase):
y = grad(foo)(x)
self.assertEqual(y, x.cos())
def test_grad_vjp(self):
x = torch.randn(3)
def test_grad_vjp(self, device):
x = torch.randn(3, device=device)
def foo(x):
_, vjp_fn = vjp(torch.sin, x)
@ -381,18 +381,18 @@ class TestComposability(TestCase):
expected = grad(lambda x: (x * x.cos()).sum())(x)
self.assertEqual(y, expected)
def test_vmap_grad(self):
x = torch.randn(3)
def test_vmap_grad(self, device):
x = torch.randn(3, device=device)
y = vmap(grad(torch.sin))(x)
self.assertEqual(y, x.cos())
def test_vmap_vmap(self):
x = torch.randn(2, 3)
def test_vmap_vmap(self, device):
x = torch.randn(2, 3, device=device)
y = vmap(vmap(torch.sin))(x)
self.assertEqual(y, x.sin())
def test_vmap_vjp(self):
x = torch.randn(3)
def test_vmap_vjp(self, device):
x = torch.randn(3, device=device)
_, vjp_fn = vjp(torch.sin, x)
def foo(x):
@ -406,24 +406,24 @@ class TestComposability(TestCase):
expected = torch.stack([vjp_fn(x)[0] for x in xs])
self.assertEqual(vmap(lambda x: vjp_fn(x)[0])(xs), expected)
def test_vjp_grad(self):
x = torch.randn([])
def test_vjp_grad(self, device):
x = torch.randn([], device=device)
y, vjp_fn = vjp(grad(torch.sin), x)
self.assertEqual(y, x.cos())
v = torch.randn([])
self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
def test_vjp_vmap(self):
x = torch.randn(3)
def test_vjp_vmap(self, device):
x = torch.randn(3, device=device)
y, vjp_fn = vjp(vmap(torch.sin), x)
self.assertEqual(y, x.sin())
v = torch.randn(3)
v = torch.randn(3, device=device)
self.assertEqual(vjp_fn(v)[0], x.cos() * v)
def test_vjp_vjp(self):
x = torch.randn(3)
def test_vjp_vjp(self, device):
x = torch.randn(3, device=device)
y, vjp_fn = vjp(torch.sin, x)
self.assertEqual(y, x.sin())
@ -433,6 +433,27 @@ class TestComposability(TestCase):
y = vjp_fn(x)[0]
# Honestly IDK what the result here is... but at least it runs
instantiate_device_type_tests(
TestGradTransform,
globals(),
None,
)
instantiate_device_type_tests(
TestVmapOfGrad,
globals(),
None,
)
instantiate_device_type_tests(
TestJacrev,
globals(),
None,
)
instantiate_device_type_tests(
TestComposability,
globals(),
None,
)
if __name__ == '__main__':
run_tests()