Files
pytorch/test/cpp/jit/test_constant_pooling.cpp
Zachary DeVito c59e35b147 interpreter handling for varargs to remove need for looking at Node (#32791)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32791

When a registered operator has varags (ends with ... in its schema),
the interpreter now appends the number of arguments to the top of
the stack before invoking the operator. This allows the removal of more
uses of Node* in the interpreter.

This PR also then cleans up the constructors for Operator to make
it more likely someone chooses the correct one. After making these ops:

```
USES NODE: prim::TupleUnpack(...) -> (...)
USES NODE: prim::TupleSlice(...) -> (...)
USES NODE: prim::TupleConstruct(...) -> (...)
USES NODE: prim::ListUnpack(...) -> (...)
USES NODE: prim::ListConstruct(...) -> (...)
USES NODE: prim::DictConstruct(...) -> (...)
USES NODE: prim::Constant() -> (...)
USES NODE: prim::isinstance(...) -> (...)
USES NODE: prim::CreateObject(...) -> (...)
USES NODE: prim::fork(...) -> (...)
USES NODE: aten::warn(str message, *, int stacklevel=2) -> () # need stack level information, so ideally in interpreter so it can look at the stack
```

Into interpreter primitives, we can remove all but two constructors for operators:
one that is (schema_string, operation), and one that is (symbol, op_creator) for
the remaining weird primitives.

Test Plan: Imported from OSS

Differential Revision: D19673158

Pulled By: zdevito

fbshipit-source-id: 95442a001538a6f53c1db4a210f8557ef118de66
2020-02-18 15:04:48 -08:00

86 lines
2.4 KiB
C++

#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/irparser.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/testing/file_check.h>
#include "test/cpp/jit/test_base.h"
#include <sstream>
#include <string>
namespace torch {
namespace jit {
void testConstantPooling() {
{
auto graph = std::make_shared<Graph>();
script::parseIR(
R"IR(
graph():
%8 : int = prim::Constant[value=1]()
%10 : int = prim::Constant[value=1]()
return (%8, %10)
)IR",
&*graph);
ConstantPooling(graph);
testing::FileCheck()
.check_count("prim::Constant", 1, /*exactly*/ true)
->run(*graph);
}
{
auto graph = std::make_shared<Graph>();
script::parseIR(
R"IR(
graph(%cond : Tensor):
%a : str = prim::Constant[value="bcd"]()
%3 : bool = aten::Bool(%cond)
%b : str = prim::If(%3)
block0():
%b.1 : str = prim::Constant[value="abc"]()
-> (%b.1)
block1():
%b.2 : str = prim::Constant[value="abc"]()
-> (%b.2)
%7 : (str, str) = prim::TupleConstruct(%a, %b)
return (%7)
)IR",
&*graph);
ConstantPooling(graph);
testing::FileCheck()
.check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
->run(*graph);
}
{
auto graph = std::make_shared<Graph>();
script::parseIR(
R"IR(
graph():
%2 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=1]()
%5 : int? = prim::Constant()
%7 : Device? = prim::Constant()
%15: bool = prim::Constant[value=0]()
%10 : int = prim::Constant[value=6]()
%3 : int[] = prim::ListConstruct(%1, %2)
%x : Tensor = aten::tensor(%3, %5, %7, %15)
%y : Tensor = aten::tensor(%3, %10, %7, %15)
%9 : int[] = prim::ListConstruct(%1, %2)
%z : Tensor = aten::tensor(%9, %10, %7, %15)
prim::Print(%x, %y, %z)
return (%1)
)IR",
&*graph);
// three tensors created - two different devices among the three
// don't have good support for parsing tensor constants
ConstantPropagation(graph);
ConstantPooling(graph);
testing::FileCheck()
.check_count("Float(2) = prim::Constant", 1, /*exactly*/ true)
->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true)
->run(*graph);
}
}
} // namespace jit
} // namespace torch