#!/usr/bin/env python3 # mypy: allow-untyped-defs """ model_dump: a one-stop shop for TorchScript model inspection. The goal of this tool is to provide a simple way to extract lots of useful information from a TorchScript model and make it easy for humans to consume. It (mostly) replaces zipinfo, common uses of show_pickle, and various ad-hoc analysis notebooks. The tool extracts information from the model and serializes it as JSON. That JSON can then be rendered by an HTML+JS page, either by loading the JSON over HTTP or producing a fully self-contained page with all of the code and data burned-in. """ # Maintainer notes follow. """ The implementation strategy has tension between 3 goals: - Small file size. - Fully self-contained. - Easy, modern JS environment. Using Preact and HTM achieves 1 and 2 with a decent result for 3. However, the models I tested with result in ~1MB JSON output, so even using something heavier like full React might be tolerable if the build process can be worked out. One principle I have followed that I think is very beneficial is to keep the JSON data as close as possible to the model and do most of the rendering logic on the client. This makes for easier development (just refresh, usually), allows for more laziness and dynamism, and lets us add more views of the same data without bloating the HTML file. Currently, this code doesn't actually load the model or even depend on any part of PyTorch. I don't know if that's an important feature to maintain, but it's probably worth preserving the ability to run at least basic analysis on models that cannot be loaded. I think the easiest way to develop this code is to cd into model_dump and run "python -m http.server", then load http://localhost:8000/skeleton.html in the browser. In another terminal, run "python -m torch.utils.model_dump --style=json FILE > \ torch/utils/model_dump/model_info.json" every time you update the Python code or model. When you update JS, just refresh. Possible improvements: - Fix various TODO comments in this file and the JS. - Make the HTML much less janky, especially the auxiliary data panel. - Make the auxiliary data panel start small, expand when data is available, and have a button to clear/contract. - Clean up the JS. There's a lot of copypasta because I don't really know how to use Preact. - Make the HTML render and work nicely inside a Jupyter notebook. - Add the ability for JS to choose the URL to load the JSON based on the page URL (query or hash). That way we could publish the inlined skeleton once and have it load various JSON blobs. - Add a button to expand all expandable sections so ctrl-F works well. - Add hyperlinking from data to code, and code to code. - Add hyperlinking from debug info to Diffusion. - Make small tensor contents available. - Do something nice for quantized models (they probably don't work at all right now). """ import argparse import io import itertools import json import os import pickle import pprint import re import sys import urllib.parse import zipfile from pathlib import Path import warnings import torch.utils.show_pickle DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024 __all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton', 'burn_in_info', 'get_info_and_burn_skeleton'] def get_storage_info(storage): if not isinstance(storage, torch.utils.show_pickle.FakeObject): raise AssertionError(f"storage is not FakeObject: {type(storage)}") if storage.module != "pers": raise AssertionError(f"storage.module is not 'pers': {storage.module!r}") if storage.name != "obj": raise AssertionError(f"storage.name is not 'obj': {storage.name!r}") if storage.state is not None: raise AssertionError(f"storage.state is not None: {storage.state!r}") if not isinstance(storage.args, tuple): raise AssertionError(f"storage.args is not a tuple: {type(storage.args)}") if len(storage.args) != 1: raise AssertionError(f"len(storage.args) is not 1: {len(storage.args)}") sa = storage.args[0] if not isinstance(sa, tuple): raise AssertionError(f"sa is not a tuple: {type(sa)}") if len(sa) != 5: raise AssertionError(f"len(sa) is not 5: {len(sa)}") if sa[0] != "storage": raise AssertionError(f"sa[0] is not 'storage': {sa[0]!r}") if not isinstance(sa[1], torch.utils.show_pickle.FakeClass): raise AssertionError(f"sa[1] is not FakeClass: {type(sa[1])}") if sa[1].module != "torch": raise AssertionError(f"sa[1].module is not 'torch': {sa[1].module!r}") if not sa[1].name.endswith("Storage"): raise AssertionError(f"sa[1].name does not end with 'Storage': {sa[1].name!r}") storage_info = [sa[1].name.replace("Storage", "")] + list(sa[2:]) return storage_info def hierarchical_pickle(data): if isinstance(data, (bool, int, float, str, type(None))): return data if isinstance(data, list): return [hierarchical_pickle(d) for d in data] if isinstance(data, tuple): return { "__tuple_values__": hierarchical_pickle(list(data)), } if isinstance(data, dict): return { "__is_dict__": True, "keys": hierarchical_pickle(list(data.keys())), "values": hierarchical_pickle(list(data.values())), } if isinstance(data, torch.utils.show_pickle.FakeObject): typename = f"{data.module}.{data.name}" if ( typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.')) ): if data.args != (): raise AssertionError("data.args is not ()") return { "__module_type__": typename, "state": hierarchical_pickle(data.state), } if typename == "torch._utils._rebuild_tensor_v2": if data.state is not None: raise AssertionError("data.state is not None") storage, offset, size, stride, requires_grad, *_ = data.args storage_info = get_storage_info(storage) return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]} if typename == "torch._utils._rebuild_qtensor": if data.state is not None: raise AssertionError("data.state is not None") storage, offset, size, stride, quantizer, requires_grad, *_ = data.args storage_info = get_storage_info(storage) if not isinstance(quantizer, tuple): raise AssertionError("quantizer is not a tuple") if not isinstance(quantizer[0], torch.utils.show_pickle.FakeClass): raise AssertionError("quantizer[0] is not a FakeClass") if quantizer[0].module != "torch": raise AssertionError("quantizer[0].module is not torch") if quantizer[0].name == "per_tensor_affine": if len(quantizer) != 3: raise AssertionError("len(quantizer) is not 3") if not isinstance(quantizer[1], float): raise AssertionError("quantizer[1] is not a float") if not isinstance(quantizer[2], int): raise AssertionError("quantizer[2] is not an int") quantizer_extra = list(quantizer[1:3]) else: quantizer_extra = [] quantizer_json = [quantizer[0].name] + quantizer_extra return {"__qtensor__": [storage_info, offset, size, stride, quantizer_json, requires_grad]} if typename == "torch.jit._pickle.restore_type_tag": if data.state is not None: raise AssertionError("data.state is not None") obj, typ = data.args if not isinstance(typ, str): raise AssertionError("typ is not a string") return hierarchical_pickle(obj) if re.fullmatch(r"torch\.jit\._pickle\.build_[a-z]+list", typename): if data.state is not None: raise AssertionError("data.state is not None") ls, = data.args if not isinstance(ls, list): raise AssertionError("ls is not a list") return hierarchical_pickle(ls) if typename == "torch.device": if data.state is not None: raise AssertionError("data.state is not None") name, = data.args if not isinstance(name, str): raise AssertionError("name is not a string") # Just forget that it was a device and return the name. return name if typename == "builtin.UnicodeDecodeError": if data.state is not None: raise AssertionError("data.state is not None") msg, = data.args if not isinstance(msg, str): raise AssertionError("msg is not a string") # Hack: Pretend this is a module so we don't need custom serialization. # Hack: Wrap the message in a tuple so it looks like a nice state object. # TODO: Undo at least that second hack. We should support string states. return { "__module_type__": typename, "state": hierarchical_pickle((msg,)), } raise Exception(f"Can't prepare fake object of type for JS: {typename}") # noqa: TRY002 raise Exception(f"Can't prepare data of type for JS: {type(data)}") # noqa: TRY002 def get_model_info( path_or_file, title=None, extra_file_size_limit=DEFAULT_EXTRA_FILE_SIZE_LIMIT): """Get JSON-friendly information about a model. The result is suitable for being saved as model_info.json, or passed to burn_in_info. """ if isinstance(path_or_file, os.PathLike): default_title = os.fspath(path_or_file) file_size = path_or_file.stat().st_size # type: ignore[attr-defined] elif isinstance(path_or_file, str): default_title = path_or_file file_size = Path(path_or_file).stat().st_size else: default_title = "buffer" path_or_file.seek(0, io.SEEK_END) file_size = path_or_file.tell() path_or_file.seek(0) title = title or default_title with zipfile.ZipFile(path_or_file) as zf: path_prefix = None zip_files = [] # pyrefly: ignore # bad-assignment for zi in zf.infolist(): prefix = re.sub("/.*", "", zi.filename) if path_prefix is None: path_prefix = prefix elif prefix != path_prefix: raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002 zip_files.append( { "filename": zi.filename, "compression": zi.compress_type, "compressed_size": zi.compress_size, "file_size": zi.file_size, } ) if path_prefix is None: raise AssertionError("path_prefix is None") version = zf.read(path_prefix + "/version").decode("utf-8").strip() def get_pickle(name): if path_prefix is None: raise AssertionError("path_prefix is None") with zf.open(path_prefix + f"/{name}.pkl") as handle: raw = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load() return hierarchical_pickle(raw) model_data = get_pickle("data") constants = get_pickle("constants") # Intern strings that are likely to be reused. # Pickle automatically detects shared structure, # so reused strings are stored efficiently. # However, JSON has no way of representing this, # so we have to do it manually. interned_strings : dict[str, int] = {} def intern(s): if s not in interned_strings: interned_strings[s] = len(interned_strings) return interned_strings[s] code_files = {} for zi in zf.infolist(): if not zi.filename.endswith(".py"): continue with zf.open(zi) as handle: raw_code = handle.read() with zf.open(zi.filename + ".debug_pkl") as handle: raw_debug = handle.read() # Parse debug info and add begin/end markers if not present # to ensure that we cover the entire source code. debug_info_t = pickle.loads(raw_debug) text_table = None if (len(debug_info_t) == 3 and isinstance(debug_info_t[0], str) and debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'): _, text_table, content = debug_info_t def parse_new_format(line): # (0, (('', '', 0), 0, 0)) num, ((text_indexes, fname_idx, offset), start, end), tag = line text = ''.join(text_table[x] for x in text_indexes) # type: ignore[index] fname = text_table[fname_idx] # type: ignore[index] return num, ((text, fname, offset), start, end), tag debug_info_t = map(parse_new_format, content) debug_info = list(debug_info_t) if not debug_info: debug_info.append((0, (('', '', 0), 0, 0))) if debug_info[-1][0] != len(raw_code): debug_info.append((len(raw_code), (('', '', 0), 0, 0))) code_parts = [] for di, di_next in itertools.pairwise(debug_info): start, source_range, *_ = di end = di_next[0] if end <= start: raise AssertionError("end is not greater than start") source, s_start, s_end = source_range s_text, s_file, s_line = source # TODO: Handle this case better. TorchScript ranges are in bytes, # but JS doesn't really handle byte strings. # if bytes and chars are not equivalent for this string, # zero out the ranges so we don't highlight the wrong thing. if len(s_text) != len(s_text.encode("utf-8")): s_start = 0 s_end = 0 text = raw_code[start:end] code_parts.append([text.decode("utf-8"), intern(s_file), s_line, intern(s_text), s_start, s_end]) code_files[zi.filename] = code_parts extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json") extra_files_jsons = {} for zi in zf.infolist(): if not extra_files_json_pattern.fullmatch(zi.filename): continue if zi.file_size > extra_file_size_limit: continue with zf.open(zi) as handle: try: json_content = json.load(handle) extra_files_jsons[zi.filename] = json_content except json.JSONDecodeError: extra_files_jsons[zi.filename] = "INVALID JSON" always_render_pickles = { "bytecode.pkl", } extra_pickles = {} for zi in zf.infolist(): if not zi.filename.endswith(".pkl"): continue with zf.open(zi) as handle: # TODO: handle errors here and just ignore the file? # NOTE: For a lot of these files (like bytecode), # we could get away with just unpickling, but this should be safer. obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load() buf = io.StringIO() pprint.pprint(obj, buf) contents = buf.getvalue() # Checked the rendered length instead of the file size # because pickles with shared structure can explode in size during rendering. if os.path.basename(zi.filename) not in always_render_pickles and \ len(contents) > extra_file_size_limit: continue extra_pickles[zi.filename] = contents return { "model": { "title": title, "file_size": file_size, "version": version, "zip_files": zip_files, "interned_strings": list(interned_strings), "code_files": code_files, "model_data": model_data, "constants": constants, "extra_files_jsons": extra_files_jsons, "extra_pickles": extra_pickles, } } def get_inline_skeleton(): """Get a fully-inlined skeleton of the frontend. The returned HTML page has no external network dependencies for code. It can load model_info.json over HTTP, or be passed to burn_in_info. """ import importlib.resources # pyrefly: ignore # bad-argument-type skeleton = importlib.resources.read_text(__package__, "skeleton.html") # pyrefly: ignore # bad-argument-type js_code = importlib.resources.read_text(__package__, "code.js") for js_module in ["preact", "htm"]: # pyrefly: ignore # bad-argument-type js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs") js_url = "data:application/javascript," + urllib.parse.quote(js_lib) js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url) skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code) return skeleton def burn_in_info(skeleton, info): """Burn model info into the HTML skeleton. The result will render the hard-coded model info and have no external network dependencies for code or data. """ # Note that Python's json serializer does not escape slashes in strings. # Since we're inlining this JSON directly into a script tag, a string # containing "" would end the script prematurely and # mess up our page. Unconditionally escape fixes that. return skeleton.replace( "BURNED_IN_MODEL_INFO = null", "BURNED_IN_MODEL_INFO = " + json.dumps(info, sort_keys=True).replace("/", "\\/")) def get_info_and_burn_skeleton(path_or_bytesio, **kwargs): model_info = get_model_info(path_or_bytesio, **kwargs) skeleton = get_inline_skeleton() page = burn_in_info(skeleton, model_info) return page def main(argv, *, stdout=None): warnings.warn("torch.utils.model_dump is deprecated and will be removed in a future PyTorch release.") parser = argparse.ArgumentParser() parser.add_argument("--style", choices=["json", "html"]) parser.add_argument("--title") parser.add_argument("model") args = parser.parse_args(argv[1:]) info = get_model_info(args.model, title=args.title) output = stdout or sys.stdout if args.style == "json": output.write(json.dumps(info, sort_keys=True) + "\n") elif args.style == "html": skeleton = get_inline_skeleton() page = burn_in_info(skeleton, info) output.write(page) else: raise Exception("Invalid style") # noqa: TRY002