mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Static Runtime] Make canEnableStaticRuntime examine sub-blocks (#87396)
Summary: Someone was running into problems where 1) Static Runtime enablement would fail 2) We would try to fall back to the JIT interpreter *after trying to create `StaticModule`* 3) The fallback fails because Static Runtime mangled the graph. We don't want to prevent Static Runtime from mutating its input due to memory concerns. The intent of `canEnableStaticRuntime` is to catch issues in the module before Static Runtime messes with it. With this diff, `StaticModule` instantiation can be avoided by querying `canEnableStaticRuntime` and the issue is fixed. Test Plan: New unit test Differential Revision: D40564452 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87396 Approved by: https://github.com/tenpercent
This commit is contained in:
committed by
PyTorch MergeBot
parent
72f446b9bc
commit
ed7a8ab436
@ -56,9 +56,9 @@ namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
bool allArgsAreTensors(Node* node) {
|
||||
bool allArgsAreTensors(const Node* node) {
|
||||
const auto& inputs = node->inputs();
|
||||
return std::all_of(inputs.begin(), inputs.end(), [](Value* value) {
|
||||
return std::all_of(inputs.begin(), inputs.end(), [](const Value* value) {
|
||||
return value->type()->kind() == TypeKind::TensorType;
|
||||
});
|
||||
}
|
||||
@ -69,7 +69,7 @@ bool allArgsAreTensors(Node* node) {
|
||||
// These are rarely-used ops. Disallowing them typically eliminates
|
||||
// corner cases in graph optimizations, allowing for more aggressive
|
||||
// optimizations and better performance.
|
||||
bool isUnsupportedOp(Node* node) {
|
||||
bool isUnsupportedOp(const Node* node) {
|
||||
auto kind = node->kind();
|
||||
if (kind != aten::__is__ && kind != aten::__isnot__) {
|
||||
return false;
|
||||
@ -87,12 +87,21 @@ bool isUnsupportedOp(Node* node) {
|
||||
return allArgsAreTensors(node);
|
||||
}
|
||||
|
||||
// graph must be frozen or canEnableStaticRuntime would return false
|
||||
// if there's any prim::CallMethod op left in the graph
|
||||
bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
// check for sub-blocks
|
||||
namespace {
|
||||
|
||||
bool canEnableStaticRuntimeImpl(const Block* block) {
|
||||
if (block == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool can_support = true;
|
||||
for (auto* node : graph->block()->nodes()) {
|
||||
for (auto* node : block->nodes()) {
|
||||
for (auto* subblock : node->blocks()) {
|
||||
// The ordering prevents && from short circuiting, which we want -
|
||||
// it's useful to see *all* the unsupported ops.
|
||||
can_support = canEnableStaticRuntimeImpl(subblock) && can_support;
|
||||
}
|
||||
|
||||
const auto kind = node->kind();
|
||||
if (kind == prim::Constant) {
|
||||
continue;
|
||||
@ -107,6 +116,14 @@ bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
return can_support;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Graph must be frozen. canEnableStaticRuntime will return false
|
||||
// if there's any prim::CallMethod ops left in the graph.
|
||||
bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
return canEnableStaticRuntimeImpl(graph->block());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
auto sr_metadata_registerer = torch::class_<StaticRuntimeMetadata>(
|
||||
|
||||
Reference in New Issue
Block a user