mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[functorch] Batching rules for comparison ops
This commit is contained in:
@ -63,6 +63,31 @@ std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
|
|||||||
return { std::move(result), std::move(result_batch_dim) };
|
return { std::move(result), std::move(result_batch_dim) };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename F, F Func>
|
||||||
|
std::tuple<Tensor,optional<int64_t>> comparison_pointwise_batch_rule(
|
||||||
|
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
|
||||||
|
const Tensor& other, optional<int64_t> other_batch_dim) {
|
||||||
|
// compute max logical rank
|
||||||
|
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
|
||||||
|
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
|
||||||
|
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
|
||||||
|
|
||||||
|
auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
|
||||||
|
auto other_ = moveBatchDimToFront(other, other_batch_dim);
|
||||||
|
|
||||||
|
// If the dimensions aren't aligned, we need to line them up.
|
||||||
|
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
|
||||||
|
// Note that only tensors that have a batch dim need to be modified.
|
||||||
|
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
|
||||||
|
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
|
||||||
|
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
|
||||||
|
|
||||||
|
auto result = Func(tensor_, other_);
|
||||||
|
auto result_batch_dim = tensor_batch_dim.has_value() || other_batch_dim.has_value()
|
||||||
|
? optional<int64_t>{0} : nullopt;
|
||||||
|
return { std::move(result), std::move(result_batch_dim) };
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> pow_scalar_tensor_batch_rule(
|
std::tuple<Tensor,optional<int64_t>> pow_scalar_tensor_batch_rule(
|
||||||
const Scalar& other,
|
const Scalar& other,
|
||||||
const Tensor& tensor, optional<int64_t> tensor_batch_dim) {
|
const Tensor& tensor, optional<int64_t> tensor_batch_dim) {
|
||||||
@ -100,6 +125,21 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
|||||||
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>));
|
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>));
|
||||||
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
|
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
|
||||||
|
|
||||||
|
|
||||||
|
#define COMPARISON_POINTWISE(op) \
|
||||||
|
VMAP_SUPPORT(#op".Tensor", \
|
||||||
|
SINGLE_ARG(comparison_pointwise_batch_rule<TensorTensorType, &at::op>)); \
|
||||||
|
VMAP_SUPPORT(#op".Scalar", \
|
||||||
|
SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::op, const Scalar&>));
|
||||||
|
|
||||||
|
COMPARISON_POINTWISE(eq);
|
||||||
|
COMPARISON_POINTWISE(gt);
|
||||||
|
COMPARISON_POINTWISE(ge);
|
||||||
|
COMPARISON_POINTWISE(le);
|
||||||
|
COMPARISON_POINTWISE(lt);
|
||||||
|
COMPARISON_POINTWISE(ne);
|
||||||
|
|
||||||
|
#undef COMPARISON_POINTWISE
|
||||||
#undef SINGLE_ARG
|
#undef SINGLE_ARG
|
||||||
#undef BINARY_POINTWISE_BATCH_RULE_SCALAR
|
#undef BINARY_POINTWISE_BATCH_RULE_SCALAR
|
||||||
#undef BINARY_POINTWISE_BATCH_RULE
|
#undef BINARY_POINTWISE_BATCH_RULE
|
||||||
|
@ -1355,13 +1355,6 @@ Tensor new_empty_strided_batching_rule(
|
|||||||
return physical_view.getPhysicalToLogicalMap().apply(result);
|
return physical_view.getPhysicalToLogicalMap().apply(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename F, F Func>
|
|
||||||
Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) {
|
|
||||||
auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
|
|
||||||
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor());
|
|
||||||
return physical_args[0].getPhysicalToLogicalMap().apply(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool BatchedTensor_is_leaf(const Tensor& self) {
|
bool BatchedTensor_is_leaf(const Tensor& self) {
|
||||||
if (torch::autograd::impl::get_autograd_meta(self)) {
|
if (torch::autograd::impl::get_autograd_meta(self)) {
|
||||||
return torch::autograd::impl::get_autograd_meta(self)->grad_fn_ == nullptr;
|
return torch::autograd::impl::get_autograd_meta(self)->grad_fn_ == nullptr;
|
||||||
@ -1671,20 +1664,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
|||||||
// // m.impl("new_zeros", new_zeros_batching_rule);
|
// // m.impl("new_zeros", new_zeros_batching_rule);
|
||||||
// //
|
// //
|
||||||
// // m.impl("contiguous", contiguous_batching_rule);
|
// // m.impl("contiguous", contiguous_batching_rule);
|
||||||
// //
|
|
||||||
// // // Comparison ops
|
|
||||||
// // #define COMPARISON_POINTWISE(op) \
|
|
||||||
// // m.impl(#op".Tensor", comparison_pointwise_batching_rule<TensorTensorType, at::op>); \
|
|
||||||
// // m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, Scalar>);
|
|
||||||
// //
|
|
||||||
// // COMPARISON_POINTWISE(eq);
|
|
||||||
// // COMPARISON_POINTWISE(gt);
|
|
||||||
// // COMPARISON_POINTWISE(ge);
|
|
||||||
// // COMPARISON_POINTWISE(le);
|
|
||||||
// // COMPARISON_POINTWISE(lt);
|
|
||||||
// // COMPARISON_POINTWISE(ne);
|
|
||||||
// //
|
|
||||||
// #undef COMPARISON_POINTWISE
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user