mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Use functorch._C._set_vmap_fallback_warning_enabled
This commit is contained in:
@ -196,7 +196,7 @@ batched jacobians:
|
||||
|
||||
## Debugging
|
||||
`functorch._C.dump_tensor`: Dumps dispatch keys on stack
|
||||
`torch._C._debug_only_display_vmap_fallback_warnings(True)`: Shows vmap fallbacks to loop/stack
|
||||
`functorch._C._set_vmap_fallback_warning_enabled(False)` if the vmap warning spam bothers you.
|
||||
|
||||
## Future Plans
|
||||
|
||||
|
@ -13,6 +13,16 @@
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
bool kVmapFallbackWarningEnabled = true;
|
||||
|
||||
bool isVmapFallbackWarningEnabled() {
|
||||
return kVmapFallbackWarningEnabled;
|
||||
}
|
||||
|
||||
void setVmapFallbackWarningEnabled(bool enabled) {
|
||||
kVmapFallbackWarningEnabled = enabled;
|
||||
}
|
||||
|
||||
// Given a linear index, return the actual index.
|
||||
// Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0]
|
||||
static at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize>
|
||||
@ -71,7 +81,7 @@ static bool isInplaceOp(const FunctionSchema& schema) {
|
||||
}
|
||||
|
||||
static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
|
||||
if (!globalContext().areVmapFallbackWarningsEnabled()) {
|
||||
if (!isVmapFallbackWarningEnabled()) {
|
||||
return;
|
||||
}
|
||||
auto uses_stack = is_inplace ? "" : " and stack";
|
||||
|
@ -21,6 +21,9 @@ namespace functorch {
|
||||
// write batching rules for operators whenever possible.
|
||||
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
||||
|
||||
bool isVmapFallbackWarningEnabled();
|
||||
void setVmapFallbackWarningEnabled(bool enabled);
|
||||
|
||||
|
||||
}
|
||||
} // namespace at
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <functorch/csrc/BatchedTensorImpl.h>
|
||||
#include <functorch/csrc/VmapTransforms.h>
|
||||
#include <functorch/csrc/PythonKey.h>
|
||||
#include <functorch/csrc/BatchedFallback.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
@ -182,6 +183,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim");
|
||||
m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim");
|
||||
m.def("_unwrap_for_grad", &at::functorch::_unwrap_for_grad, "add batch dim");
|
||||
m.def("_set_vmap_fallback_warning_enabled", &at::functorch::setVmapFallbackWarningEnabled, "Set vmap fallback warnings");
|
||||
m.def("dlevel", &at::functorch::dlevel, "add batch dim");
|
||||
m.def("dump_tensor", &at::functorch::dump_tensor, "add batch dim");
|
||||
|
||||
|
Reference in New Issue
Block a user