mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Improve assert perf in _python_dispatch._correct_storage_aliasing (#161317)
This assertion was expensive because of is_traceable_wrapper_subclass. Finding a cheap check to run first that's likely to let us skip the rest seems to improve things significantly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161317 Approved by: https://github.com/ezyang, https://github.com/XilunWu, https://github.com/bdhirsh ghstack dependencies: #161301, #161292, #161304, #161308, #161315
This commit is contained in:
committed by
PyTorch MergeBot
parent
0c459f2921
commit
302d860157
@ -536,7 +536,16 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
|
||||
# in theory if a subclass that needs this API wants to sometimes return
|
||||
# plain tensors, we could remove the assert and just not perform the aliasing,
|
||||
# but it seems safer to learn more about this case first.
|
||||
if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret):
|
||||
#
|
||||
# Performance note: This is all just to assert that the argument and result
|
||||
# types match, checking that is cheaper than is_traceable_wrapper_subclass_type,
|
||||
# and multiple returns are relatively unlikely, so just check up front!
|
||||
arg_type = type(arg)
|
||||
ret_type = type(ret)
|
||||
if arg_type is not ret_type and (
|
||||
is_traceable_wrapper_subclass_type(arg_type)
|
||||
or is_traceable_wrapper_subclass_type(ret_type)
|
||||
):
|
||||
ret_list = ret if isinstance(ret, list) else [ret]
|
||||
for r in ret_list:
|
||||
assert type(arg) == type(
|
||||
|
Reference in New Issue
Block a user