[Bugfix][Precompile][vLLM] Support for pickling einops for aot_autograd serialization in vLLM (#165359)

Fixes issue with compiling `Qwen2_5_vl` in https://github.com/vllm-project/vllm/pull/23207 (issue happens with `aot_autograd_cache`)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165359
Approved by: https://github.com/jamesjwu
This commit is contained in:
Lucas Kabela
2025-10-15 20:00:20 +00:00
committed by PyTorch MergeBot
parent ffc7552e01
commit 83f9baf413

View File

@ -427,12 +427,9 @@ class _OpPickleData:
return cls._pickle_op(name, _OpOverloadPickleData, options) return cls._pickle_op(name, _OpOverloadPickleData, options)
elif isinstance(op, torch._ops.OpOverloadPacket): elif isinstance(op, torch._ops.OpOverloadPacket):
return cls._pickle_op(name, _OpOverloadPacketPickleData, options) return cls._pickle_op(name, _OpOverloadPacketPickleData, options)
elif name.startswith(("builtins.", "math.", "torch.")): elif name.startswith(_OpFunctionPickleData.SUPPORTED_ROOTS):
root, detail = name.split(".", 1) root, detail = name.split(".", 1)
return _OpBuiltinPickleData(root, detail) return _OpFunctionPickleData(root, detail)
elif name.startswith("operator."):
_, detail = name.split(".", 1)
return _OpOperatorPickleData(detail)
else: else:
# TODO: raise a BypassFxGraphCache so we will just bypass this one... # TODO: raise a BypassFxGraphCache so we will just bypass this one...
raise NotImplementedError(f"TARGET: {type(op)} {op} {name}") raise NotImplementedError(f"TARGET: {type(op)} {op} {name}")
@ -506,7 +503,16 @@ class _OpOverloadPacketPickleData(_OpPickleData):
return obj return obj
class _OpBuiltinPickleData(_OpPickleData): class _OpFunctionPickleData(_OpPickleData):
"""
Supports pickling a set of standard/common functions
These must be prefixed with the full namespace in order to properly
be pickled (i.e `einops.rearrange` and not `from einops import rearrange`)
"""
# Static variable listing supported root names
SUPPORTED_ROOTS = ("builtins.", "math.", "torch.", "operator.", "einops.")
def __init__(self, root: str, name: str) -> None: def __init__(self, root: str, name: str) -> None:
self.root = root self.root = root
self.name = name self.name = name
@ -520,20 +526,18 @@ class _OpBuiltinPickleData(_OpPickleData):
return self._getattr_by_name(math, self.name) return self._getattr_by_name(math, self.name)
elif self.root == "torch": elif self.root == "torch":
return self._getattr_by_name(torch, self.name) return self._getattr_by_name(torch, self.name)
elif self.root == "operator":
import operator
return self._getattr_by_name(operator, self.name)
elif self.root == "einops":
import einops
return self._getattr_by_name(einops, self.name)
else: else:
raise NotImplementedError raise NotImplementedError
class _OpOperatorPickleData(_OpPickleData):
def __init__(self, name: str) -> None:
self.name = name
def unpickle(self, unpickle_state: _UnpickleState) -> object:
import operator
return self._getattr_by_name(operator, self.name)
class _GraphPickleData: class _GraphPickleData:
def __init__(self, graph: torch.fx.Graph, options: Options) -> None: def __init__(self, graph: torch.fx.Graph, options: Options) -> None:
self.tracer_cls = graph._tracer_cls self.tracer_cls = graph._tracer_cls