#include #include #include #include #include #include #include #include #include namespace torch { namespace jit { using namespace torch::jit::tensorexpr; struct WithCPUFuser { WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) { overrideCanFuseOnCPU(val); } ~WithCPUFuser() { overrideCanFuseOnCPU(cpuFuserEnabled); } bool cpuFuserEnabled; }; TEST(TEFuserPass, FuserPass_1) { WithCPUFuser cf; const auto graph_string = R"IR( graph(%0 : Float(128, strides=[1], device=cpu), %1 : Float(128, strides=[1], device=cpu)): %12 : int = prim::Constant[value=1]() %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1) %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12) %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1) %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12) return (%5))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g); // We should not be able to fuse across the in-place operation here. testing::FileCheck() .check("prim::TensorExprGroup_") ->check("aten::add_") ->check("prim::TensorExprGroup_") ->run(*g); } TEST(TEFuserPass, FuserPass_2) { WithCPUFuser cf; const auto graph_string = R"IR( graph(%0 : Float(128, strides=[1], device=cpu), %1 : Float(128, strides=[1], device=cpu)): %12 : int = prim::Constant[value=1]() %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12) %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12) %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a) return (%d))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g); // We should not be able to fuse across the in-place operation here. testing::FileCheck() .check("aten::add_") ->check("prim::TensorExprGroup_0") ->run(*g); } TEST(TEFuserPass, FuserPass_3) { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Float(128, strides=[1], device=cpu), %y : Float(128, strides=[1], device=cpu)): %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y) return (%r))IR"; { auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 2); // We should not create a fusion group since its size would be too small testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } { auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 1); // We should create a fusion group since its size is above the threshold testing::FileCheck().check("prim::TensorExprGroup")->run(*g); } } TEST(TEFuserPass, FuserPass_0DimInput) { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Float(device=cpu), %y : Float(device=cpu)): %one : int = prim::Constant[value=1]() %a : Float(device=cpu) = aten::mul(%x, %y) %b : Float(device=cpu) = aten::add(%x, %a, %one) return (%b))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g); // We should fuse 0-dim tensors too testing::FileCheck().check("prim::TensorExprGroup")->run(*g); } TEST(TEFuserPass, FuserPass_UnfusibleDevice) { WithCPUFuser cf(false); const auto graph_string = R"IR( graph(%x : Float(10, strides=[1], device=cpu), %y : Float(10, strides=[1], device=cpu)): %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) return (%a))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 1); // Test that we're not starting fusion groups from nodes with unfusible device testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } TEST(TEFuserPass, FuserPass_UnknownShapes) { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Tensor, %y : Tensor): %a : Tensor = aten::mul(%x, %y) %b : Tensor = aten::mul(%x, %a) return (%b))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g); // Test that we're not generating fusion groups when shapes are not known testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } TEST(TEFuserPass, FuserPass_Multidevice) { { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cpu), %z : Float(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) return (%cat))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 1); // We should be able to fuse this testing::FileCheck().check("prim::TensorExprGroup")->run(*g); } { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cuda:0), %z : Float(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) return (%cat))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 1); // We should not fuse this aten::cat since its inputs are from different // devices testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cpu), %z : Float(10, strides=[1], device=cuda:0)): %dim : int = prim::Constant[value=0]() %xy_list : Tensor[] = prim::ListConstruct(%x, %y) %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z) return (%r))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 2); // Test that we check device before merging one node (cat) into another // (mul) testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cpu), %z : Float(10, strides=[1], device=cuda:0)): %z2 : Tensor = aten::mul(%z, %z) %dim : int = prim::Constant[value=0]() %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2) %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) return (%cat))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 2); // Test that we check device before merging one node (mul) into another // (cat) testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Float(10, strides=[1], device=cpu), %y : Float(20, strides=[1], device=cuda:0)): %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) return (%r))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 1); // We should not fuse this graph since its inputs are from different devices testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Float(10, strides=[1], device=cuda:0), %y : Float(20, strides=[1], device=cuda:1), %z : Float(20, strides=[1], device=cpu)): %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x) %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y) %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z) return (%x2, %y2, %z2))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 2); // We should not fuse these two computations since they use different // devices testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } } TEST(TEFuserPass, FuserPass_MergeGroups) { WithCPUFuser cf; const auto graph_string = R"IR( graph(%a : Float(128, strides=[1], device=cpu), %b : Float(128, strides=[1], device=cpu)): %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a) %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b) return (%x, %y))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 1); // The %x and %y computations are completely independent and yet we should put // them into a single fusion group rather than having two separate ones. testing::FileCheck() .check("= prim::TensorExprGroup_") ->check_not("= prim::TensorExprGroup_") ->run(*g); } TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Bool(8, strides=[1], device=cpu), %y : Bool(8, strides=[1], device=cpu)): %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y) %b : Tensor = aten::__or__(%a, %y) return (%b) )IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 2); testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } TEST(TEFuserPass, FuserPass_Where) { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Float(8, strides=[1], device=cpu), %y : Float(8, strides=[1], device=cpu), %z : Float(8, strides=[1], device=cpu)): %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z) return (%b) )IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 2); testing::FileCheck().check("prim::TensorExprGroup")->run(*g); } TEST(TEFuserPass, FuserPass_WhereList) { WithCPUFuser cf; const auto graph_string = R"IR( graph(%x : Float(8, strides=[1], device=cpu), %y : Float(8, strides=[1], device=cpu), %z : Float(8, strides=[1], device=cpu)): %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) %b : Tensor[] = aten::where(%cond) return (%b) )IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs(g, /* min_group_size= */ 2); testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } TEST(TEFuserPass, DynamicShapeFusion) { WithCPUFuser cf; const auto graph_string = R"IR( graph(%0 : Float(10, 5, strides=[5, 1], device=cpu), %1 : Float(10, 5, strides=[5, 1], device=cpu)): %2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1) %3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1) return (%3))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); g->lint(); FuseTensorExprs( g, /* min_group_size = */ 2, /* add_composed_op = */ true, /* fuse_to_dynamic_shapes = */ true); Code code(g, ""); testing::FileCheck() .check("prim::TensorExprDynamicGroup_") ->check("prim::TensorExprDynamicGuard") ->check("prim::TensorExprGroup_") ->run(*g); auto run_and_compare = [&](const std::vector& inputs) { TORCH_INTERNAL_ASSERT(inputs.size() == 2); auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]); InterpreterState interp(code); Stack stack(inputs.begin(), inputs.end()); interp.run(stack); at::Tensor out = pop(stack).toTensor(); ASSERT_TRUE(at::allclose(out, ref)); }; std::vector inputs = {at::rand({10, 5}), at::rand({10, 5})}; run_and_compare(inputs); std::vector inputs2 = {at::rand({20, 5}), at::rand({20, 5})}; run_and_compare(inputs2); std::vector inputs3 = {at::rand({25, 60}), at::rand({25, 60})}; run_and_compare(inputs3); } } // namespace jit } // namespace torch