diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index 25e96fa9f1e9..4438df288487 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -997,20 +997,21 @@ class TestSDPAPatternRewriterTemplate(TestCase): attn_weights = scores.float().softmax(dim=-1).type(value.dtype) return attn_weights.matmul(value) - tensor_shape = (4, 2, 16, 32) - attn_mask = torch.randn((1, 1, 1, 2), dtype=torch.float, device=self.device) - args = [ - torch.randn(tensor_shape, device=self.device), - torch.randn(tensor_shape, device=self.device), - torch.randn(tensor_shape, device=self.device), - attn_mask, - ] - self._check_common( - dot_prod_attention, - args1=args, - has_dropout=False, - check_train=False, - ) + tensor_shapes = [(4, 2, 16, 32), (1, 2, 16, 32)] + for tensor_shape in tensor_shapes: + attn_mask = torch.randn((1, 1, 1, 2), dtype=torch.float, device=self.device) + args = [ + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + attn_mask, + ] + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + ) def _test_sdpa_rewriter_22(self): def dot_prod_attention( @@ -1027,30 +1028,31 @@ class TestSDPAPatternRewriterTemplate(TestCase): attn_weights = scores.float().softmax(dim=-1).type(value.dtype) return attn_weights.matmul(value), key, value - tensor_shape = (4, 2, 16, 32) - attn_mask = torch.randn((1, 1, 2, 2), dtype=torch.float, device=self.device) - args = [ - torch.randn(tensor_shape, device=self.device), - torch.randn(tensor_shape, device=self.device), - torch.randn(tensor_shape, device=self.device), - attn_mask, - ] - self._check_common( - dot_prod_attention, - args1=args, - has_dropout=False, - check_train=False, - ) - # test attn_mask with stride of last dim != 1 - attn_mask_ = attn_mask.transpose(2, 3) - args[3] = attn_mask_ - self._check_common( - dot_prod_attention, - args1=args, - has_dropout=False, - check_train=False, - contains=self.device == "cpu", - ) + tensor_shapes = [(4, 2, 16, 32), (1, 2, 16, 32)] + for tensor_shape in tensor_shapes: + attn_mask = torch.randn((1, 1, 2, 2), dtype=torch.float, device=self.device) + args = [ + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + attn_mask, + ] + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + ) + # test attn_mask with stride of last dim != 1 + attn_mask_ = attn_mask.transpose(2, 3) + args[3] = attn_mask_ + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + contains=self.device == "cpu", + ) def _test_sdpa_rewriter_23(self): def dot_prod_attention( @@ -1067,18 +1069,19 @@ class TestSDPAPatternRewriterTemplate(TestCase): attn_weights = scores.float().softmax(dim=-1).type(value.dtype) return attn_weights.matmul(value), key, value - tensor_shape = (4, 2, 16, 32) - args = [ - torch.randn(tensor_shape, device=self.device), - torch.randn(tensor_shape, device=self.device), - torch.randn(tensor_shape, device=self.device), - ] - self._check_common( - dot_prod_attention, - args1=args, - has_dropout=False, - check_train=False, - ) + tensor_shapes = [(4, 2, 16, 32), (1, 2, 16, 32)] + for tensor_shape in tensor_shapes: + args = [ + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + ] + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + ) def _test_sdpa_rewriter_24(self): def dot_prod_attention( diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 5f449eb49664..9a09d2531348 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -581,42 +581,6 @@ def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): ) -def _sfdp_pattern_24(query, key, value, attention_mask): - """ - this pattern is for MBartForCausalLM/PLBartForCausalLM. - attn_mask has a different dtype with QKV. - there is no scale in sdpa. - """ - bs = query.size(0) - n_head = query.size(1) - seq_len = query.size(2) - head_size = query.size(3) - q = query.view(bs * n_head, -1, head_size) - k = key.reshape(bs * n_head, -1, head_size) - v = value.reshape(bs * n_head, -1, head_size) - attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn_weights = attn_weights.view(bs, n_head, seq_len, -1) + attention_mask - attn_weights = attn_weights.view(bs * n_head, seq_len, -1) - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - if query.dtype == torch.half: - attn_weights = attn_weights.to(torch.half) - attn_output = torch.bmm(attn_weights, v) - attn_output = attn_output.view(bs, n_head, seq_len, head_size) - return attn_output - - -def _sfdp_replacement_24(query, key, value, attention_mask): - counters["inductor"]["fuse_attention"] += 1 - return _scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask.to(dtype=query.dtype), - is_causal=False, - scale=1, - ) - - def _sfdp_pattern_21(query, key, value, attn_mask): # for T5 with inplace add query = query.permute([0, 2, 1, 3]) @@ -643,7 +607,7 @@ def _sfdp_replacement_21(query, key, value, attn_mask): query, key, value, - attn_mask=attn_mask, + attn_mask=attn_mask.to(dtype=query.dtype), is_causal=False, scale=1.0, ) @@ -676,7 +640,7 @@ def _sfdp_replacement_22(query, key, value, attn_mask): query, key, value, - attn_mask=attn_mask, + attn_mask=attn_mask.to(dtype=query.dtype), is_causal=False, scale=1.0, ), @@ -723,6 +687,42 @@ def _sfdp_replacement_23(query, key, value): ) +def _sfdp_pattern_24(query, key, value, attention_mask): + """ + this pattern is for MBartForCausalLM/PLBartForCausalLM. + attn_mask has a different dtype with QKV. + there is no scale in sdpa. + """ + bs = query.size(0) + n_head = query.size(1) + seq_len = query.size(2) + head_size = query.size(3) + q = query.view(bs * n_head, -1, head_size) + k = key.reshape(bs * n_head, -1, head_size) + v = value.reshape(bs * n_head, -1, head_size) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = attn_weights.view(bs, n_head, seq_len, -1) + attention_mask + attn_weights = attn_weights.view(bs * n_head, seq_len, -1) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + if query.dtype == torch.half: + attn_weights = attn_weights.to(torch.half) + attn_output = torch.bmm(attn_weights, v) + attn_output = attn_output.view(bs, n_head, seq_len, head_size) + return attn_output + + +def _sfdp_replacement_24(query, key, value, attention_mask): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask.to(dtype=query.dtype), + is_causal=False, + scale=1, + ) + + def _sfdp_params_check(match): assert all(k in match.kwargs for k in ("query", "key", "value")) query = match.kwargs["query"].meta["val"] @@ -1024,6 +1024,13 @@ def _get_sfdp_patterns(): {}, _sfdp_params_check, ), + ( + _sfdp_pattern_21, + _sfdp_replacement_21, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_float()], + {}, + _sfdp_params_check, + ), ( _sfdp_pattern_22, _sfdp_replacement_22, @@ -1031,6 +1038,13 @@ def _get_sfdp_patterns(): {}, _sfdp_params_check, ), + ( + _sfdp_pattern_22, + _sfdp_replacement_22, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_float()], + {}, + _sfdp_params_check, + ), ( _sfdp_pattern_23, _sfdp_replacement_23, @@ -1038,6 +1052,13 @@ def _get_sfdp_patterns(): {}, _sfdp_params_check, ), + ( + _sfdp_pattern_23, + _sfdp_replacement_23, + [g_bs1(), g_bs1(), g_bs1()], + {}, + _sfdp_params_check, + ), ( _sfdp_pattern_24, _sfdp_replacement_24, diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py index ad27e6eb6bb8..4ebd4a4e14e4 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py @@ -119,6 +119,88 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) _sfdp_pattern_21_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_bs1_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) @@ -215,3 +297,95 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) _sfdp_pattern_21_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_half_bs1_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py index 41a433e40543..0971c09ad972 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py @@ -125,6 +125,94 @@ _sfdp_pattern_22_inference = MultiOutputPattern([view_default_7, ]) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) @@ -227,3 +315,101 @@ _sfdp_pattern_22_half_inference = MultiOutputPattern([view_default_7, permute_default_1, permute_default_3 ]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_half_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_half_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py index dc6f27cd2849..2be036c2e8ae 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py @@ -122,6 +122,91 @@ _sfdp_pattern_23_inference = MultiOutputPattern([view_default_7, ]) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) @@ -223,3 +308,100 @@ _sfdp_pattern_23_half_inference = MultiOutputPattern([view_default_7, permute_default_1, permute_default_3 ]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_5, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_7 = CallFunction(prims.convert_element_type.default, convert_element_type_default_6, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_7, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_half_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_half_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +])