mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
FX graph segment in CompiledFxGraph does not include tensor meta data, for example, tensor shape, tensor stride, tensor data type, tensor device. AI system co-design team requested to include these information in FX graph segment so they can use FX graph segment to project the performance on different hardware. This DIFF is to modify the Graph::Node::format_node to include tensor meta data. Before this DIFF, the triton kernel FX graph segment looks like the following: ``` # %mm : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=mm] # %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=arg2_1] # %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {}) # %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {}) # %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {}) # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {}) # %cos : cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {}) # return %cos After this DIFF: # %mm : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=mm] # %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=arg2_1] # %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {}) # %permute_1 : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {}) # %mul : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {}) # %add : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {}) # %cos : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {}) # return %cos ``` If format_node can not be changed, I can copy the code to caffe2/torch/_inductor/utils.py. Differential Revision: D77973076 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159311 Approved by: https://github.com/angelayi
784 lines
30 KiB
Python
784 lines
30 KiB
Python
# Owner(s): ["oncall: profiler"]
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import _dynamo as torchdynamo
|
|
from torch.autograd import (
|
|
_record_function_with_args_enter,
|
|
_record_function_with_args_exit,
|
|
)
|
|
from torch.profiler import (
|
|
ExecutionTraceObserver,
|
|
kineto_available,
|
|
profile,
|
|
record_function,
|
|
supported_activities,
|
|
)
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
skipCPUIf,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
IS_WINDOWS,
|
|
run_tests,
|
|
skipIfHpu,
|
|
skipIfTorchDynamo,
|
|
TEST_HPU,
|
|
TEST_XPU,
|
|
TestCase,
|
|
)
|
|
from torch.utils._triton import has_triton
|
|
|
|
|
|
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
|
|
# This causes an issue in the multithreading test because we check all events
|
|
# in that test with their tids. The events that correspond to these lingering
|
|
# threads all have TID of (uint64_t)(-1) which is invalid.
|
|
# The work around is turnning off monitoring thread when tqdm is loaded.
|
|
# Since these are unit tests, it is safe to turn off monitor thread.
|
|
try:
|
|
import tqdm
|
|
|
|
tqdm.tqdm.monitor_interval = 0
|
|
except ImportError:
|
|
pass
|
|
|
|
Json = dict[str, Any]
|
|
|
|
|
|
class TestExecutionTrace(TestCase):
|
|
def payload(self, device, use_device=False):
|
|
u = torch.randn(3, 4, 5, requires_grad=True)
|
|
with record_function("## TEST 1 ##", "1, 2, 3"):
|
|
inf_val = float("inf")
|
|
neg_inf_val = float("-inf")
|
|
nan_val = float("nan")
|
|
rf_handle = _record_function_with_args_enter(
|
|
"## TEST 2 ##",
|
|
1,
|
|
False,
|
|
2.5,
|
|
[u, u],
|
|
(u, u),
|
|
"hello",
|
|
u,
|
|
inf_val,
|
|
neg_inf_val,
|
|
nan_val,
|
|
)
|
|
x = torch.randn(10, 10, requires_grad=True)
|
|
if use_device:
|
|
x = x.to(device)
|
|
y = torch.randn(10, 10, requires_grad=True)
|
|
if use_device:
|
|
y = y.to(device)
|
|
z = x + y + x * y + x * y
|
|
z.backward(z)
|
|
gelu = nn.GELU()
|
|
m = torch.randn(2)
|
|
_ = gelu(m)
|
|
if use_device:
|
|
z = z.cpu()
|
|
_record_function_with_args_exit(rf_handle)
|
|
|
|
def get_execution_trace_root(self, output_file_name) -> Json:
|
|
import gzip
|
|
|
|
nodes = []
|
|
with (
|
|
gzip.open(output_file_name)
|
|
if output_file_name.endswith(".gz")
|
|
else open(output_file_name)
|
|
) as f:
|
|
et_graph = json.load(f)
|
|
assert "nodes" in et_graph
|
|
nodes = et_graph["nodes"]
|
|
return nodes
|
|
|
|
def get_execution_trace_rf_ids(self, nodes: list[Json]) -> list[int]:
|
|
"""Returns a sorted list of rf_id (record function ids) in execution trace"""
|
|
|
|
def get_rf_id(node):
|
|
attrs = node["attrs"]
|
|
for a in attrs:
|
|
if a["name"] == "rf_id":
|
|
return a["value"]
|
|
return None
|
|
|
|
rf_ids_ = (
|
|
get_rf_id(n)
|
|
for n in nodes
|
|
if n["name"] != "[pytorch|profiler|execution_trace|process]"
|
|
and n["name"] != "[pytorch|profiler|execution_trace|thread]"
|
|
)
|
|
return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None)
|
|
|
|
def get_kineto_rf_ids(self, events: list[Json]) -> list[int]:
|
|
"""Returns a sorted list of Record function IDs for CPU operators and user annotations"""
|
|
ops_and_annotations = (
|
|
e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"]
|
|
)
|
|
return sorted(
|
|
e.get("args", {}).get("Record function id", -1) for e in ops_and_annotations
|
|
)
|
|
|
|
@unittest.skipIf(not kineto_available(), "Kineto is required")
|
|
@skipIfHpu
|
|
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
|
|
def test_execution_trace_with_kineto(self, device):
|
|
trace_called_num = 0
|
|
|
|
def trace_handler(p):
|
|
nonlocal trace_called_num
|
|
trace_called_num += 1
|
|
|
|
use_device = (
|
|
torch.profiler.ProfilerActivity.CUDA
|
|
or torch.profiler.ProfilerActivity.XPU in supported_activities()
|
|
or torch.profiler.ProfilerActivity.HPU in supported_activities()
|
|
)
|
|
# Create a temp file to save execution trace and kineto data.
|
|
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
|
|
fp.close()
|
|
kt = tempfile.NamedTemporaryFile(
|
|
mode="w+t", suffix=".kineto.json", delete=False
|
|
)
|
|
kt.close()
|
|
|
|
with profile(
|
|
activities=supported_activities(),
|
|
schedule=torch.profiler.schedule(
|
|
skip_first=3, wait=1, warmup=1, active=2, repeat=1
|
|
),
|
|
on_trace_ready=trace_handler,
|
|
execution_trace_observer=(
|
|
ExecutionTraceObserver().register_callback(fp.name)
|
|
),
|
|
) as p:
|
|
for idx in range(10):
|
|
with record_function(f"## LOOP {idx} ##"):
|
|
self.payload(device, use_device=use_device)
|
|
p.step()
|
|
self.assertEqual(fp.name, p.execution_trace_observer.get_output_file_path())
|
|
|
|
# Uncomment for debugging
|
|
# print("Output kineto = ", kt.name)
|
|
# print("Output ET = ", fp.name)
|
|
|
|
p.export_chrome_trace(kt.name)
|
|
self.assertEqual(trace_called_num, 1)
|
|
|
|
nodes = self.get_execution_trace_root(fp.name)
|
|
loop_count = 0
|
|
found_root_node = False
|
|
for n in nodes:
|
|
assert "name" in n
|
|
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
|
found_root_node = True
|
|
if n["name"].startswith("## LOOP "):
|
|
loop_count += 1
|
|
self.assertTrue(found_root_node)
|
|
# Since profiler trace is active for 2 iterations
|
|
self.assertEqual(loop_count, 2)
|
|
|
|
# Compare the collected Execution Trace and Kineto Trace
|
|
# in terms of record func ID (rf_id) and External IDs
|
|
# both of these should match for the same trace window.
|
|
|
|
with open(kt.name) as f:
|
|
kineto = json.load(f)
|
|
events = kineto["traceEvents"]
|
|
|
|
# Look up rf_ids in both Execution and Kineto trace as two lists.
|
|
rf_ids_et = self.get_execution_trace_rf_ids(nodes)
|
|
rf_ids_kineto = self.get_kineto_rf_ids(events)
|
|
|
|
self.assertCountEqual(rf_ids_et, rf_ids_kineto)
|
|
self.assertListEqual(
|
|
rf_ids_et,
|
|
rf_ids_kineto,
|
|
msg=f"ET and kineto rf_id should exactly match\n"
|
|
f" rf_ids_et = {rf_ids_et}\n"
|
|
f" rf_ids_kineto = {rf_ids_kineto}\n",
|
|
)
|
|
|
|
@unittest.skipIf(not kineto_available(), "Kineto is required")
|
|
@skipIfHpu
|
|
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
|
|
def test_execution_trace_env_enabled_with_kineto(self, device):
|
|
import os
|
|
|
|
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1"
|
|
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS"] = "1"
|
|
trace_called_num = 0
|
|
|
|
def trace_handler(p):
|
|
nonlocal trace_called_num
|
|
trace_called_num += 1
|
|
|
|
use_device = (
|
|
torch.profiler.ProfilerActivity.CUDA
|
|
or torch.profiler.ProfilerActivity.XPU in supported_activities()
|
|
or torch.profiler.ProfilerActivity.HPU in supported_activities()
|
|
)
|
|
# Create a temp file to save kineto data.
|
|
kt = tempfile.NamedTemporaryFile(
|
|
mode="w+t", suffix=".kineto.json", delete=False
|
|
)
|
|
kt.close()
|
|
|
|
with profile(
|
|
activities=supported_activities(),
|
|
schedule=torch.profiler.schedule(
|
|
skip_first=3, wait=1, warmup=1, active=2, repeat=1
|
|
),
|
|
on_trace_ready=trace_handler,
|
|
) as p:
|
|
for idx in range(10):
|
|
with record_function(f"## LOOP {idx} ##"):
|
|
self.payload(device, use_device=use_device)
|
|
p.step()
|
|
|
|
# Uncomment for debugging
|
|
# print("Output kineto = ", kt.name)
|
|
# print("Output ET = ", fp.name)
|
|
|
|
p.export_chrome_trace(kt.name)
|
|
self.assertEqual(trace_called_num, 1)
|
|
et_path = p.execution_trace_observer.get_output_file_path()
|
|
et_res_path = p.execution_trace_observer.get_resources_dir(et_path)
|
|
# the path should be set up due to our env variables
|
|
self.assertTrue(et_path is not None)
|
|
# et_res_path should be an empty directory
|
|
self.assertTrue(os.path.isdir(et_res_path))
|
|
self.assertEqual(len(os.listdir(et_res_path)), 0)
|
|
# Compare the collected Execution Trace and Kineto Trace
|
|
# in terms of record func
|
|
nodes = self.get_execution_trace_root(et_path)
|
|
loop_count = 0
|
|
found_root_node = False
|
|
for n in nodes:
|
|
assert "name" in n
|
|
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
|
found_root_node = True
|
|
if n["name"].startswith("## LOOP "):
|
|
loop_count += 1
|
|
self.assertTrue(found_root_node)
|
|
# Since profiler trace is active for 2 iterations
|
|
self.assertEqual(loop_count, 2)
|
|
|
|
# Compare the collected Execution Trace and Kineto Trace
|
|
# in terms of record func ID (rf_id) and External IDs
|
|
# both of these should match for the same trace window.
|
|
|
|
with open(kt.name) as f:
|
|
kineto = json.load(f)
|
|
events = kineto["traceEvents"]
|
|
|
|
# Look up rf_ids in both Execution and Kineto trace as two lists.
|
|
rf_ids_et = self.get_execution_trace_rf_ids(nodes)
|
|
rf_ids_kineto = self.get_kineto_rf_ids(events)
|
|
|
|
self.assertCountEqual(rf_ids_et, rf_ids_kineto)
|
|
self.assertListEqual(
|
|
rf_ids_et,
|
|
rf_ids_kineto,
|
|
msg=f"ET and kineto rf_id should exactly match\n"
|
|
f" rf_ids_et = {rf_ids_et}\n"
|
|
f" rf_ids_kineto = {rf_ids_kineto}\n",
|
|
)
|
|
|
|
def test_execution_trace_alone(self, device):
|
|
use_device = (
|
|
torch.profiler.ProfilerActivity.CUDA
|
|
or torch.profiler.ProfilerActivity.HPU in supported_activities()
|
|
or torch.profiler.ProfilerActivity.XPU in supported_activities()
|
|
)
|
|
# Create a temp file to save execution trace data.
|
|
# Use a gzip file to test compression codepath
|
|
fp = tempfile.NamedTemporaryFile("w", suffix=".et.json.gz", delete=False)
|
|
fp.close()
|
|
expected_loop_events = 0
|
|
|
|
et = ExecutionTraceObserver().register_callback(fp.name)
|
|
|
|
et.start()
|
|
for idx in range(5):
|
|
expected_loop_events += 1
|
|
with record_function(f"## LOOP {idx} ##"):
|
|
self.payload(device, use_device=use_device)
|
|
et.stop()
|
|
|
|
assert fp.name == et.get_output_file_path()
|
|
et.unregister_callback()
|
|
nodes = self.get_execution_trace_root(fp.name)
|
|
loop_count = 0
|
|
# Expected tensor object tuple size, in th form of:
|
|
# [tensor_id, storage_id, offset, numel, itemsize, device_str]
|
|
tensor_tuple_size = 6
|
|
found_root_node = False
|
|
for n in nodes:
|
|
assert "name" in n
|
|
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
|
found_root_node = True
|
|
if n["name"].startswith("## LOOP "):
|
|
loop_count += 1
|
|
# Check if tensor tuple representation size is correct.
|
|
if n["name"] == "## TEST 2 ##":
|
|
assert len(n["inputs"]["values"][3][0]) == tensor_tuple_size
|
|
assert found_root_node
|
|
assert loop_count == expected_loop_events
|
|
|
|
def test_execution_trace_env_disabled(self, device):
|
|
import os
|
|
|
|
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "0"
|
|
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS"] = "0"
|
|
use_device = (
|
|
torch.profiler.ProfilerActivity.CUDA
|
|
or torch.profiler.ProfilerActivity.HPU in supported_activities()
|
|
or torch.profiler.ProfilerActivity.XPU in supported_activities()
|
|
)
|
|
|
|
with profile(
|
|
activities=torch.profiler.supported_activities(),
|
|
record_shapes=True,
|
|
schedule=torch.profiler.schedule(
|
|
skip_first=3, wait=1, warmup=1, active=2, repeat=1
|
|
),
|
|
) as p:
|
|
for idx in range(10):
|
|
with record_function(f"## LOOP {idx} ##"):
|
|
self.payload(device, use_device=use_device)
|
|
p.step()
|
|
|
|
self.assertTrue(p.execution_trace_observer is None)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
|
|
@unittest.skipIf(
|
|
(not has_triton()) or (not TEST_CUDA and not TEST_XPU),
|
|
"need triton and device(CUDA or XPU) availability to run",
|
|
)
|
|
@skipCPUIf(True, "skip CPU device for testing profiling triton")
|
|
def test_execution_trace_with_pt2(self, device):
|
|
@torchdynamo.optimize("inductor")
|
|
def fn(a, b, c):
|
|
x = torch.nn.functional.linear(a, b)
|
|
x = x + c
|
|
return x.cos()
|
|
|
|
a, b, c = (torch.randn(4, 4, requires_grad=True).to(device) for _ in range(3))
|
|
|
|
inputs = [a, b, c]
|
|
with torch._inductor.config.patch(compile_threads=1):
|
|
fn(*inputs)
|
|
|
|
# Create a temp file to save execution trace data.
|
|
fp = tempfile.NamedTemporaryFile("w+t", suffix="_et.json", delete=False)
|
|
fp.close()
|
|
et = ExecutionTraceObserver()
|
|
et.register_callback(fp.name)
|
|
et.set_extra_resource_collection(True)
|
|
|
|
with profile(
|
|
activities=torch.profiler.supported_activities(),
|
|
record_shapes=True,
|
|
schedule=torch.profiler.schedule(
|
|
skip_first=3, wait=1, warmup=1, active=2, repeat=1
|
|
),
|
|
execution_trace_observer=et,
|
|
) as p:
|
|
for idx in range(10):
|
|
with record_function(f"## LOOP {idx} ##"):
|
|
fn(*inputs)
|
|
p.step()
|
|
|
|
nodes = self.get_execution_trace_root(fp.name)
|
|
found_captured_triton_kernel_node = False
|
|
found_call_compiled_fx_graph = False
|
|
for n in nodes:
|
|
assert "name" in n
|
|
if "triton_" in n["name"]:
|
|
for attr in n["attrs"]:
|
|
if attr["name"] == "kernel_file" and attr["value"] != "":
|
|
found_captured_triton_kernel_node = True
|
|
assert len(n["inputs"]["values"]) > 0
|
|
assert len(n["outputs"]["values"]) == 0
|
|
elif "Call CompiledFxGraph" in n["name"]:
|
|
found_call_compiled_fx_graph = True
|
|
assert found_captured_triton_kernel_node
|
|
assert found_call_compiled_fx_graph
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
|
|
@unittest.skipIf(
|
|
(not has_triton()) or (not TEST_CUDA and not TEST_XPU),
|
|
"need triton and device(CUDA or XPU) availability to run",
|
|
)
|
|
@skipCPUIf(True, "skip CPU device for testing profiling triton")
|
|
def test_execution_trace_env_enabled_with_pt2(self, device):
|
|
# clean up the local cache for triton kernel
|
|
from torch._inductor.codecache import PyCodeCache as PyCodeCache
|
|
|
|
PyCodeCache.cache_clear(purge=True)
|
|
|
|
import os
|
|
|
|
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1"
|
|
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS"] = "1"
|
|
|
|
@torchdynamo.optimize("inductor")
|
|
def fn(a, b, c):
|
|
x = torch.nn.functional.linear(a, b)
|
|
x = x + c
|
|
return x.cos()
|
|
|
|
a, b, c = (torch.randn(4, 4, requires_grad=True).to(device) for _ in range(3))
|
|
|
|
inputs = [a, b, c]
|
|
with torch._inductor.config.patch(
|
|
compile_threads=1, fx_graph_cache=False, fx_graph_remote_cache=False
|
|
):
|
|
fn(*inputs)
|
|
|
|
with profile(
|
|
activities=torch.profiler.supported_activities(),
|
|
record_shapes=True,
|
|
schedule=torch.profiler.schedule(
|
|
skip_first=3, wait=1, warmup=1, active=2, repeat=1
|
|
),
|
|
) as p:
|
|
for idx in range(10):
|
|
with record_function(f"## LOOP {idx} ##"):
|
|
fn(*inputs)
|
|
p.step()
|
|
|
|
et_path = p.execution_trace_observer.get_output_file_path()
|
|
et_res_path = p.execution_trace_observer.get_resources_dir(et_path)
|
|
# the path should be set up due to our env variables
|
|
self.assertTrue(et_path is not None)
|
|
# et_res_path should be an empty directory
|
|
self.assertTrue(os.path.isdir(et_res_path))
|
|
self.assertEqual(len(os.listdir(et_res_path)), 2)
|
|
nodes = self.get_execution_trace_root(et_path)
|
|
found_captured_triton_kernel_node = False
|
|
for n in nodes:
|
|
assert "name" in n
|
|
if "triton_" in n["name"]:
|
|
for attr in n["attrs"]:
|
|
if attr["name"] == "kernel_file" and attr["value"] != "":
|
|
found_captured_triton_kernel_node = True
|
|
assert len(n["inputs"]["values"]) > 0
|
|
assert len(n["outputs"]["values"]) == 0
|
|
assert found_captured_triton_kernel_node
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
|
|
@unittest.skipIf(
|
|
(not has_triton()) or (not TEST_CUDA and not TEST_XPU),
|
|
"need triton and device(CUDA or XPU) availability to run",
|
|
)
|
|
@skipCPUIf(True, "skip CPU device for testing profiling triton")
|
|
def test_triton_fx_graph_with_et(self, device):
|
|
# clean up the local cache for triton kernel
|
|
from torch._inductor.codecache import PyCodeCache as PyCodeCache
|
|
|
|
PyCodeCache.cache_clear(purge=True)
|
|
|
|
import os
|
|
|
|
@torchdynamo.optimize("inductor")
|
|
def fn(a, b, c):
|
|
x = torch.nn.functional.linear(a, b)
|
|
x = x.sin()
|
|
x = x.t() + c * 1111
|
|
return x.cos()
|
|
|
|
a, b, c = (
|
|
torch.randn(4, 4, requires_grad=False).to(torch.device("cuda:0"))
|
|
for _ in range(3)
|
|
)
|
|
|
|
inputs = [a, b, c]
|
|
with torch._inductor.config.patch(
|
|
compile_threads=1, fx_graph_cache=False, fx_graph_remote_cache=False
|
|
):
|
|
fn(*inputs)
|
|
|
|
fp = tempfile.NamedTemporaryFile("w+t", suffix="fx_graph_et.json", delete=False)
|
|
fp.close()
|
|
et = ExecutionTraceObserver()
|
|
et.register_callback(fp.name)
|
|
et.set_extra_resource_collection(True)
|
|
with profile(
|
|
activities=torch.profiler.supported_activities(),
|
|
record_shapes=True,
|
|
schedule=torch.profiler.schedule(
|
|
skip_first=0, wait=1, warmup=1, active=1, repeat=1
|
|
),
|
|
execution_trace_observer=et,
|
|
) as p:
|
|
for idx in range(10):
|
|
with record_function(f"## LOOP {idx} ##"):
|
|
fn(*inputs)
|
|
p.step()
|
|
|
|
et_path = p.execution_trace_observer.get_output_file_path()
|
|
et_res_path = p.execution_trace_observer.get_resources_dir(et_path)
|
|
# the path should be set up due to our env variables
|
|
self.assertTrue(et_path is not None)
|
|
# et_res_path should be an empty directory
|
|
self.assertTrue(os.path.isdir(et_res_path))
|
|
for filename in os.listdir(et_res_path):
|
|
file_path = os.path.join(et_res_path, filename)
|
|
if os.path.isfile(file_path):
|
|
with open(file_path) as file:
|
|
fx_graph_found = False
|
|
fx_graph = []
|
|
for line in file:
|
|
line = line.strip()
|
|
# There are two files in the directory, one is the source
|
|
# code of the triton kernel, and the other is the source code for FX graph.
|
|
# Only the FX graph file contains the string "# Graph fragment:".
|
|
if line.startswith("# Graph fragment:"):
|
|
fx_graph_found = True
|
|
elif fx_graph_found and line.startswith("#"):
|
|
fx_graph.append(line)
|
|
else:
|
|
fx_graph_found = False
|
|
|
|
if len(fx_graph) > 0:
|
|
assert (
|
|
fx_graph[0]
|
|
== '# %mm : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=mm]'
|
|
)
|
|
assert (
|
|
fx_graph[1]
|
|
== '# %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=arg2_1]'
|
|
)
|
|
assert (
|
|
fx_graph[2]
|
|
== '# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})' # noqa: B950
|
|
)
|
|
assert (
|
|
fx_graph[3]
|
|
== '# %permute_1 : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})' # noqa: B950
|
|
)
|
|
assert (
|
|
fx_graph[4]
|
|
== '# %mul : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})' # noqa: B950
|
|
)
|
|
assert (
|
|
fx_graph[5]
|
|
== '# %add : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})' # noqa: B950
|
|
)
|
|
assert (
|
|
fx_graph[6]
|
|
== '# %cos : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})' # noqa: B950
|
|
)
|
|
assert fx_graph[7] == "# return %cos"
|
|
|
|
def test_execution_trace_start_stop(self, device):
|
|
use_device = (
|
|
torch.profiler.ProfilerActivity.CUDA
|
|
or torch.profiler.ProfilerActivity.XPU in supported_activities()
|
|
or torch.profiler.ProfilerActivity.HPU in supported_activities()
|
|
)
|
|
# Create a temp file to save execution trace data.
|
|
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
|
|
fp.close()
|
|
expected_loop_events = 0
|
|
et = ExecutionTraceObserver().register_callback(fp.name)
|
|
for idx in range(10):
|
|
if idx == 3:
|
|
et.start()
|
|
elif idx == 5:
|
|
et.stop()
|
|
elif idx == 8:
|
|
et.start()
|
|
elif idx == 9:
|
|
et.stop()
|
|
if et._execution_trace_running:
|
|
expected_loop_events += 1
|
|
with record_function(f"## LOOP {idx} ##"):
|
|
self.payload(device, use_device=use_device)
|
|
|
|
assert fp.name == et.get_output_file_path()
|
|
et.unregister_callback()
|
|
nodes = self.get_execution_trace_root(fp.name)
|
|
loop_count = 0
|
|
found_root_node = False
|
|
for n in nodes:
|
|
assert "name" in n
|
|
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
|
found_root_node = True
|
|
if n["name"].startswith("## LOOP "):
|
|
loop_count += 1
|
|
assert found_root_node
|
|
assert loop_count == expected_loop_events
|
|
|
|
def test_execution_trace_repeat_in_loop(self, device):
|
|
use_device = (
|
|
torch.profiler.ProfilerActivity.CUDA
|
|
or torch.profiler.ProfilerActivity.XPU in supported_activities()
|
|
or torch.profiler.ProfilerActivity.HPU in supported_activities()
|
|
)
|
|
iter_list = {3, 4, 6, 8}
|
|
expected_loop_events = len(iter_list)
|
|
output_files = []
|
|
for idx in range(10):
|
|
if idx in iter_list:
|
|
# Create a temp file to save execution trace data.
|
|
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
|
|
fp.close()
|
|
output_files.append(fp.name)
|
|
et = ExecutionTraceObserver().register_callback(fp.name)
|
|
et.start()
|
|
with record_function(f"## LOOP {idx} ##"):
|
|
self.payload(device, use_device=use_device)
|
|
if idx in iter_list:
|
|
et.stop()
|
|
et.unregister_callback()
|
|
|
|
event_count = 0
|
|
for et_file in output_files:
|
|
nodes = self.get_execution_trace_root(et_file)
|
|
found_root_node = False
|
|
for n in nodes:
|
|
assert "name" in n
|
|
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
|
assert n["id"] == 1
|
|
found_root_node = True
|
|
if n["name"].startswith("## LOOP "):
|
|
event_count += 1
|
|
assert found_root_node
|
|
assert event_count == expected_loop_events
|
|
|
|
def test_execution_trace_no_capture(self):
|
|
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
|
|
fp.close()
|
|
et = ExecutionTraceObserver().register_callback(fp.name)
|
|
|
|
assert fp.name == et.get_output_file_path()
|
|
et.unregister_callback()
|
|
nodes = self.get_execution_trace_root(fp.name)
|
|
for n in nodes:
|
|
assert "name" in n
|
|
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
|
found_root_node = True
|
|
assert found_root_node
|
|
|
|
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500")
|
|
def test_execution_trace_nested_tensor(self):
|
|
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
|
|
fp.close()
|
|
|
|
observer = ExecutionTraceObserver().register_callback(fp.name)
|
|
|
|
def fn(nt):
|
|
return nt.sin().cos()
|
|
|
|
with torch.profiler.profile(execution_trace_observer=observer):
|
|
for i in range(3):
|
|
values = torch.rand((8 + i, 4 + i))
|
|
offsets = torch.tensor([0, 2, 4, 6, 8 + i])
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
fn(nt)
|
|
|
|
nodes = self.get_execution_trace_root(fp.name)
|
|
found_cos = False
|
|
for n in nodes:
|
|
assert "name" in n
|
|
if "cos" in n["name"]:
|
|
found_cos = True
|
|
assert found_cos
|
|
|
|
@unittest.skipIf(
|
|
not TEST_CUDA,
|
|
"need CUDA device availability to run",
|
|
)
|
|
def test_execution_trace_record_integral_tensor_range(self):
|
|
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
|
|
fp.close()
|
|
|
|
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_RANGE"] = "1"
|
|
t1 = torch.tensor([[1, 2], [3, 4]]).cuda()
|
|
t2 = torch.tensor([[0, 0], [1, 0]]).cuda()
|
|
with profile(
|
|
activities=supported_activities(),
|
|
schedule=torch.profiler.schedule(
|
|
skip_first=0, wait=0, warmup=0, active=1, repeat=1
|
|
),
|
|
record_shapes=True,
|
|
execution_trace_observer=(
|
|
ExecutionTraceObserver().register_callback(fp.name)
|
|
),
|
|
) as p:
|
|
torch.gather(t1, 1, t2)
|
|
p.step()
|
|
|
|
nodes = self.get_execution_trace_root(fp.name)
|
|
for n in nodes:
|
|
assert "name" in n
|
|
if "aten::gather" in n["name"]:
|
|
for attr in n["attrs"]:
|
|
if attr["name"] == "tensor_range":
|
|
assert attr["value"] == '{"0":[1,4],"1":[0,1]}'
|
|
|
|
@unittest.skipIf(
|
|
not TEST_CUDA,
|
|
"need CUDA device availability to run",
|
|
)
|
|
def test_execution_trace_record_integral_tensor_data(self):
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
fp_name = os.path.join(temp_dir, "test.et.json")
|
|
|
|
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA"] = (
|
|
"aten::gather"
|
|
)
|
|
et = ExecutionTraceObserver()
|
|
et.register_callback(fp_name)
|
|
et.set_extra_resource_collection(True)
|
|
|
|
t1 = torch.tensor([[1, 2], [3, 4]]).cuda()
|
|
t2 = torch.tensor([[0, 0], [1, 0]]).cuda()
|
|
with profile(
|
|
activities=supported_activities(),
|
|
schedule=torch.profiler.schedule(
|
|
skip_first=0, wait=0, warmup=0, active=1, repeat=1
|
|
),
|
|
record_shapes=True,
|
|
execution_trace_observer=et,
|
|
) as p:
|
|
torch.gather(t1, 1, t2)
|
|
p.step()
|
|
|
|
resourceDir = fp_name.replace(".json", "_resources")
|
|
assert os.path.exists(resourceDir + "/nid_4_tid_0.dat")
|
|
assert os.path.exists(resourceDir + "/nid_4_tid_1.dat")
|
|
|
|
t1 = np.fromfile(resourceDir + "/nid_4_tid_0.dat", dtype=np.int64)
|
|
t2 = np.fromfile(resourceDir + "/nid_4_tid_1.dat", dtype=np.int64)
|
|
assert (t1 == np.array([1, 2, 3, 4])).all()
|
|
assert (t2 == np.array([0, 0, 1, 0])).all()
|
|
|
|
|
|
devices = ["cpu", "cuda"]
|
|
if TEST_XPU:
|
|
devices.append("xpu")
|
|
if TEST_HPU:
|
|
devices.append("hpu")
|
|
instantiate_device_type_tests(
|
|
TestExecutionTrace, globals(), allow_xpu="xpu" in devices, only_for=devices
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|