Compare commits

..

2 Commits

Author SHA1 Message Date
895795f07c [ROCm][CI] forward fix kineto submodule bump (#166421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166421
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-28 17:40:23 +00:00
2dc56456cb refactor: pull _replace_node common functionality out of Scheduler.finalize_multi_template_buffers (#163368)
Pull replace_node function out of Scheduler.finalize_multi_template_buffers(). This is needed by the next PR (#163369). As part of this also pull the _replace_operation_buffer() up to top-level since it needed no self references.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163368
Approved by: https://github.com/PaulZhang12
2025-10-28 17:21:52 +00:00
4 changed files with 92 additions and 77 deletions

View File

@ -892,10 +892,16 @@ fn(torch.randn(5))
os.remove(
file_path
) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
empty_line_normalizer(lines),
empty_line_normalizer(stderr.decode("utf-8")),
)
orig_maxDiff = unittest.TestCase.maxDiff
unittest.TestCase.maxDiff = None
try:
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
empty_line_normalizer(lines),
empty_line_normalizer(stderr.decode("utf-8")),
)
except Exception:
unittest.TestCase.maxDiff = orig_maxDiff
raise
@make_settings_test("torch._dynamo.eval_frame")
def test_log_traced_frames(self, records):

View File

@ -529,7 +529,7 @@ class TestProfiler(TestCase):
found_mm = True
if "gemm" in e.name.lower() or "Cijk" in e.name:
found_gemm = True
if "memcpy" in e.name.lower():
if "memcpy" in e.name.lower() or "__amd_rocclr_copyBuffer" in e.name:
found_memcpy = True
if use_cuda:
self.assertTrue(found_gemm)

View File

@ -445,7 +445,7 @@ use_numpy_random_stream = False
enable_cpp_guard_manager = True
# Use C++ guard manager for symbolic shapes
enable_cpp_symbolic_shape_guards = False
enable_cpp_symbolic_shape_guards = not is_fbcode()
# Enable tracing through contextlib.contextmanager
enable_trace_contextlib = True

View File

@ -409,9 +409,10 @@ class SchedulerDonatedBuffer(SchedulerBuffer):
class BaseSchedulerNode:
ancestors: OrderedSet[str]
debug_device_str: Callable[[BaseSchedulerNode], list[str]]
group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
read_writes: dependencies.ReadWrites
unmet_dependencies: OrderedSet[Dep]
last_usage: OrderedSet[str]
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
@ -420,22 +421,24 @@ class BaseSchedulerNode:
min_order: int
max_order: int
mpi_node: MemoryPlanningInfoForNode
mutation_renames: dict[str, str]
node: Optional[ir.Operation]
outputs: list[SchedulerBuffer]
outputs_by_name: dict[str, SchedulerBuffer]
override_estimated_runtime: Optional[float] = None
read_writes: dependencies.ReadWrites
unmet_dependencies: OrderedSet[Dep]
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
lambda *args, **kwargs: []
)
self.scheduler = scheduler
self.debug_device_str = lambda *args, **kwargs: []
def _init_from_node(self, node: ir.Operation) -> None:
self.node: Optional[ir.Operation] = node
self.ancestors: OrderedSet[str] = OrderedSet()
self.last_usage = OrderedSet[
str
]() # buffers that won't be used after this kernel
self.node = node
self.ancestors = OrderedSet()
self.last_usage = OrderedSet() # buffers that won't be used after this kernel
self.written = False
self.outputs: list[SchedulerBuffer] = [
self.outputs = [
SchedulerBuffer(
scheduler=self.scheduler,
node=output,
@ -443,16 +446,14 @@ class BaseSchedulerNode:
)
for output in node.get_outputs()
]
self.outputs_by_name: dict[str, SchedulerBuffer] = {
buf.get_name(): buf for buf in self.outputs
}
self.outputs_by_name = {buf.get_name(): buf for buf in self.outputs}
# mutation_renames for the current node. Due to potential
# more mutations happening later, this can be different
# to Scheduler.mutation_renames. Also this dict should be small
# since only mutation information relevant to the deps for this
# node is stored here.
self.mutation_renames: dict[str, str] = {}
self.mutation_renames = {}
def __repr__(self) -> str:
return f"{type(self).__name__}(name={self.get_name()!r})"
@ -2435,6 +2436,34 @@ def pick_loop_order(
return order
def _replace_operation_buffer(
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
) -> None:
replaced_buf_name = new_node.get_name()
orig_buf_name = orig_node.get_name()
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
replaced_op_name = new_node.get_operation_name()
orig_op_name = orig_node.get_operation_name()
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
del V.graph.name_to_buffer[replaced_buf_name]
new_node.name = orig_buf_name
del V.graph.name_to_op[replaced_op_name]
new_node.operation_name = orig_op_name
orig = V.graph.buffers.index(orig_node)
V.graph.buffers.remove(new_node)
V.graph.buffers[orig] = new_node
V.graph.name_to_buffer[orig_buf_name] = new_node
orig = V.graph.operations.index(orig_node)
V.graph.operations.remove(new_node)
V.graph.operations[orig] = new_node
V.graph.name_to_op[orig_op_name] = new_node
@dataclasses.dataclass
class NodeUser:
node: Union[BaseSchedulerNode, OutputNode]
@ -3336,33 +3365,6 @@ class Scheduler:
will force completion of compilation and benchmarking.
"""
def replace_operation_buffer(
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
) -> None:
replaced_buf_name = new_node.get_name()
orig_buf_name = orig_node.get_name()
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
replaced_op_name = new_node.get_operation_name()
orig_op_name = orig_node.get_operation_name()
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
del V.graph.name_to_buffer[replaced_buf_name]
new_node.name = orig_buf_name
del V.graph.name_to_op[replaced_op_name]
new_node.operation_name = orig_op_name
orig = V.graph.buffers.index(orig_node)
V.graph.buffers.remove(new_node)
V.graph.buffers[orig] = new_node
V.graph.name_to_buffer[orig_buf_name] = new_node
orig = V.graph.operations.index(orig_node)
V.graph.operations.remove(new_node)
V.graph.operations[orig] = new_node
V.graph.name_to_op[orig_op_name] = new_node
for i, node in enumerate(self.nodes):
if isinstance(node, SchedulerNode) and isinstance(
node.node, ir.MultiTemplateBuffer
@ -3416,40 +3418,47 @@ class Scheduler:
assign_origin_node(out_tensorbox, multi_node.origin_node)
out_buffer.layout = multi_node.layout
replace_operation_buffer(multi_node, out_buffer)
new_scheduler_node = self.create_scheduler_node(out_buffer)
self._replace_node(out_buffer, multi_node, i, node)
self.nodes[i] = new_scheduler_node
self.name_to_node[node.get_name()] = new_scheduler_node
self.name_to_fused_node[node.get_name()] = new_scheduler_node
def _replace_node(
self,
out_buffer: ir.OperationBuffer,
multi_node: ir.MultiTemplateBuffer,
i: int,
node: SchedulerNode,
) -> None:
_replace_operation_buffer(multi_node, out_buffer)
new_scheduler_node = self.create_scheduler_node(out_buffer)
# We need to reflect the mutation renames that were recorded in the original node
mutation_renames = {}
for dep in itertools.chain(
node.read_writes.reads, node.unmet_dependencies
):
if real_name := self.mutation_real_name.get(dep.name, None):
mutation_renames[real_name] = dep.name
self.nodes[i] = new_scheduler_node
self.name_to_node[node.get_name()] = new_scheduler_node
self.name_to_fused_node[node.get_name()] = new_scheduler_node
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
# We need to reflect the mutation renames that were recorded in the original node
mutation_renames = {}
for dep in itertools.chain(node.read_writes.reads, node.unmet_dependencies):
if real_name := self.mutation_real_name.get(dep.name, None):
mutation_renames[real_name] = dep.name
new_scheduler_node.unmet_dependencies = rename_deps(
new_scheduler_node.unmet_dependencies
)
new_scheduler_node.read_writes.reads = rename_deps(
new_scheduler_node.read_writes.reads
)
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
for new_out, old_out in zip(
new_scheduler_node.get_outputs(), node.get_outputs()
):
self.name_to_buf[old_out.get_name()] = new_out
new_out.users = old_out.users
new_scheduler_node.unmet_dependencies = rename_deps(
new_scheduler_node.unmet_dependencies
)
new_scheduler_node.read_writes.reads = rename_deps(
new_scheduler_node.read_writes.reads
)
new_scheduler_node.min_order = node.min_order
new_scheduler_node.max_order = node.max_order
new_scheduler_node.last_usage = node.last_usage
for new_out, old_out in zip(
new_scheduler_node.get_outputs(), node.get_outputs()
):
self.name_to_buf[old_out.get_name()] = new_out
new_out.users = old_out.users
new_scheduler_node.min_order = node.min_order
new_scheduler_node.max_order = node.max_order
new_scheduler_node.last_usage = node.last_usage
def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
return any(