Louie Tsai
2025-03-27 09:15:35 +00:00
committed by PyTorch MergeBot
parent e6afb51805
commit 7aacbab0b3

View File

@ -4,9 +4,9 @@ Profiling to understand torch.compile performance
What to use torch.profiler for:
-------------------------------
torch.profiler is helpful for understanding the performance of your program at a kernel-level granularity - for example, it can show graph breaks and GPU utilization at the level of the program. The data provided by the profiler can often help users understand where to investigate further to understand model performance.
torch.profiler is helpful for understanding the performance of your program at a kernel-level granularity - for example, it can show graph breaks and resources utilization at the level of the program. The data provided by the profiler can often help users understand where to investigate further to understand model performance.
To understand kernel-level performance, other tools exist. NVIDIA's ncu tool can be used, or :ref:`inductor's profiling tools <torchinductor-gpu-profiling>`.
To understand kernel-level performance, other tools exist, such as `Nvidia Nsight compute tool <https://developer.nvidia.com/nsight-compute>`_, `AMD Omnitrace <https://rocm.docs.amd.com/projects/omnitrace/en/latest/>`_, Intel® VTune™ Profiler or :ref:`inductor's profiling tools <torchinductor-gpu-profiling>` can be used.
See also the `general pytorch profiler guide <https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html>`_.
@ -24,8 +24,10 @@ Basics of using torch.profiler and viewing traces
import torch
from torchvision.models import resnet18
model = resnet18().cuda()
inputs = [torch.randn((5, 3, 224, 224), device='cuda') for _ in range(10)]
device = 'cuda' # or 'cpu', 'xpu', etc.
model = resnet18().to(device)
inputs = [torch.randn((5, 3, 224, 224), device=device) for _ in range(10)]
model_c = torch.compile(model)
@ -52,9 +54,9 @@ Here, we observe:
* CompiledFunction and CompiledFunctionBackward events, which correspond to the dynamo-compiled regions.
* CPU events at the top, and GPU events at the bottom.
**Flows between CPU and GPU events**
**Flows between CPU and accelerator events**
Every kernel on the GPU occurs after being launched by code running on the CPU. The profiler can draw connections (i.e. “flows”) between the GPU and CPU events to show which CPU event launched a GPU kernel. This is particularly helpful because, with a few exceptions, GPU kernels are launched asynchronously.
Every kernel on the accelerator occurs after being launched by code running on the CPU. The profiler can draw connections (i.e. “flows”) between the accelerator and CPU events to show which CPU event launched a accelerator kernel. This is particularly helpful because, with a few exceptions, accelerator kernels are launched asynchronously.
To view a flow connection, click on a GPU kernel and click “ac2g”:
@ -90,8 +92,10 @@ See an example below:
import torch
from torchvision.models import resnet18
model = resnet18().cuda()
inputs = [torch.randn((5, 3, 224, 224), device='cuda') for _ in range(10)]
# user can switch between cuda and xpu
device = 'cuda'
model = resnet18().to(device)
inputs = [torch.randn((5, 3, 224, 224), device=device) for _ in range(10)]
model_c = torch.compile(model)
@ -103,7 +107,7 @@ See an example below:
def fn(x):
return x.sin().relu()
x = torch.rand((2, 2), device='cuda', requires_grad=True)
x = torch.rand((2, 2), device=device, requires_grad=True)
fn_c = torch.compile(fn)
out = fn_c(x)
out.sum().backward()
@ -120,6 +124,7 @@ See an example below:
.. figure:: _static/img/profiling_torch_compile/compilation_profiling.png
:alt: A visualization in the chrome://trace viewer, showing dynamo and inductor compilation steps
Note a few things:
* The first invocation should occur *during* profiling in order to capture compilation
@ -146,6 +151,8 @@ See the synthetic example below for a demonstration:
import torch
import torch._dynamo
# user can switch between cuda and xpu
device = 'cuda'
class ModelWithBreaks(torch.nn.Module):
def __init__(self):
@ -172,9 +179,8 @@ See the synthetic example below for a demonstration:
mod4 = self.mod4(mod3)
return mod4
model = ModelWithBreaks().cuda()
inputs = [torch.randn((128, 128), device='cuda') for _ in range(10)]
model = ModelWithBreaks().to(device)
inputs = [torch.randn((128, 128), device=device) for _ in range(10)]
model_c = torch.compile(model)