diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp index 559ee3c4d5e1..a66cf316b00e 100644 --- a/test/cpp/tensorexpr/test_te_fuser_pass.cpp +++ b/test/cpp/tensorexpr/test_te_fuser_pass.cpp @@ -105,5 +105,24 @@ void testFuserPass_3() { testing::FileCheck().check("prim::TensorExprGroup")->run(*g); } } + +void testFuserPass_0DimInput() { + KernelScope kernel_scope; + const auto graph_string = R"IR( + graph(%x : Float(device=cuda), + %y : Float(device=cuda)): + %one : int = prim::Constant[value=1]() + %a : Float(device=cuda) = aten::mul(%x, %y) + %b : Float(device=cuda) = aten::add(%x, %a, %one) + return (%b))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g); + + // We should not fuse 0-dim tensors + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index 2379eb4ebd60..fe49ae2e6b5d 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -246,6 +246,7 @@ namespace jit { _(FuserPass_1) \ _(FuserPass_2) \ _(FuserPass_3) \ + _(FuserPass_0DimInput) \ _(TrainBasic) #define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \ diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 804bc2f3f04e..2e3cd435df74 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -105,7 +105,8 @@ class TestTEFuser(JitTestCase): @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_sum_simple(self): def func(x): - return x.sum() * 2 + x2 = x * x + return x2.sum() a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') a = a.reshape(5, 3) diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index b4fe4e61c554..66ebd308364f 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -465,8 +465,13 @@ class TensorExprFuser { bool allShapesAreKnown(Node* node) { // TODO: Relax the checks to support dynamic shapes for (Value* input : node->inputs()) { - if (input->type()->cast() && !input->isCompleteTensor()) { - return false; + if (input->type()->cast()) { + if (!input->isCompleteTensor()) { + return false; + } + if (*input->type()->cast()->dim() == 0) { + return false; + } } } return true;