mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] update binary_pointwise_batch_rule
This commit is contained in:
@ -22,7 +22,7 @@ static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& sec
|
||||
}
|
||||
|
||||
template <typename F, F Func, typename... ExtraArgs>
|
||||
std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
|
||||
std::tuple<Tensor,optional<int64_t>> _binary_pointwise_batch_rule(
|
||||
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
|
||||
const Tensor& other, optional<int64_t> other_batch_dim,
|
||||
ExtraArgs... extra_args) {
|
||||
@ -55,6 +55,27 @@ std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
|
||||
return std::make_tuple( std::move(result), 0 );
|
||||
}
|
||||
|
||||
template <typename A, A a, typename C>
|
||||
struct BinaryPointwiseBatchRuleHelper;
|
||||
|
||||
template <typename F, F Func, typename T1, typename T2, typename... T>
|
||||
struct BinaryPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
|
||||
static std::tuple<Tensor,optional<int64_t>> apply(
|
||||
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
|
||||
const Tensor& other, optional<int64_t> other_batch_dim,
|
||||
T... extra_args) {
|
||||
return _binary_pointwise_batch_rule<F, Func, T...>(
|
||||
tensor, tensor_batch_dim, other, other_batch_dim,
|
||||
std::forward<T>(extra_args)...);
|
||||
}
|
||||
};
|
||||
|
||||
#define BINARY_POINTWISE_BATCH_RULE(fn) SINGLE_ARG(\
|
||||
BinaryPointwiseBatchRuleHelper<\
|
||||
decltype(&fn),\
|
||||
&fn,\
|
||||
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
||||
|
||||
template <typename M, M Meth, typename... ExtraArgs>
|
||||
void binary_pointwise_inplace_batch_rule(
|
||||
Tensor& tensor, optional<int64_t> tensor_batch_dim,
|
||||
@ -161,55 +182,54 @@ std::tuple<Tensor,optional<int64_t>> _s_where_batch_rule(
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
#define BINARY_POINTWISE_WITH_SCALAR(op) \
|
||||
VMAP_SUPPORT(#op".Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &op, const Scalar&>));
|
||||
|
||||
#define BINARY_POINTWISE2(op, overload) \
|
||||
VMAP_SUPPORT(#op"."#overload, BINARY_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
|
||||
#define BINARY_POINTWISE(op) \
|
||||
VMAP_SUPPORT(#op".Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &at::op>));
|
||||
VMAP_SUPPORT(#op, BINARY_POINTWISE_BATCH_RULE(ATEN_FN(op)));
|
||||
#define UNARY_POINTWISE2(op, overload) \
|
||||
VMAP_SUPPORT(#op"."#overload, BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, overload)));
|
||||
#define UNARY_POINTWISE(op) \
|
||||
VMAP_SUPPORT(#op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
|
||||
|
||||
BINARY_POINTWISE_WITH_SCALAR(add);
|
||||
VMAP_SUPPORT("add.Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(add, Scalar)));
|
||||
VMAP_SUPPORT("atan2", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(atan2)), &at::atan2>));
|
||||
BINARY_POINTWISE2(add, Tensor);
|
||||
UNARY_POINTWISE2(add, Scalar);
|
||||
BINARY_POINTWISE(atan2);
|
||||
|
||||
VMAP_SUPPORT("clamp", BASIC_UNARY_BATCH_RULE(ATEN_FN(clamp)));
|
||||
UNARY_POINTWISE(clamp);
|
||||
VMAP_SUPPORT("clamp.Tensor", clamp_tensor_batch_rule);
|
||||
VMAP_SUPPORT("clamp_min.Tensor",
|
||||
SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(clamp_min, Tensor)), &at::clamp_min>));
|
||||
VMAP_SUPPORT("clamp_min", BASIC_UNARY_BATCH_RULE(ATEN_FN(clamp_min)));
|
||||
VMAP_SUPPORT("clamp_max.Tensor",
|
||||
SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(clamp_max, Tensor)), &at::clamp_max>));
|
||||
VMAP_SUPPORT("clamp_max", BASIC_UNARY_BATCH_RULE(ATEN_FN(clamp_max)));
|
||||
BINARY_POINTWISE2(clamp_min, Tensor);
|
||||
UNARY_POINTWISE(clamp_min);
|
||||
BINARY_POINTWISE2(clamp_max, Tensor);
|
||||
UNARY_POINTWISE(clamp_max);
|
||||
|
||||
BINARY_POINTWISE(div);
|
||||
VMAP_SUPPORT("div.Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(div, Scalar)));
|
||||
VMAP_SUPPORT("div.Scalar_mode", BASIC_UNARY_BATCH_RULE(ATEN_FN2(div, Scalar_mode)));
|
||||
VMAP_SUPPORT("div.Tensor_mode", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(div, Tensor_mode)), &at::div, c10::optional<string_view>>));
|
||||
BINARY_POINTWISE2(div, Tensor);
|
||||
UNARY_POINTWISE2(div, Scalar);
|
||||
UNARY_POINTWISE2(div, Scalar_mode);
|
||||
BINARY_POINTWISE2(div, Tensor_mode);
|
||||
|
||||
VMAP_SUPPORT("maximum", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(maximum)), &at::maximum>));
|
||||
VMAP_SUPPORT("minimum", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(minimum)), &at::minimum>));
|
||||
BINARY_POINTWISE(maximum);
|
||||
BINARY_POINTWISE(minimum);
|
||||
|
||||
BINARY_POINTWISE(mul);
|
||||
VMAP_SUPPORT("mul.Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(mul, Scalar)));
|
||||
BINARY_POINTWISE2(mul, Tensor);
|
||||
UNARY_POINTWISE2(mul, Scalar);
|
||||
|
||||
// at::pow has three out-of-place overloads
|
||||
VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Tensor)), &at::pow>));
|
||||
VMAP_SUPPORT("pow.Tensor_Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(pow, Tensor_Scalar)));
|
||||
BINARY_POINTWISE2(pow, Tensor_Tensor);
|
||||
UNARY_POINTWISE2(pow, Tensor_Scalar);
|
||||
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
|
||||
|
||||
BINARY_POINTWISE_WITH_SCALAR(sub);
|
||||
VMAP_SUPPORT("sub.Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(sub, Scalar)));
|
||||
BINARY_POINTWISE2(sub, Tensor);
|
||||
UNARY_POINTWISE2(sub, Scalar)
|
||||
|
||||
BINARY_POINTWISE_WITH_SCALAR(rsub);
|
||||
VMAP_SUPPORT("rsub.Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(rsub, Scalar)));
|
||||
BINARY_POINTWISE2(rsub, Tensor);
|
||||
UNARY_POINTWISE2(rsub, Scalar);
|
||||
|
||||
VMAP_SUPPORT("sigmoid_backward", SINGLE_ARG(
|
||||
binary_pointwise_batch_rule<decltype(&at::sigmoid_backward), &at::sigmoid_backward>));
|
||||
VMAP_SUPPORT("tanh_backward", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&at::tanh_backward), &at::tanh_backward>));
|
||||
VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
|
||||
binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
|
||||
BINARY_POINTWISE(sigmoid_backward);
|
||||
BINARY_POINTWISE(tanh_backward);
|
||||
BINARY_POINTWISE(threshold_backward);
|
||||
|
||||
VMAP_SUPPORT("fmin", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(fmin)), &at::fmin>));
|
||||
VMAP_SUPPORT("fmax", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(fmax)), &at::fmax>));
|
||||
BINARY_POINTWISE(fmin);
|
||||
BINARY_POINTWISE(fmax);
|
||||
|
||||
OP_DECOMPOSE2(max, other);
|
||||
OP_DECOMPOSE2(min, other);
|
||||
@ -252,7 +272,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
#define COMPARISON_POINTWISE(op) \
|
||||
VMAP_SUPPORT(#op".Tensor", \
|
||||
SINGLE_ARG(comparison_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &at::op>)); \
|
||||
VMAP_SUPPORT(#op".Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, Scalar)))
|
||||
UNARY_POINTWISE2(op, Scalar)
|
||||
|
||||
COMPARISON_POINTWISE(eq);
|
||||
COMPARISON_POINTWISE(gt);
|
||||
@ -263,10 +283,10 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
|
||||
#undef COMPARISON_POINTWISE
|
||||
#undef SINGLE_ARG
|
||||
#undef BINARY_POINTWISE_BATCH_RULE_SCALAR
|
||||
#undef BINARY_POINTWISE_BATCH_RULE
|
||||
#undef BINARY_POINTWISE_WITH_SCALAR
|
||||
#undef BINARY_POINTWISE2
|
||||
#undef BINARY_POINTWISE
|
||||
#undef UNARY_POINTWISE2
|
||||
#undef UNARY_POINTWISE
|
||||
}
|
||||
|
||||
}}
|
||||
|
Reference in New Issue
Block a user