mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32791 When a registered operator has varags (ends with ... in its schema), the interpreter now appends the number of arguments to the top of the stack before invoking the operator. This allows the removal of more uses of Node* in the interpreter. This PR also then cleans up the constructors for Operator to make it more likely someone chooses the correct one. After making these ops: ``` USES NODE: prim::TupleUnpack(...) -> (...) USES NODE: prim::TupleSlice(...) -> (...) USES NODE: prim::TupleConstruct(...) -> (...) USES NODE: prim::ListUnpack(...) -> (...) USES NODE: prim::ListConstruct(...) -> (...) USES NODE: prim::DictConstruct(...) -> (...) USES NODE: prim::Constant() -> (...) USES NODE: prim::isinstance(...) -> (...) USES NODE: prim::CreateObject(...) -> (...) USES NODE: prim::fork(...) -> (...) USES NODE: aten::warn(str message, *, int stacklevel=2) -> () # need stack level information, so ideally in interpreter so it can look at the stack ``` Into interpreter primitives, we can remove all but two constructors for operators: one that is (schema_string, operation), and one that is (symbol, op_creator) for the remaining weird primitives. Test Plan: Imported from OSS Differential Revision: D19673158 Pulled By: zdevito fbshipit-source-id: 95442a001538a6f53c1db4a210f8557ef118de66
86 lines
2.4 KiB
C++
86 lines
2.4 KiB
C++
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/jit/irparser.h>
|
|
#include <torch/csrc/jit/passes/constant_pooling.h>
|
|
#include <torch/csrc/jit/passes/constant_propagation.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include "test/cpp/jit/test_base.h"
|
|
|
|
#include <sstream>
|
|
#include <string>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testConstantPooling() {
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
script::parseIR(
|
|
R"IR(
|
|
graph():
|
|
%8 : int = prim::Constant[value=1]()
|
|
%10 : int = prim::Constant[value=1]()
|
|
return (%8, %10)
|
|
)IR",
|
|
&*graph);
|
|
ConstantPooling(graph);
|
|
testing::FileCheck()
|
|
.check_count("prim::Constant", 1, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%cond : Tensor):
|
|
%a : str = prim::Constant[value="bcd"]()
|
|
%3 : bool = aten::Bool(%cond)
|
|
%b : str = prim::If(%3)
|
|
block0():
|
|
%b.1 : str = prim::Constant[value="abc"]()
|
|
-> (%b.1)
|
|
block1():
|
|
%b.2 : str = prim::Constant[value="abc"]()
|
|
-> (%b.2)
|
|
%7 : (str, str) = prim::TupleConstruct(%a, %b)
|
|
return (%7)
|
|
)IR",
|
|
&*graph);
|
|
ConstantPooling(graph);
|
|
testing::FileCheck()
|
|
.check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
|
|
->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
script::parseIR(
|
|
R"IR(
|
|
graph():
|
|
%2 : int = prim::Constant[value=2]()
|
|
%1 : int = prim::Constant[value=1]()
|
|
%5 : int? = prim::Constant()
|
|
%7 : Device? = prim::Constant()
|
|
%15: bool = prim::Constant[value=0]()
|
|
%10 : int = prim::Constant[value=6]()
|
|
%3 : int[] = prim::ListConstruct(%1, %2)
|
|
%x : Tensor = aten::tensor(%3, %5, %7, %15)
|
|
%y : Tensor = aten::tensor(%3, %10, %7, %15)
|
|
%9 : int[] = prim::ListConstruct(%1, %2)
|
|
%z : Tensor = aten::tensor(%9, %10, %7, %15)
|
|
prim::Print(%x, %y, %z)
|
|
return (%1)
|
|
)IR",
|
|
&*graph);
|
|
// three tensors created - two different devices among the three
|
|
// don't have good support for parsing tensor constants
|
|
ConstantPropagation(graph);
|
|
ConstantPooling(graph);
|
|
testing::FileCheck()
|
|
.check_count("Float(2) = prim::Constant", 1, /*exactly*/ true)
|
|
->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|