mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This reverts commit 35d7b321597ed00245aad533a8fa6b7fdadd73ea. Reverted https://github.com/pytorch/pytorch/pull/162245 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/162245#issuecomment-3313669412))
644 lines
22 KiB
Python
644 lines
22 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import json
|
|
import tempfile
|
|
import unittest
|
|
import uuid
|
|
from io import StringIO
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch._inductor.analysis.profile_analysis import (
|
|
_augment_trace_helper,
|
|
_create_extern_mapping,
|
|
main,
|
|
)
|
|
from torch._inductor.utils import fresh_inductor_cache, tabulate_2d, zip_dicts
|
|
from torch.testing._internal.common_cuda import SM80OrLater
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
skipIf,
|
|
)
|
|
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
|
|
from torch.testing._internal.inductor_utils import IS_BIG_GPU
|
|
|
|
|
|
example_profile = """
|
|
{
|
|
"schemaVersion": 1,
|
|
"deviceProperties": [
|
|
{
|
|
"id": 0, "name": "NVIDIA H100", "totalGlobalMem": 101997215744,
|
|
"computeMajor": 9, "computeMinor": 0,
|
|
"maxThreadsPerBlock": 1024, "maxThreadsPerMultiprocessor": 2048,
|
|
"regsPerBlock": 65536, "warpSize": 32,
|
|
"sharedMemPerBlock": 49152, "numSms": 132
|
|
, "regsPerMultiprocessor": 65536, "sharedMemPerBlockOptin": 232448, "sharedMemPerMultiprocessor": 233472
|
|
}
|
|
],
|
|
"cupti_version": 24,
|
|
"cuda_runtime_version": 12060,
|
|
"with_flops": 1,
|
|
"record_shapes": 1,
|
|
"cuda_driver_version": 12040,
|
|
"profile_memory": 1,
|
|
"trace_id": "301995E163ED42048FBD783860E6E7DC",
|
|
"displayTimeUnit": "ms",
|
|
"baseTimeNanoseconds": 1743521598000000000,
|
|
"traceEvents": [
|
|
{
|
|
"ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093488368.463, "dur": 425.453,
|
|
"args": {
|
|
"External id": 1340,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \
|
|
["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", "False", "[0, 0]", "1"], "Input type": ["float", "float", "", \
|
|
"ScalarList", "ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar"], "Input Strides": [[150528, 1, 672, 3],\
|
|
[147, 1, 21, 3], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], \
|
|
[], []], "Ev Idx": 1339
|
|
}
|
|
},
|
|
{
|
|
"ph": "X", "cat": "cpu_op", "name": "aten::_convolution", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093488444.498, "dur": 341.867,
|
|
"args": {
|
|
"External id": 1341,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]",\
|
|
"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList",\
|
|
"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input Strides": \
|
|
[[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], \
|
|
[64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340
|
|
}
|
|
},
|
|
{
|
|
"ph": "X", "cat": "cpu_op", "name": "aten::addmm", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093513655.849, "dur": 251.130,
|
|
"args": {
|
|
"External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \
|
|
["", "", "", "1", "1", ""], "Input type": ["float", "float", "float", "Scalar", "Scalar", "float"], "Input Strides":\
|
|
[[1], [0, 1], [1, 2048], [], [], [1000, 1]], "Input Dims": [[1000], [1, 2048], [2048, 1000], [], [], [1, 1000]], \
|
|
"Ev Idx": 1618
|
|
}
|
|
},
|
|
{
|
|
"ph": "X", "cat": "kernel", "name": "void cutlass_addmm", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093513655.849, "dur": 251.130,
|
|
"args": {
|
|
"External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618
|
|
}
|
|
},
|
|
{
|
|
"ph": "X", "cat": "kernel", "name": "void convolution_kernel", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093513655.849, "dur": 200.130,
|
|
"args": {
|
|
"External id": 1342, "Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618
|
|
}
|
|
},
|
|
{
|
|
"ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093488444.498, "dur": 341.867,
|
|
"args": {
|
|
"External id": 1342,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", \
|
|
"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList", \
|
|
"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input \
|
|
Strides": [[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": \
|
|
[[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340
|
|
}
|
|
}
|
|
],
|
|
"traceName": "/tmp/compiled_module_profile.json"
|
|
}
|
|
"""
|
|
|
|
|
|
def verify_flops(self, expected_flops, out_profile):
|
|
j = 0
|
|
for i in range(len(out_profile["traceEvents"])):
|
|
if "kernel_flop" in out_profile["traceEvents"][i]["args"]:
|
|
self.assertEqual(
|
|
out_profile["traceEvents"][i]["args"]["kernel_flop"],
|
|
expected_flops[j],
|
|
)
|
|
j += 1
|
|
|
|
|
|
def random_tensor(size, dtype, **kwargs):
|
|
if dtype in [torch.half, torch.bfloat16, torch.float, torch.double]:
|
|
return torch.randn(size, dtype=dtype, **kwargs)
|
|
elif dtype in [torch.uint8, torch.int8, torch.short, torch.int, torch.long]:
|
|
return torch.randint(0, 100, size, dtype=dtype, **kwargs)
|
|
else:
|
|
raise ValueError("Unsupported data type")
|
|
|
|
|
|
def cT(device, dtype):
|
|
def T(*shape, requires_grad=False):
|
|
return random_tensor(
|
|
shape, requires_grad=requires_grad, device=device, dtype=dtype
|
|
)
|
|
|
|
return T
|
|
|
|
|
|
def FlopCounterMode(*args, **kwargs):
|
|
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
|
|
|
|
|
|
TMP_DIR = tempfile.mkdtemp()
|
|
|
|
|
|
def trace_files():
|
|
TRACE1 = f"{TMP_DIR}/trace1-{uuid.uuid4()}.json"
|
|
TRACE2 = f"{TMP_DIR}/trace2-{uuid.uuid4()}.json"
|
|
return TRACE1, TRACE2
|
|
|
|
|
|
def _test_model(device, dtype, compile=True, addmm=True, bmm=True):
|
|
T = cT(device, dtype)
|
|
|
|
def model():
|
|
input_conv = T(1, 3, 56, 56)
|
|
conv_weight = T(12, 3, 5, 5)
|
|
|
|
# Increased matrix sizes
|
|
B = 8
|
|
M = 256
|
|
N = 512
|
|
K = 768
|
|
mat1 = T(M, N)
|
|
mat2 = T(N, K)
|
|
|
|
batch_mat1 = T(B, M, N)
|
|
batch_mat2 = T(B, N, K)
|
|
|
|
conv_output = F.conv2d(input_conv, conv_weight)
|
|
|
|
conv_output = conv_output * 10
|
|
mm_output = torch.mm(mat1, mat2)
|
|
ret = [
|
|
conv_output.flatten(),
|
|
mm_output.flatten(),
|
|
]
|
|
|
|
if addmm:
|
|
addmm_output = torch.addmm(
|
|
torch.zeros(mm_output.shape, device=mat1.device, dtype=mat1.dtype),
|
|
mat1,
|
|
mat2,
|
|
)
|
|
ret.append(addmm_output.flatten())
|
|
|
|
if bmm:
|
|
bmm_output = torch.bmm(batch_mat1, batch_mat2)
|
|
ret.append(bmm_output.flatten())
|
|
|
|
if bmm and addmm:
|
|
baddbmm_output = torch.baddbmm(
|
|
torch.zeros(
|
|
1,
|
|
*mm_output.shape,
|
|
device=batch_mat1.device,
|
|
dtype=batch_mat1.dtype,
|
|
),
|
|
batch_mat1,
|
|
batch_mat2,
|
|
)
|
|
ret.append(baddbmm_output.flatten())
|
|
|
|
return torch.cat(ret)
|
|
|
|
if compile:
|
|
return torch.compile(
|
|
model, options={"benchmark_kernel": True, "profile_bandwidth": True}
|
|
)
|
|
return model
|
|
|
|
|
|
def _pointwise_test_model(device, dtype, compile=True):
|
|
T = cT(device, dtype)
|
|
|
|
def model():
|
|
M = 1024
|
|
N = 512
|
|
mat3 = T(M, N)
|
|
mat4 = T(M, N)
|
|
pointwise_output = torch.add(mat3, mat4).sin()
|
|
return pointwise_output
|
|
|
|
if compile:
|
|
return torch.compile(
|
|
model, options={"benchmark_kernel": True, "profile_bandwidth": True}
|
|
)
|
|
return model
|
|
|
|
|
|
prefix = ["profile.py"]
|
|
|
|
|
|
class TestUtils(TestCase):
|
|
def test_tabulate2d(self):
|
|
headers = ["Kernel", "Self H100 TIME (ms)", "Count", "Percent"]
|
|
rows = [
|
|
["aten::mm", 0.500, 7, 0.0],
|
|
["aten::bmm", 0.400, 6, 0.0],
|
|
["aten::baddbmm", 0.300, 5, 0.0],
|
|
["aten::convolution", 0.200, 4, 0.0],
|
|
["aten::cudnn_convolution", 0.100, 3, 0.0],
|
|
]
|
|
table = [
|
|
" Kernel | Self H100 TIME (ms) | Count | Percent ",
|
|
"-----------------------------------------------------------------",
|
|
" aten::mm | 0.5 | 7 | 0.0 ",
|
|
" aten::bmm | 0.4 | 6 | 0.0 ",
|
|
" aten::baddbmm | 0.3 | 5 | 0.0 ",
|
|
" aten::convolution | 0.2 | 4 | 0.0 ",
|
|
" aten::cudnn_convolution | 0.1 | 3 | 0.0 ",
|
|
]
|
|
res = tabulate_2d(rows, headers)
|
|
for r, t in zip(res.split("\n"), table):
|
|
self.assertEqual(r, t)
|
|
|
|
def test_zip_dicts(self):
|
|
d1 = {"a": 1, "b": 2}
|
|
d2 = {"a": 3, "c": 4}
|
|
res1 = zip_dicts(d1, d2, d1_default=32, d2_default=48)
|
|
self.assertEqual(set(res1), {("a", 1, 3), ("b", 2, 48), ("c", 32, 4)})
|
|
res2 = zip_dicts(d1, d2)
|
|
self.assertEqual(set(res2), {("a", 1, 3), ("b", 2, None), ("c", None, 4)})
|
|
|
|
|
|
class TestAnalysis(TestCase):
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
def test_noop(self):
|
|
with (
|
|
patch("sys.stdout", new_callable=StringIO) as mock_stdout,
|
|
patch("sys.argv", [*prefix]),
|
|
):
|
|
main()
|
|
self.assertEqual(mock_stdout.getvalue(), "")
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
@dtypes(torch.float, torch.double, torch.float16)
|
|
def test_diff(self, device, dtype):
|
|
"""
|
|
diff, testing out the nruns feature too.
|
|
"""
|
|
if device == "cpu" or torch.version.hip is not None:
|
|
# TODO cpu support
|
|
return
|
|
om = _test_model(device, dtype)
|
|
REPEAT = 5
|
|
trace1, trace2 = trace_files()
|
|
print(f"first trace {trace1}")
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as p:
|
|
om()
|
|
p.export_chrome_trace(trace1)
|
|
|
|
print(f"second trace {trace2}")
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as p:
|
|
for _ in range(REPEAT):
|
|
om()
|
|
p.export_chrome_trace(trace2)
|
|
|
|
print("diffing...")
|
|
with patch(
|
|
"sys.argv",
|
|
[
|
|
*prefix,
|
|
"--diff",
|
|
trace1,
|
|
"foo",
|
|
trace2,
|
|
"bar",
|
|
str(dtype).split(".")[-1],
|
|
"--name_limit",
|
|
"30",
|
|
],
|
|
):
|
|
main()
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
def test_augment_trace_helper_unit(self):
|
|
js = json.loads(example_profile)
|
|
out_profile = _augment_trace_helper(js)
|
|
expected_flops = [4096000, 4096000, 223552896, 223552896, 0, 0, 0]
|
|
verify_flops(self, expected_flops, out_profile)
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
@dtypes(torch.float, torch.double, torch.float16)
|
|
@parametrize(
|
|
"maxat",
|
|
[
|
|
(True, "TRITON"),
|
|
],
|
|
)
|
|
@skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune")
|
|
@torch._inductor.config.patch(force_disable_caches=True)
|
|
def test_triton_has_metadata(self, device, dtype, maxat):
|
|
"""
|
|
make sure that the chrome trace of triton kernels contains certain values
|
|
"""
|
|
if device == "cpu" or torch.version.hip is not None:
|
|
return
|
|
|
|
T = cT(device, dtype)
|
|
input_conv = T(1, 3, 56, 56)
|
|
conv_weight = T(12, 3, 5, 5)
|
|
|
|
def om(i, w):
|
|
# Convolution operation
|
|
conv_output = F.conv2d(i, w)
|
|
return conv_output
|
|
|
|
max_autotune, backends = maxat
|
|
comp_omni = torch.compile(
|
|
om,
|
|
options={
|
|
"benchmark_kernel": True,
|
|
"max_autotune_gemm_backends": backends,
|
|
"max_autotune": max_autotune,
|
|
},
|
|
)
|
|
|
|
def verify_triton(comp):
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as profile:
|
|
comp(input_conv, conv_weight)
|
|
|
|
trace1, _ = trace_files()
|
|
profile.export_chrome_trace(trace1)
|
|
with open(trace1) as f:
|
|
out_profile = json.load(f)
|
|
seen = False
|
|
for event in out_profile["traceEvents"]:
|
|
if "triton" in event["name"] and "conv" in event["name"]:
|
|
seen = True
|
|
self.assertTrue(seen, "no triton conv found")
|
|
|
|
verify_triton(comp_omni)
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
@dtypes(torch.float, torch.float16)
|
|
@parametrize(
|
|
"maxat",
|
|
[
|
|
(False, "ATEN,TRITON"),
|
|
(True, "ATEN,TRITON"),
|
|
(True, "ATEN"),
|
|
(True, "TRITON"),
|
|
],
|
|
)
|
|
@unittest.skipIf(
|
|
not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune"
|
|
)
|
|
@torch._inductor.config.patch(force_disable_caches=True)
|
|
def test_augment_trace_against_flop_counter(self, device, dtype, maxat):
|
|
# this tests to see if we can only use a Triton backend for max autotune
|
|
max_autotune, backends = maxat
|
|
if device == "cpu" or torch.version.hip is not None:
|
|
return
|
|
om = _test_model(device, dtype, compile=False)
|
|
|
|
comp_omni = torch.compile(
|
|
om,
|
|
options={
|
|
"benchmark_kernel": True,
|
|
"max_autotune_gemm_backends": backends,
|
|
"max_autotune": max_autotune,
|
|
},
|
|
)
|
|
comp_omni()
|
|
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as profile:
|
|
comp_omni()
|
|
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with FlopCounterMode() as mode:
|
|
comp_omni()
|
|
|
|
trace1, trace2 = trace_files()
|
|
profile.export_chrome_trace(trace1)
|
|
with patch(
|
|
"sys.argv",
|
|
[*prefix, "--augment_trace", trace1, trace2, str(dtype).split(".")[-1]],
|
|
):
|
|
main()
|
|
|
|
with open(trace2) as f:
|
|
out_profile = json.load(f)
|
|
|
|
flop_counts = mode.flop_counts
|
|
extern_mapping = _create_extern_mapping(out_profile)
|
|
|
|
seen_mm = False
|
|
seen_bmm = False
|
|
seen_baddbmm = False
|
|
seen_conv = False
|
|
for event in out_profile["traceEvents"]:
|
|
if (
|
|
"cat" not in event
|
|
or event["cat"] != "kernel"
|
|
or "args" not in event
|
|
or "External id" not in event["args"]
|
|
):
|
|
continue
|
|
|
|
external_op = extern_mapping[event["args"]["External id"]][0]
|
|
name: str = external_op["name"]
|
|
self.assertNotEqual(name, None)
|
|
self.assertEqual(type(name), str)
|
|
if name.startswith("aten::mm") or "_mm_" in name:
|
|
seen_mm = True
|
|
self.assertEqual(
|
|
event["args"]["kernel_flop"],
|
|
flop_counts["Global"][torch.ops.aten.mm],
|
|
)
|
|
if (
|
|
name.startswith(
|
|
(
|
|
"aten::cudnn_convolution",
|
|
"aten::convolution",
|
|
"aten::_convolution",
|
|
)
|
|
)
|
|
or "conv" in name
|
|
):
|
|
seen_conv = True
|
|
self.assertEqual(
|
|
event["args"]["kernel_flop"],
|
|
flop_counts["Global"][torch.ops.aten.convolution],
|
|
)
|
|
if name.startswith("aten::baddbmm") or "_baddbmm_" in name:
|
|
seen_baddbmm = True
|
|
self.assertEqual(
|
|
event["args"]["kernel_flop"],
|
|
flop_counts["Global"][torch.ops.aten.baddbmm],
|
|
)
|
|
if name.startswith("aten::bmm") or "_bmm_" in name:
|
|
seen_bmm = True
|
|
self.assertEqual(
|
|
event["args"]["kernel_flop"],
|
|
flop_counts["Global"][torch.ops.aten.bmm],
|
|
)
|
|
self.assertTrue(seen_mm)
|
|
self.assertTrue(seen_bmm)
|
|
self.assertTrue(seen_baddbmm)
|
|
self.assertTrue(seen_conv)
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
@dtypes(torch.float, torch.float16)
|
|
@parametrize(
|
|
"maxat",
|
|
[
|
|
(False, "ATEN,TRITON"),
|
|
(True, "ATEN,TRITON"),
|
|
(True, "ATEN"),
|
|
(True, "TRITON"),
|
|
],
|
|
)
|
|
@unittest.skipIf(
|
|
not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune"
|
|
)
|
|
@torch._inductor.config.patch(force_disable_caches=True)
|
|
def test_pointwise_bandwidth(self, device, dtype, maxat):
|
|
# this tests to see if we can only use a Triton backend for max autotune
|
|
max_autotune, backends = maxat
|
|
if device == "cpu" or torch.version.hip is not None:
|
|
return
|
|
om = _pointwise_test_model(device, dtype, compile=False)
|
|
comp_omni = torch.compile(
|
|
om,
|
|
options={
|
|
"benchmark_kernel": True,
|
|
"max_autotune_gemm_backends": backends,
|
|
"max_autotune": max_autotune,
|
|
},
|
|
)
|
|
comp_omni()
|
|
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as profile:
|
|
comp_omni()
|
|
trace1, _ = trace_files()
|
|
profile.export_chrome_trace(trace1)
|
|
|
|
with patch(
|
|
"sys.argv",
|
|
[*prefix, "--analysis", trace1, str(dtype).split(".")[-1]],
|
|
):
|
|
main()
|
|
|
|
with open(trace1) as f:
|
|
out_profile = json.load(f)
|
|
|
|
for event in out_profile["traceEvents"]:
|
|
if event["name"] == "triton_poi_fused_add_randn_sin_0":
|
|
event["args"]["kernel_num_gb"] = 0.002097168
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
@dtypes(torch.float, torch.float16)
|
|
def test_combine_profiles(self, device, dtype):
|
|
"""
|
|
Test combining multiple profiles into a single profile.
|
|
"""
|
|
if device == "cpu" or torch.version.hip is not None:
|
|
return
|
|
|
|
# Create three different models to generate different traces
|
|
om1 = _test_model(device, dtype, addmm=True, bmm=False)
|
|
om2 = _test_model(device, dtype, addmm=False, bmm=True)
|
|
om3 = _pointwise_test_model(device, dtype)
|
|
|
|
# Generate three separate traces
|
|
trace1, trace2 = trace_files()
|
|
trace3 = f"{TMP_DIR}/trace3-{uuid.uuid4()}.json"
|
|
combined_trace = f"{TMP_DIR}/combined-{uuid.uuid4()}.json"
|
|
|
|
# Generate first trace
|
|
torch._dynamo.reset()
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as p1:
|
|
om1()
|
|
p1.export_chrome_trace(trace1)
|
|
|
|
# Generate second trace
|
|
torch._dynamo.reset()
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as p2:
|
|
om2()
|
|
p2.export_chrome_trace(trace2)
|
|
|
|
# Generate third trace
|
|
torch._dynamo.reset()
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as p3:
|
|
om3()
|
|
p3.export_chrome_trace(trace3)
|
|
|
|
# Combine the three traces
|
|
with patch(
|
|
"sys.argv",
|
|
[
|
|
*prefix,
|
|
"--combine",
|
|
trace1,
|
|
trace2,
|
|
trace3,
|
|
combined_trace,
|
|
],
|
|
):
|
|
main()
|
|
|
|
# Verify the combined trace exists and contains expected data
|
|
with open(combined_trace) as f:
|
|
combined_profile = json.load(f)
|
|
|
|
# Load original traces for comparison
|
|
with open(trace1) as f:
|
|
profile1 = json.load(f)
|
|
with open(trace2) as f:
|
|
profile2 = json.load(f)
|
|
with open(trace3) as f:
|
|
profile3 = json.load(f)
|
|
|
|
# Verify trace events are combined
|
|
expected_event_count = (
|
|
len(profile1["traceEvents"])
|
|
+ len(profile2["traceEvents"])
|
|
+ len(profile3["traceEvents"])
|
|
)
|
|
self.assertEqual(len(combined_profile["traceEvents"]), expected_event_count)
|
|
|
|
# Verify device properties are present
|
|
self.assertIn("deviceProperties", combined_profile)
|
|
self.assertGreater(len(combined_profile["deviceProperties"]), 0)
|
|
|
|
# Verify some trace events from each original profile are present
|
|
combined_event_names = {
|
|
event["name"] for event in combined_profile["traceEvents"]
|
|
}
|
|
|
|
# Check that we have events from each original profile
|
|
profile1_event_names = {event["name"] for event in profile1["traceEvents"]}
|
|
profile2_event_names = {event["name"] for event in profile2["traceEvents"]}
|
|
profile3_event_names = {event["name"] for event in profile3["traceEvents"]}
|
|
|
|
# At least some events from each profile should be in the combined profile
|
|
self.assertTrue(profile1_event_names.intersection(combined_event_names))
|
|
self.assertTrue(profile2_event_names.intersection(combined_event_names))
|
|
self.assertTrue(profile3_event_names.intersection(combined_event_names))
|
|
|
|
|
|
instantiate_device_type_tests(TestAnalysis, globals())
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|