[functorch] Batching rules for comparison ops

This commit is contained in:
Richard Zou
2021-05-05 13:20:07 -07:00
committed by Jon Janzen
parent f92fbeef74
commit cb51b16e59
2 changed files with 40 additions and 21 deletions

View File

@ -63,6 +63,31 @@ std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
return { std::move(result), std::move(result_batch_dim) };
}
template <typename F, F Func>
std::tuple<Tensor,optional<int64_t>> comparison_pointwise_batch_rule(
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
const Tensor& other, optional<int64_t> other_batch_dim) {
// compute max logical rank
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
auto other_ = moveBatchDimToFront(other, other_batch_dim);
// If the dimensions aren't aligned, we need to line them up.
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
// Note that only tensors that have a batch dim need to be modified.
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
auto result = Func(tensor_, other_);
auto result_batch_dim = tensor_batch_dim.has_value() || other_batch_dim.has_value()
? optional<int64_t>{0} : nullopt;
return { std::move(result), std::move(result_batch_dim) };
}
std::tuple<Tensor,optional<int64_t>> pow_scalar_tensor_batch_rule(
const Scalar& other,
const Tensor& tensor, optional<int64_t> tensor_batch_dim) {
@ -100,6 +125,21 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>));
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
#define COMPARISON_POINTWISE(op) \
VMAP_SUPPORT(#op".Tensor", \
SINGLE_ARG(comparison_pointwise_batch_rule<TensorTensorType, &at::op>)); \
VMAP_SUPPORT(#op".Scalar", \
SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::op, const Scalar&>));
COMPARISON_POINTWISE(eq);
COMPARISON_POINTWISE(gt);
COMPARISON_POINTWISE(ge);
COMPARISON_POINTWISE(le);
COMPARISON_POINTWISE(lt);
COMPARISON_POINTWISE(ne);
#undef COMPARISON_POINTWISE
#undef SINGLE_ARG
#undef BINARY_POINTWISE_BATCH_RULE_SCALAR
#undef BINARY_POINTWISE_BATCH_RULE

View File

@ -1355,13 +1355,6 @@ Tensor new_empty_strided_batching_rule(
return physical_view.getPhysicalToLogicalMap().apply(result);
}
template <typename F, F Func>
Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) {
auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor());
return physical_args[0].getPhysicalToLogicalMap().apply(result);
}
bool BatchedTensor_is_leaf(const Tensor& self) {
if (torch::autograd::impl::get_autograd_meta(self)) {
return torch::autograd::impl::get_autograd_meta(self)->grad_fn_ == nullptr;
@ -1671,20 +1664,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
// // m.impl("new_zeros", new_zeros_batching_rule);
// //
// // m.impl("contiguous", contiguous_batching_rule);
// //
// // // Comparison ops
// // #define COMPARISON_POINTWISE(op) \
// // m.impl(#op".Tensor", comparison_pointwise_batching_rule<TensorTensorType, at::op>); \
// // m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, Scalar>);
// //
// // COMPARISON_POINTWISE(eq);
// // COMPARISON_POINTWISE(gt);
// // COMPARISON_POINTWISE(ge);
// // COMPARISON_POINTWISE(le);
// // COMPARISON_POINTWISE(lt);
// // COMPARISON_POINTWISE(ne);
// //
// #undef COMPARISON_POINTWISE
}
}