mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 06:48:48 +08:00
fix bugs in #166371 on "Avoid creating Python OpSchema in the DTensor dispatch fast path"
All we need to do is move a few checks around. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
This commit is contained in:
@ -1208,11 +1208,15 @@ class NativeOpSchema {
|
||||
NativeOpSchema(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::SmallVector<IValueOrDTensorSpec, 8> comparison_key,
|
||||
std::size_t comparison_key_hash)
|
||||
std::size_t comparison_key_hash,
|
||||
std::size_t args_schema_len)
|
||||
: op_(op),
|
||||
hash_(hash_combine(
|
||||
std::hash<c10::OperatorHandle>()(op),
|
||||
comparison_key_hash)),
|
||||
hash_combine(
|
||||
std::hash<c10::OperatorHandle>()(op),
|
||||
comparison_key_hash),
|
||||
args_schema_len)),
|
||||
args_schema_len_(args_schema_len),
|
||||
comparison_key_(std::move(comparison_key)) {}
|
||||
|
||||
bool operator==(const NativeOpSchema& rhs) const {
|
||||
@ -1220,7 +1224,8 @@ class NativeOpSchema {
|
||||
// equal, because comparison is occurring during a hash table
|
||||
// lookup and we know the hashes are already equal. Therefore, we
|
||||
// don't bother checking hash_ first.
|
||||
return op_ == rhs.op_ && comparison_key_ == rhs.comparison_key_;
|
||||
return op_ == rhs.op_ && args_schema_len_ == rhs.args_schema_len_ &&
|
||||
comparison_key_ == rhs.comparison_key_;
|
||||
}
|
||||
|
||||
std::size_t hash() const {
|
||||
@ -1230,6 +1235,19 @@ class NativeOpSchema {
|
||||
private:
|
||||
const c10::OperatorHandle& op_;
|
||||
std::size_t hash_;
|
||||
// Subtle point: consider clamp.Tensor(Tensor self, Tensor?
|
||||
// min=None, Tensor? max=None). The invocations clamp(t1, None, t2)
|
||||
// and clamp(t1, t2, None) have the same comparison key (t1, t2)
|
||||
// because we drop non-static non-tensor args from comparison. The
|
||||
// only way we happen to be able to tell them apart is that we omit
|
||||
// trailing defaulted arguments from the args tuple passed to
|
||||
// __torch_dispatch__ (and hence to DTensor dispatch as well), so
|
||||
// they have different args_schema_len_.
|
||||
//
|
||||
// I am preserving this existing behavior, but I suspect we should
|
||||
// make an algorithm change to be less brittle, such as including
|
||||
// None defaults for Tensor arguments in the comparison.
|
||||
std::size_t args_schema_len_;
|
||||
c10::SmallVector<IValueOrDTensorSpec, 8> comparison_key_;
|
||||
};
|
||||
|
||||
@ -2030,7 +2048,11 @@ static std::optional<NativeOpSchema> create_native_op_schema(
|
||||
}
|
||||
}
|
||||
|
||||
return NativeOpSchema(op, std::move(comparison_key), comparison_key_hash);
|
||||
return NativeOpSchema(
|
||||
op,
|
||||
std::move(comparison_key),
|
||||
comparison_key_hash,
|
||||
args_kwargs.num_positional_args());
|
||||
}
|
||||
|
||||
using getter = PyObject* (*)(PyObject*, void*);
|
||||
|
||||
@ -500,7 +500,7 @@ class OpDispatcher:
|
||||
schema_info=runtime_schema_info,
|
||||
)
|
||||
if create_schema
|
||||
else None,
|
||||
else None, # type: ignore[arg-type]
|
||||
args_schema,
|
||||
tuple(local_args),
|
||||
local_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user