mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Context: https://github.com/pytorch/torchdynamo/issues/1588 This PR moves [TorchDynamo](https://github.com/pytorch/torchdynamo) and TorchInductor into PyTorch core. - `torchdynamo` becomes `torch._dynamo` - `torchinductor` becomes `torch._inductor` This PR was generated by running `copy_to_core.sh` in https://github.com/pytorch/torchdynamo/pull/1538 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86461 Approved by: https://github.com/voznesenskym
108 lines
2.4 KiB
Python
108 lines
2.4 KiB
Python
import torch
|
|
|
|
import torch._inductor.triton_ops
|
|
from torch.profiler import profile, ProfilerActivity, record_function
|
|
|
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
(
|
|
BATCH,
|
|
IN_C,
|
|
IN_H,
|
|
IN_W,
|
|
KERNEL_N,
|
|
KERNEL_H,
|
|
KERNEL_W,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
dtype,
|
|
) = (32, 56, 56, 64, 3, 3, 64, (1, 1), (0, 0), (1, 1), 1, torch.float32)
|
|
|
|
|
|
def profile_op(
|
|
# provider
|
|
provider,
|
|
# Tensor dimensions
|
|
BATCH,
|
|
IN_C,
|
|
IN_H,
|
|
IN_W,
|
|
KERNEL_N,
|
|
KERNEL_H,
|
|
KERNEL_W,
|
|
# parameters of conv
|
|
stride=(1, 1),
|
|
padding=(0, 0),
|
|
dilation=(1, 1),
|
|
groups=1,
|
|
dtype=torch.float16,
|
|
layout="nhwc",
|
|
warmup=25,
|
|
rep=50,
|
|
):
|
|
|
|
# allocate inputs, nchw
|
|
x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
|
|
w = torch.randn(
|
|
(KERNEL_N, IN_C // groups, KERNEL_H, KERNEL_W), dtype=dtype, device="cuda"
|
|
)
|
|
bias = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
|
|
if layout == "nhwc":
|
|
x = x.to(memory_format=torch.channels_last)
|
|
w = w.to(memory_format=torch.channels_last)
|
|
|
|
if provider == "cublas":
|
|
|
|
def fn():
|
|
return torch.conv2d(x, w, bias, stride, padding, dilation, groups)
|
|
|
|
elif provider == "triton":
|
|
|
|
def fn():
|
|
return torch._inductor.triton_ops.conv(
|
|
x, w, bias, stride, padding, dilation, False, (0, 0), groups
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"{provider} not supported")
|
|
# warm up
|
|
for _ in range(warmup):
|
|
fn()
|
|
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
|
|
with record_function("model_inference"):
|
|
for _ in range(rep):
|
|
fn()
|
|
|
|
print("Profiling ", provider)
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
|
|
|
|
for provider in ["cublas", "triton"]:
|
|
profile_op(
|
|
# provider
|
|
provider,
|
|
# Tensor dimensions
|
|
BATCH,
|
|
IN_C,
|
|
IN_H,
|
|
IN_W,
|
|
KERNEL_N,
|
|
KERNEL_H,
|
|
KERNEL_W,
|
|
# parameters of conv
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
dtype=dtype,
|
|
layout="nhwc",
|
|
warmup=25,
|
|
rep=50,
|
|
)
|