mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72587 This pattern frequently appears in a few graphs: ``` %result = prim::If(%condition) block0(): -> (%a) block1(): -> (%b) ``` This is slow, particularly in static runtime. Static runtime creates memory planners/block runners for each sub-block, which eats up a lot of memory and introduces a lot of extra overhead for this relatively simple operation. This diff introduces a new op that replaces nodes like the above with a single op meant to act like a ternary operator: ``` %result = prim::IfThenElse(%condition, %a, %b) ``` Test Plan: New unit tests Reviewed By: eellison Differential Revision: D34091789 fbshipit-source-id: eb6a8c460c39b4c019a1f4ab1f3f1e5b6edc400c (cherry picked from commit 0f1b335e5b83f402bda2dcdd9ecb411e0b67c651)
54 lines
1.4 KiB
C++
54 lines
1.4 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/passes/add_if_then_else.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TEST(AddIfThenElseOpTest, AddIfThenElseOpSimple) {
|
|
const auto src = R"IR(
|
|
graph(%cond: bool, %a: Tensor, %b: Tensor):
|
|
%result: Tensor = prim::If(%cond)
|
|
block0():
|
|
-> (%a)
|
|
block1():
|
|
-> (%b)
|
|
return (%result)
|
|
)IR";
|
|
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(src, graph.get());
|
|
EXPECT_TRUE(AddIfThenElseOp(graph));
|
|
|
|
testing::FileCheck()
|
|
.check_count("= prim::IfThenElse", 1, /*exactly*/ true)
|
|
->check_count("= prim::If", 0, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
|
|
TEST(AddIfThenElseOpTest, NoIfThenElseOpMultipleOutputs) {
|
|
const auto src = R"IR(
|
|
graph(%cond: bool, %a: Tensor, %b: Tensor):
|
|
%result1: Tensor, %result2: Tensor = prim::If(%cond)
|
|
block0():
|
|
-> (%a, %b)
|
|
block1():
|
|
-> (%b, %a)
|
|
return (%result1, %result2)
|
|
)IR";
|
|
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(src, graph.get());
|
|
EXPECT_FALSE(AddIfThenElseOp(graph));
|
|
|
|
testing::FileCheck()
|
|
.check_count("= prim::IfThenElse", 0, /*exactly*/ true)
|
|
->check_count("= prim::If", 1, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|