[6/N] Fix Wextra-semi warning (#139605)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139605
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-11-04 13:43:16 +00:00
committed by PyTorch MergeBot
parent 2ce2e4df4e
commit 419a7e197d
105 changed files with 396 additions and 396 deletions

View File

@ -510,7 +510,7 @@ void OperatorEntry::reportSignatureError(const CppSignature& call_signature, con
"This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). ",
"Please make sure that the function signature matches the signature in the operator registration call."
);
};
}
#ifndef STRIP_ERROR_MESSAGES
static std::string post_process_dispatch_key_str(std::string dispatch_key) {

View File

@ -132,7 +132,7 @@ template <typename T, typename std::enable_if_t<is_reduced_floating_point_v<T>,
inline void cvt_to_fp32(const __m128i& a, __m256& o);
template <> inline void cvt_to_fp32<BFloat16>(const __m128i& a, __m256& o) {
cvtbf16_fp32(a, o);
};
}
template <> inline void cvt_to_fp32<Half>(const __m128i& a, __m256& o) {
cvtfp16_fp32(a, o);
}
@ -1071,8 +1071,8 @@ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(c
inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const Vectorized<float>& b) { \
return cvt_from_fp32<type>(__m256(a), __m256(b)); \
}
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_VECTORIZED_INIT(Half, half);
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16)
CONVERT_VECTORIZED_INIT(Half, half)
#else // defined(CPU_CAPABILITY_AVX2)
@ -1096,9 +1096,9 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
convert(arr, arr2, K); \
return Vectorized<type>::loadu(arr2); \
}
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16)
#if !(defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE256))
CONVERT_NON_VECTORIZED_INIT(Half, half);
CONVERT_NON_VECTORIZED_INIT(Half, half)
#endif
#endif // defined(CPU_CAPABILITY_AVX2)

View File

@ -120,11 +120,11 @@ void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
}
void Interpreter::process(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack)));
INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack)))
}
void Interpreter::sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case) {
INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack, grad_special_case)));
INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack, grad_special_case)))
}
} // namespace at::functorch

View File

@ -10,7 +10,7 @@
namespace at::mps {
C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback)
namespace HeapAllocator {

View File

@ -330,7 +330,7 @@ void gemv_fast_path<at::Half>(
y,
*incy);
}
INSTANTIATE(c10::BFloat16);
INSTANTIATE(c10::BFloat16)
#else
template <>
bool scal_use_fast_path<at::Half>(

View File

@ -1251,7 +1251,7 @@ embedding_bag(const Tensor &weight, const Tensor &indices,
mode, sparse, per_sample_weights, include_last_offset, padding_idx);
}
return out;
};
}
std::tuple<Tensor, Tensor, Tensor, Tensor>
embedding_bag(const Tensor &weight, const Tensor &indices,

View File

@ -64,7 +64,7 @@ Tensor& max_unpooling2d_forward_out_cpu(
}
return output;
};
}
Tensor max_unpooling2d_forward_cpu(
const Tensor& self,

View File

@ -871,7 +871,7 @@ static std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose2d_backward_cpu(
return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
}
REGISTER_ALL_CPU_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cpu);
REGISTER_ALL_CPU_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cpu)
} // namespace native
} // namespace at

View File

@ -741,7 +741,7 @@ static std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated3d_backward_cpu(
return std::tie(grad_input, grad_weight, grad_bias);
}
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cpu);
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cpu);
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cpu)
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cpu)
} // namespace at::native

View File

@ -14,7 +14,7 @@ using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const
int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel)
DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel)
// averge pooling has same signature for forward and backward

View File

@ -1187,10 +1187,10 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backwar
DEFINE_DISPATCH(NAME##_miopen_stub); \
DEFINE_DISPATCH(NAME##_packed_cudnn_stub); \
DEFINE_DISPATCH(NAME##_packed_miopen_stub); \
REGISTER_NO_CPU_DISPATCH(NAME##_cudnn_stub); \
REGISTER_NO_CPU_DISPATCH(NAME##_miopen_stub); \
REGISTER_NO_CPU_DISPATCH(NAME##_packed_cudnn_stub); \
REGISTER_NO_CPU_DISPATCH(NAME##_packed_miopen_stub); \
REGISTER_NO_CPU_DISPATCH(NAME##_cudnn_stub) \
REGISTER_NO_CPU_DISPATCH(NAME##_miopen_stub) \
REGISTER_NO_CPU_DISPATCH(NAME##_packed_cudnn_stub) \
REGISTER_NO_CPU_DISPATCH(NAME##_packed_miopen_stub) \
\
std::tuple<Tensor, Tensor> NAME( \
const Tensor& _input, \
@ -1415,17 +1415,17 @@ static std::tuple<Tensor, Tensor> quantized_gru_data_legacy(
using tanf_cell_type = SimpleCell<tanh_f, CellParams>;
ONE_HIDDEN_RNN(rnn_tanh, tanf_cell_type)
using relu_cell_type = SimpleCell<relu_f, CellParams>;
ONE_HIDDEN_RNN(rnn_relu, relu_cell_type);
ONE_HIDDEN_RNN(rnn_relu, relu_cell_type)
DEFINE_DISPATCH(lstm_cudnn_stub);
DEFINE_DISPATCH(lstm_packed_cudnn_stub);
DEFINE_DISPATCH(lstm_miopen_stub);
DEFINE_DISPATCH(lstm_packed_miopen_stub);
DEFINE_DISPATCH(lstm_mkldnn_stub);
REGISTER_NO_CPU_DISPATCH(lstm_cudnn_stub);
REGISTER_NO_CPU_DISPATCH(lstm_packed_cudnn_stub);
REGISTER_NO_CPU_DISPATCH(lstm_miopen_stub);
REGISTER_NO_CPU_DISPATCH(lstm_packed_miopen_stub);
REGISTER_NO_CPU_DISPATCH(lstm_cudnn_stub)
REGISTER_NO_CPU_DISPATCH(lstm_packed_cudnn_stub)
REGISTER_NO_CPU_DISPATCH(lstm_miopen_stub)
REGISTER_NO_CPU_DISPATCH(lstm_packed_miopen_stub)
std::tuple<Tensor, Tensor, Tensor> lstm(
const Tensor& _input, TensorList hx,
@ -1857,9 +1857,9 @@ static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
// Quantized LSTM cell
using quantized_lstm_cell_dynamic_type = LSTMCell<QuantizedCellParamsDynamic>;
DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx)
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx)
// Helpers for simpler cells
using simple_hx_type = const Tensor&;
@ -1871,21 +1871,21 @@ static simple_hx_type prepare_quantized_hx(simple_hx_type hx) {
using quantized_gru_cell_type = GRUCell<QuantizedCellParams>;
using quantized_gru_cell_dynamic_type = GRUCell<QuantizedCellParamsDynamic>;
DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx);
DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx)
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx);
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx)
// Quantized RNN w/ ReLU cell
using quantized_rnn_relu_cell_type = SimpleCell<relu_f, QuantizedCellParams>;
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx);
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx)
using quantized_rnn_relu_cell_dynamic_type = SimpleCell<relu_f, QuantizedCellParamsDynamic>;
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx);
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx)
// Quantized RNN w/ tanh cell
using quantized_rnn_tanh_cell_type = SimpleCell<tanh_f, QuantizedCellParams>;
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx);
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx)
using quantized_rnn_tanh_cell_dynamic_type = SimpleCell<tanh_f, QuantizedCellParamsDynamic>;
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx);
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx)
namespace {

View File

@ -932,11 +932,11 @@ Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {
Tensor special_multigammaln(const Tensor& self, int64_t p) {
return self.mvlgamma(p);
};
}
Tensor& special_multigammaln_out(const Tensor& self, int64_t p, Tensor& result) {
return at::mvlgamma_out(result, self, p);
};
}
std::tuple<Tensor, Tensor> frexp(const Tensor& self) {
Tensor mantissa = at::empty_like(self);

View File

@ -1415,16 +1415,16 @@ REGISTER_DISPATCH(laguerre_polynomial_l_stub, &laguerre_polynomial_l_kernel)
REGISTER_DISPATCH(legendre_polynomial_p_stub, &legendre_polynomial_p_kernel)
REGISTER_DISPATCH(
shifted_chebyshev_polynomial_t_stub,
&shifted_chebyshev_polynomial_t_kernel);
&shifted_chebyshev_polynomial_t_kernel)
REGISTER_DISPATCH(
shifted_chebyshev_polynomial_u_stub,
&shifted_chebyshev_polynomial_u_kernel);
&shifted_chebyshev_polynomial_u_kernel)
REGISTER_DISPATCH(
shifted_chebyshev_polynomial_v_stub,
&shifted_chebyshev_polynomial_v_kernel);
&shifted_chebyshev_polynomial_v_kernel)
REGISTER_DISPATCH(
shifted_chebyshev_polynomial_w_stub,
&shifted_chebyshev_polynomial_w_kernel);
&shifted_chebyshev_polynomial_w_kernel)
// Might enable AVX512 dispatch after enabling explicit vectorization for them.
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_kernel)
REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_kernel)

View File

@ -241,5 +241,5 @@ static void multinomial_with_replacement_kernel_impl(
REGISTER_DISPATCH(
multinomial_with_replacement_stub,
&multinomial_with_replacement_kernel_impl);
&multinomial_with_replacement_kernel_impl)
} // namespace at::native

View File

@ -22,7 +22,7 @@ inline namespace CPU_CAPABILITY {
constexpr auto kF32RegisterPairsPerIteration = 4;
constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2;
constexpr auto kF32ElementsPerRegister = vec::Vectorized<float>::size();
constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;;
constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;
namespace {
template <typename T>
@ -328,8 +328,8 @@ void fp16_gemv_trans(
#if !defined(C10_MOBILE)
// NOTE: we don't *need* to go through dispatch for the ARM-only
// implementation right now, but we will need it when we cover x86.
REGISTER_DISPATCH(fp16_dot_with_fp32_arith_stub, &fp16_dot_with_fp32_arith);
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans);
REGISTER_DISPATCH(fp16_dot_with_fp32_arith_stub, &fp16_dot_with_fp32_arith)
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans)
#else
#endif // defined(__aarch64__) && !defined(C10_MOBILE)

View File

@ -8,7 +8,7 @@ namespace at::native {
#if !defined(C10_MOBILE)
using fp16_dot_fn = float(*)(const Half*, const Half*, int64_t);
using fp16_gemv_fn = void(*)(int, int, float, const Half*, int, const Half*, int, float, Half*, int);
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_with_fp32_arith_stub);
DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub);
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_with_fp32_arith_stub)
DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub)
#endif // !defined(C10_MOBILE)
} // namespace at::native

View File

@ -1295,15 +1295,15 @@ ALSO_REGISTER_AVX512_DISPATCH(softmax_lastdim_kernel, &softmax_lastdim_kernel_im
ALSO_REGISTER_AVX512_DISPATCH(log_softmax_lastdim_kernel, &log_softmax_lastdim_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(
softmax_backward_lastdim_kernel,
&softmax_backward_lastdim_kernel_impl);
&softmax_backward_lastdim_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(
log_softmax_backward_lastdim_kernel,
&log_softmax_backward_lastdim_kernel_impl);
&log_softmax_backward_lastdim_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(softmax_kernel, &softmax_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(log_softmax_kernel, &log_softmax_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(softmax_backward_kernel, &softmax_backward_kernel_impl)
ALSO_REGISTER_AVX512_DISPATCH(
log_softmax_backward_kernel,
&log_softmax_backward_kernel_impl);
&log_softmax_backward_kernel_impl)
} // namespace at::native

View File

@ -830,15 +830,15 @@ REGISTER_DISPATCH(special_i0e_stub, &CPU_CAPABILITY::i0e_kernel)
REGISTER_DISPATCH(special_ndtri_stub, &CPU_CAPABILITY::ndtri_kernel)
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &CPU_CAPABILITY::modified_bessel_k0_kernel)
REGISTER_DISPATCH(special_modified_bessel_k1_stub, &CPU_CAPABILITY::modified_bessel_k1_kernel)
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(ceil);
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(floor);
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(round);
IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sqrt);
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(trunc);
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(i0);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sin);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(cos);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(tan);
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(ceil)
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(floor)
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(round)
IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sqrt)
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(trunc)
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(i0)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sin)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(cos)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(tan)
// The following kernels are compute-intensive & are compiled with both AVX512
// & AVX2
@ -871,19 +871,19 @@ REGISTER_DISPATCH(special_bessel_y1_stub, &CPU_CAPABILITY::bessel_y1_kernel)
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &CPU_CAPABILITY::modified_bessel_i0_kernel)
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &CPU_CAPABILITY::modified_bessel_i1_kernel)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(acos);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(asin);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(atan);
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erf);
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfc);
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfinv);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(exp);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(expm1);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log10);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log1p);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log2);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(tanh);
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(lgamma);
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(acos)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(asin)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(atan)
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erf)
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfc)
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfinv)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(exp)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(expm1)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log10)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log1p)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log2)
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(tanh)
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(lgamma)
} // namespace at::native

View File

@ -760,6 +760,6 @@ std::tuple<Tensor, Tensor> conv_depthwise2d_backward_cuda(
grad_weight);
}
REGISTER_CUDA_DISPATCH(conv_depthwise2d_backward_stub, &conv_depthwise2d_backward_cuda);
REGISTER_CUDA_DISPATCH(conv_depthwise2d_backward_stub, &conv_depthwise2d_backward_cuda)
} // namespace at::native

View File

@ -695,7 +695,7 @@ std::tuple<Tensor, Tensor, Tensor> conv_depthwise3d_backward_cuda(
}
REGISTER_CUDA_DISPATCH(conv_depthwise3d_backward_stub, &conv_depthwise3d_backward_cuda);
REGISTER_CUDA_DISPATCH(conv_depthwise3d_backward_stub, &conv_depthwise3d_backward_cuda)
#undef DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION
#undef DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS

View File

@ -23,6 +23,6 @@ Tensor flatten_indices_cuda_kernel(const Tensor& indices, IntArrayRef size) {
}
REGISTER_CUDA_DISPATCH(flatten_indices_stub, &flatten_indices_cuda_kernel);
REGISTER_CUDA_DISPATCH(flatten_indices_stub, &flatten_indices_cuda_kernel)
} // namespace at::native

View File

@ -483,6 +483,6 @@ REGISTER_DISPATCH(put_stub, &put_kernel)
REGISTER_DISPATCH(take_stub, &take_kernel)
REGISTER_DISPATCH(flip_stub, &flip_kernel)
REGISTER_CUDA_DISPATCH(index_put_kernel_quantized_stub, &index_put_kernel_quantized_cuda);
REGISTER_CUDA_DISPATCH(index_put_kernel_quantized_stub, &index_put_kernel_quantized_cuda)
} // namespace at::native

View File

@ -684,7 +684,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
}
}
REGISTER_CUDA_DISPATCH(index_put_with_sort_stub, &index_put_with_sort_kernel);
REGISTER_CUDA_DISPATCH(index_put_with_sort_stub, &index_put_with_sort_kernel)
void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<Tensor>>& indices, const Tensor & value, double scale, int zero_point, bool unsafe) {
if (indices.size() > (size_t)self.dim()) {
@ -784,7 +784,7 @@ void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<
}
}
REGISTER_CUDA_DISPATCH(index_put_with_sort_quantized_stub, &index_put_with_sort_quantized);
REGISTER_CUDA_DISPATCH(index_put_with_sort_quantized_stub, &index_put_with_sort_quantized)
} //anonymous
@ -1687,7 +1687,7 @@ void masked_fill_kernel_quantized(TensorIterator& iter, const Scalar& value, dou
});
}
REGISTER_CUDA_DISPATCH(masked_fill_kernel_quantized_stub, &masked_fill_kernel_quantized);
REGISTER_CUDA_DISPATCH(masked_fill_kernel_quantized_stub, &masked_fill_kernel_quantized)
} // anonymous namespace

View File

@ -138,19 +138,19 @@ void lazy_ldl_solve(
}
REGISTER_CUDA_DISPATCH(cholesky_stub, &lazy_cholesky_kernel)
REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &lazy_cholesky_inverse_kernel);
REGISTER_CUDA_DISPATCH(lu_factor_stub, &lazy_lu_factor);
REGISTER_CUDA_DISPATCH(ldl_factor_stub, &lazy_ldl_factor);
REGISTER_CUDA_DISPATCH(ldl_solve_stub, &lazy_ldl_solve);
REGISTER_CUDA_DISPATCH(triangular_solve_stub, &lazy_triangular_solve_kernel);
REGISTER_CUDA_DISPATCH(orgqr_stub, &lazy_orgqr_kernel);
REGISTER_CUDA_DISPATCH(ormqr_stub, &lazy_ormqr_kernel);
REGISTER_CUDA_DISPATCH(geqrf_stub, &lazy_geqrf_kernel);
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &lazy_linalg_eigh_kernel);
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &lazy_linalg_eig_kernel);
REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &lazy_cholesky_inverse_kernel)
REGISTER_CUDA_DISPATCH(lu_factor_stub, &lazy_lu_factor)
REGISTER_CUDA_DISPATCH(ldl_factor_stub, &lazy_ldl_factor)
REGISTER_CUDA_DISPATCH(ldl_solve_stub, &lazy_ldl_solve)
REGISTER_CUDA_DISPATCH(triangular_solve_stub, &lazy_triangular_solve_kernel)
REGISTER_CUDA_DISPATCH(orgqr_stub, &lazy_orgqr_kernel)
REGISTER_CUDA_DISPATCH(ormqr_stub, &lazy_ormqr_kernel)
REGISTER_CUDA_DISPATCH(geqrf_stub, &lazy_geqrf_kernel)
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &lazy_linalg_eigh_kernel)
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &lazy_linalg_eig_kernel)
REGISTER_CUDA_DISPATCH(svd_stub, &lazy_svd_kernel)
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lazy_lu_solve);
REGISTER_CUDA_DISPATCH(lstsq_stub, &lazy_lstsq_kernel);
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lazy_lu_solve)
REGISTER_CUDA_DISPATCH(lstsq_stub, &lazy_lstsq_kernel)
} // anonymous namespace
// Old style dispatches

View File

@ -828,6 +828,6 @@ std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose2d_backward_cuda(
return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
}
REGISTER_CUDA_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cuda);
REGISTER_CUDA_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cuda)
} // namespace at::native

View File

@ -1011,6 +1011,6 @@ std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose3d_backward_cuda(
return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
}
REGISTER_CUDA_DISPATCH(slow_conv_transpose3d_backward_stub, &slow_conv_transpose3d_backward_cuda);
REGISTER_CUDA_DISPATCH(slow_conv_transpose3d_backward_stub, &slow_conv_transpose3d_backward_cuda)
} // namespace at::native

View File

@ -608,7 +608,7 @@ std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated3d_backward_cuda(
return std::tie(grad_input, grad_weight, grad_bias);
}
REGISTER_CUDA_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cuda);
REGISTER_CUDA_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cuda);
REGISTER_CUDA_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cuda)
REGISTER_CUDA_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cuda)
} // namespace at::native

View File

@ -90,13 +90,13 @@ void aminmax_allreduce_kernel_impl(const Tensor& input, Tensor& min_result, Tens
} // namespace (anonymous)
REGISTER_CUDA_DISPATCH(min_stub, &min_kernel_impl);
REGISTER_CUDA_DISPATCH(max_stub, &max_kernel_impl);
REGISTER_CUDA_DISPATCH(min_all_stub, &min_all_kernel_impl);
REGISTER_CUDA_DISPATCH(max_all_stub, &max_all_kernel_impl);
REGISTER_CUDA_DISPATCH(aminmax_allreduce_stub, &aminmax_allreduce_kernel_impl);
REGISTER_CUDA_DISPATCH(aminmax_stub, &aminmax_kernel_impl);
REGISTER_CUDA_DISPATCH(min_stub, &min_kernel_impl)
REGISTER_CUDA_DISPATCH(max_stub, &max_kernel_impl)
REGISTER_CUDA_DISPATCH(min_all_stub, &min_all_kernel_impl)
REGISTER_CUDA_DISPATCH(max_all_stub, &max_all_kernel_impl)
REGISTER_CUDA_DISPATCH(aminmax_allreduce_stub, &aminmax_allreduce_kernel_impl)
REGISTER_CUDA_DISPATCH(aminmax_stub, &aminmax_kernel_impl)
REGISTER_CUDA_DISPATCH(norm_stub, &norm_kernel_cuda);
REGISTER_CUDA_DISPATCH(norm_stub, &norm_kernel_cuda)
} // namespace at::native

View File

@ -109,7 +109,7 @@ void cumprod_cuda_kernel(const Tensor& result, const Tensor& self, int64_t dim)
}
}
REGISTER_CUDA_DISPATCH(cumsum_stub, &cumsum_cuda_kernel);
REGISTER_CUDA_DISPATCH(cumprod_stub, &cumprod_cuda_kernel);
REGISTER_CUDA_DISPATCH(cumsum_stub, &cumsum_cuda_kernel)
REGISTER_CUDA_DISPATCH(cumprod_stub, &cumprod_cuda_kernel)
} // namespace at::native

View File

@ -122,6 +122,6 @@ void sort_cuda_kernel(
// TODO: we should handle this accordingly when we start using REGISTER_HIP_DISPATCH,
// since REGISTER_DISPATCH won't work in this cpp file.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_CUDA_DISPATCH(sort_stub, &sort_cuda_kernel);
REGISTER_CUDA_DISPATCH(sort_stub, &sort_cuda_kernel)
} // namespace at::native

View File

@ -204,8 +204,8 @@ void sparse_mask_projection_out_cuda_kernel(
}
REGISTER_CUDA_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cuda_kernel);
REGISTER_CUDA_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cuda_kernel);
REGISTER_CUDA_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cuda_kernel);
REGISTER_CUDA_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cuda_kernel)
REGISTER_CUDA_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cuda_kernel)
REGISTER_CUDA_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cuda_kernel)
} // namespace at::native

View File

@ -18,6 +18,6 @@ void isin_default_kernel_gpu(
} // anonymous namespace
REGISTER_CUDA_DISPATCH(isin_default_stub, &isin_default_kernel_gpu);
REGISTER_CUDA_DISPATCH(isin_default_stub, &isin_default_kernel_gpu)
} // namespace at::native

View File

@ -98,5 +98,5 @@ void mode_kernel_impl(
}
}
REGISTER_CUDA_DISPATCH(mode_stub, &mode_kernel_impl);
REGISTER_CUDA_DISPATCH(mode_stub, &mode_kernel_impl)
} // namespace at::native

View File

@ -1454,7 +1454,7 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper)
}
REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -1670,7 +1670,7 @@ static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& i
}
}
REGISTER_CUDA_DISPATCH(lu_factor_stub, &lu_factor);
REGISTER_CUDA_DISPATCH(lu_factor_stub, &lu_factor)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -1764,7 +1764,7 @@ void triangular_solve_kernel(const Tensor& A, const Tensor& B, bool left, bool u
}
}
REGISTER_CUDA_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_CUDA_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ orgqr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -1782,7 +1782,7 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) {
#endif
}
REGISTER_CUDA_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
REGISTER_CUDA_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
#ifdef USE_LINALG_SOLVER
@ -1794,7 +1794,7 @@ void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, b
#endif
}
REGISTER_CUDA_DISPATCH(ormqr_stub, &ormqr_kernel);
REGISTER_CUDA_DISPATCH(ormqr_stub, &ormqr_kernel)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ qr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -1878,7 +1878,7 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) {
#endif
}
REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel);
REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel)
template <typename scalar_t>
static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
@ -2007,7 +2007,7 @@ void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, c
#endif
}
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -2093,7 +2093,7 @@ void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos,
});
}
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -2579,7 +2579,7 @@ if (n <= 8) {
#endif // ifdef USE_LINALG_SOLVER
}
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_kernel);
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_kernel)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -2761,7 +2761,7 @@ void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& /*rank*/, Tensor& /*singul
}
}
REGISTER_CUDA_DISPATCH(lstsq_stub, &lstsq_kernel);
REGISTER_CUDA_DISPATCH(lstsq_stub, &lstsq_kernel)
#if defined(BUILD_LAZY_CUDA_LINALG)

View File

@ -2627,59 +2627,59 @@ std::pair<Tensor, hidden_type> _cudnn_impl(
std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
}
#define ONE_HIDDEN_RNN(NAME, MODE) \
void NAME##_cudnn( \
Tensor& output, \
Tensor& hy, \
const Tensor& input, \
const Tensor& hx, \
TensorList params, \
bool has_biases, \
int64_t num_layers, \
double dropout_p, \
bool train, \
bool bidirectional, \
bool batch_first) { \
std::tie(output, hy) = _cudnn_impl( \
input, \
hx, \
params, \
has_biases, \
MODE, \
num_layers, \
dropout_p, \
train, \
bidirectional, \
batch_first); \
} \
\
void NAME##_packed_cudnn( \
Tensor& output, \
Tensor& hy, \
const Tensor& data, \
const Tensor& batch_sizes, \
const Tensor& hx, \
TensorList params, \
bool has_biases, \
int64_t num_layers, \
double dropout_p, \
bool train, \
bool bidirectional) { \
std::tie(output, hy) = _cudnn_impl( \
data, \
batch_sizes, \
hx, \
params, \
has_biases, \
MODE, \
num_layers, \
dropout_p, \
train, \
bidirectional); \
} \
\
REGISTER_CUDA_DISPATCH(NAME##_cudnn_stub, &NAME##_cudnn); \
REGISTER_CUDA_DISPATCH(NAME##_packed_cudnn_stub, &NAME##_packed_cudnn);
#define ONE_HIDDEN_RNN(NAME, MODE) \
void NAME##_cudnn( \
Tensor& output, \
Tensor& hy, \
const Tensor& input, \
const Tensor& hx, \
TensorList params, \
bool has_biases, \
int64_t num_layers, \
double dropout_p, \
bool train, \
bool bidirectional, \
bool batch_first) { \
std::tie(output, hy) = _cudnn_impl( \
input, \
hx, \
params, \
has_biases, \
MODE, \
num_layers, \
dropout_p, \
train, \
bidirectional, \
batch_first); \
} \
\
void NAME##_packed_cudnn( \
Tensor& output, \
Tensor& hy, \
const Tensor& data, \
const Tensor& batch_sizes, \
const Tensor& hx, \
TensorList params, \
bool has_biases, \
int64_t num_layers, \
double dropout_p, \
bool train, \
bool bidirectional) { \
std::tie(output, hy) = _cudnn_impl( \
data, \
batch_sizes, \
hx, \
params, \
has_biases, \
MODE, \
num_layers, \
dropout_p, \
train, \
bidirectional); \
} \
\
REGISTER_CUDA_DISPATCH(NAME##_cudnn_stub, &NAME##_cudnn) \
REGISTER_CUDA_DISPATCH(NAME##_packed_cudnn_stub, &NAME##_packed_cudnn)
ONE_HIDDEN_RNN(gru, CUDNN_GRU)
ONE_HIDDEN_RNN(rnn_tanh, CUDNN_RNN_TANH)
@ -2743,8 +2743,8 @@ void lstm_packed_cudnn(
cy = std::get<1>(result.second);
}
REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn);
REGISTER_CUDA_DISPATCH(lstm_packed_cudnn_stub, &lstm_packed_cudnn);
REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn)
REGISTER_CUDA_DISPATCH(lstm_packed_cudnn_stub, &lstm_packed_cudnn)
} // namespace

View File

@ -1696,9 +1696,9 @@ Tensor miopen_convolution_relu(
}
}
REGISTER_CUDA_DISPATCH(miopen_convolution_backward_stub, &miopen_convolution_backward);
REGISTER_CUDA_DISPATCH(miopen_convolution_transpose_backward_stub, &miopen_convolution_transpose_backward);
REGISTER_CUDA_DISPATCH(miopen_depthwise_convolution_backward_stub, &miopen_depthwise_convolution_backward);
REGISTER_CUDA_DISPATCH(miopen_convolution_backward_stub, &miopen_convolution_backward)
REGISTER_CUDA_DISPATCH(miopen_convolution_transpose_backward_stub, &miopen_convolution_transpose_backward)
REGISTER_CUDA_DISPATCH(miopen_depthwise_convolution_backward_stub, &miopen_depthwise_convolution_backward)
}} // namespace

View File

@ -876,8 +876,8 @@ void NAME##_packed_miopen(Tensor& output, Tensor& hy, \
has_biases, MODE, num_layers, dropout_p, train, bidirectional); \
} \
\
REGISTER_CUDA_DISPATCH(NAME##_miopen_stub, &NAME##_miopen); \
REGISTER_CUDA_DISPATCH(NAME##_packed_miopen_stub, &NAME##_packed_miopen);
REGISTER_CUDA_DISPATCH(NAME##_miopen_stub, &NAME##_miopen) \
REGISTER_CUDA_DISPATCH(NAME##_packed_miopen_stub, &NAME##_packed_miopen)
ONE_HIDDEN_RNN(gru, miopenGRU)
ONE_HIDDEN_RNN(rnn_tanh, miopenRNNTANH)
@ -905,8 +905,8 @@ void lstm_packed_miopen(Tensor& output, Tensor& hy, Tensor& cy,
cy = std::get<1>(result.second);
}
REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen);
REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen);
REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen)
REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen)
} // anonymous namespace
}} //namespace native.

View File

@ -575,7 +575,7 @@ Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization,
#else
namespace at { namespace native {
REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub);
REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub)
Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
TORCH_CHECK(false, "fft: ATen not compiled with FFT support");

View File

@ -28,9 +28,9 @@ Tensor mkldnn_convolution(
TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
}
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub);
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_stub);
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub);
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub)
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_stub)
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub)
}}
@ -891,7 +891,7 @@ Tensor mkldnn_convolution_transpose_pointwise(
);
}
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_backward_stub, &mkldnn_convolution_backward);
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_backward_stub, &mkldnn_convolution_backward)
namespace{
Tensor mkldnn_convolution_transpose(
@ -1044,8 +1044,8 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_transpose_backward(
}
}
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_stub, &mkldnn_convolution_transpose);
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub, &mkldnn_convolution_transpose_backward);
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_stub, &mkldnn_convolution_transpose)
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub, &mkldnn_convolution_transpose_backward)
TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(

View File

@ -71,7 +71,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la
TORCH_CHECK(false, "mkldnn_rnn_layer_backward: ATen not compiled with MKLDNN support");
}
REGISTER_NO_CPU_DISPATCH(lstm_mkldnn_stub);
REGISTER_NO_CPU_DISPATCH(lstm_mkldnn_stub)
} // namespace at::native

View File

@ -154,14 +154,14 @@ const std::map<c10::string_view, AttrFunction>& fusion_unary_attr_map() {
{"gelu", attr_func_gelu},
};
return fusion_attr_map;
};
}
const std::map<c10::string_view, ideep::algorithm>& fusion_unary_alg_map() {
static const std::map<c10::string_view, ideep::algorithm> fusion_attr_map{
{"relu", {ideep::algorithm::eltwise_relu}},
};
return fusion_attr_map;
};
}
const std::map<c10::string_view, ideep::algorithm>& fusion_binary_alg_map() {
static const std::map<c10::string_view, ideep::algorithm> fusion_attr_map{
@ -171,7 +171,7 @@ const std::map<c10::string_view, ideep::algorithm>& fusion_binary_alg_map() {
{"div", {ideep::algorithm::binary_div}},
};
return fusion_attr_map;
};
}
#endif // AT_MKLDNN_ENABLED()
}}

View File

@ -356,7 +356,7 @@ TORCH_IMPL_FUNC(bitwise_not_out_mps)(const Tensor& self, const Tensor& output) {
mps::_bitwise_not_out_mps(self, output);
}
REGISTER_MPS_DISPATCH(lshift_stub, &lshift_kernel_mps);
REGISTER_MPS_DISPATCH(rshift_stub, &rshift_kernel_mps);
REGISTER_MPS_DISPATCH(lshift_stub, &lshift_kernel_mps)
REGISTER_MPS_DISPATCH(rshift_stub, &rshift_kernel_mps)
} // namespace at::native

View File

@ -17,7 +17,7 @@
namespace at::native {
DEFINE_DISPATCH(nested_dense_elementwise_stub);
REGISTER_NO_CPU_DISPATCH(nested_dense_elementwise_stub);
REGISTER_NO_CPU_DISPATCH(nested_dense_elementwise_stub)
std::pair<NestedTensorImpl*, NestedTensorImpl*>
static get_elementwise_nested_tensor_impl(

View File

@ -114,7 +114,7 @@ void _nested_op_dense_esuhm_cuda(Tensor& result, const Tensor& self, const Tenso
});
}
REGISTER_CUDA_DISPATCH(nested_dense_elementwise_stub, &_nested_op_dense_esuhm_cuda);
REGISTER_CUDA_DISPATCH(nested_dense_elementwise_stub, &_nested_op_dense_esuhm_cuda)
} // namespace native
} // namespace at

View File

@ -4264,21 +4264,21 @@ void index_put_kernel_quantized_cpu(TensorIterator& iter, IntArrayRef index_size
// AVX2 kernels would be used instead. Ref: GH 56992.
#if defined(_WIN32)
REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub,
&dequantize_tensor_per_channel_affine_cpu);
&dequantize_tensor_per_channel_affine_cpu)
REGISTER_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
&dequantize_tensor_per_channel_float_qparams_cpu);
&dequantize_tensor_per_channel_float_qparams_cpu)
REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub,
&fake_quant_per_channel_cachemask_cpu);
&fake_quant_per_channel_cachemask_cpu)
REGISTER_DISPATCH(qavg_pool2d_nhwc_stub, &qavg_pool2d_nhwc_kernel)
REGISTER_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel)
#else
// These kernels are dispatched to AVX512
ALSO_REGISTER_AVX512_DISPATCH(dequantize_tensor_per_channel_affine_stub,
&dequantize_tensor_per_channel_affine_cpu);
&dequantize_tensor_per_channel_affine_cpu)
ALSO_REGISTER_AVX512_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
&dequantize_tensor_per_channel_float_qparams_cpu);
&dequantize_tensor_per_channel_float_qparams_cpu)
ALSO_REGISTER_AVX512_DISPATCH(fake_quant_per_channel_cachemask_stub,
&fake_quant_per_channel_cachemask_cpu);
&fake_quant_per_channel_cachemask_cpu)
ALSO_REGISTER_AVX512_DISPATCH(qavg_pool2d_nhwc_stub, &qavg_pool2d_nhwc_kernel)
ALSO_REGISTER_AVX512_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel)
#endif // CPU_CAPABILITY_AVX512 && _WIN32
@ -4286,17 +4286,17 @@ ALSO_REGISTER_AVX512_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel)
// The kernels below are dispatched to AVX2 because they don't perform as well
// with AVX512. We might revisit this decision in the near future.
REGISTER_DISPATCH(dequantize_tensor_per_tensor_affine_stub,
&dequantize_tensor_per_tensor_affine_cpu);
&dequantize_tensor_per_tensor_affine_cpu)
REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub,
&fake_quantize_learnable_tensor_grad_kernel_cpu);
&fake_quantize_learnable_tensor_grad_kernel_cpu)
REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub,
&fake_quantize_tensor_cachemask_kernel);
&fake_quantize_tensor_cachemask_kernel)
REGISTER_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub,
&fake_quantize_tensor_cachemask_tensor_qparams_kernel);
&fake_quantize_tensor_cachemask_tensor_qparams_kernel)
REGISTER_DISPATCH(qadaptive_avg_pool2d_nhwc_stub,
&qadaptive_avg_pool2d_nhwc_kernel);
&qadaptive_avg_pool2d_nhwc_kernel)
REGISTER_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub,
&qadaptive_avg_pool3d_ndhwc_kernel);
&qadaptive_avg_pool3d_ndhwc_kernel)
REGISTER_DISPATCH(qadd_relu_stub, &qadd_kernel<true>)
REGISTER_DISPATCH(qadd_scalar_relu_stub, &qadd_scalar_kernel<true>)
REGISTER_DISPATCH(qadd_scalar_stub, &qadd_scalar_kernel<false>)
@ -4325,32 +4325,32 @@ REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel)
REGISTER_DISPATCH(qthreshold_stub, &qthreshold_kernel)
REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel)
REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub,
&fake_quantize_learnable_channel_grad_kernel_cpu);
&fake_quantize_learnable_channel_grad_kernel_cpu)
REGISTER_DISPATCH(
quantize_tensor_per_tensor_affine_stub,
&quantize_tensor_per_tensor_affine_cpu);
&quantize_tensor_per_tensor_affine_cpu)
REGISTER_DISPATCH(
quantize_tensor_per_channel_affine_stub,
&quantize_tensor_per_channel_affine_cpu);
&quantize_tensor_per_channel_affine_cpu)
REGISTER_DISPATCH(
quantize_tensor_per_channel_float_qparams_stub,
&quantize_tensor_per_channel_float_qparams_cpu);
&quantize_tensor_per_channel_float_qparams_cpu)
REGISTER_DISPATCH(quantized_normalize_stub, &quantized_normalize_kernel)
REGISTER_DISPATCH(quantized_groupnorm_nhwc_stub, &quantized_groupnorm_nhwc_kernel)
REGISTER_DISPATCH(qupsample_bilinear2d_nhwc_stub,
&qupsample_bilinear2d_nhwc_kernel);
&qupsample_bilinear2d_nhwc_kernel)
REGISTER_DISPATCH(
quantize_tensor_per_tensor_affine_sub_byte_stub,
&quantize_tensor_per_tensor_affine_sub_byte_cpu);
&quantize_tensor_per_tensor_affine_sub_byte_cpu)
REGISTER_DISPATCH(
dequantize_tensor_per_tensor_affine_sub_byte_stub,
&dequantize_tensor_per_tensor_affine_sub_byte_cpu);
&dequantize_tensor_per_tensor_affine_sub_byte_cpu)
REGISTER_DISPATCH(
masked_fill_kernel_quantized_stub,
&masked_fill_kernel_quantized_cpu);
&masked_fill_kernel_quantized_cpu)
REGISTER_DISPATCH(
index_put_kernel_quantized_stub,
&index_put_kernel_quantized_cpu);
&index_put_kernel_quantized_cpu)
REGISTER_DISPATCH(qmean_inner_dim_stub, &qmean_inner_dim_kernel)
REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel)
} // namespace at::native

View File

@ -22,11 +22,11 @@ Tensor flatten_indices_cpu_kernel(const Tensor& indices, IntArrayRef size) {
}
REGISTER_ARCH_DISPATCH(flatten_indices_stub, DEFAULT, &flatten_indices_cpu_kernel);
REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_ARCH_DISPATCH(flatten_indices_stub, DEFAULT, &flatten_indices_cpu_kernel)
REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
} // namespace at::native

View File

@ -156,24 +156,24 @@ void sparse_mask_projection_out_cpu_kernel(
}
REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel);
REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel)
REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
}

View File

@ -512,10 +512,10 @@ Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indice
return _sparse_compressed_tensor_unsafe_template<REQUIRED_LAYOUT>(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); \
}
SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr);
SPARSE_COMPRESSED_TENSOR_UNSAFE(csc, kSparseCsc);
SPARSE_COMPRESSED_TENSOR_UNSAFE(bsr, kSparseBsr);
SPARSE_COMPRESSED_TENSOR_UNSAFE(bsc, kSparseBsc);
SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr)
SPARSE_COMPRESSED_TENSOR_UNSAFE(csc, kSparseCsc)
SPARSE_COMPRESSED_TENSOR_UNSAFE(bsr, kSparseBsr)
SPARSE_COMPRESSED_TENSOR_UNSAFE(bsc, kSparseBsc)
static DimVector _estimate_sparse_compressed_tensor_size(
const Tensor& compressed_indices,

View File

@ -433,42 +433,42 @@ Tensor& zero_sparse_csr_(Tensor& self) {
}
#define CREATE_UNARY_UFUNC(op_name) \
CREATE_UNARY_UFUNC_OUT(op_name); \
CREATE_UNARY_UFUNC_FUNCTIONAL(op_name); \
CREATE_UNARY_UFUNC_INPLACE(op_name);
CREATE_UNARY_UFUNC_OUT(op_name) \
CREATE_UNARY_UFUNC_FUNCTIONAL(op_name) \
CREATE_UNARY_UFUNC_INPLACE(op_name)
#define CREATE_UNARY_UFUNC_NO_INPLACE(op_name) \
CREATE_UNARY_UFUNC_OUT(op_name); \
CREATE_UNARY_UFUNC_FUNCTIONAL(op_name);
CREATE_UNARY_UFUNC_OUT(op_name) \
CREATE_UNARY_UFUNC_FUNCTIONAL(op_name)
// Exhaustive list of the unary ufuncs supported by sparse compressed
CREATE_UNARY_UFUNC(abs);
CREATE_UNARY_UFUNC(asin);
CREATE_UNARY_UFUNC(asinh);
CREATE_UNARY_UFUNC(atan);
CREATE_UNARY_UFUNC(atanh);
CREATE_UNARY_UFUNC(ceil);
CREATE_UNARY_UFUNC(deg2rad);
CREATE_UNARY_UFUNC(erf);
CREATE_UNARY_UFUNC(erfinv);
CREATE_UNARY_UFUNC(expm1);
CREATE_UNARY_UFUNC(floor);
CREATE_UNARY_UFUNC(frac);
CREATE_UNARY_UFUNC(log1p);
CREATE_UNARY_UFUNC(neg);
CREATE_UNARY_UFUNC(rad2deg);
CREATE_UNARY_UFUNC(sign);
CREATE_UNARY_UFUNC(sin);
CREATE_UNARY_UFUNC(sinh);
CREATE_UNARY_UFUNC(sgn);
CREATE_UNARY_UFUNC(sqrt);
CREATE_UNARY_UFUNC(tan);
CREATE_UNARY_UFUNC(tanh);
CREATE_UNARY_UFUNC(trunc);
CREATE_UNARY_UFUNC(conj_physical);
CREATE_UNARY_UFUNC(abs)
CREATE_UNARY_UFUNC(asin)
CREATE_UNARY_UFUNC(asinh)
CREATE_UNARY_UFUNC(atan)
CREATE_UNARY_UFUNC(atanh)
CREATE_UNARY_UFUNC(ceil)
CREATE_UNARY_UFUNC(deg2rad)
CREATE_UNARY_UFUNC(erf)
CREATE_UNARY_UFUNC(erfinv)
CREATE_UNARY_UFUNC(expm1)
CREATE_UNARY_UFUNC(floor)
CREATE_UNARY_UFUNC(frac)
CREATE_UNARY_UFUNC(log1p)
CREATE_UNARY_UFUNC(neg)
CREATE_UNARY_UFUNC(rad2deg)
CREATE_UNARY_UFUNC(sign)
CREATE_UNARY_UFUNC(sin)
CREATE_UNARY_UFUNC(sinh)
CREATE_UNARY_UFUNC(sgn)
CREATE_UNARY_UFUNC(sqrt)
CREATE_UNARY_UFUNC(tan)
CREATE_UNARY_UFUNC(tanh)
CREATE_UNARY_UFUNC(trunc)
CREATE_UNARY_UFUNC(conj_physical)
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")
static CREATE_UNARY_UFUNC(relu);
static CREATE_UNARY_UFUNC(relu)
C10_DIAGNOSTIC_POP()
// With addition of `round.decimals` overload, using CREATE_UNARY_UFUNC leads
@ -512,14 +512,14 @@ Tensor& threshold_backward_sparse_compressed_out(
}
// angle, isneginf, isposinf and signbit currently don't have an inplace variant
CREATE_UNARY_UFUNC_NO_INPLACE(angle);
CREATE_UNARY_UFUNC_NO_INPLACE(isneginf);
CREATE_UNARY_UFUNC_NO_INPLACE(isposinf);
CREATE_UNARY_UFUNC_NO_INPLACE(signbit);
CREATE_UNARY_UFUNC_NO_INPLACE(angle)
CREATE_UNARY_UFUNC_NO_INPLACE(isneginf)
CREATE_UNARY_UFUNC_NO_INPLACE(isposinf)
CREATE_UNARY_UFUNC_NO_INPLACE(signbit)
// isnan and isinf don't have an out variant
CREATE_UNARY_UFUNC_FUNCTIONAL(isnan);
CREATE_UNARY_UFUNC_FUNCTIONAL(isinf);
CREATE_UNARY_UFUNC_FUNCTIONAL(isnan)
CREATE_UNARY_UFUNC_FUNCTIONAL(isinf)
template <typename scalar_t>
void addmm_out_sparse_csr_native_cpu(

View File

@ -1230,7 +1230,7 @@ void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j,
}
}
}
};
}
static Tensor& s_addmm_out_sparse_dense_cpu(
Tensor& r,

View File

@ -444,12 +444,12 @@ int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Ten
return static_cast<int64_t>(backend);
}
REGISTER_ARCH_DISPATCH(_fused_sdp_choice_stub, DEFAULT, &_fused_sdp_choice_cpp);
REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
REGISTER_ARCH_DISPATCH(_fused_sdp_choice_stub, DEFAULT, &_fused_sdp_choice_cpp)
REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
int64_t _fused_sdp_choice_meta(
const Tensor& query_,

View File

@ -1387,7 +1387,7 @@ Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tenso
return at::Tensor();
}
REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda);
REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda)
#if defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM)
namespace {

View File

@ -227,23 +227,23 @@ static bool EmbeddingLookupGenericSlowIdx(
}
// clang-format on
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false)
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false)
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false)
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false)
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false)
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false)
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false)
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false)
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true)
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true)
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true)
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true)
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true)
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true)
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true)
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true)
#undef EMBEDDING_IDX_SPECIALIZATION

View File

@ -10,16 +10,16 @@
C10_DEFINE_bool(
caffe2_threadpool_force_inline,
false,
"Force to always run jobs on the calling thread");
"Force to always run jobs on the calling thread")
// Whether or not threadpool caps apply to Android
C10_DEFINE_int(caffe2_threadpool_android_cap, true, "");
C10_DEFINE_int(caffe2_threadpool_android_cap, true, "")
// Whether or not threadpool caps apply to iOS and MacOS
C10_DEFINE_int(caffe2_threadpool_ios_cap, true, "");
C10_DEFINE_int(caffe2_threadpool_macos_cap, true, "");
C10_DEFINE_int(caffe2_threadpool_ios_cap, true, "")
C10_DEFINE_int(caffe2_threadpool_macos_cap, true, "")
C10_DEFINE_int(pthreadpool_size, 0, "Override the default thread pool size.");
C10_DEFINE_int(pthreadpool_size, 0, "Override the default thread pool size.")
namespace caffe2 {

View File

@ -14,7 +14,7 @@ namespace detail {
template <typename Derived>
class _DropoutNd : public torch::nn::Cloneable<Derived> {
public:
_DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)){};
_DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)) {}
explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)

View File

@ -227,7 +227,7 @@ class TORCH_API AdaptiveMaxPoolImpl : public torch::nn::Cloneable<Derived> {
const AdaptiveMaxPoolOptions<output_size_t>& options_)
: options(options_) {}
void reset() override{};
void reset() override {}
/// Pretty prints the `AdaptiveMaxPool{1,2,3}d` module into the given
/// `stream`.

View File

@ -128,7 +128,7 @@ class TORCH_API Optimizer {
std::unique_ptr<OptimizerOptions> defaults)
: Optimizer(
{OptimizerParamGroup(std::move(parameters))},
std::move(defaults)){};
std::move(defaults)) {}
/// Adds the given param_group to the optimizer's param_group list.
void add_param_group(const OptimizerParamGroup& param_group);

View File

@ -9,7 +9,7 @@
namespace caffe2 {
// Required for cpp_custom_type_hack to work
// NOLINTNEXTLINE(bugprone-exception-escape)
CAFFE_KNOWN_TYPE(at::RecordFunction);
CAFFE_KNOWN_TYPE(at::RecordFunction)
} // namespace caffe2
namespace torch::autograd::profiler {

View File

@ -77,7 +77,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
// Subclasses must override this method to return the backend name
virtual const std::string getBackendName() const {
TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
};
}
virtual c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& /* tensors */,

View File

@ -39,7 +39,7 @@ C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING(
GlooDeviceRegistry,
::gloo::transport::Device,
const std::string& /* interface */,
const std::string& /* hostname */);
const std::string& /* hostname */)
#if GLOO_HAVE_TRANSPORT_TCP
static std::shared_ptr<::gloo::transport::Device> makeTCPDevice(
@ -62,8 +62,8 @@ static std::shared_ptr<::gloo::transport::Device> makeTCPDevice(
// Registry priority is per key identifier. We register TCP to `LINUX` for
// the flexibility of other application to override by priority. Register
// TCP to `TCP` for env "GLOO_DEVICE_TRANSPORT" override.
C10_REGISTER_CREATOR(GlooDeviceRegistry, LINUX, makeTCPDevice);
C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice);
C10_REGISTER_CREATOR(GlooDeviceRegistry, LINUX, makeTCPDevice)
C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice)
#endif
#if GLOO_HAVE_TRANSPORT_TCP_TLS

View File

@ -97,7 +97,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
default:
TORCH_CHECK(false, "THis should never happen!");
}
};
}
static BackendType strToBackendType(const std::string& backend) {
if (backend == "undefined") {
@ -113,7 +113,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
} else {
return BackendType::CUSTOM;
}
};
}
// Not used, set for backwards compatibility and only used for TypeDef in
// Ops.cpp
@ -146,11 +146,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
virtual const std::string getBackendName() const {
return backendTypeToString(backendType_);
};
}
BackendType getBackendType() const {
return backendType_;
};
}
virtual void startCoalescing(c10::DeviceType deviceType) {
// only nccl has implemented startCoalescing so only execute for nccl

View File

@ -316,7 +316,7 @@ parseWireSections(const void* data, size_t data_size) {
static const char* kMeta = "meta";
static const char* kPayload = "payload";
}; // namespace
} // namespace
c10::List<at::Tensor> cloneSparseTensors(
const std::vector<at::Tensor>& tensors) {

View File

@ -578,7 +578,7 @@ RangeValue::RangeValue(
SugaredValuePtr RangeValue::iter(const SourceRange& loc, GraphFunction& m) {
return shared_from_this();
};
}
Value* RangeValue::len(const SourceRange& loc, GraphFunction& m) {
if (static_len_) {

View File

@ -2,6 +2,6 @@
namespace torch::jit::mobile::nnc {
C10_DEFINE_REGISTRY(NNCKernelRegistry, NNCKernel);
C10_DEFINE_REGISTRY(NNCKernelRegistry, NNCKernel)
} // namespace torch::jit::mobile::nnc

View File

@ -78,7 +78,7 @@ class TransposeFrozenLinear {
node->replaceAllUsesWith(bias_result);
}
node->destroy();
};
}
void handleBlockAndSubblocks(Block* block) {}

View File

@ -303,7 +303,7 @@ void MKLDNNLayerNormOp(Stack& stack, bool inplace) {
at::native::mkldnn_layer_norm_last_index_weight_bias_f32(
input, shape, weight, bias, eps, inplace);
push(stack, dst);
};
}
Operation BroadOp(const Node* node) {
return [](Stack& stack) {

View File

@ -201,7 +201,7 @@ struct IntegerValueRefiner {
active_refinements_.pop_back();
return block_refinements;
};
}
std::optional<int64_t> tryFindRefinement(Value* v) {
for (const auto& ref : active_refinements_) {

View File

@ -126,7 +126,7 @@ struct ListLenRefiner {
}
active_refinements_.pop_back();
return block_refinements;
};
}
std::optional<int64_t> tryFindRefinement(Value* v) {
for (const auto& ref : active_refinements_) {

View File

@ -705,7 +705,7 @@ static bool is_module(
return module_name.value() == module_qualified_name;
}
return false;
};
}
bool aten_add_alpha_is_one(
const Match& match,

View File

@ -57,7 +57,7 @@ std::shared_ptr<OperatorSet> nn_ops_first_input_preserving() {
"aten::hardswish_(Tensor self) -> Tensor",
});
return ops;
};
}
// Requirements:
// dims : Changed from first argument
@ -70,5 +70,5 @@ std::shared_ptr<OperatorSet> ops_one_tensor_in_shape_transform() {
"aten::flatten(Tensor self, int start_dim, int end_dim) -> Tensor",
});
return ops;
};
}
} // namespace torch::jit

View File

@ -29,7 +29,7 @@ struct BooleanRefinementMapping {
ListRefinement true_refine,
ListRefinement false_refine)
: true_refine_(std::move(true_refine)),
false_refine_(std::move(false_refine)){};
false_refine_(std::move(false_refine)) {}
BooleanRefinementMapping() = default; // empty
static BooleanRefinementMapping FalseRefinements(

View File

@ -36,7 +36,7 @@ std::vector<IValue> boxInputs(const ProcessedNode& pnode) {
} // namespace
C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor)
bool nativeOpIsRegistered(const c10::Symbol& op_name) {
const std::string name(op_name.toQualString());

View File

@ -38,7 +38,7 @@ TORCH_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
return fn(n); \
} \
}; \
C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id);
C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id)
TORCH_DECLARE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
#define REGISTER_NATIVE_OPERATOR_FUNCTOR(name, id, ...) \
@ -49,7 +49,7 @@ TORCH_DECLARE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
} \
}; \
C10_REGISTER_CLASS( \
SRNativeOperatorRegistry, name, SRNativeOperatorFunctor_##id);
SRNativeOperatorRegistry, name, SRNativeOperatorFunctor_##id)
inline at::Tensor create_empty_from(const at::Tensor& t) {
return at::detail::empty_cpu(

View File

@ -142,7 +142,7 @@ IValue pickle_load(const std::vector<char>& data) {
"pickle_load not supported on mobile "
"(see https://github.com/pytorch/pytorch/pull/30108)");
#endif
};
}
// A specialized version of pickle_load that can load custom objects.
c10::IValue pickle_load_obj(std::string_view data) {

View File

@ -87,7 +87,7 @@ void* CodeGen::argToPtr(const BufferArg& bufferArg, const CallArg& callArg) {
case ScalarType::Name: \
return callArg.Name##Ptr();
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
default:

View File

@ -165,7 +165,7 @@ class CodeGen::CallArg {
memcpy(buffer_, &v, sizeof(Type)); \
data_ = (void*)buffer_; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_TYPE_CTOR);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_TYPE_CTOR)
#undef ARG_TYPE_CTOR
void* data() const {
@ -199,7 +199,7 @@ class CodeGen::CallArg {
TORCH_INTERNAL_ASSERT(data_ == (void*)buffer_); \
return (Type*)data_; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_PTR_DEFINE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_PTR_DEFINE)
#undef ARG_PTR_DEFINE
private:

View File

@ -148,7 +148,7 @@ void dispatch_binary_op(std::ostream& os, const BinaryOpNode<Op>* v) {
case ScalarType::Name: \
visit_binary_op<Type>(os, v->lhs(), v->rhs(), v->expr_type()); \
break;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
default:
throw unsupported_dtype();

View File

@ -375,7 +375,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
case ScalarType::Name: \
value = compare_select_op<T, Type>(lhs, rhs, retval1, retval2, cmp_op); \
break;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
default:
throw unsupported_dtype();
@ -407,7 +407,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
value_ = compare_select_op_helper<Type>( \
lhs_v, rhs_v, ret_val1_v, ret_val2_v, cmp_op); \
break;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
default:
throw unsupported_dtype();
@ -418,7 +418,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
TORCH_API void visit(const Name##ImmPtr& v) override { \
value_ = InterpValue(v->value()); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT)
#undef IMM_VISIT
TORCH_API void visit(const BlockPtr& v) override {
@ -472,7 +472,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
case ScalarType::Name: \
this->value_ = InterpValue(castValues<SrcType, Type>(src_dtype, v)); \
break;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DST_TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DST_TYPE_CASE)
#undef DST_TYPE_CASE
#define DST_TYPE_CASE_QUANT(Type, Name, CppType) \
case ScalarType::Name: { \
@ -507,7 +507,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
case ScalarType::Name: \
doCastFromSrc<Type>(src_dtype, dst_dtype, value_); \
break;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, SRC_TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, SRC_TYPE_CASE)
SRC_TYPE_CASE(c10::quint8, QUInt8);
SRC_TYPE_CASE(c10::qint8, QInt8);
#undef SRC_TYPE_CASE
@ -615,7 +615,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
std::vector<Type> v(lanes, value.as<Type>()); \
value_ = InterpValue(v); \
} break;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
default:
throw unsupported_dtype();
@ -758,7 +758,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
} \
value_ = InterpValue(val); \
} break;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
TYPE_CASE(c10::quint8, QUInt8);
TYPE_CASE(c10::qint8, QInt8);
#undef TYPE_CASE
@ -805,7 +805,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
ptr##Name[index[i]] = value[i]; \
} \
} break;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
TYPE_CASE(c10::quint8, QUInt8);
TYPE_CASE(c10::qint8, QInt8);
#undef TYPE_CASE
@ -1268,7 +1268,7 @@ void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) {
impl_->bindVar(bufArg.var(), typed_data); \
break; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
default:
throw unsupported_dtype();

View File

@ -30,7 +30,7 @@ class InterpValue {
Name##values.push_back(v); \
return; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
throw unsupported_dtype();
}
@ -89,9 +89,9 @@ class InterpValue {
} \
return Name##values[0]; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH);
VALUE_AS_DISPATCH(c10::quint8, QUInt8);
VALUE_AS_DISPATCH(c10::qint8, QInt8);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH)
VALUE_AS_DISPATCH(c10::quint8, QUInt8)
VALUE_AS_DISPATCH(c10::qint8, QInt8)
#undef VALUE_AS_DISPATCH
#define VALUE_AS_VEC_DISPATCH(Type, Name) \
@ -102,9 +102,9 @@ VALUE_AS_DISPATCH(c10::qint8, QInt8);
} \
return Name##values; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_VEC_DISPATCH);
VALUE_AS_VEC_DISPATCH(c10::quint8, QUInt8);
VALUE_AS_VEC_DISPATCH(c10::qint8, QInt8);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_VEC_DISPATCH)
VALUE_AS_VEC_DISPATCH(c10::quint8, QUInt8)
VALUE_AS_VEC_DISPATCH(c10::qint8, QInt8)
#undef VALUE_AS_VEC_DISPATCH
template <typename Type>

View File

@ -87,7 +87,7 @@ ExprHandle ExprHandle::operator>>(const ExprHandle& other) const {
#define IMM_EXPR_DECLARE(Type, Name) \
ExprHandle::ExprHandle(Type v) : ExprHandle(Name##Imm::make(v)) {}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE)
#undef IMM_EXPR_DECLARE
ExprHandle sin(const ExprHandle& v) {

View File

@ -112,7 +112,7 @@ class TORCH_API ExprHandle {
}
#define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE)
#undef IMM_EXPR_DECLARE
template <class Op>
@ -274,7 +274,7 @@ class TORCH_API Buf : public ExprNode<Buf> {
ExprPtr initializer() const {
return initializer_;
};
}
ExprPtr qzero() const {
return qzero_;

View File

@ -97,7 +97,7 @@ void DispatchParallel(
FOR_ALL_EXTERNAL_FUNCTIONS(DECLARE_EXTERNAL_FUNCTION)
#if AT_MKLDNN_ENABLED()
DECLARE_EXTERNAL_FUNCTION(nnc_mkldnn_prepacked_conv_run);
DECLARE_EXTERNAL_FUNCTION(nnc_mkldnn_prepacked_conv_run)
#endif
TORCH_API void nnc_aten_free(size_t bufs_num, void** ptrs) noexcept;

View File

@ -119,7 +119,7 @@ using SyncThreadsPtr = NodePtr<SyncThreads>;
#define IMM_DECLARE(Type, Name) \
class Name##Imm; \
using Name##ImmPtr = NodePtr<Name##Imm>;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE)
#undef IMM_DECLARE
} // namespace torch::jit::tensorexpr

View File

@ -86,7 +86,7 @@ class TORCH_API HashProvider : public IRVisitor {
CACHE_GUARD(); \
putHash(v, hash_combine(#Name, v->value())); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT)
#undef IMM_VISIT
void visit(const CastPtr& v) override;

View File

@ -276,7 +276,7 @@ bool immediateIsPositive(const ExprPtr& e) {
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
return imm->value() > 0; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
return false;
}
@ -286,7 +286,7 @@ bool immediateIsZero(const ExprPtr& e) {
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
return imm->value() == 0; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
return false;
}

View File

@ -322,7 +322,7 @@ class Min : public BinaryOpNode<Min> {
private: \
Type value_; \
};
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE)
#undef IMM_DECLARE
// Get immediate by ScalarType.
@ -332,7 +332,7 @@ ExprPtr getImmediateByType(ScalarType immType, T initialVal) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
return alloc<Name##Imm>(Type(initialVal));
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
default:
throw unsupported_dtype();
@ -375,7 +375,7 @@ T immediateAs(const ExprPtr& e) {
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
return imm->value(); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
throw unsupported_dtype();
return 0;
@ -392,7 +392,7 @@ bool immediateEquals(const ExprPtr& e, T val) {
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
return imm->value() == val; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
throw unsupported_dtype();
return false;

View File

@ -116,7 +116,7 @@ ExprPtr IRCloner::mutate(const CompareSelectPtr& v) {
ExprPtr IRCloner::mutate(const Name##ImmPtr& v) { \
return v; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE)
#undef IMM_MUTATE_DEFINE
ExprPtr IRCloner::mutate(const CastPtr& v) {

View File

@ -25,7 +25,7 @@ class TORCH_API IRCloner : public IRMutator {
ExprPtr mutate(const CompareSelectPtr& v) override;
#define IMM_MUTATE_DECLARE(Type, Name) \
ExprPtr mutate(const Name##ImmPtr& v) override;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DECLARE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DECLARE)
#undef IMM_MUTATE_DECLARE
ExprPtr mutate(const CastPtr& v) override;
ExprPtr mutate(const BitCastPtr& v) override;

View File

@ -113,7 +113,7 @@ ExprPtr IRMutator::mutate(const CompareSelectPtr& v) {
ExprPtr IRMutator::mutate(const Name##ImmPtr& v) { \
return v; \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE)
#undef IMM_MUTATE_DEFINE
ExprPtr IRMutator::mutate(const CastPtr& v) {

View File

@ -231,7 +231,7 @@ static void formatImm(std::ostream& os, T v) {
void IRPrinter::visit(const Name##ImmPtr& v) { \
formatImm(os(), v->value()); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT)
#undef IMM_PRINT_VISIT
void IRPrinter::visit(const CastPtr& v) {

View File

@ -32,7 +32,7 @@ class TORCH_API IRPrinter : public IRVisitor {
void visit(const RshiftPtr& v) override;
void visit(const CompareSelectPtr& v) override;
#define IMM_PRINT_VISIT(Type, Name) void visit(const Name##ImmPtr& v) override;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT)
#undef IMM_PRINT_VISIT
void visit(const CastPtr& v) override;
void visit(const BitCastPtr& v) override;

View File

@ -1293,7 +1293,7 @@ bool isOperandInMinMaxTerm(
return true;
}
return false;
};
}
// Simplifies the nested min-max pattern like:
// * Max(Min(x, y), Min(x, z)) => Min(x, Max(y, z))

View File

@ -98,7 +98,7 @@ inline ExprPtr evaluateOp(const ExprPtr& v) {
Type val = eval.value<Type>(); \
return getImmediateByType(v->dtype().scalar_type(), val); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
default:
LOG(FATAL) << "Unsupported datatype: " << v->dtype();

View File

@ -76,7 +76,7 @@ void IRVisitor::visit(const CompareSelectPtr& v) {
#define IMM_VISIT(Type, Name) \
void IRVisitor::visit(const Name##ImmPtr& v) {}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT)
#undef IMM_VISIT
void IRVisitor::visit(const CastPtr& v) {

View File

@ -329,7 +329,7 @@ class LLVMCodeGenImpl : public IRVisitor {
void visit(const CompareSelectPtr& v) override;
#define IMM_VISIT_DECLARE(_1, Name) void visit(const Name##ImmPtr& v) override;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT_DECLARE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT_DECLARE)
#undef IMM_VISIT_DECLARE
void visit(const CastPtr& v) override;

View File

@ -1753,7 +1753,7 @@ std::vector<ForPtr> LoopNest::distributeLoopAndParentsOverInnerLoops(
static bool areEqual(const ExprPtr& expr1, const ExprPtr& expr2) {
auto diff = IRSimplifier::simplify(alloc<Sub>(expr1, expr2));
return diff->isConstant() && (immediateAs<int>(diff) == 0);
};
}
static bool doesExprContainAnyVar(
const ExprPtr& expr,

View File

@ -32,7 +32,7 @@ ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
case ScalarType::Name: \
e = cast<Type>(e); \
break;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
#undef TYPE_CASE
case ScalarType::QUInt8:
e = cast<c10::quint8>(e);

View File

@ -272,7 +272,7 @@ void RegisterizerAnalysis::visit(const ForPtr& v) {
// having hoisted, now we can merge normally.
mergeCurrentScopeIntoParent();
};
}
void RegisterizerAnalysis::visit(const CondPtr& v) {
ExprPtr condition = v->condition();

View File

@ -342,9 +342,9 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor {
stmtStack_.pop_front(); \
}
STMT_ON_STACK(AtomicAdd);
STMT_ON_STACK(Allocate);
STMT_ON_STACK(Free);
STMT_ON_STACK(AtomicAdd)
STMT_ON_STACK(Allocate)
STMT_ON_STACK(Free)
#undef STMT_ON_STACK

View File

@ -22,8 +22,8 @@ AT_FORALL_SCALAR_TYPES_AND7(
Float8_e4m3fn,
Float8_e4m3fnuz,
DTYPE_DEFINE)
DTYPE_DEFINE(c10::quint8, QUInt8);
DTYPE_DEFINE(c10::qint8, QInt8);
DTYPE_DEFINE(c10::quint8, QUInt8)
DTYPE_DEFINE(c10::qint8, QInt8)
#undef DTYPE_DEFINE

Some files were not shown because too many files have changed in this diff Show More