mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Save/load op profiles (#151817)
Add ability to save/load op profiles into a yaml file:
```python
op_profile = self.get_sample_op_profile()
# Save
save_op_profiles(op_profile, "op_profile.yaml")
# Load
loaded = load_op_profiles("op_profile.yaml")
assert op_profile == loaded
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151817
Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
8358eca2ce
commit
8f420a500a
1
setup.py
1
setup.py
@ -1181,6 +1181,7 @@ def main():
|
||||
extras_require = {
|
||||
"optree": ["optree>=0.13.0"],
|
||||
"opt-einsum": ["opt-einsum>=3.3"],
|
||||
"pyyaml": ["pyyaml"],
|
||||
}
|
||||
|
||||
# Read in README.md for our long_description
|
||||
|
||||
@ -2,16 +2,20 @@
|
||||
# ruff: noqa: F841
|
||||
|
||||
import collections
|
||||
import io
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import typing
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
import torch._custom_ops as custom_ops
|
||||
import torch.testing._internal.optests as optests
|
||||
@ -20,7 +24,15 @@ import torch.utils.cpp_extension
|
||||
from functorch import make_fx
|
||||
from torch import Tensor
|
||||
from torch._custom_op.impl import CustomOp, infer_schema
|
||||
from torch._library.fake_profile import MissingOpProfile, OpProfile, TensorMetadata
|
||||
from torch._library.fake_profile import (
|
||||
generate_yaml_from_profiles,
|
||||
load_op_profiles,
|
||||
MissingOpProfile,
|
||||
OpProfile,
|
||||
read_profiles_from_yaml,
|
||||
save_op_profiles,
|
||||
TensorMetadata,
|
||||
)
|
||||
from torch._library.infer_schema import tuple_to_list
|
||||
from torch._utils_internal import get_file_path_2 # @manual
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
@ -39,6 +51,7 @@ from torch.testing._internal.common_utils import (
|
||||
scoped_load_inline,
|
||||
skipIfTorchDynamo,
|
||||
subtest,
|
||||
TemporaryFileName,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.custom_op_db import numpy_nonzero
|
||||
@ -4601,6 +4614,46 @@ class TestOpProfiles(TestCase):
|
||||
with fm:
|
||||
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.bfloat16)
|
||||
|
||||
def test_yaml(self):
|
||||
op_profiles = self.get_sample_op_profile()
|
||||
yaml_str = generate_yaml_from_profiles(op_profiles)
|
||||
loaded = read_profiles_from_yaml(yaml_str)
|
||||
self.assertEqual(op_profiles, loaded)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
||||
def test_save_to_file(self):
|
||||
op_profile = self.get_sample_op_profile()
|
||||
|
||||
# Saving with buffer
|
||||
buffer = io.BytesIO()
|
||||
save_op_profiles(op_profile, buffer)
|
||||
buffer.seek(0)
|
||||
loaded = load_op_profiles(buffer)
|
||||
self.assertEqual(op_profile, loaded)
|
||||
|
||||
# Saving with file
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
save_op_profiles(op_profile, f.name)
|
||||
f.seek(0)
|
||||
loaded = load_op_profiles(f.name)
|
||||
self.assertEqual(op_profile, loaded)
|
||||
|
||||
# Saving with Path
|
||||
with TemporaryFileName() as fname:
|
||||
path = Path(fname)
|
||||
save_op_profiles(op_profile, path)
|
||||
loaded = load_op_profiles(path)
|
||||
self.assertEqual(op_profile, loaded)
|
||||
|
||||
def test_version(self):
|
||||
op_profiles = self.get_sample_op_profile()
|
||||
yaml_str = generate_yaml_from_profiles(op_profiles)
|
||||
loaded = yaml.safe_load(yaml_str)
|
||||
loaded["torch_version"] = "2.7"
|
||||
yaml_str = yaml.dump(loaded, sort_keys=False)
|
||||
with self.assertRaisesRegex(RuntimeError, "Unable to load outdated profile"):
|
||||
loaded = read_profiles_from_yaml(yaml_str)
|
||||
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._library.custom_ops import _maybe_get_opdef
|
||||
from torch.types import FileLike
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -182,3 +185,139 @@ def unsafe_generate_fake_kernels(op_profiles: dict[str, set[OpProfile]]) -> Gene
|
||||
opdef = _maybe_get_opdef(op_str)
|
||||
assert opdef is not None
|
||||
opdef.register_fake(old_fake)
|
||||
|
||||
|
||||
def get_torch_version() -> str:
|
||||
version = torch.__version__.split(".")
|
||||
return f"{int(version[0])}.{int(version[1])}"
|
||||
|
||||
|
||||
def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str:
|
||||
"""
|
||||
Generates a yaml string from the given operator profiles which can be saved
|
||||
to a file. The yaml string can be loaded back into an operator profile
|
||||
structure using `read_profiles_from_yaml`.
|
||||
"""
|
||||
import yaml
|
||||
|
||||
from torch._export.serde.serialize import (
|
||||
_TORCH_TO_SERIALIZE_DTYPE,
|
||||
_TORCH_TO_SERIALIZE_LAYOUT,
|
||||
)
|
||||
|
||||
def serialize_tensor_metadata(t: TensorMetadata) -> dict:
|
||||
return {
|
||||
"rank": t.rank,
|
||||
"dtype": _TORCH_TO_SERIALIZE_DTYPE[t.dtype].value,
|
||||
"device": str(t.device),
|
||||
"layout": _TORCH_TO_SERIALIZE_LAYOUT[t.layout].value,
|
||||
}
|
||||
|
||||
def serialize_op_profile(op: OpProfile) -> dict:
|
||||
return {
|
||||
"args_profile": [
|
||||
serialize_tensor_metadata(arg)
|
||||
for arg in op.args_profile
|
||||
if arg is not None
|
||||
],
|
||||
"out_profile": (
|
||||
serialize_tensor_metadata(op.out_profile)
|
||||
if isinstance(op.out_profile, TensorMetadata)
|
||||
else [serialize_tensor_metadata(out) for out in op.out_profile]
|
||||
),
|
||||
}
|
||||
|
||||
serialized_data = {
|
||||
operator: [serialize_op_profile(profile) for profile in profiles]
|
||||
for operator, profiles in op_profiles.items()
|
||||
}
|
||||
return yaml.dump(
|
||||
{"torch_version": get_torch_version(), "operators": serialized_data},
|
||||
sort_keys=False,
|
||||
)
|
||||
|
||||
|
||||
def save_op_profiles(op_profiles: dict[str, set[OpProfile]], f: FileLike) -> None:
|
||||
"""
|
||||
Serializes the given operator profiles into a yaml format and saves it to
|
||||
the given file. The operator profile can be loaded back using `load_op_profiles`.
|
||||
"""
|
||||
yaml_str = generate_yaml_from_profiles(op_profiles)
|
||||
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
|
||||
with open(f, "w") as file:
|
||||
file.write(yaml_str)
|
||||
|
||||
elif isinstance(f, io.BytesIO):
|
||||
f.write(yaml_str.encode("utf-8"))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid type of file {f}")
|
||||
|
||||
|
||||
def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]:
|
||||
"""
|
||||
Reads the yaml saved by `save_op_profiles` and returns the operator profiles.
|
||||
"""
|
||||
import yaml
|
||||
|
||||
from torch._export.serde.serialize import (
|
||||
_SERIALIZE_TO_TORCH_DTYPE,
|
||||
_SERIALIZE_TO_TORCH_LAYOUT,
|
||||
)
|
||||
|
||||
def deserialize_tensor_metadata(data: dict) -> TensorMetadata:
|
||||
return TensorMetadata(
|
||||
rank=data["rank"],
|
||||
dtype=_SERIALIZE_TO_TORCH_DTYPE[data["dtype"]],
|
||||
device=torch.device(data["device"]),
|
||||
layout=_SERIALIZE_TO_TORCH_LAYOUT[data["layout"]],
|
||||
)
|
||||
|
||||
def deserialize_op_profile(data: dict) -> OpProfile:
|
||||
args_profile = tuple(
|
||||
deserialize_tensor_metadata(arg) for arg in data["args_profile"]
|
||||
)
|
||||
out_profile_data = data["out_profile"]
|
||||
out_profile: Union[tuple[TensorMetadata], TensorMetadata] = (
|
||||
tuple(deserialize_tensor_metadata(out) for out in out_profile_data) # type: ignore[assignment]
|
||||
if isinstance(out_profile_data, list)
|
||||
else deserialize_tensor_metadata(out_profile_data)
|
||||
)
|
||||
return OpProfile(args_profile=args_profile, out_profile=out_profile) # type: ignore[arg-type]
|
||||
|
||||
loaded_data = yaml.safe_load(yaml_str)
|
||||
loaded_torch_version = loaded_data["torch_version"]
|
||||
|
||||
if loaded_torch_version != get_torch_version():
|
||||
raise RuntimeError(
|
||||
"Unable to load outdated profile. It was saved with torch version: "
|
||||
f"{loaded_torch_version} but the current torch version is: {get_torch_version()}"
|
||||
)
|
||||
|
||||
operators_data = loaded_data["operators"]
|
||||
return {
|
||||
operator: {deserialize_op_profile(profile) for profile in profiles}
|
||||
for operator, profiles in operators_data.items()
|
||||
}
|
||||
|
||||
|
||||
def load_op_profiles(f: FileLike) -> dict[str, set[OpProfile]]:
|
||||
"""
|
||||
Loads the saved operator profiles from `save_op_profiles`.
|
||||
"""
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
|
||||
with open(f) as file:
|
||||
yaml_str = file.read()
|
||||
|
||||
elif isinstance(f, io.BytesIO):
|
||||
yaml_str = f.read().decode("utf-8")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid type of file {f}")
|
||||
|
||||
return read_profiles_from_yaml(yaml_str)
|
||||
|
||||
Reference in New Issue
Block a user