[tensorexpr] Enabled aten::stack in the fuser pass with static shapes (#74077)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74077

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D34808051

Pulled By: huiguoo

fbshipit-source-id: 213e2ffdf87fb1a74104037cea7ef25e4bfd4307
(cherry picked from commit ad9e84842e5b47eda845827d325b08ba361a8286)
This commit is contained in:
Hui Guo
2022-03-30 18:55:53 -07:00
committed by PyTorch MergeBot
parent 317b1a0ed9
commit 90c3699cc8
4 changed files with 69 additions and 6 deletions

View File

@ -317,6 +317,25 @@ TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) {
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
TEST(TEFuserPass, FuserPass_Stack) {
WithCPUFuser cf;
const auto graph_string =
R"IR(graph(%y.1 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu),
%x.1 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu)):
%1 : int = prim::Constant[value=2]()
%9 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu) = aten::tanh(%x.1)
%7 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu) = aten::tanh(%y.1)
%5 : Tensor[] = prim::ListConstruct(%9, %7)
%z.2 : Float(5, 3, 2, 3, 6, strides=[108, 36, 18, 6, 1], requires_grad=0, device=cpu) = aten::stack(%5, %1)
return (%z.2)
)IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g, /* min_group_size= */ 2);
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
}
TEST(TEFuserPass, FuserPass_Where) {
WithCPUFuser cf;
const auto graph_string = R"IR(

View File

@ -739,6 +739,22 @@ class TestTEFuser(JitTestCase):
# XXX: TE fuser can handle concats in a fusion group.
# FileCheck().check("FusedConcat").check_next("return").run(str(graph))
def test_stack(self):
# "aten::stack fusion is not enabled yet with dynamic shapes"
if self.dynamic_shapes:
return True
with set_fusion_group_inlining(True):
for device in self.devices:
hx = torch.randn(3, 20, dtype=torch.float, device=device)
cx = torch.randn(3, 20, dtype=torch.float, device=device)
def foo(hx, cx):
return torch.stack((hx + cx, hx - cx))
ge = self.checkTrace(foo, (hx, cx))
graph = ge.graph_for(hx, cx)
self.assertAllFused(graph)
def test_remove_output_used_only_in_size(self):
for device in self.devices:
def test_fuse(a, b):
@ -1781,6 +1797,7 @@ class TestTEFuser(JitTestCase):
devices = self.devices
list_ops = [
torch.cat,
torch.stack
]
for dtype, op, device in product(self.dtypes, list_ops, devices):
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":

View File

@ -390,6 +390,31 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_stack(self):
def f(a, b):
return torch.stack((a, b), dim=1)
device = "cpu"
x = torch.rand((3, 5), device=device)
y = torch.rand((3, 5), device=device)
graph_str = """
graph(%x.1 : Float(3, 5, strides=[5, 1], requires_grad=0, device=cpu),
%y.1 : Float(3, 5, strides=[5, 1], requires_grad=0, device=cpu)):
%1 : int = prim::Constant[value=1]()
%5 : Tensor[] = prim::ListConstruct(%x.1, %y.1)
%z.2 : Float(3, 2, 5, strides=[10, 5, 1], requires_grad=0, device=cpu) = aten::stack(%5, %1) # local/stack.py:39:12
return (%z.2)
"""
graph = torch._C.parse_ir(graph_str)
kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x, y))
res2 = kernel.fallback((x, y))
correct = f(x, y)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_alloc_in_loop(self):
a, tmp, b = [

View File

@ -94,6 +94,7 @@ bool isSupported(Node* node) {
};
static const OperatorSet supported_misc_set{
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
"aten::stack(Tensor[] tensors, int dim=0) -> Tensor",
"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)",
};
// clang-format on
@ -771,7 +772,7 @@ class TensorExprFuser {
std::vector<Node*> nodes_to_merge = {to_merge};
if (to_merge->kind() == aten::cat) {
if (to_merge->kind() == aten::cat || to_merge->kind() == aten::stack) {
Node* listconstruct = to_merge->input(0)->node();
nodes_to_merge.push_back(listconstruct);
}
@ -1053,7 +1054,6 @@ class TensorExprFuser {
REQ(isFusableOnDevice(node));
REQ(operators_not_to_fuse.find(node->kind()) ==
operators_not_to_fuse.end());
for (Value* input : node->inputs()) {
if (auto const& tt = input->type()->cast<TensorType>()) {
auto st = tt->scalarType();
@ -1066,7 +1066,7 @@ class TensorExprFuser {
}
}
}
if (node->kind() == aten::cat) {
if (node->kind() == aten::cat || node->kind() == aten::stack) {
REQ(node->input(0)->node()->kind() == prim::ListConstruct);
REQ(node->input(0)->uses().size() == 1);
REQ(node->input(1)->node()->kind() == prim::Constant);
@ -1120,7 +1120,8 @@ class TensorExprFuser {
REQ(nInputs <= subgraphArgLimit);
// Device checks
if (consumer->kind() != aten::cat && producer->kind() != aten::cat) {
if (consumer->kind() != aten::cat && producer->kind() != aten::cat &&
consumer->kind() != aten::stack && producer->kind() != aten::stack) {
// aten::cat needs a special handling because it takes a Tensor[] as its
// input We deal with that in the code below.
auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs());
@ -1154,7 +1155,7 @@ class TensorExprFuser {
REQ(producer->kind() != prim::Constant);
}
if (producer->kind() == aten::cat) {
if (producer->kind() == aten::cat || producer->kind() == aten::stack) {
REQ(producer->input(0)->node()->kind() == prim::ListConstruct);
REQ(producer->input(0)->uses().size() == 1);
REQ(producer->input(1)->node()->kind() == prim::Constant);
@ -1172,7 +1173,8 @@ class TensorExprFuser {
REQ(isFusableOnDevice(input->node()));
}
REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit);
} else if (consumer->kind() == aten::cat) {
} else if (
consumer->kind() == aten::cat || consumer->kind() == aten::stack) {
REQ(consumer->input(0)->node()->kind() == prim::ListConstruct);
REQ(consumer->input(0)->uses().size() == 1);
REQ(consumer->input(1)->node()->kind() == prim::Constant);