[aoti][mps] Add fused_rms and sdpa_mps fallback ops (#156844)

Needed for llama3.1

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156844
Approved by: https://github.com/desertfire
ghstack dependencies: #156843
This commit is contained in:
angelayi
2025-06-25 20:06:49 -07:00
committed by PyTorch MergeBot
parent 17dab018e3
commit aff9c1eec5
2 changed files with 4 additions and 0 deletions

View File

@ -18,7 +18,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, int64_t normalized_shape_ndim, AtenTensorHandle weight, double eps, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);

View File

@ -39,10 +39,12 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
"aten._flash_attention_forward.default": {},
"aten._fused_moving_avg_obs_fq_helper_functional.default": {},
"aten._fused_moving_avg_obs_fq_helper.default": {},
"aten._fused_rms_norm.default": {},
"aten._histogramdd_from_bin_cts.default": {},
"aten._int_mm.out": {},
"aten._pdist_backward.default": {},
"aten._pdist_forward.default": {},
"aten._scaled_dot_product_attention_math_for_mps.default": {},
"aten._scaled_dot_product_cudnn_attention_backward.default": {},
"aten._scaled_dot_product_cudnn_attention.default": {},
"aten._scaled_dot_product_efficient_attention_backward.default": {},