More aggressively market functorch.vmap when torch.vmap gets called (#67347)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67347

This PR:
- changes the warning when torch.vmap gets called to suggest using
functorch.vmap
- changes the warning when a batching rule isn't implemented to suggest
using functorch.vmap

Test Plan: - test/test_vmap.py

Reviewed By: H-Huang

Differential Revision: D31966603

Pulled By: zou3519

fbshipit-source-id: b01dc1c2e298ce899b4a3a5fb333222a8d5bfb56
This commit is contained in:
Richard Zou
2021-11-12 16:08:29 -08:00
committed by Facebook GitHub Bot
parent da5ffe752a
commit a8b93cb3ec
3 changed files with 19 additions and 16 deletions

View File

@ -66,13 +66,16 @@ static bool isInplaceOp(const c10::FunctionSchema& schema) {
return return_alias_info && return_alias_info->isWrite();
}
static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
static void warnFallback(const c10::FunctionSchema& schema) {
if (!globalContext().areVmapFallbackWarningsEnabled()) {
return;
}
auto uses_stack = is_inplace ? "" : " and stack";
TORCH_WARN("Batching rule not implemented for ", schema.operator_name(), " falling back "
"to slow (for loop", uses_stack, ") implementation");
TORCH_WARN("There is a performance drop because we have not yet implemented ",
"the batching rule for ", schema.operator_name(), ". ",
"We've moved development of vmap to to functorch "
"(https://github.com/pytorch/functorch), please try functorch.vmap "
"instead and/or file ",
" an issue on GitHub so that we can prioritize its implementation.");
}
// The general flow of the algorithm is as follows.
@ -88,7 +91,7 @@ static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
// the operator, and then pop the results off the stack.
void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
warnFallback(schema, /*in_place*/true);
warnFallback(schema);
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
const auto arguments = torch::jit::last(stack, num_arguments);
@ -260,7 +263,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
TORCH_CHECK(num_returns >= 1,
"Batching rule not implemented for ", schema.operator_name(), ". ",
"The fallback path does not support operations with no returns.");
warnFallback(schema, /*in_place*/false);
warnFallback(schema);
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
const auto arguments = torch::jit::last(stack, num_arguments);

View File

@ -12,7 +12,7 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
import types
FALLBACK_REGEX = r'falling back to slow \(for loop( and stack)?\) implementation'
FALLBACK_REGEX = r'There is a performance drop'
class EnableVmapFallbackWarnings:
def __enter__(self):

View File

@ -158,9 +158,10 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
gradients when composed with autograd.
.. note::
We are actively developing a different and improved vmap prototype
`here. <https://github.com/zou3519/functorch>`_ The improved
prototype is able to arbitrarily compose with gradient computation.
We have moved development of vmap to
`functorch. <https://github.com/pytorch/functorch>`_ functorch's
vmap is able to arbitrarily compose with gradient computation
and contains significant performance improvements.
Please give that a try if that is what you're looking for.
Furthermore, if you're interested in using vmap for your use case,
@ -247,12 +248,11 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
sequences out of the box.
"""
warnings.warn(
'torch.vmap is an experimental prototype that is subject to '
'change and/or deletion. Please use at your own risk. There may be '
'unexpected performance cliffs due to certain operators not being '
'implemented. To see detailed performance warnings please use '
'`torch._C._debug_only_display_vmap_fallback_warnings(True) '
'before the call to `vmap`.',
'Please use functorch.vmap instead of torch.vmap '
'(https://github.com/pytorch/functorch). '
'We\'ve moved development on torch.vmap over to functorch; '
'functorch\'s vmap has a multitude of significant performance and '
'functionality improvements.',
stacklevel=2)
return _vmap(func, in_dims, out_dims)