mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR renames every cache_limit to recompile_limit via sed. Old config options are maintained via Config(alias='xyz') Pull Request resolved: https://github.com/pytorch/pytorch/pull/143709 Approved by: https://github.com/jansel
47 lines
966 B
Python
47 lines
966 B
Python
import time
|
|
import timeit
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch._dynamo.config
|
|
|
|
|
|
# to satisfy linter complaining about undefined variable
|
|
foo = None
|
|
|
|
args = [f"x{i}" for i in range(100)]
|
|
fn_str = f"""\
|
|
def foo({", ".join(args)}):
|
|
n = {" + ".join(arg + ".shape[0]" for arg in args)}
|
|
return x0 + n
|
|
"""
|
|
|
|
exec(fn_str, globals())
|
|
torch._dynamo.config.recompile_limit = 16
|
|
|
|
|
|
def bench(name, fn):
|
|
torch._dynamo.reset()
|
|
inps = [[torch.randn(i) for _ in range(100)] for i in range(10, 101, 10)]
|
|
|
|
def run_fn():
|
|
for inp in inps:
|
|
fn(*inp)
|
|
|
|
start = time.perf_counter()
|
|
for _ in range(3):
|
|
run_fn()
|
|
end = time.perf_counter()
|
|
|
|
results = timeit.repeat(lambda: run_fn(), number=1000, repeat=10)
|
|
print(f"{name} {np.median(results) * 1000:.1f}us (warmup={end - start:.1f}s)")
|
|
|
|
|
|
def main():
|
|
bench("compiled", torch.compile(foo, dynamic=False)) # type: ignore[F821]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|