mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Revert "Inductor logging + analysis of torch.profile (#149697)"
This reverts commit e5afbe31245287a92fe328c404b3557e5c5eca73.
Reverted https://github.com/pytorch/pytorch/pull/149697 on behalf of https://github.com/malfet due to Broke rocm, see 642687af29/1
([comment](https://github.com/pytorch/pytorch/pull/149697#issuecomment-2942415600))
This commit is contained in:
@ -61,7 +61,6 @@ from .codegen.triton import (
|
||||
from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta
|
||||
from .codegen.wrapper import pexpr
|
||||
from .exc import CUDACompileError
|
||||
from .fx_utils import count_flops_fx
|
||||
from .ir import ChoiceCaller, PrimitiveInfoType
|
||||
from .ops_handler import StoreMode
|
||||
from .runtime.benchmarking import benchmarker
|
||||
@ -483,20 +482,12 @@ class TritonTemplateKernel(TritonKernel):
|
||||
ninplace_args = len(unique(self.args.inplace_buffers.values()))
|
||||
num_bytes = []
|
||||
for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))):
|
||||
size = V.graph.sizevars.size_hints(inp.get_size(), fallback=0)
|
||||
size = V.graph.sizevars.size_hints(inp.get_size())
|
||||
numel = functools.reduce(operator.mul, size, 1)
|
||||
dtype_size = get_dtype_size(inp.get_dtype())
|
||||
num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
|
||||
return sum(num_bytes)
|
||||
|
||||
def estimate_flops(self) -> int:
|
||||
for node in self.input_nodes:
|
||||
for fx_node in node._current_origins:
|
||||
f = count_flops_fx(fx_node)
|
||||
if f is not None:
|
||||
return V.graph.sizevars.size_hint(f, fallback=0)
|
||||
return 0
|
||||
|
||||
def jit_lines(self):
|
||||
if self.use_jit:
|
||||
return "@triton.jit"
|
||||
@ -532,9 +523,6 @@ class TritonTemplateKernel(TritonKernel):
|
||||
if config.profile_bandwidth or config.benchmark_kernel:
|
||||
num_gb = self.estimate_kernel_num_bytes() / 1e9
|
||||
inductor_meta["kernel_num_gb"] = num_gb
|
||||
if config.benchmark_kernel:
|
||||
flops = self.estimate_flops()
|
||||
inductor_meta["kernel_flop"] = flops
|
||||
|
||||
template_args = f"""
|
||||
num_stages={self.num_stages},
|
||||
|
Reference in New Issue
Block a user