mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add SDPA patterns for T5 variants when batch size is 1 (#163252)
As mentioned in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/fuse_attention.py#L838, this PR generates patterns for the cases batch size == 1. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163252 Approved by: https://github.com/Valentine233, https://github.com/jansel
This commit is contained in:
@ -997,20 +997,21 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
attn_weights = scores.float().softmax(dim=-1).type(value.dtype)
|
||||
return attn_weights.matmul(value)
|
||||
|
||||
tensor_shape = (4, 2, 16, 32)
|
||||
attn_mask = torch.randn((1, 1, 1, 2), dtype=torch.float, device=self.device)
|
||||
args = [
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
attn_mask,
|
||||
]
|
||||
self._check_common(
|
||||
dot_prod_attention,
|
||||
args1=args,
|
||||
has_dropout=False,
|
||||
check_train=False,
|
||||
)
|
||||
tensor_shapes = [(4, 2, 16, 32), (1, 2, 16, 32)]
|
||||
for tensor_shape in tensor_shapes:
|
||||
attn_mask = torch.randn((1, 1, 1, 2), dtype=torch.float, device=self.device)
|
||||
args = [
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
attn_mask,
|
||||
]
|
||||
self._check_common(
|
||||
dot_prod_attention,
|
||||
args1=args,
|
||||
has_dropout=False,
|
||||
check_train=False,
|
||||
)
|
||||
|
||||
def _test_sdpa_rewriter_22(self):
|
||||
def dot_prod_attention(
|
||||
@ -1027,30 +1028,31 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
attn_weights = scores.float().softmax(dim=-1).type(value.dtype)
|
||||
return attn_weights.matmul(value), key, value
|
||||
|
||||
tensor_shape = (4, 2, 16, 32)
|
||||
attn_mask = torch.randn((1, 1, 2, 2), dtype=torch.float, device=self.device)
|
||||
args = [
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
attn_mask,
|
||||
]
|
||||
self._check_common(
|
||||
dot_prod_attention,
|
||||
args1=args,
|
||||
has_dropout=False,
|
||||
check_train=False,
|
||||
)
|
||||
# test attn_mask with stride of last dim != 1
|
||||
attn_mask_ = attn_mask.transpose(2, 3)
|
||||
args[3] = attn_mask_
|
||||
self._check_common(
|
||||
dot_prod_attention,
|
||||
args1=args,
|
||||
has_dropout=False,
|
||||
check_train=False,
|
||||
contains=self.device == "cpu",
|
||||
)
|
||||
tensor_shapes = [(4, 2, 16, 32), (1, 2, 16, 32)]
|
||||
for tensor_shape in tensor_shapes:
|
||||
attn_mask = torch.randn((1, 1, 2, 2), dtype=torch.float, device=self.device)
|
||||
args = [
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
attn_mask,
|
||||
]
|
||||
self._check_common(
|
||||
dot_prod_attention,
|
||||
args1=args,
|
||||
has_dropout=False,
|
||||
check_train=False,
|
||||
)
|
||||
# test attn_mask with stride of last dim != 1
|
||||
attn_mask_ = attn_mask.transpose(2, 3)
|
||||
args[3] = attn_mask_
|
||||
self._check_common(
|
||||
dot_prod_attention,
|
||||
args1=args,
|
||||
has_dropout=False,
|
||||
check_train=False,
|
||||
contains=self.device == "cpu",
|
||||
)
|
||||
|
||||
def _test_sdpa_rewriter_23(self):
|
||||
def dot_prod_attention(
|
||||
@ -1067,18 +1069,19 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
attn_weights = scores.float().softmax(dim=-1).type(value.dtype)
|
||||
return attn_weights.matmul(value), key, value
|
||||
|
||||
tensor_shape = (4, 2, 16, 32)
|
||||
args = [
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
]
|
||||
self._check_common(
|
||||
dot_prod_attention,
|
||||
args1=args,
|
||||
has_dropout=False,
|
||||
check_train=False,
|
||||
)
|
||||
tensor_shapes = [(4, 2, 16, 32), (1, 2, 16, 32)]
|
||||
for tensor_shape in tensor_shapes:
|
||||
args = [
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
torch.randn(tensor_shape, device=self.device),
|
||||
]
|
||||
self._check_common(
|
||||
dot_prod_attention,
|
||||
args1=args,
|
||||
has_dropout=False,
|
||||
check_train=False,
|
||||
)
|
||||
|
||||
def _test_sdpa_rewriter_24(self):
|
||||
def dot_prod_attention(
|
||||
|
@ -581,42 +581,6 @@ def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p):
|
||||
)
|
||||
|
||||
|
||||
def _sfdp_pattern_24(query, key, value, attention_mask):
|
||||
"""
|
||||
this pattern is for MBartForCausalLM/PLBartForCausalLM.
|
||||
attn_mask has a different dtype with QKV.
|
||||
there is no scale in sdpa.
|
||||
"""
|
||||
bs = query.size(0)
|
||||
n_head = query.size(1)
|
||||
seq_len = query.size(2)
|
||||
head_size = query.size(3)
|
||||
q = query.view(bs * n_head, -1, head_size)
|
||||
k = key.reshape(bs * n_head, -1, head_size)
|
||||
v = value.reshape(bs * n_head, -1, head_size)
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = attn_weights.view(bs, n_head, seq_len, -1) + attention_mask
|
||||
attn_weights = attn_weights.view(bs * n_head, seq_len, -1)
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
if query.dtype == torch.half:
|
||||
attn_weights = attn_weights.to(torch.half)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
attn_output = attn_output.view(bs, n_head, seq_len, head_size)
|
||||
return attn_output
|
||||
|
||||
|
||||
def _sfdp_replacement_24(query, key, value, attention_mask):
|
||||
counters["inductor"]["fuse_attention"] += 1
|
||||
return _scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask.to(dtype=query.dtype),
|
||||
is_causal=False,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
|
||||
def _sfdp_pattern_21(query, key, value, attn_mask):
|
||||
# for T5 with inplace add
|
||||
query = query.permute([0, 2, 1, 3])
|
||||
@ -643,7 +607,7 @@ def _sfdp_replacement_21(query, key, value, attn_mask):
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
attn_mask=attn_mask.to(dtype=query.dtype),
|
||||
is_causal=False,
|
||||
scale=1.0,
|
||||
)
|
||||
@ -676,7 +640,7 @@ def _sfdp_replacement_22(query, key, value, attn_mask):
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
attn_mask=attn_mask.to(dtype=query.dtype),
|
||||
is_causal=False,
|
||||
scale=1.0,
|
||||
),
|
||||
@ -723,6 +687,42 @@ def _sfdp_replacement_23(query, key, value):
|
||||
)
|
||||
|
||||
|
||||
def _sfdp_pattern_24(query, key, value, attention_mask):
|
||||
"""
|
||||
this pattern is for MBartForCausalLM/PLBartForCausalLM.
|
||||
attn_mask has a different dtype with QKV.
|
||||
there is no scale in sdpa.
|
||||
"""
|
||||
bs = query.size(0)
|
||||
n_head = query.size(1)
|
||||
seq_len = query.size(2)
|
||||
head_size = query.size(3)
|
||||
q = query.view(bs * n_head, -1, head_size)
|
||||
k = key.reshape(bs * n_head, -1, head_size)
|
||||
v = value.reshape(bs * n_head, -1, head_size)
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = attn_weights.view(bs, n_head, seq_len, -1) + attention_mask
|
||||
attn_weights = attn_weights.view(bs * n_head, seq_len, -1)
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
if query.dtype == torch.half:
|
||||
attn_weights = attn_weights.to(torch.half)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
attn_output = attn_output.view(bs, n_head, seq_len, head_size)
|
||||
return attn_output
|
||||
|
||||
|
||||
def _sfdp_replacement_24(query, key, value, attention_mask):
|
||||
counters["inductor"]["fuse_attention"] += 1
|
||||
return _scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask.to(dtype=query.dtype),
|
||||
is_causal=False,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
|
||||
def _sfdp_params_check(match):
|
||||
assert all(k in match.kwargs for k in ("query", "key", "value"))
|
||||
query = match.kwargs["query"].meta["val"]
|
||||
@ -1024,6 +1024,13 @@ def _get_sfdp_patterns():
|
||||
{},
|
||||
_sfdp_params_check,
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_21,
|
||||
_sfdp_replacement_21,
|
||||
[g_bs1(), g_bs1(), g_bs1(), m_bs1_float()],
|
||||
{},
|
||||
_sfdp_params_check,
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_22,
|
||||
_sfdp_replacement_22,
|
||||
@ -1031,6 +1038,13 @@ def _get_sfdp_patterns():
|
||||
{},
|
||||
_sfdp_params_check,
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_22,
|
||||
_sfdp_replacement_22,
|
||||
[g_bs1(), g_bs1(), g_bs1(), m_bs1_float()],
|
||||
{},
|
||||
_sfdp_params_check,
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_23,
|
||||
_sfdp_replacement_23,
|
||||
@ -1038,6 +1052,13 @@ def _get_sfdp_patterns():
|
||||
{},
|
||||
_sfdp_params_check,
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_23,
|
||||
_sfdp_replacement_23,
|
||||
[g_bs1(), g_bs1(), g_bs1()],
|
||||
{},
|
||||
_sfdp_params_check,
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_24,
|
||||
_sfdp_replacement_24,
|
||||
|
@ -119,6 +119,88 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
_sfdp_pattern_21_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
|
||||
view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
||||
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
|
||||
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
view_default_10 = CallFunction(aten.view.default, fma_default, Ignored())
|
||||
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
|
||||
view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2)
|
||||
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
||||
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
|
||||
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
||||
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
|
||||
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
||||
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
|
||||
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
||||
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
|
||||
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
||||
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
|
||||
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
||||
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
|
||||
_sfdp_pattern_21_bs1_training = MultiOutputPattern([view_default_7,
|
||||
permute_default_6,
|
||||
permute_default_9,
|
||||
permute_default_11,
|
||||
None
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
|
||||
view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
_sfdp_pattern_21_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
||||
@ -215,3 +297,95 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
||||
view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
_sfdp_pattern_21_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
||||
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
|
||||
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
||||
view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored())
|
||||
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
|
||||
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored())
|
||||
convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored())
|
||||
view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2)
|
||||
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
||||
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
|
||||
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
||||
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
|
||||
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
||||
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
|
||||
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
||||
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
|
||||
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
||||
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
|
||||
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
||||
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
|
||||
_sfdp_pattern_21_half_bs1_training = MultiOutputPattern([view_default_7,
|
||||
permute_default_6,
|
||||
permute_default_9,
|
||||
permute_default_11,
|
||||
None
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
_sfdp_pattern_21_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
@ -125,6 +125,94 @@ _sfdp_pattern_22_inference = MultiOutputPattern([view_default_7,
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
|
||||
view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
||||
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
|
||||
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
view_default_10 = CallFunction(aten.view.default, fma_default, Ignored())
|
||||
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
|
||||
view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2)
|
||||
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
||||
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
|
||||
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
||||
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
|
||||
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
||||
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
|
||||
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
||||
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
|
||||
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
||||
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
|
||||
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
||||
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
|
||||
_sfdp_pattern_22_bs1_training = MultiOutputPattern([view_default_7,
|
||||
permute_default_1,
|
||||
permute_default_3,
|
||||
permute_default_6,
|
||||
permute_default_9,
|
||||
permute_default_11,
|
||||
None
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
|
||||
view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
_sfdp_pattern_22_bs1_inference = MultiOutputPattern([view_default_7,
|
||||
permute_default_1,
|
||||
permute_default_3
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
||||
@ -227,3 +315,101 @@ _sfdp_pattern_22_half_inference = MultiOutputPattern([view_default_7,
|
||||
permute_default_1,
|
||||
permute_default_3
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
||||
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
|
||||
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
||||
view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored())
|
||||
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
|
||||
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored())
|
||||
convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored())
|
||||
view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2)
|
||||
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
||||
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
|
||||
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
||||
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
|
||||
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
||||
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
|
||||
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
||||
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
|
||||
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
||||
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
|
||||
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
||||
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
|
||||
_sfdp_pattern_22_half_bs1_training = MultiOutputPattern([view_default_7,
|
||||
permute_default_1,
|
||||
permute_default_3,
|
||||
permute_default_6,
|
||||
permute_default_9,
|
||||
permute_default_11,
|
||||
None
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask'))
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
_sfdp_pattern_22_half_bs1_inference = MultiOutputPattern([view_default_7,
|
||||
permute_default_1,
|
||||
permute_default_3
|
||||
])
|
||||
|
@ -122,6 +122,91 @@ _sfdp_pattern_23_inference = MultiOutputPattern([view_default_7,
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
||||
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
|
||||
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
view_default_10 = CallFunction(aten.view.default, fma_default, Ignored())
|
||||
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
|
||||
view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2)
|
||||
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
||||
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
|
||||
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
||||
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
|
||||
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
||||
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
|
||||
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
||||
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
|
||||
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
||||
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
|
||||
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
||||
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
|
||||
_sfdp_pattern_23_bs1_training = MultiOutputPattern([view_default_7,
|
||||
permute_default_1,
|
||||
permute_default_3,
|
||||
permute_default_6,
|
||||
permute_default_9,
|
||||
permute_default_11
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
_sfdp_pattern_23_bs1_inference = MultiOutputPattern([view_default_7,
|
||||
permute_default_1,
|
||||
permute_default_3
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
||||
@ -223,3 +308,100 @@ _sfdp_pattern_23_half_inference = MultiOutputPattern([view_default_7,
|
||||
permute_default_1,
|
||||
permute_default_3
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
||||
view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4)
|
||||
view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_4, div_Tensor, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
||||
view_default_10 = CallFunction(aten.view.default, convert_element_type_default_5, Ignored())
|
||||
view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored())
|
||||
convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored())
|
||||
convert_element_type_default_7 = CallFunction(prims.convert_element_type.default, convert_element_type_default_6, Ignored())
|
||||
view_default_12 = CallFunction(aten.view.default, convert_element_type_default_7, Ignored(), _users=2)
|
||||
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
||||
bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5)
|
||||
view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
||||
permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored())
|
||||
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
||||
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12)
|
||||
view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
||||
permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored())
|
||||
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
||||
permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8)
|
||||
view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
||||
permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored())
|
||||
_sfdp_pattern_23_half_bs1_training = MultiOutputPattern([view_default_7,
|
||||
permute_default_1,
|
||||
permute_default_3,
|
||||
permute_default_6,
|
||||
permute_default_9,
|
||||
permute_default_11
|
||||
])
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored())
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored())
|
||||
view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
_sfdp_pattern_23_half_bs1_inference = MultiOutputPattern([view_default_7,
|
||||
permute_default_1,
|
||||
permute_default_3
|
||||
])
|
||||
|
Reference in New Issue
Block a user