mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[6/N] Fix Wextra-semi warning (#139605)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/139605 Approved by: https://github.com/ezyang
This commit is contained in:
@ -510,7 +510,7 @@ void OperatorEntry::reportSignatureError(const CppSignature& call_signature, con
|
||||
"This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). ",
|
||||
"Please make sure that the function signature matches the signature in the operator registration call."
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
#ifndef STRIP_ERROR_MESSAGES
|
||||
static std::string post_process_dispatch_key_str(std::string dispatch_key) {
|
||||
|
@ -132,7 +132,7 @@ template <typename T, typename std::enable_if_t<is_reduced_floating_point_v<T>,
|
||||
inline void cvt_to_fp32(const __m128i& a, __m256& o);
|
||||
template <> inline void cvt_to_fp32<BFloat16>(const __m128i& a, __m256& o) {
|
||||
cvtbf16_fp32(a, o);
|
||||
};
|
||||
}
|
||||
template <> inline void cvt_to_fp32<Half>(const __m128i& a, __m256& o) {
|
||||
cvtfp16_fp32(a, o);
|
||||
}
|
||||
@ -1071,8 +1071,8 @@ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(c
|
||||
inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const Vectorized<float>& b) { \
|
||||
return cvt_from_fp32<type>(__m256(a), __m256(b)); \
|
||||
}
|
||||
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
|
||||
CONVERT_VECTORIZED_INIT(Half, half);
|
||||
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16)
|
||||
CONVERT_VECTORIZED_INIT(Half, half)
|
||||
|
||||
#else // defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
@ -1096,9 +1096,9 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
|
||||
convert(arr, arr2, K); \
|
||||
return Vectorized<type>::loadu(arr2); \
|
||||
}
|
||||
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
|
||||
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16)
|
||||
#if !(defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
CONVERT_NON_VECTORIZED_INIT(Half, half);
|
||||
CONVERT_NON_VECTORIZED_INIT(Half, half)
|
||||
#endif
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_AVX2)
|
||||
|
@ -120,11 +120,11 @@ void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
}
|
||||
|
||||
void Interpreter::process(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack)));
|
||||
INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack)))
|
||||
}
|
||||
|
||||
void Interpreter::sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case) {
|
||||
INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack, grad_special_case)));
|
||||
INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack, grad_special_case)))
|
||||
}
|
||||
|
||||
} // namespace at::functorch
|
||||
|
@ -10,7 +10,7 @@
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
|
||||
C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback)
|
||||
|
||||
namespace HeapAllocator {
|
||||
|
||||
|
@ -330,7 +330,7 @@ void gemv_fast_path<at::Half>(
|
||||
y,
|
||||
*incy);
|
||||
}
|
||||
INSTANTIATE(c10::BFloat16);
|
||||
INSTANTIATE(c10::BFloat16)
|
||||
#else
|
||||
template <>
|
||||
bool scal_use_fast_path<at::Half>(
|
||||
|
@ -1251,7 +1251,7 @@ embedding_bag(const Tensor &weight, const Tensor &indices,
|
||||
mode, sparse, per_sample_weights, include_last_offset, padding_idx);
|
||||
}
|
||||
return out;
|
||||
};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor>
|
||||
embedding_bag(const Tensor &weight, const Tensor &indices,
|
||||
|
@ -64,7 +64,7 @@ Tensor& max_unpooling2d_forward_out_cpu(
|
||||
}
|
||||
|
||||
return output;
|
||||
};
|
||||
}
|
||||
|
||||
Tensor max_unpooling2d_forward_cpu(
|
||||
const Tensor& self,
|
||||
|
@ -871,7 +871,7 @@ static std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose2d_backward_cpu(
|
||||
return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
|
||||
}
|
||||
|
||||
REGISTER_ALL_CPU_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cpu);
|
||||
REGISTER_ALL_CPU_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cpu)
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -741,7 +741,7 @@ static std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated3d_backward_cpu(
|
||||
return std::tie(grad_input, grad_weight, grad_bias);
|
||||
}
|
||||
|
||||
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cpu);
|
||||
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cpu);
|
||||
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cpu)
|
||||
REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cpu)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -14,7 +14,7 @@ using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const
|
||||
int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
|
||||
using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
|
||||
|
||||
DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
|
||||
DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel)
|
||||
DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel)
|
||||
|
||||
// averge pooling has same signature for forward and backward
|
||||
|
@ -1187,10 +1187,10 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backwar
|
||||
DEFINE_DISPATCH(NAME##_miopen_stub); \
|
||||
DEFINE_DISPATCH(NAME##_packed_cudnn_stub); \
|
||||
DEFINE_DISPATCH(NAME##_packed_miopen_stub); \
|
||||
REGISTER_NO_CPU_DISPATCH(NAME##_cudnn_stub); \
|
||||
REGISTER_NO_CPU_DISPATCH(NAME##_miopen_stub); \
|
||||
REGISTER_NO_CPU_DISPATCH(NAME##_packed_cudnn_stub); \
|
||||
REGISTER_NO_CPU_DISPATCH(NAME##_packed_miopen_stub); \
|
||||
REGISTER_NO_CPU_DISPATCH(NAME##_cudnn_stub) \
|
||||
REGISTER_NO_CPU_DISPATCH(NAME##_miopen_stub) \
|
||||
REGISTER_NO_CPU_DISPATCH(NAME##_packed_cudnn_stub) \
|
||||
REGISTER_NO_CPU_DISPATCH(NAME##_packed_miopen_stub) \
|
||||
\
|
||||
std::tuple<Tensor, Tensor> NAME( \
|
||||
const Tensor& _input, \
|
||||
@ -1415,17 +1415,17 @@ static std::tuple<Tensor, Tensor> quantized_gru_data_legacy(
|
||||
using tanf_cell_type = SimpleCell<tanh_f, CellParams>;
|
||||
ONE_HIDDEN_RNN(rnn_tanh, tanf_cell_type)
|
||||
using relu_cell_type = SimpleCell<relu_f, CellParams>;
|
||||
ONE_HIDDEN_RNN(rnn_relu, relu_cell_type);
|
||||
ONE_HIDDEN_RNN(rnn_relu, relu_cell_type)
|
||||
|
||||
DEFINE_DISPATCH(lstm_cudnn_stub);
|
||||
DEFINE_DISPATCH(lstm_packed_cudnn_stub);
|
||||
DEFINE_DISPATCH(lstm_miopen_stub);
|
||||
DEFINE_DISPATCH(lstm_packed_miopen_stub);
|
||||
DEFINE_DISPATCH(lstm_mkldnn_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(lstm_cudnn_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(lstm_packed_cudnn_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(lstm_miopen_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(lstm_packed_miopen_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(lstm_cudnn_stub)
|
||||
REGISTER_NO_CPU_DISPATCH(lstm_packed_cudnn_stub)
|
||||
REGISTER_NO_CPU_DISPATCH(lstm_miopen_stub)
|
||||
REGISTER_NO_CPU_DISPATCH(lstm_packed_miopen_stub)
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> lstm(
|
||||
const Tensor& _input, TensorList hx,
|
||||
@ -1857,9 +1857,9 @@ static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
|
||||
// Quantized LSTM cell
|
||||
using quantized_lstm_cell_dynamic_type = LSTMCell<QuantizedCellParamsDynamic>;
|
||||
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx)
|
||||
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx)
|
||||
|
||||
// Helpers for simpler cells
|
||||
using simple_hx_type = const Tensor&;
|
||||
@ -1871,21 +1871,21 @@ static simple_hx_type prepare_quantized_hx(simple_hx_type hx) {
|
||||
using quantized_gru_cell_type = GRUCell<QuantizedCellParams>;
|
||||
using quantized_gru_cell_dynamic_type = GRUCell<QuantizedCellParamsDynamic>;
|
||||
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx);
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx)
|
||||
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx);
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx)
|
||||
|
||||
// Quantized RNN w/ ReLU cell
|
||||
using quantized_rnn_relu_cell_type = SimpleCell<relu_f, QuantizedCellParams>;
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx);
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx)
|
||||
using quantized_rnn_relu_cell_dynamic_type = SimpleCell<relu_f, QuantizedCellParamsDynamic>;
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx);
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx)
|
||||
|
||||
// Quantized RNN w/ tanh cell
|
||||
using quantized_rnn_tanh_cell_type = SimpleCell<tanh_f, QuantizedCellParams>;
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx);
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx)
|
||||
using quantized_rnn_tanh_cell_dynamic_type = SimpleCell<tanh_f, QuantizedCellParamsDynamic>;
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx);
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx)
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -932,11 +932,11 @@ Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {
|
||||
|
||||
Tensor special_multigammaln(const Tensor& self, int64_t p) {
|
||||
return self.mvlgamma(p);
|
||||
};
|
||||
}
|
||||
|
||||
Tensor& special_multigammaln_out(const Tensor& self, int64_t p, Tensor& result) {
|
||||
return at::mvlgamma_out(result, self, p);
|
||||
};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> frexp(const Tensor& self) {
|
||||
Tensor mantissa = at::empty_like(self);
|
||||
|
@ -1415,16 +1415,16 @@ REGISTER_DISPATCH(laguerre_polynomial_l_stub, &laguerre_polynomial_l_kernel)
|
||||
REGISTER_DISPATCH(legendre_polynomial_p_stub, &legendre_polynomial_p_kernel)
|
||||
REGISTER_DISPATCH(
|
||||
shifted_chebyshev_polynomial_t_stub,
|
||||
&shifted_chebyshev_polynomial_t_kernel);
|
||||
&shifted_chebyshev_polynomial_t_kernel)
|
||||
REGISTER_DISPATCH(
|
||||
shifted_chebyshev_polynomial_u_stub,
|
||||
&shifted_chebyshev_polynomial_u_kernel);
|
||||
&shifted_chebyshev_polynomial_u_kernel)
|
||||
REGISTER_DISPATCH(
|
||||
shifted_chebyshev_polynomial_v_stub,
|
||||
&shifted_chebyshev_polynomial_v_kernel);
|
||||
&shifted_chebyshev_polynomial_v_kernel)
|
||||
REGISTER_DISPATCH(
|
||||
shifted_chebyshev_polynomial_w_stub,
|
||||
&shifted_chebyshev_polynomial_w_kernel);
|
||||
&shifted_chebyshev_polynomial_w_kernel)
|
||||
// Might enable AVX512 dispatch after enabling explicit vectorization for them.
|
||||
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_kernel)
|
||||
REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_kernel)
|
||||
|
@ -241,5 +241,5 @@ static void multinomial_with_replacement_kernel_impl(
|
||||
|
||||
REGISTER_DISPATCH(
|
||||
multinomial_with_replacement_stub,
|
||||
&multinomial_with_replacement_kernel_impl);
|
||||
&multinomial_with_replacement_kernel_impl)
|
||||
} // namespace at::native
|
||||
|
@ -22,7 +22,7 @@ inline namespace CPU_CAPABILITY {
|
||||
constexpr auto kF32RegisterPairsPerIteration = 4;
|
||||
constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2;
|
||||
constexpr auto kF32ElementsPerRegister = vec::Vectorized<float>::size();
|
||||
constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;;
|
||||
constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
@ -328,8 +328,8 @@ void fp16_gemv_trans(
|
||||
#if !defined(C10_MOBILE)
|
||||
// NOTE: we don't *need* to go through dispatch for the ARM-only
|
||||
// implementation right now, but we will need it when we cover x86.
|
||||
REGISTER_DISPATCH(fp16_dot_with_fp32_arith_stub, &fp16_dot_with_fp32_arith);
|
||||
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans);
|
||||
REGISTER_DISPATCH(fp16_dot_with_fp32_arith_stub, &fp16_dot_with_fp32_arith)
|
||||
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans)
|
||||
#else
|
||||
#endif // defined(__aarch64__) && !defined(C10_MOBILE)
|
||||
|
||||
|
@ -8,7 +8,7 @@ namespace at::native {
|
||||
#if !defined(C10_MOBILE)
|
||||
using fp16_dot_fn = float(*)(const Half*, const Half*, int64_t);
|
||||
using fp16_gemv_fn = void(*)(int, int, float, const Half*, int, const Half*, int, float, Half*, int);
|
||||
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_with_fp32_arith_stub);
|
||||
DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub);
|
||||
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_with_fp32_arith_stub)
|
||||
DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub)
|
||||
#endif // !defined(C10_MOBILE)
|
||||
} // namespace at::native
|
||||
|
@ -1295,15 +1295,15 @@ ALSO_REGISTER_AVX512_DISPATCH(softmax_lastdim_kernel, &softmax_lastdim_kernel_im
|
||||
ALSO_REGISTER_AVX512_DISPATCH(log_softmax_lastdim_kernel, &log_softmax_lastdim_kernel_impl)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(
|
||||
softmax_backward_lastdim_kernel,
|
||||
&softmax_backward_lastdim_kernel_impl);
|
||||
&softmax_backward_lastdim_kernel_impl)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(
|
||||
log_softmax_backward_lastdim_kernel,
|
||||
&log_softmax_backward_lastdim_kernel_impl);
|
||||
&log_softmax_backward_lastdim_kernel_impl)
|
||||
|
||||
ALSO_REGISTER_AVX512_DISPATCH(softmax_kernel, &softmax_kernel_impl)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(log_softmax_kernel, &log_softmax_kernel_impl)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(softmax_backward_kernel, &softmax_backward_kernel_impl)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(
|
||||
log_softmax_backward_kernel,
|
||||
&log_softmax_backward_kernel_impl);
|
||||
&log_softmax_backward_kernel_impl)
|
||||
} // namespace at::native
|
||||
|
@ -830,15 +830,15 @@ REGISTER_DISPATCH(special_i0e_stub, &CPU_CAPABILITY::i0e_kernel)
|
||||
REGISTER_DISPATCH(special_ndtri_stub, &CPU_CAPABILITY::ndtri_kernel)
|
||||
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &CPU_CAPABILITY::modified_bessel_k0_kernel)
|
||||
REGISTER_DISPATCH(special_modified_bessel_k1_stub, &CPU_CAPABILITY::modified_bessel_k1_kernel)
|
||||
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(ceil);
|
||||
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(floor);
|
||||
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(round);
|
||||
IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sqrt);
|
||||
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(trunc);
|
||||
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(i0);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sin);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(cos);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(tan);
|
||||
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(ceil)
|
||||
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(floor)
|
||||
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(round)
|
||||
IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sqrt)
|
||||
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(trunc)
|
||||
IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(i0)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sin)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(cos)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(tan)
|
||||
|
||||
// The following kernels are compute-intensive & are compiled with both AVX512
|
||||
// & AVX2
|
||||
@ -871,19 +871,19 @@ REGISTER_DISPATCH(special_bessel_y1_stub, &CPU_CAPABILITY::bessel_y1_kernel)
|
||||
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &CPU_CAPABILITY::modified_bessel_i0_kernel)
|
||||
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &CPU_CAPABILITY::modified_bessel_i1_kernel)
|
||||
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(acos);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(asin);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(atan);
|
||||
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erf);
|
||||
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfc);
|
||||
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfinv);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(exp);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(expm1);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log10);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log1p);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log2);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(tanh);
|
||||
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(lgamma);
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(acos)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(asin)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(atan)
|
||||
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erf)
|
||||
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfc)
|
||||
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfinv)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(exp)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(expm1)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log10)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log1p)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log2)
|
||||
STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(tanh)
|
||||
IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(lgamma)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -760,6 +760,6 @@ std::tuple<Tensor, Tensor> conv_depthwise2d_backward_cuda(
|
||||
grad_weight);
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(conv_depthwise2d_backward_stub, &conv_depthwise2d_backward_cuda);
|
||||
REGISTER_CUDA_DISPATCH(conv_depthwise2d_backward_stub, &conv_depthwise2d_backward_cuda)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -695,7 +695,7 @@ std::tuple<Tensor, Tensor, Tensor> conv_depthwise3d_backward_cuda(
|
||||
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(conv_depthwise3d_backward_stub, &conv_depthwise3d_backward_cuda);
|
||||
REGISTER_CUDA_DISPATCH(conv_depthwise3d_backward_stub, &conv_depthwise3d_backward_cuda)
|
||||
|
||||
#undef DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION
|
||||
#undef DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS
|
||||
|
@ -23,6 +23,6 @@ Tensor flatten_indices_cuda_kernel(const Tensor& indices, IntArrayRef size) {
|
||||
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(flatten_indices_stub, &flatten_indices_cuda_kernel);
|
||||
REGISTER_CUDA_DISPATCH(flatten_indices_stub, &flatten_indices_cuda_kernel)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -483,6 +483,6 @@ REGISTER_DISPATCH(put_stub, &put_kernel)
|
||||
REGISTER_DISPATCH(take_stub, &take_kernel)
|
||||
REGISTER_DISPATCH(flip_stub, &flip_kernel)
|
||||
|
||||
REGISTER_CUDA_DISPATCH(index_put_kernel_quantized_stub, &index_put_kernel_quantized_cuda);
|
||||
REGISTER_CUDA_DISPATCH(index_put_kernel_quantized_stub, &index_put_kernel_quantized_cuda)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -684,7 +684,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(index_put_with_sort_stub, &index_put_with_sort_kernel);
|
||||
REGISTER_CUDA_DISPATCH(index_put_with_sort_stub, &index_put_with_sort_kernel)
|
||||
|
||||
void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<Tensor>>& indices, const Tensor & value, double scale, int zero_point, bool unsafe) {
|
||||
if (indices.size() > (size_t)self.dim()) {
|
||||
@ -784,7 +784,7 @@ void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(index_put_with_sort_quantized_stub, &index_put_with_sort_quantized);
|
||||
REGISTER_CUDA_DISPATCH(index_put_with_sort_quantized_stub, &index_put_with_sort_quantized)
|
||||
} //anonymous
|
||||
|
||||
|
||||
@ -1687,7 +1687,7 @@ void masked_fill_kernel_quantized(TensorIterator& iter, const Scalar& value, dou
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(masked_fill_kernel_quantized_stub, &masked_fill_kernel_quantized);
|
||||
REGISTER_CUDA_DISPATCH(masked_fill_kernel_quantized_stub, &masked_fill_kernel_quantized)
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
@ -138,19 +138,19 @@ void lazy_ldl_solve(
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(cholesky_stub, &lazy_cholesky_kernel)
|
||||
REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &lazy_cholesky_inverse_kernel);
|
||||
REGISTER_CUDA_DISPATCH(lu_factor_stub, &lazy_lu_factor);
|
||||
REGISTER_CUDA_DISPATCH(ldl_factor_stub, &lazy_ldl_factor);
|
||||
REGISTER_CUDA_DISPATCH(ldl_solve_stub, &lazy_ldl_solve);
|
||||
REGISTER_CUDA_DISPATCH(triangular_solve_stub, &lazy_triangular_solve_kernel);
|
||||
REGISTER_CUDA_DISPATCH(orgqr_stub, &lazy_orgqr_kernel);
|
||||
REGISTER_CUDA_DISPATCH(ormqr_stub, &lazy_ormqr_kernel);
|
||||
REGISTER_CUDA_DISPATCH(geqrf_stub, &lazy_geqrf_kernel);
|
||||
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &lazy_linalg_eigh_kernel);
|
||||
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &lazy_linalg_eig_kernel);
|
||||
REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &lazy_cholesky_inverse_kernel)
|
||||
REGISTER_CUDA_DISPATCH(lu_factor_stub, &lazy_lu_factor)
|
||||
REGISTER_CUDA_DISPATCH(ldl_factor_stub, &lazy_ldl_factor)
|
||||
REGISTER_CUDA_DISPATCH(ldl_solve_stub, &lazy_ldl_solve)
|
||||
REGISTER_CUDA_DISPATCH(triangular_solve_stub, &lazy_triangular_solve_kernel)
|
||||
REGISTER_CUDA_DISPATCH(orgqr_stub, &lazy_orgqr_kernel)
|
||||
REGISTER_CUDA_DISPATCH(ormqr_stub, &lazy_ormqr_kernel)
|
||||
REGISTER_CUDA_DISPATCH(geqrf_stub, &lazy_geqrf_kernel)
|
||||
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &lazy_linalg_eigh_kernel)
|
||||
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &lazy_linalg_eig_kernel)
|
||||
REGISTER_CUDA_DISPATCH(svd_stub, &lazy_svd_kernel)
|
||||
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lazy_lu_solve);
|
||||
REGISTER_CUDA_DISPATCH(lstsq_stub, &lazy_lstsq_kernel);
|
||||
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lazy_lu_solve)
|
||||
REGISTER_CUDA_DISPATCH(lstsq_stub, &lazy_lstsq_kernel)
|
||||
} // anonymous namespace
|
||||
|
||||
// Old style dispatches
|
||||
|
@ -828,6 +828,6 @@ std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose2d_backward_cuda(
|
||||
return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cuda);
|
||||
REGISTER_CUDA_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cuda)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1011,6 +1011,6 @@ std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose3d_backward_cuda(
|
||||
return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(slow_conv_transpose3d_backward_stub, &slow_conv_transpose3d_backward_cuda);
|
||||
REGISTER_CUDA_DISPATCH(slow_conv_transpose3d_backward_stub, &slow_conv_transpose3d_backward_cuda)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -608,7 +608,7 @@ std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated3d_backward_cuda(
|
||||
return std::tie(grad_input, grad_weight, grad_bias);
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cuda);
|
||||
REGISTER_CUDA_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cuda);
|
||||
REGISTER_CUDA_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cuda)
|
||||
REGISTER_CUDA_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cuda)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -90,13 +90,13 @@ void aminmax_allreduce_kernel_impl(const Tensor& input, Tensor& min_result, Tens
|
||||
|
||||
} // namespace (anonymous)
|
||||
|
||||
REGISTER_CUDA_DISPATCH(min_stub, &min_kernel_impl);
|
||||
REGISTER_CUDA_DISPATCH(max_stub, &max_kernel_impl);
|
||||
REGISTER_CUDA_DISPATCH(min_all_stub, &min_all_kernel_impl);
|
||||
REGISTER_CUDA_DISPATCH(max_all_stub, &max_all_kernel_impl);
|
||||
REGISTER_CUDA_DISPATCH(aminmax_allreduce_stub, &aminmax_allreduce_kernel_impl);
|
||||
REGISTER_CUDA_DISPATCH(aminmax_stub, &aminmax_kernel_impl);
|
||||
REGISTER_CUDA_DISPATCH(min_stub, &min_kernel_impl)
|
||||
REGISTER_CUDA_DISPATCH(max_stub, &max_kernel_impl)
|
||||
REGISTER_CUDA_DISPATCH(min_all_stub, &min_all_kernel_impl)
|
||||
REGISTER_CUDA_DISPATCH(max_all_stub, &max_all_kernel_impl)
|
||||
REGISTER_CUDA_DISPATCH(aminmax_allreduce_stub, &aminmax_allreduce_kernel_impl)
|
||||
REGISTER_CUDA_DISPATCH(aminmax_stub, &aminmax_kernel_impl)
|
||||
|
||||
REGISTER_CUDA_DISPATCH(norm_stub, &norm_kernel_cuda);
|
||||
REGISTER_CUDA_DISPATCH(norm_stub, &norm_kernel_cuda)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -109,7 +109,7 @@ void cumprod_cuda_kernel(const Tensor& result, const Tensor& self, int64_t dim)
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(cumsum_stub, &cumsum_cuda_kernel);
|
||||
REGISTER_CUDA_DISPATCH(cumprod_stub, &cumprod_cuda_kernel);
|
||||
REGISTER_CUDA_DISPATCH(cumsum_stub, &cumsum_cuda_kernel)
|
||||
REGISTER_CUDA_DISPATCH(cumprod_stub, &cumprod_cuda_kernel)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -122,6 +122,6 @@ void sort_cuda_kernel(
|
||||
// TODO: we should handle this accordingly when we start using REGISTER_HIP_DISPATCH,
|
||||
// since REGISTER_DISPATCH won't work in this cpp file.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_CUDA_DISPATCH(sort_stub, &sort_cuda_kernel);
|
||||
REGISTER_CUDA_DISPATCH(sort_stub, &sort_cuda_kernel)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -204,8 +204,8 @@ void sparse_mask_projection_out_cuda_kernel(
|
||||
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cuda_kernel);
|
||||
REGISTER_CUDA_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cuda_kernel);
|
||||
REGISTER_CUDA_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cuda_kernel);
|
||||
REGISTER_CUDA_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cuda_kernel)
|
||||
REGISTER_CUDA_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cuda_kernel)
|
||||
REGISTER_CUDA_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cuda_kernel)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -18,6 +18,6 @@ void isin_default_kernel_gpu(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
REGISTER_CUDA_DISPATCH(isin_default_stub, &isin_default_kernel_gpu);
|
||||
REGISTER_CUDA_DISPATCH(isin_default_stub, &isin_default_kernel_gpu)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -98,5 +98,5 @@ void mode_kernel_impl(
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(mode_stub, &mode_kernel_impl);
|
||||
REGISTER_CUDA_DISPATCH(mode_stub, &mode_kernel_impl)
|
||||
} // namespace at::native
|
||||
|
@ -1454,7 +1454,7 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper)
|
||||
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
|
||||
REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -1670,7 +1670,7 @@ static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& i
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(lu_factor_stub, &lu_factor);
|
||||
REGISTER_CUDA_DISPATCH(lu_factor_stub, &lu_factor)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -1764,7 +1764,7 @@ void triangular_solve_kernel(const Tensor& A, const Tensor& B, bool left, bool u
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
|
||||
REGISTER_CUDA_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ orgqr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -1782,7 +1782,7 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) {
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
|
||||
REGISTER_CUDA_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
|
||||
void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
|
||||
#ifdef USE_LINALG_SOLVER
|
||||
@ -1794,7 +1794,7 @@ void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, b
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(ormqr_stub, &ormqr_kernel);
|
||||
REGISTER_CUDA_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ qr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -1878,7 +1878,7 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) {
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel);
|
||||
REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
|
||||
@ -2007,7 +2007,7 @@ void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, c
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
|
||||
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -2093,7 +2093,7 @@ void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos,
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
|
||||
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -2579,7 +2579,7 @@ if (n <= 8) {
|
||||
#endif // ifdef USE_LINALG_SOLVER
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_kernel);
|
||||
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -2761,7 +2761,7 @@ void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& /*rank*/, Tensor& /*singul
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(lstsq_stub, &lstsq_kernel);
|
||||
REGISTER_CUDA_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
|
||||
|
||||
#if defined(BUILD_LAZY_CUDA_LINALG)
|
||||
|
@ -2627,59 +2627,59 @@ std::pair<Tensor, hidden_type> _cudnn_impl(
|
||||
std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
|
||||
}
|
||||
|
||||
#define ONE_HIDDEN_RNN(NAME, MODE) \
|
||||
void NAME##_cudnn( \
|
||||
Tensor& output, \
|
||||
Tensor& hy, \
|
||||
const Tensor& input, \
|
||||
const Tensor& hx, \
|
||||
TensorList params, \
|
||||
bool has_biases, \
|
||||
int64_t num_layers, \
|
||||
double dropout_p, \
|
||||
bool train, \
|
||||
bool bidirectional, \
|
||||
bool batch_first) { \
|
||||
std::tie(output, hy) = _cudnn_impl( \
|
||||
input, \
|
||||
hx, \
|
||||
params, \
|
||||
has_biases, \
|
||||
MODE, \
|
||||
num_layers, \
|
||||
dropout_p, \
|
||||
train, \
|
||||
bidirectional, \
|
||||
batch_first); \
|
||||
} \
|
||||
\
|
||||
void NAME##_packed_cudnn( \
|
||||
Tensor& output, \
|
||||
Tensor& hy, \
|
||||
const Tensor& data, \
|
||||
const Tensor& batch_sizes, \
|
||||
const Tensor& hx, \
|
||||
TensorList params, \
|
||||
bool has_biases, \
|
||||
int64_t num_layers, \
|
||||
double dropout_p, \
|
||||
bool train, \
|
||||
bool bidirectional) { \
|
||||
std::tie(output, hy) = _cudnn_impl( \
|
||||
data, \
|
||||
batch_sizes, \
|
||||
hx, \
|
||||
params, \
|
||||
has_biases, \
|
||||
MODE, \
|
||||
num_layers, \
|
||||
dropout_p, \
|
||||
train, \
|
||||
bidirectional); \
|
||||
} \
|
||||
\
|
||||
REGISTER_CUDA_DISPATCH(NAME##_cudnn_stub, &NAME##_cudnn); \
|
||||
REGISTER_CUDA_DISPATCH(NAME##_packed_cudnn_stub, &NAME##_packed_cudnn);
|
||||
#define ONE_HIDDEN_RNN(NAME, MODE) \
|
||||
void NAME##_cudnn( \
|
||||
Tensor& output, \
|
||||
Tensor& hy, \
|
||||
const Tensor& input, \
|
||||
const Tensor& hx, \
|
||||
TensorList params, \
|
||||
bool has_biases, \
|
||||
int64_t num_layers, \
|
||||
double dropout_p, \
|
||||
bool train, \
|
||||
bool bidirectional, \
|
||||
bool batch_first) { \
|
||||
std::tie(output, hy) = _cudnn_impl( \
|
||||
input, \
|
||||
hx, \
|
||||
params, \
|
||||
has_biases, \
|
||||
MODE, \
|
||||
num_layers, \
|
||||
dropout_p, \
|
||||
train, \
|
||||
bidirectional, \
|
||||
batch_first); \
|
||||
} \
|
||||
\
|
||||
void NAME##_packed_cudnn( \
|
||||
Tensor& output, \
|
||||
Tensor& hy, \
|
||||
const Tensor& data, \
|
||||
const Tensor& batch_sizes, \
|
||||
const Tensor& hx, \
|
||||
TensorList params, \
|
||||
bool has_biases, \
|
||||
int64_t num_layers, \
|
||||
double dropout_p, \
|
||||
bool train, \
|
||||
bool bidirectional) { \
|
||||
std::tie(output, hy) = _cudnn_impl( \
|
||||
data, \
|
||||
batch_sizes, \
|
||||
hx, \
|
||||
params, \
|
||||
has_biases, \
|
||||
MODE, \
|
||||
num_layers, \
|
||||
dropout_p, \
|
||||
train, \
|
||||
bidirectional); \
|
||||
} \
|
||||
\
|
||||
REGISTER_CUDA_DISPATCH(NAME##_cudnn_stub, &NAME##_cudnn) \
|
||||
REGISTER_CUDA_DISPATCH(NAME##_packed_cudnn_stub, &NAME##_packed_cudnn)
|
||||
|
||||
ONE_HIDDEN_RNN(gru, CUDNN_GRU)
|
||||
ONE_HIDDEN_RNN(rnn_tanh, CUDNN_RNN_TANH)
|
||||
@ -2743,8 +2743,8 @@ void lstm_packed_cudnn(
|
||||
cy = std::get<1>(result.second);
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn);
|
||||
REGISTER_CUDA_DISPATCH(lstm_packed_cudnn_stub, &lstm_packed_cudnn);
|
||||
REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn)
|
||||
REGISTER_CUDA_DISPATCH(lstm_packed_cudnn_stub, &lstm_packed_cudnn)
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -1696,9 +1696,9 @@ Tensor miopen_convolution_relu(
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(miopen_convolution_backward_stub, &miopen_convolution_backward);
|
||||
REGISTER_CUDA_DISPATCH(miopen_convolution_transpose_backward_stub, &miopen_convolution_transpose_backward);
|
||||
REGISTER_CUDA_DISPATCH(miopen_depthwise_convolution_backward_stub, &miopen_depthwise_convolution_backward);
|
||||
REGISTER_CUDA_DISPATCH(miopen_convolution_backward_stub, &miopen_convolution_backward)
|
||||
REGISTER_CUDA_DISPATCH(miopen_convolution_transpose_backward_stub, &miopen_convolution_transpose_backward)
|
||||
REGISTER_CUDA_DISPATCH(miopen_depthwise_convolution_backward_stub, &miopen_depthwise_convolution_backward)
|
||||
|
||||
}} // namespace
|
||||
|
||||
|
@ -876,8 +876,8 @@ void NAME##_packed_miopen(Tensor& output, Tensor& hy, \
|
||||
has_biases, MODE, num_layers, dropout_p, train, bidirectional); \
|
||||
} \
|
||||
\
|
||||
REGISTER_CUDA_DISPATCH(NAME##_miopen_stub, &NAME##_miopen); \
|
||||
REGISTER_CUDA_DISPATCH(NAME##_packed_miopen_stub, &NAME##_packed_miopen);
|
||||
REGISTER_CUDA_DISPATCH(NAME##_miopen_stub, &NAME##_miopen) \
|
||||
REGISTER_CUDA_DISPATCH(NAME##_packed_miopen_stub, &NAME##_packed_miopen)
|
||||
|
||||
ONE_HIDDEN_RNN(gru, miopenGRU)
|
||||
ONE_HIDDEN_RNN(rnn_tanh, miopenRNNTANH)
|
||||
@ -905,8 +905,8 @@ void lstm_packed_miopen(Tensor& output, Tensor& hy, Tensor& cy,
|
||||
cy = std::get<1>(result.second);
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen);
|
||||
REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen);
|
||||
REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen)
|
||||
REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen)
|
||||
|
||||
} // anonymous namespace
|
||||
}} //namespace native.
|
||||
|
@ -575,7 +575,7 @@ Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
#else
|
||||
|
||||
namespace at { namespace native {
|
||||
REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub)
|
||||
|
||||
Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
|
||||
TORCH_CHECK(false, "fft: ATen not compiled with FFT support");
|
||||
|
@ -28,9 +28,9 @@ Tensor mkldnn_convolution(
|
||||
TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub)
|
||||
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_stub)
|
||||
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub)
|
||||
|
||||
}}
|
||||
|
||||
@ -891,7 +891,7 @@ Tensor mkldnn_convolution_transpose_pointwise(
|
||||
);
|
||||
}
|
||||
|
||||
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_backward_stub, &mkldnn_convolution_backward);
|
||||
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_backward_stub, &mkldnn_convolution_backward)
|
||||
|
||||
namespace{
|
||||
Tensor mkldnn_convolution_transpose(
|
||||
@ -1044,8 +1044,8 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_transpose_backward(
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_stub, &mkldnn_convolution_transpose);
|
||||
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub, &mkldnn_convolution_transpose_backward);
|
||||
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_stub, &mkldnn_convolution_transpose)
|
||||
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub, &mkldnn_convolution_transpose_backward)
|
||||
|
||||
TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
|
||||
m.impl(
|
||||
|
@ -71,7 +71,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la
|
||||
TORCH_CHECK(false, "mkldnn_rnn_layer_backward: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
REGISTER_NO_CPU_DISPATCH(lstm_mkldnn_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(lstm_mkldnn_stub)
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
|
@ -154,14 +154,14 @@ const std::map<c10::string_view, AttrFunction>& fusion_unary_attr_map() {
|
||||
{"gelu", attr_func_gelu},
|
||||
};
|
||||
return fusion_attr_map;
|
||||
};
|
||||
}
|
||||
|
||||
const std::map<c10::string_view, ideep::algorithm>& fusion_unary_alg_map() {
|
||||
static const std::map<c10::string_view, ideep::algorithm> fusion_attr_map{
|
||||
{"relu", {ideep::algorithm::eltwise_relu}},
|
||||
};
|
||||
return fusion_attr_map;
|
||||
};
|
||||
}
|
||||
|
||||
const std::map<c10::string_view, ideep::algorithm>& fusion_binary_alg_map() {
|
||||
static const std::map<c10::string_view, ideep::algorithm> fusion_attr_map{
|
||||
@ -171,7 +171,7 @@ const std::map<c10::string_view, ideep::algorithm>& fusion_binary_alg_map() {
|
||||
{"div", {ideep::algorithm::binary_div}},
|
||||
};
|
||||
return fusion_attr_map;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // AT_MKLDNN_ENABLED()
|
||||
}}
|
||||
|
@ -356,7 +356,7 @@ TORCH_IMPL_FUNC(bitwise_not_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
mps::_bitwise_not_out_mps(self, output);
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(lshift_stub, &lshift_kernel_mps);
|
||||
REGISTER_MPS_DISPATCH(rshift_stub, &rshift_kernel_mps);
|
||||
REGISTER_MPS_DISPATCH(lshift_stub, &lshift_kernel_mps)
|
||||
REGISTER_MPS_DISPATCH(rshift_stub, &rshift_kernel_mps)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -17,7 +17,7 @@
|
||||
namespace at::native {
|
||||
|
||||
DEFINE_DISPATCH(nested_dense_elementwise_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(nested_dense_elementwise_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(nested_dense_elementwise_stub)
|
||||
|
||||
std::pair<NestedTensorImpl*, NestedTensorImpl*>
|
||||
static get_elementwise_nested_tensor_impl(
|
||||
|
@ -114,7 +114,7 @@ void _nested_op_dense_esuhm_cuda(Tensor& result, const Tensor& self, const Tenso
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(nested_dense_elementwise_stub, &_nested_op_dense_esuhm_cuda);
|
||||
REGISTER_CUDA_DISPATCH(nested_dense_elementwise_stub, &_nested_op_dense_esuhm_cuda)
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -4264,21 +4264,21 @@ void index_put_kernel_quantized_cpu(TensorIterator& iter, IntArrayRef index_size
|
||||
// AVX2 kernels would be used instead. Ref: GH 56992.
|
||||
#if defined(_WIN32)
|
||||
REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub,
|
||||
&dequantize_tensor_per_channel_affine_cpu);
|
||||
&dequantize_tensor_per_channel_affine_cpu)
|
||||
REGISTER_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
|
||||
&dequantize_tensor_per_channel_float_qparams_cpu);
|
||||
&dequantize_tensor_per_channel_float_qparams_cpu)
|
||||
REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub,
|
||||
&fake_quant_per_channel_cachemask_cpu);
|
||||
&fake_quant_per_channel_cachemask_cpu)
|
||||
REGISTER_DISPATCH(qavg_pool2d_nhwc_stub, &qavg_pool2d_nhwc_kernel)
|
||||
REGISTER_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel)
|
||||
#else
|
||||
// These kernels are dispatched to AVX512
|
||||
ALSO_REGISTER_AVX512_DISPATCH(dequantize_tensor_per_channel_affine_stub,
|
||||
&dequantize_tensor_per_channel_affine_cpu);
|
||||
&dequantize_tensor_per_channel_affine_cpu)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
|
||||
&dequantize_tensor_per_channel_float_qparams_cpu);
|
||||
&dequantize_tensor_per_channel_float_qparams_cpu)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(fake_quant_per_channel_cachemask_stub,
|
||||
&fake_quant_per_channel_cachemask_cpu);
|
||||
&fake_quant_per_channel_cachemask_cpu)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(qavg_pool2d_nhwc_stub, &qavg_pool2d_nhwc_kernel)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel)
|
||||
#endif // CPU_CAPABILITY_AVX512 && _WIN32
|
||||
@ -4286,17 +4286,17 @@ ALSO_REGISTER_AVX512_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel)
|
||||
// The kernels below are dispatched to AVX2 because they don't perform as well
|
||||
// with AVX512. We might revisit this decision in the near future.
|
||||
REGISTER_DISPATCH(dequantize_tensor_per_tensor_affine_stub,
|
||||
&dequantize_tensor_per_tensor_affine_cpu);
|
||||
&dequantize_tensor_per_tensor_affine_cpu)
|
||||
REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub,
|
||||
&fake_quantize_learnable_tensor_grad_kernel_cpu);
|
||||
&fake_quantize_learnable_tensor_grad_kernel_cpu)
|
||||
REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub,
|
||||
&fake_quantize_tensor_cachemask_kernel);
|
||||
&fake_quantize_tensor_cachemask_kernel)
|
||||
REGISTER_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub,
|
||||
&fake_quantize_tensor_cachemask_tensor_qparams_kernel);
|
||||
&fake_quantize_tensor_cachemask_tensor_qparams_kernel)
|
||||
REGISTER_DISPATCH(qadaptive_avg_pool2d_nhwc_stub,
|
||||
&qadaptive_avg_pool2d_nhwc_kernel);
|
||||
&qadaptive_avg_pool2d_nhwc_kernel)
|
||||
REGISTER_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub,
|
||||
&qadaptive_avg_pool3d_ndhwc_kernel);
|
||||
&qadaptive_avg_pool3d_ndhwc_kernel)
|
||||
REGISTER_DISPATCH(qadd_relu_stub, &qadd_kernel<true>)
|
||||
REGISTER_DISPATCH(qadd_scalar_relu_stub, &qadd_scalar_kernel<true>)
|
||||
REGISTER_DISPATCH(qadd_scalar_stub, &qadd_scalar_kernel<false>)
|
||||
@ -4325,32 +4325,32 @@ REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel)
|
||||
REGISTER_DISPATCH(qthreshold_stub, &qthreshold_kernel)
|
||||
REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel)
|
||||
REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub,
|
||||
&fake_quantize_learnable_channel_grad_kernel_cpu);
|
||||
&fake_quantize_learnable_channel_grad_kernel_cpu)
|
||||
REGISTER_DISPATCH(
|
||||
quantize_tensor_per_tensor_affine_stub,
|
||||
&quantize_tensor_per_tensor_affine_cpu);
|
||||
&quantize_tensor_per_tensor_affine_cpu)
|
||||
REGISTER_DISPATCH(
|
||||
quantize_tensor_per_channel_affine_stub,
|
||||
&quantize_tensor_per_channel_affine_cpu);
|
||||
&quantize_tensor_per_channel_affine_cpu)
|
||||
REGISTER_DISPATCH(
|
||||
quantize_tensor_per_channel_float_qparams_stub,
|
||||
&quantize_tensor_per_channel_float_qparams_cpu);
|
||||
&quantize_tensor_per_channel_float_qparams_cpu)
|
||||
REGISTER_DISPATCH(quantized_normalize_stub, &quantized_normalize_kernel)
|
||||
REGISTER_DISPATCH(quantized_groupnorm_nhwc_stub, &quantized_groupnorm_nhwc_kernel)
|
||||
REGISTER_DISPATCH(qupsample_bilinear2d_nhwc_stub,
|
||||
&qupsample_bilinear2d_nhwc_kernel);
|
||||
&qupsample_bilinear2d_nhwc_kernel)
|
||||
REGISTER_DISPATCH(
|
||||
quantize_tensor_per_tensor_affine_sub_byte_stub,
|
||||
&quantize_tensor_per_tensor_affine_sub_byte_cpu);
|
||||
&quantize_tensor_per_tensor_affine_sub_byte_cpu)
|
||||
REGISTER_DISPATCH(
|
||||
dequantize_tensor_per_tensor_affine_sub_byte_stub,
|
||||
&dequantize_tensor_per_tensor_affine_sub_byte_cpu);
|
||||
&dequantize_tensor_per_tensor_affine_sub_byte_cpu)
|
||||
REGISTER_DISPATCH(
|
||||
masked_fill_kernel_quantized_stub,
|
||||
&masked_fill_kernel_quantized_cpu);
|
||||
&masked_fill_kernel_quantized_cpu)
|
||||
REGISTER_DISPATCH(
|
||||
index_put_kernel_quantized_stub,
|
||||
&index_put_kernel_quantized_cpu);
|
||||
&index_put_kernel_quantized_cpu)
|
||||
REGISTER_DISPATCH(qmean_inner_dim_stub, &qmean_inner_dim_kernel)
|
||||
REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel)
|
||||
} // namespace at::native
|
||||
|
@ -22,11 +22,11 @@ Tensor flatten_indices_cpu_kernel(const Tensor& indices, IntArrayRef size) {
|
||||
|
||||
}
|
||||
|
||||
REGISTER_ARCH_DISPATCH(flatten_indices_stub, DEFAULT, &flatten_indices_cpu_kernel);
|
||||
REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
|
||||
REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
|
||||
REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
|
||||
REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
|
||||
REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
|
||||
REGISTER_ARCH_DISPATCH(flatten_indices_stub, DEFAULT, &flatten_indices_cpu_kernel)
|
||||
REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -156,24 +156,24 @@ void sparse_mask_projection_out_cpu_kernel(
|
||||
|
||||
}
|
||||
|
||||
REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel);
|
||||
REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
|
||||
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
|
||||
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
|
||||
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
|
||||
REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
|
||||
REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel);
|
||||
REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
|
||||
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
|
||||
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
|
||||
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
|
||||
REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
|
||||
REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel);
|
||||
REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
|
||||
REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
|
||||
REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
|
||||
REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
|
||||
REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
|
||||
REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
}
|
||||
|
@ -512,10 +512,10 @@ Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indice
|
||||
return _sparse_compressed_tensor_unsafe_template<REQUIRED_LAYOUT>(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); \
|
||||
}
|
||||
|
||||
SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr);
|
||||
SPARSE_COMPRESSED_TENSOR_UNSAFE(csc, kSparseCsc);
|
||||
SPARSE_COMPRESSED_TENSOR_UNSAFE(bsr, kSparseBsr);
|
||||
SPARSE_COMPRESSED_TENSOR_UNSAFE(bsc, kSparseBsc);
|
||||
SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr)
|
||||
SPARSE_COMPRESSED_TENSOR_UNSAFE(csc, kSparseCsc)
|
||||
SPARSE_COMPRESSED_TENSOR_UNSAFE(bsr, kSparseBsr)
|
||||
SPARSE_COMPRESSED_TENSOR_UNSAFE(bsc, kSparseBsc)
|
||||
|
||||
static DimVector _estimate_sparse_compressed_tensor_size(
|
||||
const Tensor& compressed_indices,
|
||||
|
@ -433,42 +433,42 @@ Tensor& zero_sparse_csr_(Tensor& self) {
|
||||
}
|
||||
|
||||
#define CREATE_UNARY_UFUNC(op_name) \
|
||||
CREATE_UNARY_UFUNC_OUT(op_name); \
|
||||
CREATE_UNARY_UFUNC_FUNCTIONAL(op_name); \
|
||||
CREATE_UNARY_UFUNC_INPLACE(op_name);
|
||||
CREATE_UNARY_UFUNC_OUT(op_name) \
|
||||
CREATE_UNARY_UFUNC_FUNCTIONAL(op_name) \
|
||||
CREATE_UNARY_UFUNC_INPLACE(op_name)
|
||||
|
||||
#define CREATE_UNARY_UFUNC_NO_INPLACE(op_name) \
|
||||
CREATE_UNARY_UFUNC_OUT(op_name); \
|
||||
CREATE_UNARY_UFUNC_FUNCTIONAL(op_name);
|
||||
CREATE_UNARY_UFUNC_OUT(op_name) \
|
||||
CREATE_UNARY_UFUNC_FUNCTIONAL(op_name)
|
||||
|
||||
// Exhaustive list of the unary ufuncs supported by sparse compressed
|
||||
CREATE_UNARY_UFUNC(abs);
|
||||
CREATE_UNARY_UFUNC(asin);
|
||||
CREATE_UNARY_UFUNC(asinh);
|
||||
CREATE_UNARY_UFUNC(atan);
|
||||
CREATE_UNARY_UFUNC(atanh);
|
||||
CREATE_UNARY_UFUNC(ceil);
|
||||
CREATE_UNARY_UFUNC(deg2rad);
|
||||
CREATE_UNARY_UFUNC(erf);
|
||||
CREATE_UNARY_UFUNC(erfinv);
|
||||
CREATE_UNARY_UFUNC(expm1);
|
||||
CREATE_UNARY_UFUNC(floor);
|
||||
CREATE_UNARY_UFUNC(frac);
|
||||
CREATE_UNARY_UFUNC(log1p);
|
||||
CREATE_UNARY_UFUNC(neg);
|
||||
CREATE_UNARY_UFUNC(rad2deg);
|
||||
CREATE_UNARY_UFUNC(sign);
|
||||
CREATE_UNARY_UFUNC(sin);
|
||||
CREATE_UNARY_UFUNC(sinh);
|
||||
CREATE_UNARY_UFUNC(sgn);
|
||||
CREATE_UNARY_UFUNC(sqrt);
|
||||
CREATE_UNARY_UFUNC(tan);
|
||||
CREATE_UNARY_UFUNC(tanh);
|
||||
CREATE_UNARY_UFUNC(trunc);
|
||||
CREATE_UNARY_UFUNC(conj_physical);
|
||||
CREATE_UNARY_UFUNC(abs)
|
||||
CREATE_UNARY_UFUNC(asin)
|
||||
CREATE_UNARY_UFUNC(asinh)
|
||||
CREATE_UNARY_UFUNC(atan)
|
||||
CREATE_UNARY_UFUNC(atanh)
|
||||
CREATE_UNARY_UFUNC(ceil)
|
||||
CREATE_UNARY_UFUNC(deg2rad)
|
||||
CREATE_UNARY_UFUNC(erf)
|
||||
CREATE_UNARY_UFUNC(erfinv)
|
||||
CREATE_UNARY_UFUNC(expm1)
|
||||
CREATE_UNARY_UFUNC(floor)
|
||||
CREATE_UNARY_UFUNC(frac)
|
||||
CREATE_UNARY_UFUNC(log1p)
|
||||
CREATE_UNARY_UFUNC(neg)
|
||||
CREATE_UNARY_UFUNC(rad2deg)
|
||||
CREATE_UNARY_UFUNC(sign)
|
||||
CREATE_UNARY_UFUNC(sin)
|
||||
CREATE_UNARY_UFUNC(sinh)
|
||||
CREATE_UNARY_UFUNC(sgn)
|
||||
CREATE_UNARY_UFUNC(sqrt)
|
||||
CREATE_UNARY_UFUNC(tan)
|
||||
CREATE_UNARY_UFUNC(tanh)
|
||||
CREATE_UNARY_UFUNC(trunc)
|
||||
CREATE_UNARY_UFUNC(conj_physical)
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")
|
||||
static CREATE_UNARY_UFUNC(relu);
|
||||
static CREATE_UNARY_UFUNC(relu)
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
// With addition of `round.decimals` overload, using CREATE_UNARY_UFUNC leads
|
||||
@ -512,14 +512,14 @@ Tensor& threshold_backward_sparse_compressed_out(
|
||||
}
|
||||
|
||||
// angle, isneginf, isposinf and signbit currently don't have an inplace variant
|
||||
CREATE_UNARY_UFUNC_NO_INPLACE(angle);
|
||||
CREATE_UNARY_UFUNC_NO_INPLACE(isneginf);
|
||||
CREATE_UNARY_UFUNC_NO_INPLACE(isposinf);
|
||||
CREATE_UNARY_UFUNC_NO_INPLACE(signbit);
|
||||
CREATE_UNARY_UFUNC_NO_INPLACE(angle)
|
||||
CREATE_UNARY_UFUNC_NO_INPLACE(isneginf)
|
||||
CREATE_UNARY_UFUNC_NO_INPLACE(isposinf)
|
||||
CREATE_UNARY_UFUNC_NO_INPLACE(signbit)
|
||||
|
||||
// isnan and isinf don't have an out variant
|
||||
CREATE_UNARY_UFUNC_FUNCTIONAL(isnan);
|
||||
CREATE_UNARY_UFUNC_FUNCTIONAL(isinf);
|
||||
CREATE_UNARY_UFUNC_FUNCTIONAL(isnan)
|
||||
CREATE_UNARY_UFUNC_FUNCTIONAL(isinf)
|
||||
|
||||
template <typename scalar_t>
|
||||
void addmm_out_sparse_csr_native_cpu(
|
||||
|
@ -1230,7 +1230,7 @@ void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static Tensor& s_addmm_out_sparse_dense_cpu(
|
||||
Tensor& r,
|
||||
|
@ -444,12 +444,12 @@ int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Ten
|
||||
return static_cast<int64_t>(backend);
|
||||
}
|
||||
|
||||
REGISTER_ARCH_DISPATCH(_fused_sdp_choice_stub, DEFAULT, &_fused_sdp_choice_cpp);
|
||||
REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
|
||||
REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
|
||||
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
|
||||
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
|
||||
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
|
||||
REGISTER_ARCH_DISPATCH(_fused_sdp_choice_stub, DEFAULT, &_fused_sdp_choice_cpp)
|
||||
REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
|
||||
int64_t _fused_sdp_choice_meta(
|
||||
const Tensor& query_,
|
||||
|
@ -1387,7 +1387,7 @@ Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tenso
|
||||
return at::Tensor();
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda);
|
||||
REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda)
|
||||
|
||||
#if defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM)
|
||||
namespace {
|
||||
|
@ -227,23 +227,23 @@ static bool EmbeddingLookupGenericSlowIdx(
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false)
|
||||
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true)
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true)
|
||||
|
||||
#undef EMBEDDING_IDX_SPECIALIZATION
|
||||
|
||||
|
@ -10,16 +10,16 @@
|
||||
C10_DEFINE_bool(
|
||||
caffe2_threadpool_force_inline,
|
||||
false,
|
||||
"Force to always run jobs on the calling thread");
|
||||
"Force to always run jobs on the calling thread")
|
||||
|
||||
// Whether or not threadpool caps apply to Android
|
||||
C10_DEFINE_int(caffe2_threadpool_android_cap, true, "");
|
||||
C10_DEFINE_int(caffe2_threadpool_android_cap, true, "")
|
||||
|
||||
// Whether or not threadpool caps apply to iOS and MacOS
|
||||
C10_DEFINE_int(caffe2_threadpool_ios_cap, true, "");
|
||||
C10_DEFINE_int(caffe2_threadpool_macos_cap, true, "");
|
||||
C10_DEFINE_int(caffe2_threadpool_ios_cap, true, "")
|
||||
C10_DEFINE_int(caffe2_threadpool_macos_cap, true, "")
|
||||
|
||||
C10_DEFINE_int(pthreadpool_size, 0, "Override the default thread pool size.");
|
||||
C10_DEFINE_int(pthreadpool_size, 0, "Override the default thread pool size.")
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
@ -14,7 +14,7 @@ namespace detail {
|
||||
template <typename Derived>
|
||||
class _DropoutNd : public torch::nn::Cloneable<Derived> {
|
||||
public:
|
||||
_DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)){};
|
||||
_DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)) {}
|
||||
|
||||
explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) {
|
||||
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
||||
|
@ -227,7 +227,7 @@ class TORCH_API AdaptiveMaxPoolImpl : public torch::nn::Cloneable<Derived> {
|
||||
const AdaptiveMaxPoolOptions<output_size_t>& options_)
|
||||
: options(options_) {}
|
||||
|
||||
void reset() override{};
|
||||
void reset() override {}
|
||||
|
||||
/// Pretty prints the `AdaptiveMaxPool{1,2,3}d` module into the given
|
||||
/// `stream`.
|
||||
|
@ -128,7 +128,7 @@ class TORCH_API Optimizer {
|
||||
std::unique_ptr<OptimizerOptions> defaults)
|
||||
: Optimizer(
|
||||
{OptimizerParamGroup(std::move(parameters))},
|
||||
std::move(defaults)){};
|
||||
std::move(defaults)) {}
|
||||
|
||||
/// Adds the given param_group to the optimizer's param_group list.
|
||||
void add_param_group(const OptimizerParamGroup& param_group);
|
||||
|
@ -9,7 +9,7 @@
|
||||
namespace caffe2 {
|
||||
// Required for cpp_custom_type_hack to work
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
CAFFE_KNOWN_TYPE(at::RecordFunction);
|
||||
CAFFE_KNOWN_TYPE(at::RecordFunction)
|
||||
} // namespace caffe2
|
||||
|
||||
namespace torch::autograd::profiler {
|
||||
|
@ -77,7 +77,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
// Subclasses must override this method to return the backend name
|
||||
virtual const std::string getBackendName() const {
|
||||
TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
|
||||
};
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> broadcast(
|
||||
std::vector<at::Tensor>& /* tensors */,
|
||||
|
@ -39,7 +39,7 @@ C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING(
|
||||
GlooDeviceRegistry,
|
||||
::gloo::transport::Device,
|
||||
const std::string& /* interface */,
|
||||
const std::string& /* hostname */);
|
||||
const std::string& /* hostname */)
|
||||
|
||||
#if GLOO_HAVE_TRANSPORT_TCP
|
||||
static std::shared_ptr<::gloo::transport::Device> makeTCPDevice(
|
||||
@ -62,8 +62,8 @@ static std::shared_ptr<::gloo::transport::Device> makeTCPDevice(
|
||||
// Registry priority is per key identifier. We register TCP to `LINUX` for
|
||||
// the flexibility of other application to override by priority. Register
|
||||
// TCP to `TCP` for env "GLOO_DEVICE_TRANSPORT" override.
|
||||
C10_REGISTER_CREATOR(GlooDeviceRegistry, LINUX, makeTCPDevice);
|
||||
C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice);
|
||||
C10_REGISTER_CREATOR(GlooDeviceRegistry, LINUX, makeTCPDevice)
|
||||
C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice)
|
||||
#endif
|
||||
|
||||
#if GLOO_HAVE_TRANSPORT_TCP_TLS
|
||||
|
@ -97,7 +97,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
default:
|
||||
TORCH_CHECK(false, "THis should never happen!");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static BackendType strToBackendType(const std::string& backend) {
|
||||
if (backend == "undefined") {
|
||||
@ -113,7 +113,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
} else {
|
||||
return BackendType::CUSTOM;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Not used, set for backwards compatibility and only used for TypeDef in
|
||||
// Ops.cpp
|
||||
@ -146,11 +146,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
|
||||
virtual const std::string getBackendName() const {
|
||||
return backendTypeToString(backendType_);
|
||||
};
|
||||
}
|
||||
|
||||
BackendType getBackendType() const {
|
||||
return backendType_;
|
||||
};
|
||||
}
|
||||
|
||||
virtual void startCoalescing(c10::DeviceType deviceType) {
|
||||
// only nccl has implemented startCoalescing so only execute for nccl
|
||||
|
@ -316,7 +316,7 @@ parseWireSections(const void* data, size_t data_size) {
|
||||
|
||||
static const char* kMeta = "meta";
|
||||
static const char* kPayload = "payload";
|
||||
}; // namespace
|
||||
} // namespace
|
||||
|
||||
c10::List<at::Tensor> cloneSparseTensors(
|
||||
const std::vector<at::Tensor>& tensors) {
|
||||
|
@ -578,7 +578,7 @@ RangeValue::RangeValue(
|
||||
|
||||
SugaredValuePtr RangeValue::iter(const SourceRange& loc, GraphFunction& m) {
|
||||
return shared_from_this();
|
||||
};
|
||||
}
|
||||
|
||||
Value* RangeValue::len(const SourceRange& loc, GraphFunction& m) {
|
||||
if (static_len_) {
|
||||
|
@ -2,6 +2,6 @@
|
||||
|
||||
namespace torch::jit::mobile::nnc {
|
||||
|
||||
C10_DEFINE_REGISTRY(NNCKernelRegistry, NNCKernel);
|
||||
C10_DEFINE_REGISTRY(NNCKernelRegistry, NNCKernel)
|
||||
|
||||
} // namespace torch::jit::mobile::nnc
|
||||
|
@ -78,7 +78,7 @@ class TransposeFrozenLinear {
|
||||
node->replaceAllUsesWith(bias_result);
|
||||
}
|
||||
node->destroy();
|
||||
};
|
||||
}
|
||||
|
||||
void handleBlockAndSubblocks(Block* block) {}
|
||||
|
||||
|
@ -303,7 +303,7 @@ void MKLDNNLayerNormOp(Stack& stack, bool inplace) {
|
||||
at::native::mkldnn_layer_norm_last_index_weight_bias_f32(
|
||||
input, shape, weight, bias, eps, inplace);
|
||||
push(stack, dst);
|
||||
};
|
||||
}
|
||||
|
||||
Operation BroadOp(const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
|
@ -201,7 +201,7 @@ struct IntegerValueRefiner {
|
||||
|
||||
active_refinements_.pop_back();
|
||||
return block_refinements;
|
||||
};
|
||||
}
|
||||
|
||||
std::optional<int64_t> tryFindRefinement(Value* v) {
|
||||
for (const auto& ref : active_refinements_) {
|
||||
|
@ -126,7 +126,7 @@ struct ListLenRefiner {
|
||||
}
|
||||
active_refinements_.pop_back();
|
||||
return block_refinements;
|
||||
};
|
||||
}
|
||||
|
||||
std::optional<int64_t> tryFindRefinement(Value* v) {
|
||||
for (const auto& ref : active_refinements_) {
|
||||
|
@ -705,7 +705,7 @@ static bool is_module(
|
||||
return module_name.value() == module_qualified_name;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
}
|
||||
|
||||
bool aten_add_alpha_is_one(
|
||||
const Match& match,
|
||||
|
@ -57,7 +57,7 @@ std::shared_ptr<OperatorSet> nn_ops_first_input_preserving() {
|
||||
"aten::hardswish_(Tensor self) -> Tensor",
|
||||
});
|
||||
return ops;
|
||||
};
|
||||
}
|
||||
|
||||
// Requirements:
|
||||
// dims : Changed from first argument
|
||||
@ -70,5 +70,5 @@ std::shared_ptr<OperatorSet> ops_one_tensor_in_shape_transform() {
|
||||
"aten::flatten(Tensor self, int start_dim, int end_dim) -> Tensor",
|
||||
});
|
||||
return ops;
|
||||
};
|
||||
}
|
||||
} // namespace torch::jit
|
||||
|
@ -29,7 +29,7 @@ struct BooleanRefinementMapping {
|
||||
ListRefinement true_refine,
|
||||
ListRefinement false_refine)
|
||||
: true_refine_(std::move(true_refine)),
|
||||
false_refine_(std::move(false_refine)){};
|
||||
false_refine_(std::move(false_refine)) {}
|
||||
BooleanRefinementMapping() = default; // empty
|
||||
|
||||
static BooleanRefinementMapping FalseRefinements(
|
||||
|
@ -36,7 +36,7 @@ std::vector<IValue> boxInputs(const ProcessedNode& pnode) {
|
||||
|
||||
} // namespace
|
||||
|
||||
C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
|
||||
C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor)
|
||||
|
||||
bool nativeOpIsRegistered(const c10::Symbol& op_name) {
|
||||
const std::string name(op_name.toQualString());
|
||||
|
@ -38,7 +38,7 @@ TORCH_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
|
||||
return fn(n); \
|
||||
} \
|
||||
}; \
|
||||
C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id);
|
||||
C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id)
|
||||
|
||||
TORCH_DECLARE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
|
||||
#define REGISTER_NATIVE_OPERATOR_FUNCTOR(name, id, ...) \
|
||||
@ -49,7 +49,7 @@ TORCH_DECLARE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
|
||||
} \
|
||||
}; \
|
||||
C10_REGISTER_CLASS( \
|
||||
SRNativeOperatorRegistry, name, SRNativeOperatorFunctor_##id);
|
||||
SRNativeOperatorRegistry, name, SRNativeOperatorFunctor_##id)
|
||||
|
||||
inline at::Tensor create_empty_from(const at::Tensor& t) {
|
||||
return at::detail::empty_cpu(
|
||||
|
@ -142,7 +142,7 @@ IValue pickle_load(const std::vector<char>& data) {
|
||||
"pickle_load not supported on mobile "
|
||||
"(see https://github.com/pytorch/pytorch/pull/30108)");
|
||||
#endif
|
||||
};
|
||||
}
|
||||
|
||||
// A specialized version of pickle_load that can load custom objects.
|
||||
c10::IValue pickle_load_obj(std::string_view data) {
|
||||
|
@ -87,7 +87,7 @@ void* CodeGen::argToPtr(const BufferArg& bufferArg, const CallArg& callArg) {
|
||||
case ScalarType::Name: \
|
||||
return callArg.Name##Ptr();
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
|
||||
default:
|
||||
|
@ -165,7 +165,7 @@ class CodeGen::CallArg {
|
||||
memcpy(buffer_, &v, sizeof(Type)); \
|
||||
data_ = (void*)buffer_; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_TYPE_CTOR);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_TYPE_CTOR)
|
||||
#undef ARG_TYPE_CTOR
|
||||
|
||||
void* data() const {
|
||||
@ -199,7 +199,7 @@ class CodeGen::CallArg {
|
||||
TORCH_INTERNAL_ASSERT(data_ == (void*)buffer_); \
|
||||
return (Type*)data_; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_PTR_DEFINE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_PTR_DEFINE)
|
||||
#undef ARG_PTR_DEFINE
|
||||
|
||||
private:
|
||||
|
@ -148,7 +148,7 @@ void dispatch_binary_op(std::ostream& os, const BinaryOpNode<Op>* v) {
|
||||
case ScalarType::Name: \
|
||||
visit_binary_op<Type>(os, v->lhs(), v->rhs(), v->expr_type()); \
|
||||
break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
|
@ -375,7 +375,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
case ScalarType::Name: \
|
||||
value = compare_select_op<T, Type>(lhs, rhs, retval1, retval2, cmp_op); \
|
||||
break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
@ -407,7 +407,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
value_ = compare_select_op_helper<Type>( \
|
||||
lhs_v, rhs_v, ret_val1_v, ret_val2_v, cmp_op); \
|
||||
break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
@ -418,7 +418,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
TORCH_API void visit(const Name##ImmPtr& v) override { \
|
||||
value_ = InterpValue(v->value()); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT)
|
||||
#undef IMM_VISIT
|
||||
|
||||
TORCH_API void visit(const BlockPtr& v) override {
|
||||
@ -472,7 +472,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
case ScalarType::Name: \
|
||||
this->value_ = InterpValue(castValues<SrcType, Type>(src_dtype, v)); \
|
||||
break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DST_TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DST_TYPE_CASE)
|
||||
#undef DST_TYPE_CASE
|
||||
#define DST_TYPE_CASE_QUANT(Type, Name, CppType) \
|
||||
case ScalarType::Name: { \
|
||||
@ -507,7 +507,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
case ScalarType::Name: \
|
||||
doCastFromSrc<Type>(src_dtype, dst_dtype, value_); \
|
||||
break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, SRC_TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, SRC_TYPE_CASE)
|
||||
SRC_TYPE_CASE(c10::quint8, QUInt8);
|
||||
SRC_TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef SRC_TYPE_CASE
|
||||
@ -615,7 +615,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
std::vector<Type> v(lanes, value.as<Type>()); \
|
||||
value_ = InterpValue(v); \
|
||||
} break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
@ -758,7 +758,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
} \
|
||||
value_ = InterpValue(val); \
|
||||
} break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
@ -805,7 +805,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
ptr##Name[index[i]] = value[i]; \
|
||||
} \
|
||||
} break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
@ -1268,7 +1268,7 @@ void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) {
|
||||
impl_->bindVar(bufArg.var(), typed_data); \
|
||||
break; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
|
@ -30,7 +30,7 @@ class InterpValue {
|
||||
Name##values.push_back(v); \
|
||||
return; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
throw unsupported_dtype();
|
||||
}
|
||||
@ -89,9 +89,9 @@ class InterpValue {
|
||||
} \
|
||||
return Name##values[0]; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH);
|
||||
VALUE_AS_DISPATCH(c10::quint8, QUInt8);
|
||||
VALUE_AS_DISPATCH(c10::qint8, QInt8);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH)
|
||||
VALUE_AS_DISPATCH(c10::quint8, QUInt8)
|
||||
VALUE_AS_DISPATCH(c10::qint8, QInt8)
|
||||
#undef VALUE_AS_DISPATCH
|
||||
|
||||
#define VALUE_AS_VEC_DISPATCH(Type, Name) \
|
||||
@ -102,9 +102,9 @@ VALUE_AS_DISPATCH(c10::qint8, QInt8);
|
||||
} \
|
||||
return Name##values; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_VEC_DISPATCH);
|
||||
VALUE_AS_VEC_DISPATCH(c10::quint8, QUInt8);
|
||||
VALUE_AS_VEC_DISPATCH(c10::qint8, QInt8);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_VEC_DISPATCH)
|
||||
VALUE_AS_VEC_DISPATCH(c10::quint8, QUInt8)
|
||||
VALUE_AS_VEC_DISPATCH(c10::qint8, QInt8)
|
||||
#undef VALUE_AS_VEC_DISPATCH
|
||||
|
||||
template <typename Type>
|
||||
|
@ -87,7 +87,7 @@ ExprHandle ExprHandle::operator>>(const ExprHandle& other) const {
|
||||
|
||||
#define IMM_EXPR_DECLARE(Type, Name) \
|
||||
ExprHandle::ExprHandle(Type v) : ExprHandle(Name##Imm::make(v)) {}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE)
|
||||
#undef IMM_EXPR_DECLARE
|
||||
|
||||
ExprHandle sin(const ExprHandle& v) {
|
||||
|
@ -112,7 +112,7 @@ class TORCH_API ExprHandle {
|
||||
}
|
||||
|
||||
#define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE)
|
||||
#undef IMM_EXPR_DECLARE
|
||||
|
||||
template <class Op>
|
||||
@ -274,7 +274,7 @@ class TORCH_API Buf : public ExprNode<Buf> {
|
||||
|
||||
ExprPtr initializer() const {
|
||||
return initializer_;
|
||||
};
|
||||
}
|
||||
|
||||
ExprPtr qzero() const {
|
||||
return qzero_;
|
||||
|
@ -97,7 +97,7 @@ void DispatchParallel(
|
||||
|
||||
FOR_ALL_EXTERNAL_FUNCTIONS(DECLARE_EXTERNAL_FUNCTION)
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
DECLARE_EXTERNAL_FUNCTION(nnc_mkldnn_prepacked_conv_run);
|
||||
DECLARE_EXTERNAL_FUNCTION(nnc_mkldnn_prepacked_conv_run)
|
||||
#endif
|
||||
|
||||
TORCH_API void nnc_aten_free(size_t bufs_num, void** ptrs) noexcept;
|
||||
|
@ -119,7 +119,7 @@ using SyncThreadsPtr = NodePtr<SyncThreads>;
|
||||
#define IMM_DECLARE(Type, Name) \
|
||||
class Name##Imm; \
|
||||
using Name##ImmPtr = NodePtr<Name##Imm>;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE)
|
||||
#undef IMM_DECLARE
|
||||
|
||||
} // namespace torch::jit::tensorexpr
|
||||
|
@ -86,7 +86,7 @@ class TORCH_API HashProvider : public IRVisitor {
|
||||
CACHE_GUARD(); \
|
||||
putHash(v, hash_combine(#Name, v->value())); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT)
|
||||
#undef IMM_VISIT
|
||||
|
||||
void visit(const CastPtr& v) override;
|
||||
|
@ -276,7 +276,7 @@ bool immediateIsPositive(const ExprPtr& e) {
|
||||
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
|
||||
return imm->value() > 0; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
return false;
|
||||
}
|
||||
@ -286,7 +286,7 @@ bool immediateIsZero(const ExprPtr& e) {
|
||||
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
|
||||
return imm->value() == 0; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
return false;
|
||||
}
|
||||
|
@ -322,7 +322,7 @@ class Min : public BinaryOpNode<Min> {
|
||||
private: \
|
||||
Type value_; \
|
||||
};
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE)
|
||||
#undef IMM_DECLARE
|
||||
|
||||
// Get immediate by ScalarType.
|
||||
@ -332,7 +332,7 @@ ExprPtr getImmediateByType(ScalarType immType, T initialVal) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
case ScalarType::Name: \
|
||||
return alloc<Name##Imm>(Type(initialVal));
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
@ -375,7 +375,7 @@ T immediateAs(const ExprPtr& e) {
|
||||
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
|
||||
return imm->value(); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
throw unsupported_dtype();
|
||||
return 0;
|
||||
@ -392,7 +392,7 @@ bool immediateEquals(const ExprPtr& e, T val) {
|
||||
if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
|
||||
return imm->value() == val; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
throw unsupported_dtype();
|
||||
return false;
|
||||
|
@ -116,7 +116,7 @@ ExprPtr IRCloner::mutate(const CompareSelectPtr& v) {
|
||||
ExprPtr IRCloner::mutate(const Name##ImmPtr& v) { \
|
||||
return v; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE)
|
||||
#undef IMM_MUTATE_DEFINE
|
||||
|
||||
ExprPtr IRCloner::mutate(const CastPtr& v) {
|
||||
|
@ -25,7 +25,7 @@ class TORCH_API IRCloner : public IRMutator {
|
||||
ExprPtr mutate(const CompareSelectPtr& v) override;
|
||||
#define IMM_MUTATE_DECLARE(Type, Name) \
|
||||
ExprPtr mutate(const Name##ImmPtr& v) override;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DECLARE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DECLARE)
|
||||
#undef IMM_MUTATE_DECLARE
|
||||
ExprPtr mutate(const CastPtr& v) override;
|
||||
ExprPtr mutate(const BitCastPtr& v) override;
|
||||
|
@ -113,7 +113,7 @@ ExprPtr IRMutator::mutate(const CompareSelectPtr& v) {
|
||||
ExprPtr IRMutator::mutate(const Name##ImmPtr& v) { \
|
||||
return v; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE)
|
||||
#undef IMM_MUTATE_DEFINE
|
||||
|
||||
ExprPtr IRMutator::mutate(const CastPtr& v) {
|
||||
|
@ -231,7 +231,7 @@ static void formatImm(std::ostream& os, T v) {
|
||||
void IRPrinter::visit(const Name##ImmPtr& v) { \
|
||||
formatImm(os(), v->value()); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT)
|
||||
#undef IMM_PRINT_VISIT
|
||||
|
||||
void IRPrinter::visit(const CastPtr& v) {
|
||||
|
@ -32,7 +32,7 @@ class TORCH_API IRPrinter : public IRVisitor {
|
||||
void visit(const RshiftPtr& v) override;
|
||||
void visit(const CompareSelectPtr& v) override;
|
||||
#define IMM_PRINT_VISIT(Type, Name) void visit(const Name##ImmPtr& v) override;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT)
|
||||
#undef IMM_PRINT_VISIT
|
||||
void visit(const CastPtr& v) override;
|
||||
void visit(const BitCastPtr& v) override;
|
||||
|
@ -1293,7 +1293,7 @@ bool isOperandInMinMaxTerm(
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
}
|
||||
|
||||
// Simplifies the nested min-max pattern like:
|
||||
// * Max(Min(x, y), Min(x, z)) => Min(x, Max(y, z))
|
||||
|
@ -98,7 +98,7 @@ inline ExprPtr evaluateOp(const ExprPtr& v) {
|
||||
Type val = eval.value<Type>(); \
|
||||
return getImmediateByType(v->dtype().scalar_type(), val); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported datatype: " << v->dtype();
|
||||
|
@ -76,7 +76,7 @@ void IRVisitor::visit(const CompareSelectPtr& v) {
|
||||
|
||||
#define IMM_VISIT(Type, Name) \
|
||||
void IRVisitor::visit(const Name##ImmPtr& v) {}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT)
|
||||
#undef IMM_VISIT
|
||||
|
||||
void IRVisitor::visit(const CastPtr& v) {
|
||||
|
@ -329,7 +329,7 @@ class LLVMCodeGenImpl : public IRVisitor {
|
||||
void visit(const CompareSelectPtr& v) override;
|
||||
|
||||
#define IMM_VISIT_DECLARE(_1, Name) void visit(const Name##ImmPtr& v) override;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT_DECLARE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT_DECLARE)
|
||||
#undef IMM_VISIT_DECLARE
|
||||
|
||||
void visit(const CastPtr& v) override;
|
||||
|
@ -1753,7 +1753,7 @@ std::vector<ForPtr> LoopNest::distributeLoopAndParentsOverInnerLoops(
|
||||
static bool areEqual(const ExprPtr& expr1, const ExprPtr& expr2) {
|
||||
auto diff = IRSimplifier::simplify(alloc<Sub>(expr1, expr2));
|
||||
return diff->isConstant() && (immediateAs<int>(diff) == 0);
|
||||
};
|
||||
}
|
||||
|
||||
static bool doesExprContainAnyVar(
|
||||
const ExprPtr& expr,
|
||||
|
@ -32,7 +32,7 @@ ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
|
||||
case ScalarType::Name: \
|
||||
e = cast<Type>(e); \
|
||||
break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
#undef TYPE_CASE
|
||||
case ScalarType::QUInt8:
|
||||
e = cast<c10::quint8>(e);
|
||||
|
@ -272,7 +272,7 @@ void RegisterizerAnalysis::visit(const ForPtr& v) {
|
||||
|
||||
// having hoisted, now we can merge normally.
|
||||
mergeCurrentScopeIntoParent();
|
||||
};
|
||||
}
|
||||
|
||||
void RegisterizerAnalysis::visit(const CondPtr& v) {
|
||||
ExprPtr condition = v->condition();
|
||||
|
@ -342,9 +342,9 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor {
|
||||
stmtStack_.pop_front(); \
|
||||
}
|
||||
|
||||
STMT_ON_STACK(AtomicAdd);
|
||||
STMT_ON_STACK(Allocate);
|
||||
STMT_ON_STACK(Free);
|
||||
STMT_ON_STACK(AtomicAdd)
|
||||
STMT_ON_STACK(Allocate)
|
||||
STMT_ON_STACK(Free)
|
||||
|
||||
#undef STMT_ON_STACK
|
||||
|
||||
|
@ -22,8 +22,8 @@ AT_FORALL_SCALAR_TYPES_AND7(
|
||||
Float8_e4m3fn,
|
||||
Float8_e4m3fnuz,
|
||||
DTYPE_DEFINE)
|
||||
DTYPE_DEFINE(c10::quint8, QUInt8);
|
||||
DTYPE_DEFINE(c10::qint8, QInt8);
|
||||
DTYPE_DEFINE(c10::quint8, QUInt8)
|
||||
DTYPE_DEFINE(c10::qint8, QInt8)
|
||||
|
||||
#undef DTYPE_DEFINE
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user