Add SDPA patterns for T5 variants when batch size is 1 (#163252)

As mentioned in
https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/fuse_attention.py#L838, this PR generates patterns  for the cases batch size == 1.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163252
Approved by: https://github.com/Valentine233, https://github.com/jansel
This commit is contained in:
CaoE
2025-09-26 08:50:06 +00:00
committed by PyTorch MergeBot
parent 04b51499f7
commit c8e5b7dabb
5 changed files with 654 additions and 88 deletions

View File

@ -997,7 +997,8 @@ class TestSDPAPatternRewriterTemplate(TestCase):
attn_weights = scores.float().softmax(dim=-1).type(value.dtype)
return attn_weights.matmul(value)
tensor_shape = (4, 2, 16, 32)
tensor_shapes = [(4, 2, 16, 32), (1, 2, 16, 32)]
for tensor_shape in tensor_shapes:
attn_mask = torch.randn((1, 1, 1, 2), dtype=torch.float, device=self.device)
args = [
torch.randn(tensor_shape, device=self.device),
@ -1027,7 +1028,8 @@ class TestSDPAPatternRewriterTemplate(TestCase):
attn_weights = scores.float().softmax(dim=-1).type(value.dtype)
return attn_weights.matmul(value), key, value
tensor_shape = (4, 2, 16, 32)
tensor_shapes = [(4, 2, 16, 32), (1, 2, 16, 32)]
for tensor_shape in tensor_shapes:
attn_mask = torch.randn((1, 1, 2, 2), dtype=torch.float, device=self.device)
args = [
torch.randn(tensor_shape, device=self.device),
@ -1067,7 +1069,8 @@ class TestSDPAPatternRewriterTemplate(TestCase):
attn_weights = scores.float().softmax(dim=-1).type(value.dtype)
return attn_weights.matmul(value), key, value
tensor_shape = (4, 2, 16, 32)
tensor_shapes = [(4, 2, 16, 32), (1, 2, 16, 32)]
for tensor_shape in tensor_shapes:
args = [
torch.randn(tensor_shape, device=self.device),
torch.randn(tensor_shape, device=self.device),

View File

@ -581,42 +581,6 @@ def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p):
)
def _sfdp_pattern_24(query, key, value, attention_mask):
"""
this pattern is for MBartForCausalLM/PLBartForCausalLM.
attn_mask has a different dtype with QKV.
there is no scale in sdpa.
"""
bs = query.size(0)
n_head = query.size(1)
seq_len = query.size(2)
head_size = query.size(3)
q = query.view(bs * n_head, -1, head_size)
k = key.reshape(bs * n_head, -1, head_size)
v = value.reshape(bs * n_head, -1, head_size)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = attn_weights.view(bs, n_head, seq_len, -1) + attention_mask
attn_weights = attn_weights.view(bs * n_head, seq_len, -1)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
if query.dtype == torch.half:
attn_weights = attn_weights.to(torch.half)
attn_output = torch.bmm(attn_weights, v)
attn_output = attn_output.view(bs, n_head, seq_len, head_size)
return attn_output
def _sfdp_replacement_24(query, key, value, attention_mask):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask.to(dtype=query.dtype),
is_causal=False,
scale=1,
)
def _sfdp_pattern_21(query, key, value, attn_mask):
# for T5 with inplace add
query = query.permute([0, 2, 1, 3])
@ -643,7 +607,7 @@ def _sfdp_replacement_21(query, key, value, attn_mask):
query,
key,
value,
attn_mask=attn_mask,
attn_mask=attn_mask.to(dtype=query.dtype),
is_causal=False,
scale=1.0,
)
@ -676,7 +640,7 @@ def _sfdp_replacement_22(query, key, value, attn_mask):
query,
key,
value,
attn_mask=attn_mask,
attn_mask=attn_mask.to(dtype=query.dtype),
is_causal=False,
scale=1.0,
),
@ -723,6 +687,42 @@ def _sfdp_replacement_23(query, key, value):
)
def _sfdp_pattern_24(query, key, value, attention_mask):
"""
this pattern is for MBartForCausalLM/PLBartForCausalLM.
attn_mask has a different dtype with QKV.
there is no scale in sdpa.
"""
bs = query.size(0)
n_head = query.size(1)
seq_len = query.size(2)
head_size = query.size(3)
q = query.view(bs * n_head, -1, head_size)
k = key.reshape(bs * n_head, -1, head_size)
v = value.reshape(bs * n_head, -1, head_size)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = attn_weights.view(bs, n_head, seq_len, -1) + attention_mask
attn_weights = attn_weights.view(bs * n_head, seq_len, -1)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
if query.dtype == torch.half:
attn_weights = attn_weights.to(torch.half)
attn_output = torch.bmm(attn_weights, v)
attn_output = attn_output.view(bs, n_head, seq_len, head_size)
return attn_output
def _sfdp_replacement_24(query, key, value, attention_mask):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask.to(dtype=query.dtype),
is_causal=False,
scale=1,
)
def _sfdp_params_check(match):
assert all(k in match.kwargs for k in ("query", "key", "value"))
query = match.kwargs["query"].meta["val"]
@ -1024,6 +1024,13 @@ def _get_sfdp_patterns():
{},
_sfdp_params_check,
),
(
_sfdp_pattern_21,
_sfdp_replacement_21,
[g_bs1(), g_bs1(), g_bs1(), m_bs1_float()],
{},
_sfdp_params_check,
),
(
_sfdp_pattern_22,
_sfdp_replacement_22,
@ -1031,6 +1038,13 @@ def _get_sfdp_patterns():
{},
_sfdp_params_check,
),
(
_sfdp_pattern_22,
_sfdp_replacement_22,
[g_bs1(), g_bs1(), g_bs1(), m_bs1_float()],
{},
_sfdp_params_check,
),
(
_sfdp_pattern_23,
_sfdp_replacement_23,
@ -1038,6 +1052,13 @@ def _get_sfdp_patterns():
{},
_sfdp_params_check,
),
(
_sfdp_pattern_23,
_sfdp_replacement_23,
[g_bs1(), g_bs1(), g_bs1()],
{},
_sfdp_params_check,
),
(
_sfdp_pattern_24,
_sfdp_replacement_24,

View File

@ -119,6 +119,88 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
_sfdp_pattern_21_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
neg_default = CallFunction(aten.neg.default, div_Tensor)
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
view_default_10 = CallFunction(aten.view.default, fma_default, Ignored())
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2)
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
_sfdp_pattern_21_bs1_training = MultiOutputPattern([view_default_7,
permute_default_6,
permute_default_9,
permute_default_11,
None
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored())
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
_sfdp_pattern_21_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
@ -215,3 +297,95 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
_sfdp_pattern_21_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored())
view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
neg_default = CallFunction(aten.neg.default, div_Tensor)
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored())
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored())
convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored())
view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2)
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
_sfdp_pattern_21_half_bs1_training = MultiOutputPattern([view_default_7,
permute_default_6,
permute_default_9,
permute_default_11,
None
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored())
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored())
view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
_sfdp_pattern_21_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)

View File

@ -125,6 +125,94 @@ _sfdp_pattern_22_inference = MultiOutputPattern([view_default_7,
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
neg_default = CallFunction(aten.neg.default, div_Tensor)
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
view_default_10 = CallFunction(aten.view.default, fma_default, Ignored())
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2)
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
_sfdp_pattern_22_bs1_training = MultiOutputPattern([view_default_7,
permute_default_1,
permute_default_3,
permute_default_6,
permute_default_9,
permute_default_11,
None
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored())
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
_sfdp_pattern_22_bs1_inference = MultiOutputPattern([view_default_7,
permute_default_1,
permute_default_3
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
@ -227,3 +315,101 @@ _sfdp_pattern_22_half_inference = MultiOutputPattern([view_default_7,
permute_default_1,
permute_default_3
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored())
view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
neg_default = CallFunction(aten.neg.default, div_Tensor)
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored())
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored())
convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored())
view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2)
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
_sfdp_pattern_22_half_bs1_training = MultiOutputPattern([view_default_7,
permute_default_1,
permute_default_3,
permute_default_6,
permute_default_9,
permute_default_11,
None
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored())
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored())
view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
_sfdp_pattern_22_half_bs1_inference = MultiOutputPattern([view_default_7,
permute_default_1,
permute_default_3
])

View File

@ -122,6 +122,91 @@ _sfdp_pattern_23_inference = MultiOutputPattern([view_default_7,
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
neg_default = CallFunction(aten.neg.default, div_Tensor)
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
view_default_10 = CallFunction(aten.view.default, fma_default, Ignored())
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2)
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
_sfdp_pattern_23_bs1_training = MultiOutputPattern([view_default_7,
permute_default_1,
permute_default_3,
permute_default_6,
permute_default_9,
permute_default_11
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored())
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
_sfdp_pattern_23_bs1_inference = MultiOutputPattern([view_default_7,
permute_default_1,
permute_default_3
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
@ -223,3 +308,100 @@ _sfdp_pattern_23_half_inference = MultiOutputPattern([view_default_7,
permute_default_1,
permute_default_3
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored())
view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
neg_default = CallFunction(aten.neg.default, div_Tensor)
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_4, div_Tensor, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
view_default_10 = CallFunction(aten.view.default, convert_element_type_default_5, Ignored())
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored())
convert_element_type_default_7 = CallFunction(prims.convert_element_type.default, convert_element_type_default_6, Ignored())
view_default_12 = CallFunction(aten.view.default, convert_element_type_default_7, Ignored(), _users=2)
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
_sfdp_pattern_23_half_bs1_training = MultiOutputPattern([view_default_7,
permute_default_1,
permute_default_3,
permute_default_6,
permute_default_9,
permute_default_11
])
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored())
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored())
view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored())
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
_sfdp_pattern_23_half_bs1_inference = MultiOutputPattern([view_default_7,
permute_default_1,
permute_default_3
])