Revert "Inductor logging + analysis of torch.profile (#149697)"

This reverts commit e5afbe31245287a92fe328c404b3557e5c5eca73.

Reverted https://github.com/pytorch/pytorch/pull/149697 on behalf of https://github.com/malfet due to Broke rocm, see 642687af29/1 ([comment](https://github.com/pytorch/pytorch/pull/149697#issuecomment-2942415600))
This commit is contained in:
PyTorch MergeBot
2025-06-05 01:38:13 +00:00
parent 642687af29
commit 5e03433443
19 changed files with 74 additions and 1847 deletions

View File

@ -61,7 +61,6 @@ from .codegen.triton import (
from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta
from .codegen.wrapper import pexpr
from .exc import CUDACompileError
from .fx_utils import count_flops_fx
from .ir import ChoiceCaller, PrimitiveInfoType
from .ops_handler import StoreMode
from .runtime.benchmarking import benchmarker
@ -483,20 +482,12 @@ class TritonTemplateKernel(TritonKernel):
ninplace_args = len(unique(self.args.inplace_buffers.values()))
num_bytes = []
for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))):
size = V.graph.sizevars.size_hints(inp.get_size(), fallback=0)
size = V.graph.sizevars.size_hints(inp.get_size())
numel = functools.reduce(operator.mul, size, 1)
dtype_size = get_dtype_size(inp.get_dtype())
num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
return sum(num_bytes)
def estimate_flops(self) -> int:
for node in self.input_nodes:
for fx_node in node._current_origins:
f = count_flops_fx(fx_node)
if f is not None:
return V.graph.sizevars.size_hint(f, fallback=0)
return 0
def jit_lines(self):
if self.use_jit:
return "@triton.jit"
@ -532,9 +523,6 @@ class TritonTemplateKernel(TritonKernel):
if config.profile_bandwidth or config.benchmark_kernel:
num_gb = self.estimate_kernel_num_bytes() / 1e9
inductor_meta["kernel_num_gb"] = num_gb
if config.benchmark_kernel:
flops = self.estimate_flops()
inductor_meta["kernel_flop"] = flops
template_args = f"""
num_stages={self.num_stages},