Replace runtime type parameterization (#155221)

See:

```
>>> import timeit; print(f"OrderedSet[str](): {timeit.timeit('OrderedSet[str]()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s, OrderedSet(): {timeit.timeit('OrderedSet()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s")
```
> `OrderedSet[str]()`: 0.354622s, OrderedSet(): 0.095376s

Type parameterization should be on type hint, not in runtime.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155221
Approved by: https://github.com/Skylion007, https://github.com/jansel
This commit is contained in:
eellison
2025-06-05 06:57:09 -07:00
committed by PyTorch MergeBot
parent 7dcc77e422
commit 0827464002
14 changed files with 55 additions and 53 deletions

View File

@ -1688,7 +1688,7 @@ class KernelArgs:
# after you do a call into this kernel, which buffers actually contain
# updated data? Modeled off of python_argdefs.
def live_output_buffers(self) -> OrderedSet[str]:
live_outs = OrderedSet[str]()
live_outs: OrderedSet[str] = OrderedSet()
for inplaced in unique(self.inplace_buffers.values()):
if isinstance(inplaced, RemovedArg):
continue
@ -1948,16 +1948,16 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.num_reduction = 0
self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix)
self.must_keep_buffers = OrderedSet[str]()
self.store_buffer_names = OrderedSet[str]()
self.must_keep_buffers: OrderedSet[str] = OrderedSet()
self.store_buffer_names: OrderedSet[str] = OrderedSet()
self._load_mask: Optional[str] = None
self._load_other: Union[None, int, float] = None
# OrderedSet in set_current_node
self.current_node: Optional[SchedulerNode] = None
self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None
self.removed_buffers = OrderedSet[str]()
self.inplaced_to_remove = OrderedSet[str]()
self.removed_buffers: OrderedSet[str] = OrderedSet()
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
# key: the buffer to write
# value: the buffer to read and whose memory can be reused for
@ -2144,7 +2144,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
for buf in self.store_buffer_names
if buf in scheduler.name_to_buf
)
names_to_remove = OrderedSet[str]()
names_to_remove: OrderedSet[str] = OrderedSet()
for name in self.store_buffer_names:
if (
name not in self.must_keep_buffers

View File

@ -4897,7 +4897,7 @@ class CppScheduling(BaseScheduling):
# https://github.com/pytorch/pytorch/blob/1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 # noqa: B950
# where the buffer is with size of last dim and contiguous.
# Only support this typical case at first.
visited_scheduler_nodes = OrderedSet[str]()
visited_scheduler_nodes: OrderedSet[str] = OrderedSet()
for scheduler_node in node.get_nodes():
# all users inside same OuterLoopFusedSchedulerNode
assert isinstance(scheduler_node, SchedulerNode)

View File

@ -1341,7 +1341,7 @@ class CppGemmTemplate(CppTemplate):
reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = []
epilogue_creators: list[Callable[[ir.Buffer], ir.Pointwise]] = []
fake_buffers: list[ir.Buffer] = []
Y_aliases = OrderedSet[str]()
Y_aliases: OrderedSet[str] = OrderedSet()
use_local_acc = (
self.layout.dtype != torch.float

View File

@ -56,7 +56,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
if not hasattr(self, "device"):
self.device = "cpu"
# must be initialized prior to calling super().__init__()
self.included_devices = OrderedSet[str]()
self.included_devices: OrderedSet[str] = OrderedSet()
super().__init__()
self.declare = "auto "
self.declare_maybe_reference = "decltype(auto) "
@ -66,14 +66,14 @@ class CppWrapperCpu(PythonWrapperCodegen):
self.supports_intermediate_hooks = False
self.kernel_callsite_id = count()
self.int_array_id = count() # for int array local variable declarations
self.declared_int_array_vars = OrderedSet[str]()
self.declared_int_array_vars: OrderedSet[str] = OrderedSet()
self.tmp_tensor_id = count() # for tmp tensor local variable declarations
self.arg_var_id = count()
self.used_cached_devices = OrderedSet[str]()
self.used_cached_dtypes = OrderedSet[str]()
self.used_cached_layouts = OrderedSet[str]()
self.used_cached_memory_formats = OrderedSet[str]()
self.used_cond_predicate = OrderedSet[str]()
self.used_cached_devices: OrderedSet[str] = OrderedSet()
self.used_cached_dtypes: OrderedSet[str] = OrderedSet()
self.used_cached_layouts: OrderedSet[str] = OrderedSet()
self.used_cached_memory_formats: OrderedSet[str] = OrderedSet()
self.used_cond_predicate: OrderedSet[str] = OrderedSet()
self.cached_output_id = count()
self.scalar_to_tensor_id = count()
self.custom_op_wrapper_loaded = False

View File

@ -224,7 +224,7 @@ class MultiKernel:
def codegen_nan_check(self):
wrapper = V.graph.wrapper_code
seen = OrderedSet[str]()
seen: OrderedSet[str] = OrderedSet()
for k in self.kernels:
_, call_args, precompile_args, _ = k.args.python_argdefs()
for arg, precompile_arg in zip(call_args, precompile_args):

View File

@ -1259,8 +1259,8 @@ class SIMDScheduling(BaseScheduling):
done = OrderedSet[scheduler.BaseSchedulerNode]()
# Writes with a reduced shape, meaning they are only present once the
# reduction loop has ended
not_ready_yet_nodes = OrderedSet[str]()
current_loop_buffer_usage = OrderedSet[str]()
not_ready_yet_nodes: OrderedSet[str] = OrderedSet()
current_loop_buffer_usage: OrderedSet[str] = OrderedSet()
maybe_split_index: Optional[int] = None
def fits_in_main_body(n):
@ -2327,7 +2327,7 @@ class SIMDScheduling(BaseScheduling):
return default_tiling, None
seen_names = OrderedSet[str]()
seen_names: OrderedSet[str] = OrderedSet()
candidate_tiles: Counter[CandidateTiling] = collections.Counter()
for node in EnableReduction.filter(node_schedule):
for candidate_tiling in cls.candidate_tilings(node, numel, reduction_numel):

View File

@ -123,7 +123,7 @@ class SIMDKernelFeatures:
return bool(self.op_counts().get(op_name))
def get_mutations(self) -> OrderedSet[str]:
mutations = OrderedSet[str]()
mutations: OrderedSet[str] = OrderedSet()
for node in self.scheduler_nodes():
for buf in node.get_outputs():
mutations.update(buf.get_mutations())
@ -132,7 +132,7 @@ class SIMDKernelFeatures:
@cache_on_self
def select_index_dtype(self) -> torch.dtype:
# Gather all used buffer names
buffer_names = OrderedSet[str]()
buffer_names: OrderedSet[str] = OrderedSet()
for node in self.scheduler_nodes():
buffer_names.update(node.get_buffer_names())
buffer_names.update(node.used_buffer_names())

View File

@ -749,7 +749,7 @@ class TritonCSEVariable(CSEVariable):
def __init__(self, name, bounds: ValueRanges[Any], dtype: torch.dtype) -> None:
super().__init__(name, bounds, dtype)
# We'll use this to track which masks the variable needs when used for indirect indexing
self.mask_vars = OrderedSet[str]()
self.mask_vars: OrderedSet[str] = OrderedSet()
assert dtype is not None, "TritonCSEVariable must have dtype"
def update_on_args(self, name, args, kwargs):
@ -1769,7 +1769,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
index_vars = index.free_symbols
has_rindex = False
mask_vars: OrderedSet[str] = OrderedSet[str]()
mask_vars: OrderedSet[str] = OrderedSet()
for var in sorted(index_vars, key=operator.attrgetter("name")):
assert isinstance(var, sympy.Symbol)
has_rindex = has_rindex or symbol_is_type(
@ -1811,7 +1811,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
have_dense = True
have_loop_vars = False
dense_mask_vars = OrderedSet[str]()
dense_mask_vars: OrderedSet[str] = OrderedSet()
for tree in self.active_range_trees():
if index_vars.intersection(tree.var_list):
@ -3550,7 +3550,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol]
)
mutated_args = OrderedSet[str]()
mutated_args: OrderedSet[str] = OrderedSet()
for mutation in self.mutations:
if mutation in self.args.input_buffers:
mutated_args.add(self.args.input_buffers[mutation])

View File

@ -536,7 +536,7 @@ class ComboKernel(Kernel):
return heuristics_list[0], size_hints_list[0], self.sub_kernels[0]
def get_mutated_args_sub_kernels(self) -> list[str]:
mutated_args = OrderedSet[str]()
mutated_args: OrderedSet[str] = OrderedSet()
for sub_kernel in self.sub_kernels:
for mutation in sub_kernel.mutations:
if mutation in sub_kernel.args.input_buffers:

View File

@ -215,7 +215,7 @@ def user_defined_kernel_grid_fn_code(
else:
assert len(grids) > 1
assert len(grids) == len(configs)
seen = OrderedSet[str]()
seen: OrderedSet[str] = OrderedSet()
# sort the configs from the largest # of kwargs to the smallest to
# emit the grids in the order of (approximately) decreasing specificity
# TODO(aakhundov): the sorting below is generally not sufficient, so
@ -857,7 +857,7 @@ class PythonWrapperCodegen(CodeGen):
self.kernel_autotune_defs = IndentedBuffer()
self.kernel_autotune_calls = IndentedBuffer()
self.subgraph_definitions = IndentedBuffer()
self.kernel_autotune_names = OrderedSet[str]()
self.kernel_autotune_names: OrderedSet[str] = OrderedSet()
# Map key is the kernel argument name; value is a tuple of the resulting example
# tensor name with the kernel where that tensor was most recently used.
self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {}
@ -877,7 +877,9 @@ class PythonWrapperCodegen(CodeGen):
self.last_seen_device_guard_index: Optional[int] = None
self.supports_intermediate_hooks = True
self.user_defined_kernel_cache: dict[tuple[Any, ...], tuple[str, Any]] = {}
self.unbacked_symbol_decls = OrderedSet[str]() # str of sympy.Symbol
self.unbacked_symbol_decls: OrderedSet[str] = (
OrderedSet()
) # str of sympy.Symbol
self.computed_sizes: OrderedSet[sympy.Symbol] = OrderedSet()
self.launcher_fn_name = None
# This function can be overridden to change the launcher name
@ -921,9 +923,9 @@ class PythonWrapperCodegen(CodeGen):
self.add_import_once = add_import_once
self._metas: dict[str, str] = {}
self._meta_vars = OrderedSet[str]()
self._meta_vars: OrderedSet[str] = OrderedSet()
self.multi_kernel_state = MultiKernelState()
self.already_codegened_subgraphs = OrderedSet[str]()
self.already_codegened_subgraphs: OrderedSet[str] = OrderedSet()
self.allocated_workspaces: dict[str, Any] = {}
# intermediate tensor value printing utility

View File

@ -343,7 +343,7 @@ class GraphLowering(torch.fx.Interpreter):
self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {}
self.graph_inputs_original: dict[str, InputBuffer] = {}
self.partition_maps: Optional[list[GraphPartitionMap]] = None
self.zero_dim_cpu_tensor_list = OrderedSet[str]()
self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet()
self.device_types: OrderedSet[str] = (
const_module.device_types if const_module else OrderedSet()
)
@ -380,12 +380,12 @@ class GraphLowering(torch.fx.Interpreter):
] = {}
self.seen_subgraphs: dict[str, ir.Subgraph] = {}
self.constant_reprs: dict[str, str] = {}
self.removed_operations = OrderedSet[str]()
self.removed_buffers = OrderedSet[str]()
self.removed_inplace_buffers = OrderedSet[str]()
self.mutated_buffers = OrderedSet[str]()
self.never_reuse_buffers = OrderedSet[str]()
self.inplaced_to_remove = OrderedSet[str]()
self.removed_operations: OrderedSet[str] = OrderedSet()
self.removed_buffers: OrderedSet[str] = OrderedSet()
self.removed_inplace_buffers: OrderedSet[str] = OrderedSet()
self.mutated_buffers: OrderedSet[str] = OrderedSet()
self.never_reuse_buffers: OrderedSet[str] = OrderedSet()
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
# See `ProxyExecutor Design Note` in ir.py for more details
@ -401,7 +401,7 @@ class GraphLowering(torch.fx.Interpreter):
self.current_node: torch.fx.Node = None # type: ignore[assignment]
self.lists: dict[str, list[str]] = {}
self.mutated_inputs = OrderedSet[str]()
self.mutated_inputs: OrderedSet[str] = OrderedSet()
self.mutated_input_idxs: list[int] = []
self.name_to_buffer: dict[str, ir.Buffer] = {}
self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list)
@ -466,14 +466,14 @@ class GraphLowering(torch.fx.Interpreter):
# This can either be a graph input or the output of fallback
# kernels.
self.unaligned_buffers: OrderedSet[str] = OrderedSet()
self.no_fuse_buffer_names = OrderedSet[str]()
self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet()
self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet()
# more aggressive prologue fusion
self.invoke_quant_ops: OrderedSet[str] = OrderedSet()
# Below field is related to printing debug intermediate tensor values info for debugging
self.all_codegen_kernel_names = OrderedSet[str]()
self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet()
# state used by for Kernel.workspace
self.workspace_id = itertools.count()

View File

@ -433,7 +433,7 @@ def enabled_metric_tables() -> OrderedSet[str]:
@lru_cache
def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]:
enabled = OrderedSet[str]()
enabled: OrderedSet[str] = OrderedSet()
for name in config_str.split(","):
name = name.strip()
if not name:

View File

@ -1561,7 +1561,7 @@ def register_replacement(
return pattern.pattern
_serialized_patterns = OrderedSet[str]()
_serialized_patterns: OrderedSet[str] = OrderedSet()
def _serialize_pattern(
@ -2235,7 +2235,7 @@ def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
# TODO: remove in follow up diff, used internally
_seen_patterns = OrderedSet[str]()
_seen_patterns: OrderedSet[str] = OrderedSet()
def get_arg_value(

View File

@ -212,7 +212,7 @@ class BaseSchedulerNode:
def _init_from_node(self, node: ir.Operation) -> None:
self.node: Optional[ir.Operation] = node
self.ancestors = OrderedSet[str]()
self.ancestors: OrderedSet[str] = OrderedSet()
self.last_usage = OrderedSet[
str
]() # buffers that won't be used after this kernel
@ -325,7 +325,7 @@ class BaseSchedulerNode:
)
def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
used_names = OrderedSet[str]()
used_names: OrderedSet[str] = OrderedSet()
deps = [
dep.name
@ -1238,7 +1238,7 @@ class SchedulerNode(BaseSchedulerNode):
@cache_on_self
def _get_atomic_add_buffers(self) -> OrderedSet[str]:
buffers_store_as_atomic_add = OrderedSet[str]()
buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet()
if isinstance(self._body, LoopBody):
for node in self._body.get_nodes():
if (
@ -1441,7 +1441,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
super().set_last_usage(future_used_buffers, mutation_real_name)
# Set self.last_usage on the snodes
# This will be used for optimisations within the kernel
future_used_buffers = OrderedSet[str]()
future_used_buffers: OrderedSet[str] = OrderedSet()
for node in reversed(self.snodes):
node.set_last_usage(future_used_buffers, mutation_real_name)
future_used_buffers.update(node.last_usage)
@ -2032,7 +2032,7 @@ class Scheduler:
self.post_grad_graph_id = next(_post_grad_graph_counter)
self._graph_partition_counter = itertools.count()
self.completed_operations = OrderedSet[str]()
self.completed_operations: OrderedSet[str] = OrderedSet()
self.available_buffer_names = OrderedSet(
[
*V.graph.graph_inputs.keys(),
@ -2134,7 +2134,7 @@ class Scheduler:
self.debug_draw_graph()
# used during codegen:
self.buffer_names_to_free = OrderedSet[str]()
self.buffer_names_to_free: OrderedSet[str] = OrderedSet()
# fx graph node to the position it appears in the graph
# for debug attribution
@ -2194,7 +2194,7 @@ class Scheduler:
raise NotImplementedError(node)
def create_foreach_nodes(self) -> None:
removed_node_names = OrderedSet[str]()
removed_node_names: OrderedSet[str] = OrderedSet()
fe_nodes = []
kept_node_names = self.name_to_fused_node.keys()
@ -2515,7 +2515,7 @@ class Scheduler:
return result
def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]:
unmet_deps = OrderedSet[str]()
unmet_deps: OrderedSet[str] = OrderedSet()
if isinstance(
snode,
(
@ -2567,7 +2567,7 @@ class Scheduler:
# note self.nodes is topologically sorted
name_to_ancestors: dict[str, OrderedSet[str]] = {}
for node in self.nodes:
ancestors = OrderedSet[str]()
ancestors: OrderedSet[str] = OrderedSet()
for dep in node.unmet_dependencies:
dep_node_name = self.name_to_buf[dep.name].defining_op_name()
ancestors.add(dep_node_name)