mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[tensorexpr] Enabled aten::stack in the fuser pass with static shapes (#74077)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74077 Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D34808051 Pulled By: huiguoo fbshipit-source-id: 213e2ffdf87fb1a74104037cea7ef25e4bfd4307 (cherry picked from commit ad9e84842e5b47eda845827d325b08ba361a8286)
This commit is contained in:
committed by
PyTorch MergeBot
parent
317b1a0ed9
commit
90c3699cc8
@ -317,6 +317,25 @@ TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) {
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_Stack) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string =
|
||||
R"IR(graph(%y.1 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu),
|
||||
%x.1 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu)):
|
||||
%1 : int = prim::Constant[value=2]()
|
||||
%9 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu) = aten::tanh(%x.1)
|
||||
%7 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu) = aten::tanh(%y.1)
|
||||
%5 : Tensor[] = prim::ListConstruct(%9, %7)
|
||||
%z.2 : Float(5, 3, 2, 3, 6, strides=[108, 36, 18, 6, 1], requires_grad=0, device=cpu) = aten::stack(%5, %1)
|
||||
return (%z.2)
|
||||
)IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 2);
|
||||
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_Where) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
|
@ -739,6 +739,22 @@ class TestTEFuser(JitTestCase):
|
||||
# XXX: TE fuser can handle concats in a fusion group.
|
||||
# FileCheck().check("FusedConcat").check_next("return").run(str(graph))
|
||||
|
||||
def test_stack(self):
|
||||
# "aten::stack fusion is not enabled yet with dynamic shapes"
|
||||
if self.dynamic_shapes:
|
||||
return True
|
||||
with set_fusion_group_inlining(True):
|
||||
for device in self.devices:
|
||||
hx = torch.randn(3, 20, dtype=torch.float, device=device)
|
||||
cx = torch.randn(3, 20, dtype=torch.float, device=device)
|
||||
|
||||
def foo(hx, cx):
|
||||
return torch.stack((hx + cx, hx - cx))
|
||||
|
||||
ge = self.checkTrace(foo, (hx, cx))
|
||||
graph = ge.graph_for(hx, cx)
|
||||
self.assertAllFused(graph)
|
||||
|
||||
def test_remove_output_used_only_in_size(self):
|
||||
for device in self.devices:
|
||||
def test_fuse(a, b):
|
||||
@ -1781,6 +1797,7 @@ class TestTEFuser(JitTestCase):
|
||||
devices = self.devices
|
||||
list_ops = [
|
||||
torch.cat,
|
||||
torch.stack
|
||||
]
|
||||
for dtype, op, device in product(self.dtypes, list_ops, devices):
|
||||
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
|
||||
|
@ -390,6 +390,31 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_stack(self):
|
||||
def f(a, b):
|
||||
return torch.stack((a, b), dim=1)
|
||||
|
||||
device = "cpu"
|
||||
x = torch.rand((3, 5), device=device)
|
||||
y = torch.rand((3, 5), device=device)
|
||||
graph_str = """
|
||||
graph(%x.1 : Float(3, 5, strides=[5, 1], requires_grad=0, device=cpu),
|
||||
%y.1 : Float(3, 5, strides=[5, 1], requires_grad=0, device=cpu)):
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%5 : Tensor[] = prim::ListConstruct(%x.1, %y.1)
|
||||
%z.2 : Float(3, 2, 5, strides=[10, 5, 1], requires_grad=0, device=cpu) = aten::stack(%5, %1) # local/stack.py:39:12
|
||||
return (%z.2)
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x, y))
|
||||
res2 = kernel.fallback((x, y))
|
||||
correct = f(x, y)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_alloc_in_loop(self):
|
||||
a, tmp, b = [
|
||||
|
@ -94,6 +94,7 @@ bool isSupported(Node* node) {
|
||||
};
|
||||
static const OperatorSet supported_misc_set{
|
||||
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
|
||||
"aten::stack(Tensor[] tensors, int dim=0) -> Tensor",
|
||||
"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)",
|
||||
};
|
||||
// clang-format on
|
||||
@ -771,7 +772,7 @@ class TensorExprFuser {
|
||||
|
||||
std::vector<Node*> nodes_to_merge = {to_merge};
|
||||
|
||||
if (to_merge->kind() == aten::cat) {
|
||||
if (to_merge->kind() == aten::cat || to_merge->kind() == aten::stack) {
|
||||
Node* listconstruct = to_merge->input(0)->node();
|
||||
nodes_to_merge.push_back(listconstruct);
|
||||
}
|
||||
@ -1053,7 +1054,6 @@ class TensorExprFuser {
|
||||
REQ(isFusableOnDevice(node));
|
||||
REQ(operators_not_to_fuse.find(node->kind()) ==
|
||||
operators_not_to_fuse.end());
|
||||
|
||||
for (Value* input : node->inputs()) {
|
||||
if (auto const& tt = input->type()->cast<TensorType>()) {
|
||||
auto st = tt->scalarType();
|
||||
@ -1066,7 +1066,7 @@ class TensorExprFuser {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (node->kind() == aten::cat) {
|
||||
if (node->kind() == aten::cat || node->kind() == aten::stack) {
|
||||
REQ(node->input(0)->node()->kind() == prim::ListConstruct);
|
||||
REQ(node->input(0)->uses().size() == 1);
|
||||
REQ(node->input(1)->node()->kind() == prim::Constant);
|
||||
@ -1120,7 +1120,8 @@ class TensorExprFuser {
|
||||
REQ(nInputs <= subgraphArgLimit);
|
||||
|
||||
// Device checks
|
||||
if (consumer->kind() != aten::cat && producer->kind() != aten::cat) {
|
||||
if (consumer->kind() != aten::cat && producer->kind() != aten::cat &&
|
||||
consumer->kind() != aten::stack && producer->kind() != aten::stack) {
|
||||
// aten::cat needs a special handling because it takes a Tensor[] as its
|
||||
// input We deal with that in the code below.
|
||||
auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs());
|
||||
@ -1154,7 +1155,7 @@ class TensorExprFuser {
|
||||
REQ(producer->kind() != prim::Constant);
|
||||
}
|
||||
|
||||
if (producer->kind() == aten::cat) {
|
||||
if (producer->kind() == aten::cat || producer->kind() == aten::stack) {
|
||||
REQ(producer->input(0)->node()->kind() == prim::ListConstruct);
|
||||
REQ(producer->input(0)->uses().size() == 1);
|
||||
REQ(producer->input(1)->node()->kind() == prim::Constant);
|
||||
@ -1172,7 +1173,8 @@ class TensorExprFuser {
|
||||
REQ(isFusableOnDevice(input->node()));
|
||||
}
|
||||
REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit);
|
||||
} else if (consumer->kind() == aten::cat) {
|
||||
} else if (
|
||||
consumer->kind() == aten::cat || consumer->kind() == aten::stack) {
|
||||
REQ(consumer->input(0)->node()->kind() == prim::ListConstruct);
|
||||
REQ(consumer->input(0)->uses().size() == 1);
|
||||
REQ(consumer->input(1)->node()->kind() == prim::Constant);
|
||||
|
Reference in New Issue
Block a user