Save/load op profiles (#151817)

Add ability to save/load op profiles into a yaml file:
```python
op_profile = self.get_sample_op_profile()

# Save
save_op_profiles(op_profile, "op_profile.yaml")
# Load
loaded = load_op_profiles("op_profile.yaml")

assert op_profile == loaded
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151817
Approved by: https://github.com/zou3519
This commit is contained in:
angelayi
2025-04-29 23:11:28 +00:00
committed by PyTorch MergeBot
parent 8358eca2ce
commit 8f420a500a
3 changed files with 194 additions and 1 deletions

View File

@ -1181,6 +1181,7 @@ def main():
extras_require = {
"optree": ["optree>=0.13.0"],
"opt-einsum": ["opt-einsum>=3.3"],
"pyyaml": ["pyyaml"],
}
# Read in README.md for our long_description

View File

@ -2,16 +2,20 @@
# ruff: noqa: F841
import collections
import io
import itertools
import os
import re
import subprocess
import sys
import tempfile
import typing
import unittest
from pathlib import Path
from typing import * # noqa: F403
import numpy as np
import yaml
import torch._custom_ops as custom_ops
import torch.testing._internal.optests as optests
@ -20,7 +24,15 @@ import torch.utils.cpp_extension
from functorch import make_fx
from torch import Tensor
from torch._custom_op.impl import CustomOp, infer_schema
from torch._library.fake_profile import MissingOpProfile, OpProfile, TensorMetadata
from torch._library.fake_profile import (
generate_yaml_from_profiles,
load_op_profiles,
MissingOpProfile,
OpProfile,
read_profiles_from_yaml,
save_op_profiles,
TensorMetadata,
)
from torch._library.infer_schema import tuple_to_list
from torch._utils_internal import get_file_path_2 # @manual
from torch.fx.experimental.symbolic_shapes import ShapeEnv
@ -39,6 +51,7 @@ from torch.testing._internal.common_utils import (
scoped_load_inline,
skipIfTorchDynamo,
subtest,
TemporaryFileName,
TestCase,
)
from torch.testing._internal.custom_op_db import numpy_nonzero
@ -4601,6 +4614,46 @@ class TestOpProfiles(TestCase):
with fm:
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.bfloat16)
def test_yaml(self):
op_profiles = self.get_sample_op_profile()
yaml_str = generate_yaml_from_profiles(op_profiles)
loaded = read_profiles_from_yaml(yaml_str)
self.assertEqual(op_profiles, loaded)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
def test_save_to_file(self):
op_profile = self.get_sample_op_profile()
# Saving with buffer
buffer = io.BytesIO()
save_op_profiles(op_profile, buffer)
buffer.seek(0)
loaded = load_op_profiles(buffer)
self.assertEqual(op_profile, loaded)
# Saving with file
with tempfile.NamedTemporaryFile() as f:
save_op_profiles(op_profile, f.name)
f.seek(0)
loaded = load_op_profiles(f.name)
self.assertEqual(op_profile, loaded)
# Saving with Path
with TemporaryFileName() as fname:
path = Path(fname)
save_op_profiles(op_profile, path)
loaded = load_op_profiles(path)
self.assertEqual(op_profile, loaded)
def test_version(self):
op_profiles = self.get_sample_op_profile()
yaml_str = generate_yaml_from_profiles(op_profiles)
loaded = yaml.safe_load(yaml_str)
loaded["torch_version"] = "2.7"
yaml_str = yaml.dump(loaded, sort_keys=False)
with self.assertRaisesRegex(RuntimeError, "Unable to load outdated profile"):
loaded = read_profiles_from_yaml(yaml_str)
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)

View File

@ -1,11 +1,14 @@
import contextlib
import io
import logging
import os
from collections.abc import Generator
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import torch
from torch._library.custom_ops import _maybe_get_opdef
from torch.types import FileLike
log = logging.getLogger(__name__)
@ -182,3 +185,139 @@ def unsafe_generate_fake_kernels(op_profiles: dict[str, set[OpProfile]]) -> Gene
opdef = _maybe_get_opdef(op_str)
assert opdef is not None
opdef.register_fake(old_fake)
def get_torch_version() -> str:
version = torch.__version__.split(".")
return f"{int(version[0])}.{int(version[1])}"
def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str:
"""
Generates a yaml string from the given operator profiles which can be saved
to a file. The yaml string can be loaded back into an operator profile
structure using `read_profiles_from_yaml`.
"""
import yaml
from torch._export.serde.serialize import (
_TORCH_TO_SERIALIZE_DTYPE,
_TORCH_TO_SERIALIZE_LAYOUT,
)
def serialize_tensor_metadata(t: TensorMetadata) -> dict:
return {
"rank": t.rank,
"dtype": _TORCH_TO_SERIALIZE_DTYPE[t.dtype].value,
"device": str(t.device),
"layout": _TORCH_TO_SERIALIZE_LAYOUT[t.layout].value,
}
def serialize_op_profile(op: OpProfile) -> dict:
return {
"args_profile": [
serialize_tensor_metadata(arg)
for arg in op.args_profile
if arg is not None
],
"out_profile": (
serialize_tensor_metadata(op.out_profile)
if isinstance(op.out_profile, TensorMetadata)
else [serialize_tensor_metadata(out) for out in op.out_profile]
),
}
serialized_data = {
operator: [serialize_op_profile(profile) for profile in profiles]
for operator, profiles in op_profiles.items()
}
return yaml.dump(
{"torch_version": get_torch_version(), "operators": serialized_data},
sort_keys=False,
)
def save_op_profiles(op_profiles: dict[str, set[OpProfile]], f: FileLike) -> None:
"""
Serializes the given operator profiles into a yaml format and saves it to
the given file. The operator profile can be loaded back using `load_op_profiles`.
"""
yaml_str = generate_yaml_from_profiles(op_profiles)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
with open(f, "w") as file:
file.write(yaml_str)
elif isinstance(f, io.BytesIO):
f.write(yaml_str.encode("utf-8"))
else:
raise ValueError(f"Invalid type of file {f}")
def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]:
"""
Reads the yaml saved by `save_op_profiles` and returns the operator profiles.
"""
import yaml
from torch._export.serde.serialize import (
_SERIALIZE_TO_TORCH_DTYPE,
_SERIALIZE_TO_TORCH_LAYOUT,
)
def deserialize_tensor_metadata(data: dict) -> TensorMetadata:
return TensorMetadata(
rank=data["rank"],
dtype=_SERIALIZE_TO_TORCH_DTYPE[data["dtype"]],
device=torch.device(data["device"]),
layout=_SERIALIZE_TO_TORCH_LAYOUT[data["layout"]],
)
def deserialize_op_profile(data: dict) -> OpProfile:
args_profile = tuple(
deserialize_tensor_metadata(arg) for arg in data["args_profile"]
)
out_profile_data = data["out_profile"]
out_profile: Union[tuple[TensorMetadata], TensorMetadata] = (
tuple(deserialize_tensor_metadata(out) for out in out_profile_data) # type: ignore[assignment]
if isinstance(out_profile_data, list)
else deserialize_tensor_metadata(out_profile_data)
)
return OpProfile(args_profile=args_profile, out_profile=out_profile) # type: ignore[arg-type]
loaded_data = yaml.safe_load(yaml_str)
loaded_torch_version = loaded_data["torch_version"]
if loaded_torch_version != get_torch_version():
raise RuntimeError(
"Unable to load outdated profile. It was saved with torch version: "
f"{loaded_torch_version} but the current torch version is: {get_torch_version()}"
)
operators_data = loaded_data["operators"]
return {
operator: {deserialize_op_profile(profile) for profile in profiles}
for operator, profiles in operators_data.items()
}
def load_op_profiles(f: FileLike) -> dict[str, set[OpProfile]]:
"""
Loads the saved operator profiles from `save_op_profiles`.
"""
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
with open(f) as file:
yaml_str = file.read()
elif isinstance(f, io.BytesIO):
yaml_str = f.read().decode("utf-8")
else:
raise ValueError(f"Invalid type of file {f}")
return read_profiles_from_yaml(yaml_str)