[torch.compile] auto infer dynamic_arg_dims from type annotation (#9589)

This commit is contained in:
youkaichao
2024-10-22 13:43:37 -07:00
committed by GitHub
parent cd5601ac37
commit 17c79f3c36
3 changed files with 65 additions and 19 deletions

View File

@ -1,24 +1,58 @@
import inspect
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union
import torch
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
logger = init_logger(__name__)
def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def support_torch_compile(
cls: Optional[type] = None,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
"""
A decorator to add support for compiling the forward method of a class.
Usage 1: use directly as a decorator without arguments:
```python
@support_torch_compile
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
Usage 2: use as a decorator with arguments:
```python
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers.
Depending on the value of arguments:
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
of the `forward` method, based on the following default rules:
- if the argument is annotated as `torch.Tensor` or
`Optional[torch.Tensor]`, the first dimension will be
marked as dynamic.
- if the argument is annotated as `IntermediateTensors`, the first
dimension of all the tensors in the intermediate tensors
will be marked as dynamic.
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
@ -38,11 +72,35 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
if not hasattr(cls, 'forward'):
raise TypeError("decorated class should have a forward method.")
sig = inspect.signature(cls.forward)
for k in dynamic_arg_dims:
inferred_dynamic_arg_dims = dynamic_arg_dims
if inferred_dynamic_arg_dims is None:
inferred_dynamic_arg_dims = {}
for k, v in sig.parameters.items():
if v.annotation in [
torch.Tensor, Optional[torch.Tensor],
IntermediateTensors, Optional[IntermediateTensors]
]:
inferred_dynamic_arg_dims[k] = 0
logger.debug(("Inferred dynamic dimensions for "
"forward method of %s: %s"), cls,
list(inferred_dynamic_arg_dims.keys()))
if len(inferred_dynamic_arg_dims) == 0:
raise ValueError(
"No dynamic dimensions found in the forward method of "
f"{cls}. Please provide dynamic_arg_dims explicitly.")
for k in inferred_dynamic_arg_dims:
if k not in sig.parameters:
raise ValueError(
f"Argument {k} not found in the forward method of {cls}")
return _support_torch_compile(cls, dynamic_arg_dims)
return _support_torch_compile(cls, inferred_dynamic_arg_dims)
if cls is not None:
# use `support_torch_compile` as a decorator without arguments
assert isinstance(cls, type)
return cls_decorator_helper(cls)
return cls_decorator_helper

View File

@ -241,13 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
@support_torch_compile
class Gemma2Model(nn.Module):
def __init__(

View File

@ -268,13 +268,7 @@ class LlamaDecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(