[TensorExpr] Fuser: do not fuse ops with 0-dim tensors. (#44073)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44073

We don't have a proper support on NNC and JIT IR->NNC lowering side for it yet.

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D23487905

Pulled By: ZolotukhinM

fbshipit-source-id: da0da7478fc8ce7b455176c95d8fd610c94352c1
This commit is contained in:
Mikhail Zolotukhin
2020-09-02 22:55:56 -07:00
committed by Facebook GitHub Bot
parent 3da82aee03
commit 40fec4e739
4 changed files with 29 additions and 3 deletions

View File

@ -105,5 +105,24 @@ void testFuserPass_3() {
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
}
}
void testFuserPass_0DimInput() {
KernelScope kernel_scope;
const auto graph_string = R"IR(
graph(%x : Float(device=cuda),
%y : Float(device=cuda)):
%one : int = prim::Constant[value=1]()
%a : Float(device=cuda) = aten::mul(%x, %y)
%b : Float(device=cuda) = aten::add(%x, %a, %one)
return (%b))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
FuseTensorExprs(g);
// We should not fuse 0-dim tensors
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
}
} // namespace jit
} // namespace torch

View File

@ -246,6 +246,7 @@ namespace jit {
_(FuserPass_1) \
_(FuserPass_2) \
_(FuserPass_3) \
_(FuserPass_0DimInput) \
_(TrainBasic)
#define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \

View File

@ -105,7 +105,8 @@ class TestTEFuser(JitTestCase):
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
def test_sum_simple(self):
def func(x):
return x.sum() * 2
x2 = x * x
return x2.sum()
a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu')
a = a.reshape(5, 3)

View File

@ -465,8 +465,13 @@ class TensorExprFuser {
bool allShapesAreKnown(Node* node) {
// TODO: Relax the checks to support dynamic shapes
for (Value* input : node->inputs()) {
if (input->type()->cast<TensorType>() && !input->isCompleteTensor()) {
return false;
if (input->type()->cast<TensorType>()) {
if (!input->isCompleteTensor()) {
return false;
}
if (*input->type()->cast<TensorType>()->dim() == 0) {
return false;
}
}
}
return true;