Expose functions used in custom backend in torch_python dll (#148213)

Fixes #148208. There are solutions for exposing symbols implicitly from inline functions (i.e., inline function A calls non-inline function B in foo.h. Code includes foo.h has to see the symbol B in DLL).

Solution 1: tag the entire struct where the inline functions are defined as member functions with TORCH_PYTHON_API --- this PR does this for python_arg_parser.h. An alternative solution exists but will slow down dispatching a lot --- drop inline keyword and move implementation to .cc file.

Solution 2: tag individual functions with TORCH_PYTHON_API. This PR does this for python_tensor.h.

Related discussion about hiding torch_python symbols: https://github.com/pytorch/pytorch/pull/142214

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148213
Approved by: https://github.com/malfet
This commit is contained in:
Wei-Sheng Chin
2025-03-07 02:34:37 +00:00
committed by PyTorch MergeBot
parent dfb4094b9c
commit 9c9b05bc4f
2 changed files with 6 additions and 4 deletions

View File

@ -27,9 +27,9 @@ TORCH_PYTHON_API void py_set_default_dtype(PyObject* dtype_obj);
// TODO: This is nuts! There is no reason to let the default tensor type id
// change. Probably only store ScalarType, as that's the only flex point
// we support.
TORCH_API c10::DispatchKey get_default_dispatch_key();
TORCH_PYTHON_API c10::DispatchKey get_default_dispatch_key();
TORCH_PYTHON_API at::Device get_default_device();
// Gets the ScalarType for the default tensor type.
at::ScalarType get_default_scalar_type();
TORCH_PYTHON_API at::ScalarType get_default_scalar_type();
} // namespace torch::tensors

View File

@ -209,7 +209,7 @@ struct FunctionSignature {
// PythonArgs contains bound Python arguments for an actual invocation
// along with references to the matched signature.
struct PythonArgs {
struct TORCH_PYTHON_API PythonArgs {
PythonArgs(
bool traceable,
const FunctionSignature& signature,
@ -303,6 +303,8 @@ struct PythonArgs {
inline std::optional<c10::DispatchKeySet> toDispatchKeySetOptional(int i);
private:
// Non-inline functions' symbols are exposed to torch_python DLL
// via TORCH_PYTHON_API tag at struct level.
at::Tensor tensor_slow(int i);
at::Scalar scalar_slow(int i);
at::Scalar scalar_slow(PyObject* arg);
@ -320,7 +322,7 @@ struct FunctionParameter {
int64_t* failed_idx = nullptr);
void set_default_str(const std::string& str);
std::string type_name() const;
TORCH_PYTHON_API std::string type_name() const;
ParameterType type_;
bool optional;