Port OpSchema.__post_init__ and OpSchema._recompute_comparison_key to C++ (#161695)

I initially didn't see good results porting this, but it was apparently because of pybind11 function calling overhead. (pybind11's object-handling primitives seem fine enough.) I'm interested in setting up nanobind, but this demonstrates it's not blocking.

Differential Revision: [D81530102](https://our.internmc.facebook.com/intern/diff/D81530102)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161695
Approved by: https://github.com/ezyang
This commit is contained in:
Scott Wolchok
2025-09-18 13:08:44 -07:00
committed by PyTorch MergeBot
parent bd964cbbfb
commit 76a841fd47
4 changed files with 369 additions and 32 deletions

View File

@ -1,8 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import random
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import OpSchema
from torch.distributed.tensor._op_schema import OpSchema, RuntimeSchemaInfo
from torch.testing._internal.common_utils import run_tests, TestCase
@ -10,12 +12,108 @@ class TestOpSchema(TestCase):
def test_equality_checks_lists_of_dtensor_spec(self):
"""If x == y, then we must have h(x) == h(y)."""
dts = DTensorSpec(mesh=None, placements=tuple(), tensor_meta=None)
schema1 = OpSchema(op=None, args_schema=[dts, [dts]], kwargs_schema={})
schema2 = OpSchema(op=None, args_schema=[dts, [dts, dts]], kwargs_schema={})
schema1 = OpSchema(op=None, args_schema=(dts, [dts]), kwargs_schema={})
schema2 = OpSchema(op=None, args_schema=(dts, [dts, dts]), kwargs_schema={})
# This is a regression test; these schemas used to compare equal.
self.assertNotEqual(schema1, schema2)
self.assertNotEqual(hash(schema1), hash(schema2))
def test_equality_respects_static_attributes(self):
def _get_sample_op_schemas(static_arg_val, static_kwarg_val):
dts = DTensorSpec(mesh=None, placements=tuple(), tensor_meta=None)
static_argnum = 2
static_kwargkey = ["statickwarg"]
annotated_schemas = [
(False, False, None),
(True, False, RuntimeSchemaInfo(static_argnum=static_argnum)),
(False, True, RuntimeSchemaInfo(static_kwargkey=static_kwargkey)),
(
True,
True,
RuntimeSchemaInfo(
static_argnum=static_argnum, static_kwargkey=static_kwargkey
),
),
]
# non-tensor args show up in hash iff the argnum is static/
# kwargs show up in hash iff their name is in static_kwargkey.
# random elements are random because they are not supposed to matter for
# equality at all.
args_schema = (dts, random.randint(1, 1000000), static_arg_val)
kwargs_schema = {
"ignoredkwarg": random.randint(1, 1000000),
"statickwarg": static_kwarg_val,
}
return [
(
has_static_arg,
has_static_kwarg,
OpSchema(
op=None,
args_schema=args_schema,
kwargs_schema=kwargs_schema,
schema_info=si,
),
)
for (has_static_arg, has_static_kwarg, si) in annotated_schemas
]
for lhs_has_static_arg, lhs_has_static_kwarg, lhs in _get_sample_op_schemas(
1, 2
):
# Static arg/kwarg both match
for rhs_has_static_arg, rhs_has_static_kwarg, rhs in _get_sample_op_schemas(
1, 2
):
if (
lhs_has_static_arg == rhs_has_static_arg
and lhs_has_static_kwarg == rhs_has_static_kwarg
):
self.assertEqual(lhs, rhs)
else:
self.assertNotEqual(lhs, rhs)
# Static arg mismatch
for rhs_has_static_arg, rhs_has_static_kwarg, rhs in _get_sample_op_schemas(
3, 2
):
if (
lhs_has_static_arg
or rhs_has_static_arg
or lhs_has_static_kwarg != rhs_has_static_kwarg
):
self.assertNotEqual(lhs, rhs)
else:
self.assertEqual(lhs, rhs)
# Static kwarg mismatch
for rhs_has_static_arg, rhs_has_static_kwarg, rhs in _get_sample_op_schemas(
1, 3
):
if (
lhs_has_static_kwarg
or rhs_has_static_kwarg
or lhs_has_static_arg != rhs_has_static_arg
):
self.assertNotEqual(lhs, rhs)
else:
self.assertEqual(lhs, rhs)
# Static arg/kwarg both mismatch
for rhs_has_static_arg, rhs_has_static_kwarg, rhs in _get_sample_op_schemas(
3, 4
):
if (
lhs_has_static_arg
or rhs_has_static_arg
or lhs_has_static_kwarg
or rhs_has_static_kwarg
):
self.assertNotEqual(lhs, rhs)
else:
self.assertEqual(lhs, rhs)
if __name__ == "__main__":
run_tests()

View File

@ -41,6 +41,7 @@ from torch._C import (
from torch._prims_common import DeviceLikeType
from torch.autograd.graph import Node as _Node
from torch.cuda import _POOL_HANDLE
from torch.distributed.tensor._op_schema import OpSchema
from torch.fx.node import Node as FxNode
from torch.package import PackageExporter
from torch.storage import TypedStorage, UntypedStorage
@ -1942,6 +1943,9 @@ class TensorBase(metaclass=_TensorMeta):
_TensorBase = TensorBase
def _DTensor_OpSchema_post_init(self: OpSchema) -> None: ...
def _DTensor_OpSchema_recompute_comparison_key(self: OpSchema) -> None: ...
# Defined in torch/csrc/multiprocessing/init.cpp
def _multiprocessing_init() -> None: ...
def _set_thread_name(name: str) -> None: ...

View File

@ -852,6 +852,243 @@ static PyObject* THPVariable_make_dtensor(
END_HANDLE_TH_ERRORS
}
static py::handle get_dtensor_spec_class() {
#if IS_PYBIND_2_13_PLUS
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
storage;
return storage
.call_once_and_store_result([]() -> py::object {
return py::module::import("torch")
.attr("distributed")
.attr("tensor")
.attr("_dtensor_spec")
.attr("DTensorSpec");
})
.get_stored();
#else
static py::handle dtensor_spec_class = py::object(py::module::import("torch")
.attr("distributed")
.attr("tensor")
.attr("_dtensor_spec")
.attr("DTensorSpec"))
.release();
return dtensor_spec_class;
#endif
}
static bool arg_type_tensor_or_tensor_list_like(py::handle arg) {
const auto dtensor_spec_class = get_dtensor_spec_class();
if (py::isinstance(arg, dtensor_spec_class)) {
return true;
}
if (!PyList_Check(arg.ptr())) {
return false;
}
py::list arg_list = py::reinterpret_borrow<py::list>(arg);
for (const auto e : arg_list) {
if (!e.is_none() && !py::isinstance(e, dtensor_spec_class)) {
return false;
}
}
return true;
}
#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \
_(_comparison_key) \
_(args_schema) \
_(has_symints) \
_(kwargs_schema) \
_(op) \
_(schema_info) \
_(shape) \
_(static_argnum) \
_(static_kwargkey) \
_(tensor_meta)
struct DTensorInternedStrings {
#define DECLARE_INTERNED_STRING_VARIABLE(s) PyObject* s;
FOR_EACH_DTENSOR_INTERNED_STRING(DECLARE_INTERNED_STRING_VARIABLE)
#undef DECLARE_INTERNED_STRING_VARIABLE
};
static DTensorInternedStrings dtensor_interned_strings;
static bool intern_dtensor_strings() {
#define INTERN_DTENSOR_STRING(s) \
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dtensor_interned_strings.s == nullptr); \
dtensor_interned_strings.s = PyUnicode_InternFromString(#s); \
if (dtensor_interned_strings.s == nullptr) { \
return false; \
}
FOR_EACH_DTENSOR_INTERNED_STRING(INTERN_DTENSOR_STRING);
#undef INTERN_DTENSOR_STRING
return true;
}
static bool checked_not(PyObject* obj) {
int result = PyObject_Not(obj);
if (result == -1) {
throw py::error_already_set();
}
return result;
}
static bool DTensor_OpSchema_recompute_comparison_key_impl(
PyObject* self,
const py::tuple& args_schema) {
py::object static_kwargkey;
size_t static_argnum = 0;
const py::handle self_handle = py::handle(self);
const py::handle schema_info =
self_handle.attr(dtensor_interned_strings.schema_info);
if (checked_not(schema_info.ptr())) {
static_argnum = args_schema.size();
static_kwargkey = py::none();
} else {
static_argnum = py::cast<size_t>(
schema_info.attr(dtensor_interned_strings.static_argnum));
static_kwargkey =
schema_info.attr(dtensor_interned_strings.static_kwargkey);
}
c10::SmallVector<py::object, 8> args_to_hash;
size_t idx = 0;
for (const auto& e : args_schema) {
if (idx >= static_argnum || arg_type_tensor_or_tensor_list_like(e)) {
if (PyList_Check(e.ptr())) {
args_to_hash.push_back(
py::reinterpret_steal<py::object>(PyList_AsTuple(e.ptr())));
} else {
args_to_hash.push_back(py::reinterpret_borrow<py::object>(e));
}
}
idx++;
}
py::tuple args_to_hash_tup(args_to_hash.size());
for (const auto idx : c10::irange(args_to_hash.size())) {
args_to_hash_tup[idx] = std::move(args_to_hash[idx]);
}
PyObject* comparison_key = nullptr;
if (!static_kwargkey.is_none()) {
if (!PyList_Check(static_kwargkey.ptr())) {
PyErr_SetString(
PyExc_TypeError, "self.schema_info.static_kwargkey must be a list!");
return false;
}
py::list static_kwargkey_list =
py::reinterpret_borrow<py::list>(static_kwargkey);
auto raw_kwargs_schema =
self_handle.attr(dtensor_interned_strings.kwargs_schema);
if (!PyDict_Check(raw_kwargs_schema.ptr())) {
PyErr_SetString(PyExc_TypeError, "self.kwargs_schema must be a dict!");
return false;
}
py::tuple kwargs_to_hash(static_kwargkey_list.size());
int idx = 0;
auto kwargs_schema = py::reinterpret_borrow<py::dict>(raw_kwargs_schema);
for (const auto& k : static_kwargkey_list) {
PyObject* item = PyDict_GetItemWithError(kwargs_schema.ptr(), k.ptr());
if (item) {
kwargs_to_hash[idx++] = py::reinterpret_borrow<py::object>(item);
} else if (PyErr_Occurred()) {
return false;
} else {
kwargs_to_hash[idx++] = py::none();
}
}
comparison_key = PyTuple_Pack(
3,
self_handle.attr(dtensor_interned_strings.op).ptr(),
args_to_hash_tup.ptr(),
kwargs_to_hash.ptr());
} else {
comparison_key = PyTuple_Pack(
2,
self_handle.attr(dtensor_interned_strings.op).ptr(),
args_to_hash_tup.release().ptr());
}
if (!comparison_key) {
return false;
}
self_handle.attr(dtensor_interned_strings._comparison_key) =
py::reinterpret_steal<py::object>(comparison_key);
return true;
}
static PyObject* DTensor_OpSchema_recompute_comparison_key(
PyObject* mod,
PyObject* self) {
HANDLE_TH_ERRORS
const py::handle self_handle = py::handle(self);
const py::handle raw_args_schema =
self_handle.attr(dtensor_interned_strings.args_schema);
if (!PyTuple_Check(raw_args_schema.ptr())) {
PyErr_SetString(PyExc_TypeError, "DTensor.args_schema must be a tuple!");
return nullptr;
}
py::tuple args_schema = py::reinterpret_borrow<py::tuple>(raw_args_schema);
if (!DTensor_OpSchema_recompute_comparison_key_impl(self, args_schema)) {
return nullptr;
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* DTensor_OpSchema_post_init(PyObject* mod, PyObject* self) {
HANDLE_TH_ERRORS
const py::handle self_handle = py::handle(self);
const py::handle raw_args_schema =
self_handle.attr(dtensor_interned_strings.args_schema);
if (!PyTuple_Check(raw_args_schema.ptr())) {
PyErr_SetString(
PyExc_TypeError,
"DTensor_OpSchema_post_init requires self.args_schema to be a tuple!");
return nullptr;
}
py::tuple args_schema = py::reinterpret_borrow<py::tuple>(raw_args_schema);
if (!DTensor_OpSchema_recompute_comparison_key_impl(self, args_schema)) {
return nullptr;
}
const auto dtensor_spec_class = get_dtensor_spec_class();
bool has_symints = false;
for (const auto& a : args_schema) {
if (Py_TYPE(a.ptr()) != (PyTypeObject*)(dtensor_spec_class.ptr()) &&
!py::isinstance(a, dtensor_spec_class)) {
continue;
}
const py::handle tensor_meta = a.attr(dtensor_interned_strings.tensor_meta);
if (tensor_meta.is_none()) {
continue;
}
const auto contains_any_symint = [](const py::tuple& sequence) {
for (const auto& s : sequence) {
if (THPUtils_checkLong(s.ptr())) {
continue;
}
if (torch::is_symint(s)) {
return true;
}
}
return false;
};
// Specifically it's supposed to be torch.Size.
py::object raw_shape = tensor_meta.attr(dtensor_interned_strings.shape);
if (!PyTuple_Check(raw_shape.ptr())) {
PyErr_SetString(PyExc_TypeError, "OpSchema.shape must be a tuple!");
return nullptr;
}
const auto shape = py::reinterpret_steal<py::tuple>(raw_shape.release());
if (contains_any_symint(shape)) {
has_symints = true;
}
}
self_handle.attr(dtensor_interned_strings.has_symints) = has_symints;
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
using getter = PyObject* (*)(PyObject*, void*);
using setter = int (*)(PyObject*, PyObject*, void*);
@ -1762,6 +1999,18 @@ static PyMethodDef extra_methods[] = {
{"_use_count", THPVariable__use_count, METH_NOARGS, nullptr},
{nullptr}};
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static PyMethodDef extra_functions[] = {
{"_DTensor_OpSchema_post_init",
DTensor_OpSchema_post_init,
METH_O,
nullptr},
{"_DTensor_OpSchema_recompute_comparison_key",
DTensor_OpSchema_recompute_comparison_key,
METH_O,
nullptr},
{nullptr}};
struct THPVariableMeta {
PyHeapTypeObject base;
};
@ -2488,5 +2737,10 @@ bool THPVariable_initModule(PyObject* module) {
torch::autograd::initTorchFunctions(module);
torch::autograd::initTensorImplConversion(module);
torch::utils::validate_numpy_for_dlpack_deleter_bug();
if (!intern_dtensor_strings()) {
return false;
}
PyModule_AddFunctions(module, extra_functions);
return true;
}

View File

@ -24,12 +24,16 @@ These schema definitions enable the DTensor system to:
"""
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, Optional, Union
from typing_extensions import deprecated
import torch
from torch._C import (
_DTensor_OpSchema_post_init,
_DTensor_OpSchema_recompute_comparison_key,
)
from torch._ops import OpOverload
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@ -331,6 +335,8 @@ class OpSchema:
_comparison_key: Optional[tuple[object, ...]] = None
has_symints: bool = field(init=False)
@property
def args_spec(self) -> tuple[DTensorSpec, ...]:
"""
@ -386,14 +392,7 @@ class OpSchema:
return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})"
def __post_init__(self) -> None:
has_symints = False
for a in self.args_schema:
if isinstance(a, DTensorSpec) and a.tensor_meta is not None:
if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape):
has_symints = True
break
self.has_symints = has_symints
self._recompute_comparison_key()
_DTensor_OpSchema_post_init(self)
def arg_type_tensor_or_tensor_list_like(self, arg: object) -> bool:
is_tensor = isinstance(arg, DTensorSpec)
@ -479,26 +478,8 @@ class OpSchema:
def is_view_op(self) -> bool:
return self.op._schema._is_view_op()
def _recompute_comparison_key(self):
if not self.schema_info:
static_argnum = len(self.args_schema)
static_kwargkey = None
else:
static_argnum = self.schema_info.static_argnum
static_kwargkey = self.schema_info.static_kwargkey
args_to_hash = tuple(
tuple(e) if isinstance(e, list) else e
for i, e in enumerate(self.args_schema)
if self.arg_type_tensor_or_tensor_list_like(e) or i >= static_argnum
)
if static_kwargkey is not None:
kwargs_to_hash = tuple(
self.kwargs_schema.get(k, None) for k in static_kwargkey
)
self._comparison_key = (self.op, args_to_hash, kwargs_to_hash)
else:
self._comparison_key = (self.op, args_to_hash)
def _recompute_comparison_key(self) -> None:
_DTensor_OpSchema_recompute_comparison_key(self)
def __hash__(self) -> int:
return hash(self._comparison_key)