More aggressively market functorch.vmap when torch.vmap gets called (#67347)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67347

This PR:
- changes the warning when torch.vmap gets called to suggest using
functorch.vmap
- changes the warning when a batching rule isn't implemented to suggest
using functorch.vmap

Test Plan: - test/test_vmap.py

Reviewed By: H-Huang

Differential Revision: D31966603

Pulled By: zou3519

fbshipit-source-id: b01dc1c2e298ce899b4a3a5fb333222a8d5bfb56
This commit is contained in:
Richard Zou
2021-11-12 16:08:29 -08:00
committed by Facebook GitHub Bot
parent da5ffe752a
commit a8b93cb3ec
3 changed files with 19 additions and 16 deletions

View File

@ -158,9 +158,10 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
gradients when composed with autograd.
.. note::
We are actively developing a different and improved vmap prototype
`here. <https://github.com/zou3519/functorch>`_ The improved
prototype is able to arbitrarily compose with gradient computation.
We have moved development of vmap to
`functorch. <https://github.com/pytorch/functorch>`_ functorch's
vmap is able to arbitrarily compose with gradient computation
and contains significant performance improvements.
Please give that a try if that is what you're looking for.
Furthermore, if you're interested in using vmap for your use case,
@ -247,12 +248,11 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
sequences out of the box.
"""
warnings.warn(
'torch.vmap is an experimental prototype that is subject to '
'change and/or deletion. Please use at your own risk. There may be '
'unexpected performance cliffs due to certain operators not being '
'implemented. To see detailed performance warnings please use '
'`torch._C._debug_only_display_vmap_fallback_warnings(True) '
'before the call to `vmap`.',
'Please use functorch.vmap instead of torch.vmap '
'(https://github.com/pytorch/functorch). '
'We\'ve moved development on torch.vmap over to functorch; '
'functorch\'s vmap has a multitude of significant performance and '
'functionality improvements.',
stacklevel=2)
return _vmap(func, in_dims, out_dims)