mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Replace runtime type parameterization (#155221)
See:
```
>>> import timeit; print(f"OrderedSet[str](): {timeit.timeit('OrderedSet[str]()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s, OrderedSet(): {timeit.timeit('OrderedSet()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s")
```
> `OrderedSet[str]()`: 0.354622s, OrderedSet(): 0.095376s
Type parameterization should be on type hint, not in runtime.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155221
Approved by: https://github.com/Skylion007, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
7dcc77e422
commit
0827464002
@ -1688,7 +1688,7 @@ class KernelArgs:
|
||||
# after you do a call into this kernel, which buffers actually contain
|
||||
# updated data? Modeled off of python_argdefs.
|
||||
def live_output_buffers(self) -> OrderedSet[str]:
|
||||
live_outs = OrderedSet[str]()
|
||||
live_outs: OrderedSet[str] = OrderedSet()
|
||||
for inplaced in unique(self.inplace_buffers.values()):
|
||||
if isinstance(inplaced, RemovedArg):
|
||||
continue
|
||||
@ -1948,16 +1948,16 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
self.num_reduction = 0
|
||||
|
||||
self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix)
|
||||
self.must_keep_buffers = OrderedSet[str]()
|
||||
self.store_buffer_names = OrderedSet[str]()
|
||||
self.must_keep_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.store_buffer_names: OrderedSet[str] = OrderedSet()
|
||||
self._load_mask: Optional[str] = None
|
||||
self._load_other: Union[None, int, float] = None
|
||||
# OrderedSet in set_current_node
|
||||
self.current_node: Optional[SchedulerNode] = None
|
||||
self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None
|
||||
|
||||
self.removed_buffers = OrderedSet[str]()
|
||||
self.inplaced_to_remove = OrderedSet[str]()
|
||||
self.removed_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
|
||||
|
||||
# key: the buffer to write
|
||||
# value: the buffer to read and whose memory can be reused for
|
||||
@ -2144,7 +2144,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
for buf in self.store_buffer_names
|
||||
if buf in scheduler.name_to_buf
|
||||
)
|
||||
names_to_remove = OrderedSet[str]()
|
||||
names_to_remove: OrderedSet[str] = OrderedSet()
|
||||
for name in self.store_buffer_names:
|
||||
if (
|
||||
name not in self.must_keep_buffers
|
||||
|
||||
@ -4897,7 +4897,7 @@ class CppScheduling(BaseScheduling):
|
||||
# https://github.com/pytorch/pytorch/blob/1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 # noqa: B950
|
||||
# where the buffer is with size of last dim and contiguous.
|
||||
# Only support this typical case at first.
|
||||
visited_scheduler_nodes = OrderedSet[str]()
|
||||
visited_scheduler_nodes: OrderedSet[str] = OrderedSet()
|
||||
for scheduler_node in node.get_nodes():
|
||||
# all users inside same OuterLoopFusedSchedulerNode
|
||||
assert isinstance(scheduler_node, SchedulerNode)
|
||||
|
||||
@ -1341,7 +1341,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = []
|
||||
epilogue_creators: list[Callable[[ir.Buffer], ir.Pointwise]] = []
|
||||
fake_buffers: list[ir.Buffer] = []
|
||||
Y_aliases = OrderedSet[str]()
|
||||
Y_aliases: OrderedSet[str] = OrderedSet()
|
||||
|
||||
use_local_acc = (
|
||||
self.layout.dtype != torch.float
|
||||
|
||||
@ -56,7 +56,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
if not hasattr(self, "device"):
|
||||
self.device = "cpu"
|
||||
# must be initialized prior to calling super().__init__()
|
||||
self.included_devices = OrderedSet[str]()
|
||||
self.included_devices: OrderedSet[str] = OrderedSet()
|
||||
super().__init__()
|
||||
self.declare = "auto "
|
||||
self.declare_maybe_reference = "decltype(auto) "
|
||||
@ -66,14 +66,14 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
self.supports_intermediate_hooks = False
|
||||
self.kernel_callsite_id = count()
|
||||
self.int_array_id = count() # for int array local variable declarations
|
||||
self.declared_int_array_vars = OrderedSet[str]()
|
||||
self.declared_int_array_vars: OrderedSet[str] = OrderedSet()
|
||||
self.tmp_tensor_id = count() # for tmp tensor local variable declarations
|
||||
self.arg_var_id = count()
|
||||
self.used_cached_devices = OrderedSet[str]()
|
||||
self.used_cached_dtypes = OrderedSet[str]()
|
||||
self.used_cached_layouts = OrderedSet[str]()
|
||||
self.used_cached_memory_formats = OrderedSet[str]()
|
||||
self.used_cond_predicate = OrderedSet[str]()
|
||||
self.used_cached_devices: OrderedSet[str] = OrderedSet()
|
||||
self.used_cached_dtypes: OrderedSet[str] = OrderedSet()
|
||||
self.used_cached_layouts: OrderedSet[str] = OrderedSet()
|
||||
self.used_cached_memory_formats: OrderedSet[str] = OrderedSet()
|
||||
self.used_cond_predicate: OrderedSet[str] = OrderedSet()
|
||||
self.cached_output_id = count()
|
||||
self.scalar_to_tensor_id = count()
|
||||
self.custom_op_wrapper_loaded = False
|
||||
|
||||
@ -224,7 +224,7 @@ class MultiKernel:
|
||||
|
||||
def codegen_nan_check(self):
|
||||
wrapper = V.graph.wrapper_code
|
||||
seen = OrderedSet[str]()
|
||||
seen: OrderedSet[str] = OrderedSet()
|
||||
for k in self.kernels:
|
||||
_, call_args, precompile_args, _ = k.args.python_argdefs()
|
||||
for arg, precompile_arg in zip(call_args, precompile_args):
|
||||
|
||||
@ -1259,8 +1259,8 @@ class SIMDScheduling(BaseScheduling):
|
||||
done = OrderedSet[scheduler.BaseSchedulerNode]()
|
||||
# Writes with a reduced shape, meaning they are only present once the
|
||||
# reduction loop has ended
|
||||
not_ready_yet_nodes = OrderedSet[str]()
|
||||
current_loop_buffer_usage = OrderedSet[str]()
|
||||
not_ready_yet_nodes: OrderedSet[str] = OrderedSet()
|
||||
current_loop_buffer_usage: OrderedSet[str] = OrderedSet()
|
||||
maybe_split_index: Optional[int] = None
|
||||
|
||||
def fits_in_main_body(n):
|
||||
@ -2327,7 +2327,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
return default_tiling, None
|
||||
|
||||
seen_names = OrderedSet[str]()
|
||||
seen_names: OrderedSet[str] = OrderedSet()
|
||||
candidate_tiles: Counter[CandidateTiling] = collections.Counter()
|
||||
for node in EnableReduction.filter(node_schedule):
|
||||
for candidate_tiling in cls.candidate_tilings(node, numel, reduction_numel):
|
||||
|
||||
@ -123,7 +123,7 @@ class SIMDKernelFeatures:
|
||||
return bool(self.op_counts().get(op_name))
|
||||
|
||||
def get_mutations(self) -> OrderedSet[str]:
|
||||
mutations = OrderedSet[str]()
|
||||
mutations: OrderedSet[str] = OrderedSet()
|
||||
for node in self.scheduler_nodes():
|
||||
for buf in node.get_outputs():
|
||||
mutations.update(buf.get_mutations())
|
||||
@ -132,7 +132,7 @@ class SIMDKernelFeatures:
|
||||
@cache_on_self
|
||||
def select_index_dtype(self) -> torch.dtype:
|
||||
# Gather all used buffer names
|
||||
buffer_names = OrderedSet[str]()
|
||||
buffer_names: OrderedSet[str] = OrderedSet()
|
||||
for node in self.scheduler_nodes():
|
||||
buffer_names.update(node.get_buffer_names())
|
||||
buffer_names.update(node.used_buffer_names())
|
||||
|
||||
@ -749,7 +749,7 @@ class TritonCSEVariable(CSEVariable):
|
||||
def __init__(self, name, bounds: ValueRanges[Any], dtype: torch.dtype) -> None:
|
||||
super().__init__(name, bounds, dtype)
|
||||
# We'll use this to track which masks the variable needs when used for indirect indexing
|
||||
self.mask_vars = OrderedSet[str]()
|
||||
self.mask_vars: OrderedSet[str] = OrderedSet()
|
||||
assert dtype is not None, "TritonCSEVariable must have dtype"
|
||||
|
||||
def update_on_args(self, name, args, kwargs):
|
||||
@ -1769,7 +1769,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
index_vars = index.free_symbols
|
||||
has_rindex = False
|
||||
|
||||
mask_vars: OrderedSet[str] = OrderedSet[str]()
|
||||
mask_vars: OrderedSet[str] = OrderedSet()
|
||||
for var in sorted(index_vars, key=operator.attrgetter("name")):
|
||||
assert isinstance(var, sympy.Symbol)
|
||||
has_rindex = has_rindex or symbol_is_type(
|
||||
@ -1811,7 +1811,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
|
||||
have_dense = True
|
||||
have_loop_vars = False
|
||||
dense_mask_vars = OrderedSet[str]()
|
||||
dense_mask_vars: OrderedSet[str] = OrderedSet()
|
||||
|
||||
for tree in self.active_range_trees():
|
||||
if index_vars.intersection(tree.var_list):
|
||||
@ -3550,7 +3550,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol]
|
||||
)
|
||||
|
||||
mutated_args = OrderedSet[str]()
|
||||
mutated_args: OrderedSet[str] = OrderedSet()
|
||||
for mutation in self.mutations:
|
||||
if mutation in self.args.input_buffers:
|
||||
mutated_args.add(self.args.input_buffers[mutation])
|
||||
|
||||
@ -536,7 +536,7 @@ class ComboKernel(Kernel):
|
||||
return heuristics_list[0], size_hints_list[0], self.sub_kernels[0]
|
||||
|
||||
def get_mutated_args_sub_kernels(self) -> list[str]:
|
||||
mutated_args = OrderedSet[str]()
|
||||
mutated_args: OrderedSet[str] = OrderedSet()
|
||||
for sub_kernel in self.sub_kernels:
|
||||
for mutation in sub_kernel.mutations:
|
||||
if mutation in sub_kernel.args.input_buffers:
|
||||
|
||||
@ -215,7 +215,7 @@ def user_defined_kernel_grid_fn_code(
|
||||
else:
|
||||
assert len(grids) > 1
|
||||
assert len(grids) == len(configs)
|
||||
seen = OrderedSet[str]()
|
||||
seen: OrderedSet[str] = OrderedSet()
|
||||
# sort the configs from the largest # of kwargs to the smallest to
|
||||
# emit the grids in the order of (approximately) decreasing specificity
|
||||
# TODO(aakhundov): the sorting below is generally not sufficient, so
|
||||
@ -857,7 +857,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.kernel_autotune_defs = IndentedBuffer()
|
||||
self.kernel_autotune_calls = IndentedBuffer()
|
||||
self.subgraph_definitions = IndentedBuffer()
|
||||
self.kernel_autotune_names = OrderedSet[str]()
|
||||
self.kernel_autotune_names: OrderedSet[str] = OrderedSet()
|
||||
# Map key is the kernel argument name; value is a tuple of the resulting example
|
||||
# tensor name with the kernel where that tensor was most recently used.
|
||||
self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {}
|
||||
@ -877,7 +877,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.last_seen_device_guard_index: Optional[int] = None
|
||||
self.supports_intermediate_hooks = True
|
||||
self.user_defined_kernel_cache: dict[tuple[Any, ...], tuple[str, Any]] = {}
|
||||
self.unbacked_symbol_decls = OrderedSet[str]() # str of sympy.Symbol
|
||||
self.unbacked_symbol_decls: OrderedSet[str] = (
|
||||
OrderedSet()
|
||||
) # str of sympy.Symbol
|
||||
self.computed_sizes: OrderedSet[sympy.Symbol] = OrderedSet()
|
||||
self.launcher_fn_name = None
|
||||
# This function can be overridden to change the launcher name
|
||||
@ -921,9 +923,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||
|
||||
self.add_import_once = add_import_once
|
||||
self._metas: dict[str, str] = {}
|
||||
self._meta_vars = OrderedSet[str]()
|
||||
self._meta_vars: OrderedSet[str] = OrderedSet()
|
||||
self.multi_kernel_state = MultiKernelState()
|
||||
self.already_codegened_subgraphs = OrderedSet[str]()
|
||||
self.already_codegened_subgraphs: OrderedSet[str] = OrderedSet()
|
||||
self.allocated_workspaces: dict[str, Any] = {}
|
||||
|
||||
# intermediate tensor value printing utility
|
||||
|
||||
@ -343,7 +343,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {}
|
||||
self.graph_inputs_original: dict[str, InputBuffer] = {}
|
||||
self.partition_maps: Optional[list[GraphPartitionMap]] = None
|
||||
self.zero_dim_cpu_tensor_list = OrderedSet[str]()
|
||||
self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet()
|
||||
self.device_types: OrderedSet[str] = (
|
||||
const_module.device_types if const_module else OrderedSet()
|
||||
)
|
||||
@ -380,12 +380,12 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
] = {}
|
||||
self.seen_subgraphs: dict[str, ir.Subgraph] = {}
|
||||
self.constant_reprs: dict[str, str] = {}
|
||||
self.removed_operations = OrderedSet[str]()
|
||||
self.removed_buffers = OrderedSet[str]()
|
||||
self.removed_inplace_buffers = OrderedSet[str]()
|
||||
self.mutated_buffers = OrderedSet[str]()
|
||||
self.never_reuse_buffers = OrderedSet[str]()
|
||||
self.inplaced_to_remove = OrderedSet[str]()
|
||||
self.removed_operations: OrderedSet[str] = OrderedSet()
|
||||
self.removed_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.removed_inplace_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.mutated_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.never_reuse_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
|
||||
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
||||
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
|
||||
# See `ProxyExecutor Design Note` in ir.py for more details
|
||||
@ -401,7 +401,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
self.current_node: torch.fx.Node = None # type: ignore[assignment]
|
||||
self.lists: dict[str, list[str]] = {}
|
||||
self.mutated_inputs = OrderedSet[str]()
|
||||
self.mutated_inputs: OrderedSet[str] = OrderedSet()
|
||||
self.mutated_input_idxs: list[int] = []
|
||||
self.name_to_buffer: dict[str, ir.Buffer] = {}
|
||||
self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list)
|
||||
@ -466,14 +466,14 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# This can either be a graph input or the output of fallback
|
||||
# kernels.
|
||||
self.unaligned_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.no_fuse_buffer_names = OrderedSet[str]()
|
||||
self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet()
|
||||
|
||||
self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet()
|
||||
# more aggressive prologue fusion
|
||||
self.invoke_quant_ops: OrderedSet[str] = OrderedSet()
|
||||
|
||||
# Below field is related to printing debug intermediate tensor values info for debugging
|
||||
self.all_codegen_kernel_names = OrderedSet[str]()
|
||||
self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet()
|
||||
|
||||
# state used by for Kernel.workspace
|
||||
self.workspace_id = itertools.count()
|
||||
|
||||
@ -433,7 +433,7 @@ def enabled_metric_tables() -> OrderedSet[str]:
|
||||
|
||||
@lru_cache
|
||||
def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]:
|
||||
enabled = OrderedSet[str]()
|
||||
enabled: OrderedSet[str] = OrderedSet()
|
||||
for name in config_str.split(","):
|
||||
name = name.strip()
|
||||
if not name:
|
||||
|
||||
@ -1561,7 +1561,7 @@ def register_replacement(
|
||||
return pattern.pattern
|
||||
|
||||
|
||||
_serialized_patterns = OrderedSet[str]()
|
||||
_serialized_patterns: OrderedSet[str] = OrderedSet()
|
||||
|
||||
|
||||
def _serialize_pattern(
|
||||
@ -2235,7 +2235,7 @@ def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
|
||||
|
||||
# TODO: remove in follow up diff, used internally
|
||||
_seen_patterns = OrderedSet[str]()
|
||||
_seen_patterns: OrderedSet[str] = OrderedSet()
|
||||
|
||||
|
||||
def get_arg_value(
|
||||
|
||||
@ -212,7 +212,7 @@ class BaseSchedulerNode:
|
||||
|
||||
def _init_from_node(self, node: ir.Operation) -> None:
|
||||
self.node: Optional[ir.Operation] = node
|
||||
self.ancestors = OrderedSet[str]()
|
||||
self.ancestors: OrderedSet[str] = OrderedSet()
|
||||
self.last_usage = OrderedSet[
|
||||
str
|
||||
]() # buffers that won't be used after this kernel
|
||||
@ -325,7 +325,7 @@ class BaseSchedulerNode:
|
||||
)
|
||||
|
||||
def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
|
||||
used_names = OrderedSet[str]()
|
||||
used_names: OrderedSet[str] = OrderedSet()
|
||||
|
||||
deps = [
|
||||
dep.name
|
||||
@ -1238,7 +1238,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
|
||||
@cache_on_self
|
||||
def _get_atomic_add_buffers(self) -> OrderedSet[str]:
|
||||
buffers_store_as_atomic_add = OrderedSet[str]()
|
||||
buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet()
|
||||
if isinstance(self._body, LoopBody):
|
||||
for node in self._body.get_nodes():
|
||||
if (
|
||||
@ -1441,7 +1441,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
super().set_last_usage(future_used_buffers, mutation_real_name)
|
||||
# Set self.last_usage on the snodes
|
||||
# This will be used for optimisations within the kernel
|
||||
future_used_buffers = OrderedSet[str]()
|
||||
future_used_buffers: OrderedSet[str] = OrderedSet()
|
||||
for node in reversed(self.snodes):
|
||||
node.set_last_usage(future_used_buffers, mutation_real_name)
|
||||
future_used_buffers.update(node.last_usage)
|
||||
@ -2032,7 +2032,7 @@ class Scheduler:
|
||||
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
||||
self._graph_partition_counter = itertools.count()
|
||||
|
||||
self.completed_operations = OrderedSet[str]()
|
||||
self.completed_operations: OrderedSet[str] = OrderedSet()
|
||||
self.available_buffer_names = OrderedSet(
|
||||
[
|
||||
*V.graph.graph_inputs.keys(),
|
||||
@ -2134,7 +2134,7 @@ class Scheduler:
|
||||
self.debug_draw_graph()
|
||||
|
||||
# used during codegen:
|
||||
self.buffer_names_to_free = OrderedSet[str]()
|
||||
self.buffer_names_to_free: OrderedSet[str] = OrderedSet()
|
||||
|
||||
# fx graph node to the position it appears in the graph
|
||||
# for debug attribution
|
||||
@ -2194,7 +2194,7 @@ class Scheduler:
|
||||
raise NotImplementedError(node)
|
||||
|
||||
def create_foreach_nodes(self) -> None:
|
||||
removed_node_names = OrderedSet[str]()
|
||||
removed_node_names: OrderedSet[str] = OrderedSet()
|
||||
fe_nodes = []
|
||||
kept_node_names = self.name_to_fused_node.keys()
|
||||
|
||||
@ -2515,7 +2515,7 @@ class Scheduler:
|
||||
return result
|
||||
|
||||
def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]:
|
||||
unmet_deps = OrderedSet[str]()
|
||||
unmet_deps: OrderedSet[str] = OrderedSet()
|
||||
if isinstance(
|
||||
snode,
|
||||
(
|
||||
@ -2567,7 +2567,7 @@ class Scheduler:
|
||||
# note self.nodes is topologically sorted
|
||||
name_to_ancestors: dict[str, OrderedSet[str]] = {}
|
||||
for node in self.nodes:
|
||||
ancestors = OrderedSet[str]()
|
||||
ancestors: OrderedSet[str] = OrderedSet()
|
||||
for dep in node.unmet_dependencies:
|
||||
dep_node_name = self.name_to_buf[dep.name].defining_op_name()
|
||||
ancestors.add(dep_node_name)
|
||||
|
||||
Reference in New Issue
Block a user