mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor CUDA's amp cast policy to be generic (#124051)
# Motivation This PR intends to create several op lists for different policies: - `AT_FORALL_LOWER_PRECISION_FP` for policy `lower_precision_fp` - `AT_FORALL_FP32` for policy `fp32` - `AT_FORALL_FP32_SET_OPT_DTYPE` for policy `fp32_set_opt_dtype` - `AT_FORALL_PROMOTE` for policy `promote`. To make sure the other backend can reuse the policy op list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124051 Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD ghstack dependencies: #124050
This commit is contained in:
committed by
PyTorch MergeBot
parent
8ad66e05d2
commit
ec608a5d66
@ -241,135 +241,187 @@ namespace {
|
||||
/*****************************************
|
||||
Explicit registration for out-of-place ops
|
||||
*****************************************/
|
||||
|
||||
#define AT_FORALL_LOWER_PRECISION_FP(_) \
|
||||
_(_convolution, deprecated) \
|
||||
_(_convolution) \
|
||||
_(conv1d) \
|
||||
_(conv2d) \
|
||||
_(conv3d) \
|
||||
_(conv_tbc) \
|
||||
_(conv_transpose1d) \
|
||||
_(conv_transpose2d, input) \
|
||||
_(conv_transpose3d, input) \
|
||||
_(convolution) \
|
||||
_(prelu) \
|
||||
_(addmm) \
|
||||
_(addmv) \
|
||||
_(addr) \
|
||||
_(matmul) \
|
||||
_(einsum) \
|
||||
_(mm) \
|
||||
_(mv) \
|
||||
_(linalg_vecdot) \
|
||||
_(linear) \
|
||||
_(addbmm) \
|
||||
_(baddbmm) \
|
||||
_(bmm) \
|
||||
_(chain_matmul) \
|
||||
_(linalg_multi_dot) \
|
||||
_(_thnn_fused_lstm_cell) \
|
||||
_(_thnn_fused_gru_cell) \
|
||||
_(lstm_cell) \
|
||||
_(gru_cell) \
|
||||
_(rnn_tanh_cell) \
|
||||
_(rnn_relu_cell) \
|
||||
_(_scaled_dot_product_flash_attention) \
|
||||
_(scaled_dot_product_attention)
|
||||
|
||||
#define AT_FORALL_FP32(_) \
|
||||
_(acos) \
|
||||
_(asin) \
|
||||
_(cosh) \
|
||||
_(erfinv) \
|
||||
_(exp) \
|
||||
_(expm1) \
|
||||
_(log) \
|
||||
_(log10) \
|
||||
_(log2) \
|
||||
_(log1p) \
|
||||
_(reciprocal) \
|
||||
_(rsqrt) \
|
||||
_(sinh) \
|
||||
_(tan) \
|
||||
_(pow, Tensor_Scalar) \
|
||||
_(pow, Tensor_Tensor) \
|
||||
_(pow, Scalar) \
|
||||
_(softplus) \
|
||||
_(layer_norm) \
|
||||
_(native_layer_norm) \
|
||||
_(group_norm) \
|
||||
_(frobenius_norm, dim) \
|
||||
_(nuclear_norm) \
|
||||
_(nuclear_norm, dim) \
|
||||
_(cosine_similarity) \
|
||||
_(poisson_nll_loss) \
|
||||
_(cosine_embedding_loss) \
|
||||
_(nll_loss) \
|
||||
_(nll_loss2d) \
|
||||
_(hinge_embedding_loss) \
|
||||
_(kl_div) \
|
||||
_(l1_loss) \
|
||||
_(smooth_l1_loss) \
|
||||
_(huber_loss) \
|
||||
_(mse_loss) \
|
||||
_(margin_ranking_loss) \
|
||||
_(multilabel_margin_loss) \
|
||||
_(soft_margin_loss) \
|
||||
_(triplet_margin_loss) \
|
||||
_(multi_margin_loss) \
|
||||
_(binary_cross_entropy_with_logits) \
|
||||
_(dist) \
|
||||
_(pdist) \
|
||||
_(cdist) \
|
||||
_(renorm) \
|
||||
_(logsumexp)
|
||||
|
||||
#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
|
||||
_(prod) \
|
||||
_(prod, dim_int) \
|
||||
_(prod, dim_Dimname) \
|
||||
_(softmax, int) \
|
||||
_(softmax, Dimname) \
|
||||
_(log_softmax, int) \
|
||||
_(log_softmax, Dimname) \
|
||||
_(cumprod) \
|
||||
_(cumprod, dimname) \
|
||||
_(cumsum) \
|
||||
_(cumsum, dimname) \
|
||||
_(linalg_vector_norm) \
|
||||
_(linalg_matrix_norm) \
|
||||
_(linalg_matrix_norm, str_ord) \
|
||||
_(sum) \
|
||||
_(sum, dim_IntList) \
|
||||
_(sum, dim_DimnameList)
|
||||
|
||||
#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
|
||||
_(ADD_NS(norm), \
|
||||
"norm.Scalar", \
|
||||
Tensor(const Tensor&, const Scalar&), \
|
||||
Tensor(const Tensor&, const c10::optional<Scalar>&, ScalarType), \
|
||||
fp32_append_dtype) \
|
||||
_(ADD_NS(norm), \
|
||||
"norm.ScalarOpt_dim", \
|
||||
Tensor(const Tensor&, const c10::optional<Scalar>&, IntArrayRef, bool), \
|
||||
Tensor( \
|
||||
const Tensor&, \
|
||||
const c10::optional<Scalar>&, \
|
||||
IntArrayRef, \
|
||||
bool, \
|
||||
ScalarType), \
|
||||
fp32_append_dtype) \
|
||||
_(ADD_NS(norm), \
|
||||
"norm.names_ScalarOpt_dim", \
|
||||
Tensor(const Tensor&, const c10::optional<Scalar>&, DimnameList, bool), \
|
||||
Tensor( \
|
||||
const Tensor&, \
|
||||
const c10::optional<Scalar>&, \
|
||||
DimnameList, \
|
||||
bool, \
|
||||
ScalarType), \
|
||||
fp32_append_dtype)
|
||||
|
||||
#define AT_FORALL_PROMOTE(_) \
|
||||
_(addcdiv) \
|
||||
_(addcmul) \
|
||||
_(atan2) \
|
||||
_(bilinear) \
|
||||
_(cross) \
|
||||
_(dot) \
|
||||
_(grid_sampler) \
|
||||
_(index_put) \
|
||||
_(tensordot) \
|
||||
_(scatter_add)
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, Autocast, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
||||
// lower_precision_fp
|
||||
KERNEL_CUDA(_convolution, deprecated, lower_precision_fp)
|
||||
KERNEL_CUDA(_convolution, lower_precision_fp)
|
||||
KERNEL_CUDA(conv1d, lower_precision_fp)
|
||||
KERNEL_CUDA(conv2d, lower_precision_fp)
|
||||
KERNEL_CUDA(conv3d, lower_precision_fp)
|
||||
KERNEL_CUDA(conv_tbc, lower_precision_fp)
|
||||
KERNEL_CUDA(conv_transpose1d, lower_precision_fp)
|
||||
KERNEL_CUDA(conv_transpose2d, input, lower_precision_fp)
|
||||
KERNEL_CUDA(conv_transpose3d, input, lower_precision_fp)
|
||||
KERNEL_CUDA(convolution, lower_precision_fp)
|
||||
#define _KERNEL_CUDA_LOW_PRECISION_FP(...) \
|
||||
KERNEL_CUDA(__VA_ARGS__, lower_precision_fp)
|
||||
|
||||
AT_FORALL_LOWER_PRECISION_FP(_KERNEL_CUDA_LOW_PRECISION_FP)
|
||||
KERNEL_CUDA(cudnn_convolution, lower_precision_fp)
|
||||
KERNEL_CUDA(cudnn_convolution_transpose, lower_precision_fp)
|
||||
KERNEL_CUDA(prelu, lower_precision_fp)
|
||||
KERNEL_CUDA(addmm, lower_precision_fp)
|
||||
KERNEL_CUDA(addmv, lower_precision_fp)
|
||||
KERNEL_CUDA(addr, lower_precision_fp)
|
||||
KERNEL_CUDA(matmul, lower_precision_fp)
|
||||
KERNEL_CUDA(einsum, lower_precision_fp)
|
||||
KERNEL_CUDA(mm, lower_precision_fp)
|
||||
KERNEL_CUDA(mv, lower_precision_fp)
|
||||
KERNEL_CUDA(linalg_vecdot, lower_precision_fp)
|
||||
KERNEL_CUDA(linear, lower_precision_fp)
|
||||
KERNEL_CUDA(addbmm, lower_precision_fp)
|
||||
KERNEL_CUDA(baddbmm, lower_precision_fp)
|
||||
KERNEL_CUDA(bmm, lower_precision_fp)
|
||||
KERNEL_CUDA(chain_matmul, lower_precision_fp)
|
||||
KERNEL_CUDA(linalg_multi_dot, lower_precision_fp)
|
||||
KERNEL_CUDA(_thnn_fused_lstm_cell, lower_precision_fp)
|
||||
KERNEL_CUDA(_thnn_fused_gru_cell, lower_precision_fp)
|
||||
KERNEL_CUDA(lstm_cell, lower_precision_fp)
|
||||
KERNEL_CUDA(gru_cell, lower_precision_fp)
|
||||
KERNEL_CUDA(rnn_tanh_cell, lower_precision_fp)
|
||||
KERNEL_CUDA(rnn_relu_cell, lower_precision_fp)
|
||||
KERNEL_CUDA(_scaled_dot_product_flash_attention, lower_precision_fp)
|
||||
KERNEL_CUDA(scaled_dot_product_attention, lower_precision_fp)
|
||||
|
||||
// fp32
|
||||
KERNEL_CUDA(acos, fp32)
|
||||
KERNEL_CUDA(asin, fp32)
|
||||
KERNEL_CUDA(cosh, fp32)
|
||||
KERNEL_CUDA(erfinv, fp32)
|
||||
KERNEL_CUDA(exp, fp32)
|
||||
KERNEL_CUDA(expm1, fp32)
|
||||
KERNEL_CUDA(log, fp32)
|
||||
KERNEL_CUDA(log10, fp32)
|
||||
KERNEL_CUDA(log2, fp32)
|
||||
KERNEL_CUDA(log1p, fp32)
|
||||
KERNEL_CUDA(reciprocal, fp32)
|
||||
KERNEL_CUDA(rsqrt, fp32)
|
||||
KERNEL_CUDA(sinh, fp32)
|
||||
KERNEL_CUDA(tan, fp32)
|
||||
KERNEL_CUDA(pow, Tensor_Scalar, fp32)
|
||||
KERNEL_CUDA(pow, Tensor_Tensor, fp32)
|
||||
KERNEL_CUDA(pow, Scalar, fp32)
|
||||
KERNEL_CUDA(softplus, fp32)
|
||||
KERNEL_CUDA(layer_norm, fp32)
|
||||
KERNEL_CUDA(native_layer_norm, fp32)
|
||||
KERNEL_CUDA(group_norm, fp32)
|
||||
KERNEL_CUDA(frobenius_norm, dim, fp32)
|
||||
KERNEL_CUDA(nuclear_norm, fp32)
|
||||
KERNEL_CUDA(nuclear_norm, dim, fp32)
|
||||
KERNEL_CUDA(cosine_similarity, fp32)
|
||||
KERNEL_CUDA(poisson_nll_loss, fp32)
|
||||
KERNEL_CUDA(cosine_embedding_loss, fp32)
|
||||
KERNEL_CUDA(nll_loss, fp32)
|
||||
KERNEL_CUDA(nll_loss2d, fp32)
|
||||
KERNEL_CUDA(hinge_embedding_loss, fp32)
|
||||
KERNEL_CUDA(kl_div, fp32)
|
||||
KERNEL_CUDA(l1_loss, fp32)
|
||||
KERNEL_CUDA(smooth_l1_loss, fp32)
|
||||
KERNEL_CUDA(huber_loss, fp32)
|
||||
KERNEL_CUDA(mse_loss, fp32)
|
||||
KERNEL_CUDA(margin_ranking_loss, fp32)
|
||||
KERNEL_CUDA(multilabel_margin_loss, fp32)
|
||||
KERNEL_CUDA(soft_margin_loss, fp32)
|
||||
KERNEL_CUDA(triplet_margin_loss, fp32)
|
||||
KERNEL_CUDA(multi_margin_loss, fp32)
|
||||
KERNEL_CUDA(binary_cross_entropy_with_logits, fp32)
|
||||
KERNEL_CUDA(dist, fp32)
|
||||
KERNEL_CUDA(pdist, fp32)
|
||||
KERNEL_CUDA(cdist, fp32)
|
||||
KERNEL_CUDA(renorm, fp32)
|
||||
KERNEL_CUDA(logsumexp, fp32)
|
||||
#define _KERNEL_CUDA_FP32(...) KERNEL_CUDA(__VA_ARGS__, fp32)
|
||||
|
||||
AT_FORALL_FP32(_KERNEL_CUDA_FP32)
|
||||
|
||||
// fp32_set_opt_dtype
|
||||
KERNEL_CUDA(prod, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(prod, dim_int, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(prod, dim_Dimname, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(softmax, int, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(softmax, Dimname, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(log_softmax, int, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(log_softmax, Dimname, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(cumprod, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(cumprod, dimname, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(cumsum, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(cumsum, dimname, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(linalg_vector_norm, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(linalg_matrix_norm, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(linalg_matrix_norm, str_ord, fp32_set_opt_dtype)
|
||||
#define _KERNEL_CUDA_FP32_SET_OPT_DTYPE(...) \
|
||||
KERNEL_CUDA(__VA_ARGS__, fp32_set_opt_dtype)
|
||||
|
||||
AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_CUDA_FP32_SET_OPT_DTYPE)
|
||||
// commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
|
||||
// when autocasting.
|
||||
// KERNEL_CUDA(norm, ScalarOpt_dtype, fp32_set_opt_dtype)
|
||||
// KERNEL_CUDA(norm, ScalarOpt_dim_dtype, fp32_set_opt_dtype)
|
||||
// KERNEL_CUDA(norm, names_ScalarOpt_dim_dtype, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(sum, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(sum, dim_IntList, fp32_set_opt_dtype)
|
||||
KERNEL_CUDA(sum, dim_DimnameList, fp32_set_opt_dtype)
|
||||
|
||||
// fp32_append_dtype
|
||||
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
|
||||
// norm does not implicitly promote, but be aware when adding new ops to this policy.
|
||||
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, const Scalar&), Tensor (const Tensor &, const c10::optional<Scalar>&, ScalarType), fp32_append_dtype)
|
||||
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool, ScalarType), fp32_append_dtype)
|
||||
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool, ScalarType), fp32_append_dtype)
|
||||
AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(
|
||||
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA)
|
||||
|
||||
// promote
|
||||
KERNEL_CUDA(addcdiv, promote)
|
||||
KERNEL_CUDA(addcmul, promote)
|
||||
KERNEL_CUDA(atan2, promote)
|
||||
KERNEL_CUDA(bilinear, promote)
|
||||
KERNEL_CUDA(cross, promote)
|
||||
KERNEL_CUDA(dot, promote)
|
||||
KERNEL_CUDA(grid_sampler, promote)
|
||||
KERNEL_CUDA(index_put, promote)
|
||||
KERNEL_CUDA(tensordot, promote)
|
||||
KERNEL_CUDA(scatter_add, promote)
|
||||
#define _KERNEL_CUDA_PROMOTE(...) KERNEL_CUDA(__VA_ARGS__, promote)
|
||||
|
||||
AT_FORALL_PROMOTE(_KERNEL_CUDA_PROMOTE)
|
||||
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
|
||||
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
|
||||
|
Reference in New Issue
Block a user