mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytorch][mobile] fixed AutoGradMode/AutoNonVariableTypeMode uses for mobile callsites
Summary: There are three guards related to mobile build: * AutoGradMode * AutoNonVariableTypeMode * GraphOptimizerEnabledGuard Today we need set some of these guards before calling libtorch APIs because we customized mobile build to only support inference (for both OSS and most FB use cases) to optimize binary size. Several changes were made since 1.3 release so there are already inconsistent uses of these guards in the codebase. I did a sweep of all mobile related model loading & forward() call sites, trying to unify the use of these guards: Full JIT: still set all three guards. More specifically: * OSS: Fixed a bug of not setting the guard at model load time correctly in Android JNI. * FB: Not covered by this diff (as we are using mobile interpreter for most internal builds). Lite JIT (mobile interpreter): only needs AutoNonVariableTypeMode guard. AutoGradMode doesn't seem to be relevant (so removed from a few places) and GraphOptimizerEnabledGuard definitely not relevant (only full JIT has graph optimizer). More specifically: * OSS: At this point we are not committed to support Lite-JIT. For Android it shares the same code with FB JNI callsites. * FB: ** JNI callsites: Use the unified LiteJITCallGuard. ** For iOS/C++: manually set AutoNonVariableTypeMode for _load_for_mobile() & forward() callsites. Ideally we should avoid having to set AutoNonVariableTypeMode for mobile interpreter. It's currently needed for dynamic dispatch + inference-only mobile build (where variable kernels are not registered) - without the guard it will try to run `variable_fallback_kernel` and crash (PR #34038). The proper fix will take some time so using this workaround to unblock selective BUCK build which depends on dynamic dispatch. PS. The current status (of having to set AutoNonVariableTypeMode) should not block running FL model + mobile interpreter - if all necessary variable kernels are registered then it can call _load_for_mobile()/forward() against the FL model without setting the AutoNonVariableTypeMode guard. It's still inconvenient for JAVA callsites as it's set unconditionally inside JNI methods. Test Plan: - CI Reviewed By: xta0 Differential Revision: D20498017 fbshipit-source-id: ba6740f66839a61790873df46e8e66e4e141c728
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a4048b4703
commit
6e47e7bf52
@ -26,6 +26,12 @@ namespace {
|
||||
struct JITCallGuard {
|
||||
// AutoGrad is disabled for mobile by default.
|
||||
torch::autograd::AutoGradMode no_autograd_guard{false};
|
||||
// VariableType dispatch is not included in default mobile build. We need set
|
||||
// this guard globally to avoid dispatch error (only for dynamic dispatch).
|
||||
// Thanks to the unification of Variable class and Tensor class it's no longer
|
||||
// required to toggle the NonVariableTypeMode per op - so it doesn't hurt to
|
||||
// always set NonVariableTypeMode for inference only use case.
|
||||
torch::AutoNonVariableTypeMode non_var_guard{true};
|
||||
// Disable graph optimizer to ensure list of unused ops are not changed for
|
||||
// custom mobile build.
|
||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
|
||||
@ -111,11 +117,11 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
/* need_inputs */ false,
|
||||
/* sampled */ false);
|
||||
#endif
|
||||
JITCallGuard guard;
|
||||
}
|
||||
|
||||
PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
|
||||
preModuleLoadSetup();
|
||||
JITCallGuard guard;
|
||||
module_ = torch::jit::load(std::move(modelPath->toStdString()));
|
||||
module_.eval();
|
||||
}
|
||||
@ -147,6 +153,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
"Could not get buffer for asset '%s'",
|
||||
assetName->toStdString().c_str());
|
||||
}
|
||||
JITCallGuard guard;
|
||||
module_ = torch::jit::load(torch::make_unique<MemoryReadAdapter>(
|
||||
assetBuffer, AAsset_getLength(asset)));
|
||||
AAsset_close(asset);
|
||||
|
@ -12,10 +12,23 @@
|
||||
|
||||
#include "pytorch_jni_common.h"
|
||||
|
||||
using namespace pytorch_jni;
|
||||
|
||||
namespace pytorch_jni {
|
||||
|
||||
namespace {
|
||||
|
||||
struct LiteJITCallGuard {
|
||||
// VariableType dispatch is not included in default mobile build. We need set
|
||||
// this guard globally to avoid dispatch error (only for dynamic dispatch).
|
||||
// Thanks to the unification of Variable class and Tensor class it's no longer
|
||||
// required to toggle the NonVariableTypeMode per op - so it doesn't hurt to
|
||||
// always set NonVariableTypeMode for inference only use case.
|
||||
// TODO: avoid having to set this guard for custom mobile build with mobile
|
||||
// interpreter.
|
||||
torch::AutoNonVariableTypeMode non_var_guard{true};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
private:
|
||||
friend HybridBase;
|
||||
@ -31,6 +44,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
}
|
||||
|
||||
PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
|
||||
LiteJITCallGuard guard;
|
||||
module_ = torch::jit::_load_for_mobile(std::move(modelPath->toStdString()));
|
||||
}
|
||||
|
||||
@ -55,8 +69,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
}
|
||||
|
||||
auto output = [&]() {
|
||||
torch::autograd::AutoGradMode guard(false);
|
||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||
LiteJITCallGuard guard;
|
||||
return module_.forward(inputs);
|
||||
}();
|
||||
return JIValue::newJIValueFromAtIValue(output);
|
||||
@ -78,7 +91,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
}
|
||||
if (auto method = module_.find_method(methodName)) {
|
||||
auto output = [&]() {
|
||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||
LiteJITCallGuard guard;
|
||||
return module_.run_method(methodName, inputs);
|
||||
}();
|
||||
return JIValue::newJIValueFromAtIValue(output);
|
||||
|
@ -24,7 +24,10 @@ int main(int argc, char** argv) {
|
||||
std::cerr << FLAGS_model << ":Model file is not provided\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
// TODO: avoid having to set this guard for custom mobile build with mobile
|
||||
// interpreter.
|
||||
torch::AutoNonVariableTypeMode non_var_guard{true};
|
||||
torch::jit::mobile::Module bc = torch::jit::_load_for_mobile(FLAGS_model);
|
||||
return 0;
|
||||
}
|
||||
|
Reference in New Issue
Block a user