mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make custom op alias check consistent (#164576)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164576 Approved by: https://github.com/soulitzer ghstack dependencies: #164467
This commit is contained in:
@ -341,13 +341,13 @@ def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"):
|
||||
"""
|
||||
custom operators' outputs must not alias any inputs or other outputs.
|
||||
"""
|
||||
storages = {id(t.untyped_storage()) for t in prev if isinstance(t, torch.Tensor)}
|
||||
storages = {t.untyped_storage()._cdata for t in prev if isinstance(t, torch.Tensor)}
|
||||
tuple_result = result
|
||||
if not isinstance(result, tuple):
|
||||
tuple_result = (result,)
|
||||
for tensor in iter_tensors(tuple_result, {}):
|
||||
key = id(tensor.untyped_storage())
|
||||
if id(tensor.untyped_storage()) in storages:
|
||||
key = tensor.untyped_storage()._cdata
|
||||
if tensor.untyped_storage()._cdata in storages:
|
||||
raise RuntimeError(
|
||||
f"{name} (with implementation in {get_module()}): "
|
||||
f"The output of this custom operator (1) must not "
|
||||
|
@ -1113,7 +1113,7 @@ static PyObject* any_output_is_alias_to_input_or_output(
|
||||
if (!t.storage()) {
|
||||
return false;
|
||||
}
|
||||
auto* cp = t.storage().data_ptr().get_context();
|
||||
auto* cp = t.storage().unsafeGetStorageImpl();
|
||||
if (cp) {
|
||||
s.insert(cp);
|
||||
}
|
||||
@ -1124,7 +1124,7 @@ static PyObject* any_output_is_alias_to_input_or_output(
|
||||
if (!t.storage()) {
|
||||
return false;
|
||||
}
|
||||
auto* cp = t.storage().data_ptr().get_context();
|
||||
auto* cp = t.storage().unsafeGetStorageImpl();
|
||||
if (!cp) {
|
||||
return false;
|
||||
}
|
||||
|
Reference in New Issue
Block a user