mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147547 Approved by: https://github.com/bobrenjc93
36 lines
897 B
Python
36 lines
897 B
Python
# mypy: allow-untyped-defs
|
|
import torch
|
|
from torch._strobelight.cli_function_profiler import (
|
|
strobelight,
|
|
StrobelightCLIFunctionProfiler,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
def fn(x, y, z):
|
|
return x * y + z
|
|
|
|
# use decorator with default profiler or optional profile arguments.
|
|
@strobelight(sample_each=10000, stop_at_error=False)
|
|
@torch.compile()
|
|
def work():
|
|
for _ in range(10):
|
|
torch._dynamo.reset()
|
|
for j in range(5):
|
|
torch._dynamo.reset()
|
|
fn(torch.rand(j, j), torch.rand(j, j), torch.rand(j, j))
|
|
|
|
work()
|
|
|
|
# or pass a profiler instance.
|
|
profiler = StrobelightCLIFunctionProfiler(stop_at_error=False)
|
|
|
|
@strobelight(profiler, sample_tags=["something", "another"])
|
|
def work2():
|
|
sum = 0
|
|
for _ in range(100000000):
|
|
sum += 1 # noqa: SIM113
|
|
|
|
work2()
|