[Static Runtime] Make canEnableStaticRuntime examine sub-blocks (#87396)

Summary:
Someone was running into problems where

1) Static Runtime enablement would fail
2) We would try to fall back to the JIT interpreter *after trying to create `StaticModule`*
3) The fallback fails because Static Runtime mangled the graph.

We don't want to prevent Static Runtime from mutating its input due to memory concerns. The intent of `canEnableStaticRuntime` is to catch issues in the module before Static Runtime messes with it.

With this diff, `StaticModule` instantiation can be avoided by querying `canEnableStaticRuntime` and the issue is fixed.

Test Plan: New unit test

Differential Revision: D40564452

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87396
Approved by: https://github.com/tenpercent
This commit is contained in:
Mike Iovine
2022-10-26 14:34:29 +00:00
committed by PyTorch MergeBot
parent 72f446b9bc
commit ed7a8ab436
2 changed files with 37 additions and 8 deletions

View File

@ -56,9 +56,9 @@ namespace jit {
namespace {
bool allArgsAreTensors(Node* node) {
bool allArgsAreTensors(const Node* node) {
const auto& inputs = node->inputs();
return std::all_of(inputs.begin(), inputs.end(), [](Value* value) {
return std::all_of(inputs.begin(), inputs.end(), [](const Value* value) {
return value->type()->kind() == TypeKind::TensorType;
});
}
@ -69,7 +69,7 @@ bool allArgsAreTensors(Node* node) {
// These are rarely-used ops. Disallowing them typically eliminates
// corner cases in graph optimizations, allowing for more aggressive
// optimizations and better performance.
bool isUnsupportedOp(Node* node) {
bool isUnsupportedOp(const Node* node) {
auto kind = node->kind();
if (kind != aten::__is__ && kind != aten::__isnot__) {
return false;
@ -87,12 +87,21 @@ bool isUnsupportedOp(Node* node) {
return allArgsAreTensors(node);
}
// graph must be frozen or canEnableStaticRuntime would return false
// if there's any prim::CallMethod op left in the graph
bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
// check for sub-blocks
namespace {
bool canEnableStaticRuntimeImpl(const Block* block) {
if (block == nullptr) {
return false;
}
bool can_support = true;
for (auto* node : graph->block()->nodes()) {
for (auto* node : block->nodes()) {
for (auto* subblock : node->blocks()) {
// The ordering prevents && from short circuiting, which we want -
// it's useful to see *all* the unsupported ops.
can_support = canEnableStaticRuntimeImpl(subblock) && can_support;
}
const auto kind = node->kind();
if (kind == prim::Constant) {
continue;
@ -107,6 +116,14 @@ bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
return can_support;
}
} // namespace
// Graph must be frozen. canEnableStaticRuntime will return false
// if there's any prim::CallMethod ops left in the graph.
bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
return canEnableStaticRuntimeImpl(graph->block());
}
namespace {
auto sr_metadata_registerer = torch::class_<StaticRuntimeMetadata>(