[inductor] Refactor BaseSchedulerNode.__init__ (#135400)

Might be a small compile time improvement since we remove a call to extract_read_writes().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135400
Approved by: https://github.com/oulgen
ghstack dependencies: #135286, #135306, #135377
This commit is contained in:
Jason Ansel
2024-09-07 19:54:41 -07:00
committed by PyTorch MergeBot
parent 16f5155992
commit 53290ca00b

View File

@ -159,27 +159,27 @@ class BaseSchedulerNode:
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
read_writes: dependencies.ReadWrites
unmet_dependencies: OrderedSet[Dep]
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
# For non-"grouped" nodes (i.e. regular SchedulerNode),
# .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`.
min_order: int
max_order: int
def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
def _init_from_node(self, node: ir.Operation) -> None:
self.node: Optional[ir.Operation] = node
self.set_read_writes(node.get_read_writes())
self.ancestors: OrderedSet[str] = OrderedSet()
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
# For non-"grouped" nodes (i.e. regular SchedulerNode),
# .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`.
self.min_order: int
self.max_order: int
self.last_usage: OrderedSet[
str
] = OrderedSet() # buffers that won't be used after this kernel
self.written = False
self.outputs: List[SchedulerBuffer] = [
SchedulerBuffer(
scheduler=scheduler,
scheduler=self.scheduler,
node=output,
defining_op=self,
)
@ -799,6 +799,11 @@ kernel_name_to_op = {
class ExternKernelSchedulerNode(BaseSchedulerNode):
def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
super().__init__(scheduler)
self._init_from_node(node)
self.set_read_writes(node.get_read_writes())
def debug_str_extra(self) -> str:
return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}"
@ -811,7 +816,10 @@ class ExternKernelSchedulerNode(BaseSchedulerNode):
class NopKernelSchedulerNode(BaseSchedulerNode):
pass
def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
super().__init__(scheduler)
self._init_from_node(node)
self.set_read_writes(node.get_read_writes())
class SchedulerNode(BaseSchedulerNode):
@ -820,7 +828,8 @@ class SchedulerNode(BaseSchedulerNode):
scheduler: Scheduler,
node: Union[ir.ComputedBuffer, ir.TemplateBuffer],
) -> None:
super().__init__(scheduler, node)
super().__init__(scheduler)
self._init_from_node(node)
self._compute_attrs()
def _compute_attrs(
@ -1121,7 +1130,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
refresh_group_node_dependencies(self)
def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None:
# NB: No need to call super().__init__() because we don't need to re-use any of its logic.
super().__init__(scheduler)
init_group_node(self, scheduler, snodes)
self.users: List[NodeUser] = []
self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
@ -1590,7 +1599,7 @@ class GroupedSchedulerNode(BaseSchedulerNode):
return grouped_snode
def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None:
# NB: No need to call super().__init__() because we don't need to re-use any of its logic.
super().__init__(scheduler)
init_group_node(self, scheduler, snodes)
def unpack(self) -> List[BaseSchedulerNode]: