From a39cb7a80f22ec79497e430c4e4347e75565fdb0 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 21 Apr 2021 10:44:44 -0700 Subject: [PATCH] [functorch] beef testing --- functorch/functorch/__init__.py | 21 ++ functorch/functorch/_src/eager_transforms.py | 49 +--- functorch/functorch/_src/vmap.py | 18 +- .../functorch/csrc/BatchingRegistrations.cpp | 4 +- functorch/functorch/csrc/DynamicLayer.cpp | 53 +++-- functorch/functorch/csrc/VmapMode.cpp | 60 ----- functorch/functorch/csrc/VmapMode.h | 30 --- functorch/functorch/csrc/init.cpp | 19 +- functorch/test/test_eager_transforms.py | 211 ++++++++++-------- 9 files changed, 191 insertions(+), 274 deletions(-) delete mode 100644 functorch/functorch/csrc/VmapMode.cpp delete mode 100644 functorch/functorch/csrc/VmapMode.h diff --git a/functorch/functorch/__init__.py b/functorch/functorch/__init__.py index c62bfc751b91..40f0f61e36da 100644 --- a/functorch/functorch/__init__.py +++ b/functorch/functorch/__init__.py @@ -4,3 +4,24 @@ from . import _C from ._src.vmap import vmap from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev from ._src.make_functional import make_functional, make_functional_with_buffers + + +# Monkeypatching lol +_old_cross_entropy = torch.nn.functional.cross_entropy + + +def cross_entropy(input, target, weight=None, size_average=None, + ignore_index=-100, reduce=None, reduction='mean'): + if input.dim() == 1 and target.dim() == 0: + input = input.unsqueeze(0) + target = target.unsqueeze(0) + + result = _old_cross_entropy( + input, target, weight, size_average, + ignore_index, reduce, reduction) + if reduction == 'none': + return result.squeeze(0) + return result + + +torch.nn.functional.cross_entropy = cross_entropy diff --git a/functorch/functorch/_src/eager_transforms.py b/functorch/functorch/_src/eager_transforms.py index 756081f81534..5c69f530af5a 100644 --- a/functorch/functorch/_src/eager_transforms.py +++ b/functorch/functorch/_src/eager_transforms.py @@ -3,10 +3,10 @@ from functools import partial import collections import torch.nn as nn import torch.nn.functional as F -from torch.make_functional import make_functional, make_functional_with_buffers import gc from .vmap import vmap +from .make_functional import make_functional, make_functional_with_buffers from functorch._C import ( _wrap_for_grad, @@ -15,13 +15,6 @@ from functorch._C import ( _grad_decrement_nesting, ) -# x = torch.ones(2, 3) -# y = torch.ones(2, 3) -# # result = vmap(torch.add)(x, y) -# result = vmap(vmap(torch.add))(x, y) - -# assert torch.allclose(result, x + y) - # TODO: replace all of these with pytrees def _create_differentiable(tensor_or_tuple_of_tensors, level=None): if isinstance(tensor_or_tuple_of_tensors, torch.Tensor): @@ -96,46 +89,6 @@ def jacrev(f): return result return wrapper_fn -# -# -# def jacrev(f, diff_argnums=(0,)): -# def wrapper(*args): -# torch._C._grad_increment_nesting() -# output = None -# grad_outputs = None -# try: -# args = [_create_differentiable(arg) if i in diff_argnums else arg -# for i, arg in enumerate(args)] -# output = f(*args) -# # Only support single tensor output for now -# assert isinstance(output, torch.Tensor) -# output_numel = output.numel() -# if output_numel != 0: -# grad_output = torch.eye(output_numel).view(output_numel, *output.shape) -# -# diff_args = [args[i] for i in diff_argnums] -# single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1 -# # TODO: quick hack... -# if len(diff_args) == 1 and isinstance(diff_args[0], tuple): -# diff_args = diff_args[0] -# # NB: need create_graph so that backward pass isn't run in no_grad mode -# -# def compute_vjp(v): -# return torch.autograd.grad(output, diff_args, v, create_graph=True) -# -# if output_numel == 0: -# grad_input = compute_vjp(grad_output) -# else: -# grad_input = vmap(compute_vjp)(grad_output) -# -# if single_diff_arg: -# grad_input = grad_input[0] -# finally: -# _undo_create_differentiable(args) -# torch._C._grad_decrement_nesting() -# return grad_input, output -# return wrapper - def grad_with_value(f, diff_argnums=(0,), has_aux=False): def wrapper(*args): level = _grad_increment_nesting() diff --git a/functorch/functorch/_src/vmap.py b/functorch/functorch/_src/vmap.py index a836b0870e6b..066c03f00e17 100644 --- a/functorch/functorch/_src/vmap.py +++ b/functorch/functorch/_src/vmap.py @@ -8,8 +8,8 @@ import warnings from functorch._C import ( _add_batch_dim, _remove_batch_dim, - _vmapmode_decrement_nesting, - _vmapmode_increment_nesting, + _vmap_decrement_nesting, + _vmap_increment_nesting, ) in_dims_t = Union[int, Tuple] @@ -244,14 +244,6 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca vmap does not provide general autobatching or handle variable-length sequences out of the box. """ - warnings.warn( - 'functorch.vmap is an experimental prototype that is subject to ' - 'change and/or deletion. Please use at your own risk. There may be ' - 'unexpected performance cliffs due to certain operators not being ' - 'implemented. To see detailed performance warnings please use ' - '`torch._C._debug_only_display_vmap_fallback_warnings(True) ' - 'before the call to `vmap`.', - stacklevel=2) return _vmap(func, in_dims, out_dims) # A version of vmap but without the initial "experimental prototype" warning @@ -259,12 +251,14 @@ def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> C @functools.wraps(func) def wrapped(*args): _check_out_dims_is_int_or_int_tuple(out_dims, func) - vmap_level = _vmapmode_increment_nesting() + vmap_level = _vmap_increment_nesting() + torch._C._vmapmode_increment_nesting() try: batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func) batched_outputs = func(*batched_inputs) _validate_outputs(batched_outputs, func) return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func) finally: - _vmapmode_decrement_nesting() + torch._C._vmapmode_decrement_nesting() + _vmap_decrement_nesting() return wrapped diff --git a/functorch/functorch/csrc/BatchingRegistrations.cpp b/functorch/functorch/csrc/BatchingRegistrations.cpp index a053c87e3462..8ac7a96392f7 100644 --- a/functorch/functorch/csrc/BatchingRegistrations.cpp +++ b/functorch/functorch/csrc/BatchingRegistrations.cpp @@ -1487,8 +1487,8 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) { // Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution // here is to just directly call the underlying implementation. m.impl("size.int", static_cast(native::size)); - m.impl("_add_batch_dim", native::_add_batch_dim); - m.impl("_remove_batch_dim", native::_remove_batch_dim); + // m.impl("_add_batch_dim", native::_add_batch_dim); + // m.impl("_remove_batch_dim", native::_remove_batch_dim); m.impl("max_pool2d", at::native::max_pool2d); // composite m.impl("max_pool2d_with_indices", max_pool2d_with_indices_batching_rule); diff --git a/functorch/functorch/csrc/DynamicLayer.cpp b/functorch/functorch/csrc/DynamicLayer.cpp index 6b9283eabbc8..557911d5e272 100644 --- a/functorch/functorch/csrc/DynamicLayer.cpp +++ b/functorch/functorch/csrc/DynamicLayer.cpp @@ -31,15 +31,23 @@ class DynamicLayerStackHolder : public c10::DebugInfoBase { thread_local std::shared_ptr kDynamicLayerStack; static std::vector& dynamicLayerStackAccessor() { - if (kDynamicLayerStack == nullptr) { - kDynamicLayerStack = std::make_shared(); - c10::ThreadLocalDebugInfo::_push( - // TODO: this isn't a PRODUCER_INFO, but there's nothing else we can use - c10::DebugInfoKind::PRODUCER_INFO, - kDynamicLayerStack); + if (kDynamicLayerStack != nullptr) { + // TODO: can figure out how to memoize this. std::call_once with thread_local? + return kDynamicLayerStack->dynamicLayerStack; } + if (ThreadLocalDebugInfo::current() != nullptr) { + // TODO: this is going to break if someone else uses PRODUCER_INFO + kDynamicLayerStack = std::static_pointer_cast( + ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PRODUCER_INFO)); + TORCH_INTERNAL_ASSERT(kDynamicLayerStack != nullptr); + return kDynamicLayerStack->dynamicLayerStack; + } + kDynamicLayerStack = std::make_shared(); + c10::ThreadLocalDebugInfo::_push( + // TODO: this isn't a PRODUCER_INFO, but there's nothing else we can use + c10::DebugInfoKind::PRODUCER_INFO, + kDynamicLayerStack); TORCH_INTERNAL_ASSERT(kDynamicLayerStack != nullptr); - // TODO: can figure out how to memoize this. std::call_once with thread_local? return kDynamicLayerStack->dynamicLayerStack; } @@ -111,17 +119,17 @@ DynamicLayer popDynamicLayerAndDeleteMetadata() { auto level = result.layerId(); // TODO: is this lock safe? No one else should be writing to the same bucket - if (c10::show_dispatch_trace_enabled()) { - std::cout << "deleting metadata" << std::endl; - } + // if (c10::show_dispatch_trace_enabled()) { + // std::cout << "deleting metadata" << std::endl; + // } auto& data = getGlobalDynmetaData(); auto it = data.find(level); if (it == data.end()) { return result; } - if (c10::show_dispatch_trace_enabled()) { - std::cout << "deleted metadata for level " << level << std::endl; - } + // if (c10::show_dispatch_trace_enabled()) { + // std::cout << "deleted metadata for level " << level << std::endl; + // } // invalidate the thing *(it->second) = false; data.erase(level); @@ -215,9 +223,9 @@ static void sanityCheckStack(torch::jit::Stack* stack) { void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { auto& dynamicLayerStack = dynamicLayerStackAccessor(); - if (c10::show_dispatch_trace_enabled()) { - std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl; - } + // if (c10::show_dispatch_trace_enabled()) { + // std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl; + // } if (dynamicLayerStack.size() == 0) { sanityCheckStack(stack); c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset); @@ -276,12 +284,11 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* if (!maybe_tensor_wrapper) { return tensor; } - if (maybe_tensor_wrapper->level().value() == cur_level) { + auto tensor_wrapper_level = maybe_tensor_wrapper->level().value(); + TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= cur_level); + if (tensor_wrapper_level == cur_level) { return maybe_tensor_wrapper->value(); } - if (c10::show_dispatch_trace_enabled()) { - std::cout << "unwrap " << cur_level << std::endl; - } return tensor; }; auto wrap = [&](const Tensor& tensor) { @@ -291,9 +298,9 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* if (cur_level == 1) { return tensor; } - if (c10::show_dispatch_trace_enabled()) { - std::cout << "wrap " << cur_level << std::endl; - } + // if (c10::show_dispatch_trace_enabled()) { + // std::cout << "wrap " << cur_level << std::endl; + // } return makeTensorWrapper(tensor, cur_level); }; diff --git a/functorch/functorch/csrc/VmapMode.cpp b/functorch/functorch/csrc/VmapMode.cpp deleted file mode 100644 index 8d61b01c5930..000000000000 --- a/functorch/functorch/csrc/VmapMode.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include -#include -#include - -namespace at { -namespace functorch { -namespace impl { - -/// thread_local is a feature that is not enabled by Caffe2 mobile -/// build (e.g. iOS). Therefore, we only provide `at::VmapMode` -/// when we are not in mobile build or when FEATURE_TORCH_MOBILE -/// is on. -#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE) - -thread_local int64_t VmapMode_current_vmap_level = 0; - -int64_t VmapMode::current_vmap_level() { - return VmapMode_current_vmap_level; -} - -int64_t VmapMode::increment_nesting() { - VmapMode_current_vmap_level++; - - auto level = initAndPushDynamicLayer(kBatchedKey); - if (VmapMode_current_vmap_level == 1) { - c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, true); - } - return level; -} - -int64_t VmapMode::decrement_nesting() { - VmapMode_current_vmap_level--; - auto layer = popDynamicLayerAndDeleteMetadata(); - TORCH_INTERNAL_ASSERT(layer.key() == kBatchedKey); - if (VmapMode_current_vmap_level == 0) { - c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, false); - } - // TODO: this return value should never be used - return VmapMode_current_vmap_level; -} - -#else - -int64_t VmapMode::current_nesting_level() { - TORCH_CHECK(false, "VmapMode is not supported on mobile"); -} - -int64_t VmapMode::increment_nesting() { - TORCH_CHECK(false, "VmapMode is not supported on mobile"); -} - -int64_t VmapMode::decrement_nesting() { - TORCH_CHECK(false, "VmapMode is not supported on mobile"); -} - -#endif - -} // namespace impl -} -} // namespace at diff --git a/functorch/functorch/csrc/VmapMode.h b/functorch/functorch/csrc/VmapMode.h deleted file mode 100644 index c1f86d38b37f..000000000000 --- a/functorch/functorch/csrc/VmapMode.h +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -#include - -namespace at { -namespace functorch { -namespace impl { - -// VmapMode contains a thread local count of how many nested vmaps -// we are currently inside. That number is known as the `vmap level`. -// VmapMode is used in the implementation of the Python `torch.vmap` API. -// -// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet. - -struct TORCH_API VmapMode { - // Returns the vmap level, aka the count of how many nested vmaps we're in. - static int64_t current_vmap_level(); - - // Increment the count of nested vmaps. If this causes the vmap level to be - // greater than 0, then it enables DispatchKey::VmapMode on all tensors. - static int64_t increment_nesting(); - - // Decrements the count of nested vmaps. If this causes the vmap level to be - // equal to 0, then it disables DispatchKey::VmapMode on all tensors. - static int64_t decrement_nesting(); -}; - -} // namespace impl -} -} // namespace at diff --git a/functorch/functorch/csrc/init.cpp b/functorch/functorch/csrc/init.cpp index 38b7529f25cc..fad08a06b3f6 100644 --- a/functorch/functorch/csrc/init.cpp +++ b/functorch/functorch/csrc/init.cpp @@ -5,7 +5,6 @@ #include #include #include -#include namespace at { namespace functorch { @@ -155,7 +154,19 @@ int64_t _grad_increment_nesting() { } int64_t _grad_decrement_nesting() { - return popDynamicLayerAndDeleteMetadata().layerId(); + auto layer = popDynamicLayerAndDeleteMetadata(); + TORCH_INTERNAL_ASSERT(layer.key() == DispatchKey::Autograd); + return layer.layerId(); +} + +int64_t _vmap_increment_nesting() { + return initAndPushDynamicLayer(kBatchedKey); +} + +int64_t _vmap_decrement_nesting() { + auto layer = popDynamicLayerAndDeleteMetadata(); + TORCH_INTERNAL_ASSERT(layer.key() == kBatchedKey); + return layer.layerId(); } @@ -165,8 +176,8 @@ int64_t _grad_decrement_nesting() { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim"); m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim"); - m.def("_vmapmode_increment_nesting", &at::functorch::impl::VmapMode::increment_nesting, "add batch dim"); - m.def("_vmapmode_decrement_nesting", &at::functorch::impl::VmapMode::decrement_nesting, "remove batch dim"); + m.def("_vmap_increment_nesting", &at::functorch::_vmap_increment_nesting, "remove batch dim"); + m.def("_vmap_decrement_nesting", &at::functorch::_vmap_decrement_nesting, "remove batch dim"); m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim"); m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim"); m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim"); diff --git a/functorch/test/test_eager_transforms.py b/functorch/test/test_eager_transforms.py index 975c4eb08608..5de7b2349682 100644 --- a/functorch/test/test_eager_transforms.py +++ b/functorch/test/test_eager_transforms.py @@ -7,7 +7,7 @@ import functools import itertools import warnings from torch.testing._internal.common_device_type import instantiate_device_type_tests, \ - skipCUDAIfNoMagma + skipCUDAIfNoMagma, onlyOnCPUAndCUDA import types from functools import partial @@ -16,19 +16,19 @@ from functorch import grad, vjp, vmap, make_functional, jacrev class TestGradTransform(TestCase): - def test_primitive(self): - x = torch.randn([]) + def test_primitive(self, device): + x = torch.randn([], device=device) result = grad(torch.sin)(x) self.assertEqual(result, torch.cos(x)) - def test_composite_simple(self): - x = torch.randn(2, 3, 4) + def test_composite_simple(self, device): + x = torch.randn(2, 3, 4, device=device) result = grad(lambda x: torch.flatten(x).sum())(x) self.assertEqual(result, torch.ones_like(x)) - def test_composite_complicated(self): - x = torch.randn(3) - y = torch.randn(3, 5) + def test_composite_complicated(self, device): + x = torch.randn(3, device=device) + y = torch.randn(3, 5, device=device) def foo(x, y): result = x @ y @@ -42,10 +42,10 @@ class TestGradTransform(TestCase): self.assertEqual(result, expected) - def test_composite_two_ops(self): + def test_composite_two_ops(self, device): N, C = 2, 5 - y = torch.randn(N, C) - targets = torch.randint(0, C, (N,)) + y = torch.randn(N, C, device=device) + targets = torch.randint(0, C, (N,), device=device) def foo(y, targets): return F.cross_entropy(y, targets) @@ -57,8 +57,8 @@ class TestGradTransform(TestCase): self.assertEqual(result, expected) - def _test_attributes(self, get_attr_lambda): - x = torch.randn(2, 3, 5, dtype=torch.double) + def _test_attributes(self, get_attr_lambda, device): + x = torch.randn(2, 3, 5, dtype=torch.double, device=device) expected = get_attr_lambda(x) def foo(x): @@ -67,20 +67,20 @@ class TestGradTransform(TestCase): grad(foo)(x) - def test_shape(self): - self._test_attributes(lambda x: x.shape) + def test_shape(self, device): + self._test_attributes(lambda x: x.shape, device) - def test_dtype(self): - self._test_attributes(lambda x: x.dtype) + def test_dtype(self, device): + self._test_attributes(lambda x: x.dtype, device) - def test_is_cuda(self): - self._test_attributes(lambda x: x.is_cuda) + def test_is_cuda(self, device): + self._test_attributes(lambda x: x.is_cuda, device) - def test_numel(self): - self._test_attributes(lambda x: x.numel()) + def test_numel(self, device): + self._test_attributes(lambda x: x.numel(), device) - def test_inplace(self): - x = torch.randn([]) + def test_inplace(self, device): + x = torch.randn([], device=device) def foo(x): return x.clone().sin_() @@ -88,8 +88,8 @@ class TestGradTransform(TestCase): result = grad(foo)(x) self.assertEqual(result, x.cos()) - def test_inplace_on_view(self): - x = torch.randn(3) + def test_inplace_on_view(self, device): + x = torch.randn(3, device=device) def foo(x): y = x.clone() @@ -105,8 +105,8 @@ class TestGradTransform(TestCase): self.assertEqual(result, expected) - def test_inplace_on_view_base(self): - x = torch.randn(3) + def test_inplace_on_view_base(self, device): + x = torch.randn(3, device=device) def foo(x): y = x.clone() @@ -122,13 +122,13 @@ class TestGradTransform(TestCase): self.assertEqual(result, expected) - def test_nesting_simple(self): - x = torch.randn([]) + def test_nesting_simple(self, device): + x = torch.randn([], device=device) result = grad(grad(torch.sin))(x) self.assertEqual(result, -torch.sin(x)) - def test_escaped_wrappers_are_marked_as_dead(self): - x = torch.randn([]) + def test_escaped_wrappers_are_marked_as_dead(self, device): + x = torch.randn([], device=device) escaped = [] def foo(x): y = x.sin() @@ -136,10 +136,10 @@ class TestGradTransform(TestCase): return y result = grad(foo)(x) - self.assertEqual(escaped[0].dlevel(), -1) + self.assertEqual(functorch._C.dlevel(escaped[0]), -1) - def test_escaped_wrappers_are_ignored(self): - x = torch.randn([]) + def test_escaped_wrappers_are_ignored(self, device): + x = torch.randn([], device=device) escaped = [] def foo(x): y = x.sin() @@ -149,28 +149,28 @@ class TestGradTransform(TestCase): result = grad(foo)(x) something = escaped[0].sum() - self.assertEqual(something.dlevel(), 0) + self.assertEqual(functorch._C.dlevel(something), 0) self.assertEqual(something, x.sin().sum()) - def test_vjp(self): - x = torch.randn([]) + def test_vjp(self, device): + x = torch.randn([], device=device) out, vjp_fn = vjp(torch.sin, x) self.assertEqual(out, x.sin()) - v = torch.randn([]) + v = torch.randn([], device=device) result, = vjp_fn(v) self.assertEqual(result, v * x.cos()) - def test_composed_with_autograd(self): - x = torch.randn([], requires_grad=True) + def test_composed_with_autograd(self, device): + x = torch.randn([], requires_grad=True, device=device) y = grad(torch.sin)(x) result, = torch.autograd.grad(y, x) self.assertEqual(result, -x.sin()) - def test_grad_of_vjp_composition(self): - x = torch.randn([]) - y = torch.randn([]) + def test_grad_of_vjp_composition(self, device): + x = torch.randn([], device=device) + y = torch.randn([], device=device) def foo(x, y): out, vjp_fn = vjp(torch.sin, x) @@ -180,9 +180,9 @@ class TestGradTransform(TestCase): expected = x.cos() self.assertEqual(result, expected) - def test_vjp_of_grad_composition(self): - x = torch.randn([]) - y = torch.randn([]) + def test_vjp_of_grad_composition(self, device): + x = torch.randn([], device=device) + y = torch.randn([], device=device) def foo(x, y): out, vjp_fn = vjp(grad(torch.sin), x) @@ -192,9 +192,9 @@ class TestGradTransform(TestCase): expected = -y * x.sin() self.assertEqual(result, expected) - def test_grad_of_vjp_of_grad_composition(self): - x = torch.randn([]) - y = torch.randn([]) + def test_grad_of_vjp_of_grad_composition(self, device): + x = torch.randn([], device=device) + y = torch.randn([], device=device) def foo(x, y): df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x) @@ -204,9 +204,9 @@ class TestGradTransform(TestCase): expected = x.cos() self.assertEqual(result, expected) - def test_views(self): - x = torch.randn([], requires_grad=True) - y = torch.randn([], requires_grad=True) + def test_views(self, device): + x = torch.randn([], requires_grad=True, device=device) + y = torch.randn([], requires_grad=True, device=device) def silly_sin(x): x = x.view([]) @@ -223,34 +223,34 @@ class TestGradTransform(TestCase): self.assertEqual(grads[0], -x.sin()) self.assertEqual(grads[1], -y.sin()) - def test_view_inplace_simple(self): + def test_view_inplace_simple(self, device): def foo(x): x = x.clone() x.view([]).sin_() return x - x = torch.randn([], requires_grad=True) + x = torch.randn([], requires_grad=True, device=device) result = grad(foo)(x) self.assertEqual(result, x.cos()) class TestVmapOfGrad(TestCase): - def test_per_sample_grads_inplace_view(self): + def test_per_sample_grads_inplace_view(self, device): def compute_loss(weight, x, t): x = x.mm(weight) y = x.squeeze_(0) return (y - t).sum() - weight = torch.randn(16, 2) - x = torch.randn(64, 1, 16) - t = torch.randn(64, 2) + weight = torch.randn(16, 2, device=device) + x = torch.randn(64, 1, 16, device=device) + t = torch.randn(64, 2, device=device) result = vmap(partial(grad(compute_loss), weight))(x, t) expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)] expected = torch.stack(expected) # TODO: Check if the rtol is a problem self.assertEqual(result, expected, atol=0, rtol=5e-4) - def test_new_zeros_materializes_tensor(self): + def test_new_zeros_materializes_tensor(self, device): N = 3 C = 5 @@ -259,25 +259,25 @@ class TestVmapOfGrad(TestCase): result.copy_(y) return result.sum() - x = torch.randn(N) - y = torch.randn(N, C) + x = torch.randn(N, device=device) + y = torch.randn(N, C, device=device) result = vmap(grad(foo))(x, y) - def test_per_sample_grads_simple(self): + def test_per_sample_grads_simple(self, device): def compute_loss(weight, x, t): y = x @ weight return ((y - t) ** 2).sum() - weight = torch.randn(16, 2) - x = torch.randn(64, 16) - t = torch.randn(64, 2) + weight = torch.randn(16, 2, device=device) + x = torch.randn(64, 16, device=device) + t = torch.randn(64, 2, device=device) result = vmap(partial(grad(compute_loss), weight))(x, t) expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)] expected = torch.stack(expected) # TODO: Check if the rtol is a problem self.assertEqual(result, expected, atol=0, rtol=5e-4) - def test_per_sample_grads_embeddingnet(self): + def test_per_sample_grads_embeddingnet(self, device): class SampleNet(nn.Module): def __init__(self, vocab_size: int): super().__init__() @@ -301,11 +301,11 @@ class TestVmapOfGrad(TestCase): vocab_size = 1000 batch_shape = [64] words_per_sentence = 5 - data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence)) - targets = torch.randint(0, 1, (*batch_shape,)) + data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence), device=device) + targets = torch.randint(0, 1, (*batch_shape,), device=device) # Construct our module - net = SampleNet(vocab_size) + net = SampleNet(vocab_size).to(device=device) criterion = nn.CrossEntropyLoss() params = dict(net.named_parameters()) @@ -326,42 +326,42 @@ class TestVmapOfGrad(TestCase): self.assertEqual(r, e, atol=0, rtol=1e-4) class TestJacrev(TestCase): - def test_simple(self): - x = torch.randn(3) + def test_simple(self, device): + x = torch.randn(3, device=device) y = jacrev(torch.sin)(x) expected = torch.diagflat(x.cos()) assert torch.allclose(y, expected) - def test_simple_not_flat(self): - x = torch.randn(2, 3) + def test_simple_not_flat(self, device): + x = torch.randn(2, 3, device=device) y = jacrev(torch.sin)(x) expected = torch.diagflat(x.view(-1).cos()) expected = expected.view(2, 3, 2, 3) assert torch.allclose(y, expected) - def test_vmap_on_jacrev_simple(self): - x = torch.randn(2, 3) + def test_vmap_on_jacrev_simple(self, device): + x = torch.randn(2, 3, device=device) y = vmap(jacrev(torch.sin))(x) expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)]) assert torch.allclose(y, expected) - def test_hessian_simple(self): + def test_hessian_simple(self, device): def foo(x): return x.sin().sum() - x = torch.randn(3) + x = torch.randn(3, device=device) y = jacrev(jacrev(foo))(x) expected = torch.diagflat(-x.sin()) assert torch.allclose(y, expected) class TestComposability(TestCase): - def test_grad_grad(self): - x = torch.randn([]) + def test_grad_grad(self, device): + x = torch.randn([], device=device) y = grad(grad(torch.sin))(x) self.assertEqual(y, -x.sin()) - def test_grad_vmap(self): + def test_grad_vmap(self, device): def foo(x): y = vmap(torch.sin)(x) return y.sum() @@ -370,8 +370,8 @@ class TestComposability(TestCase): y = grad(foo)(x) self.assertEqual(y, x.cos()) - def test_grad_vjp(self): - x = torch.randn(3) + def test_grad_vjp(self, device): + x = torch.randn(3, device=device) def foo(x): _, vjp_fn = vjp(torch.sin, x) @@ -381,18 +381,18 @@ class TestComposability(TestCase): expected = grad(lambda x: (x * x.cos()).sum())(x) self.assertEqual(y, expected) - def test_vmap_grad(self): - x = torch.randn(3) + def test_vmap_grad(self, device): + x = torch.randn(3, device=device) y = vmap(grad(torch.sin))(x) self.assertEqual(y, x.cos()) - def test_vmap_vmap(self): - x = torch.randn(2, 3) + def test_vmap_vmap(self, device): + x = torch.randn(2, 3, device=device) y = vmap(vmap(torch.sin))(x) self.assertEqual(y, x.sin()) - def test_vmap_vjp(self): - x = torch.randn(3) + def test_vmap_vjp(self, device): + x = torch.randn(3, device=device) _, vjp_fn = vjp(torch.sin, x) def foo(x): @@ -406,24 +406,24 @@ class TestComposability(TestCase): expected = torch.stack([vjp_fn(x)[0] for x in xs]) self.assertEqual(vmap(lambda x: vjp_fn(x)[0])(xs), expected) - def test_vjp_grad(self): - x = torch.randn([]) + def test_vjp_grad(self, device): + x = torch.randn([], device=device) y, vjp_fn = vjp(grad(torch.sin), x) self.assertEqual(y, x.cos()) v = torch.randn([]) self.assertEqual(vjp_fn(v)[0], -x.sin() * v) - def test_vjp_vmap(self): - x = torch.randn(3) + def test_vjp_vmap(self, device): + x = torch.randn(3, device=device) y, vjp_fn = vjp(vmap(torch.sin), x) self.assertEqual(y, x.sin()) - v = torch.randn(3) + v = torch.randn(3, device=device) self.assertEqual(vjp_fn(v)[0], x.cos() * v) - def test_vjp_vjp(self): - x = torch.randn(3) + def test_vjp_vjp(self, device): + x = torch.randn(3, device=device) y, vjp_fn = vjp(torch.sin, x) self.assertEqual(y, x.sin()) @@ -433,6 +433,27 @@ class TestComposability(TestCase): y = vjp_fn(x)[0] # Honestly IDK what the result here is... but at least it runs +instantiate_device_type_tests( + TestGradTransform, + globals(), + None, +) +instantiate_device_type_tests( + TestVmapOfGrad, + globals(), + None, +) +instantiate_device_type_tests( + TestJacrev, + globals(), + None, +) +instantiate_device_type_tests( + TestComposability, + globals(), + None, +) + if __name__ == '__main__': run_tests()