[v1][torch.compile] support managing cudagraph buffer (#10203)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
youkaichao
2024-11-11 11:10:27 -08:00
committed by GitHub
parent d7a4f2207b
commit 330e82d34a
4 changed files with 59 additions and 8 deletions

View File

@ -1,4 +1,5 @@
{
"use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"]
"non_cudagraph_ops": ["silly.attention"],
"cudagraph_copy_inputs": true
}

View File

@ -80,7 +80,7 @@ def test_simple_piecewise_compile():
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
input_buffer = torch.randn(100).cuda()
inputs = torch.randn(100).cuda()
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
@ -92,15 +92,15 @@ def test_simple_piecewise_compile():
):
with set_compile_context([1, 2]):
model(input_buffer)
model(inputs)
model(input_buffer[:2])
model(input_buffer[:1])
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
input_buffer[:2].zero_()
input = torch.zeros(2).cuda()
global global_counter
global_counter = 0
output = model(input_buffer[:2])
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))

View File

@ -389,6 +389,8 @@ class VllmBackend:
returned_callable: Callable
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable]
sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor]
def __init__(self, post_grad_passes: Sequence[Callable] = ()):
global global_graph_pool
@ -401,6 +403,9 @@ class VllmBackend:
self.graph_pool = global_graph_pool
self.post_grad_passes = post_grad_passes
self.sym_tensor_indices = []
self.input_buffers = []
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
@ -461,8 +466,47 @@ class VllmBackend:
self._called = True
if not self.compilation_configs.use_cudagraph or \
not self.compilation_configs.cudagraph_copy_inputs:
return self.split_gm
# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
fake_args = [
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in example_inputs
]
# index of tensors that have symbolic shapes (batch size)
self.sym_tensor_indices = [
i for i, x in enumerate(fake_args)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
]
# compiler managed cudagraph input buffers
# we assume the first run with symbolic shapes
# has the maximum size among all the tensors
self.input_buffers = [
example_inputs[x].clone() for x in self.sym_tensor_indices
]
def copy_and_call(*args):
list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index]
runtime_shape = runtime_tensor.shape[0]
static_tensor = self.input_buffers[i][:runtime_shape]
# copy the tensor to the static buffer
static_tensor.copy_(runtime_tensor)
# replace the tensor in the list_args to the static buffer
list_args[index] = static_tensor
return self.split_gm(*list_args)
return copy_and_call
@dataclasses.dataclass
class ConcreteSizeEntry:

View File

@ -32,6 +32,11 @@ class CompilationConfig(BaseModel):
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
@ -78,6 +83,7 @@ class CompilationConfig(BaseModel):
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))