mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129755 Approved by: https://github.com/zou3519 ghstack dependencies: #129752
125 lines
2.4 KiB
Python
125 lines
2.4 KiB
Python
import torch
|
|
import torch.fx as fx
|
|
from functorch import make_fx
|
|
from torch._functorch.compile_utils import fx_graph_cse
|
|
from torch.profiler import profile, ProfilerActivity
|
|
|
|
|
|
def profile_it(f, inp):
|
|
for _ in range(5):
|
|
f(inp)
|
|
|
|
itr = 5
|
|
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
|
|
for _ in range(itr):
|
|
f(inp)
|
|
|
|
timing = prof.key_averages()
|
|
cuda_time_total = 0
|
|
for e in timing:
|
|
cuda_time_total = cuda_time_total + e.cuda_time_total
|
|
return cuda_time_total / itr
|
|
|
|
|
|
def profile_function(name, f, inp):
|
|
fx_g = make_fx(f)(inp)
|
|
|
|
new_g = fx_graph_cse(fx_g.graph)
|
|
new_g = fx.GraphModule(fx_g, new_g)
|
|
# do not benchmark against the scripted version because script already does some CSE
|
|
# script_f = torch.jit.script(fx_g)
|
|
# script_g = torch.jit.script(new_g)
|
|
# avg_cuda_time_f = profile_it(script_f, inp)
|
|
# avg_cuda_time_g = profile_it(script_g, inp)
|
|
avg_cuda_time_f = profile_it(fx_g, inp)
|
|
avg_cuda_time_g = profile_it(new_g, inp)
|
|
num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)
|
|
|
|
print(
|
|
f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}"
|
|
)
|
|
|
|
|
|
g_gpu = torch.Generator(device="cuda")
|
|
g_gpu.manual_seed(2147483647)
|
|
inp = torch.randn(2**20, device="cuda", generator=g_gpu)
|
|
|
|
|
|
def f1(x):
|
|
return x.cos().cos()
|
|
|
|
|
|
profile_function("f1", f1, inp)
|
|
|
|
|
|
def fsum(x):
|
|
a = x.sum()
|
|
b = x.sum()
|
|
c = x.sum()
|
|
d = x.sum()
|
|
return a + b + c + d
|
|
|
|
|
|
profile_function("fsum", fsum, inp)
|
|
|
|
|
|
def fconcat(x):
|
|
a = torch.cat((x, x))
|
|
b = torch.cat((x, x))
|
|
return a + b
|
|
|
|
|
|
profile_function("fconcat", fconcat, inp)
|
|
|
|
|
|
def fsum2(x):
|
|
a = x.sum()
|
|
for _ in range(30):
|
|
a = a + x.sum()
|
|
return a
|
|
|
|
|
|
profile_function("fsum2", fsum2, inp)
|
|
|
|
|
|
def fsummulti(x):
|
|
a = 0
|
|
for _ in range(3):
|
|
a = a + x.sum()
|
|
a = a * x.sum()
|
|
return a
|
|
|
|
|
|
profile_function("fsummulti", fsummulti, inp)
|
|
|
|
|
|
def fsummulti2(x):
|
|
a = 0
|
|
for _ in range(30):
|
|
a = a + x.sum()
|
|
a = a * x.sum()
|
|
return a
|
|
|
|
|
|
profile_function("fsummulti2", fsummulti2, inp)
|
|
|
|
|
|
def fcos(x):
|
|
a = 0
|
|
for _ in range(3):
|
|
a = a + x.cos()
|
|
return a
|
|
|
|
|
|
profile_function("fcos", fcos, inp)
|
|
|
|
|
|
def fcos2(x):
|
|
a = 0
|
|
for _ in range(30):
|
|
a = a + x.cos()
|
|
return a
|
|
|
|
|
|
profile_function("fcos2", fcos2, inp)
|