mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Refactor BaseSchedulerNode.__init__ (#135400)
Might be a small compile time improvement since we remove a call to extract_read_writes(). Pull Request resolved: https://github.com/pytorch/pytorch/pull/135400 Approved by: https://github.com/oulgen ghstack dependencies: #135286, #135306, #135377
This commit is contained in:
committed by
PyTorch MergeBot
parent
16f5155992
commit
53290ca00b
@ -159,27 +159,27 @@ class BaseSchedulerNode:
|
||||
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
|
||||
read_writes: dependencies.ReadWrites
|
||||
unmet_dependencies: OrderedSet[Dep]
|
||||
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
|
||||
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
|
||||
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
|
||||
# For non-"grouped" nodes (i.e. regular SchedulerNode),
|
||||
# .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`.
|
||||
min_order: int
|
||||
max_order: int
|
||||
|
||||
def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
self.scheduler: Scheduler = scheduler
|
||||
|
||||
def _init_from_node(self, node: ir.Operation) -> None:
|
||||
self.node: Optional[ir.Operation] = node
|
||||
self.set_read_writes(node.get_read_writes())
|
||||
self.ancestors: OrderedSet[str] = OrderedSet()
|
||||
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
|
||||
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
|
||||
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
|
||||
# For non-"grouped" nodes (i.e. regular SchedulerNode),
|
||||
# .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`.
|
||||
self.min_order: int
|
||||
self.max_order: int
|
||||
self.last_usage: OrderedSet[
|
||||
str
|
||||
] = OrderedSet() # buffers that won't be used after this kernel
|
||||
self.written = False
|
||||
|
||||
self.outputs: List[SchedulerBuffer] = [
|
||||
SchedulerBuffer(
|
||||
scheduler=scheduler,
|
||||
scheduler=self.scheduler,
|
||||
node=output,
|
||||
defining_op=self,
|
||||
)
|
||||
@ -799,6 +799,11 @@ kernel_name_to_op = {
|
||||
|
||||
|
||||
class ExternKernelSchedulerNode(BaseSchedulerNode):
|
||||
def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
|
||||
super().__init__(scheduler)
|
||||
self._init_from_node(node)
|
||||
self.set_read_writes(node.get_read_writes())
|
||||
|
||||
def debug_str_extra(self) -> str:
|
||||
return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}"
|
||||
|
||||
@ -811,7 +816,10 @@ class ExternKernelSchedulerNode(BaseSchedulerNode):
|
||||
|
||||
|
||||
class NopKernelSchedulerNode(BaseSchedulerNode):
|
||||
pass
|
||||
def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
|
||||
super().__init__(scheduler)
|
||||
self._init_from_node(node)
|
||||
self.set_read_writes(node.get_read_writes())
|
||||
|
||||
|
||||
class SchedulerNode(BaseSchedulerNode):
|
||||
@ -820,7 +828,8 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
scheduler: Scheduler,
|
||||
node: Union[ir.ComputedBuffer, ir.TemplateBuffer],
|
||||
) -> None:
|
||||
super().__init__(scheduler, node)
|
||||
super().__init__(scheduler)
|
||||
self._init_from_node(node)
|
||||
self._compute_attrs()
|
||||
|
||||
def _compute_attrs(
|
||||
@ -1121,7 +1130,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
refresh_group_node_dependencies(self)
|
||||
|
||||
def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None:
|
||||
# NB: No need to call super().__init__() because we don't need to re-use any of its logic.
|
||||
super().__init__(scheduler)
|
||||
init_group_node(self, scheduler, snodes)
|
||||
self.users: List[NodeUser] = []
|
||||
self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
|
||||
@ -1590,7 +1599,7 @@ class GroupedSchedulerNode(BaseSchedulerNode):
|
||||
return grouped_snode
|
||||
|
||||
def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None:
|
||||
# NB: No need to call super().__init__() because we don't need to re-use any of its logic.
|
||||
super().__init__(scheduler)
|
||||
init_group_node(self, scheduler, snodes)
|
||||
|
||||
def unpack(self) -> List[BaseSchedulerNode]:
|
||||
|
||||
Reference in New Issue
Block a user