Make custom op alias check consistent (#164576)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164576
Approved by: https://github.com/soulitzer
This commit is contained in:
albanD
2025-10-06 15:09:42 -04:00
committed by PyTorch MergeBot
parent 49f7d8d19d
commit 56d66ac0d7
2 changed files with 5 additions and 5 deletions

View File

@ -341,13 +341,13 @@ def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"):
"""
custom operators' outputs must not alias any inputs or other outputs.
"""
storages = {id(t.untyped_storage()) for t in prev if isinstance(t, torch.Tensor)}
storages = {t.untyped_storage()._cdata for t in prev if isinstance(t, torch.Tensor)}
tuple_result = result
if not isinstance(result, tuple):
tuple_result = (result,)
for tensor in iter_tensors(tuple_result, {}):
key = id(tensor.untyped_storage())
if id(tensor.untyped_storage()) in storages:
key = tensor.untyped_storage()._cdata
if tensor.untyped_storage()._cdata in storages:
raise RuntimeError(
f"{name} (with implementation in {get_module()}): "
f"The output of this custom operator (1) must not "

View File

@ -1113,7 +1113,7 @@ static PyObject* any_output_is_alias_to_input_or_output(
if (!t.storage()) {
return false;
}
auto* cp = t.storage().data_ptr().get_context();
auto* cp = t.storage().unsafeGetStorageImpl();
if (cp) {
s.insert(cp);
}
@ -1124,7 +1124,7 @@ static PyObject* any_output_is_alias_to_input_or_output(
if (!t.storage()) {
return false;
}
auto* cp = t.storage().data_ptr().get_context();
auto* cp = t.storage().unsafeGetStorageImpl();
if (!cp) {
return false;
}