mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
More aggressively market functorch.vmap when torch.vmap gets called (#67347)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67347 This PR: - changes the warning when torch.vmap gets called to suggest using functorch.vmap - changes the warning when a batching rule isn't implemented to suggest using functorch.vmap Test Plan: - test/test_vmap.py Reviewed By: H-Huang Differential Revision: D31966603 Pulled By: zou3519 fbshipit-source-id: b01dc1c2e298ce899b4a3a5fb333222a8d5bfb56
This commit is contained in:
committed by
Facebook GitHub Bot
parent
da5ffe752a
commit
a8b93cb3ec
@ -66,13 +66,16 @@ static bool isInplaceOp(const c10::FunctionSchema& schema) {
|
||||
return return_alias_info && return_alias_info->isWrite();
|
||||
}
|
||||
|
||||
static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
|
||||
static void warnFallback(const c10::FunctionSchema& schema) {
|
||||
if (!globalContext().areVmapFallbackWarningsEnabled()) {
|
||||
return;
|
||||
}
|
||||
auto uses_stack = is_inplace ? "" : " and stack";
|
||||
TORCH_WARN("Batching rule not implemented for ", schema.operator_name(), " falling back "
|
||||
"to slow (for loop", uses_stack, ") implementation");
|
||||
TORCH_WARN("There is a performance drop because we have not yet implemented ",
|
||||
"the batching rule for ", schema.operator_name(), ". ",
|
||||
"We've moved development of vmap to to functorch "
|
||||
"(https://github.com/pytorch/functorch), please try functorch.vmap "
|
||||
"instead and/or file ",
|
||||
" an issue on GitHub so that we can prioritize its implementation.");
|
||||
}
|
||||
|
||||
// The general flow of the algorithm is as follows.
|
||||
@ -88,7 +91,7 @@ static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
|
||||
// the operator, and then pop the results off the stack.
|
||||
void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
warnFallback(schema, /*in_place*/true);
|
||||
warnFallback(schema);
|
||||
|
||||
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
|
||||
const auto arguments = torch::jit::last(stack, num_arguments);
|
||||
@ -260,7 +263,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
|
||||
TORCH_CHECK(num_returns >= 1,
|
||||
"Batching rule not implemented for ", schema.operator_name(), ". ",
|
||||
"The fallback path does not support operations with no returns.");
|
||||
warnFallback(schema, /*in_place*/false);
|
||||
warnFallback(schema);
|
||||
|
||||
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
|
||||
const auto arguments = torch::jit::last(stack, num_arguments);
|
||||
|
@ -12,7 +12,7 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
|
||||
import types
|
||||
|
||||
|
||||
FALLBACK_REGEX = r'falling back to slow \(for loop( and stack)?\) implementation'
|
||||
FALLBACK_REGEX = r'There is a performance drop'
|
||||
|
||||
class EnableVmapFallbackWarnings:
|
||||
def __enter__(self):
|
||||
|
@ -158,9 +158,10 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
|
||||
gradients when composed with autograd.
|
||||
|
||||
.. note::
|
||||
We are actively developing a different and improved vmap prototype
|
||||
`here. <https://github.com/zou3519/functorch>`_ The improved
|
||||
prototype is able to arbitrarily compose with gradient computation.
|
||||
We have moved development of vmap to
|
||||
`functorch. <https://github.com/pytorch/functorch>`_ functorch's
|
||||
vmap is able to arbitrarily compose with gradient computation
|
||||
and contains significant performance improvements.
|
||||
Please give that a try if that is what you're looking for.
|
||||
|
||||
Furthermore, if you're interested in using vmap for your use case,
|
||||
@ -247,12 +248,11 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
|
||||
sequences out of the box.
|
||||
"""
|
||||
warnings.warn(
|
||||
'torch.vmap is an experimental prototype that is subject to '
|
||||
'change and/or deletion. Please use at your own risk. There may be '
|
||||
'unexpected performance cliffs due to certain operators not being '
|
||||
'implemented. To see detailed performance warnings please use '
|
||||
'`torch._C._debug_only_display_vmap_fallback_warnings(True) '
|
||||
'before the call to `vmap`.',
|
||||
'Please use functorch.vmap instead of torch.vmap '
|
||||
'(https://github.com/pytorch/functorch). '
|
||||
'We\'ve moved development on torch.vmap over to functorch; '
|
||||
'functorch\'s vmap has a multitude of significant performance and '
|
||||
'functionality improvements.',
|
||||
stacklevel=2)
|
||||
return _vmap(func, in_dims, out_dims)
|
||||
|
||||
|
Reference in New Issue
Block a user