mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] beef testing
This commit is contained in:
@ -4,3 +4,24 @@ from . import _C
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers
|
||||
|
||||
|
||||
# Monkeypatching lol
|
||||
_old_cross_entropy = torch.nn.functional.cross_entropy
|
||||
|
||||
|
||||
def cross_entropy(input, target, weight=None, size_average=None,
|
||||
ignore_index=-100, reduce=None, reduction='mean'):
|
||||
if input.dim() == 1 and target.dim() == 0:
|
||||
input = input.unsqueeze(0)
|
||||
target = target.unsqueeze(0)
|
||||
|
||||
result = _old_cross_entropy(
|
||||
input, target, weight, size_average,
|
||||
ignore_index, reduce, reduction)
|
||||
if reduction == 'none':
|
||||
return result.squeeze(0)
|
||||
return result
|
||||
|
||||
|
||||
torch.nn.functional.cross_entropy = cross_entropy
|
||||
|
@ -3,10 +3,10 @@ from functools import partial
|
||||
import collections
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.make_functional import make_functional, make_functional_with_buffers
|
||||
import gc
|
||||
|
||||
from .vmap import vmap
|
||||
from .make_functional import make_functional, make_functional_with_buffers
|
||||
|
||||
from functorch._C import (
|
||||
_wrap_for_grad,
|
||||
@ -15,13 +15,6 @@ from functorch._C import (
|
||||
_grad_decrement_nesting,
|
||||
)
|
||||
|
||||
# x = torch.ones(2, 3)
|
||||
# y = torch.ones(2, 3)
|
||||
# # result = vmap(torch.add)(x, y)
|
||||
# result = vmap(vmap(torch.add))(x, y)
|
||||
|
||||
# assert torch.allclose(result, x + y)
|
||||
|
||||
# TODO: replace all of these with pytrees
|
||||
def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
@ -96,46 +89,6 @@ def jacrev(f):
|
||||
return result
|
||||
return wrapper_fn
|
||||
|
||||
#
|
||||
#
|
||||
# def jacrev(f, diff_argnums=(0,)):
|
||||
# def wrapper(*args):
|
||||
# torch._C._grad_increment_nesting()
|
||||
# output = None
|
||||
# grad_outputs = None
|
||||
# try:
|
||||
# args = [_create_differentiable(arg) if i in diff_argnums else arg
|
||||
# for i, arg in enumerate(args)]
|
||||
# output = f(*args)
|
||||
# # Only support single tensor output for now
|
||||
# assert isinstance(output, torch.Tensor)
|
||||
# output_numel = output.numel()
|
||||
# if output_numel != 0:
|
||||
# grad_output = torch.eye(output_numel).view(output_numel, *output.shape)
|
||||
#
|
||||
# diff_args = [args[i] for i in diff_argnums]
|
||||
# single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1
|
||||
# # TODO: quick hack...
|
||||
# if len(diff_args) == 1 and isinstance(diff_args[0], tuple):
|
||||
# diff_args = diff_args[0]
|
||||
# # NB: need create_graph so that backward pass isn't run in no_grad mode
|
||||
#
|
||||
# def compute_vjp(v):
|
||||
# return torch.autograd.grad(output, diff_args, v, create_graph=True)
|
||||
#
|
||||
# if output_numel == 0:
|
||||
# grad_input = compute_vjp(grad_output)
|
||||
# else:
|
||||
# grad_input = vmap(compute_vjp)(grad_output)
|
||||
#
|
||||
# if single_diff_arg:
|
||||
# grad_input = grad_input[0]
|
||||
# finally:
|
||||
# _undo_create_differentiable(args)
|
||||
# torch._C._grad_decrement_nesting()
|
||||
# return grad_input, output
|
||||
# return wrapper
|
||||
|
||||
def grad_with_value(f, diff_argnums=(0,), has_aux=False):
|
||||
def wrapper(*args):
|
||||
level = _grad_increment_nesting()
|
||||
|
@ -8,8 +8,8 @@ import warnings
|
||||
from functorch._C import (
|
||||
_add_batch_dim,
|
||||
_remove_batch_dim,
|
||||
_vmapmode_decrement_nesting,
|
||||
_vmapmode_increment_nesting,
|
||||
_vmap_decrement_nesting,
|
||||
_vmap_increment_nesting,
|
||||
)
|
||||
|
||||
in_dims_t = Union[int, Tuple]
|
||||
@ -244,14 +244,6 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
|
||||
vmap does not provide general autobatching or handle variable-length
|
||||
sequences out of the box.
|
||||
"""
|
||||
warnings.warn(
|
||||
'functorch.vmap is an experimental prototype that is subject to '
|
||||
'change and/or deletion. Please use at your own risk. There may be '
|
||||
'unexpected performance cliffs due to certain operators not being '
|
||||
'implemented. To see detailed performance warnings please use '
|
||||
'`torch._C._debug_only_display_vmap_fallback_warnings(True) '
|
||||
'before the call to `vmap`.',
|
||||
stacklevel=2)
|
||||
return _vmap(func, in_dims, out_dims)
|
||||
|
||||
# A version of vmap but without the initial "experimental prototype" warning
|
||||
@ -259,12 +251,14 @@ def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> C
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args):
|
||||
_check_out_dims_is_int_or_int_tuple(out_dims, func)
|
||||
vmap_level = _vmapmode_increment_nesting()
|
||||
vmap_level = _vmap_increment_nesting()
|
||||
torch._C._vmapmode_increment_nesting()
|
||||
try:
|
||||
batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func)
|
||||
batched_outputs = func(*batched_inputs)
|
||||
_validate_outputs(batched_outputs, func)
|
||||
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
|
||||
finally:
|
||||
_vmapmode_decrement_nesting()
|
||||
torch._C._vmapmode_decrement_nesting()
|
||||
_vmap_decrement_nesting()
|
||||
return wrapped
|
||||
|
@ -1487,8 +1487,8 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
// Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution
|
||||
// here is to just directly call the underlying implementation.
|
||||
m.impl("size.int", static_cast<int64_t (*)(const Tensor&, int64_t)>(native::size));
|
||||
m.impl("_add_batch_dim", native::_add_batch_dim);
|
||||
m.impl("_remove_batch_dim", native::_remove_batch_dim);
|
||||
// m.impl("_add_batch_dim", native::_add_batch_dim);
|
||||
// m.impl("_remove_batch_dim", native::_remove_batch_dim);
|
||||
|
||||
m.impl("max_pool2d", at::native::max_pool2d); // composite
|
||||
m.impl("max_pool2d_with_indices", max_pool2d_with_indices_batching_rule);
|
||||
|
@ -31,15 +31,23 @@ class DynamicLayerStackHolder : public c10::DebugInfoBase {
|
||||
thread_local std::shared_ptr<DynamicLayerStackHolder> kDynamicLayerStack;
|
||||
|
||||
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
|
||||
if (kDynamicLayerStack == nullptr) {
|
||||
kDynamicLayerStack = std::make_shared<DynamicLayerStackHolder>();
|
||||
c10::ThreadLocalDebugInfo::_push(
|
||||
// TODO: this isn't a PRODUCER_INFO, but there's nothing else we can use
|
||||
c10::DebugInfoKind::PRODUCER_INFO,
|
||||
kDynamicLayerStack);
|
||||
if (kDynamicLayerStack != nullptr) {
|
||||
// TODO: can figure out how to memoize this. std::call_once with thread_local?
|
||||
return kDynamicLayerStack->dynamicLayerStack;
|
||||
}
|
||||
if (ThreadLocalDebugInfo::current() != nullptr) {
|
||||
// TODO: this is going to break if someone else uses PRODUCER_INFO
|
||||
kDynamicLayerStack = std::static_pointer_cast<DynamicLayerStackHolder>(
|
||||
ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PRODUCER_INFO));
|
||||
TORCH_INTERNAL_ASSERT(kDynamicLayerStack != nullptr);
|
||||
return kDynamicLayerStack->dynamicLayerStack;
|
||||
}
|
||||
kDynamicLayerStack = std::make_shared<DynamicLayerStackHolder>();
|
||||
c10::ThreadLocalDebugInfo::_push(
|
||||
// TODO: this isn't a PRODUCER_INFO, but there's nothing else we can use
|
||||
c10::DebugInfoKind::PRODUCER_INFO,
|
||||
kDynamicLayerStack);
|
||||
TORCH_INTERNAL_ASSERT(kDynamicLayerStack != nullptr);
|
||||
// TODO: can figure out how to memoize this. std::call_once with thread_local?
|
||||
return kDynamicLayerStack->dynamicLayerStack;
|
||||
}
|
||||
|
||||
@ -111,17 +119,17 @@ DynamicLayer popDynamicLayerAndDeleteMetadata() {
|
||||
auto level = result.layerId();
|
||||
|
||||
// TODO: is this lock safe? No one else should be writing to the same bucket
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "deleting metadata" << std::endl;
|
||||
}
|
||||
// if (c10::show_dispatch_trace_enabled()) {
|
||||
// std::cout << "deleting metadata" << std::endl;
|
||||
// }
|
||||
auto& data = getGlobalDynmetaData();
|
||||
auto it = data.find(level);
|
||||
if (it == data.end()) {
|
||||
return result;
|
||||
}
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "deleted metadata for level " << level << std::endl;
|
||||
}
|
||||
// if (c10::show_dispatch_trace_enabled()) {
|
||||
// std::cout << "deleted metadata for level " << level << std::endl;
|
||||
// }
|
||||
// invalidate the thing
|
||||
*(it->second) = false;
|
||||
data.erase(level);
|
||||
@ -215,9 +223,9 @@ static void sanityCheckStack(torch::jit::Stack* stack) {
|
||||
|
||||
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
auto& dynamicLayerStack = dynamicLayerStackAccessor();
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
|
||||
}
|
||||
// if (c10::show_dispatch_trace_enabled()) {
|
||||
// std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
|
||||
// }
|
||||
if (dynamicLayerStack.size() == 0) {
|
||||
sanityCheckStack(stack);
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset);
|
||||
@ -276,12 +284,11 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
|
||||
if (!maybe_tensor_wrapper) {
|
||||
return tensor;
|
||||
}
|
||||
if (maybe_tensor_wrapper->level().value() == cur_level) {
|
||||
auto tensor_wrapper_level = maybe_tensor_wrapper->level().value();
|
||||
TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= cur_level);
|
||||
if (tensor_wrapper_level == cur_level) {
|
||||
return maybe_tensor_wrapper->value();
|
||||
}
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "unwrap " << cur_level << std::endl;
|
||||
}
|
||||
return tensor;
|
||||
};
|
||||
auto wrap = [&](const Tensor& tensor) {
|
||||
@ -291,9 +298,9 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
|
||||
if (cur_level == 1) {
|
||||
return tensor;
|
||||
}
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "wrap " << cur_level << std::endl;
|
||||
}
|
||||
// if (c10::show_dispatch_trace_enabled()) {
|
||||
// std::cout << "wrap " << cur_level << std::endl;
|
||||
// }
|
||||
return makeTensorWrapper(tensor, cur_level);
|
||||
};
|
||||
|
||||
|
@ -1,60 +0,0 @@
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
#include <functorch/csrc/VmapMode.h>
|
||||
#include <functorch/csrc/Constants.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
namespace impl {
|
||||
|
||||
/// thread_local is a feature that is not enabled by Caffe2 mobile
|
||||
/// build (e.g. iOS). Therefore, we only provide `at::VmapMode`
|
||||
/// when we are not in mobile build or when FEATURE_TORCH_MOBILE
|
||||
/// is on.
|
||||
#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)
|
||||
|
||||
thread_local int64_t VmapMode_current_vmap_level = 0;
|
||||
|
||||
int64_t VmapMode::current_vmap_level() {
|
||||
return VmapMode_current_vmap_level;
|
||||
}
|
||||
|
||||
int64_t VmapMode::increment_nesting() {
|
||||
VmapMode_current_vmap_level++;
|
||||
|
||||
auto level = initAndPushDynamicLayer(kBatchedKey);
|
||||
if (VmapMode_current_vmap_level == 1) {
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, true);
|
||||
}
|
||||
return level;
|
||||
}
|
||||
|
||||
int64_t VmapMode::decrement_nesting() {
|
||||
VmapMode_current_vmap_level--;
|
||||
auto layer = popDynamicLayerAndDeleteMetadata();
|
||||
TORCH_INTERNAL_ASSERT(layer.key() == kBatchedKey);
|
||||
if (VmapMode_current_vmap_level == 0) {
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, false);
|
||||
}
|
||||
// TODO: this return value should never be used
|
||||
return VmapMode_current_vmap_level;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int64_t VmapMode::current_nesting_level() {
|
||||
TORCH_CHECK(false, "VmapMode is not supported on mobile");
|
||||
}
|
||||
|
||||
int64_t VmapMode::increment_nesting() {
|
||||
TORCH_CHECK(false, "VmapMode is not supported on mobile");
|
||||
}
|
||||
|
||||
int64_t VmapMode::decrement_nesting() {
|
||||
TORCH_CHECK(false, "VmapMode is not supported on mobile");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace impl
|
||||
}
|
||||
} // namespace at
|
@ -1,30 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
namespace impl {
|
||||
|
||||
// VmapMode contains a thread local count of how many nested vmaps
|
||||
// we are currently inside. That number is known as the `vmap level`.
|
||||
// VmapMode is used in the implementation of the Python `torch.vmap` API.
|
||||
//
|
||||
// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
|
||||
|
||||
struct TORCH_API VmapMode {
|
||||
// Returns the vmap level, aka the count of how many nested vmaps we're in.
|
||||
static int64_t current_vmap_level();
|
||||
|
||||
// Increment the count of nested vmaps. If this causes the vmap level to be
|
||||
// greater than 0, then it enables DispatchKey::VmapMode on all tensors.
|
||||
static int64_t increment_nesting();
|
||||
|
||||
// Decrements the count of nested vmaps. If this causes the vmap level to be
|
||||
// equal to 0, then it disables DispatchKey::VmapMode on all tensors.
|
||||
static int64_t decrement_nesting();
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
}
|
||||
} // namespace at
|
@ -5,7 +5,6 @@
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
#include <functorch/csrc/BatchedTensorImpl.h>
|
||||
#include <functorch/csrc/VmapTransforms.h>
|
||||
#include <functorch/csrc/VmapMode.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
@ -155,7 +154,19 @@ int64_t _grad_increment_nesting() {
|
||||
}
|
||||
|
||||
int64_t _grad_decrement_nesting() {
|
||||
return popDynamicLayerAndDeleteMetadata().layerId();
|
||||
auto layer = popDynamicLayerAndDeleteMetadata();
|
||||
TORCH_INTERNAL_ASSERT(layer.key() == DispatchKey::Autograd);
|
||||
return layer.layerId();
|
||||
}
|
||||
|
||||
int64_t _vmap_increment_nesting() {
|
||||
return initAndPushDynamicLayer(kBatchedKey);
|
||||
}
|
||||
|
||||
int64_t _vmap_decrement_nesting() {
|
||||
auto layer = popDynamicLayerAndDeleteMetadata();
|
||||
TORCH_INTERNAL_ASSERT(layer.key() == kBatchedKey);
|
||||
return layer.layerId();
|
||||
}
|
||||
|
||||
|
||||
@ -165,8 +176,8 @@ int64_t _grad_decrement_nesting() {
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim");
|
||||
m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim");
|
||||
m.def("_vmapmode_increment_nesting", &at::functorch::impl::VmapMode::increment_nesting, "add batch dim");
|
||||
m.def("_vmapmode_decrement_nesting", &at::functorch::impl::VmapMode::decrement_nesting, "remove batch dim");
|
||||
m.def("_vmap_increment_nesting", &at::functorch::_vmap_increment_nesting, "remove batch dim");
|
||||
m.def("_vmap_decrement_nesting", &at::functorch::_vmap_decrement_nesting, "remove batch dim");
|
||||
m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim");
|
||||
m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim");
|
||||
m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim");
|
||||
|
@ -7,7 +7,7 @@ import functools
|
||||
import itertools
|
||||
import warnings
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
||||
skipCUDAIfNoMagma
|
||||
skipCUDAIfNoMagma, onlyOnCPUAndCUDA
|
||||
import types
|
||||
from functools import partial
|
||||
|
||||
@ -16,19 +16,19 @@ from functorch import grad, vjp, vmap, make_functional, jacrev
|
||||
|
||||
|
||||
class TestGradTransform(TestCase):
|
||||
def test_primitive(self):
|
||||
x = torch.randn([])
|
||||
def test_primitive(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
result = grad(torch.sin)(x)
|
||||
self.assertEqual(result, torch.cos(x))
|
||||
|
||||
def test_composite_simple(self):
|
||||
x = torch.randn(2, 3, 4)
|
||||
def test_composite_simple(self, device):
|
||||
x = torch.randn(2, 3, 4, device=device)
|
||||
result = grad(lambda x: torch.flatten(x).sum())(x)
|
||||
self.assertEqual(result, torch.ones_like(x))
|
||||
|
||||
def test_composite_complicated(self):
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3, 5)
|
||||
def test_composite_complicated(self, device):
|
||||
x = torch.randn(3, device=device)
|
||||
y = torch.randn(3, 5, device=device)
|
||||
|
||||
def foo(x, y):
|
||||
result = x @ y
|
||||
@ -42,10 +42,10 @@ class TestGradTransform(TestCase):
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_composite_two_ops(self):
|
||||
def test_composite_two_ops(self, device):
|
||||
N, C = 2, 5
|
||||
y = torch.randn(N, C)
|
||||
targets = torch.randint(0, C, (N,))
|
||||
y = torch.randn(N, C, device=device)
|
||||
targets = torch.randint(0, C, (N,), device=device)
|
||||
|
||||
def foo(y, targets):
|
||||
return F.cross_entropy(y, targets)
|
||||
@ -57,8 +57,8 @@ class TestGradTransform(TestCase):
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def _test_attributes(self, get_attr_lambda):
|
||||
x = torch.randn(2, 3, 5, dtype=torch.double)
|
||||
def _test_attributes(self, get_attr_lambda, device):
|
||||
x = torch.randn(2, 3, 5, dtype=torch.double, device=device)
|
||||
expected = get_attr_lambda(x)
|
||||
|
||||
def foo(x):
|
||||
@ -67,20 +67,20 @@ class TestGradTransform(TestCase):
|
||||
|
||||
grad(foo)(x)
|
||||
|
||||
def test_shape(self):
|
||||
self._test_attributes(lambda x: x.shape)
|
||||
def test_shape(self, device):
|
||||
self._test_attributes(lambda x: x.shape, device)
|
||||
|
||||
def test_dtype(self):
|
||||
self._test_attributes(lambda x: x.dtype)
|
||||
def test_dtype(self, device):
|
||||
self._test_attributes(lambda x: x.dtype, device)
|
||||
|
||||
def test_is_cuda(self):
|
||||
self._test_attributes(lambda x: x.is_cuda)
|
||||
def test_is_cuda(self, device):
|
||||
self._test_attributes(lambda x: x.is_cuda, device)
|
||||
|
||||
def test_numel(self):
|
||||
self._test_attributes(lambda x: x.numel())
|
||||
def test_numel(self, device):
|
||||
self._test_attributes(lambda x: x.numel(), device)
|
||||
|
||||
def test_inplace(self):
|
||||
x = torch.randn([])
|
||||
def test_inplace(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
|
||||
def foo(x):
|
||||
return x.clone().sin_()
|
||||
@ -88,8 +88,8 @@ class TestGradTransform(TestCase):
|
||||
result = grad(foo)(x)
|
||||
self.assertEqual(result, x.cos())
|
||||
|
||||
def test_inplace_on_view(self):
|
||||
x = torch.randn(3)
|
||||
def test_inplace_on_view(self, device):
|
||||
x = torch.randn(3, device=device)
|
||||
|
||||
def foo(x):
|
||||
y = x.clone()
|
||||
@ -105,8 +105,8 @@ class TestGradTransform(TestCase):
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_inplace_on_view_base(self):
|
||||
x = torch.randn(3)
|
||||
def test_inplace_on_view_base(self, device):
|
||||
x = torch.randn(3, device=device)
|
||||
|
||||
def foo(x):
|
||||
y = x.clone()
|
||||
@ -122,13 +122,13 @@ class TestGradTransform(TestCase):
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_nesting_simple(self):
|
||||
x = torch.randn([])
|
||||
def test_nesting_simple(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
result = grad(grad(torch.sin))(x)
|
||||
self.assertEqual(result, -torch.sin(x))
|
||||
|
||||
def test_escaped_wrappers_are_marked_as_dead(self):
|
||||
x = torch.randn([])
|
||||
def test_escaped_wrappers_are_marked_as_dead(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
escaped = []
|
||||
def foo(x):
|
||||
y = x.sin()
|
||||
@ -136,10 +136,10 @@ class TestGradTransform(TestCase):
|
||||
return y
|
||||
|
||||
result = grad(foo)(x)
|
||||
self.assertEqual(escaped[0].dlevel(), -1)
|
||||
self.assertEqual(functorch._C.dlevel(escaped[0]), -1)
|
||||
|
||||
def test_escaped_wrappers_are_ignored(self):
|
||||
x = torch.randn([])
|
||||
def test_escaped_wrappers_are_ignored(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
escaped = []
|
||||
def foo(x):
|
||||
y = x.sin()
|
||||
@ -149,28 +149,28 @@ class TestGradTransform(TestCase):
|
||||
result = grad(foo)(x)
|
||||
|
||||
something = escaped[0].sum()
|
||||
self.assertEqual(something.dlevel(), 0)
|
||||
self.assertEqual(functorch._C.dlevel(something), 0)
|
||||
self.assertEqual(something, x.sin().sum())
|
||||
|
||||
def test_vjp(self):
|
||||
x = torch.randn([])
|
||||
def test_vjp(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
out, vjp_fn = vjp(torch.sin, x)
|
||||
self.assertEqual(out, x.sin())
|
||||
|
||||
v = torch.randn([])
|
||||
v = torch.randn([], device=device)
|
||||
result, = vjp_fn(v)
|
||||
self.assertEqual(result, v * x.cos())
|
||||
|
||||
def test_composed_with_autograd(self):
|
||||
x = torch.randn([], requires_grad=True)
|
||||
def test_composed_with_autograd(self, device):
|
||||
x = torch.randn([], requires_grad=True, device=device)
|
||||
|
||||
y = grad(torch.sin)(x)
|
||||
result, = torch.autograd.grad(y, x)
|
||||
self.assertEqual(result, -x.sin())
|
||||
|
||||
def test_grad_of_vjp_composition(self):
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
def test_grad_of_vjp_composition(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
y = torch.randn([], device=device)
|
||||
|
||||
def foo(x, y):
|
||||
out, vjp_fn = vjp(torch.sin, x)
|
||||
@ -180,9 +180,9 @@ class TestGradTransform(TestCase):
|
||||
expected = x.cos()
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_vjp_of_grad_composition(self):
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
def test_vjp_of_grad_composition(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
y = torch.randn([], device=device)
|
||||
|
||||
def foo(x, y):
|
||||
out, vjp_fn = vjp(grad(torch.sin), x)
|
||||
@ -192,9 +192,9 @@ class TestGradTransform(TestCase):
|
||||
expected = -y * x.sin()
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_grad_of_vjp_of_grad_composition(self):
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
def test_grad_of_vjp_of_grad_composition(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
y = torch.randn([], device=device)
|
||||
|
||||
def foo(x, y):
|
||||
df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x)
|
||||
@ -204,9 +204,9 @@ class TestGradTransform(TestCase):
|
||||
expected = x.cos()
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_views(self):
|
||||
x = torch.randn([], requires_grad=True)
|
||||
y = torch.randn([], requires_grad=True)
|
||||
def test_views(self, device):
|
||||
x = torch.randn([], requires_grad=True, device=device)
|
||||
y = torch.randn([], requires_grad=True, device=device)
|
||||
|
||||
def silly_sin(x):
|
||||
x = x.view([])
|
||||
@ -223,34 +223,34 @@ class TestGradTransform(TestCase):
|
||||
self.assertEqual(grads[0], -x.sin())
|
||||
self.assertEqual(grads[1], -y.sin())
|
||||
|
||||
def test_view_inplace_simple(self):
|
||||
def test_view_inplace_simple(self, device):
|
||||
def foo(x):
|
||||
x = x.clone()
|
||||
x.view([]).sin_()
|
||||
return x
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
x = torch.randn([], requires_grad=True, device=device)
|
||||
result = grad(foo)(x)
|
||||
self.assertEqual(result, x.cos())
|
||||
|
||||
|
||||
class TestVmapOfGrad(TestCase):
|
||||
def test_per_sample_grads_inplace_view(self):
|
||||
def test_per_sample_grads_inplace_view(self, device):
|
||||
def compute_loss(weight, x, t):
|
||||
x = x.mm(weight)
|
||||
y = x.squeeze_(0)
|
||||
return (y - t).sum()
|
||||
|
||||
weight = torch.randn(16, 2)
|
||||
x = torch.randn(64, 1, 16)
|
||||
t = torch.randn(64, 2)
|
||||
weight = torch.randn(16, 2, device=device)
|
||||
x = torch.randn(64, 1, 16, device=device)
|
||||
t = torch.randn(64, 2, device=device)
|
||||
result = vmap(partial(grad(compute_loss), weight))(x, t)
|
||||
expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
|
||||
expected = torch.stack(expected)
|
||||
# TODO: Check if the rtol is a problem
|
||||
self.assertEqual(result, expected, atol=0, rtol=5e-4)
|
||||
|
||||
def test_new_zeros_materializes_tensor(self):
|
||||
def test_new_zeros_materializes_tensor(self, device):
|
||||
N = 3
|
||||
C = 5
|
||||
|
||||
@ -259,25 +259,25 @@ class TestVmapOfGrad(TestCase):
|
||||
result.copy_(y)
|
||||
return result.sum()
|
||||
|
||||
x = torch.randn(N)
|
||||
y = torch.randn(N, C)
|
||||
x = torch.randn(N, device=device)
|
||||
y = torch.randn(N, C, device=device)
|
||||
result = vmap(grad(foo))(x, y)
|
||||
|
||||
def test_per_sample_grads_simple(self):
|
||||
def test_per_sample_grads_simple(self, device):
|
||||
def compute_loss(weight, x, t):
|
||||
y = x @ weight
|
||||
return ((y - t) ** 2).sum()
|
||||
|
||||
weight = torch.randn(16, 2)
|
||||
x = torch.randn(64, 16)
|
||||
t = torch.randn(64, 2)
|
||||
weight = torch.randn(16, 2, device=device)
|
||||
x = torch.randn(64, 16, device=device)
|
||||
t = torch.randn(64, 2, device=device)
|
||||
result = vmap(partial(grad(compute_loss), weight))(x, t)
|
||||
expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
|
||||
expected = torch.stack(expected)
|
||||
# TODO: Check if the rtol is a problem
|
||||
self.assertEqual(result, expected, atol=0, rtol=5e-4)
|
||||
|
||||
def test_per_sample_grads_embeddingnet(self):
|
||||
def test_per_sample_grads_embeddingnet(self, device):
|
||||
class SampleNet(nn.Module):
|
||||
def __init__(self, vocab_size: int):
|
||||
super().__init__()
|
||||
@ -301,11 +301,11 @@ class TestVmapOfGrad(TestCase):
|
||||
vocab_size = 1000
|
||||
batch_shape = [64]
|
||||
words_per_sentence = 5
|
||||
data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence))
|
||||
targets = torch.randint(0, 1, (*batch_shape,))
|
||||
data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence), device=device)
|
||||
targets = torch.randint(0, 1, (*batch_shape,), device=device)
|
||||
|
||||
# Construct our module
|
||||
net = SampleNet(vocab_size)
|
||||
net = SampleNet(vocab_size).to(device=device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
params = dict(net.named_parameters())
|
||||
@ -326,42 +326,42 @@ class TestVmapOfGrad(TestCase):
|
||||
self.assertEqual(r, e, atol=0, rtol=1e-4)
|
||||
|
||||
class TestJacrev(TestCase):
|
||||
def test_simple(self):
|
||||
x = torch.randn(3)
|
||||
def test_simple(self, device):
|
||||
x = torch.randn(3, device=device)
|
||||
y = jacrev(torch.sin)(x)
|
||||
expected = torch.diagflat(x.cos())
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
def test_simple_not_flat(self):
|
||||
x = torch.randn(2, 3)
|
||||
def test_simple_not_flat(self, device):
|
||||
x = torch.randn(2, 3, device=device)
|
||||
y = jacrev(torch.sin)(x)
|
||||
expected = torch.diagflat(x.view(-1).cos())
|
||||
expected = expected.view(2, 3, 2, 3)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
def test_vmap_on_jacrev_simple(self):
|
||||
x = torch.randn(2, 3)
|
||||
def test_vmap_on_jacrev_simple(self, device):
|
||||
x = torch.randn(2, 3, device=device)
|
||||
y = vmap(jacrev(torch.sin))(x)
|
||||
expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
def test_hessian_simple(self):
|
||||
def test_hessian_simple(self, device):
|
||||
def foo(x):
|
||||
return x.sin().sum()
|
||||
|
||||
x = torch.randn(3)
|
||||
x = torch.randn(3, device=device)
|
||||
y = jacrev(jacrev(foo))(x)
|
||||
expected = torch.diagflat(-x.sin())
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
class TestComposability(TestCase):
|
||||
def test_grad_grad(self):
|
||||
x = torch.randn([])
|
||||
def test_grad_grad(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
y = grad(grad(torch.sin))(x)
|
||||
self.assertEqual(y, -x.sin())
|
||||
|
||||
def test_grad_vmap(self):
|
||||
def test_grad_vmap(self, device):
|
||||
def foo(x):
|
||||
y = vmap(torch.sin)(x)
|
||||
return y.sum()
|
||||
@ -370,8 +370,8 @@ class TestComposability(TestCase):
|
||||
y = grad(foo)(x)
|
||||
self.assertEqual(y, x.cos())
|
||||
|
||||
def test_grad_vjp(self):
|
||||
x = torch.randn(3)
|
||||
def test_grad_vjp(self, device):
|
||||
x = torch.randn(3, device=device)
|
||||
|
||||
def foo(x):
|
||||
_, vjp_fn = vjp(torch.sin, x)
|
||||
@ -381,18 +381,18 @@ class TestComposability(TestCase):
|
||||
expected = grad(lambda x: (x * x.cos()).sum())(x)
|
||||
self.assertEqual(y, expected)
|
||||
|
||||
def test_vmap_grad(self):
|
||||
x = torch.randn(3)
|
||||
def test_vmap_grad(self, device):
|
||||
x = torch.randn(3, device=device)
|
||||
y = vmap(grad(torch.sin))(x)
|
||||
self.assertEqual(y, x.cos())
|
||||
|
||||
def test_vmap_vmap(self):
|
||||
x = torch.randn(2, 3)
|
||||
def test_vmap_vmap(self, device):
|
||||
x = torch.randn(2, 3, device=device)
|
||||
y = vmap(vmap(torch.sin))(x)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
def test_vmap_vjp(self):
|
||||
x = torch.randn(3)
|
||||
def test_vmap_vjp(self, device):
|
||||
x = torch.randn(3, device=device)
|
||||
_, vjp_fn = vjp(torch.sin, x)
|
||||
|
||||
def foo(x):
|
||||
@ -406,24 +406,24 @@ class TestComposability(TestCase):
|
||||
expected = torch.stack([vjp_fn(x)[0] for x in xs])
|
||||
self.assertEqual(vmap(lambda x: vjp_fn(x)[0])(xs), expected)
|
||||
|
||||
def test_vjp_grad(self):
|
||||
x = torch.randn([])
|
||||
def test_vjp_grad(self, device):
|
||||
x = torch.randn([], device=device)
|
||||
y, vjp_fn = vjp(grad(torch.sin), x)
|
||||
self.assertEqual(y, x.cos())
|
||||
|
||||
v = torch.randn([])
|
||||
self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
|
||||
|
||||
def test_vjp_vmap(self):
|
||||
x = torch.randn(3)
|
||||
def test_vjp_vmap(self, device):
|
||||
x = torch.randn(3, device=device)
|
||||
y, vjp_fn = vjp(vmap(torch.sin), x)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
v = torch.randn(3)
|
||||
v = torch.randn(3, device=device)
|
||||
self.assertEqual(vjp_fn(v)[0], x.cos() * v)
|
||||
|
||||
def test_vjp_vjp(self):
|
||||
x = torch.randn(3)
|
||||
def test_vjp_vjp(self, device):
|
||||
x = torch.randn(3, device=device)
|
||||
y, vjp_fn = vjp(torch.sin, x)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
@ -433,6 +433,27 @@ class TestComposability(TestCase):
|
||||
y = vjp_fn(x)[0]
|
||||
# Honestly IDK what the result here is... but at least it runs
|
||||
|
||||
instantiate_device_type_tests(
|
||||
TestGradTransform,
|
||||
globals(),
|
||||
None,
|
||||
)
|
||||
instantiate_device_type_tests(
|
||||
TestVmapOfGrad,
|
||||
globals(),
|
||||
None,
|
||||
)
|
||||
instantiate_device_type_tests(
|
||||
TestJacrev,
|
||||
globals(),
|
||||
None,
|
||||
)
|
||||
instantiate_device_type_tests(
|
||||
TestComposability,
|
||||
globals(),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user