Improve assert perf in _python_dispatch._correct_storage_aliasing (#161317)

This assertion was expensive because of is_traceable_wrapper_subclass. Finding a cheap check to run first that's likely to let us skip the rest seems to improve things significantly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161317
Approved by: https://github.com/ezyang, https://github.com/XilunWu, https://github.com/bdhirsh
ghstack dependencies: #161301, #161292, #161304, #161308, #161315
This commit is contained in:
Scott Wolchok
2025-08-29 15:54:12 -07:00
committed by PyTorch MergeBot
parent 0c459f2921
commit 302d860157

View File

@ -536,7 +536,16 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
# in theory if a subclass that needs this API wants to sometimes return
# plain tensors, we could remove the assert and just not perform the aliasing,
# but it seems safer to learn more about this case first.
if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret):
#
# Performance note: This is all just to assert that the argument and result
# types match, checking that is cheaper than is_traceable_wrapper_subclass_type,
# and multiple returns are relatively unlikely, so just check up front!
arg_type = type(arg)
ret_type = type(ret)
if arg_type is not ret_type and (
is_traceable_wrapper_subclass_type(arg_type)
or is_traceable_wrapper_subclass_type(ret_type)
):
ret_list = ret if isinstance(ret, list) else [ret]
for r in ret_list:
assert type(arg) == type(