mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Add flip batching rule
This commit is contained in:
@ -112,6 +112,7 @@ std::tuple<Tensor,optional<int64_t>> repeat_batch_rule(
|
||||
return std::make_tuple(self_.repeat(sizes_with_bdim), 0);
|
||||
}
|
||||
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>> diag_batch_rule(
|
||||
const Tensor& input,
|
||||
optional<int64_t> input_bdim,
|
||||
@ -144,14 +145,27 @@ std::tuple<Tensor,optional<int64_t>> _unsafe_view_batch_rule(
|
||||
return std::make_tuple(at::_unsafe_view(self, view_size), self_bdim);
|
||||
}
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>> flip_batch_rule(const Tensor& self, optional<int64_t> self_bdim, IntArrayRef dims) {
|
||||
if (!self_bdim) {
|
||||
return std::make_tuple(at::flip(self, dims), nullopt);
|
||||
}
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
VmapDimVector new_dims;
|
||||
for (auto i: dims) {
|
||||
new_dims.push_back(getPhysicalDim(self, true, i));
|
||||
}
|
||||
return std::make_tuple(at::flip(self_, new_dims), 0);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
|
||||
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
|
||||
VMAP_SUPPORT("repeat", repeat_batch_rule);
|
||||
VMAP_SUPPORT("diag", diag_batch_rule);
|
||||
VMAP_SUPPORT("triu", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(triu)), &at::triu, int64_t>));
|
||||
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
|
||||
VMAP_SUPPORT("flip", flip_batch_rule);
|
||||
VMAP_SUPPORT("tril", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(tril)), &at::tril, int64_t>));
|
||||
VMAP_SUPPORT("triu", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(triu)), &at::triu, int64_t>));
|
||||
VMAP_SUPPORT("repeat", repeat_batch_rule);
|
||||
VMAP_SUPPORT("_unsafe_view", _unsafe_view_batch_rule);
|
||||
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
|
||||
}
|
||||
|
||||
}}
|
||||
|
Reference in New Issue
Block a user