Overload _get_operation_for_overload_or_packet & friends to accept ArrayRef (#162219)

Avoids requiring vector allocation to call this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162219
Approved by: https://github.com/Skylion007
ghstack dependencies: #161591, #161595, #161633, #161634, #161692
This commit is contained in:
Scott Wolchok
2025-09-08 11:09:13 -07:00
committed by PyTorch MergeBot
parent 12db2a7889
commit a8a187b2cf
3 changed files with 56 additions and 3 deletions

View File

@ -1726,7 +1726,7 @@ void initJITBindings(PyObject* module) {
const py::args& args, const py::kwargs& kwargs) {
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
return _get_operation_for_overload_or_packet(
{op}, symbol, args, kwargs, /*is_overload*/ true);
op, symbol, args, kwargs, /*is_overload*/ true);
});
auto func_dk =
py::cpp_function([op, symbol, allow_numbers_as_tensors](
@ -1735,7 +1735,7 @@ void initJITBindings(PyObject* module) {
const py::kwargs& kwargs) {
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
return _get_operation_for_overload_or_packet(
{op}, symbol, args, kwargs, /*is_overload*/ true, dk_);
op, symbol, args, kwargs, /*is_overload*/ true, dk_);
});
return py::make_tuple(
func, func_dk, py::cast(op->getTags().vec()));

View File

@ -780,9 +780,17 @@ std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
const std::vector<std::shared_ptr<Operator>>& operations,
const py::args& args,
const py::kwargs& kwargs) {
return getOpWithStack(
c10::ArrayRef<std::shared_ptr<Operator>>(operations), args, kwargs);
}
std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
c10::ArrayRef<std::shared_ptr<Operator>> operations,
const py::args& args,
const py::kwargs& kwargs) {
Stack stack;
if (operations.size() == 1) {
std::shared_ptr<Operator> op = operations.at(0);
std::shared_ptr<Operator> op = operations[0];
// Create a stack full of the arguments and keyword arguments.
stack = createStackForSchema(op->schema(), args, kwargs, std::nullopt);
@ -834,6 +842,15 @@ py::object invokeOperatorFromPython(
const py::args& args,
const py::kwargs& kwargs,
std::optional<c10::DispatchKey> dk) {
return invokeOperatorFromPython(
c10::ArrayRef<std::shared_ptr<Operator>>(operations), args, kwargs, dk);
}
py::object invokeOperatorFromPython(
c10::ArrayRef<std::shared_ptr<Operator>> operations,
const py::args& args,
const py::kwargs& kwargs,
std::optional<c10::DispatchKey> dk) {
auto [found_op, stack] = getOpWithStack(operations, args, kwargs);
{
pybind11::gil_scoped_release no_gil_guard;
@ -912,6 +929,17 @@ py::object _get_operation_for_overload_or_packet(
const py::kwargs& kwargs,
bool is_overload,
std::optional<c10::DispatchKey> dk) {
return _get_operation_for_overload_or_packet(
c10::ArrayRef(operations), symbol, args, kwargs, is_overload, dk);
}
py::object _get_operation_for_overload_or_packet(
c10::ArrayRef<std::shared_ptr<Operator>> operations,
Symbol symbol,
const py::args& args,
const py::kwargs& kwargs,
bool is_overload,
std::optional<c10::DispatchKey> dk) {
std::string ns = symbol.ns().toUnqualString();
std::string method_name = symbol.toUnqualString();
std::string overload_name = operations[0]->schema().overload_name();

View File

@ -1275,12 +1275,27 @@ TORCH_PYTHON_API std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
const py::args& args,
const py::kwargs& kwargs);
// Efficient overload (does not require vector allocation) of the
// above for use from C++ code.
std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
c10::ArrayRef<std::shared_ptr<Operator>> operations,
const py::args& args,
const py::kwargs& kwargs);
TORCH_PYTHON_API py::object invokeOperatorFromPython(
const std::vector<std::shared_ptr<Operator>>& operations,
const py::args& args,
const py::kwargs& kwargs,
std::optional<c10::DispatchKey> dk = std::nullopt);
// Efficient overload (does not require vector allocation) of the
// above for use from C++ code.
py::object invokeOperatorFromPython(
c10::ArrayRef<std::shared_ptr<Operator>> operations,
const py::args& args,
const py::kwargs& kwargs,
std::optional<c10::DispatchKey> dk = std::nullopt);
TORCH_PYTHON_API std::optional<py::object> _maybe_handle_torch_function(
const std::string& ns,
const std::string& method_name,
@ -1302,4 +1317,14 @@ TORCH_PYTHON_API py::object _get_operation_for_overload_or_packet(
bool is_overload,
std::optional<c10::DispatchKey> dk = std::nullopt);
// Efficient overload (does not require vector allocation) of the
// above for use from C++ code.
py::object _get_operation_for_overload_or_packet(
c10::ArrayRef<std::shared_ptr<Operator>> operations,
Symbol symbol,
const py::args& args,
const py::kwargs& kwargs,
bool is_overload,
std::optional<c10::DispatchKey> dk = std::nullopt);
} // namespace torch::jit