mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ez] remove unused functions (#156466)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156466 Approved by: https://github.com/jingsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
2ee23175d9
commit
a67eb1a0d6
@ -18,65 +18,6 @@ from .common import TensorArg, WorkspaceArg
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_kernel_argdefs(kernel):
|
||||
arg_defs, _, _, _ = kernel.args.python_argdefs()
|
||||
return [x.name for x in arg_defs]
|
||||
|
||||
|
||||
def _get_all_args(args_list, arg_types_list=None):
|
||||
all_args = max(args_list, key=len)[:]
|
||||
arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None
|
||||
for args in args_list:
|
||||
assert OrderedSet(args).issubset(OrderedSet(all_args)), (
|
||||
f"{args} v.s. {all_args}"
|
||||
)
|
||||
|
||||
return all_args, arg_types
|
||||
|
||||
|
||||
def get_all_kernel_argdefs(kernels):
|
||||
"""
|
||||
The logic here must match with `get_all_call_args`, except no need to get arg_types here
|
||||
"""
|
||||
argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels]
|
||||
|
||||
return _get_all_args(argdefs_list)[0]
|
||||
|
||||
|
||||
def get_all_call_args(call_args_list, arg_types_list):
|
||||
"""
|
||||
Passed in the call_args for each subkernel and return the call_args for the
|
||||
combined multi-kernel.
|
||||
|
||||
Note an algorithm as follows does not always work:
|
||||
```
|
||||
all_call_args: Dict[
|
||||
Any, None
|
||||
] = {} # use a dict rather than set to maintain insertion order
|
||||
for call_args in call_args_list:
|
||||
all_call_args.update({arg: None for arg in call_args})
|
||||
|
||||
all_call_args = list(all_call_args.keys())
|
||||
```
|
||||
It will fail if any kernel has the same argument passed in multiple times.
|
||||
Check test_pass_same_arg_multi_times in test_multi_kernel.py
|
||||
|
||||
Instead, we pick the longest call args and assert that other call args are
|
||||
a subset of it.
|
||||
"""
|
||||
return _get_all_args(call_args_list, arg_types_list)
|
||||
|
||||
|
||||
def get_numel_argdefs(kernel):
|
||||
numel_argdefs = [
|
||||
f"{tree.prefix}numel"
|
||||
for tree in kernel.range_trees
|
||||
if not tree.is_reduction or kernel.inside_reduction
|
||||
]
|
||||
|
||||
return numel_argdefs
|
||||
|
||||
|
||||
class MultiKernelState:
|
||||
"""
|
||||
Maintain state of multi-kernel compilation so we don't define duplicated
|
||||
|
Reference in New Issue
Block a user