mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Bugfix][Precompile][vLLM] Support for pickling einops for aot_autograd serialization in vLLM (#165359)
Fixes issue with compiling `Qwen2_5_vl` in https://github.com/vllm-project/vllm/pull/23207 (issue happens with `aot_autograd_cache`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165359 Approved by: https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
ffc7552e01
commit
83f9baf413
@ -427,12 +427,9 @@ class _OpPickleData:
|
|||||||
return cls._pickle_op(name, _OpOverloadPickleData, options)
|
return cls._pickle_op(name, _OpOverloadPickleData, options)
|
||||||
elif isinstance(op, torch._ops.OpOverloadPacket):
|
elif isinstance(op, torch._ops.OpOverloadPacket):
|
||||||
return cls._pickle_op(name, _OpOverloadPacketPickleData, options)
|
return cls._pickle_op(name, _OpOverloadPacketPickleData, options)
|
||||||
elif name.startswith(("builtins.", "math.", "torch.")):
|
elif name.startswith(_OpFunctionPickleData.SUPPORTED_ROOTS):
|
||||||
root, detail = name.split(".", 1)
|
root, detail = name.split(".", 1)
|
||||||
return _OpBuiltinPickleData(root, detail)
|
return _OpFunctionPickleData(root, detail)
|
||||||
elif name.startswith("operator."):
|
|
||||||
_, detail = name.split(".", 1)
|
|
||||||
return _OpOperatorPickleData(detail)
|
|
||||||
else:
|
else:
|
||||||
# TODO: raise a BypassFxGraphCache so we will just bypass this one...
|
# TODO: raise a BypassFxGraphCache so we will just bypass this one...
|
||||||
raise NotImplementedError(f"TARGET: {type(op)} {op} {name}")
|
raise NotImplementedError(f"TARGET: {type(op)} {op} {name}")
|
||||||
@ -506,7 +503,16 @@ class _OpOverloadPacketPickleData(_OpPickleData):
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
class _OpBuiltinPickleData(_OpPickleData):
|
class _OpFunctionPickleData(_OpPickleData):
|
||||||
|
"""
|
||||||
|
Supports pickling a set of standard/common functions
|
||||||
|
These must be prefixed with the full namespace in order to properly
|
||||||
|
be pickled (i.e `einops.rearrange` and not `from einops import rearrange`)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Static variable listing supported root names
|
||||||
|
SUPPORTED_ROOTS = ("builtins.", "math.", "torch.", "operator.", "einops.")
|
||||||
|
|
||||||
def __init__(self, root: str, name: str) -> None:
|
def __init__(self, root: str, name: str) -> None:
|
||||||
self.root = root
|
self.root = root
|
||||||
self.name = name
|
self.name = name
|
||||||
@ -520,20 +526,18 @@ class _OpBuiltinPickleData(_OpPickleData):
|
|||||||
return self._getattr_by_name(math, self.name)
|
return self._getattr_by_name(math, self.name)
|
||||||
elif self.root == "torch":
|
elif self.root == "torch":
|
||||||
return self._getattr_by_name(torch, self.name)
|
return self._getattr_by_name(torch, self.name)
|
||||||
|
elif self.root == "operator":
|
||||||
|
import operator
|
||||||
|
|
||||||
|
return self._getattr_by_name(operator, self.name)
|
||||||
|
elif self.root == "einops":
|
||||||
|
import einops
|
||||||
|
|
||||||
|
return self._getattr_by_name(einops, self.name)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class _OpOperatorPickleData(_OpPickleData):
|
|
||||||
def __init__(self, name: str) -> None:
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
def unpickle(self, unpickle_state: _UnpickleState) -> object:
|
|
||||||
import operator
|
|
||||||
|
|
||||||
return self._getattr_by_name(operator, self.name)
|
|
||||||
|
|
||||||
|
|
||||||
class _GraphPickleData:
|
class _GraphPickleData:
|
||||||
def __init__(self, graph: torch.fx.Graph, options: Options) -> None:
|
def __init__(self, graph: torch.fx.Graph, options: Options) -> None:
|
||||||
self.tracer_cls = graph._tracer_cls
|
self.tracer_cls = graph._tracer_cls
|
||||||
|
Reference in New Issue
Block a user