Refactor CUDA's amp cast policy to be generic (#124051)

# Motivation
This PR intends to create several op lists for different policies:
- `AT_FORALL_LOWER_PRECISION_FP` for policy `lower_precision_fp`
- `AT_FORALL_FP32` for policy `fp32`
- `AT_FORALL_FP32_SET_OPT_DTYPE` for policy `fp32_set_opt_dtype`
- `AT_FORALL_PROMOTE` for policy `promote`.

To make sure the other backend can reuse the policy op list.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124051
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD
ghstack dependencies: #124050
This commit is contained in:
Yu, Guangye
2024-04-17 09:55:28 +00:00
committed by PyTorch MergeBot
parent 8ad66e05d2
commit ec608a5d66

View File

@ -241,135 +241,187 @@ namespace {
/*****************************************
Explicit registration for out-of-place ops
*****************************************/
#define AT_FORALL_LOWER_PRECISION_FP(_) \
_(_convolution, deprecated) \
_(_convolution) \
_(conv1d) \
_(conv2d) \
_(conv3d) \
_(conv_tbc) \
_(conv_transpose1d) \
_(conv_transpose2d, input) \
_(conv_transpose3d, input) \
_(convolution) \
_(prelu) \
_(addmm) \
_(addmv) \
_(addr) \
_(matmul) \
_(einsum) \
_(mm) \
_(mv) \
_(linalg_vecdot) \
_(linear) \
_(addbmm) \
_(baddbmm) \
_(bmm) \
_(chain_matmul) \
_(linalg_multi_dot) \
_(_thnn_fused_lstm_cell) \
_(_thnn_fused_gru_cell) \
_(lstm_cell) \
_(gru_cell) \
_(rnn_tanh_cell) \
_(rnn_relu_cell) \
_(_scaled_dot_product_flash_attention) \
_(scaled_dot_product_attention)
#define AT_FORALL_FP32(_) \
_(acos) \
_(asin) \
_(cosh) \
_(erfinv) \
_(exp) \
_(expm1) \
_(log) \
_(log10) \
_(log2) \
_(log1p) \
_(reciprocal) \
_(rsqrt) \
_(sinh) \
_(tan) \
_(pow, Tensor_Scalar) \
_(pow, Tensor_Tensor) \
_(pow, Scalar) \
_(softplus) \
_(layer_norm) \
_(native_layer_norm) \
_(group_norm) \
_(frobenius_norm, dim) \
_(nuclear_norm) \
_(nuclear_norm, dim) \
_(cosine_similarity) \
_(poisson_nll_loss) \
_(cosine_embedding_loss) \
_(nll_loss) \
_(nll_loss2d) \
_(hinge_embedding_loss) \
_(kl_div) \
_(l1_loss) \
_(smooth_l1_loss) \
_(huber_loss) \
_(mse_loss) \
_(margin_ranking_loss) \
_(multilabel_margin_loss) \
_(soft_margin_loss) \
_(triplet_margin_loss) \
_(multi_margin_loss) \
_(binary_cross_entropy_with_logits) \
_(dist) \
_(pdist) \
_(cdist) \
_(renorm) \
_(logsumexp)
#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
_(prod) \
_(prod, dim_int) \
_(prod, dim_Dimname) \
_(softmax, int) \
_(softmax, Dimname) \
_(log_softmax, int) \
_(log_softmax, Dimname) \
_(cumprod) \
_(cumprod, dimname) \
_(cumsum) \
_(cumsum, dimname) \
_(linalg_vector_norm) \
_(linalg_matrix_norm) \
_(linalg_matrix_norm, str_ord) \
_(sum) \
_(sum, dim_IntList) \
_(sum, dim_DimnameList)
#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
_(ADD_NS(norm), \
"norm.Scalar", \
Tensor(const Tensor&, const Scalar&), \
Tensor(const Tensor&, const c10::optional<Scalar>&, ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, IntArrayRef, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
IntArrayRef, \
bool, \
ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.names_ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, DimnameList, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
DimnameList, \
bool, \
ScalarType), \
fp32_append_dtype)
#define AT_FORALL_PROMOTE(_) \
_(addcdiv) \
_(addcmul) \
_(atan2) \
_(bilinear) \
_(cross) \
_(dot) \
_(grid_sampler) \
_(index_put) \
_(tensordot) \
_(scatter_add)
TORCH_LIBRARY_IMPL(_, Autocast, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(aten, Autocast, m) {
// lower_precision_fp
KERNEL_CUDA(_convolution, deprecated, lower_precision_fp)
KERNEL_CUDA(_convolution, lower_precision_fp)
KERNEL_CUDA(conv1d, lower_precision_fp)
KERNEL_CUDA(conv2d, lower_precision_fp)
KERNEL_CUDA(conv3d, lower_precision_fp)
KERNEL_CUDA(conv_tbc, lower_precision_fp)
KERNEL_CUDA(conv_transpose1d, lower_precision_fp)
KERNEL_CUDA(conv_transpose2d, input, lower_precision_fp)
KERNEL_CUDA(conv_transpose3d, input, lower_precision_fp)
KERNEL_CUDA(convolution, lower_precision_fp)
#define _KERNEL_CUDA_LOW_PRECISION_FP(...) \
KERNEL_CUDA(__VA_ARGS__, lower_precision_fp)
AT_FORALL_LOWER_PRECISION_FP(_KERNEL_CUDA_LOW_PRECISION_FP)
KERNEL_CUDA(cudnn_convolution, lower_precision_fp)
KERNEL_CUDA(cudnn_convolution_transpose, lower_precision_fp)
KERNEL_CUDA(prelu, lower_precision_fp)
KERNEL_CUDA(addmm, lower_precision_fp)
KERNEL_CUDA(addmv, lower_precision_fp)
KERNEL_CUDA(addr, lower_precision_fp)
KERNEL_CUDA(matmul, lower_precision_fp)
KERNEL_CUDA(einsum, lower_precision_fp)
KERNEL_CUDA(mm, lower_precision_fp)
KERNEL_CUDA(mv, lower_precision_fp)
KERNEL_CUDA(linalg_vecdot, lower_precision_fp)
KERNEL_CUDA(linear, lower_precision_fp)
KERNEL_CUDA(addbmm, lower_precision_fp)
KERNEL_CUDA(baddbmm, lower_precision_fp)
KERNEL_CUDA(bmm, lower_precision_fp)
KERNEL_CUDA(chain_matmul, lower_precision_fp)
KERNEL_CUDA(linalg_multi_dot, lower_precision_fp)
KERNEL_CUDA(_thnn_fused_lstm_cell, lower_precision_fp)
KERNEL_CUDA(_thnn_fused_gru_cell, lower_precision_fp)
KERNEL_CUDA(lstm_cell, lower_precision_fp)
KERNEL_CUDA(gru_cell, lower_precision_fp)
KERNEL_CUDA(rnn_tanh_cell, lower_precision_fp)
KERNEL_CUDA(rnn_relu_cell, lower_precision_fp)
KERNEL_CUDA(_scaled_dot_product_flash_attention, lower_precision_fp)
KERNEL_CUDA(scaled_dot_product_attention, lower_precision_fp)
// fp32
KERNEL_CUDA(acos, fp32)
KERNEL_CUDA(asin, fp32)
KERNEL_CUDA(cosh, fp32)
KERNEL_CUDA(erfinv, fp32)
KERNEL_CUDA(exp, fp32)
KERNEL_CUDA(expm1, fp32)
KERNEL_CUDA(log, fp32)
KERNEL_CUDA(log10, fp32)
KERNEL_CUDA(log2, fp32)
KERNEL_CUDA(log1p, fp32)
KERNEL_CUDA(reciprocal, fp32)
KERNEL_CUDA(rsqrt, fp32)
KERNEL_CUDA(sinh, fp32)
KERNEL_CUDA(tan, fp32)
KERNEL_CUDA(pow, Tensor_Scalar, fp32)
KERNEL_CUDA(pow, Tensor_Tensor, fp32)
KERNEL_CUDA(pow, Scalar, fp32)
KERNEL_CUDA(softplus, fp32)
KERNEL_CUDA(layer_norm, fp32)
KERNEL_CUDA(native_layer_norm, fp32)
KERNEL_CUDA(group_norm, fp32)
KERNEL_CUDA(frobenius_norm, dim, fp32)
KERNEL_CUDA(nuclear_norm, fp32)
KERNEL_CUDA(nuclear_norm, dim, fp32)
KERNEL_CUDA(cosine_similarity, fp32)
KERNEL_CUDA(poisson_nll_loss, fp32)
KERNEL_CUDA(cosine_embedding_loss, fp32)
KERNEL_CUDA(nll_loss, fp32)
KERNEL_CUDA(nll_loss2d, fp32)
KERNEL_CUDA(hinge_embedding_loss, fp32)
KERNEL_CUDA(kl_div, fp32)
KERNEL_CUDA(l1_loss, fp32)
KERNEL_CUDA(smooth_l1_loss, fp32)
KERNEL_CUDA(huber_loss, fp32)
KERNEL_CUDA(mse_loss, fp32)
KERNEL_CUDA(margin_ranking_loss, fp32)
KERNEL_CUDA(multilabel_margin_loss, fp32)
KERNEL_CUDA(soft_margin_loss, fp32)
KERNEL_CUDA(triplet_margin_loss, fp32)
KERNEL_CUDA(multi_margin_loss, fp32)
KERNEL_CUDA(binary_cross_entropy_with_logits, fp32)
KERNEL_CUDA(dist, fp32)
KERNEL_CUDA(pdist, fp32)
KERNEL_CUDA(cdist, fp32)
KERNEL_CUDA(renorm, fp32)
KERNEL_CUDA(logsumexp, fp32)
#define _KERNEL_CUDA_FP32(...) KERNEL_CUDA(__VA_ARGS__, fp32)
AT_FORALL_FP32(_KERNEL_CUDA_FP32)
// fp32_set_opt_dtype
KERNEL_CUDA(prod, fp32_set_opt_dtype)
KERNEL_CUDA(prod, dim_int, fp32_set_opt_dtype)
KERNEL_CUDA(prod, dim_Dimname, fp32_set_opt_dtype)
KERNEL_CUDA(softmax, int, fp32_set_opt_dtype)
KERNEL_CUDA(softmax, Dimname, fp32_set_opt_dtype)
KERNEL_CUDA(log_softmax, int, fp32_set_opt_dtype)
KERNEL_CUDA(log_softmax, Dimname, fp32_set_opt_dtype)
KERNEL_CUDA(cumprod, fp32_set_opt_dtype)
KERNEL_CUDA(cumprod, dimname, fp32_set_opt_dtype)
KERNEL_CUDA(cumsum, fp32_set_opt_dtype)
KERNEL_CUDA(cumsum, dimname, fp32_set_opt_dtype)
KERNEL_CUDA(linalg_vector_norm, fp32_set_opt_dtype)
KERNEL_CUDA(linalg_matrix_norm, fp32_set_opt_dtype)
KERNEL_CUDA(linalg_matrix_norm, str_ord, fp32_set_opt_dtype)
#define _KERNEL_CUDA_FP32_SET_OPT_DTYPE(...) \
KERNEL_CUDA(__VA_ARGS__, fp32_set_opt_dtype)
AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_CUDA_FP32_SET_OPT_DTYPE)
// commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
// when autocasting.
// KERNEL_CUDA(norm, ScalarOpt_dtype, fp32_set_opt_dtype)
// KERNEL_CUDA(norm, ScalarOpt_dim_dtype, fp32_set_opt_dtype)
// KERNEL_CUDA(norm, names_ScalarOpt_dim_dtype, fp32_set_opt_dtype)
KERNEL_CUDA(sum, fp32_set_opt_dtype)
KERNEL_CUDA(sum, dim_IntList, fp32_set_opt_dtype)
KERNEL_CUDA(sum, dim_DimnameList, fp32_set_opt_dtype)
// fp32_append_dtype
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
// norm does not implicitly promote, but be aware when adding new ops to this policy.
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, const Scalar&), Tensor (const Tensor &, const c10::optional<Scalar>&, ScalarType), fp32_append_dtype)
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool, ScalarType), fp32_append_dtype)
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool, ScalarType), fp32_append_dtype)
AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA)
// promote
KERNEL_CUDA(addcdiv, promote)
KERNEL_CUDA(addcmul, promote)
KERNEL_CUDA(atan2, promote)
KERNEL_CUDA(bilinear, promote)
KERNEL_CUDA(cross, promote)
KERNEL_CUDA(dot, promote)
KERNEL_CUDA(grid_sampler, promote)
KERNEL_CUDA(index_put, promote)
KERNEL_CUDA(tensordot, promote)
KERNEL_CUDA(scatter_add, promote)
#define _KERNEL_CUDA_PROMOTE(...) KERNEL_CUDA(__VA_ARGS__, promote)
AT_FORALL_PROMOTE(_KERNEL_CUDA_PROMOTE)
m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));