mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update
[ghstack-poisoned]
This commit is contained in:
@ -10,6 +10,7 @@ from typing_extensions import override, Self
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import TracingContext
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor
|
||||
from torch._subclasses.meta_utils import (
|
||||
MetaConverter,
|
||||
@ -422,7 +423,14 @@ class _OpPickleData:
|
||||
if isinstance(op, str):
|
||||
return _OpStrPickleData(op)
|
||||
|
||||
if isinstance(getattr(op, "__wrapped__", None), AOTCompiledArtifact):
|
||||
assert hasattr(op, "__wrapped__")
|
||||
artifact = op.__wrapped__
|
||||
assert isinstance(artifact, AOTCompiledArtifact)
|
||||
return _OpPrecompiledPickleData(artifact)
|
||||
|
||||
name = torch.fx.Node._pretty_print_target(op)
|
||||
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
return cls._pickle_op(name, _OpOverloadPickleData, options)
|
||||
elif isinstance(op, torch._ops.OpOverloadPacket):
|
||||
@ -503,6 +511,21 @@ class _OpOverloadPacketPickleData(_OpPickleData):
|
||||
return obj
|
||||
|
||||
|
||||
class _OpPrecompiledPickleData(_OpPickleData):
|
||||
def __init__(self, artifact: AOTCompiledArtifact) -> None:
|
||||
self.contents = artifact.serialize()
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> object:
|
||||
precompiled_artifact = AOTCompiledArtifact.deserialize(self.contents)
|
||||
import functools
|
||||
|
||||
@functools.wraps(precompiled_artifact)
|
||||
def wrapped(*args: Any) -> Any:
|
||||
return precompiled_artifact(*args)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
class _OpFunctionPickleData(_OpPickleData):
|
||||
"""
|
||||
Supports pickling a set of standard/common functions
|
||||
|
Reference in New Issue
Block a user