[special] add torch.special namespace (#52296)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/50345

 * Add `torch.special` namespace
* Add `torch.special.gammaln` (alias to `torch.lgamma`)

TODO:
* Add proper entries for docs.
   * [x] Add .rst file entry
   * [x] Add documentation
   * [x] Update `lgamma` OpInfo entry for alias to `special.gammaln`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/52296

Reviewed By: ngimel

Differential Revision: D26754890

Pulled By: mruberry

fbshipit-source-id: 73479f68989d6443ad07b7b02763fa98973c15f6
This commit is contained in:
kshitij12345
2021-03-04 00:00:09 -08:00
committed by Facebook GitHub Bot
parent c5b0c2fa8b
commit c4c77e2001
24 changed files with 215 additions and 7 deletions

View File

@ -1,7 +1,7 @@
# Generates Python bindings for ATen functions
#
# The bindings are generated as methods on python_variable or functions on the
# torch._C._nn. torch._C._fft, or torch._C._linalg objects.
# torch._C._nn. torch._C._fft, torch._C._linalg or torch._C._special objects.
#
# Code tries to stick to the following rules:
@ -132,6 +132,9 @@ def is_py_fft_function(f: NativeFunction) -> bool:
def is_py_linalg_function(f: NativeFunction) -> bool:
return f.python_module == 'linalg'
def is_py_special_function(f: NativeFunction) -> bool:
return f.python_module == 'special'
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Main Function
@ -158,6 +161,9 @@ def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_pat
create_python_bindings(
fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False)
def create_python_bindings(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
@ -528,6 +534,7 @@ if(check_has_torch_function(self_)) {{
"torch.nn": "THPNNVariableFunctionsModule",
"torch.fft": "THPFFTVariableFunctionsModule",
"torch.linalg": "THPLinalgVariableFunctionsModule",
"torch.special": "THPSpecialVariableFunctionsModule",
}[module] if module else "THPVariableClass"
return f"""\