[functorch] Use functorch._C._set_vmap_fallback_warning_enabled

This commit is contained in:
Richard Zou
2021-05-04 12:09:15 -07:00
committed by Jon Janzen
parent 15ab42ce7c
commit 9a81203259
4 changed files with 17 additions and 2 deletions

View File

@ -196,7 +196,7 @@ batched jacobians:
## Debugging
`functorch._C.dump_tensor`: Dumps dispatch keys on stack
`torch._C._debug_only_display_vmap_fallback_warnings(True)`: Shows vmap fallbacks to loop/stack
`functorch._C._set_vmap_fallback_warning_enabled(False)` if the vmap warning spam bothers you.
## Future Plans

View File

@ -13,6 +13,16 @@
namespace at {
namespace functorch {
bool kVmapFallbackWarningEnabled = true;
bool isVmapFallbackWarningEnabled() {
return kVmapFallbackWarningEnabled;
}
void setVmapFallbackWarningEnabled(bool enabled) {
kVmapFallbackWarningEnabled = enabled;
}
// Given a linear index, return the actual index.
// Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0]
static at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize>
@ -71,7 +81,7 @@ static bool isInplaceOp(const FunctionSchema& schema) {
}
static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
if (!globalContext().areVmapFallbackWarningsEnabled()) {
if (!isVmapFallbackWarningEnabled()) {
return;
}
auto uses_stack = is_inplace ? "" : " and stack";

View File

@ -21,6 +21,9 @@ namespace functorch {
// write batching rules for operators whenever possible.
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
bool isVmapFallbackWarningEnabled();
void setVmapFallbackWarningEnabled(bool enabled);
}
} // namespace at

View File

@ -6,6 +6,7 @@
#include <functorch/csrc/BatchedTensorImpl.h>
#include <functorch/csrc/VmapTransforms.h>
#include <functorch/csrc/PythonKey.h>
#include <functorch/csrc/BatchedFallback.h>
namespace at {
namespace functorch {
@ -182,6 +183,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim");
m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim");
m.def("_unwrap_for_grad", &at::functorch::_unwrap_for_grad, "add batch dim");
m.def("_set_vmap_fallback_warning_enabled", &at::functorch::setVmapFallbackWarningEnabled, "Set vmap fallback warnings");
m.def("dlevel", &at::functorch::dlevel, "add batch dim");
m.def("dump_tensor", &at::functorch::dump_tensor, "add batch dim");