fix bugs in #166371 on "Avoid creating Python OpSchema in the DTensor dispatch fast path"

All we need to do is move a few checks around.

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
This commit is contained in:
Scott Wolchok
2025-11-04 10:16:11 -08:00
2 changed files with 28 additions and 6 deletions

View File

@ -1208,11 +1208,15 @@ class NativeOpSchema {
NativeOpSchema(
const c10::OperatorHandle& op,
c10::SmallVector<IValueOrDTensorSpec, 8> comparison_key,
std::size_t comparison_key_hash)
std::size_t comparison_key_hash,
std::size_t args_schema_len)
: op_(op),
hash_(hash_combine(
std::hash<c10::OperatorHandle>()(op),
comparison_key_hash)),
hash_combine(
std::hash<c10::OperatorHandle>()(op),
comparison_key_hash),
args_schema_len)),
args_schema_len_(args_schema_len),
comparison_key_(std::move(comparison_key)) {}
bool operator==(const NativeOpSchema& rhs) const {
@ -1220,7 +1224,8 @@ class NativeOpSchema {
// equal, because comparison is occurring during a hash table
// lookup and we know the hashes are already equal. Therefore, we
// don't bother checking hash_ first.
return op_ == rhs.op_ && comparison_key_ == rhs.comparison_key_;
return op_ == rhs.op_ && args_schema_len_ == rhs.args_schema_len_ &&
comparison_key_ == rhs.comparison_key_;
}
std::size_t hash() const {
@ -1230,6 +1235,19 @@ class NativeOpSchema {
private:
const c10::OperatorHandle& op_;
std::size_t hash_;
// Subtle point: consider clamp.Tensor(Tensor self, Tensor?
// min=None, Tensor? max=None). The invocations clamp(t1, None, t2)
// and clamp(t1, t2, None) have the same comparison key (t1, t2)
// because we drop non-static non-tensor args from comparison. The
// only way we happen to be able to tell them apart is that we omit
// trailing defaulted arguments from the args tuple passed to
// __torch_dispatch__ (and hence to DTensor dispatch as well), so
// they have different args_schema_len_.
//
// I am preserving this existing behavior, but I suspect we should
// make an algorithm change to be less brittle, such as including
// None defaults for Tensor arguments in the comparison.
std::size_t args_schema_len_;
c10::SmallVector<IValueOrDTensorSpec, 8> comparison_key_;
};
@ -2030,7 +2048,11 @@ static std::optional<NativeOpSchema> create_native_op_schema(
}
}
return NativeOpSchema(op, std::move(comparison_key), comparison_key_hash);
return NativeOpSchema(
op,
std::move(comparison_key),
comparison_key_hash,
args_kwargs.num_positional_args());
}
using getter = PyObject* (*)(PyObject*, void*);

View File

@ -500,7 +500,7 @@ class OpDispatcher:
schema_info=runtime_schema_info,
)
if create_schema
else None,
else None, # type: ignore[arg-type]
args_schema,
tuple(local_args),
local_kwargs,