[functorch] update binary_pointwise_batch_rule

This commit is contained in:
Richard Zou
2021-07-29 06:27:14 -07:00
committed by Jon Janzen
parent 4ce294d25c
commit fd7e524b4e

View File

@ -22,7 +22,7 @@ static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& sec
}
template <typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
std::tuple<Tensor,optional<int64_t>> _binary_pointwise_batch_rule(
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
const Tensor& other, optional<int64_t> other_batch_dim,
ExtraArgs... extra_args) {
@ -55,6 +55,27 @@ std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
return std::make_tuple( std::move(result), 0 );
}
template <typename A, A a, typename C>
struct BinaryPointwiseBatchRuleHelper;
template <typename F, F Func, typename T1, typename T2, typename... T>
struct BinaryPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
static std::tuple<Tensor,optional<int64_t>> apply(
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
const Tensor& other, optional<int64_t> other_batch_dim,
T... extra_args) {
return _binary_pointwise_batch_rule<F, Func, T...>(
tensor, tensor_batch_dim, other, other_batch_dim,
std::forward<T>(extra_args)...);
}
};
#define BINARY_POINTWISE_BATCH_RULE(fn) SINGLE_ARG(\
BinaryPointwiseBatchRuleHelper<\
decltype(&fn),\
&fn,\
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
template <typename M, M Meth, typename... ExtraArgs>
void binary_pointwise_inplace_batch_rule(
Tensor& tensor, optional<int64_t> tensor_batch_dim,
@ -161,55 +182,54 @@ std::tuple<Tensor,optional<int64_t>> _s_where_batch_rule(
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
#define BINARY_POINTWISE_WITH_SCALAR(op) \
VMAP_SUPPORT(#op".Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &op, const Scalar&>));
#define BINARY_POINTWISE2(op, overload) \
VMAP_SUPPORT(#op"."#overload, BINARY_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
#define BINARY_POINTWISE(op) \
VMAP_SUPPORT(#op".Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &at::op>));
VMAP_SUPPORT(#op, BINARY_POINTWISE_BATCH_RULE(ATEN_FN(op)));
#define UNARY_POINTWISE2(op, overload) \
VMAP_SUPPORT(#op"."#overload, BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, overload)));
#define UNARY_POINTWISE(op) \
VMAP_SUPPORT(#op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
BINARY_POINTWISE_WITH_SCALAR(add);
VMAP_SUPPORT("add.Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(add, Scalar)));
VMAP_SUPPORT("atan2", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(atan2)), &at::atan2>));
BINARY_POINTWISE2(add, Tensor);
UNARY_POINTWISE2(add, Scalar);
BINARY_POINTWISE(atan2);
VMAP_SUPPORT("clamp", BASIC_UNARY_BATCH_RULE(ATEN_FN(clamp)));
UNARY_POINTWISE(clamp);
VMAP_SUPPORT("clamp.Tensor", clamp_tensor_batch_rule);
VMAP_SUPPORT("clamp_min.Tensor",
SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(clamp_min, Tensor)), &at::clamp_min>));
VMAP_SUPPORT("clamp_min", BASIC_UNARY_BATCH_RULE(ATEN_FN(clamp_min)));
VMAP_SUPPORT("clamp_max.Tensor",
SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(clamp_max, Tensor)), &at::clamp_max>));
VMAP_SUPPORT("clamp_max", BASIC_UNARY_BATCH_RULE(ATEN_FN(clamp_max)));
BINARY_POINTWISE2(clamp_min, Tensor);
UNARY_POINTWISE(clamp_min);
BINARY_POINTWISE2(clamp_max, Tensor);
UNARY_POINTWISE(clamp_max);
BINARY_POINTWISE(div);
VMAP_SUPPORT("div.Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(div, Scalar)));
VMAP_SUPPORT("div.Scalar_mode", BASIC_UNARY_BATCH_RULE(ATEN_FN2(div, Scalar_mode)));
VMAP_SUPPORT("div.Tensor_mode", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(div, Tensor_mode)), &at::div, c10::optional<string_view>>));
BINARY_POINTWISE2(div, Tensor);
UNARY_POINTWISE2(div, Scalar);
UNARY_POINTWISE2(div, Scalar_mode);
BINARY_POINTWISE2(div, Tensor_mode);
VMAP_SUPPORT("maximum", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(maximum)), &at::maximum>));
VMAP_SUPPORT("minimum", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(minimum)), &at::minimum>));
BINARY_POINTWISE(maximum);
BINARY_POINTWISE(minimum);
BINARY_POINTWISE(mul);
VMAP_SUPPORT("mul.Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(mul, Scalar)));
BINARY_POINTWISE2(mul, Tensor);
UNARY_POINTWISE2(mul, Scalar);
// at::pow has three out-of-place overloads
VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Tensor)), &at::pow>));
VMAP_SUPPORT("pow.Tensor_Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(pow, Tensor_Scalar)));
BINARY_POINTWISE2(pow, Tensor_Tensor);
UNARY_POINTWISE2(pow, Tensor_Scalar);
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
BINARY_POINTWISE_WITH_SCALAR(sub);
VMAP_SUPPORT("sub.Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(sub, Scalar)));
BINARY_POINTWISE2(sub, Tensor);
UNARY_POINTWISE2(sub, Scalar)
BINARY_POINTWISE_WITH_SCALAR(rsub);
VMAP_SUPPORT("rsub.Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(rsub, Scalar)));
BINARY_POINTWISE2(rsub, Tensor);
UNARY_POINTWISE2(rsub, Scalar);
VMAP_SUPPORT("sigmoid_backward", SINGLE_ARG(
binary_pointwise_batch_rule<decltype(&at::sigmoid_backward), &at::sigmoid_backward>));
VMAP_SUPPORT("tanh_backward", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&at::tanh_backward), &at::tanh_backward>));
VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
BINARY_POINTWISE(sigmoid_backward);
BINARY_POINTWISE(tanh_backward);
BINARY_POINTWISE(threshold_backward);
VMAP_SUPPORT("fmin", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(fmin)), &at::fmin>));
VMAP_SUPPORT("fmax", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(fmax)), &at::fmax>));
BINARY_POINTWISE(fmin);
BINARY_POINTWISE(fmax);
OP_DECOMPOSE2(max, other);
OP_DECOMPOSE2(min, other);
@ -252,7 +272,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
#define COMPARISON_POINTWISE(op) \
VMAP_SUPPORT(#op".Tensor", \
SINGLE_ARG(comparison_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &at::op>)); \
VMAP_SUPPORT(#op".Scalar", BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, Scalar)))
UNARY_POINTWISE2(op, Scalar)
COMPARISON_POINTWISE(eq);
COMPARISON_POINTWISE(gt);
@ -263,10 +283,10 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
#undef COMPARISON_POINTWISE
#undef SINGLE_ARG
#undef BINARY_POINTWISE_BATCH_RULE_SCALAR
#undef BINARY_POINTWISE_BATCH_RULE
#undef BINARY_POINTWISE_WITH_SCALAR
#undef BINARY_POINTWISE2
#undef BINARY_POINTWISE
#undef UNARY_POINTWISE2
#undef UNARY_POINTWISE
}
}}