mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Overload _get_operation_for_overload_or_packet & friends to accept ArrayRef (#162219)
Avoids requiring vector allocation to call this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162219 Approved by: https://github.com/Skylion007 ghstack dependencies: #161591, #161595, #161633, #161634, #161692
This commit is contained in:
committed by
PyTorch MergeBot
parent
12db2a7889
commit
a8a187b2cf
@ -1726,7 +1726,7 @@ void initJITBindings(PyObject* module) {
|
||||
const py::args& args, const py::kwargs& kwargs) {
|
||||
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
|
||||
return _get_operation_for_overload_or_packet(
|
||||
{op}, symbol, args, kwargs, /*is_overload*/ true);
|
||||
op, symbol, args, kwargs, /*is_overload*/ true);
|
||||
});
|
||||
auto func_dk =
|
||||
py::cpp_function([op, symbol, allow_numbers_as_tensors](
|
||||
@ -1735,7 +1735,7 @@ void initJITBindings(PyObject* module) {
|
||||
const py::kwargs& kwargs) {
|
||||
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
|
||||
return _get_operation_for_overload_or_packet(
|
||||
{op}, symbol, args, kwargs, /*is_overload*/ true, dk_);
|
||||
op, symbol, args, kwargs, /*is_overload*/ true, dk_);
|
||||
});
|
||||
return py::make_tuple(
|
||||
func, func_dk, py::cast(op->getTags().vec()));
|
||||
|
@ -780,9 +780,17 @@ std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
|
||||
const std::vector<std::shared_ptr<Operator>>& operations,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
return getOpWithStack(
|
||||
c10::ArrayRef<std::shared_ptr<Operator>>(operations), args, kwargs);
|
||||
}
|
||||
|
||||
std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
|
||||
c10::ArrayRef<std::shared_ptr<Operator>> operations,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
Stack stack;
|
||||
if (operations.size() == 1) {
|
||||
std::shared_ptr<Operator> op = operations.at(0);
|
||||
std::shared_ptr<Operator> op = operations[0];
|
||||
// Create a stack full of the arguments and keyword arguments.
|
||||
stack = createStackForSchema(op->schema(), args, kwargs, std::nullopt);
|
||||
|
||||
@ -834,6 +842,15 @@ py::object invokeOperatorFromPython(
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs,
|
||||
std::optional<c10::DispatchKey> dk) {
|
||||
return invokeOperatorFromPython(
|
||||
c10::ArrayRef<std::shared_ptr<Operator>>(operations), args, kwargs, dk);
|
||||
}
|
||||
|
||||
py::object invokeOperatorFromPython(
|
||||
c10::ArrayRef<std::shared_ptr<Operator>> operations,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs,
|
||||
std::optional<c10::DispatchKey> dk) {
|
||||
auto [found_op, stack] = getOpWithStack(operations, args, kwargs);
|
||||
{
|
||||
pybind11::gil_scoped_release no_gil_guard;
|
||||
@ -912,6 +929,17 @@ py::object _get_operation_for_overload_or_packet(
|
||||
const py::kwargs& kwargs,
|
||||
bool is_overload,
|
||||
std::optional<c10::DispatchKey> dk) {
|
||||
return _get_operation_for_overload_or_packet(
|
||||
c10::ArrayRef(operations), symbol, args, kwargs, is_overload, dk);
|
||||
}
|
||||
|
||||
py::object _get_operation_for_overload_or_packet(
|
||||
c10::ArrayRef<std::shared_ptr<Operator>> operations,
|
||||
Symbol symbol,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs,
|
||||
bool is_overload,
|
||||
std::optional<c10::DispatchKey> dk) {
|
||||
std::string ns = symbol.ns().toUnqualString();
|
||||
std::string method_name = symbol.toUnqualString();
|
||||
std::string overload_name = operations[0]->schema().overload_name();
|
||||
|
@ -1275,12 +1275,27 @@ TORCH_PYTHON_API std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs);
|
||||
|
||||
// Efficient overload (does not require vector allocation) of the
|
||||
// above for use from C++ code.
|
||||
std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
|
||||
c10::ArrayRef<std::shared_ptr<Operator>> operations,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs);
|
||||
|
||||
TORCH_PYTHON_API py::object invokeOperatorFromPython(
|
||||
const std::vector<std::shared_ptr<Operator>>& operations,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs,
|
||||
std::optional<c10::DispatchKey> dk = std::nullopt);
|
||||
|
||||
// Efficient overload (does not require vector allocation) of the
|
||||
// above for use from C++ code.
|
||||
py::object invokeOperatorFromPython(
|
||||
c10::ArrayRef<std::shared_ptr<Operator>> operations,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs,
|
||||
std::optional<c10::DispatchKey> dk = std::nullopt);
|
||||
|
||||
TORCH_PYTHON_API std::optional<py::object> _maybe_handle_torch_function(
|
||||
const std::string& ns,
|
||||
const std::string& method_name,
|
||||
@ -1302,4 +1317,14 @@ TORCH_PYTHON_API py::object _get_operation_for_overload_or_packet(
|
||||
bool is_overload,
|
||||
std::optional<c10::DispatchKey> dk = std::nullopt);
|
||||
|
||||
// Efficient overload (does not require vector allocation) of the
|
||||
// above for use from C++ code.
|
||||
py::object _get_operation_for_overload_or_packet(
|
||||
c10::ArrayRef<std::shared_ptr<Operator>> operations,
|
||||
Symbol symbol,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs,
|
||||
bool is_overload,
|
||||
std::optional<c10::DispatchKey> dk = std::nullopt);
|
||||
|
||||
} // namespace torch::jit
|
||||
|
Reference in New Issue
Block a user