Files
pytorch/test/cpp/jit/test_add_if_then_else.cpp
Mike Iovine d1c5f9e439 [JIT][SR] Introduce prim::IfThenElse (#72587)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72587

This pattern frequently appears in a few graphs:

```
%result = prim::If(%condition)
  block0():
    -> (%a)
  block1():
    -> (%b)
```

This is slow, particularly in static runtime. Static runtime creates memory planners/block runners for each sub-block, which eats up a lot of memory and introduces a lot of extra overhead for this relatively simple operation.

This diff introduces a new op that replaces nodes like the above with a single op meant to act like a ternary operator:

```
%result = prim::IfThenElse(%condition, %a, %b)
```

Test Plan: New unit tests

Reviewed By: eellison

Differential Revision: D34091789

fbshipit-source-id: eb6a8c460c39b4c019a1f4ab1f3f1e5b6edc400c
(cherry picked from commit 0f1b335e5b83f402bda2dcdd9ecb411e0b67c651)
2022-02-17 18:22:48 +00:00

54 lines
1.4 KiB
C++

#include <gtest/gtest.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/add_if_then_else.h>
namespace torch {
namespace jit {
TEST(AddIfThenElseOpTest, AddIfThenElseOpSimple) {
const auto src = R"IR(
graph(%cond: bool, %a: Tensor, %b: Tensor):
%result: Tensor = prim::If(%cond)
block0():
-> (%a)
block1():
-> (%b)
return (%result)
)IR";
auto graph = std::make_shared<Graph>();
parseIR(src, graph.get());
EXPECT_TRUE(AddIfThenElseOp(graph));
testing::FileCheck()
.check_count("= prim::IfThenElse", 1, /*exactly*/ true)
->check_count("= prim::If", 0, /*exactly*/ true)
->run(*graph);
}
TEST(AddIfThenElseOpTest, NoIfThenElseOpMultipleOutputs) {
const auto src = R"IR(
graph(%cond: bool, %a: Tensor, %b: Tensor):
%result1: Tensor, %result2: Tensor = prim::If(%cond)
block0():
-> (%a, %b)
block1():
-> (%b, %a)
return (%result1, %result2)
)IR";
auto graph = std::make_shared<Graph>();
parseIR(src, graph.get());
EXPECT_FALSE(AddIfThenElseOp(graph));
testing::FileCheck()
.check_count("= prim::IfThenElse", 0, /*exactly*/ true)
->check_count("= prim::If", 1, /*exactly*/ true)
->run(*graph);
}
} // namespace jit
} // namespace torch