[functorch] Add flip batching rule

This commit is contained in:
Horace He
2021-06-27 11:49:09 -07:00
committed by Jon Janzen
parent b20e4decc4
commit c02cc07c96

View File

@ -112,6 +112,7 @@ std::tuple<Tensor,optional<int64_t>> repeat_batch_rule(
return std::make_tuple(self_.repeat(sizes_with_bdim), 0);
}
std::tuple<Tensor,optional<int64_t>> diag_batch_rule(
const Tensor& input,
optional<int64_t> input_bdim,
@ -144,14 +145,27 @@ std::tuple<Tensor,optional<int64_t>> _unsafe_view_batch_rule(
return std::make_tuple(at::_unsafe_view(self, view_size), self_bdim);
}
std::tuple<Tensor,optional<int64_t>> flip_batch_rule(const Tensor& self, optional<int64_t> self_bdim, IntArrayRef dims) {
if (!self_bdim) {
return std::make_tuple(at::flip(self, dims), nullopt);
}
auto self_ = moveBatchDimToFront(self, self_bdim);
VmapDimVector new_dims;
for (auto i: dims) {
new_dims.push_back(getPhysicalDim(self, true, i));
}
return std::make_tuple(at::flip(self_, new_dims), 0);
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
VMAP_SUPPORT("repeat", repeat_batch_rule);
VMAP_SUPPORT("diag", diag_batch_rule);
VMAP_SUPPORT("triu", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(triu)), &at::triu, int64_t>));
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
VMAP_SUPPORT("flip", flip_batch_rule);
VMAP_SUPPORT("tril", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(tril)), &at::tril, int64_t>));
VMAP_SUPPORT("triu", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(triu)), &at::triu, int64_t>));
VMAP_SUPPORT("repeat", repeat_batch_rule);
VMAP_SUPPORT("_unsafe_view", _unsafe_view_batch_rule);
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
}
}}