mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port OpSchema.__post_init__ and OpSchema._recompute_comparison_key to C++ (#161695)
I initially didn't see good results porting this, but it was apparently because of pybind11 function calling overhead. (pybind11's object-handling primitives seem fine enough.) I'm interested in setting up nanobind, but this demonstrates it's not blocking. Differential Revision: [D81530102](https://our.internmc.facebook.com/intern/diff/D81530102) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161695 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
bd964cbbfb
commit
76a841fd47
@ -1,8 +1,10 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import random
|
||||
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
from torch.distributed.tensor._op_schema import OpSchema
|
||||
from torch.distributed.tensor._op_schema import OpSchema, RuntimeSchemaInfo
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
@ -10,12 +12,108 @@ class TestOpSchema(TestCase):
|
||||
def test_equality_checks_lists_of_dtensor_spec(self):
|
||||
"""If x == y, then we must have h(x) == h(y)."""
|
||||
dts = DTensorSpec(mesh=None, placements=tuple(), tensor_meta=None)
|
||||
schema1 = OpSchema(op=None, args_schema=[dts, [dts]], kwargs_schema={})
|
||||
schema2 = OpSchema(op=None, args_schema=[dts, [dts, dts]], kwargs_schema={})
|
||||
schema1 = OpSchema(op=None, args_schema=(dts, [dts]), kwargs_schema={})
|
||||
schema2 = OpSchema(op=None, args_schema=(dts, [dts, dts]), kwargs_schema={})
|
||||
# This is a regression test; these schemas used to compare equal.
|
||||
self.assertNotEqual(schema1, schema2)
|
||||
self.assertNotEqual(hash(schema1), hash(schema2))
|
||||
|
||||
def test_equality_respects_static_attributes(self):
|
||||
def _get_sample_op_schemas(static_arg_val, static_kwarg_val):
|
||||
dts = DTensorSpec(mesh=None, placements=tuple(), tensor_meta=None)
|
||||
static_argnum = 2
|
||||
static_kwargkey = ["statickwarg"]
|
||||
annotated_schemas = [
|
||||
(False, False, None),
|
||||
(True, False, RuntimeSchemaInfo(static_argnum=static_argnum)),
|
||||
(False, True, RuntimeSchemaInfo(static_kwargkey=static_kwargkey)),
|
||||
(
|
||||
True,
|
||||
True,
|
||||
RuntimeSchemaInfo(
|
||||
static_argnum=static_argnum, static_kwargkey=static_kwargkey
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
# non-tensor args show up in hash iff the argnum is static/
|
||||
# kwargs show up in hash iff their name is in static_kwargkey.
|
||||
# random elements are random because they are not supposed to matter for
|
||||
# equality at all.
|
||||
args_schema = (dts, random.randint(1, 1000000), static_arg_val)
|
||||
kwargs_schema = {
|
||||
"ignoredkwarg": random.randint(1, 1000000),
|
||||
"statickwarg": static_kwarg_val,
|
||||
}
|
||||
return [
|
||||
(
|
||||
has_static_arg,
|
||||
has_static_kwarg,
|
||||
OpSchema(
|
||||
op=None,
|
||||
args_schema=args_schema,
|
||||
kwargs_schema=kwargs_schema,
|
||||
schema_info=si,
|
||||
),
|
||||
)
|
||||
for (has_static_arg, has_static_kwarg, si) in annotated_schemas
|
||||
]
|
||||
|
||||
for lhs_has_static_arg, lhs_has_static_kwarg, lhs in _get_sample_op_schemas(
|
||||
1, 2
|
||||
):
|
||||
# Static arg/kwarg both match
|
||||
for rhs_has_static_arg, rhs_has_static_kwarg, rhs in _get_sample_op_schemas(
|
||||
1, 2
|
||||
):
|
||||
if (
|
||||
lhs_has_static_arg == rhs_has_static_arg
|
||||
and lhs_has_static_kwarg == rhs_has_static_kwarg
|
||||
):
|
||||
self.assertEqual(lhs, rhs)
|
||||
else:
|
||||
self.assertNotEqual(lhs, rhs)
|
||||
|
||||
# Static arg mismatch
|
||||
for rhs_has_static_arg, rhs_has_static_kwarg, rhs in _get_sample_op_schemas(
|
||||
3, 2
|
||||
):
|
||||
if (
|
||||
lhs_has_static_arg
|
||||
or rhs_has_static_arg
|
||||
or lhs_has_static_kwarg != rhs_has_static_kwarg
|
||||
):
|
||||
self.assertNotEqual(lhs, rhs)
|
||||
else:
|
||||
self.assertEqual(lhs, rhs)
|
||||
|
||||
# Static kwarg mismatch
|
||||
for rhs_has_static_arg, rhs_has_static_kwarg, rhs in _get_sample_op_schemas(
|
||||
1, 3
|
||||
):
|
||||
if (
|
||||
lhs_has_static_kwarg
|
||||
or rhs_has_static_kwarg
|
||||
or lhs_has_static_arg != rhs_has_static_arg
|
||||
):
|
||||
self.assertNotEqual(lhs, rhs)
|
||||
else:
|
||||
self.assertEqual(lhs, rhs)
|
||||
|
||||
# Static arg/kwarg both mismatch
|
||||
for rhs_has_static_arg, rhs_has_static_kwarg, rhs in _get_sample_op_schemas(
|
||||
3, 4
|
||||
):
|
||||
if (
|
||||
lhs_has_static_arg
|
||||
or rhs_has_static_arg
|
||||
or lhs_has_static_kwarg
|
||||
or rhs_has_static_kwarg
|
||||
):
|
||||
self.assertNotEqual(lhs, rhs)
|
||||
else:
|
||||
self.assertEqual(lhs, rhs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -41,6 +41,7 @@ from torch._C import (
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from torch.autograd.graph import Node as _Node
|
||||
from torch.cuda import _POOL_HANDLE
|
||||
from torch.distributed.tensor._op_schema import OpSchema
|
||||
from torch.fx.node import Node as FxNode
|
||||
from torch.package import PackageExporter
|
||||
from torch.storage import TypedStorage, UntypedStorage
|
||||
@ -1942,6 +1943,9 @@ class TensorBase(metaclass=_TensorMeta):
|
||||
|
||||
_TensorBase = TensorBase
|
||||
|
||||
def _DTensor_OpSchema_post_init(self: OpSchema) -> None: ...
|
||||
def _DTensor_OpSchema_recompute_comparison_key(self: OpSchema) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/multiprocessing/init.cpp
|
||||
def _multiprocessing_init() -> None: ...
|
||||
def _set_thread_name(name: str) -> None: ...
|
||||
|
@ -852,6 +852,243 @@ static PyObject* THPVariable_make_dtensor(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static py::handle get_dtensor_spec_class() {
|
||||
#if IS_PYBIND_2_13_PLUS
|
||||
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
|
||||
storage;
|
||||
return storage
|
||||
.call_once_and_store_result([]() -> py::object {
|
||||
return py::module::import("torch")
|
||||
.attr("distributed")
|
||||
.attr("tensor")
|
||||
.attr("_dtensor_spec")
|
||||
.attr("DTensorSpec");
|
||||
})
|
||||
.get_stored();
|
||||
#else
|
||||
static py::handle dtensor_spec_class = py::object(py::module::import("torch")
|
||||
.attr("distributed")
|
||||
.attr("tensor")
|
||||
.attr("_dtensor_spec")
|
||||
.attr("DTensorSpec"))
|
||||
.release();
|
||||
return dtensor_spec_class;
|
||||
#endif
|
||||
}
|
||||
|
||||
static bool arg_type_tensor_or_tensor_list_like(py::handle arg) {
|
||||
const auto dtensor_spec_class = get_dtensor_spec_class();
|
||||
if (py::isinstance(arg, dtensor_spec_class)) {
|
||||
return true;
|
||||
}
|
||||
if (!PyList_Check(arg.ptr())) {
|
||||
return false;
|
||||
}
|
||||
py::list arg_list = py::reinterpret_borrow<py::list>(arg);
|
||||
for (const auto e : arg_list) {
|
||||
if (!e.is_none() && !py::isinstance(e, dtensor_spec_class)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \
|
||||
_(_comparison_key) \
|
||||
_(args_schema) \
|
||||
_(has_symints) \
|
||||
_(kwargs_schema) \
|
||||
_(op) \
|
||||
_(schema_info) \
|
||||
_(shape) \
|
||||
_(static_argnum) \
|
||||
_(static_kwargkey) \
|
||||
_(tensor_meta)
|
||||
|
||||
struct DTensorInternedStrings {
|
||||
#define DECLARE_INTERNED_STRING_VARIABLE(s) PyObject* s;
|
||||
FOR_EACH_DTENSOR_INTERNED_STRING(DECLARE_INTERNED_STRING_VARIABLE)
|
||||
#undef DECLARE_INTERNED_STRING_VARIABLE
|
||||
};
|
||||
|
||||
static DTensorInternedStrings dtensor_interned_strings;
|
||||
|
||||
static bool intern_dtensor_strings() {
|
||||
#define INTERN_DTENSOR_STRING(s) \
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dtensor_interned_strings.s == nullptr); \
|
||||
dtensor_interned_strings.s = PyUnicode_InternFromString(#s); \
|
||||
if (dtensor_interned_strings.s == nullptr) { \
|
||||
return false; \
|
||||
}
|
||||
|
||||
FOR_EACH_DTENSOR_INTERNED_STRING(INTERN_DTENSOR_STRING);
|
||||
#undef INTERN_DTENSOR_STRING
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool checked_not(PyObject* obj) {
|
||||
int result = PyObject_Not(obj);
|
||||
if (result == -1) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool DTensor_OpSchema_recompute_comparison_key_impl(
|
||||
PyObject* self,
|
||||
const py::tuple& args_schema) {
|
||||
py::object static_kwargkey;
|
||||
size_t static_argnum = 0;
|
||||
const py::handle self_handle = py::handle(self);
|
||||
const py::handle schema_info =
|
||||
self_handle.attr(dtensor_interned_strings.schema_info);
|
||||
if (checked_not(schema_info.ptr())) {
|
||||
static_argnum = args_schema.size();
|
||||
static_kwargkey = py::none();
|
||||
} else {
|
||||
static_argnum = py::cast<size_t>(
|
||||
schema_info.attr(dtensor_interned_strings.static_argnum));
|
||||
static_kwargkey =
|
||||
schema_info.attr(dtensor_interned_strings.static_kwargkey);
|
||||
}
|
||||
c10::SmallVector<py::object, 8> args_to_hash;
|
||||
size_t idx = 0;
|
||||
for (const auto& e : args_schema) {
|
||||
if (idx >= static_argnum || arg_type_tensor_or_tensor_list_like(e)) {
|
||||
if (PyList_Check(e.ptr())) {
|
||||
args_to_hash.push_back(
|
||||
py::reinterpret_steal<py::object>(PyList_AsTuple(e.ptr())));
|
||||
} else {
|
||||
args_to_hash.push_back(py::reinterpret_borrow<py::object>(e));
|
||||
}
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
py::tuple args_to_hash_tup(args_to_hash.size());
|
||||
for (const auto idx : c10::irange(args_to_hash.size())) {
|
||||
args_to_hash_tup[idx] = std::move(args_to_hash[idx]);
|
||||
}
|
||||
PyObject* comparison_key = nullptr;
|
||||
if (!static_kwargkey.is_none()) {
|
||||
if (!PyList_Check(static_kwargkey.ptr())) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError, "self.schema_info.static_kwargkey must be a list!");
|
||||
return false;
|
||||
}
|
||||
py::list static_kwargkey_list =
|
||||
py::reinterpret_borrow<py::list>(static_kwargkey);
|
||||
auto raw_kwargs_schema =
|
||||
self_handle.attr(dtensor_interned_strings.kwargs_schema);
|
||||
if (!PyDict_Check(raw_kwargs_schema.ptr())) {
|
||||
PyErr_SetString(PyExc_TypeError, "self.kwargs_schema must be a dict!");
|
||||
return false;
|
||||
}
|
||||
py::tuple kwargs_to_hash(static_kwargkey_list.size());
|
||||
int idx = 0;
|
||||
auto kwargs_schema = py::reinterpret_borrow<py::dict>(raw_kwargs_schema);
|
||||
for (const auto& k : static_kwargkey_list) {
|
||||
PyObject* item = PyDict_GetItemWithError(kwargs_schema.ptr(), k.ptr());
|
||||
if (item) {
|
||||
kwargs_to_hash[idx++] = py::reinterpret_borrow<py::object>(item);
|
||||
} else if (PyErr_Occurred()) {
|
||||
return false;
|
||||
} else {
|
||||
kwargs_to_hash[idx++] = py::none();
|
||||
}
|
||||
}
|
||||
comparison_key = PyTuple_Pack(
|
||||
3,
|
||||
self_handle.attr(dtensor_interned_strings.op).ptr(),
|
||||
args_to_hash_tup.ptr(),
|
||||
kwargs_to_hash.ptr());
|
||||
} else {
|
||||
comparison_key = PyTuple_Pack(
|
||||
2,
|
||||
self_handle.attr(dtensor_interned_strings.op).ptr(),
|
||||
args_to_hash_tup.release().ptr());
|
||||
}
|
||||
if (!comparison_key) {
|
||||
return false;
|
||||
}
|
||||
self_handle.attr(dtensor_interned_strings._comparison_key) =
|
||||
py::reinterpret_steal<py::object>(comparison_key);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static PyObject* DTensor_OpSchema_recompute_comparison_key(
|
||||
PyObject* mod,
|
||||
PyObject* self) {
|
||||
HANDLE_TH_ERRORS
|
||||
const py::handle self_handle = py::handle(self);
|
||||
const py::handle raw_args_schema =
|
||||
self_handle.attr(dtensor_interned_strings.args_schema);
|
||||
if (!PyTuple_Check(raw_args_schema.ptr())) {
|
||||
PyErr_SetString(PyExc_TypeError, "DTensor.args_schema must be a tuple!");
|
||||
return nullptr;
|
||||
}
|
||||
py::tuple args_schema = py::reinterpret_borrow<py::tuple>(raw_args_schema);
|
||||
if (!DTensor_OpSchema_recompute_comparison_key_impl(self, args_schema)) {
|
||||
return nullptr;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* DTensor_OpSchema_post_init(PyObject* mod, PyObject* self) {
|
||||
HANDLE_TH_ERRORS
|
||||
const py::handle self_handle = py::handle(self);
|
||||
const py::handle raw_args_schema =
|
||||
self_handle.attr(dtensor_interned_strings.args_schema);
|
||||
if (!PyTuple_Check(raw_args_schema.ptr())) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"DTensor_OpSchema_post_init requires self.args_schema to be a tuple!");
|
||||
return nullptr;
|
||||
}
|
||||
py::tuple args_schema = py::reinterpret_borrow<py::tuple>(raw_args_schema);
|
||||
if (!DTensor_OpSchema_recompute_comparison_key_impl(self, args_schema)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto dtensor_spec_class = get_dtensor_spec_class();
|
||||
bool has_symints = false;
|
||||
for (const auto& a : args_schema) {
|
||||
if (Py_TYPE(a.ptr()) != (PyTypeObject*)(dtensor_spec_class.ptr()) &&
|
||||
!py::isinstance(a, dtensor_spec_class)) {
|
||||
continue;
|
||||
}
|
||||
const py::handle tensor_meta = a.attr(dtensor_interned_strings.tensor_meta);
|
||||
if (tensor_meta.is_none()) {
|
||||
continue;
|
||||
}
|
||||
const auto contains_any_symint = [](const py::tuple& sequence) {
|
||||
for (const auto& s : sequence) {
|
||||
if (THPUtils_checkLong(s.ptr())) {
|
||||
continue;
|
||||
}
|
||||
if (torch::is_symint(s)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
// Specifically it's supposed to be torch.Size.
|
||||
py::object raw_shape = tensor_meta.attr(dtensor_interned_strings.shape);
|
||||
if (!PyTuple_Check(raw_shape.ptr())) {
|
||||
PyErr_SetString(PyExc_TypeError, "OpSchema.shape must be a tuple!");
|
||||
return nullptr;
|
||||
}
|
||||
const auto shape = py::reinterpret_steal<py::tuple>(raw_shape.release());
|
||||
if (contains_any_symint(shape)) {
|
||||
has_symints = true;
|
||||
}
|
||||
}
|
||||
self_handle.attr(dtensor_interned_strings.has_symints) = has_symints;
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
using getter = PyObject* (*)(PyObject*, void*);
|
||||
using setter = int (*)(PyObject*, PyObject*, void*);
|
||||
|
||||
@ -1762,6 +1999,18 @@ static PyMethodDef extra_methods[] = {
|
||||
{"_use_count", THPVariable__use_count, METH_NOARGS, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static PyMethodDef extra_functions[] = {
|
||||
{"_DTensor_OpSchema_post_init",
|
||||
DTensor_OpSchema_post_init,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_DTensor_OpSchema_recompute_comparison_key",
|
||||
DTensor_OpSchema_recompute_comparison_key,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{nullptr}};
|
||||
|
||||
struct THPVariableMeta {
|
||||
PyHeapTypeObject base;
|
||||
};
|
||||
@ -2488,5 +2737,10 @@ bool THPVariable_initModule(PyObject* module) {
|
||||
torch::autograd::initTorchFunctions(module);
|
||||
torch::autograd::initTensorImplConversion(module);
|
||||
torch::utils::validate_numpy_for_dlpack_deleter_bug();
|
||||
|
||||
if (!intern_dtensor_strings()) {
|
||||
return false;
|
||||
}
|
||||
PyModule_AddFunctions(module, extra_functions);
|
||||
return true;
|
||||
}
|
||||
|
@ -24,12 +24,16 @@ These schema definitions enable the DTensor system to:
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
from torch._C import (
|
||||
_DTensor_OpSchema_post_init,
|
||||
_DTensor_OpSchema_recompute_comparison_key,
|
||||
)
|
||||
from torch._ops import OpOverload
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
@ -331,6 +335,8 @@ class OpSchema:
|
||||
|
||||
_comparison_key: Optional[tuple[object, ...]] = None
|
||||
|
||||
has_symints: bool = field(init=False)
|
||||
|
||||
@property
|
||||
def args_spec(self) -> tuple[DTensorSpec, ...]:
|
||||
"""
|
||||
@ -386,14 +392,7 @@ class OpSchema:
|
||||
return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
has_symints = False
|
||||
for a in self.args_schema:
|
||||
if isinstance(a, DTensorSpec) and a.tensor_meta is not None:
|
||||
if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape):
|
||||
has_symints = True
|
||||
break
|
||||
self.has_symints = has_symints
|
||||
self._recompute_comparison_key()
|
||||
_DTensor_OpSchema_post_init(self)
|
||||
|
||||
def arg_type_tensor_or_tensor_list_like(self, arg: object) -> bool:
|
||||
is_tensor = isinstance(arg, DTensorSpec)
|
||||
@ -479,26 +478,8 @@ class OpSchema:
|
||||
def is_view_op(self) -> bool:
|
||||
return self.op._schema._is_view_op()
|
||||
|
||||
def _recompute_comparison_key(self):
|
||||
if not self.schema_info:
|
||||
static_argnum = len(self.args_schema)
|
||||
static_kwargkey = None
|
||||
else:
|
||||
static_argnum = self.schema_info.static_argnum
|
||||
static_kwargkey = self.schema_info.static_kwargkey
|
||||
|
||||
args_to_hash = tuple(
|
||||
tuple(e) if isinstance(e, list) else e
|
||||
for i, e in enumerate(self.args_schema)
|
||||
if self.arg_type_tensor_or_tensor_list_like(e) or i >= static_argnum
|
||||
)
|
||||
if static_kwargkey is not None:
|
||||
kwargs_to_hash = tuple(
|
||||
self.kwargs_schema.get(k, None) for k in static_kwargkey
|
||||
)
|
||||
self._comparison_key = (self.op, args_to_hash, kwargs_to_hash)
|
||||
else:
|
||||
self._comparison_key = (self.op, args_to_hash)
|
||||
def _recompute_comparison_key(self) -> None:
|
||||
_DTensor_OpSchema_recompute_comparison_key(self)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._comparison_key)
|
||||
|
Reference in New Issue
Block a user