mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] moved some decompositions from batchingregistrations.cpp out
This commit is contained in:
@ -203,6 +203,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
m.impl("nll_loss_forward", nll_loss_forward_plumbing);
|
||||
OP_DECOMPOSE(nll_loss_nd);
|
||||
OP_DECOMPOSE(nll_loss);
|
||||
OP_DECOMPOSE(cross_entropy_loss);
|
||||
m.impl("nll_loss_backward", nll_loss_backward_plumbing);
|
||||
}
|
||||
|
||||
|
@ -52,6 +52,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
UNARY_POINTWISE_ALL(ceil);
|
||||
UNARY_POINTWISE_ALL(cos);
|
||||
UNARY_POINTWISE_ALL(cosh);
|
||||
OP_DECOMPOSE(conj);
|
||||
UNARY_POINTWISE(_conj);
|
||||
UNARY_POINTWISE_ALL(deg2rad);
|
||||
UNARY_POINTWISE_ALL(digamma);
|
||||
|
@ -151,9 +151,12 @@ std::tuple<Tensor,optional<int64_t>> flip_batch_rule(const Tensor& self, optiona
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("diag", diag_batch_rule);
|
||||
|
||||
OP_DECOMPOSE(expand_as);
|
||||
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
|
||||
VMAP_SUPPORT("flip", flip_batch_rule);
|
||||
OP_DECOMPOSE(meshgrid);
|
||||
OP_DECOMPOSE(narrow);
|
||||
m.impl("trace", trace_decomp);
|
||||
VMAP_SUPPORT("tril", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(tril)), &at::tril, int64_t>));
|
||||
VMAP_SUPPORT("triu", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(triu)), &at::triu, int64_t>));
|
||||
|
@ -1337,8 +1337,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
// m.impl("log_softmax.int", log_softmax_batching_rule);
|
||||
m.impl("_log_softmax", _log_softmax_batching_rule);
|
||||
m.impl("is_complex", native::is_complex);
|
||||
m.impl("conj", native::conj);
|
||||
m.impl("cross_entropy_loss", native::cross_entropy_loss);
|
||||
//
|
||||
// // inplace operations
|
||||
// m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
|
||||
@ -1356,12 +1354,10 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
|
||||
m.impl("diagonal", diagonal_batching_rule);
|
||||
m.impl("expand", expand_batching_rule);
|
||||
m.impl("expand_as", native::expand_as); // composite wrt autograd
|
||||
m.impl("movedim.intlist", movedim_batching_rule);
|
||||
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
|
||||
// NB: static_cast because there's another variant of narrow. However, we don't
|
||||
// want to support the other variant yet bc it isn't documented...
|
||||
m.impl("narrow", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t,int64_t)>(native::narrow)); // composite wrt autograd
|
||||
m.impl("numpy_T", native::numpy_T); // composite wrt autograd
|
||||
m.impl("permute", permute_batching_rule);
|
||||
m.impl("reshape", reshape_batching_rule);
|
||||
|
Reference in New Issue
Block a user