mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Context: https://github.com/pytorch/torchdynamo/issues/1588 This PR moves [TorchDynamo](https://github.com/pytorch/torchdynamo) and TorchInductor into PyTorch core. - `torchdynamo` becomes `torch._dynamo` - `torchinductor` becomes `torch._inductor` This PR was generated by running `copy_to_core.sh` in https://github.com/pytorch/torchdynamo/pull/1538 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86461 Approved by: https://github.com/voznesenskym
299 lines
7.8 KiB
Python
299 lines
7.8 KiB
Python
# flake8: noqa
|
|
import model
|
|
import torch
|
|
|
|
import torch._dynamo
|
|
import torch._inductor.config
|
|
import triton
|
|
from prettytable import PrettyTable
|
|
|
|
# torch._inductor.config.debug = True
|
|
torch._inductor.config.triton.convolution = "triton"
|
|
torch._inductor.config.triton.dense_indexing = True
|
|
torch.manual_seed(0)
|
|
useCudaGraph = True
|
|
|
|
|
|
class Func(object):
|
|
# conv
|
|
@torch._dynamo.optimize("inductor")
|
|
def conv_torchinductor(x, w, bias, stride, padding, dilation, groups):
|
|
y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
|
|
return y
|
|
|
|
# conv
|
|
def conv(x, w, bias, stride, padding, dilation, groups):
|
|
y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
|
|
return y
|
|
|
|
# conv+bias
|
|
@torch._dynamo.optimize("inductor")
|
|
def conv_add_torchinductor(x, w, bias, stride, padding, dilation, groups):
|
|
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
|
|
return y
|
|
|
|
# conv+bias
|
|
def conv_add(x, w, bias, stride, padding, dilation, groups):
|
|
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
|
|
return y
|
|
|
|
# relu(conv)
|
|
@torch._dynamo.optimize("inductor")
|
|
def conv_relu_torchinductor(x, w, bias, stride, padding, dilation, groups):
|
|
y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
|
|
return torch.relu(y)
|
|
|
|
# relu(conv)
|
|
def conv_relu(x, w, bias, stride, padding, dilation, groups):
|
|
y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
|
|
return torch.relu(y)
|
|
|
|
# relu(conv+bias)
|
|
@torch._dynamo.optimize("inductor")
|
|
def conv_add_relu_torchinductor(x, w, bias, stride, padding, dilation, groups):
|
|
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
|
|
return torch.relu(y)
|
|
|
|
# relu(conv+bias)
|
|
def conv_add_relu(x, w, bias, stride, padding, dilation, groups):
|
|
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
|
|
return torch.relu(y)
|
|
|
|
# bn(conv)
|
|
@torch._dynamo.optimize("inductor")
|
|
def conv_bn_torchinductor(
|
|
x,
|
|
w,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
running_mean,
|
|
running_var,
|
|
bn_weight,
|
|
bn_bias,
|
|
):
|
|
y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
|
|
y = torch.batch_norm(
|
|
y,
|
|
weight=bn_weight,
|
|
bias=bn_bias,
|
|
running_mean=running_mean,
|
|
running_var=running_var,
|
|
training=False,
|
|
momentum=1,
|
|
eps=1e-5,
|
|
cudnn_enabled=True,
|
|
)
|
|
return y
|
|
|
|
# bn(conv)
|
|
def conv_bn(
|
|
x,
|
|
w,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
running_mean,
|
|
running_var,
|
|
bn_weight,
|
|
bn_bias,
|
|
):
|
|
y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
|
|
y = torch.batch_norm(
|
|
y,
|
|
weight=bn_weight,
|
|
bias=bn_bias,
|
|
running_mean=running_mean,
|
|
running_var=running_var,
|
|
training=False,
|
|
momentum=1,
|
|
eps=1e-5,
|
|
cudnn_enabled=True,
|
|
)
|
|
return y
|
|
|
|
# relu(bn(conv))
|
|
@torch._dynamo.optimize("inductor")
|
|
def conv_bn_relu_torchinductor(
|
|
x,
|
|
w,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
running_mean,
|
|
running_var,
|
|
bn_weight,
|
|
bn_bias,
|
|
):
|
|
y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
|
|
y = torch.batch_norm(
|
|
y,
|
|
weight=bn_weight,
|
|
bias=bn_bias,
|
|
running_mean=running_mean,
|
|
running_var=running_var,
|
|
training=False,
|
|
momentum=1,
|
|
eps=1e-5,
|
|
cudnn_enabled=True,
|
|
)
|
|
return torch.relu(y)
|
|
|
|
# relu(bn(conv))
|
|
def conv_bn_relu(
|
|
x,
|
|
w,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
running_mean,
|
|
running_var,
|
|
bn_weight,
|
|
bn_bias,
|
|
):
|
|
y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
|
|
y = torch.batch_norm(
|
|
y,
|
|
weight=bn_weight,
|
|
bias=bn_bias,
|
|
running_mean=running_mean,
|
|
running_var=running_var,
|
|
training=False,
|
|
momentum=1,
|
|
eps=1e-5,
|
|
cudnn_enabled=True,
|
|
)
|
|
return torch.relu(y)
|
|
|
|
|
|
def cuda_graph(fn, x, w, bias):
|
|
new_x = x.clone()
|
|
new_w = w.clone()
|
|
if bias is not None:
|
|
new_bias = bias.clone()
|
|
|
|
# warmp up for cudagraph
|
|
s = torch.cuda.Stream()
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(s):
|
|
for i in range(3):
|
|
fn()
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
# capture
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g):
|
|
fn()
|
|
|
|
def fn():
|
|
x.copy_(new_x)
|
|
w.copy_(new_w)
|
|
if bias is not None:
|
|
bias.copy_(new_bias)
|
|
return g.replay()
|
|
|
|
return fn
|
|
|
|
|
|
def bench(layer_params, layer_id, p, fusion_types=[""]):
|
|
BATCH = 32
|
|
IN_H, IN_W, IN_C, KERNEL_H, KERNEL_W, KERNEL_N, stride, padding = layer_params
|
|
dilation, groups = (1, 1), 1
|
|
dtype = torch.float32
|
|
|
|
OUT_H = (
|
|
IN_H + 2 * padding[0] - dilation[0] * (KERNEL_H - 1) - 1 + stride[0]
|
|
) // stride[0]
|
|
OUT_W = (
|
|
IN_W + 2 * padding[1] - dilation[1] * (KERNEL_W - 1) - 1 + stride[1]
|
|
) // stride[1]
|
|
tflops = (
|
|
lambda ms: 2.0
|
|
* BATCH
|
|
* OUT_H
|
|
* OUT_W
|
|
* IN_C
|
|
* KERNEL_H
|
|
* KERNEL_W
|
|
* KERNEL_N
|
|
/ ms
|
|
* 1e-9
|
|
)
|
|
|
|
# allocate inputs, nchw
|
|
x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
|
|
w = torch.randn(
|
|
(KERNEL_N, IN_C // groups, KERNEL_H, KERNEL_W), dtype=dtype, device="cuda"
|
|
)
|
|
|
|
row = [layer_id]
|
|
for fusion_type in fusion_types:
|
|
|
|
if fusion_type == "":
|
|
conv_torchinductor = getattr(Func, "conv_torchinductor")
|
|
conv = getattr(Func, "conv")
|
|
else:
|
|
conv_torchinductor = getattr(Func, f"conv_{fusion_type}_torchinductor")
|
|
conv = getattr(Func, f"conv_{fusion_type}")
|
|
|
|
if "add" in fusion_type:
|
|
bias = torch.randn((KERNEL_N,), dtype=dtype, device="cuda")
|
|
else:
|
|
bias = None
|
|
|
|
args = (x, w, bias, stride, padding, dilation, groups)
|
|
|
|
if "bn" in fusion_type:
|
|
running_mean = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
|
|
running_var = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
|
|
bn_weight = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
|
|
bn_bias = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
|
|
args += (
|
|
running_mean,
|
|
running_var,
|
|
bn_weight,
|
|
bn_bias,
|
|
)
|
|
|
|
def fn_conv():
|
|
return conv(*args)
|
|
|
|
def fn_conv_torchinductor():
|
|
return conv_torchinductor(*args)
|
|
|
|
if useCudaGraph:
|
|
fn_conv = cuda_graph(fn_conv, x, w, bias)
|
|
|
|
torch_conv_ms, _, _ = triton.testing.do_bench(fn_conv)
|
|
triton_conv_ms, _, _ = triton.testing.do_bench(fn_conv_torchinductor)
|
|
row.extend([tflops(torch_conv_ms), tflops(triton_conv_ms)])
|
|
|
|
p.add_row(row)
|
|
|
|
|
|
fusion_types = ["", "add", "relu", "add_relu", "bn", "bn_relu"]
|
|
p = PrettyTable()
|
|
field_names = ["layer"]
|
|
for fusion_type in fusion_types:
|
|
if fusion_type == "":
|
|
field_names.append("torch conv")
|
|
field_names.append("triton conv")
|
|
else:
|
|
field_names.append(f"torch conv+{fusion_type}")
|
|
field_names.append(f"triton conv+{fusion_type}")
|
|
|
|
p.field_names = field_names
|
|
p.float_format = ".3"
|
|
for id, layer in enumerate(model.resnet50_layers):
|
|
bench(layer, id, p, fusion_types)
|
|
|
|
print(p)
|