mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[special] add torch.special
namespace (#52296)
Summary: Reference: https://github.com/pytorch/pytorch/issues/50345 * Add `torch.special` namespace * Add `torch.special.gammaln` (alias to `torch.lgamma`) TODO: * Add proper entries for docs. * [x] Add .rst file entry * [x] Add documentation * [x] Update `lgamma` OpInfo entry for alias to `special.gammaln`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/52296 Reviewed By: ngimel Differential Revision: D26754890 Pulled By: mruberry fbshipit-source-id: 73479f68989d6443ad07b7b02763fa98973c15f6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c5b0c2fa8b
commit
c4c77e2001
@ -207,6 +207,7 @@ libtorch_python_generated_sources = [
|
||||
"torch/csrc/autograd/generated/python_nn_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_fft_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_linalg_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_special_functions.cpp",
|
||||
]
|
||||
|
||||
genrule(
|
||||
|
@ -423,7 +423,6 @@ _(aten, leaky_relu) \
|
||||
_(aten, leaky_relu_backward) \
|
||||
_(aten, leaky_relu_forward) \
|
||||
_(aten, lerp) \
|
||||
_(aten, lgamma) \
|
||||
_(aten, linear) \
|
||||
_(aten, linspace) \
|
||||
_(aten, log) \
|
||||
|
@ -303,6 +303,8 @@ namespace c10 {
|
||||
_(aten, swapdims_) \
|
||||
_(aten, movedim) \
|
||||
_(aten, moveaxis) \
|
||||
_(aten, lgamma) \
|
||||
_(aten, special_gammaln) \
|
||||
_(aten, has_torch_function) \
|
||||
FORALL_ATEN_BASE_SYMBOLS(_) \
|
||||
_(onnx, Add) \
|
||||
|
@ -541,7 +541,8 @@ name that we skip in python binding generation, e.g. `*_backward`. Check
|
||||
The generated bindings are either exposed as methods on python_variable or functions on
|
||||
the torch._C._nn (marked with `python_module: nn`),
|
||||
torch._C._fft (marked with `python_module: fft`),
|
||||
or torch._C._linalg (marked with `python_module: linalg`) objects.
|
||||
torch._C._linalg (marked with `python_module: linalg`) objects,
|
||||
or torch._C._special (marked with `python_module: special`) objects.
|
||||
|
||||
### Can it handle being passed Variables?
|
||||
|
||||
|
@ -670,6 +670,11 @@ Tensor& lgamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_fl
|
||||
Tensor lgamma(const Tensor& self) { return unary_op_impl_float(self, lgamma_stub); }
|
||||
Tensor& lgamma_(Tensor& self) { return unary_op_impl_(self, at::lgamma_out); }
|
||||
|
||||
// alias for lgamma, implements special.gammanln equivalent to
|
||||
// scipy.special.gammaln
|
||||
Tensor special_gammaln(const Tensor& self) { return self.lgamma(); }
|
||||
Tensor& special_gammaln_out(const Tensor& self, Tensor& result) { return at::lgamma_out(result, self); }
|
||||
|
||||
DEFINE_DISPATCH(abs_stub);
|
||||
DEFINE_DISPATCH(angle_stub);
|
||||
DEFINE_DISPATCH(real_stub);
|
||||
|
@ -8826,6 +8826,23 @@
|
||||
- func: _remove_batch_dim(Tensor self, int level, int batch_size, int out_dim) -> Tensor
|
||||
variants: function
|
||||
|
||||
## Functions related to the `torch.special` namespace
|
||||
# Note [special namespace binding]
|
||||
# Functions in the special python module should have their names start with
|
||||
# "special_" underscore and be bound to the desired Python name in
|
||||
# torch/special/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/special.h.
|
||||
# The "special_" names should be hidden from the user and not documented.
|
||||
|
||||
- func: special_gammaln(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: special_gammaln_out
|
||||
|
||||
## Functions related to the fast Fourier transform and the torch.fft namespace
|
||||
# Note [FFT namespace binding]
|
||||
# Functions in the fft python module should have their names start with
|
||||
|
@ -435,6 +435,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_nn_functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_fft_functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_linalg_functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
|
||||
)
|
||||
|
||||
set(GENERATED_H_PYTHON
|
||||
@ -479,6 +480,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
"${TOOLS_PATH}/autograd/templates/python_nn_functions.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/python_fft_functions.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/python_linalg_functions.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/python_special_functions.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/variable_factories.h"
|
||||
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py"
|
||||
"${TOOLS_PATH}/autograd/deprecated.yaml"
|
||||
|
@ -66,6 +66,7 @@ Features described in this documentation are classified by release status:
|
||||
torch.hub <hub>
|
||||
torch.jit <jit>
|
||||
torch.linalg <linalg>
|
||||
torch.special <special>
|
||||
torch.overrides
|
||||
profiler
|
||||
nn.init
|
||||
|
21
docs/source/special.rst
Normal file
21
docs/source/special.rst
Normal file
@ -0,0 +1,21 @@
|
||||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
torch.special
|
||||
=============
|
||||
|
||||
The torch.special module, modeled after SciPy's `special <https://docs.scipy.org/doc/scipy/reference/special.html>`_ module.
|
||||
|
||||
This module is in BETA. New functions are still being added, and some
|
||||
functions may change in future PyTorch releases. See the documentation of each
|
||||
function for details.
|
||||
|
||||
.. automodule:: torch.special
|
||||
:noindex:
|
||||
|
||||
.. currentmodule:: torch.special
|
||||
|
||||
Functions
|
||||
-----------------------
|
||||
|
||||
.. autofunction:: gammaln
|
@ -27,6 +27,7 @@ set(TORCH_API_TEST_SOURCES
|
||||
${TORCH_API_TEST_DIR}/sequential.cpp
|
||||
${TORCH_API_TEST_DIR}/transformer.cpp
|
||||
${TORCH_API_TEST_DIR}/serialize.cpp
|
||||
${TORCH_API_TEST_DIR}/special.cpp
|
||||
${TORCH_API_TEST_DIR}/static.cpp
|
||||
${TORCH_API_TEST_DIR}/support.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_cuda.cpp
|
||||
|
13
test/cpp/api/special.cpp
Normal file
13
test/cpp/api/special.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include <torch/special.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
// Simple test that verifies the special namespace is registered properly
|
||||
// properly in C++
|
||||
TEST(SpecialTest, special) {
|
||||
auto t = torch::randn(128, torch::kDouble);
|
||||
torch::special::gammaln(t);
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
# Generates Python bindings for ATen functions
|
||||
#
|
||||
# The bindings are generated as methods on python_variable or functions on the
|
||||
# torch._C._nn. torch._C._fft, or torch._C._linalg objects.
|
||||
# torch._C._nn. torch._C._fft, torch._C._linalg or torch._C._special objects.
|
||||
#
|
||||
|
||||
# Code tries to stick to the following rules:
|
||||
@ -132,6 +132,9 @@ def is_py_fft_function(f: NativeFunction) -> bool:
|
||||
def is_py_linalg_function(f: NativeFunction) -> bool:
|
||||
return f.python_module == 'linalg'
|
||||
|
||||
def is_py_special_function(f: NativeFunction) -> bool:
|
||||
return f.python_module == 'special'
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
# Main Function
|
||||
@ -158,6 +161,9 @@ def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_pat
|
||||
create_python_bindings(
|
||||
fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)
|
||||
|
||||
create_python_bindings(
|
||||
fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False)
|
||||
|
||||
def create_python_bindings(
|
||||
fm: FileManager,
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
@ -528,6 +534,7 @@ if(check_has_torch_function(self_)) {{
|
||||
"torch.nn": "THPNNVariableFunctionsModule",
|
||||
"torch.fft": "THPFFTVariableFunctionsModule",
|
||||
"torch.linalg": "THPLinalgVariableFunctionsModule",
|
||||
"torch.special": "THPSpecialVariableFunctionsModule",
|
||||
}[module] if module else "THPVariableClass"
|
||||
|
||||
return f"""\
|
||||
|
73
tools/autograd/templates/python_special_functions.cpp
Normal file
73
tools/autograd/templates/python_special_functions.cpp
Normal file
@ -0,0 +1,73 @@
|
||||
// ${generated_comment}
|
||||
|
||||
#include "torch/csrc/Device.h"
|
||||
#include "torch/csrc/DynamicTypes.h"
|
||||
#include "torch/csrc/Exceptions.h"
|
||||
#include "torch/csrc/autograd/python_special_functions.h"
|
||||
#include "torch/csrc/autograd/python_variable.h"
|
||||
#include "torch/csrc/autograd/utils/wrap_outputs.h"
|
||||
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
|
||||
#include "torch/csrc/autograd/generated/variable_factories.h"
|
||||
#include "torch/csrc/utils/out_types.h"
|
||||
#include "torch/csrc/utils/pycfunction_helpers.h"
|
||||
#include "torch/csrc/utils/python_arg_parser.h"
|
||||
#include "torch/csrc/utils/structseq.h"
|
||||
#include "torch/csrc/utils/cuda_lazy_init.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
using at::Tensor;
|
||||
using at::Device;
|
||||
using at::Layout;
|
||||
using at::Scalar;
|
||||
using at::ScalarType;
|
||||
using at::Backend;
|
||||
using at::OptionalDeviceGuard;
|
||||
using at::DeviceGuard;
|
||||
using at::TensorOptions;
|
||||
using at::IntArrayRef;
|
||||
using at::Generator;
|
||||
using at::TensorList;
|
||||
using at::Dimname;
|
||||
using at::DimnameList;
|
||||
|
||||
using torch::utils::check_out_type_matches;
|
||||
using namespace torch::autograd::utils;
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
// generated forward declarations start here
|
||||
|
||||
${py_forwards}
|
||||
|
||||
static PyMethodDef special_functions[] = {
|
||||
${py_method_defs}
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyObject* THPSpecialVariableFunctionsModule = NULL;
|
||||
|
||||
void initSpecialFunctions(PyObject* module) {
|
||||
static struct PyModuleDef def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"torch._C._special",
|
||||
NULL,
|
||||
-1,
|
||||
special_functions
|
||||
};
|
||||
PyObject* special = PyModule_Create(&def);
|
||||
THPSpecialVariableFunctionsModule = special;
|
||||
if (!special) {
|
||||
throw python_error();
|
||||
}
|
||||
// steals a reference to special
|
||||
if (PyModule_AddObject(module, "_special", special) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
// generated methods start here
|
||||
|
||||
${py_methods}
|
||||
|
||||
}} // namespace torch::autograd
|
@ -16,6 +16,7 @@ GENERATED_CPP = [
|
||||
"autograd/generated/python_nn_functions.cpp",
|
||||
"autograd/generated/python_fft_functions.cpp",
|
||||
"autograd/generated/python_linalg_functions.cpp",
|
||||
"autograd/generated/python_special_functions.cpp",
|
||||
"autograd/generated/python_torch_functions.cpp",
|
||||
"autograd/generated/python_variable_methods.cpp",
|
||||
]
|
||||
@ -631,6 +632,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"):
|
||||
"autograd/generated/python_nn_functions.cpp",
|
||||
"autograd/generated/python_fft_functions.cpp",
|
||||
"autograd/generated/python_linalg_functions.cpp",
|
||||
"autograd/generated/python_special_functions.cpp",
|
||||
"autograd/generated/python_torch_functions.cpp",
|
||||
"autograd/generated/python_variable_methods.cpp",
|
||||
]]
|
||||
|
@ -650,6 +650,7 @@ from torch import optim as optim
|
||||
import torch.optim._multi_tensor
|
||||
from torch import multiprocessing as multiprocessing
|
||||
from torch import sparse as sparse
|
||||
from torch import special as special
|
||||
import torch.utils.backcompat
|
||||
from torch import onnx as onnx
|
||||
from torch import jit as jit
|
||||
|
@ -4357,13 +4357,15 @@ add_docstr(torch.lgamma,
|
||||
r"""
|
||||
lgamma(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the logarithm of the gamma function on :attr:`input`.
|
||||
Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`.
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} = \log \Gamma(\text{input}_{i})
|
||||
\text{out}_{i} = \ln \Gamma(|\text{input}_{i}|)
|
||||
""" + """
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
@ -34,6 +34,7 @@
|
||||
#include <torch/csrc/autograd/python_nn_functions.h>
|
||||
#include <torch/csrc/autograd/python_fft_functions.h>
|
||||
#include <torch/csrc/autograd/python_linalg_functions.h>
|
||||
#include <torch/csrc/autograd/python_special_functions.h>
|
||||
#include <torch/csrc/autograd/python_legacy_variable.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/multiprocessing/init.h>
|
||||
@ -831,6 +832,7 @@ PyObject* initModule() {
|
||||
torch::autograd::initNNFunctions(module);
|
||||
torch::autograd::initFFTFunctions(module);
|
||||
torch::autograd::initLinalgFunctions(module);
|
||||
torch::autograd::initSpecialFunctions(module);
|
||||
torch::autograd::init_legacy_variable(module);
|
||||
torch::python::init_bindings(module);
|
||||
#ifdef USE_CUDA
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <torch/nn.h>
|
||||
#include <torch/optim.h>
|
||||
#include <torch/serialize.h>
|
||||
#include <torch/special.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
#include <torch/autograd.h>
|
||||
|
20
torch/csrc/api/include/torch/special.h
Normal file
20
torch/csrc/api/include/torch/special.h
Normal file
@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace torch {
|
||||
namespace special {
|
||||
|
||||
/// Computes the natural logarithm of the absolute value of the gamma function
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.gammaln.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto t = torch::randn(128, dtype=kDouble);
|
||||
/// torch::special::gammaln(t);
|
||||
/// ```
|
||||
inline Tensor gammaln(const Tensor& self) {
|
||||
return torch::special_gammaln(self);
|
||||
}
|
||||
|
||||
}} // torch::special
|
7
torch/csrc/autograd/python_special_functions.h
Normal file
7
torch/csrc/autograd/python_special_functions.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
void initSpecialFunctions(PyObject* module);
|
||||
|
||||
}} // namespace torch::autograd
|
@ -105,7 +105,7 @@ const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
|
||||
{aten::swapaxes, aten::transpose},
|
||||
{aten::swapaxes_, aten::transpose_},
|
||||
{aten::moveaxis, aten::movedim},
|
||||
};
|
||||
{aten::special_gammaln, aten::lgamma}};
|
||||
return alias_map;
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@ from typing import Dict, Optional
|
||||
|
||||
_builtin_table: Optional[Dict[int, str]] = None
|
||||
|
||||
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg) # type: ignore
|
||||
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._special) # type: ignore
|
||||
|
||||
_builtin_ops = [
|
||||
# Pairs of (function, op_name)
|
||||
|
29
torch/special/__init__.py
Normal file
29
torch/special/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch._C import _add_docstr, _special # type: ignore
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
gammaln = _add_docstr(_special.special_gammaln,
|
||||
r"""
|
||||
gammaln(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`.
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} = \ln \Gamma(|\text{input}_{i}|)
|
||||
""" + """
|
||||
Args:
|
||||
input (Tensor): the input tensor.
|
||||
|
||||
Keyword args:
|
||||
out (Tensor, optional): the output tensor.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.arange(0.5, 2, 0.5)
|
||||
>>> torch.special.gammaln(a)
|
||||
tensor([ 0.5724, 0.0000, -0.1208])
|
||||
|
||||
""")
|
@ -2641,6 +2641,7 @@ if TEST_SCIPY:
|
||||
),
|
||||
UnaryUfuncInfo('lgamma',
|
||||
ref=reference_lgamma,
|
||||
aliases=('special.gammaln', ),
|
||||
decorators=(precisionOverride({torch.float16: 7e-1}),),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
|
||||
|
Reference in New Issue
Block a user