Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: youkaichao <youkaichao@126.com>
39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
import torch
|
|
|
|
from vllm.compilation.config import CompilationConfig
|
|
# yapf: disable
|
|
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
|
from vllm.distributed import (
|
|
get_tensor_model_parallel_world_size as get_tp_world_size)
|
|
from vllm.distributed import model_parallel_is_initialized as p_is_init
|
|
# yapf: enable
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class InductorPass(ABC):
|
|
|
|
@abstractmethod
|
|
def __call__(self, graph: torch.fx.Graph):
|
|
raise NotImplementedError
|
|
|
|
def __init__(self, config: CompilationConfig):
|
|
self.config = config
|
|
|
|
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
|
if stage in self.config.dump_graph_stages:
|
|
# Make sure filename includes rank in the distributed setting
|
|
parallel = p_is_init() and get_tp_world_size() > 1
|
|
rank = f"-{get_tp_rank()}" if parallel else ""
|
|
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
|
|
|
|
logger.info("Printing graph to %s", filepath)
|
|
with open(filepath, "w") as f:
|
|
src = graph.python_code(root_module="self", verbose=True).src
|
|
# Add imports so it's not full of errors
|
|
print("import torch; from torch import device", file=f)
|
|
print(src, file=f)
|