[functorch] moved some decompositions from batchingregistrations.cpp out

This commit is contained in:
Horace He
2021-06-29 18:16:30 -07:00
committed by Jon Janzen
parent 71042b1b16
commit 23fc2e0f6e
4 changed files with 5 additions and 4 deletions

View File

@ -203,6 +203,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("nll_loss_forward", nll_loss_forward_plumbing);
OP_DECOMPOSE(nll_loss_nd);
OP_DECOMPOSE(nll_loss);
OP_DECOMPOSE(cross_entropy_loss);
m.impl("nll_loss_backward", nll_loss_backward_plumbing);
}

View File

@ -52,6 +52,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
UNARY_POINTWISE_ALL(ceil);
UNARY_POINTWISE_ALL(cos);
UNARY_POINTWISE_ALL(cosh);
OP_DECOMPOSE(conj);
UNARY_POINTWISE(_conj);
UNARY_POINTWISE_ALL(deg2rad);
UNARY_POINTWISE_ALL(digamma);

View File

@ -151,9 +151,12 @@ std::tuple<Tensor,optional<int64_t>> flip_batch_rule(const Tensor& self, optiona
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("diag", diag_batch_rule);
OP_DECOMPOSE(expand_as);
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
VMAP_SUPPORT("flip", flip_batch_rule);
OP_DECOMPOSE(meshgrid);
OP_DECOMPOSE(narrow);
m.impl("trace", trace_decomp);
VMAP_SUPPORT("tril", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(tril)), &at::tril, int64_t>));
VMAP_SUPPORT("triu", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(triu)), &at::triu, int64_t>));

View File

@ -1337,8 +1337,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
// m.impl("log_softmax.int", log_softmax_batching_rule);
m.impl("_log_softmax", _log_softmax_batching_rule);
m.impl("is_complex", native::is_complex);
m.impl("conj", native::conj);
m.impl("cross_entropy_loss", native::cross_entropy_loss);
//
// // inplace operations
// m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
@ -1356,12 +1354,10 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
m.impl("diagonal", diagonal_batching_rule);
m.impl("expand", expand_batching_rule);
m.impl("expand_as", native::expand_as); // composite wrt autograd
m.impl("movedim.intlist", movedim_batching_rule);
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
// NB: static_cast because there's another variant of narrow. However, we don't
// want to support the other variant yet bc it isn't documented...
m.impl("narrow", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t,int64_t)>(native::narrow)); // composite wrt autograd
m.impl("numpy_T", native::numpy_T); // composite wrt autograd
m.impl("permute", permute_batching_rule);
m.impl("reshape", reshape_batching_rule);