From 01a08739e006c9711fd9edf23fbe7cf23d43df2d Mon Sep 17 00:00:00 2001 From: Grace Ho <146482179+gracehonv@users.noreply.github.com> Date: Tue, 19 Aug 2025 00:44:53 -0700 Subject: [PATCH] [misc] split engine_model into json file for nsys profile tool (#23117) Signed-off-by: Grace Ho Signed-off-by: Grace Ho <146482179+gracehonv@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung --- tools/profiler/nsys_profile_tools/README.md | 53 +++-- .../nsys_profile_tools/gputrc2graph.py | 205 ++++-------------- .../nsys_profile_tools/vllm_engine_model.json | 63 ++++++ 3 files changed, 135 insertions(+), 186 deletions(-) create mode 100644 tools/profiler/nsys_profile_tools/vllm_engine_model.json diff --git a/tools/profiler/nsys_profile_tools/README.md b/tools/profiler/nsys_profile_tools/README.md index 75ae0811cc..9577efb68f 100644 --- a/tools/profiler/nsys_profile_tools/README.md +++ b/tools/profiler/nsys_profile_tools/README.md @@ -36,8 +36,7 @@ profiling and analyzing nsys profile output. ## Notes - Make sure you have pandas installed. -- Make sure nsys is installed, and specify the path to the `nsys` command with - `--nsys_cmd` if it is not in your PATH. +- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is installed, and specify the path to the `nsys` command with `--nsys_cmd` if it is not in your PATH. - For more details on available engines and models, see the help string in the script or run: @@ -135,34 +134,31 @@ time which would cause a difference for the overall category. ## Example 3: add new classification for a new model -Suppose there's a new model ABC that is available for engine DEF, and say there -are 4 kernels to be classified into "gemm" and "attn", where the gemm kernels -have names with "*H*" or "*I*" in them, and attn kernels have names with "*J*" -or "*K*" in them, add a new entry like so: +To create a new engine DEF with model ABC, just add another json file in the same directory as +gputrc2graph.py with the same format as the other json files. The script will automatically pick up all the json files in the same directory as engine/model specifications. -```python -engine_model = { - 'DEF': { - 'ABC': { - 'layer_anno': { - 'Stage': { - '.*': 'layer', - }, - 'Substage': { - 'H|I': 'gemm', - 'J|K': 'attn', - 'CUDA mem': 'non-gpu-H_D_memops', - '.*': 'misc' - } - } - }, - } - 'vllm': {...} +Then, for this new model, suppose there are 4 kernels to be classified into "gemm" and "attn", where the gemm kernels +have names with "*H*" or "*I*" in them, and attn kernels have names with "*J*" +or "*K*" in them, just add another .json file in the same directory as +gputrc2graph.py with the same format as the other json files, like the following: + +```json +{ + "DEF": { + "ABC": { + "H|I": "gemm", + "J|K": "attn", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} ``` -Basically Substage is a dictionary with a list of key/value pairs, where the -keys are regex's of the kernel names to be classified, and values are the -classification bins which one wishes to compare across engines/models. +Each entry in the dictionary consists of: + +- key: a regex used to classify the kernels +- value: the category to classify the kernels into. The last 2 entries are common for all engine/models, consisting of CUDA memory operations and a 'misc' for anything that's leftover and can't be classified. @@ -173,3 +169,6 @@ like the following: ```bash --infile new.nsys-rep,DEF,ABC, ``` + +If the engine_DEF.json file already exists, just add the model as a new node in +the existing engine file, after the other models. diff --git a/tools/profiler/nsys_profile_tools/gputrc2graph.py b/tools/profiler/nsys_profile_tools/gputrc2graph.py index 8921e1f20f..42dfede9e9 100755 --- a/tools/profiler/nsys_profile_tools/gputrc2graph.py +++ b/tools/profiler/nsys_profile_tools/gputrc2graph.py @@ -15,132 +15,18 @@ logger = logging.getLogger(__name__) # helper data class for annotating kernels -class EngineModelData: - # engine + model mappings - engine_model = { - 'vllm': { - 'llama': { - 'layer_anno': { - 'Stage': { - '.*': 'layer', - }, - 'Substage': { - 'gemm': 'gemm', - 'fused_moe_kernel|GroupProblemShape|group_gemm_starts': - 'moe_gemm', #llama4 - 'moe|sigmoid': 'moe', #llama4 - 'CatArrayBatched|prepare_inputs': 'prepare_next', - 'flash': 'attn', - 'ncclDevKernel|cross_device_reduce': - 'nccl_and_custom_ar', - '_norm_': 'norm', - 'act_and_mul_': 'silu', - 'rotary_embedding_kernel': 'rope', - 'SoftMax': 'softmax', - 'elementwise': 'elementwise', - 'fp8_quant': 'quantize', - 'reduce_kernel': 'reduce', - 'triton': 'triton_kernel', - 'CUDA mem': 'non-gpu-H_D_memops', - '.*': 'misc' - } - } - }, - 'ds': { - 'layer_anno': { - 'Stage': { - '.*': 'layer', - }, - 'Substage': { - 'block_fp8|gemm_fp8_blockwise': - 'block_fp8_gemm', - 'fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal': - 'moe_gemm', - 'gemm|matmul|nvjet': - 'gemm', - 'moe|sigmoid|expert': - 'moe', - '_fwd_|FlashAttn|_mla_|_attn_': - 'attn', - 'CatArrayBatched': - 'prepare_next', - 'ncclDevKernel|cross_device_reduce': - 'nccl_and_custom_ar', - 'Norm|_norm_': - 'norm', - 'sbtopk': - 'topk', - 'act_and_mul_': - 'activation', - 'compute_position_kernel': - 'rope', - 'elementwise': - 'elementwise', - 'fp8_quant|quant_fp8|cvt_fp16_to_fp4': - 'quantize', - 'reduce': - 'reduce', - 'SoftMax': - 'softmax', - 'triton': - 'triton_kernel', - 'CUDA mem': - 'non-gpu-H_D_memops', - '.*': - 'misc' - } - } - }, - 'gpt-oss': { - 'layer_anno': { - 'Stage': { - '.*': 'layer', - }, - 'Substage': { - 'block_fp8|gemm_fp8_blockwise': - 'block_fp8_gemm', - 'fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_' - # this section is triton_moe_gemm - '|matmul_ogs_|_topk_forward|_combined_routing' - '|_sum_bitmatrix_rows|_compute_writeback_idx': - 'moe_gemm', - 'gemm|matmul|nvjet': - 'gemm', - 'moe|sigmoid|expert|splitKreduce': - 'moe', - '_fwd_|FlashAttn|_mla_|_attn_|_flash_|flash::prepare_varlen|fmha': - 'attn', - 'CatArrayBatched': - 'prepare_next', - 'ncclDevKernel|cross_device_reduce': - 'nccl_and_custom_ar', - 'Norm|_norm_': - 'norm', - 'sbtopk': - 'topk', - 'act_and_mul_': - 'activation', - 'compute_position_kernel': - 'rope', - 'elementwise': - 'elementwise', - 'fp8_quant|quant_fp8|cvt_fp16_to_fp4|quantize': - 'quantize', - 'reduce': - 'reduce', - 'SoftMax': - 'softmax', - 'triton': - 'triton_kernel', - 'CUDA mem': - 'non-gpu-H_D_memops', - '.*': - 'misc' - } - } - } - }, - } +def load_engine_model(): + """ returns engine_model built from all json files in the current dir """ + import glob + import json + engine_model = {} + + json_files = glob.glob( + os.path.join(os.path.dirname(__file__) or ".", "*.json")) + for fname in json_files: + with open(fname, encoding="utf-8") as f: + engine_model.update(json.load(f)) + return engine_model class GPUTrace2Graph: @@ -148,8 +34,7 @@ class GPUTrace2Graph: Parses output of nsys report, generates csv and bar chart output """ - def __init__(self, nsys_cmd): - self.nsys_cmd = nsys_cmd + def __init__(self): import pandas as pd # avoid importing till needed self.pd = pd self.pd.options.mode.copy_on_write = True @@ -227,7 +112,7 @@ class GPUTrace2Graph: title = 'Model_Engine' x = 'Model_Engine' y = 'Elapsed Time (sec)' - color = 'Substage' + color = 'Category' """ generate kernel mapping table """ # Sort Model_Engine categories by last field after underscore df['Model_Engine'] = self.pd.Categorical( @@ -249,14 +134,13 @@ class GPUTrace2Graph: Generate data table with columns per Model_Engine into result.html """ pivot_df = df.pivot_table(values='Elapsed Time (sec)', - index='Substage', + index='Category', columns='Model_Engine', aggfunc='sum', observed=False).round(2) # Add sum row at bottom pivot_df.loc['total_elapsed_sec'] = pivot_df.sum() pivot_df.fillna('').to_html('temp.html') - print('got') with (open(f'{output_name}.html', 'a', encoding='utf-8') as outfile, open('temp.html', encoding='utf-8') as infile): outfile.write(infile.read()) @@ -264,23 +148,22 @@ class GPUTrace2Graph: print(f'Finished generating: \n' f' {output_name}.html for stack bar chart \n' - f' {output_name}.csv for Kernel-Substage mapping') + f' {output_name}.csv for Kernel-Category mapping') def anno_gpu_kernname(self, df, mapping): - """ add "stage" and "substage" columns """ + """ add "Category" column """ - def anno_gpu_kernname_helper(name, stage): - for kern_name, val in mapping['layer_anno'][stage].items(): + def anno_gpu_kernname_helper(name): + for kern_name, val in mapping.items(): if re.search(kern_name, name): return val - for stage in ['Stage', 'Substage']: - df[stage] = df['Name'].apply(anno_gpu_kernname_helper, stage=stage) + df['Category'] = df['Name'].apply(anno_gpu_kernname_helper) def make_nongpu_row(self, df, nongpu_sec): """ this will append non-gpu time entry at end of df """ nongpu_row = self.pd.DataFrame([df.iloc[-1]]) - nongpu_row['Substage'] = nongpu_row['Name'] = 'CPU(non-GPU)' + nongpu_row['Category'] = nongpu_row['Name'] = 'CPU(non-GPU)' nongpu_row['Instances'] = 1 nongpu_row['Elapsed Time (sec)'] = nongpu_sec return (nongpu_row) @@ -302,7 +185,7 @@ class GPUTrace2Graph: logger.info('generating %s', new_file) return True - def gen_sum_file(self, file): + def gen_sum_file(self, file, nsys_cmd): """ generates sum file from nsys trace with times per kernel and returns the name of the sum file @@ -318,17 +201,21 @@ class GPUTrace2Graph: sum_file = f'{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv' if self.should_gen_file(nsys_stats_file, file): cmd = [ - self.nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o', + nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o', f'{file_dir}/{file_name}' ] cmd_str = ' '.join(cmd) logger.info('+ %s', cmd_str) + # estimate time based on calibrated 240M/min + file_size_mb = os.path.getsize(file) / 1e6 + logger.info( + 'nsys stats for %.2f MB file expected to take %.2f min', + file_size_mb, file_size_mb / 240) try: - subprocess.run(cmd) + subprocess.run(cmd, check=True) except Exception: - logger.error( - "%s failed, specify --nsys_cmd for correct nsys path", - cmd_str) + logger.error("%s failed; Use --nsys_cmd to specify nsys path", + cmd_str) exit(1) logger.info('generating non-overalapped sum %s', sum_file) self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file) @@ -336,7 +223,7 @@ class GPUTrace2Graph: logger.info('Finished generating %s', sum_file) return sum_file - def gen_graph(self, in_file, out_dir, title): + def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model): """ generates graph and csv file from in_file into out_dir """ # Initialize an empty DataFrame to store combined data combined_df = self.pd.DataFrame() @@ -345,17 +232,16 @@ class GPUTrace2Graph: file_name = os.path.basename(file) if not file_dir: file_dir = '.' - sum_file = self.gen_sum_file(file) + sum_file = self.gen_sum_file(file, nsys_cmd) # read kernel summary file df = self.pd.read_csv(sum_file) # annotate kernel to their categories - assert EngineModelData.engine_model.get(engine) - assert EngineModelData.engine_model[engine].get(model) + assert engine_model.get(engine), f'engine {engine} unknown' + assert engine_model[engine].get(model), f'model {model} unknown' # remove nsys-rep from file_name for shorter x-label file_name = file_name.replace('.nsys-rep', '') df['Model_Engine'] = f'{model}_{engine}_{file_name}_{idx}' - self.anno_gpu_kernname(df, - EngineModelData.engine_model[engine][model]) + self.anno_gpu_kernname(df, engine_model[engine][model]) # patch in non-gpu time gpu_sec = round(df['Elapsed Time (sec)'].sum(), 1) total_sec = round(float(total_sec), 1) @@ -393,12 +279,12 @@ def main(): "--out_dir results/ --title \"Model=gpt-oss vLLM chart\""), formatter_class=argparse.RawDescriptionHelpFormatter) - # Build help string showing available engine/model combinations - engine_model_help = [] - for engine, models in EngineModelData.engine_model.items(): - model_list = list(models.keys()) - engine_model_help.append(f"{engine}:[{','.join(model_list)}]") - engine_model_str = ' '.join(engine_model_help) + # load supported engine_model + engine_model_supported = load_engine_model() + # Get a string representation of supported engine/model combinations + engine_model_supported_str = ', '.join( + f"{engine}:[{', '.join(models.keys())}]" + for engine, models in engine_model_supported.items()) parser.add_argument( '--in_file', type=parse_tuple, @@ -408,7 +294,7 @@ def main(): 'separated by space. Elapsed_nonprofiled_sec is runtime without ' 'profiling used to calculate non-gpu time. Specify 0 to use ' 'elapsed time from nsys-rep but that might inflate non-gpu time. ' - f'Available engine:[model] are: {engine_model_str} ' + f'Available engine:[model] are: {engine_model_supported_str} ' f'Example: --infile d1.nsys-rep,vllm,llama,100 ' 'd2.nsys-rep,vllm,gpt-oss,102'), required=True) @@ -418,8 +304,9 @@ def main(): help=('nsys cmd, e.g. /usr/bin/nsys, Default: nsys'), default="nsys") args = parser.parse_args() - gputrace = GPUTrace2Graph(args.nsys_cmd) - gputrace.gen_graph(args.in_file, args.out_dir, args.title) + gputrace = GPUTrace2Graph() + gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd, + engine_model_supported) if __name__ == '__main__': diff --git a/tools/profiler/nsys_profile_tools/vllm_engine_model.json b/tools/profiler/nsys_profile_tools/vllm_engine_model.json new file mode 100644 index 0000000000..264c628dde --- /dev/null +++ b/tools/profiler/nsys_profile_tools/vllm_engine_model.json @@ -0,0 +1,63 @@ +{ + "vllm": { + "llama": { + "fused_moe_kernel|GroupProblemShape|group_gemm_starts|bmm_|GemmUniversal": "moe_gemm", + "gemm|nvjet": "gemm", + "moe|sigmoid": "moe", + "CatArrayBatched|prepare_inputs": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "_norm_|Norm": "norm", + "act_and_mul_": "activation", + "Rotary": "rope", + "SoftMax": "softmax", + "flash|fmha": "attn", + "elementwise": "elementwise", + "fp8_quant|cvt_": "quantize", + "reduce_kernel": "reduce", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "ds": { + "block_fp8|gemm_fp8_blockwise": "block_fp8_gemm", + "fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_": "moe_gemm", + "gemm|matmul|nvjet": "gemm", + "moe|sigmoid|expert": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "sbtopk": "topk", + "act_and_mul_": "activation", + "compute_position_kernel": "rope", + "elementwise": "elementwise", + "fp8_quant|quant_fp8|cvt_": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "_fwd_|FlashAttn|_mla_|_attn_|fmha": "attn", + "triton": "triton_kernel", + "topk": "topk", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "gpt-oss": { + "block_fp8|gemm_fp8_blockwise": "block_fp8_gemm", + "fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm", + "gemm|matmul|nvjet": "gemm", + "moe|sigmoid|expert|splitKreduce": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "topk": "topk", + "act_and_mul_": "activation", + "compute_position_kernel": "rope", + "elementwise": "elementwise", + "fp8_quant|quant_fp8|cvt_|quantize": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "_fwd_|FlashAttn|_mla_|_attn_|_flash_|flash::prepare_varlen|fmha": "attn", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} \ No newline at end of file