mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] Fuser: do not fuse ops with 0-dim tensors. (#44073)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44073 We don't have a proper support on NNC and JIT IR->NNC lowering side for it yet. Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D23487905 Pulled By: ZolotukhinM fbshipit-source-id: da0da7478fc8ce7b455176c95d8fd610c94352c1
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3da82aee03
commit
40fec4e739
@ -105,5 +105,24 @@ void testFuserPass_3() {
|
||||
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
}
|
||||
|
||||
void testFuserPass_0DimInput() {
|
||||
KernelScope kernel_scope;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(device=cuda),
|
||||
%y : Float(device=cuda)):
|
||||
%one : int = prim::Constant[value=1]()
|
||||
%a : Float(device=cuda) = aten::mul(%x, %y)
|
||||
%b : Float(device=cuda) = aten::add(%x, %a, %one)
|
||||
return (%b))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g);
|
||||
|
||||
// We should not fuse 0-dim tensors
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -246,6 +246,7 @@ namespace jit {
|
||||
_(FuserPass_1) \
|
||||
_(FuserPass_2) \
|
||||
_(FuserPass_3) \
|
||||
_(FuserPass_0DimInput) \
|
||||
_(TrainBasic)
|
||||
|
||||
#define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \
|
||||
|
@ -105,7 +105,8 @@ class TestTEFuser(JitTestCase):
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
def test_sum_simple(self):
|
||||
def func(x):
|
||||
return x.sum() * 2
|
||||
x2 = x * x
|
||||
return x2.sum()
|
||||
|
||||
a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu')
|
||||
a = a.reshape(5, 3)
|
||||
|
@ -465,8 +465,13 @@ class TensorExprFuser {
|
||||
bool allShapesAreKnown(Node* node) {
|
||||
// TODO: Relax the checks to support dynamic shapes
|
||||
for (Value* input : node->inputs()) {
|
||||
if (input->type()->cast<TensorType>() && !input->isCompleteTensor()) {
|
||||
return false;
|
||||
if (input->type()->cast<TensorType>()) {
|
||||
if (!input->isCompleteTensor()) {
|
||||
return false;
|
||||
}
|
||||
if (*input->type()->cast<TensorType>()->dim() == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
Reference in New Issue
Block a user