Compare commits
	
		
			193 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 651c614aa4 | |||
| d3a5bd9fb7 | |||
| e8ef4c0820 | |||
| 348897af31 | |||
| 9d9072a069 | |||
| 928de46888 | |||
| 29678cd213 | |||
| d0740dff1b | |||
| de89472897 | |||
| e7c8555d06 | |||
| ec3b5ce9cc | |||
| 6368e777a8 | |||
| 875afe38ab | |||
| ee8217e5be | |||
| 980dd4a2c4 | |||
| 8285736840 | |||
| 91fce82c6f | |||
| ac5cf86aa6 | |||
| 6a6119554c | |||
| b95ee898fe | |||
| 9eed4d1f3e | |||
| 6b5296aa3a | |||
| ee92b58b3a | |||
| 09ff7f106a | |||
| acbed3ef40 | |||
| 66d18a7fb0 | |||
| ba0bfd40e2 | |||
| 84e4e37d14 | |||
| a60b353005 | |||
| ebe4d1db3a | |||
| b5a10eb0ef | |||
| 0967102c6d | |||
| e2fb71ec9f | |||
| f936657eb6 | |||
| 6f88f762bf | |||
| 202351d5bf | |||
| 2e8e49fce3 | |||
| a8e98aee0c | |||
| bb1ba58f06 | |||
| 7bedab5748 | |||
| 20f7cc4cde | |||
| 649aa730c5 | |||
| a19bc5c628 | |||
| 28e616c4e3 | |||
| 30e775281d | |||
| 21877b0d75 | |||
| cf5cb1e33e | |||
| 03ffd0a022 | |||
| a425bd9a9a | |||
| bbbf86565f | |||
| 9f6be8692e | |||
| f187877945 | |||
| 947b794146 | |||
| 8d926e91f1 | |||
| 4ee52bb169 | |||
| 7d7e3b78a3 | |||
| f98b745a81 | |||
| 2d1e86f1b1 | |||
| 1ac4ccf73c | |||
| 2ac4d5e2bf | |||
| 3302f0aef3 | |||
| 6f2dd6c37e | |||
| bc0644574c | |||
| 400b8289f7 | |||
| c1026311b5 | |||
| 2b1c116b5a | |||
| cc796b1358 | |||
| f029ef94d7 | |||
| 95592fa00a | |||
| fbe66e1d0b | |||
| 90979c38f8 | |||
| e21d7687a9 | |||
| ff36139ffc | |||
| e3e79e9e8a | |||
| b9fe4616f9 | |||
| 64ca424e75 | |||
| b5f93d0631 | |||
| a58936966f | |||
| dd54a4b026 | |||
| eda1a7cad3 | |||
| f04908cae7 | |||
| ab019eea75 | |||
| 9841d48a10 | |||
| 3272d7a0b7 | |||
| 0bb1e885a0 | |||
| d6545ad22e | |||
| 90eb3f43ca | |||
| e67b4f2c2a | |||
| d6770d1f23 | |||
| b9cecc2635 | |||
| 898285c9bf | |||
| a62de9ecfd | |||
| 4042d192f5 | |||
| 1117aa1411 | |||
| 080438477f | |||
| 4b5bcf8906 | |||
| 852ef5b4f5 | |||
| db09d4ad83 | |||
| c957c741d9 | |||
| c07ece5ca4 | |||
| 7a9c20c715 | |||
| 005ba458b5 | |||
| 320a622ec4 | |||
| c9927c1a6a | |||
| fbd80ad409 | |||
| 22379d5513 | |||
| 1696725879 | |||
| 002800f081 | |||
| e15932bb60 | |||
| ce741ba3e4 | |||
| bf87484efa | |||
| 8ce9c50d40 | |||
| 32b6816e55 | |||
| c128d69856 | |||
| 55b28b1eee | |||
| e11222333f | |||
| 28873a2799 | |||
| 0080d8329d | |||
| 0d93f15694 | |||
| becd7a56f1 | |||
| 75471386de | |||
| d2b2eed67c | |||
| 4b6f069b6f | |||
| 791d79de32 | |||
| 94d2f59895 | |||
| 75c0ca9d43 | |||
| 2a4ec90854 | |||
| 85ebcda94d | |||
| d64bf1646c | |||
| a41c20435e | |||
| eedac9dba0 | |||
| 14f9c72bfd | |||
| ad5f2fe34c | |||
| 4f8584756d | |||
| 65fc1c3127 | |||
| c393af6cd7 | |||
| 0c04ce3234 | |||
| 73b3de79ea | |||
| d1744376ae | |||
| 805de738f6 | |||
| 1b151ed181 | |||
| e06f504a76 | |||
| 462ae5220a | |||
| 66c54aa9c3 | |||
| 735ecfff61 | |||
| a57d13cc96 | |||
| 79af7e96a0 | |||
| 621980bdc0 | |||
| aa84c92ef6 | |||
| f7389f4763 | |||
| 55fe8a81ec | |||
| e8ddc08ec8 | |||
| 1b0bd0fe8a | |||
| 20044cab7a | |||
| 64f23c2900 | |||
| d4c7755ca8 | |||
| aa39e42c5a | |||
| 953f28cf9a | |||
| c0d00f5be6 | |||
| 58a072be15 | |||
| 82ad323dee | |||
| df5dd3c68e | |||
| 2d867b55fa | |||
| d7a1c6d614 | |||
| 7d5a155e4a | |||
| 1dde34e0f8 | |||
| 6fc2a38b11 | |||
| c487a221ee | |||
| 9925c17940 | |||
| 8c4b2592fb | |||
| cf21a9bd5c | |||
| 16c3e295a8 | |||
| bda41c70dd | |||
| 453bafb96f | |||
| 328d231c17 | |||
| b4b195b360 | |||
| 20b0d88d16 | |||
| 2bdea7ac11 | |||
| 58df2883cb | |||
| 6d7d95a70a | |||
| 96853af5a8 | |||
| dbed69058c | |||
| 7b6ae94059 | |||
| c6dfc3cdbe | |||
| 51be365143 | |||
| c894836108 | |||
| 75beba29b5 | |||
| ddfdf470ae | |||
| b6fbb9a565 | |||
| 2179e4f4c5 | |||
| a945fcc2ae | |||
| be54f8e5c4 | |||
| b396cb4998 | 
							
								
								
									
										102
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,102 @@
 | 
				
			|||||||
 | 
					# This workflow will upload a Python Package to Release asset
 | 
				
			||||||
 | 
					# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					name: Create Release
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					on:
 | 
				
			||||||
 | 
					  push:
 | 
				
			||||||
 | 
					    tags:
 | 
				
			||||||
 | 
					      - v*
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Needed to create release and upload assets
 | 
				
			||||||
 | 
					permissions:
 | 
				
			||||||
 | 
					  contents: write
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					jobs:
 | 
				
			||||||
 | 
					  release:
 | 
				
			||||||
 | 
					    # Retrieve tag and create release
 | 
				
			||||||
 | 
					    name: Create Release
 | 
				
			||||||
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
 | 
					    outputs:
 | 
				
			||||||
 | 
					      upload_url: ${{ steps.create_release.outputs.upload_url }}
 | 
				
			||||||
 | 
					    steps:
 | 
				
			||||||
 | 
					      - name: Checkout
 | 
				
			||||||
 | 
					        uses: actions/checkout@v3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Extract branch info
 | 
				
			||||||
 | 
					        shell: bash
 | 
				
			||||||
 | 
					        run: |
 | 
				
			||||||
 | 
					          echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Create Release
 | 
				
			||||||
 | 
					        id: create_release
 | 
				
			||||||
 | 
					        uses: "actions/github-script@v6"
 | 
				
			||||||
 | 
					        env:
 | 
				
			||||||
 | 
					          RELEASE_TAG: ${{ env.release_tag }}
 | 
				
			||||||
 | 
					        with:
 | 
				
			||||||
 | 
					          github-token: "${{ secrets.GITHUB_TOKEN }}"
 | 
				
			||||||
 | 
					          script: |
 | 
				
			||||||
 | 
					            const script = require('.github/workflows/scripts/create_release.js')
 | 
				
			||||||
 | 
					            await script(github, context, core)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  wheel:
 | 
				
			||||||
 | 
					    name: Build Wheel
 | 
				
			||||||
 | 
					    runs-on: ${{ matrix.os }}
 | 
				
			||||||
 | 
					    needs: release
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    strategy:
 | 
				
			||||||
 | 
					      fail-fast: false
 | 
				
			||||||
 | 
					      matrix:
 | 
				
			||||||
 | 
					          os: ['ubuntu-20.04']
 | 
				
			||||||
 | 
					          python-version: ['3.8', '3.9', '3.10', '3.11']
 | 
				
			||||||
 | 
					          pytorch-version: ['2.0.1']
 | 
				
			||||||
 | 
					          cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    steps:
 | 
				
			||||||
 | 
					      - name: Checkout
 | 
				
			||||||
 | 
					        uses: actions/checkout@v3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Set up Linux Env
 | 
				
			||||||
 | 
					        if: ${{ runner.os == 'Linux' }}
 | 
				
			||||||
 | 
					        run: |
 | 
				
			||||||
 | 
					          bash -x .github/workflows/scripts/env.sh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Set up Python
 | 
				
			||||||
 | 
					        uses: actions/setup-python@v4
 | 
				
			||||||
 | 
					        with:
 | 
				
			||||||
 | 
					            python-version: ${{ matrix.python-version }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Install CUDA ${{ matrix.cuda-version }}
 | 
				
			||||||
 | 
					        run: |
 | 
				
			||||||
 | 
					          bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
 | 
				
			||||||
 | 
					        run: |
 | 
				
			||||||
 | 
					          bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Build wheel
 | 
				
			||||||
 | 
					        shell: bash
 | 
				
			||||||
 | 
					        run: |
 | 
				
			||||||
 | 
					          bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
 | 
				
			||||||
 | 
					          wheel_name=$(ls dist/*whl | xargs -n 1 basename)
 | 
				
			||||||
 | 
					          asset_name=${wheel_name//"linux"/"manylinux1"}
 | 
				
			||||||
 | 
					          echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
 | 
				
			||||||
 | 
					          echo "asset_name=${asset_name}" >> $GITHUB_ENV
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      - name: Upload Release Asset
 | 
				
			||||||
 | 
					        uses: actions/upload-release-asset@v1
 | 
				
			||||||
 | 
					        env:
 | 
				
			||||||
 | 
					          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
 | 
				
			||||||
 | 
					        with:
 | 
				
			||||||
 | 
					          upload_url: ${{ needs.release.outputs.upload_url }}
 | 
				
			||||||
 | 
					          asset_path: ./dist/${{ env.wheel_name }}
 | 
				
			||||||
 | 
					          asset_name: ${{ env.asset_name }}
 | 
				
			||||||
 | 
					          asset_content_type: application/*
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
 | 
				
			||||||
 | 
					      # - name: Publish package
 | 
				
			||||||
 | 
					      #   uses: pypa/gh-action-pypi-publish@release/v1.8
 | 
				
			||||||
 | 
					      #   with:
 | 
				
			||||||
 | 
					      #     repository-url: https://test.pypi.org/legacy/
 | 
				
			||||||
 | 
					      #     password: ${{ secrets.PYPI_API_TOKEN }}
 | 
				
			||||||
 | 
					      #     skip-existing: true
 | 
				
			||||||
							
								
								
									
										2
									
								
								.github/workflows/pylint.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						@ -28,4 +28,4 @@ jobs:
 | 
				
			|||||||
        pip install pylint==2.8.2
 | 
					        pip install pylint==2.8.2
 | 
				
			||||||
    - name: Analysing the code with pylint
 | 
					    - name: Analysing the code with pylint
 | 
				
			||||||
      run: |
 | 
					      run: |
 | 
				
			||||||
        pylint vllm
 | 
					        pylint vllm tests
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										15
									
								
								.github/workflows/scripts/build.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					#!/bin/bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					python_executable=python$1
 | 
				
			||||||
 | 
					cuda_home=/usr/local/cuda-$2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Update paths
 | 
				
			||||||
 | 
					PATH=${cuda_home}/bin:$PATH
 | 
				
			||||||
 | 
					LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Install requirements
 | 
				
			||||||
 | 
					$python_executable -m pip install wheel packaging
 | 
				
			||||||
 | 
					$python_executable -m pip install -r requirements.txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Build
 | 
				
			||||||
 | 
					$python_executable setup.py bdist_wheel --dist-dir=dist
 | 
				
			||||||
							
								
								
									
										20
									
								
								.github/workflows/scripts/create_release.js
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,20 @@
 | 
				
			|||||||
 | 
					// Uses Github's API to create the release and wait for result.
 | 
				
			||||||
 | 
					// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					module.exports = async (github, context, core) => {
 | 
				
			||||||
 | 
						try {
 | 
				
			||||||
 | 
							const response = await github.rest.repos.createRelease({
 | 
				
			||||||
 | 
								draft: false,
 | 
				
			||||||
 | 
								generate_release_notes: true,
 | 
				
			||||||
 | 
								name: process.env.RELEASE_TAG,
 | 
				
			||||||
 | 
								owner: context.repo.owner,
 | 
				
			||||||
 | 
								prerelease: false,
 | 
				
			||||||
 | 
								repo: context.repo.repo,
 | 
				
			||||||
 | 
								tag_name: process.env.RELEASE_TAG,
 | 
				
			||||||
 | 
							});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							core.setOutput('upload_url', response.data.upload_url);
 | 
				
			||||||
 | 
						} catch (error) {
 | 
				
			||||||
 | 
							core.setFailed(error.message);
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										18
									
								
								.github/workflows/scripts/cuda-install.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,18 @@
 | 
				
			|||||||
 | 
					#!/bin/bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Replace '.' with '-' ex: 11.8 -> 11-8
 | 
				
			||||||
 | 
					cuda_version=$(echo $1 | tr "." "-")
 | 
				
			||||||
 | 
					# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
 | 
				
			||||||
 | 
					OS=$(echo $2 | tr -d ".\-")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Installs CUDA
 | 
				
			||||||
 | 
					wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
 | 
				
			||||||
 | 
					sudo dpkg -i cuda-keyring_1.1-1_all.deb
 | 
				
			||||||
 | 
					rm cuda-keyring_1.1-1_all.deb
 | 
				
			||||||
 | 
					sudo apt -qq update
 | 
				
			||||||
 | 
					sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
 | 
				
			||||||
 | 
					sudo apt clean
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Test nvcc
 | 
				
			||||||
 | 
					PATH=/usr/local/cuda-$1/bin:${PATH}
 | 
				
			||||||
 | 
					nvcc --version
 | 
				
			||||||
							
								
								
									
										56
									
								
								.github/workflows/scripts/env.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,56 @@
 | 
				
			|||||||
 | 
					#!/bin/bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This file installs common linux environment tools
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export LANG C.UTF-8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# python_version=$1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					sudo    apt-get update && \
 | 
				
			||||||
 | 
					sudo    apt-get install -y --no-install-recommends \
 | 
				
			||||||
 | 
					        software-properties-common \
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					sudo    apt-get install -y --no-install-recommends \
 | 
				
			||||||
 | 
					        build-essential \
 | 
				
			||||||
 | 
					        apt-utils \
 | 
				
			||||||
 | 
					        ca-certificates \
 | 
				
			||||||
 | 
					        wget \
 | 
				
			||||||
 | 
					        git \
 | 
				
			||||||
 | 
					        vim \
 | 
				
			||||||
 | 
					        libssl-dev \
 | 
				
			||||||
 | 
					        curl \
 | 
				
			||||||
 | 
					        unzip \
 | 
				
			||||||
 | 
					        unrar \
 | 
				
			||||||
 | 
					        cmake \
 | 
				
			||||||
 | 
					        net-tools \
 | 
				
			||||||
 | 
					        sudo \
 | 
				
			||||||
 | 
					        autotools-dev \
 | 
				
			||||||
 | 
					        rsync \
 | 
				
			||||||
 | 
					        jq \
 | 
				
			||||||
 | 
					        openssh-server \
 | 
				
			||||||
 | 
					        tmux \
 | 
				
			||||||
 | 
					        screen \
 | 
				
			||||||
 | 
					        htop \
 | 
				
			||||||
 | 
					        pdsh \
 | 
				
			||||||
 | 
					        openssh-client \
 | 
				
			||||||
 | 
					        lshw \
 | 
				
			||||||
 | 
					        dmidecode \
 | 
				
			||||||
 | 
					        util-linux \
 | 
				
			||||||
 | 
					        automake \
 | 
				
			||||||
 | 
					        autoconf \
 | 
				
			||||||
 | 
					        libtool \
 | 
				
			||||||
 | 
					        net-tools \
 | 
				
			||||||
 | 
					        pciutils \
 | 
				
			||||||
 | 
					        libpci-dev \
 | 
				
			||||||
 | 
					        libaio-dev \
 | 
				
			||||||
 | 
					        libcap2 \
 | 
				
			||||||
 | 
					        libtinfo5 \
 | 
				
			||||||
 | 
					        fakeroot \
 | 
				
			||||||
 | 
					        devscripts \
 | 
				
			||||||
 | 
					        debhelper \
 | 
				
			||||||
 | 
					        nfs-common
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Remove github bloat files to free up disk space
 | 
				
			||||||
 | 
					sudo rm -rf "/usr/local/share/boost"
 | 
				
			||||||
 | 
					sudo rm -rf "$AGENT_TOOLSDIRECTORY"
 | 
				
			||||||
 | 
					sudo rm -rf "/usr/share/dotnet"
 | 
				
			||||||
							
								
								
									
										15
									
								
								.github/workflows/scripts/pytorch-install.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					#!/bin/bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					python_executable=python$1
 | 
				
			||||||
 | 
					pytorch_version=$2
 | 
				
			||||||
 | 
					cuda_version=$3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Install torch
 | 
				
			||||||
 | 
					$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
 | 
				
			||||||
 | 
					$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Print version information
 | 
				
			||||||
 | 
					$python_executable --version
 | 
				
			||||||
 | 
					$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
 | 
				
			||||||
 | 
					$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
 | 
				
			||||||
 | 
					$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
 | 
				
			||||||
							
								
								
									
										2
									
								
								.github/workflows/yapf.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						@ -28,4 +28,4 @@ jobs:
 | 
				
			|||||||
        pip install toml==0.10.2
 | 
					        pip install toml==0.10.2
 | 
				
			||||||
    - name: Running yapf
 | 
					    - name: Running yapf
 | 
				
			||||||
      run: |
 | 
					      run: |
 | 
				
			||||||
        yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'
 | 
					        yapf --diff --recursive vllm tests
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						@ -173,3 +173,7 @@ cython_debug/
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Sphinx documentation
 | 
					# Sphinx documentation
 | 
				
			||||||
_build/
 | 
					_build/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# vim swap files
 | 
				
			||||||
 | 
					*.swo
 | 
				
			||||||
 | 
					*.swp
 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,7 @@
 | 
				
			|||||||
[MASTER]
 | 
					[MASTER]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Files or directories to be skipped. They should be base names, not paths.
 | 
					# Files or directories to be skipped. They should be base names, not paths.
 | 
				
			||||||
ignore=docs,parallel_utils
 | 
					ignore=docs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Files or directories matching the regex patterns are skipped. The regex
 | 
					# Files or directories matching the regex patterns are skipped. The regex
 | 
				
			||||||
# matches against base names, not paths.
 | 
					# matches against base names, not paths.
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										62
									
								
								README.md
									
									
									
									
									
								
							
							
						
						@ -10,13 +10,18 @@ Easy, fast, and cheap LLM serving for everyone
 | 
				
			|||||||
</h3>
 | 
					</h3>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
<p align="center">
 | 
					<p align="center">
 | 
				
			||||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://github.com/vllm-project/vllm/discussions"><b>Discussions</b></a> |
 | 
					| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
</p>
 | 
					</p>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
*Latest News* 🔥
 | 
					*Latest News* 🔥
 | 
				
			||||||
 | 
					- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
 | 
				
			||||||
 | 
					- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
 | 
				
			||||||
 | 
					- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
 | 
				
			||||||
 | 
					- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
 | 
				
			||||||
 | 
					- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
 | 
				
			||||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
 | 
					- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
 | 
				
			||||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
 | 
					- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -33,21 +38,28 @@ vLLM is fast with:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
vLLM is flexible and easy to use with:
 | 
					vLLM is flexible and easy to use with:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- Seamless integration with popular HuggingFace models
 | 
					- Seamless integration with popular Hugging Face models
 | 
				
			||||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
 | 
					- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
 | 
				
			||||||
- Tensor parallelism support for distributed inference
 | 
					- Tensor parallelism support for distributed inference
 | 
				
			||||||
- Streaming outputs
 | 
					- Streaming outputs
 | 
				
			||||||
- OpenAI-compatible API server
 | 
					- OpenAI-compatible API server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
vLLM seamlessly supports many Huggingface models, including the following architectures:
 | 
					vLLM seamlessly supports many Hugging Face models, including the following architectures:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
 | 
				
			||||||
 | 
					- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
 | 
				
			||||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
 | 
					- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
 | 
				
			||||||
 | 
					- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
 | 
				
			||||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
 | 
					- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
 | 
				
			||||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
 | 
					- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
 | 
				
			||||||
 | 
					- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
 | 
				
			||||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
 | 
					- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
 | 
				
			||||||
- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
 | 
					- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
 | 
				
			||||||
 | 
					- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
 | 
				
			||||||
 | 
					- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
 | 
				
			||||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
 | 
					- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
 | 
				
			||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
 | 
					- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
 | 
				
			||||||
 | 
					- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
 | 
					Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -62,37 +74,19 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
 | 
				
			|||||||
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
 | 
					- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
 | 
				
			||||||
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
 | 
					- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Performance
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
vLLM outperforms HuggingFace Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
 | 
					 | 
				
			||||||
For details, check out our [blog post](https://vllm.ai).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
<p align="center">
 | 
					 | 
				
			||||||
  <picture>
 | 
					 | 
				
			||||||
  <source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n1_dark.png">
 | 
					 | 
				
			||||||
  <img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n1_light.png" width="45%">
 | 
					 | 
				
			||||||
  </picture>
 | 
					 | 
				
			||||||
  <picture>
 | 
					 | 
				
			||||||
  <source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n1_dark.png">
 | 
					 | 
				
			||||||
  <img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n1_light.png" width="45%">
 | 
					 | 
				
			||||||
  </picture>
 | 
					 | 
				
			||||||
  <br>
 | 
					 | 
				
			||||||
  <em> Serving throughput when each request asks for 1 output completion. </em>
 | 
					 | 
				
			||||||
</p>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
<p align="center">
 | 
					 | 
				
			||||||
  <picture>
 | 
					 | 
				
			||||||
  <source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n3_dark.png">
 | 
					 | 
				
			||||||
  <img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n3_light.png" width="45%">
 | 
					 | 
				
			||||||
  </picture>
 | 
					 | 
				
			||||||
  <picture>
 | 
					 | 
				
			||||||
  <source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n3_dark.png">
 | 
					 | 
				
			||||||
  <img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n3_light.png" width="45%">
 | 
					 | 
				
			||||||
  </picture>  <br>
 | 
					 | 
				
			||||||
  <em> Serving throughput when each request asks for 3 output completions. </em>
 | 
					 | 
				
			||||||
</p>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Contributing
 | 
					## Contributing
 | 
				
			||||||
 | 
					
 | 
				
			||||||
We welcome and value any contributions and collaborations.
 | 
					We welcome and value any contributions and collaborations.
 | 
				
			||||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
 | 
					Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Citation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
 | 
				
			||||||
 | 
					```bibtex
 | 
				
			||||||
 | 
					@inproceedings{kwon2023efficient,
 | 
				
			||||||
 | 
					  title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
 | 
				
			||||||
 | 
					  author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
 | 
				
			||||||
 | 
					  booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
 | 
				
			||||||
 | 
					  year={2023}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
				
			|||||||
@ -18,9 +18,12 @@ def main(args: argparse.Namespace):
 | 
				
			|||||||
    llm = LLM(
 | 
					    llm = LLM(
 | 
				
			||||||
        model=args.model,
 | 
					        model=args.model,
 | 
				
			||||||
        tokenizer=args.tokenizer,
 | 
					        tokenizer=args.tokenizer,
 | 
				
			||||||
 | 
					        quantization=args.quantization,
 | 
				
			||||||
        tensor_parallel_size=args.tensor_parallel_size,
 | 
					        tensor_parallel_size=args.tensor_parallel_size,
 | 
				
			||||||
        max_num_seqs=args.batch_size,
 | 
					        max_num_seqs=args.batch_size,
 | 
				
			||||||
        max_num_batched_tokens=args.batch_size * args.input_len,
 | 
					        max_num_batched_tokens=args.batch_size * args.input_len,
 | 
				
			||||||
 | 
					        trust_remote_code=args.trust_remote_code,
 | 
				
			||||||
 | 
					        dtype=args.dtype,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    sampling_params = SamplingParams(
 | 
					    sampling_params = SamplingParams(
 | 
				
			||||||
@ -37,13 +40,13 @@ def main(args: argparse.Namespace):
 | 
				
			|||||||
    def run_to_completion(profile: bool = False):
 | 
					    def run_to_completion(profile: bool = False):
 | 
				
			||||||
        if profile:
 | 
					        if profile:
 | 
				
			||||||
            torch.cuda.cudart().cudaProfilerStart()
 | 
					            torch.cuda.cudart().cudaProfilerStart()
 | 
				
			||||||
        start_time = time.time()
 | 
					        start_time = time.perf_counter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        llm.generate(prompt_token_ids=dummy_prompt_token_ids,
 | 
					        llm.generate(prompt_token_ids=dummy_prompt_token_ids,
 | 
				
			||||||
                     sampling_params=sampling_params,
 | 
					                     sampling_params=sampling_params,
 | 
				
			||||||
                     use_tqdm=False)
 | 
					                     use_tqdm=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        end_time = time.time()
 | 
					        end_time = time.perf_counter()
 | 
				
			||||||
        latency = end_time - start_time
 | 
					        latency = end_time - start_time
 | 
				
			||||||
        if profile:
 | 
					        if profile:
 | 
				
			||||||
            torch.cuda.cudart().cudaProfilerStop()
 | 
					            torch.cuda.cudart().cudaProfilerStop()
 | 
				
			||||||
@ -62,17 +65,37 @@ def main(args: argparse.Namespace):
 | 
				
			|||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    parser = argparse.ArgumentParser(
 | 
					    parser = argparse.ArgumentParser(
 | 
				
			||||||
        description='Benchmark the latency of processing a single batch of '
 | 
					        description='Benchmark the latency of processing a single batch of '
 | 
				
			||||||
                    'requests till completion.')
 | 
					        'requests till completion.')
 | 
				
			||||||
    parser.add_argument('--model', type=str, default='facebook/opt-125m')
 | 
					    parser.add_argument('--model', type=str, default='facebook/opt-125m')
 | 
				
			||||||
    parser.add_argument('--tokenizer', type=str, default=None)
 | 
					    parser.add_argument('--tokenizer', type=str, default=None)
 | 
				
			||||||
 | 
					    parser.add_argument('--quantization',
 | 
				
			||||||
 | 
					                        '-q',
 | 
				
			||||||
 | 
					                        choices=['awq', None],
 | 
				
			||||||
 | 
					                        default=None)
 | 
				
			||||||
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
 | 
					    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
 | 
				
			||||||
    parser.add_argument('--input-len', type=int, default=32)
 | 
					    parser.add_argument('--input-len', type=int, default=32)
 | 
				
			||||||
    parser.add_argument('--output-len', type=int, default=128)
 | 
					    parser.add_argument('--output-len', type=int, default=128)
 | 
				
			||||||
    parser.add_argument('--batch-size', type=int, default=8)
 | 
					    parser.add_argument('--batch-size', type=int, default=8)
 | 
				
			||||||
    parser.add_argument('--n', type=int, default=1,
 | 
					    parser.add_argument('--n',
 | 
				
			||||||
 | 
					                        type=int,
 | 
				
			||||||
 | 
					                        default=1,
 | 
				
			||||||
                        help='Number of generated sequences per prompt.')
 | 
					                        help='Number of generated sequences per prompt.')
 | 
				
			||||||
    parser.add_argument('--use-beam-search', action='store_true')
 | 
					    parser.add_argument('--use-beam-search', action='store_true')
 | 
				
			||||||
    parser.add_argument('--num-iters', type=int, default=3,
 | 
					    parser.add_argument('--num-iters',
 | 
				
			||||||
 | 
					                        type=int,
 | 
				
			||||||
 | 
					                        default=3,
 | 
				
			||||||
                        help='Number of iterations to run.')
 | 
					                        help='Number of iterations to run.')
 | 
				
			||||||
 | 
					    parser.add_argument('--trust-remote-code',
 | 
				
			||||||
 | 
					                        action='store_true',
 | 
				
			||||||
 | 
					                        help='trust remote code from huggingface')
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        '--dtype',
 | 
				
			||||||
 | 
					        type=str,
 | 
				
			||||||
 | 
					        default='auto',
 | 
				
			||||||
 | 
					        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
 | 
				
			||||||
 | 
					        help='data type for model weights and activations. '
 | 
				
			||||||
 | 
					        'The "auto" option will use FP16 precision '
 | 
				
			||||||
 | 
					        'for FP32 and FP16 models, and BF16 precision '
 | 
				
			||||||
 | 
					        'for BF16 models.')
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    main(args)
 | 
					    main(args)
 | 
				
			||||||
 | 
				
			|||||||
@ -105,7 +105,7 @@ async def send_request(
 | 
				
			|||||||
    best_of: int,
 | 
					    best_of: int,
 | 
				
			||||||
    use_beam_search: bool,
 | 
					    use_beam_search: bool,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    request_start_time = time.time()
 | 
					    request_start_time = time.perf_counter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    headers = {"User-Agent": "Benchmark Client"}
 | 
					    headers = {"User-Agent": "Benchmark Client"}
 | 
				
			||||||
    if backend == "vllm":
 | 
					    if backend == "vllm":
 | 
				
			||||||
@ -148,7 +148,7 @@ async def send_request(
 | 
				
			|||||||
            if "error" not in output:
 | 
					            if "error" not in output:
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    request_end_time = time.time()
 | 
					    request_end_time = time.perf_counter()
 | 
				
			||||||
    request_latency = request_end_time - request_start_time
 | 
					    request_latency = request_end_time - request_start_time
 | 
				
			||||||
    REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
 | 
					    REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -177,13 +177,13 @@ def main(args: argparse.Namespace):
 | 
				
			|||||||
    np.random.seed(args.seed)
 | 
					    np.random.seed(args.seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    api_url = f"http://{args.host}:{args.port}/generate"
 | 
					    api_url = f"http://{args.host}:{args.port}/generate"
 | 
				
			||||||
    tokenizer = get_tokenizer(args.tokenizer)
 | 
					    tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
 | 
				
			||||||
    input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
 | 
					    input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    benchmark_start_time = time.time()
 | 
					    benchmark_start_time = time.perf_counter()
 | 
				
			||||||
    asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
 | 
					    asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
 | 
				
			||||||
                          args.use_beam_search, args.request_rate))
 | 
					                          args.use_beam_search, args.request_rate))
 | 
				
			||||||
    benchmark_end_time = time.time()
 | 
					    benchmark_end_time = time.perf_counter()
 | 
				
			||||||
    benchmark_time = benchmark_end_time - benchmark_start_time
 | 
					    benchmark_time = benchmark_end_time - benchmark_start_time
 | 
				
			||||||
    print(f"Total time: {benchmark_time:.2f} s")
 | 
					    print(f"Total time: {benchmark_time:.2f} s")
 | 
				
			||||||
    print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
 | 
					    print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
 | 
				
			||||||
@ -227,5 +227,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
                             "Otherwise, we use Poisson process to synthesize "
 | 
					                             "Otherwise, we use Poisson process to synthesize "
 | 
				
			||||||
                             "the request arrival times.")
 | 
					                             "the request arrival times.")
 | 
				
			||||||
    parser.add_argument("--seed", type=int, default=0)
 | 
					    parser.add_argument("--seed", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument('--trust-remote-code', action='store_true',
 | 
				
			||||||
 | 
					                        help='trust remote code from huggingface')
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    main(args)
 | 
					    main(args)
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,7 @@ import argparse
 | 
				
			|||||||
import json
 | 
					import json
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from typing import List, Tuple
 | 
					from typing import List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
 | 
					from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
 | 
				
			||||||
@ -22,15 +22,10 @@ def sample_requests(
 | 
				
			|||||||
    with open(dataset_path) as f:
 | 
					    with open(dataset_path) as f:
 | 
				
			||||||
        dataset = json.load(f)
 | 
					        dataset = json.load(f)
 | 
				
			||||||
    # Filter out the conversations with less than 2 turns.
 | 
					    # Filter out the conversations with less than 2 turns.
 | 
				
			||||||
    dataset = [
 | 
					    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
 | 
				
			||||||
        data for data in dataset
 | 
					 | 
				
			||||||
        if len(data["conversations"]) >= 2
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
    # Only keep the first two turns of each conversation.
 | 
					    # Only keep the first two turns of each conversation.
 | 
				
			||||||
    dataset = [
 | 
					    dataset = [(data["conversations"][0]["value"],
 | 
				
			||||||
        (data["conversations"][0]["value"], data["conversations"][1]["value"])
 | 
					                data["conversations"][1]["value"]) for data in dataset]
 | 
				
			||||||
        for data in dataset
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Tokenize the prompts and completions.
 | 
					    # Tokenize the prompts and completions.
 | 
				
			||||||
    prompts = [prompt for prompt, _ in dataset]
 | 
					    prompts = [prompt for prompt, _ in dataset]
 | 
				
			||||||
@ -63,16 +58,22 @@ def run_vllm(
 | 
				
			|||||||
    requests: List[Tuple[str, int, int]],
 | 
					    requests: List[Tuple[str, int, int]],
 | 
				
			||||||
    model: str,
 | 
					    model: str,
 | 
				
			||||||
    tokenizer: str,
 | 
					    tokenizer: str,
 | 
				
			||||||
 | 
					    quantization: Optional[str],
 | 
				
			||||||
    tensor_parallel_size: int,
 | 
					    tensor_parallel_size: int,
 | 
				
			||||||
    seed: int,
 | 
					    seed: int,
 | 
				
			||||||
    n: int,
 | 
					    n: int,
 | 
				
			||||||
    use_beam_search: bool,
 | 
					    use_beam_search: bool,
 | 
				
			||||||
 | 
					    trust_remote_code: bool,
 | 
				
			||||||
 | 
					    dtype: str,
 | 
				
			||||||
) -> float:
 | 
					) -> float:
 | 
				
			||||||
    llm = LLM(
 | 
					    llm = LLM(
 | 
				
			||||||
        model=model,
 | 
					        model=model,
 | 
				
			||||||
        tokenizer=tokenizer,
 | 
					        tokenizer=tokenizer,
 | 
				
			||||||
 | 
					        quantization=quantization,
 | 
				
			||||||
        tensor_parallel_size=tensor_parallel_size,
 | 
					        tensor_parallel_size=tensor_parallel_size,
 | 
				
			||||||
        seed=seed,
 | 
					        seed=seed,
 | 
				
			||||||
 | 
					        trust_remote_code=trust_remote_code,
 | 
				
			||||||
 | 
					        dtype=dtype,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Add the requests to the engine.
 | 
					    # Add the requests to the engine.
 | 
				
			||||||
@ -92,10 +93,10 @@ def run_vllm(
 | 
				
			|||||||
            sampling_params=sampling_params,
 | 
					            sampling_params=sampling_params,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    start = time.time()
 | 
					    start = time.perf_counter()
 | 
				
			||||||
    # FIXME(woosuk): Do use internal method.
 | 
					    # FIXME(woosuk): Do use internal method.
 | 
				
			||||||
    llm._run_engine(use_tqdm=True)
 | 
					    llm._run_engine(use_tqdm=True)
 | 
				
			||||||
    end = time.time()
 | 
					    end = time.perf_counter()
 | 
				
			||||||
    return end - start
 | 
					    return end - start
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -106,16 +107,18 @@ def run_hf(
 | 
				
			|||||||
    n: int,
 | 
					    n: int,
 | 
				
			||||||
    use_beam_search: bool,
 | 
					    use_beam_search: bool,
 | 
				
			||||||
    max_batch_size: int,
 | 
					    max_batch_size: int,
 | 
				
			||||||
 | 
					    trust_remote_code: bool,
 | 
				
			||||||
) -> float:
 | 
					) -> float:
 | 
				
			||||||
    assert not use_beam_search
 | 
					    assert not use_beam_search
 | 
				
			||||||
    llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
 | 
					    llm = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
 | 
					        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
 | 
				
			||||||
    if llm.config.model_type == "llama":
 | 
					    if llm.config.model_type == "llama":
 | 
				
			||||||
        # To enable padding in the HF backend.
 | 
					        # To enable padding in the HF backend.
 | 
				
			||||||
        tokenizer.pad_token = tokenizer.eos_token
 | 
					        tokenizer.pad_token = tokenizer.eos_token
 | 
				
			||||||
    llm = llm.cuda()
 | 
					    llm = llm.cuda()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    pbar = tqdm(total=len(requests))
 | 
					    pbar = tqdm(total=len(requests))
 | 
				
			||||||
    start = time.time()
 | 
					    start = time.perf_counter()
 | 
				
			||||||
    batch: List[str] = []
 | 
					    batch: List[str] = []
 | 
				
			||||||
    max_prompt_len = 0
 | 
					    max_prompt_len = 0
 | 
				
			||||||
    max_output_len = 0
 | 
					    max_output_len = 0
 | 
				
			||||||
@ -128,13 +131,14 @@ def run_hf(
 | 
				
			|||||||
        if len(batch) < max_batch_size and i != len(requests) - 1:
 | 
					        if len(batch) < max_batch_size and i != len(requests) - 1:
 | 
				
			||||||
            # Check if we can add more requests to the batch.
 | 
					            # Check if we can add more requests to the batch.
 | 
				
			||||||
            _, next_prompt_len, next_output_len = requests[i + 1]
 | 
					            _, next_prompt_len, next_output_len = requests[i + 1]
 | 
				
			||||||
            if (max(max_prompt_len, next_prompt_len) + max(
 | 
					            if (max(max_prompt_len, next_prompt_len) +
 | 
				
			||||||
                max_output_len, next_output_len)) <= 2048:
 | 
					                    max(max_output_len, next_output_len)) <= 2048:
 | 
				
			||||||
                # We can add more requests to the batch.
 | 
					                # We can add more requests to the batch.
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Generate the sequences.
 | 
					        # Generate the sequences.
 | 
				
			||||||
        input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
 | 
					        input_ids = tokenizer(batch, return_tensors="pt",
 | 
				
			||||||
 | 
					                              padding=True).input_ids
 | 
				
			||||||
        llm_outputs = llm.generate(
 | 
					        llm_outputs = llm.generate(
 | 
				
			||||||
            input_ids=input_ids.cuda(),
 | 
					            input_ids=input_ids.cuda(),
 | 
				
			||||||
            do_sample=not use_beam_search,
 | 
					            do_sample=not use_beam_search,
 | 
				
			||||||
@ -152,7 +156,7 @@ def run_hf(
 | 
				
			|||||||
        batch = []
 | 
					        batch = []
 | 
				
			||||||
        max_prompt_len = 0
 | 
					        max_prompt_len = 0
 | 
				
			||||||
        max_output_len = 0
 | 
					        max_output_len = 0
 | 
				
			||||||
    end = time.time()
 | 
					    end = time.perf_counter()
 | 
				
			||||||
    return end - start
 | 
					    return end - start
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -161,44 +165,71 @@ def main(args: argparse.Namespace):
 | 
				
			|||||||
    random.seed(args.seed)
 | 
					    random.seed(args.seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Sample the requests.
 | 
					    # Sample the requests.
 | 
				
			||||||
    tokenizer = get_tokenizer(args.tokenizer)
 | 
					    tokenizer = get_tokenizer(args.tokenizer,
 | 
				
			||||||
 | 
					                              trust_remote_code=args.trust_remote_code)
 | 
				
			||||||
    requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
 | 
					    requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if args.backend == "vllm":
 | 
					    if args.backend == "vllm":
 | 
				
			||||||
        elapsed_time = run_vllm(
 | 
					        elapsed_time = run_vllm(requests, args.model, args.tokenizer,
 | 
				
			||||||
            requests, args.model, args.tokenizer, args.tensor_parallel_size,
 | 
					                                args.quantization, args.tensor_parallel_size,
 | 
				
			||||||
            args.seed, args.n, args.use_beam_search)
 | 
					                                args.seed, args.n, args.use_beam_search,
 | 
				
			||||||
 | 
					                                args.trust_remote_code, args.dtype)
 | 
				
			||||||
    elif args.backend == "hf":
 | 
					    elif args.backend == "hf":
 | 
				
			||||||
        assert args.tensor_parallel_size == 1
 | 
					        assert args.tensor_parallel_size == 1
 | 
				
			||||||
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
 | 
					        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
 | 
				
			||||||
                              args.use_beam_search, args.hf_max_batch_size)
 | 
					                              args.use_beam_search, args.hf_max_batch_size,
 | 
				
			||||||
 | 
					                              args.trust_remote_code)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        raise ValueError(f"Unknown backend: {args.backend}")
 | 
					        raise ValueError(f"Unknown backend: {args.backend}")
 | 
				
			||||||
    total_num_tokens = sum(
 | 
					    total_num_tokens = sum(prompt_len + output_len
 | 
				
			||||||
        prompt_len + output_len
 | 
					                           for _, prompt_len, output_len in requests)
 | 
				
			||||||
        for _, prompt_len, output_len in requests
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
 | 
					    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
 | 
				
			||||||
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
 | 
					          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    parser = argparse.ArgumentParser(description="Benchmark the throughput.")
 | 
					    parser = argparse.ArgumentParser(description="Benchmark the throughput.")
 | 
				
			||||||
    parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
 | 
					    parser.add_argument("--backend",
 | 
				
			||||||
 | 
					                        type=str,
 | 
				
			||||||
 | 
					                        choices=["vllm", "hf"],
 | 
				
			||||||
                        default="vllm")
 | 
					                        default="vllm")
 | 
				
			||||||
    parser.add_argument("--dataset", type=str, required=True,
 | 
					    parser.add_argument("--dataset",
 | 
				
			||||||
 | 
					                        type=str,
 | 
				
			||||||
 | 
					                        required=True,
 | 
				
			||||||
                        help="Path to the dataset.")
 | 
					                        help="Path to the dataset.")
 | 
				
			||||||
    parser.add_argument("--model", type=str, default="facebook/opt-125m")
 | 
					    parser.add_argument("--model", type=str, default="facebook/opt-125m")
 | 
				
			||||||
    parser.add_argument("--tokenizer", type=str, default=None)
 | 
					    parser.add_argument("--tokenizer", type=str, default=None)
 | 
				
			||||||
 | 
					    parser.add_argument('--quantization',
 | 
				
			||||||
 | 
					                        '-q',
 | 
				
			||||||
 | 
					                        choices=['awq', None],
 | 
				
			||||||
 | 
					                        default=None)
 | 
				
			||||||
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
 | 
					    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
 | 
				
			||||||
    parser.add_argument("--n", type=int, default=1,
 | 
					    parser.add_argument("--n",
 | 
				
			||||||
 | 
					                        type=int,
 | 
				
			||||||
 | 
					                        default=1,
 | 
				
			||||||
                        help="Number of generated sequences per prompt.")
 | 
					                        help="Number of generated sequences per prompt.")
 | 
				
			||||||
    parser.add_argument("--use-beam-search", action="store_true")
 | 
					    parser.add_argument("--use-beam-search", action="store_true")
 | 
				
			||||||
    parser.add_argument("--num-prompts", type=int, default=1000,
 | 
					    parser.add_argument("--num-prompts",
 | 
				
			||||||
 | 
					                        type=int,
 | 
				
			||||||
 | 
					                        default=1000,
 | 
				
			||||||
                        help="Number of prompts to process.")
 | 
					                        help="Number of prompts to process.")
 | 
				
			||||||
    parser.add_argument("--seed", type=int, default=0)
 | 
					    parser.add_argument("--seed", type=int, default=0)
 | 
				
			||||||
    parser.add_argument("--hf-max-batch-size", type=int, default=None,
 | 
					    parser.add_argument("--hf-max-batch-size",
 | 
				
			||||||
 | 
					                        type=int,
 | 
				
			||||||
 | 
					                        default=None,
 | 
				
			||||||
                        help="Maximum batch size for HF backend.")
 | 
					                        help="Maximum batch size for HF backend.")
 | 
				
			||||||
 | 
					    parser.add_argument('--trust-remote-code',
 | 
				
			||||||
 | 
					                        action='store_true',
 | 
				
			||||||
 | 
					                        help='trust remote code from huggingface')
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        '--dtype',
 | 
				
			||||||
 | 
					        type=str,
 | 
				
			||||||
 | 
					        default='auto',
 | 
				
			||||||
 | 
					        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
 | 
				
			||||||
 | 
					        help='data type for model weights and activations. '
 | 
				
			||||||
 | 
					        'The "auto" option will use FP16 precision '
 | 
				
			||||||
 | 
					        'for FP32 and FP16 models, and BF16 precision '
 | 
				
			||||||
 | 
					        'for BF16 models.')
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if args.backend == "vllm":
 | 
					    if args.backend == "vllm":
 | 
				
			||||||
@ -207,6 +238,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
    elif args.backend == "hf":
 | 
					    elif args.backend == "hf":
 | 
				
			||||||
        if args.hf_max_batch_size is None:
 | 
					        if args.hf_max_batch_size is None:
 | 
				
			||||||
            raise ValueError("HF max batch size is required for HF backend.")
 | 
					            raise ValueError("HF max batch size is required for HF backend.")
 | 
				
			||||||
 | 
					        if args.quantization is not None:
 | 
				
			||||||
 | 
					            raise ValueError("Quantization is only for vLLM backend.")
 | 
				
			||||||
    if args.tokenizer is None:
 | 
					    if args.tokenizer is None:
 | 
				
			||||||
        args.tokenizer = args.model
 | 
					        args.tokenizer = args.model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										197
									
								
								benchmarks/kernels/benchmark_paged_attention.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,197 @@
 | 
				
			|||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm import attention_ops
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NUM_BLOCKS = 1024
 | 
				
			||||||
 | 
					PARTITION_SIZE = 512
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.inference_mode()
 | 
				
			||||||
 | 
					def main(
 | 
				
			||||||
 | 
					    version: str,
 | 
				
			||||||
 | 
					    num_seqs: int,
 | 
				
			||||||
 | 
					    context_len: int,
 | 
				
			||||||
 | 
					    num_query_heads: int,
 | 
				
			||||||
 | 
					    num_kv_heads: int,
 | 
				
			||||||
 | 
					    head_size: int,
 | 
				
			||||||
 | 
					    use_alibi: bool,
 | 
				
			||||||
 | 
					    block_size: int,
 | 
				
			||||||
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
 | 
					    do_profile: bool,
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					    random.seed(seed)
 | 
				
			||||||
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scale = float(1.0 / (head_size**0.5))
 | 
				
			||||||
 | 
					    query = torch.empty(num_seqs,
 | 
				
			||||||
 | 
					                        num_query_heads,
 | 
				
			||||||
 | 
					                        head_size,
 | 
				
			||||||
 | 
					                        dtype=dtype,
 | 
				
			||||||
 | 
					                        device="cuda")
 | 
				
			||||||
 | 
					    query.uniform_(-scale, scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert num_query_heads % num_kv_heads == 0
 | 
				
			||||||
 | 
					    num_queries_per_kv = num_query_heads // num_kv_heads
 | 
				
			||||||
 | 
					    head_mapping = torch.repeat_interleave(
 | 
				
			||||||
 | 
					        torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
 | 
				
			||||||
 | 
					        num_queries_per_kv)
 | 
				
			||||||
 | 
					    alibi_slopes = None
 | 
				
			||||||
 | 
					    if use_alibi:
 | 
				
			||||||
 | 
					        alibi_slopes = torch.randn(num_query_heads,
 | 
				
			||||||
 | 
					                                   dtype=torch.float,
 | 
				
			||||||
 | 
					                                   device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    context_lens = [context_len for _ in range(num_seqs)]
 | 
				
			||||||
 | 
					    max_context_len = max(context_lens)
 | 
				
			||||||
 | 
					    context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create the block tables.
 | 
				
			||||||
 | 
					    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
 | 
				
			||||||
 | 
					    block_tables = []
 | 
				
			||||||
 | 
					    for _ in range(num_seqs):
 | 
				
			||||||
 | 
					        block_table = [
 | 
				
			||||||
 | 
					            random.randint(0, NUM_BLOCKS - 1)
 | 
				
			||||||
 | 
					            for _ in range(max_num_blocks_per_seq)
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        block_tables.append(block_table)
 | 
				
			||||||
 | 
					    block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create the KV cache.
 | 
				
			||||||
 | 
					    x = 16 // torch.tensor([], dtype=dtype).element_size()
 | 
				
			||||||
 | 
					    key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
 | 
				
			||||||
 | 
					    key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
 | 
				
			||||||
 | 
					    key_cache.uniform_(-scale, scale)
 | 
				
			||||||
 | 
					    value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
 | 
				
			||||||
 | 
					    value_cache = torch.empty(size=value_cache_shape,
 | 
				
			||||||
 | 
					                              dtype=dtype,
 | 
				
			||||||
 | 
					                              device="cuda")
 | 
				
			||||||
 | 
					    value_cache.uniform_(-scale, scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Prepare for the paged attention kernel.
 | 
				
			||||||
 | 
					    output = torch.empty_like(query)
 | 
				
			||||||
 | 
					    if version == "v2":
 | 
				
			||||||
 | 
					        num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
 | 
				
			||||||
 | 
					                          PARTITION_SIZE)
 | 
				
			||||||
 | 
					        tmp_output = torch.empty(
 | 
				
			||||||
 | 
					            size=(num_seqs, num_query_heads, num_partitions, head_size),
 | 
				
			||||||
 | 
					            dtype=output.dtype,
 | 
				
			||||||
 | 
					            device=output.device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        exp_sums = torch.empty(
 | 
				
			||||||
 | 
					            size=(num_seqs, num_query_heads, num_partitions),
 | 
				
			||||||
 | 
					            dtype=torch.float32,
 | 
				
			||||||
 | 
					            device=output.device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        max_logits = torch.empty_like(exp_sums)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run_benchmark(num_iters: int, profile: bool = False) -> float:
 | 
				
			||||||
 | 
					        torch.cuda.synchronize()
 | 
				
			||||||
 | 
					        if profile:
 | 
				
			||||||
 | 
					            torch.cuda.cudart().cudaProfilerStart()
 | 
				
			||||||
 | 
					        start_time = time.perf_counter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for _ in range(num_iters):
 | 
				
			||||||
 | 
					            if version == "v1":
 | 
				
			||||||
 | 
					                attention_ops.paged_attention_v1(
 | 
				
			||||||
 | 
					                    output,
 | 
				
			||||||
 | 
					                    query,
 | 
				
			||||||
 | 
					                    key_cache,
 | 
				
			||||||
 | 
					                    value_cache,
 | 
				
			||||||
 | 
					                    head_mapping,
 | 
				
			||||||
 | 
					                    scale,
 | 
				
			||||||
 | 
					                    block_tables,
 | 
				
			||||||
 | 
					                    context_lens,
 | 
				
			||||||
 | 
					                    block_size,
 | 
				
			||||||
 | 
					                    max_context_len,
 | 
				
			||||||
 | 
					                    alibi_slopes,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            elif version == "v2":
 | 
				
			||||||
 | 
					                attention_ops.paged_attention_v2(
 | 
				
			||||||
 | 
					                    output,
 | 
				
			||||||
 | 
					                    exp_sums,
 | 
				
			||||||
 | 
					                    max_logits,
 | 
				
			||||||
 | 
					                    tmp_output,
 | 
				
			||||||
 | 
					                    query,
 | 
				
			||||||
 | 
					                    key_cache,
 | 
				
			||||||
 | 
					                    value_cache,
 | 
				
			||||||
 | 
					                    head_mapping,
 | 
				
			||||||
 | 
					                    scale,
 | 
				
			||||||
 | 
					                    block_tables,
 | 
				
			||||||
 | 
					                    context_lens,
 | 
				
			||||||
 | 
					                    block_size,
 | 
				
			||||||
 | 
					                    max_context_len,
 | 
				
			||||||
 | 
					                    alibi_slopes,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise ValueError(f"Invalid version: {version}")
 | 
				
			||||||
 | 
					        torch.cuda.synchronize()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        end_time = time.perf_counter()
 | 
				
			||||||
 | 
					        if profile:
 | 
				
			||||||
 | 
					            torch.cuda.cudart().cudaProfilerStart()
 | 
				
			||||||
 | 
					        return (end_time - start_time) / num_iters
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Warmup.
 | 
				
			||||||
 | 
					    print("Warming up...")
 | 
				
			||||||
 | 
					    run_benchmark(num_iters=3, profile=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Benchmark.
 | 
				
			||||||
 | 
					    if do_profile:
 | 
				
			||||||
 | 
					        latency = run_benchmark(num_iters=1, profile=True)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        latency = run_benchmark(num_iters=100, profile=False)
 | 
				
			||||||
 | 
					    print(f"Kernel running time: {latency * 1000000:.3f} us")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(
 | 
				
			||||||
 | 
					        description="Benchmark the paged attention kernel.")
 | 
				
			||||||
 | 
					    parser.add_argument("--version",
 | 
				
			||||||
 | 
					                        type=str,
 | 
				
			||||||
 | 
					                        choices=["v1", "v2"],
 | 
				
			||||||
 | 
					                        default="v2")
 | 
				
			||||||
 | 
					    parser.add_argument("--batch-size", type=int, default=8)
 | 
				
			||||||
 | 
					    parser.add_argument("--context-len", type=int, default=4096)
 | 
				
			||||||
 | 
					    parser.add_argument("--num-query-heads", type=int, default=64)
 | 
				
			||||||
 | 
					    parser.add_argument("--num-kv-heads", type=int, default=8)
 | 
				
			||||||
 | 
					    parser.add_argument("--head-size",
 | 
				
			||||||
 | 
					                        type=int,
 | 
				
			||||||
 | 
					                        choices=[64, 80, 96, 112, 128, 256],
 | 
				
			||||||
 | 
					                        default=128)
 | 
				
			||||||
 | 
					    parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
 | 
				
			||||||
 | 
					    parser.add_argument("--use-alibi", action="store_true")
 | 
				
			||||||
 | 
					    parser.add_argument("--dtype",
 | 
				
			||||||
 | 
					                        type=str,
 | 
				
			||||||
 | 
					                        choices=["half", "bfloat16", "float"],
 | 
				
			||||||
 | 
					                        default="half")
 | 
				
			||||||
 | 
					    parser.add_argument("--seed", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--profile", action="store_true")
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    print(args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.num_query_heads % args.num_kv_heads != 0:
 | 
				
			||||||
 | 
					        raise ValueError("num_query_heads must be divisible by num_kv_heads")
 | 
				
			||||||
 | 
					    dtype_to_torch_dtype = {
 | 
				
			||||||
 | 
					        "half": torch.half,
 | 
				
			||||||
 | 
					        "bfloat16": torch.bfloat16,
 | 
				
			||||||
 | 
					        "float": torch.float,
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    main(
 | 
				
			||||||
 | 
					        version=args.version,
 | 
				
			||||||
 | 
					        num_seqs=args.batch_size,
 | 
				
			||||||
 | 
					        context_len=args.context_len,
 | 
				
			||||||
 | 
					        num_query_heads=args.num_query_heads,
 | 
				
			||||||
 | 
					        num_kv_heads=args.num_kv_heads,
 | 
				
			||||||
 | 
					        head_size=args.head_size,
 | 
				
			||||||
 | 
					        block_size=args.block_size,
 | 
				
			||||||
 | 
					        use_alibi=args.use_alibi,
 | 
				
			||||||
 | 
					        dtype=dtype_to_torch_dtype[args.dtype],
 | 
				
			||||||
 | 
					        seed=args.seed,
 | 
				
			||||||
 | 
					        do_profile=args.profile,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
@ -4,9 +4,25 @@ void silu_and_mul(
 | 
				
			|||||||
  torch::Tensor& out,
 | 
					  torch::Tensor& out,
 | 
				
			||||||
  torch::Tensor& input);
 | 
					  torch::Tensor& input);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void gelu_new(
 | 
				
			||||||
 | 
					  torch::Tensor& out,
 | 
				
			||||||
 | 
					  torch::Tensor& input);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void gelu_fast(
 | 
				
			||||||
 | 
					  torch::Tensor& out,
 | 
				
			||||||
 | 
					  torch::Tensor& input);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
					PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
    "silu_and_mul",
 | 
					    "silu_and_mul",
 | 
				
			||||||
    &silu_and_mul,
 | 
					    &silu_and_mul,
 | 
				
			||||||
    "Activation function used in SwiGLU.");
 | 
					    "Activation function used in SwiGLU.");
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					    "gelu_new",
 | 
				
			||||||
 | 
					    &gelu_new,
 | 
				
			||||||
 | 
					    "GELU implementation used in GPT-2.");
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					    "gelu_fast",
 | 
				
			||||||
 | 
					    &gelu_fast,
 | 
				
			||||||
 | 
					    "Approximate GELU implementation.");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,8 @@
 | 
				
			|||||||
#include <torch/extension.h>
 | 
					#include <torch/extension.h>
 | 
				
			||||||
#include <ATen/cuda/CUDAContext.h>
 | 
					#include <ATen/cuda/CUDAContext.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "dispatch_utils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace vllm {
 | 
					namespace vllm {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template<typename T>
 | 
					template<typename T>
 | 
				
			||||||
@ -34,9 +36,7 @@ void silu_and_mul(
 | 
				
			|||||||
  dim3 grid(num_tokens);
 | 
					  dim3 grid(num_tokens);
 | 
				
			||||||
  dim3 block(std::min(d, 1024));
 | 
					  dim3 block(std::min(d, 1024));
 | 
				
			||||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
					  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
				
			||||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
					  VLLM_DISPATCH_FLOATING_TYPES(
 | 
				
			||||||
    at::ScalarType::Half,
 | 
					 | 
				
			||||||
    at::ScalarType::BFloat16,
 | 
					 | 
				
			||||||
    input.scalar_type(),
 | 
					    input.scalar_type(),
 | 
				
			||||||
    "silu_and_mul_kernel",
 | 
					    "silu_and_mul_kernel",
 | 
				
			||||||
    [&] {
 | 
					    [&] {
 | 
				
			||||||
@ -46,3 +46,69 @@ void silu_and_mul(
 | 
				
			|||||||
        d);
 | 
					        d);
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace vllm {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Element-wise activation kernel template.
 | 
				
			||||||
 | 
					template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
 | 
				
			||||||
 | 
					__global__ void activation_kernel(
 | 
				
			||||||
 | 
					  scalar_t* __restrict__ out,               // [num_tokens, d]
 | 
				
			||||||
 | 
					  const scalar_t* __restrict__ input,       // [num_tokens, d]
 | 
				
			||||||
 | 
					  const int d) {
 | 
				
			||||||
 | 
					  const int token_idx = blockIdx.x;
 | 
				
			||||||
 | 
					  for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
 | 
				
			||||||
 | 
					    const scalar_t x = __ldg(&input[token_idx * d + idx]);
 | 
				
			||||||
 | 
					    out[token_idx * d + idx] = ACT_FN(x);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					} // namespace vllm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Launch element-wise activation kernel.
 | 
				
			||||||
 | 
					#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                                  \
 | 
				
			||||||
 | 
					  int num_tokens = input.size(0);                                                         \
 | 
				
			||||||
 | 
					  int d = input.size(1);                                                                  \
 | 
				
			||||||
 | 
					  dim3 grid(num_tokens);                                                                  \
 | 
				
			||||||
 | 
					  dim3 block(std::min(d, 1024));                                                          \
 | 
				
			||||||
 | 
					  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
 | 
				
			||||||
 | 
					  VLLM_DISPATCH_FLOATING_TYPES(                                                           \
 | 
				
			||||||
 | 
					    input.scalar_type(),                                                                  \
 | 
				
			||||||
 | 
					    "activation_kernel",                                                                  \
 | 
				
			||||||
 | 
					    [&] {                                                                                 \
 | 
				
			||||||
 | 
					      vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(    \
 | 
				
			||||||
 | 
					        out.data_ptr<scalar_t>(),                                                         \
 | 
				
			||||||
 | 
					        input.data_ptr<scalar_t>(),                                                       \
 | 
				
			||||||
 | 
					        d);                                                                               \
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace vllm {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template<typename T>
 | 
				
			||||||
 | 
					__device__ __forceinline__ T gelu_new_kernel(const T& x) {
 | 
				
			||||||
 | 
					  const float x3 = (float) (x * x * x);
 | 
				
			||||||
 | 
					  const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
 | 
				
			||||||
 | 
					  return ((T) 0.5) * x * (((T) 1.0) + t);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template<typename T>
 | 
				
			||||||
 | 
					__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
 | 
				
			||||||
 | 
					  const float f = (float) x;
 | 
				
			||||||
 | 
					  const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
 | 
				
			||||||
 | 
					  return ((T) 0.5) * x * (((T) 1.0) + t);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					} // namespace vllm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void gelu_new(
 | 
				
			||||||
 | 
					  torch::Tensor& out,     // [num_tokens, d]
 | 
				
			||||||
 | 
					  torch::Tensor& input)   // [num_tokens, d]
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void gelu_fast(
 | 
				
			||||||
 | 
					  torch::Tensor& out,     // [num_tokens, d]
 | 
				
			||||||
 | 
					  torch::Tensor& input)   // [num_tokens, d]
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,28 @@
 | 
				
			|||||||
#include <torch/extension.h>
 | 
					#include <torch/extension.h>
 | 
				
			||||||
#include <c10/util/Optional.h>
 | 
					#include <c10/util/Optional.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void single_query_cached_kv_attention(
 | 
					void paged_attention_v1(
 | 
				
			||||||
  torch::Tensor& out,
 | 
					  torch::Tensor& out,
 | 
				
			||||||
  torch::Tensor& query,
 | 
					  torch::Tensor& query,
 | 
				
			||||||
  torch::Tensor& key_cache,
 | 
					  torch::Tensor& key_cache,
 | 
				
			||||||
  torch::Tensor& value_cache,
 | 
					  torch::Tensor& value_cache,
 | 
				
			||||||
 | 
					  torch::Tensor& head_mapping,
 | 
				
			||||||
 | 
					  float scale,
 | 
				
			||||||
 | 
					  torch::Tensor& block_tables,
 | 
				
			||||||
 | 
					  torch::Tensor& context_lens,
 | 
				
			||||||
 | 
					  int block_size,
 | 
				
			||||||
 | 
					  int max_context_len,
 | 
				
			||||||
 | 
					  const c10::optional<torch::Tensor>& alibi_slopes);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void paged_attention_v2(
 | 
				
			||||||
 | 
					  torch::Tensor& out,
 | 
				
			||||||
 | 
					  torch::Tensor& exp_sums,
 | 
				
			||||||
 | 
					  torch::Tensor& max_logits,
 | 
				
			||||||
 | 
					  torch::Tensor& tmp_out,
 | 
				
			||||||
 | 
					  torch::Tensor& query,
 | 
				
			||||||
 | 
					  torch::Tensor& key_cache,
 | 
				
			||||||
 | 
					  torch::Tensor& value_cache,
 | 
				
			||||||
 | 
					  torch::Tensor& head_mapping,
 | 
				
			||||||
  float scale,
 | 
					  float scale,
 | 
				
			||||||
  torch::Tensor& block_tables,
 | 
					  torch::Tensor& block_tables,
 | 
				
			||||||
  torch::Tensor& context_lens,
 | 
					  torch::Tensor& context_lens,
 | 
				
			||||||
@ -15,7 +32,11 @@ void single_query_cached_kv_attention(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
					PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
    "single_query_cached_kv_attention",
 | 
					    "paged_attention_v1",
 | 
				
			||||||
    &single_query_cached_kv_attention,
 | 
					    &paged_attention_v1,
 | 
				
			||||||
    "Compute the attention between an input query and the cached key/value tensors");
 | 
					    "Compute the attention between an input query and the cached keys/values using PagedAttention.");
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					    "paged_attention_v2",
 | 
				
			||||||
 | 
					    &paged_attention_v2,
 | 
				
			||||||
 | 
					    "PagedAttention V2.");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -26,6 +26,7 @@
 | 
				
			|||||||
#define WARP_SIZE 32
 | 
					#define WARP_SIZE 32
 | 
				
			||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
 | 
					#define MAX(a, b) ((a) > (b) ? (a) : (b))
 | 
				
			||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
 | 
					#define MIN(a, b) ((a) < (b) ? (a) : (b))
 | 
				
			||||||
 | 
					#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace vllm {
 | 
					namespace vllm {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -65,25 +66,57 @@ inline __device__ float block_sum(float* red_smem, float sum) {
 | 
				
			|||||||
  return __shfl_sync(uint32_t(-1), sum, 0);
 | 
					  return __shfl_sync(uint32_t(-1), sum, 0);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Grid: (num_heads, num_seqs).
 | 
					// TODO(woosuk): Merge the last two dimensions of the grid.
 | 
				
			||||||
 | 
					// Grid: (num_heads, num_seqs, max_num_partitions).
 | 
				
			||||||
template<
 | 
					template<
 | 
				
			||||||
  typename scalar_t,
 | 
					  typename scalar_t,
 | 
				
			||||||
  int HEAD_SIZE,
 | 
					  int HEAD_SIZE,
 | 
				
			||||||
  int BLOCK_SIZE,
 | 
					  int BLOCK_SIZE,
 | 
				
			||||||
  int NUM_THREADS>
 | 
					  int NUM_THREADS,
 | 
				
			||||||
__global__ void single_query_cached_kv_attention_kernel(
 | 
					  int PARTITION_SIZE = 0> // Zero means no partitioning.
 | 
				
			||||||
  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
 | 
					__device__ void paged_attention_kernel(
 | 
				
			||||||
 | 
					  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
 | 
				
			||||||
 | 
					  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
 | 
				
			||||||
 | 
					  scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
 | 
				
			||||||
  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
 | 
					  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
 | 
				
			||||||
  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
					  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
 | 
				
			||||||
  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_heads, head_size, block_size]
 | 
					  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
 | 
				
			||||||
 | 
					  const int* __restrict__ head_mapping,   // [num_heads]
 | 
				
			||||||
  const float scale,
 | 
					  const float scale,
 | 
				
			||||||
  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
 | 
					  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
 | 
				
			||||||
  const int* __restrict__ context_lens,   // [num_seqs]
 | 
					  const int* __restrict__ context_lens,   // [num_seqs]
 | 
				
			||||||
  const int max_num_blocks_per_seq,
 | 
					  const int max_num_blocks_per_seq,
 | 
				
			||||||
  const float* __restrict__ alibi_slopes, // [num_heads]
 | 
					  const float* __restrict__ alibi_slopes, // [num_heads]
 | 
				
			||||||
  const int q_stride) {
 | 
					  const int q_stride,
 | 
				
			||||||
 | 
					  const int kv_block_stride,
 | 
				
			||||||
 | 
					  const int kv_head_stride) {
 | 
				
			||||||
 | 
					  const int seq_idx = blockIdx.y;
 | 
				
			||||||
 | 
					  const int partition_idx = blockIdx.z;
 | 
				
			||||||
 | 
					  const int max_num_partitions = gridDim.z;
 | 
				
			||||||
 | 
					  constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
 | 
				
			||||||
 | 
					  const int context_len = context_lens[seq_idx];
 | 
				
			||||||
 | 
					  if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
 | 
				
			||||||
 | 
					    // No work to do. Terminate the thread block.
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
 | 
				
			||||||
 | 
					  const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // [start_block_idx, end_block_idx) is the range of blocks to process.
 | 
				
			||||||
 | 
					  const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
 | 
				
			||||||
 | 
					  const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
 | 
				
			||||||
 | 
					  const int num_blocks = end_block_idx - start_block_idx;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // [start_token_idx, end_token_idx) is the range of tokens to process.
 | 
				
			||||||
 | 
					  const int start_token_idx = start_block_idx * BLOCK_SIZE;
 | 
				
			||||||
 | 
					  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
 | 
				
			||||||
 | 
					  const int num_tokens = end_token_idx - start_token_idx;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
 | 
					  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
 | 
				
			||||||
  constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
 | 
					  constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
 | 
				
			||||||
 | 
					  assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
 | 
				
			||||||
 | 
					  constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
 | 
				
			||||||
  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
 | 
					  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
 | 
				
			||||||
  const int thread_idx = threadIdx.x;
 | 
					  const int thread_idx = threadIdx.x;
 | 
				
			||||||
  const int warp_idx = thread_idx / WARP_SIZE;
 | 
					  const int warp_idx = thread_idx / WARP_SIZE;
 | 
				
			||||||
@ -91,7 +124,7 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  const int head_idx = blockIdx.x;
 | 
					  const int head_idx = blockIdx.x;
 | 
				
			||||||
  const int num_heads = gridDim.x;
 | 
					  const int num_heads = gridDim.x;
 | 
				
			||||||
  const int seq_idx = blockIdx.y;
 | 
					  const int kv_head_idx = head_mapping[head_idx];
 | 
				
			||||||
  const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
 | 
					  const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // A vector type to store a part of a key or a query.
 | 
					  // A vector type to store a part of a key or a query.
 | 
				
			||||||
@ -116,12 +149,13 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
  // th vectors of the query, and so on.
 | 
					  // th vectors of the query, and so on.
 | 
				
			||||||
  // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
 | 
					  // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
 | 
				
			||||||
  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
 | 
					  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
 | 
				
			||||||
  Q_vec q_vecs[NUM_VECS_PER_THREAD];
 | 
					  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
 | 
				
			||||||
#pragma unroll
 | 
					#pragma unroll
 | 
				
			||||||
  for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
 | 
					  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
 | 
				
			||||||
    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
 | 
					    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
 | 
				
			||||||
    q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
 | 
					    q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Memory planning.
 | 
					  // Memory planning.
 | 
				
			||||||
  extern __shared__ char shared_mem[];
 | 
					  extern __shared__ char shared_mem[];
 | 
				
			||||||
@ -135,15 +169,12 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
  constexpr int x = 16 / sizeof(scalar_t);
 | 
					  constexpr int x = 16 / sizeof(scalar_t);
 | 
				
			||||||
  float qk_max = -FLT_MAX;
 | 
					  float qk_max = -FLT_MAX;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
 | 
					 | 
				
			||||||
  const int context_len = context_lens[seq_idx];
 | 
					 | 
				
			||||||
  const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Iterate over the key blocks.
 | 
					  // Iterate over the key blocks.
 | 
				
			||||||
  // Each warp fetches a block of keys for each iteration.
 | 
					  // Each warp fetches a block of keys for each iteration.
 | 
				
			||||||
  // Each thread group in a warp fetches a key from the block, and computes
 | 
					  // Each thread group in a warp fetches a key from the block, and computes
 | 
				
			||||||
  // dot product with the query.
 | 
					  // dot product with the query.
 | 
				
			||||||
  for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
 | 
					  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
 | 
				
			||||||
 | 
					  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
 | 
				
			||||||
    const int physical_block_number = block_table[block_idx];
 | 
					    const int physical_block_number = block_table[block_idx];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Load a key to registers.
 | 
					    // Load a key to registers.
 | 
				
			||||||
@ -158,8 +189,8 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
#pragma unroll
 | 
					#pragma unroll
 | 
				
			||||||
      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
 | 
					      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
 | 
				
			||||||
        const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
 | 
					        const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
 | 
				
			||||||
                                        + head_idx * HEAD_SIZE * BLOCK_SIZE
 | 
					                                        + kv_head_idx * kv_head_stride
 | 
				
			||||||
                                        + physical_block_offset * x;
 | 
					                                        + physical_block_offset * x;
 | 
				
			||||||
        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
 | 
					        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
 | 
				
			||||||
        const int offset1 = (vec_idx * VEC_SIZE) / x;
 | 
					        const int offset1 = (vec_idx * VEC_SIZE) / x;
 | 
				
			||||||
@ -169,15 +200,15 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
      // Compute dot product.
 | 
					      // Compute dot product.
 | 
				
			||||||
      // This includes a reduction across the threads in the same thread group.
 | 
					      // This includes a reduction across the threads in the same thread group.
 | 
				
			||||||
      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
 | 
					      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
 | 
				
			||||||
      // Add the ALiBi bias if slopes are given.
 | 
					      // Add the ALiBi bias if slopes are given.
 | 
				
			||||||
      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
 | 
					      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      if (thread_group_offset == 0) {
 | 
					      if (thread_group_offset == 0) {
 | 
				
			||||||
        // Store the partial reductions to shared memory.
 | 
					        // Store the partial reductions to shared memory.
 | 
				
			||||||
        // NOTE(woosuk): It is required to zero out the masked logits.
 | 
					        // NOTE(woosuk): It is required to zero out the masked logits.
 | 
				
			||||||
        const bool mask = token_idx >= context_len;
 | 
					        const bool mask = token_idx >= context_len;
 | 
				
			||||||
        logits[token_idx] = mask ? 0.f : qk;
 | 
					        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
 | 
				
			||||||
        // Update the max value.
 | 
					        // Update the max value.
 | 
				
			||||||
        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
 | 
					        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
@ -208,7 +239,7 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // Get the sum of the exp values.
 | 
					  // Get the sum of the exp values.
 | 
				
			||||||
  float exp_sum = 0.f;
 | 
					  float exp_sum = 0.f;
 | 
				
			||||||
  for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
 | 
					  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
 | 
				
			||||||
    float val = __expf(logits[i] - qk_max);
 | 
					    float val = __expf(logits[i] - qk_max);
 | 
				
			||||||
    logits[i] = val;
 | 
					    logits[i] = val;
 | 
				
			||||||
    exp_sum += val;
 | 
					    exp_sum += val;
 | 
				
			||||||
@ -217,11 +248,23 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // Compute softmax.
 | 
					  // Compute softmax.
 | 
				
			||||||
  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
 | 
					  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
 | 
				
			||||||
  for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
 | 
					  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
 | 
				
			||||||
    logits[i] *= inv_sum;
 | 
					    logits[i] *= inv_sum;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  __syncthreads();
 | 
					  __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // If partitioning is enabled, store the max logit and exp_sum.
 | 
				
			||||||
 | 
					  if (USE_PARTITIONING && thread_idx == 0) {
 | 
				
			||||||
 | 
					    float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
 | 
				
			||||||
 | 
					                                       + head_idx * max_num_partitions
 | 
				
			||||||
 | 
					                                       + partition_idx;
 | 
				
			||||||
 | 
					    *max_logits_ptr = qk_max;
 | 
				
			||||||
 | 
					    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
 | 
				
			||||||
 | 
					                                   + head_idx * max_num_partitions
 | 
				
			||||||
 | 
					                                   + partition_idx;
 | 
				
			||||||
 | 
					    *exp_sums_ptr = exp_sum;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Each thread will fetch 16 bytes from the value cache at a time.
 | 
					  // Each thread will fetch 16 bytes from the value cache at a time.
 | 
				
			||||||
  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
 | 
					  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
 | 
				
			||||||
  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
 | 
					  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
 | 
				
			||||||
@ -230,7 +273,7 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
 | 
					  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
 | 
				
			||||||
  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
 | 
					  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
 | 
				
			||||||
  constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
 | 
					  constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
 | 
					  // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
 | 
				
			||||||
  float accs[NUM_ROWS_PER_THREAD];
 | 
					  float accs[NUM_ROWS_PER_THREAD];
 | 
				
			||||||
@ -239,21 +282,33 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
    accs[i] = 0.f;
 | 
					    accs[i] = 0.f;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
 | 
					  scalar_t zero_value;
 | 
				
			||||||
 | 
					  zero(zero_value);
 | 
				
			||||||
 | 
					  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
 | 
				
			||||||
    const int physical_block_number = block_table[block_idx];
 | 
					    const int physical_block_number = block_table[block_idx];
 | 
				
			||||||
    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
 | 
					    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
 | 
				
			||||||
    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
 | 
					    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
 | 
				
			||||||
    L_vec logits_vec;
 | 
					    L_vec logits_vec;
 | 
				
			||||||
    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
 | 
					    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
 | 
					    const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
 | 
				
			||||||
                                    + head_idx * HEAD_SIZE * BLOCK_SIZE;
 | 
					                                    + kv_head_idx * kv_head_stride;
 | 
				
			||||||
#pragma unroll
 | 
					#pragma unroll
 | 
				
			||||||
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
 | 
					    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
 | 
				
			||||||
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
 | 
					      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
 | 
				
			||||||
      if (row_idx < HEAD_SIZE) {
 | 
					      if (row_idx < HEAD_SIZE) {
 | 
				
			||||||
        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
 | 
					        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
 | 
				
			||||||
        V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
 | 
					        V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
 | 
				
			||||||
 | 
					        if (block_idx == num_context_blocks - 1) {
 | 
				
			||||||
 | 
					          // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
 | 
				
			||||||
 | 
					          // we should explicitly zero out the values since they may contain NaNs.
 | 
				
			||||||
 | 
					          // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
 | 
				
			||||||
 | 
					          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					          for (int j = 0; j < V_VEC_SIZE; j++) {
 | 
				
			||||||
 | 
					            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        accs[i] += dot(logits_vec, v_vec);
 | 
					        accs[i] += dot(logits_vec, v_vec);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -308,7 +363,9 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // Write the final output.
 | 
					  // Write the final output.
 | 
				
			||||||
  if (warp_idx == 0) {
 | 
					  if (warp_idx == 0) {
 | 
				
			||||||
    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
 | 
					    scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
 | 
				
			||||||
 | 
					                            + head_idx * max_num_partitions * HEAD_SIZE
 | 
				
			||||||
 | 
					                            + partition_idx * HEAD_SIZE;
 | 
				
			||||||
#pragma unroll
 | 
					#pragma unroll
 | 
				
			||||||
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
 | 
					    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
 | 
				
			||||||
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
 | 
					      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
 | 
				
			||||||
@ -319,32 +376,193 @@ __global__ void single_query_cached_kv_attention_kernel(
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Grid: (num_heads, num_seqs, 1).
 | 
				
			||||||
 | 
					template<
 | 
				
			||||||
 | 
					  typename scalar_t,
 | 
				
			||||||
 | 
					  int HEAD_SIZE,
 | 
				
			||||||
 | 
					  int BLOCK_SIZE,
 | 
				
			||||||
 | 
					  int NUM_THREADS>
 | 
				
			||||||
 | 
					__global__ void paged_attention_v1_kernel(
 | 
				
			||||||
 | 
					  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
 | 
				
			||||||
 | 
					  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
 | 
				
			||||||
 | 
					  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
 | 
				
			||||||
 | 
					  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
 | 
				
			||||||
 | 
					  const int* __restrict__ head_mapping,   // [num_heads]
 | 
				
			||||||
 | 
					  const float scale,
 | 
				
			||||||
 | 
					  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
 | 
				
			||||||
 | 
					  const int* __restrict__ context_lens,   // [num_seqs]
 | 
				
			||||||
 | 
					  const int max_num_blocks_per_seq,
 | 
				
			||||||
 | 
					  const float* __restrict__ alibi_slopes, // [num_heads]
 | 
				
			||||||
 | 
					  const int q_stride,
 | 
				
			||||||
 | 
					  const int kv_block_stride,
 | 
				
			||||||
 | 
					  const int kv_head_stride) {
 | 
				
			||||||
 | 
					  paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
 | 
				
			||||||
 | 
					    /* exp_sums */ nullptr, /* max_logits */ nullptr,
 | 
				
			||||||
 | 
					    out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
 | 
				
			||||||
 | 
					    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Grid: (num_heads, num_seqs, max_num_partitions).
 | 
				
			||||||
 | 
					template<
 | 
				
			||||||
 | 
					  typename scalar_t,
 | 
				
			||||||
 | 
					  int HEAD_SIZE,
 | 
				
			||||||
 | 
					  int BLOCK_SIZE,
 | 
				
			||||||
 | 
					  int NUM_THREADS,
 | 
				
			||||||
 | 
					  int PARTITION_SIZE>
 | 
				
			||||||
 | 
					__global__ void paged_attention_v2_kernel(
 | 
				
			||||||
 | 
					  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
 | 
				
			||||||
 | 
					  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
 | 
				
			||||||
 | 
					  scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
 | 
				
			||||||
 | 
					  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
 | 
				
			||||||
 | 
					  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
 | 
				
			||||||
 | 
					  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
 | 
				
			||||||
 | 
					  const int* __restrict__ head_mapping,   // [num_heads]
 | 
				
			||||||
 | 
					  const float scale,
 | 
				
			||||||
 | 
					  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
 | 
				
			||||||
 | 
					  const int* __restrict__ context_lens,   // [num_seqs]
 | 
				
			||||||
 | 
					  const int max_num_blocks_per_seq,
 | 
				
			||||||
 | 
					  const float* __restrict__ alibi_slopes, // [num_heads]
 | 
				
			||||||
 | 
					  const int q_stride,
 | 
				
			||||||
 | 
					  const int kv_block_stride,
 | 
				
			||||||
 | 
					  const int kv_head_stride) {
 | 
				
			||||||
 | 
					  paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
 | 
				
			||||||
 | 
					    exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
 | 
				
			||||||
 | 
					    block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
 | 
				
			||||||
 | 
					    q_stride, kv_block_stride, kv_head_stride);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Grid: (num_heads, num_seqs).
 | 
				
			||||||
 | 
					template<
 | 
				
			||||||
 | 
					  typename scalar_t,
 | 
				
			||||||
 | 
					  int HEAD_SIZE,
 | 
				
			||||||
 | 
					  int NUM_THREADS,
 | 
				
			||||||
 | 
					  int PARTITION_SIZE>
 | 
				
			||||||
 | 
					__global__ void paged_attention_v2_reduce_kernel(
 | 
				
			||||||
 | 
					  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
 | 
				
			||||||
 | 
					  const float* __restrict__ exp_sums,     // [num_seqs, num_heads, max_num_partitions]
 | 
				
			||||||
 | 
					  const float* __restrict__ max_logits,   // [num_seqs, num_heads, max_num_partitions]
 | 
				
			||||||
 | 
					  const scalar_t* __restrict__ tmp_out,   // [num_seqs, num_heads, max_num_partitions, head_size]
 | 
				
			||||||
 | 
					  const int* __restrict__ context_lens,   // [num_seqs]
 | 
				
			||||||
 | 
					  const int max_num_partitions) {
 | 
				
			||||||
 | 
					  const int num_heads = gridDim.x;
 | 
				
			||||||
 | 
					  const int head_idx = blockIdx.x;
 | 
				
			||||||
 | 
					  const int seq_idx = blockIdx.y;
 | 
				
			||||||
 | 
					  const int context_len = context_lens[seq_idx];
 | 
				
			||||||
 | 
					  const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
 | 
				
			||||||
 | 
					  if (num_partitions == 1) {
 | 
				
			||||||
 | 
					    // No need to reduce. Only copy tmp_out to out.
 | 
				
			||||||
 | 
					    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
 | 
				
			||||||
 | 
					    const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
 | 
				
			||||||
 | 
					                                          + head_idx * max_num_partitions * HEAD_SIZE;
 | 
				
			||||||
 | 
					    for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
 | 
				
			||||||
 | 
					      out_ptr[i] = tmp_out_ptr[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    // Terminate the thread block.
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
 | 
				
			||||||
 | 
					  const int warp_idx = threadIdx.x / WARP_SIZE;
 | 
				
			||||||
 | 
					  const int lane = threadIdx.x % WARP_SIZE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Size: 2 * num_partitions.
 | 
				
			||||||
 | 
					  extern __shared__ char shared_mem[];
 | 
				
			||||||
 | 
					  // Workspace for reduction.
 | 
				
			||||||
 | 
					  __shared__ float red_smem[2 * NUM_WARPS];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Load max logits to shared memory.
 | 
				
			||||||
 | 
					  float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
 | 
				
			||||||
 | 
					  const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
 | 
				
			||||||
 | 
					                                           + head_idx * max_num_partitions;
 | 
				
			||||||
 | 
					  float max_logit = -FLT_MAX;
 | 
				
			||||||
 | 
					  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
 | 
				
			||||||
 | 
					    const float l = max_logits_ptr[i];
 | 
				
			||||||
 | 
					    shared_max_logits[i] = l;
 | 
				
			||||||
 | 
					    max_logit = fmaxf(max_logit, l);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get the global max logit.
 | 
				
			||||||
 | 
					  // Reduce within the warp.
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
 | 
				
			||||||
 | 
					    max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  if (lane == 0) {
 | 
				
			||||||
 | 
					    red_smem[warp_idx] = max_logit;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  __syncthreads();
 | 
				
			||||||
 | 
					  // Reduce across warps.
 | 
				
			||||||
 | 
					  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
 | 
				
			||||||
 | 
					    max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Broadcast the max value to all threads.
 | 
				
			||||||
 | 
					  max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Load rescaled exp sums to shared memory.
 | 
				
			||||||
 | 
					  float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
 | 
				
			||||||
 | 
					  const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
 | 
				
			||||||
 | 
					                                       + head_idx * max_num_partitions;
 | 
				
			||||||
 | 
					  float global_exp_sum = 0.0f;
 | 
				
			||||||
 | 
					  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
 | 
				
			||||||
 | 
					    float l = shared_max_logits[i];
 | 
				
			||||||
 | 
					    float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
 | 
				
			||||||
 | 
					    global_exp_sum += rescaled_exp_sum;
 | 
				
			||||||
 | 
					    shared_exp_sums[i] = rescaled_exp_sum;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  __syncthreads();
 | 
				
			||||||
 | 
					  global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
 | 
				
			||||||
 | 
					  const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Aggregate tmp_out to out.
 | 
				
			||||||
 | 
					  const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
 | 
				
			||||||
 | 
					                                        + head_idx * max_num_partitions * HEAD_SIZE;
 | 
				
			||||||
 | 
					  scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					  for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
 | 
				
			||||||
 | 
					    float acc = 0.0f;
 | 
				
			||||||
 | 
					    for (int j = 0; j < num_partitions; ++j) {
 | 
				
			||||||
 | 
					      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    from_float(out_ptr[i], acc);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace vllm
 | 
					} // namespace vllm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS)                        \
 | 
					#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                  \
 | 
				
			||||||
  vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>        \
 | 
					  cudaFuncSetAttribute(                                                                       \
 | 
				
			||||||
 | 
					    vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>,                   \
 | 
				
			||||||
 | 
					    cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size);                            \
 | 
				
			||||||
 | 
					  vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>                      \
 | 
				
			||||||
  <<<grid, block, shared_mem_size, stream>>>(                                                 \
 | 
					  <<<grid, block, shared_mem_size, stream>>>(                                                 \
 | 
				
			||||||
    out_ptr,                                                                                  \
 | 
					    out_ptr,                                                                                  \
 | 
				
			||||||
    query_ptr,                                                                                \
 | 
					    query_ptr,                                                                                \
 | 
				
			||||||
    key_cache_ptr,                                                                            \
 | 
					    key_cache_ptr,                                                                            \
 | 
				
			||||||
    value_cache_ptr,                                                                          \
 | 
					    value_cache_ptr,                                                                          \
 | 
				
			||||||
 | 
					    head_mapping_ptr,                                                                         \
 | 
				
			||||||
    scale,                                                                                    \
 | 
					    scale,                                                                                    \
 | 
				
			||||||
    block_tables_ptr,                                                                         \
 | 
					    block_tables_ptr,                                                                         \
 | 
				
			||||||
    context_lens_ptr,                                                                         \
 | 
					    context_lens_ptr,                                                                         \
 | 
				
			||||||
    max_num_blocks_per_seq,                                                                   \
 | 
					    max_num_blocks_per_seq,                                                                   \
 | 
				
			||||||
    alibi_slopes_ptr,                                                                         \
 | 
					    alibi_slopes_ptr,                                                                         \
 | 
				
			||||||
    query_stride);
 | 
					    q_stride,                                                                                 \
 | 
				
			||||||
 | 
					    kv_block_stride,                                                                          \
 | 
				
			||||||
 | 
					    kv_head_stride);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO(woosuk): Tune NUM_THREADS.
 | 
					// TODO(woosuk): Tune NUM_THREADS.
 | 
				
			||||||
template<
 | 
					template<
 | 
				
			||||||
  typename T,
 | 
					  typename T,
 | 
				
			||||||
  int BLOCK_SIZE,
 | 
					  int BLOCK_SIZE,
 | 
				
			||||||
  int NUM_THREADS = 128>
 | 
					  int NUM_THREADS = 128>
 | 
				
			||||||
void single_query_cached_kv_attention_launcher(
 | 
					void paged_attention_v1_launcher(
 | 
				
			||||||
  torch::Tensor& out,
 | 
					  torch::Tensor& out,
 | 
				
			||||||
  torch::Tensor& query,
 | 
					  torch::Tensor& query,
 | 
				
			||||||
  torch::Tensor& key_cache,
 | 
					  torch::Tensor& key_cache,
 | 
				
			||||||
  torch::Tensor& value_cache,
 | 
					  torch::Tensor& value_cache,
 | 
				
			||||||
 | 
					  torch::Tensor& head_mapping,
 | 
				
			||||||
  float scale,
 | 
					  float scale,
 | 
				
			||||||
  torch::Tensor& block_tables,
 | 
					  torch::Tensor& block_tables,
 | 
				
			||||||
  torch::Tensor& context_lens,
 | 
					  torch::Tensor& context_lens,
 | 
				
			||||||
@ -354,7 +572,9 @@ void single_query_cached_kv_attention_launcher(
 | 
				
			|||||||
  int num_heads = query.size(1);
 | 
					  int num_heads = query.size(1);
 | 
				
			||||||
  int head_size = query.size(2);
 | 
					  int head_size = query.size(2);
 | 
				
			||||||
  int max_num_blocks_per_seq = block_tables.size(1);
 | 
					  int max_num_blocks_per_seq = block_tables.size(1);
 | 
				
			||||||
  int query_stride = query.stride(0);
 | 
					  int q_stride = query.stride(0);
 | 
				
			||||||
 | 
					  int kv_block_stride = key_cache.stride(0);
 | 
				
			||||||
 | 
					  int kv_head_stride = key_cache.stride(1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
 | 
					  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
 | 
				
			||||||
  assert(head_size % thread_group_size == 0);
 | 
					  assert(head_size % thread_group_size == 0);
 | 
				
			||||||
@ -368,60 +588,56 @@ void single_query_cached_kv_attention_launcher(
 | 
				
			|||||||
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
 | 
					  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
 | 
				
			||||||
  T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
 | 
					  T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
 | 
				
			||||||
  T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
 | 
					  T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
 | 
				
			||||||
 | 
					  int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
 | 
				
			||||||
  int* block_tables_ptr = block_tables.data_ptr<int>();
 | 
					  int* block_tables_ptr = block_tables.data_ptr<int>();
 | 
				
			||||||
  int* context_lens_ptr = context_lens.data_ptr<int>();
 | 
					  int* context_lens_ptr = context_lens.data_ptr<int>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
 | 
					  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
 | 
				
			||||||
  int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
 | 
					  int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
 | 
				
			||||||
  int logits_size = padded_max_context_len * sizeof(float);
 | 
					  int logits_size = padded_max_context_len * sizeof(float);
 | 
				
			||||||
  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
 | 
					  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
 | 
				
			||||||
 | 
					  // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
 | 
				
			||||||
 | 
					  // Keep that in sync with the logic here!
 | 
				
			||||||
  int shared_mem_size = std::max(logits_size, outputs_size);
 | 
					  int shared_mem_size = std::max(logits_size, outputs_size);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  dim3 grid(num_heads, num_seqs);
 | 
					  dim3 grid(num_heads, num_seqs, 1);
 | 
				
			||||||
  dim3 block(NUM_THREADS);
 | 
					  dim3 block(NUM_THREADS);
 | 
				
			||||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
					  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
				
			||||||
  switch (head_size) {
 | 
					  switch (head_size) {
 | 
				
			||||||
    // NOTE(woosuk): To reduce the compilation time, we omitted head sizes
 | 
					    // NOTE(woosuk): To reduce the compilation time, we only compile for the
 | 
				
			||||||
    // 32, 160, 192, 256.
 | 
					    // head sizes that we use in the model. However, we can easily extend this
 | 
				
			||||||
    // case 32:
 | 
					    // to support any head size which is a multiple of 16.
 | 
				
			||||||
    //   LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
 | 
					 | 
				
			||||||
    //   break;
 | 
					 | 
				
			||||||
    case 64:
 | 
					    case 64:
 | 
				
			||||||
      LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
 | 
					      LAUNCH_PAGED_ATTENTION_V1(64);
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
    case 80:
 | 
					    case 80:
 | 
				
			||||||
      LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
 | 
					      LAUNCH_PAGED_ATTENTION_V1(80);
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
    case 96:
 | 
					    case 96:
 | 
				
			||||||
      LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
 | 
					      LAUNCH_PAGED_ATTENTION_V1(96);
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
    case 112:
 | 
					    case 112:
 | 
				
			||||||
      LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS);
 | 
					      LAUNCH_PAGED_ATTENTION_V1(112);
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
    case 128:
 | 
					    case 128:
 | 
				
			||||||
      LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
 | 
					      LAUNCH_PAGED_ATTENTION_V1(128);
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    case 256:
 | 
				
			||||||
 | 
					      LAUNCH_PAGED_ATTENTION_V1(256);
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
    // case 160:
 | 
					 | 
				
			||||||
    //   LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
 | 
					 | 
				
			||||||
    //   break;
 | 
					 | 
				
			||||||
    // case 192:
 | 
					 | 
				
			||||||
    //   LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
 | 
					 | 
				
			||||||
    //   break;
 | 
					 | 
				
			||||||
    // case 256:
 | 
					 | 
				
			||||||
    //   LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
 | 
					 | 
				
			||||||
    //   break;
 | 
					 | 
				
			||||||
    default:
 | 
					    default:
 | 
				
			||||||
      TORCH_CHECK(false, "Unsupported head size: ", head_size);
 | 
					      TORCH_CHECK(false, "Unsupported head size: ", head_size);
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE)                         \
 | 
					#define CALL_V1_LAUNCHER(T, BLOCK_SIZE)                             \
 | 
				
			||||||
  single_query_cached_kv_attention_launcher<T, BLOCK_SIZE>(         \
 | 
					  paged_attention_v1_launcher<T, BLOCK_SIZE>(                       \
 | 
				
			||||||
    out,                                                            \
 | 
					    out,                                                            \
 | 
				
			||||||
    query,                                                          \
 | 
					    query,                                                          \
 | 
				
			||||||
    key_cache,                                                      \
 | 
					    key_cache,                                                      \
 | 
				
			||||||
    value_cache,                                                    \
 | 
					    value_cache,                                                    \
 | 
				
			||||||
 | 
					    head_mapping,                                                   \
 | 
				
			||||||
    scale,                                                          \
 | 
					    scale,                                                          \
 | 
				
			||||||
    block_tables,                                                   \
 | 
					    block_tables,                                                   \
 | 
				
			||||||
    context_lens,                                                   \
 | 
					    context_lens,                                                   \
 | 
				
			||||||
@ -430,45 +646,28 @@ void single_query_cached_kv_attention_launcher(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
 | 
					// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
 | 
				
			||||||
// 1, 2, 4, 64, 128, 256.
 | 
					// 1, 2, 4, 64, 128, 256.
 | 
				
			||||||
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T)                          \
 | 
					#define CALL_V1_LAUNCHER_BLOCK_SIZE(T)                              \
 | 
				
			||||||
  switch (block_size) {                                             \
 | 
					  switch (block_size) {                                             \
 | 
				
			||||||
    /* case 1:                         */                           \
 | 
					 | 
				
			||||||
    /*   CALL_KERNEL_LAUNCHER(T, 1);   */                           \
 | 
					 | 
				
			||||||
    /*   break;                        */                           \
 | 
					 | 
				
			||||||
    /* case 2:                         */                           \
 | 
					 | 
				
			||||||
    /*   CALL_KERNEL_LAUNCHER(T, 2);   */                           \
 | 
					 | 
				
			||||||
    /*   break;                        */                           \
 | 
					 | 
				
			||||||
    /* case 4:                         */                           \
 | 
					 | 
				
			||||||
    /*   CALL_KERNEL_LAUNCHER(T, 4);   */                           \
 | 
					 | 
				
			||||||
    /*   break;                        */                           \
 | 
					 | 
				
			||||||
    case 8:                                                         \
 | 
					    case 8:                                                         \
 | 
				
			||||||
      CALL_KERNEL_LAUNCHER(T, 8);                                   \
 | 
					      CALL_V1_LAUNCHER(T, 8);                                       \
 | 
				
			||||||
      break;                                                        \
 | 
					      break;                                                        \
 | 
				
			||||||
    case 16:                                                        \
 | 
					    case 16:                                                        \
 | 
				
			||||||
      CALL_KERNEL_LAUNCHER(T, 16);                                  \
 | 
					      CALL_V1_LAUNCHER(T, 16);                                      \
 | 
				
			||||||
      break;                                                        \
 | 
					      break;                                                        \
 | 
				
			||||||
    case 32:                                                        \
 | 
					    case 32:                                                        \
 | 
				
			||||||
      CALL_KERNEL_LAUNCHER(T, 32);                                  \
 | 
					      CALL_V1_LAUNCHER(T, 32);                                      \
 | 
				
			||||||
      break;                                                        \
 | 
					      break;                                                        \
 | 
				
			||||||
    /* case 64:                        */                           \
 | 
					 | 
				
			||||||
    /*   CALL_KERNEL_LAUNCHER(T, 64);  */                           \
 | 
					 | 
				
			||||||
    /*   break;                        */                           \
 | 
					 | 
				
			||||||
    /* case 128:                       */                           \
 | 
					 | 
				
			||||||
    /*   CALL_KERNEL_LAUNCHER(T, 128); */                           \
 | 
					 | 
				
			||||||
    /*   break;                        */                           \
 | 
					 | 
				
			||||||
    /* case 256:                       */                           \
 | 
					 | 
				
			||||||
    /*   CALL_KERNEL_LAUNCHER(T, 256); */                           \
 | 
					 | 
				
			||||||
    /*   break;                        */                           \
 | 
					 | 
				
			||||||
    default:                                                        \
 | 
					    default:                                                        \
 | 
				
			||||||
      TORCH_CHECK(false, "Unsupported block size: ", block_size);   \
 | 
					      TORCH_CHECK(false, "Unsupported block size: ", block_size);   \
 | 
				
			||||||
      break;                                                        \
 | 
					      break;                                                        \
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void single_query_cached_kv_attention(
 | 
					void paged_attention_v1(
 | 
				
			||||||
  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
 | 
					  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
 | 
				
			||||||
  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
 | 
					  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
 | 
				
			||||||
  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
					  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
				
			||||||
  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
 | 
					  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
 | 
				
			||||||
 | 
					  torch::Tensor& head_mapping,    // [num_heads]
 | 
				
			||||||
  float scale,
 | 
					  float scale,
 | 
				
			||||||
  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
 | 
					  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
 | 
				
			||||||
  torch::Tensor& context_lens,    // [num_seqs]
 | 
					  torch::Tensor& context_lens,    // [num_seqs]
 | 
				
			||||||
@ -476,11 +675,186 @@ void single_query_cached_kv_attention(
 | 
				
			|||||||
  int max_context_len,
 | 
					  int max_context_len,
 | 
				
			||||||
  const c10::optional<torch::Tensor>& alibi_slopes) {
 | 
					  const c10::optional<torch::Tensor>& alibi_slopes) {
 | 
				
			||||||
  if (query.dtype() == at::ScalarType::Float) {
 | 
					  if (query.dtype() == at::ScalarType::Float) {
 | 
				
			||||||
    CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
 | 
					    CALL_V1_LAUNCHER_BLOCK_SIZE(float);
 | 
				
			||||||
  } else if (query.dtype() == at::ScalarType::Half) {
 | 
					  } else if (query.dtype() == at::ScalarType::Half) {
 | 
				
			||||||
    CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
 | 
					    CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
 | 
				
			||||||
  } else if (query.dtype() == at::ScalarType::BFloat16) {
 | 
					  } else if (query.dtype() == at::ScalarType::BFloat16) {
 | 
				
			||||||
    CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
 | 
					    CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
 | 
				
			||||||
 | 
					  vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>      \
 | 
				
			||||||
 | 
					  <<<grid, block, shared_mem_size, stream>>>(                                                 \
 | 
				
			||||||
 | 
					    exp_sums_ptr,                                                                             \
 | 
				
			||||||
 | 
					    max_logits_ptr,                                                                           \
 | 
				
			||||||
 | 
					    tmp_out_ptr,                                                                              \
 | 
				
			||||||
 | 
					    query_ptr,                                                                                \
 | 
				
			||||||
 | 
					    key_cache_ptr,                                                                            \
 | 
				
			||||||
 | 
					    value_cache_ptr,                                                                          \
 | 
				
			||||||
 | 
					    head_mapping_ptr,                                                                         \
 | 
				
			||||||
 | 
					    scale,                                                                                    \
 | 
				
			||||||
 | 
					    block_tables_ptr,                                                                         \
 | 
				
			||||||
 | 
					    context_lens_ptr,                                                                         \
 | 
				
			||||||
 | 
					    max_num_blocks_per_seq,                                                                   \
 | 
				
			||||||
 | 
					    alibi_slopes_ptr,                                                                         \
 | 
				
			||||||
 | 
					    q_stride,                                                                                 \
 | 
				
			||||||
 | 
					    kv_block_stride,                                                                          \
 | 
				
			||||||
 | 
					    kv_head_stride);                                                                          \
 | 
				
			||||||
 | 
					  vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
 | 
				
			||||||
 | 
					  <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
 | 
				
			||||||
 | 
					    out_ptr,                                                                                  \
 | 
				
			||||||
 | 
					    exp_sums_ptr,                                                                             \
 | 
				
			||||||
 | 
					    max_logits_ptr,                                                                           \
 | 
				
			||||||
 | 
					    tmp_out_ptr,                                                                              \
 | 
				
			||||||
 | 
					    context_lens_ptr,                                                                         \
 | 
				
			||||||
 | 
					    max_num_partitions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template<
 | 
				
			||||||
 | 
					  typename T,
 | 
				
			||||||
 | 
					  int BLOCK_SIZE,
 | 
				
			||||||
 | 
					  int NUM_THREADS = 128,
 | 
				
			||||||
 | 
					  int PARTITION_SIZE = 512>
 | 
				
			||||||
 | 
					void paged_attention_v2_launcher(
 | 
				
			||||||
 | 
					  torch::Tensor& out,
 | 
				
			||||||
 | 
					  torch::Tensor& exp_sums,
 | 
				
			||||||
 | 
					  torch::Tensor& max_logits,
 | 
				
			||||||
 | 
					  torch::Tensor& tmp_out,
 | 
				
			||||||
 | 
					  torch::Tensor& query,
 | 
				
			||||||
 | 
					  torch::Tensor& key_cache,
 | 
				
			||||||
 | 
					  torch::Tensor& value_cache,
 | 
				
			||||||
 | 
					  torch::Tensor& head_mapping,
 | 
				
			||||||
 | 
					  float scale,
 | 
				
			||||||
 | 
					  torch::Tensor& block_tables,
 | 
				
			||||||
 | 
					  torch::Tensor& context_lens,
 | 
				
			||||||
 | 
					  int max_context_len,
 | 
				
			||||||
 | 
					  const c10::optional<torch::Tensor>& alibi_slopes) {
 | 
				
			||||||
 | 
					  int num_seqs = query.size(0);
 | 
				
			||||||
 | 
					  int num_heads = query.size(1);
 | 
				
			||||||
 | 
					  int head_size = query.size(2);
 | 
				
			||||||
 | 
					  int max_num_blocks_per_seq = block_tables.size(1);
 | 
				
			||||||
 | 
					  int q_stride = query.stride(0);
 | 
				
			||||||
 | 
					  int kv_block_stride = key_cache.stride(0);
 | 
				
			||||||
 | 
					  int kv_head_stride = key_cache.stride(1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
 | 
				
			||||||
 | 
					  assert(head_size % thread_group_size == 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // NOTE: alibi_slopes is optional.
 | 
				
			||||||
 | 
					  const float* alibi_slopes_ptr = alibi_slopes ?
 | 
				
			||||||
 | 
					    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
 | 
				
			||||||
 | 
					    : nullptr;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
 | 
				
			||||||
 | 
					  float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
 | 
				
			||||||
 | 
					  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
 | 
				
			||||||
 | 
					  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
 | 
				
			||||||
 | 
					  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
 | 
				
			||||||
 | 
					  T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
 | 
				
			||||||
 | 
					  T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
 | 
				
			||||||
 | 
					  int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
 | 
				
			||||||
 | 
					  int* block_tables_ptr = block_tables.data_ptr<int>();
 | 
				
			||||||
 | 
					  int* context_lens_ptr = context_lens.data_ptr<int>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
 | 
				
			||||||
 | 
					  int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
 | 
				
			||||||
 | 
					  int logits_size = PARTITION_SIZE * sizeof(float);
 | 
				
			||||||
 | 
					  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // For paged attention v2 kernel.
 | 
				
			||||||
 | 
					  dim3 grid(num_heads, num_seqs, max_num_partitions);
 | 
				
			||||||
 | 
					  int shared_mem_size = std::max(logits_size, outputs_size);
 | 
				
			||||||
 | 
					  // For paged attention v2 reduce kernel.
 | 
				
			||||||
 | 
					  dim3 reduce_grid(num_heads, num_seqs);
 | 
				
			||||||
 | 
					  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  dim3 block(NUM_THREADS);
 | 
				
			||||||
 | 
					  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
				
			||||||
 | 
					  switch (head_size) {
 | 
				
			||||||
 | 
					    // NOTE(woosuk): To reduce the compilation time, we only compile for the
 | 
				
			||||||
 | 
					    // head sizes that we use in the model. However, we can easily extend this
 | 
				
			||||||
 | 
					    // to support any head size which is a multiple of 16.
 | 
				
			||||||
 | 
					    case 64:
 | 
				
			||||||
 | 
					      LAUNCH_PAGED_ATTENTION_V2(64);
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    case 80:
 | 
				
			||||||
 | 
					      LAUNCH_PAGED_ATTENTION_V2(80);
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    case 96:
 | 
				
			||||||
 | 
					      LAUNCH_PAGED_ATTENTION_V2(96);
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    case 112:
 | 
				
			||||||
 | 
					      LAUNCH_PAGED_ATTENTION_V2(112);
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    case 128:
 | 
				
			||||||
 | 
					      LAUNCH_PAGED_ATTENTION_V2(128);
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    case 256:
 | 
				
			||||||
 | 
					      LAUNCH_PAGED_ATTENTION_V2(256);
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    default:
 | 
				
			||||||
 | 
					      TORCH_CHECK(false, "Unsupported head size: ", head_size);
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define CALL_V2_LAUNCHER(T, BLOCK_SIZE)                             \
 | 
				
			||||||
 | 
					  paged_attention_v2_launcher<T, BLOCK_SIZE>(                       \
 | 
				
			||||||
 | 
					    out,                                                            \
 | 
				
			||||||
 | 
					    exp_sums,                                                       \
 | 
				
			||||||
 | 
					    max_logits,                                                     \
 | 
				
			||||||
 | 
					    tmp_out,                                                        \
 | 
				
			||||||
 | 
					    query,                                                          \
 | 
				
			||||||
 | 
					    key_cache,                                                      \
 | 
				
			||||||
 | 
					    value_cache,                                                    \
 | 
				
			||||||
 | 
					    head_mapping,                                                   \
 | 
				
			||||||
 | 
					    scale,                                                          \
 | 
				
			||||||
 | 
					    block_tables,                                                   \
 | 
				
			||||||
 | 
					    context_lens,                                                   \
 | 
				
			||||||
 | 
					    max_context_len,                                                \
 | 
				
			||||||
 | 
					    alibi_slopes);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
 | 
				
			||||||
 | 
					// 1, 2, 4, 64, 128, 256.
 | 
				
			||||||
 | 
					#define CALL_V2_LAUNCHER_BLOCK_SIZE(T)                              \
 | 
				
			||||||
 | 
					  switch (block_size) {                                             \
 | 
				
			||||||
 | 
					    case 8:                                                         \
 | 
				
			||||||
 | 
					      CALL_V2_LAUNCHER(T, 8);                                       \
 | 
				
			||||||
 | 
					      break;                                                        \
 | 
				
			||||||
 | 
					    case 16:                                                        \
 | 
				
			||||||
 | 
					      CALL_V2_LAUNCHER(T, 16);                                      \
 | 
				
			||||||
 | 
					      break;                                                        \
 | 
				
			||||||
 | 
					    case 32:                                                        \
 | 
				
			||||||
 | 
					      CALL_V2_LAUNCHER(T, 32);                                      \
 | 
				
			||||||
 | 
					      break;                                                        \
 | 
				
			||||||
 | 
					    default:                                                        \
 | 
				
			||||||
 | 
					      TORCH_CHECK(false, "Unsupported block size: ", block_size);   \
 | 
				
			||||||
 | 
					      break;                                                        \
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void paged_attention_v2(
 | 
				
			||||||
 | 
					  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
 | 
				
			||||||
 | 
					  torch::Tensor& exp_sums,        // [num_seqs, num_heads, max_num_partitions]
 | 
				
			||||||
 | 
					  torch::Tensor& max_logits,      // [num_seqs, num_heads, max_num_partitions]
 | 
				
			||||||
 | 
					  torch::Tensor& tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
 | 
				
			||||||
 | 
					  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
 | 
				
			||||||
 | 
					  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
				
			||||||
 | 
					  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
 | 
				
			||||||
 | 
					  torch::Tensor& head_mapping,    // [num_heads]
 | 
				
			||||||
 | 
					  float scale,
 | 
				
			||||||
 | 
					  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
 | 
				
			||||||
 | 
					  torch::Tensor& context_lens,    // [num_seqs]
 | 
				
			||||||
 | 
					  int block_size,
 | 
				
			||||||
 | 
					  int max_context_len,
 | 
				
			||||||
 | 
					  const c10::optional<torch::Tensor>& alibi_slopes) {
 | 
				
			||||||
 | 
					  if (query.dtype() == at::ScalarType::Float) {
 | 
				
			||||||
 | 
					    CALL_V2_LAUNCHER_BLOCK_SIZE(float);
 | 
				
			||||||
 | 
					  } else if (query.dtype() == at::ScalarType::Half) {
 | 
				
			||||||
 | 
					    CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
 | 
				
			||||||
 | 
					  } else if (query.dtype() == at::ScalarType::BFloat16) {
 | 
				
			||||||
 | 
					    CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
 | 
					    TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@ -489,3 +863,4 @@ void single_query_cached_kv_attention(
 | 
				
			|||||||
#undef WARP_SIZE
 | 
					#undef WARP_SIZE
 | 
				
			||||||
#undef MAX
 | 
					#undef MAX
 | 
				
			||||||
#undef MIN
 | 
					#undef MIN
 | 
				
			||||||
 | 
					#undef DIVIDE_ROUND_UP
 | 
				
			||||||
 | 
				
			|||||||
@ -420,4 +420,19 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
 | 
				
			|||||||
#endif
 | 
					#endif
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// From bfloat16 to float32.
 | 
				
			||||||
 | 
					inline __device__ float to_float(__nv_bfloat16 u) {
 | 
				
			||||||
 | 
					  return __bfloat162float(u);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Zero-out a variable.
 | 
				
			||||||
 | 
					inline __device__ void zero(__nv_bfloat16& dst) {
 | 
				
			||||||
 | 
					#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 | 
				
			||||||
 | 
					  assert(false);
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					  // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
 | 
				
			||||||
 | 
					  dst = __ushort_as_bfloat16((unsigned short)0x0000U);
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace vllm
 | 
					} // namespace vllm
 | 
				
			||||||
 | 
				
			|||||||
@ -390,11 +390,6 @@ inline __device__ float sum(uint4 v) {
 | 
				
			|||||||
  return sum(c);
 | 
					  return sum(c);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Zero-out a vector.
 | 
					 | 
				
			||||||
inline __device__ void zero(uint16_t& dst) {
 | 
					 | 
				
			||||||
  dst = uint16_t(0);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// From float32 to float16.
 | 
					// From float32 to float16.
 | 
				
			||||||
inline __device__ void from_float(uint16_t& dst, float src) {
 | 
					inline __device__ void from_float(uint16_t& dst, float src) {
 | 
				
			||||||
  dst = float_to_half(src);
 | 
					  dst = float_to_half(src);
 | 
				
			||||||
@ -441,4 +436,9 @@ inline __device__ Float8_ to_float(uint4 u) {
 | 
				
			|||||||
  return tmp;
 | 
					  return tmp;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Zero-out a variable.
 | 
				
			||||||
 | 
					inline __device__ void zero(uint16_t& dst) {
 | 
				
			||||||
 | 
					  dst = uint16_t(0);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace vllm
 | 
					} // namespace vllm
 | 
				
			||||||
 | 
				
			|||||||
@ -265,4 +265,9 @@ inline __device__ Float8_ to_float(Float8_ u) {
 | 
				
			|||||||
  return u;
 | 
					  return u;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Zero-out a variable.
 | 
				
			||||||
 | 
					inline __device__ void zero(float& dst) {
 | 
				
			||||||
 | 
					  dst = 0.f;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace vllm
 | 
					} // namespace vllm
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,8 @@
 | 
				
			|||||||
#include <torch/extension.h>
 | 
					#include <torch/extension.h>
 | 
				
			||||||
#include <ATen/cuda/CUDAContext.h>
 | 
					#include <ATen/cuda/CUDAContext.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "dispatch_utils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <algorithm>
 | 
					#include <algorithm>
 | 
				
			||||||
#include <cassert>
 | 
					#include <cassert>
 | 
				
			||||||
#include <map>
 | 
					#include <map>
 | 
				
			||||||
@ -125,9 +127,7 @@ void copy_blocks(
 | 
				
			|||||||
  dim3 grid(num_layers, num_pairs);
 | 
					  dim3 grid(num_layers, num_pairs);
 | 
				
			||||||
  dim3 block(std::min(1024, numel_per_block));
 | 
					  dim3 block(std::min(1024, numel_per_block));
 | 
				
			||||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
					  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
				
			||||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
					  VLLM_DISPATCH_FLOATING_TYPES(
 | 
				
			||||||
    at::ScalarType::Half,
 | 
					 | 
				
			||||||
    at::ScalarType::BFloat16,
 | 
					 | 
				
			||||||
    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
 | 
					    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
 | 
				
			||||||
      vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
					      vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
				
			||||||
        key_cache_ptrs_tensor.data_ptr<int64_t>(),
 | 
					        key_cache_ptrs_tensor.data_ptr<int64_t>(),
 | 
				
			||||||
@ -202,9 +202,7 @@ void reshape_and_cache(
 | 
				
			|||||||
  dim3 grid(num_tokens);
 | 
					  dim3 grid(num_tokens);
 | 
				
			||||||
  dim3 block(std::min(num_heads * head_size, 512));
 | 
					  dim3 block(std::min(num_heads * head_size, 512));
 | 
				
			||||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
					  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
				
			||||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
					  VLLM_DISPATCH_FLOATING_TYPES(
 | 
				
			||||||
    at::ScalarType::Half,
 | 
					 | 
				
			||||||
    at::ScalarType::BFloat16,
 | 
					 | 
				
			||||||
    key.scalar_type(),
 | 
					    key.scalar_type(),
 | 
				
			||||||
    "reshape_and_cache_kernel",
 | 
					    "reshape_and_cache_kernel",
 | 
				
			||||||
    [&] {
 | 
					    [&] {
 | 
				
			||||||
@ -364,9 +362,7 @@ void gather_cached_kv(
 | 
				
			|||||||
  dim3 grid(num_tokens);
 | 
					  dim3 grid(num_tokens);
 | 
				
			||||||
  dim3 block(std::min(num_heads * head_size, 512));
 | 
					  dim3 block(std::min(num_heads * head_size, 512));
 | 
				
			||||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
					  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
				
			||||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
					  VLLM_DISPATCH_FLOATING_TYPES(
 | 
				
			||||||
    at::ScalarType::Half,
 | 
					 | 
				
			||||||
    at::ScalarType::BFloat16,
 | 
					 | 
				
			||||||
    key.scalar_type(),
 | 
					    key.scalar_type(),
 | 
				
			||||||
    "gather_cached_kv_kernel_optimized",
 | 
					    "gather_cached_kv_kernel_optimized",
 | 
				
			||||||
    [&] {
 | 
					    [&] {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										13
									
								
								csrc/cuda_utils.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					#include <torch/extension.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int get_device_attribute(
 | 
				
			||||||
 | 
					    int attribute,
 | 
				
			||||||
 | 
					    int device_id);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					    "get_device_attribute",
 | 
				
			||||||
 | 
					    &get_device_attribute,
 | 
				
			||||||
 | 
					    "Gets the specified device attribute.");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										14
									
								
								csrc/cuda_utils_kernels.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,14 @@
 | 
				
			|||||||
 | 
					int get_device_attribute(
 | 
				
			||||||
 | 
					    int attribute,
 | 
				
			||||||
 | 
					    int device_id)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    int device, value;
 | 
				
			||||||
 | 
					    if (device_id < 0) {
 | 
				
			||||||
 | 
					        cudaGetDevice(&device);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    else {
 | 
				
			||||||
 | 
					        device = device_id;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
 | 
				
			||||||
 | 
					    return value;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										14
									
								
								csrc/dispatch_utils.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,14 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 * Adapted from
 | 
				
			||||||
 | 
					 * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					#include <torch/extension.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)              \
 | 
				
			||||||
 | 
					  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
 | 
				
			||||||
 | 
					  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
 | 
				
			||||||
 | 
					  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \
 | 
				
			||||||
 | 
					  AT_DISPATCH_SWITCH(                                             \
 | 
				
			||||||
 | 
					    TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
 | 
				
			||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
#include <torch/extension.h>
 | 
					#include <torch/extension.h>
 | 
				
			||||||
#include <ATen/cuda/CUDAContext.h>
 | 
					#include <ATen/cuda/CUDAContext.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "dispatch_utils.h"
 | 
				
			||||||
#include "reduction_utils.cuh"
 | 
					#include "reduction_utils.cuh"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace vllm {
 | 
					namespace vllm {
 | 
				
			||||||
@ -46,9 +47,7 @@ void rms_norm(
 | 
				
			|||||||
  dim3 grid(num_tokens);
 | 
					  dim3 grid(num_tokens);
 | 
				
			||||||
  dim3 block(std::min(hidden_size, 1024));
 | 
					  dim3 block(std::min(hidden_size, 1024));
 | 
				
			||||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
					  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
				
			||||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
					  VLLM_DISPATCH_FLOATING_TYPES(
 | 
				
			||||||
    at::ScalarType::Half,
 | 
					 | 
				
			||||||
    at::ScalarType::BFloat16,
 | 
					 | 
				
			||||||
    input.scalar_type(),
 | 
					    input.scalar_type(),
 | 
				
			||||||
    "rms_norm_kernel",
 | 
					    "rms_norm_kernel",
 | 
				
			||||||
    [&] {
 | 
					    [&] {
 | 
				
			||||||
 | 
				
			|||||||
@ -1,15 +1,16 @@
 | 
				
			|||||||
#include <torch/extension.h>
 | 
					#include <torch/extension.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void rotary_embedding_neox(
 | 
					void rotary_embedding(
 | 
				
			||||||
  torch::Tensor& positions,
 | 
					  torch::Tensor& positions,
 | 
				
			||||||
  torch::Tensor& query,
 | 
					  torch::Tensor& query,
 | 
				
			||||||
  torch::Tensor& key,
 | 
					  torch::Tensor& key,
 | 
				
			||||||
  int head_size,
 | 
					  int head_size,
 | 
				
			||||||
  torch::Tensor& cos_sin_cache);
 | 
					  torch::Tensor& cos_sin_cache,
 | 
				
			||||||
 | 
					  bool is_neox);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
					PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
    "rotary_embedding_neox",
 | 
					    "rotary_embedding",
 | 
				
			||||||
    &rotary_embedding_neox,
 | 
					    &rotary_embedding,
 | 
				
			||||||
    "Apply GPT-NeoX style rotary embedding to query and key");
 | 
					    "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,17 +1,51 @@
 | 
				
			|||||||
#include <torch/extension.h>
 | 
					#include <torch/extension.h>
 | 
				
			||||||
#include <ATen/cuda/CUDAContext.h>
 | 
					#include <ATen/cuda/CUDAContext.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "dispatch_utils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace vllm {
 | 
					namespace vllm {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template<typename scalar_t>
 | 
					template<typename scalar_t, bool IS_NEOX>
 | 
				
			||||||
__global__ void rotary_embedding_neox_kernel(
 | 
					inline __device__ void apply_rotary_embedding(
 | 
				
			||||||
 | 
					  scalar_t* __restrict__ arr,
 | 
				
			||||||
 | 
					  const scalar_t* __restrict__ cos_ptr,
 | 
				
			||||||
 | 
					  const scalar_t* __restrict__ sin_ptr,
 | 
				
			||||||
 | 
					  int rot_offset,
 | 
				
			||||||
 | 
					  int embed_dim)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  int x_index, y_index;
 | 
				
			||||||
 | 
					  scalar_t cos, sin;
 | 
				
			||||||
 | 
					  if (IS_NEOX) {
 | 
				
			||||||
 | 
					    // GPT-NeoX style rotary embedding.
 | 
				
			||||||
 | 
					    x_index = rot_offset;
 | 
				
			||||||
 | 
					    y_index = embed_dim + rot_offset;
 | 
				
			||||||
 | 
					    cos = __ldg(cos_ptr + x_index);
 | 
				
			||||||
 | 
					    sin = __ldg(sin_ptr + x_index);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    // GPT-J style rotary embedding.
 | 
				
			||||||
 | 
					    x_index = 2 * rot_offset;
 | 
				
			||||||
 | 
					    y_index = 2 * rot_offset + 1;
 | 
				
			||||||
 | 
					    cos = __ldg(cos_ptr + x_index / 2);
 | 
				
			||||||
 | 
					    sin = __ldg(sin_ptr + x_index / 2);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const scalar_t x = arr[x_index];
 | 
				
			||||||
 | 
					  const scalar_t y = arr[y_index];
 | 
				
			||||||
 | 
					  arr[x_index] = x * cos - y * sin;
 | 
				
			||||||
 | 
					  arr[y_index] = y * cos + x * sin;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template<typename scalar_t, bool IS_NEOX>
 | 
				
			||||||
 | 
					__global__ void rotary_embedding_kernel(
 | 
				
			||||||
  const int64_t* __restrict__ positions,        // [num_tokens]
 | 
					  const int64_t* __restrict__ positions,        // [num_tokens]
 | 
				
			||||||
  scalar_t* __restrict__ query,                 // [num_tokens, num_heads, head_size]
 | 
					  scalar_t* __restrict__ query,                 // [num_tokens, num_heads, head_size]
 | 
				
			||||||
  scalar_t* __restrict__ key,                   // [num_tokens, num_heads, head_size]
 | 
					  scalar_t* __restrict__ key,                   // [num_tokens, num_kv_heads, head_size]
 | 
				
			||||||
  const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
 | 
					  const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
 | 
				
			||||||
  const int rot_dim,
 | 
					  const int rot_dim,
 | 
				
			||||||
  const int stride,
 | 
					  const int query_stride,
 | 
				
			||||||
 | 
					  const int key_stride,
 | 
				
			||||||
  const int num_heads,
 | 
					  const int num_heads,
 | 
				
			||||||
 | 
					  const int num_kv_heads,
 | 
				
			||||||
  const int head_size) {
 | 
					  const int head_size) {
 | 
				
			||||||
  // Each thread block is responsible for one token.
 | 
					  // Each thread block is responsible for one token.
 | 
				
			||||||
  const int token_idx = blockIdx.x;
 | 
					  const int token_idx = blockIdx.x;
 | 
				
			||||||
@ -19,65 +53,75 @@ __global__ void rotary_embedding_neox_kernel(
 | 
				
			|||||||
  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
 | 
					  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const int embed_dim = rot_dim / 2;
 | 
					  const int embed_dim = rot_dim / 2;
 | 
				
			||||||
  const int n = num_heads * embed_dim;
 | 
					  const scalar_t* cos_ptr = cache_ptr;
 | 
				
			||||||
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
 | 
					  const scalar_t* sin_ptr = cache_ptr + embed_dim;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const int nq = num_heads * embed_dim;
 | 
				
			||||||
 | 
					  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
 | 
				
			||||||
    const int head_idx = i / embed_dim;
 | 
					    const int head_idx = i / embed_dim;
 | 
				
			||||||
    const int token_head = token_idx * stride + head_idx * head_size;
 | 
					    const int token_head = token_idx * query_stride + head_idx * head_size;
 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int rot_offset = i % embed_dim;
 | 
					    const int rot_offset = i % embed_dim;
 | 
				
			||||||
    const int x_index = rot_offset;
 | 
					    apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
 | 
				
			||||||
    const int y_index = embed_dim + rot_offset;
 | 
					                                              sin_ptr, rot_offset, embed_dim);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int out_x = token_idx * stride + head_idx * head_size + x_index;
 | 
					  const int nk = num_kv_heads * embed_dim;
 | 
				
			||||||
    const int out_y = token_idx * stride + head_idx * head_size + y_index;
 | 
					  for (int i = threadIdx.x; i < nk; i += blockDim.x) {
 | 
				
			||||||
 | 
					    const int head_idx = i / embed_dim;
 | 
				
			||||||
    const scalar_t cos = __ldg(cache_ptr + x_index);
 | 
					    const int token_head = token_idx * key_stride + head_idx * head_size;
 | 
				
			||||||
    const scalar_t sin = __ldg(cache_ptr + y_index);
 | 
					    const int rot_offset = i % embed_dim;
 | 
				
			||||||
 | 
					    apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
 | 
				
			||||||
    const scalar_t q_x = query[token_head + x_index];
 | 
					                                              sin_ptr, rot_offset, embed_dim);
 | 
				
			||||||
    const scalar_t q_y = query[token_head + y_index];
 | 
					 | 
				
			||||||
    query[out_x] = q_x * cos - q_y * sin;
 | 
					 | 
				
			||||||
    query[out_y] = q_y * cos + q_x * sin;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const scalar_t k_x = key[token_head + x_index];
 | 
					 | 
				
			||||||
    const scalar_t k_y = key[token_head + y_index];
 | 
					 | 
				
			||||||
    key[out_x] = k_x * cos - k_y * sin;
 | 
					 | 
				
			||||||
    key[out_y] = k_y * cos + k_x * sin;
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace vllm
 | 
					} // namespace vllm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void rotary_embedding_neox(
 | 
					void rotary_embedding(
 | 
				
			||||||
  torch::Tensor& positions,         // [num_tokens]
 | 
					  torch::Tensor& positions,         // [num_tokens]
 | 
				
			||||||
  torch::Tensor& query,             // [num_tokens, num_heads * head_size]
 | 
					  torch::Tensor& query,             // [num_tokens, num_heads * head_size]
 | 
				
			||||||
  torch::Tensor& key,               // [num_tokens, num_heads * head_size]
 | 
					  torch::Tensor& key,               // [num_tokens, num_kv_heads * head_size]
 | 
				
			||||||
  int head_size,
 | 
					  int head_size,
 | 
				
			||||||
  torch::Tensor& cos_sin_cache)     // [max_position, rot_dim]
 | 
					  torch::Tensor& cos_sin_cache,     // [max_position, rot_dim]
 | 
				
			||||||
{
 | 
					  bool is_neox) {
 | 
				
			||||||
  int num_tokens = query.size(0);
 | 
					  int num_tokens = query.size(0);
 | 
				
			||||||
  int rot_dim = cos_sin_cache.size(1);
 | 
					  int rot_dim = cos_sin_cache.size(1);
 | 
				
			||||||
  int num_heads = query.size(1) / head_size;
 | 
					  int num_heads = query.size(1) / head_size;
 | 
				
			||||||
  int stride = query.stride(0);
 | 
					  int num_kv_heads = key.size(1) / head_size;
 | 
				
			||||||
  TORCH_CHECK(stride == key.stride(0));
 | 
					  int query_stride = query.stride(0);
 | 
				
			||||||
 | 
					  int key_stride = key.stride(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  dim3 grid(num_tokens);
 | 
					  dim3 grid(num_tokens);
 | 
				
			||||||
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
 | 
					  dim3 block(std::min(num_heads * rot_dim / 2, 512));
 | 
				
			||||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
					  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
				
			||||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
					  VLLM_DISPATCH_FLOATING_TYPES(
 | 
				
			||||||
    at::ScalarType::Half,
 | 
					 | 
				
			||||||
    at::ScalarType::BFloat16,
 | 
					 | 
				
			||||||
    query.scalar_type(),
 | 
					    query.scalar_type(),
 | 
				
			||||||
    "rotary_embedding_neox",
 | 
					    "rotary_embedding",
 | 
				
			||||||
    [&] {
 | 
					    [&] {
 | 
				
			||||||
      vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
					      if (is_neox) {
 | 
				
			||||||
        positions.data_ptr<int64_t>(),
 | 
					        vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
 | 
				
			||||||
        query.data_ptr<scalar_t>(),
 | 
					          positions.data_ptr<int64_t>(),
 | 
				
			||||||
        key.data_ptr<scalar_t>(),
 | 
					          query.data_ptr<scalar_t>(),
 | 
				
			||||||
        cos_sin_cache.data_ptr<scalar_t>(),
 | 
					          key.data_ptr<scalar_t>(),
 | 
				
			||||||
        rot_dim,
 | 
					          cos_sin_cache.data_ptr<scalar_t>(),
 | 
				
			||||||
        stride,
 | 
					          rot_dim,
 | 
				
			||||||
        num_heads,
 | 
					          query_stride,
 | 
				
			||||||
        head_size);
 | 
					          key_stride,
 | 
				
			||||||
 | 
					          num_heads,
 | 
				
			||||||
 | 
					          num_kv_heads,
 | 
				
			||||||
 | 
					          head_size);
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
 | 
				
			||||||
 | 
					          positions.data_ptr<int64_t>(),
 | 
				
			||||||
 | 
					          query.data_ptr<scalar_t>(),
 | 
				
			||||||
 | 
					          key.data_ptr<scalar_t>(),
 | 
				
			||||||
 | 
					          cos_sin_cache.data_ptr<scalar_t>(),
 | 
				
			||||||
 | 
					          rot_dim,
 | 
				
			||||||
 | 
					          query_stride,
 | 
				
			||||||
 | 
					          key_stride,
 | 
				
			||||||
 | 
					          num_heads,
 | 
				
			||||||
 | 
					          num_kv_heads,
 | 
				
			||||||
 | 
					          head_size);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										15
									
								
								csrc/quantization.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					#include <torch/extension.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					torch::Tensor awq_gemm(
 | 
				
			||||||
 | 
					  torch::Tensor _in_feats,
 | 
				
			||||||
 | 
					  torch::Tensor _kernel,
 | 
				
			||||||
 | 
					  torch::Tensor _scaling_factors,
 | 
				
			||||||
 | 
					  torch::Tensor _zeros,
 | 
				
			||||||
 | 
					  int split_k_iters);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					    "awq_gemm",
 | 
				
			||||||
 | 
					    &awq_gemm,
 | 
				
			||||||
 | 
					    "Quantized GEMM for AWQ");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										87
									
								
								csrc/quantization/awq/dequantize.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,87 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Adapted from https://github.com/mit-han-lab/llm-awq
 | 
				
			||||||
 | 
					Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
 | 
				
			||||||
 | 
					@article{lin2023awq,
 | 
				
			||||||
 | 
					  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
 | 
				
			||||||
 | 
					  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
 | 
				
			||||||
 | 
					  journal={arXiv},
 | 
				
			||||||
 | 
					  year={2023}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace vllm {
 | 
				
			||||||
 | 
					namespace awq {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
 | 
				
			||||||
 | 
					  assert(false);
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					    uint4 result;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    uint32_t*      h   = reinterpret_cast<uint32_t*>(&result);
 | 
				
			||||||
 | 
					    uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // First, we extract the i4s and construct an intermediate fp16 number.
 | 
				
			||||||
 | 
					    static constexpr uint32_t immLut                = (0xf0 & 0xcc) | 0xaa;
 | 
				
			||||||
 | 
					    static constexpr uint32_t BOTTOM_MASK           = 0x000f000f;
 | 
				
			||||||
 | 
					    static constexpr uint32_t TOP_MASK              = 0x00f000f0;
 | 
				
			||||||
 | 
					    static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
 | 
				
			||||||
 | 
					    // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
 | 
				
			||||||
 | 
					    // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
 | 
				
			||||||
 | 
					    // elt_67 to fp16 without having to shift them to the bottom bits before hand.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
 | 
				
			||||||
 | 
					    // immediately before required.
 | 
				
			||||||
 | 
					    const uint32_t top_i4s = i4s >> 8;
 | 
				
			||||||
 | 
					    // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
 | 
				
			||||||
 | 
					    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | 
				
			||||||
 | 
					                    : "=r"(h[0])
 | 
				
			||||||
 | 
					                    : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
 | 
				
			||||||
 | 
					    // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
 | 
				
			||||||
 | 
					    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | 
				
			||||||
 | 
					                    : "=r"(h[1])
 | 
				
			||||||
 | 
					                    : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
 | 
				
			||||||
 | 
					    // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
 | 
				
			||||||
 | 
					    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | 
				
			||||||
 | 
					                    : "=r"(h[2])
 | 
				
			||||||
 | 
					                    : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
 | 
				
			||||||
 | 
					    // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
 | 
				
			||||||
 | 
					    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | 
				
			||||||
 | 
					                    : "=r"(h[3])
 | 
				
			||||||
 | 
					                    : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
 | 
				
			||||||
 | 
					    // half2 ctor. In this case, I chose performance reliability over code readability.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // This is the half2 {1032, 1032} represented as an integer.
 | 
				
			||||||
 | 
					    // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
 | 
				
			||||||
 | 
					    // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
 | 
				
			||||||
 | 
					    static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
 | 
				
			||||||
 | 
					    // This is the half2 {1 / 16, 1 / 16} represented as an integer.
 | 
				
			||||||
 | 
					    static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
 | 
				
			||||||
 | 
					    // This is the half2 {-72, -72} represented as an integer.
 | 
				
			||||||
 | 
					    // static constexpr uint32_t NEG_72 = 0xd480d480;
 | 
				
			||||||
 | 
					    // Haotian: Let's use {-64, -64}.
 | 
				
			||||||
 | 
					    static constexpr uint32_t NEG_64 = 0xd400d400;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Finally, we construct the output numbers.
 | 
				
			||||||
 | 
					    // Convert elt_01
 | 
				
			||||||
 | 
					    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
 | 
				
			||||||
 | 
					    // Convert elt_23
 | 
				
			||||||
 | 
					    asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
 | 
				
			||||||
 | 
					    // Convert elt_45
 | 
				
			||||||
 | 
					    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
 | 
				
			||||||
 | 
					    // Convert elt_67
 | 
				
			||||||
 | 
					    asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					} // namespace awq
 | 
				
			||||||
 | 
					} // namespace vllm
 | 
				
			||||||
							
								
								
									
										560
									
								
								csrc/quantization/awq/gemm_kernels.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,560 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Adapted from https://github.com/mit-han-lab/llm-awq
 | 
				
			||||||
 | 
					@article{lin2023awq,
 | 
				
			||||||
 | 
					  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
 | 
				
			||||||
 | 
					  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
 | 
				
			||||||
 | 
					  journal={arXiv},
 | 
				
			||||||
 | 
					  year={2023}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <torch/extension.h>
 | 
				
			||||||
 | 
					#include <c10/cuda/CUDAGuard.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "dequantize.cuh"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <cuda_fp16.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace vllm {
 | 
				
			||||||
 | 
					namespace awq {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Pack two half values.
 | 
				
			||||||
 | 
					static inline __device__ __host__ unsigned
 | 
				
			||||||
 | 
					__pack_half2(const half x, const half y) {
 | 
				
			||||||
 | 
					  unsigned v0 = *((unsigned short *)&x);
 | 
				
			||||||
 | 
					  unsigned v1 = *((unsigned short *)&y);
 | 
				
			||||||
 | 
					  return (v1 << 16) | v0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
 | 
				
			||||||
 | 
					  assert(false);
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					  static constexpr uint32_t ZERO = 0x0;
 | 
				
			||||||
 | 
					  float C_warp[32];
 | 
				
			||||||
 | 
					  __shared__ half A_shared[16 * (32 + 8)];
 | 
				
			||||||
 | 
					  __shared__ half B_shared[32 * (128 + 8)];
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  __shared__ half scaling_factors_shared[128];
 | 
				
			||||||
 | 
					  __shared__ half zeros_shared[128];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int j_factors1 = ((OC + 128 - 1) / 128);
 | 
				
			||||||
 | 
					  int blockIdx_x = 0;
 | 
				
			||||||
 | 
					  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
 | 
				
			||||||
 | 
					  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  half A_shared_warp[8];
 | 
				
			||||||
 | 
					  half B_shared_warp[32];
 | 
				
			||||||
 | 
					  for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
 | 
				
			||||||
 | 
					    for (int i = 0; i < 8; ++i) {
 | 
				
			||||||
 | 
					      C_warp[(j_0_4_init * 8) + i] = 0.0;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static constexpr int row_stride_warp = 32 * 8 / 32;
 | 
				
			||||||
 | 
					  static constexpr int row_stride = 2 * 32 * 8 / 128;
 | 
				
			||||||
 | 
					  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
 | 
				
			||||||
 | 
					  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
 | 
				
			||||||
 | 
					  bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
 | 
				
			||||||
 | 
					  // bool wb_C_flag = (threadIdx.x / 4) < M;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  half* A_ptr = A 
 | 
				
			||||||
 | 
					                + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
 | 
				
			||||||
 | 
					                + (((int)threadIdx.x) % (32 / 8)) * 8;
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  int* B_ptr = B
 | 
				
			||||||
 | 
					            + ((int)threadIdx.y) * (OC / 8) * 2
 | 
				
			||||||
 | 
					            + (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
 | 
				
			||||||
 | 
					            + (((int)blockIdx_y) % j_factors1) * (128 / 8)
 | 
				
			||||||
 | 
					            + (((int)threadIdx.x) % (128 / 8)) * 1;
 | 
				
			||||||
 | 
					// Why * 1 in the above line?
 | 
				
			||||||
 | 
					                        
 | 
				
			||||||
 | 
					  half* A_shared_ptr = A_shared 
 | 
				
			||||||
 | 
					                    + ((int)threadIdx.y) * row_stride_warp * (32 + 8) 
 | 
				
			||||||
 | 
					                    + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
 | 
				
			||||||
 | 
					                    + (((int)threadIdx.x) % (32 / 8) ) * 8;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  half* B_shared_ptr = B_shared
 | 
				
			||||||
 | 
					                    + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
 | 
				
			||||||
 | 
					                    + (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
 | 
				
			||||||
 | 
					                    + (((int)threadIdx.x) % (128 / 8)) * 8;
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  int* zeros_ptr = zeros
 | 
				
			||||||
 | 
					                + (((int)blockIdx_y) % j_factors1) * (128 / 8)
 | 
				
			||||||
 | 
					                + ((int)threadIdx.x) % (128 / 8);
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  half* scaling_factors_ptr = scaling_factors
 | 
				
			||||||
 | 
					                            + (((int)blockIdx_y) % j_factors1) * (128) 
 | 
				
			||||||
 | 
					                            + (((int)threadIdx.x) % (128 / 8)) * 8;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  half* C_ptr = C 
 | 
				
			||||||
 | 
					              + static_cast<long long>(blockIdx_z) * M * OC        // blockIdz.x -> split_k dim
 | 
				
			||||||
 | 
					              + (((int)blockIdx_y) % j_factors1) * 128
 | 
				
			||||||
 | 
					              + ((int)threadIdx.y) * 64
 | 
				
			||||||
 | 
					              + (((int)threadIdx.x) % 4) * 2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // preload s.f. and zeros
 | 
				
			||||||
 | 
					  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
 | 
				
			||||||
 | 
					  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
 | 
				
			||||||
 | 
					  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
 | 
				
			||||||
 | 
					    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
 | 
				
			||||||
 | 
					    __syncthreads();
 | 
				
			||||||
 | 
					    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
 | 
				
			||||||
 | 
					    if (ld_A_flag)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
 | 
				
			||||||
 | 
					    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
 | 
				
			||||||
 | 
					    uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
 | 
				
			||||||
 | 
					    uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
 | 
				
			||||||
 | 
					    /*
 | 
				
			||||||
 | 
					    if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
 | 
				
			||||||
 | 
					      printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    */
 | 
				
			||||||
 | 
					    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
 | 
				
			||||||
 | 
					    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // B: 32 x 136 (128+8) float16
 | 
				
			||||||
 | 
					      // each warp: 32 x 4
 | 
				
			||||||
 | 
					      // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
 | 
				
			||||||
 | 
					      // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
 | 
				
			||||||
 | 
					      // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) 
 | 
				
			||||||
 | 
					      uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
 | 
				
			||||||
 | 
					      uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
 | 
				
			||||||
 | 
					      //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
 | 
				
			||||||
 | 
					      // - zero and * scale
 | 
				
			||||||
 | 
					      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
 | 
				
			||||||
 | 
					      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
 | 
				
			||||||
 | 
					      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
 | 
				
			||||||
 | 
					      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
 | 
				
			||||||
 | 
					      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
 | 
				
			||||||
 | 
					      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
 | 
				
			||||||
 | 
					      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
 | 
				
			||||||
 | 
					      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
 | 
				
			||||||
 | 
					      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
 | 
				
			||||||
 | 
					      /*
 | 
				
			||||||
 | 
					      if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
 | 
				
			||||||
 | 
					        printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // write back
 | 
				
			||||||
 | 
					      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        unsigned int addr;
 | 
				
			||||||
 | 
					        __asm__ __volatile__(
 | 
				
			||||||
 | 
					          "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
 | 
				
			||||||
 | 
					          : "=r"(addr)
 | 
				
			||||||
 | 
					          : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
 | 
				
			||||||
 | 
					        );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        __asm__ __volatile__(
 | 
				
			||||||
 | 
					          "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
 | 
				
			||||||
 | 
					          "{%0, %1, %2, %3}, [%4];\n"
 | 
				
			||||||
 | 
					          : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
 | 
				
			||||||
 | 
					          : "r"(addr)
 | 
				
			||||||
 | 
					        );
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          unsigned int addr;
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
 | 
				
			||||||
 | 
					            : "=r"(addr)
 | 
				
			||||||
 | 
					            : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
 | 
				
			||||||
 | 
					          );
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, [%4];\n"
 | 
				
			||||||
 | 
					            : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
 | 
				
			||||||
 | 
					            : "r"(addr)
 | 
				
			||||||
 | 
					          );
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
 | 
				
			||||||
 | 
					#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TODO: Shang: Hoist loop invariance.
 | 
				
			||||||
 | 
					  for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
 | 
				
			||||||
 | 
					    for (int local_id = 0; local_id < 8; ++local_id) {
 | 
				
			||||||
 | 
					      int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
 | 
				
			||||||
 | 
					      if (row_offset < M)
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
 | 
				
			||||||
 | 
					  assert(false);
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					  static constexpr uint32_t ZERO = 0x0;
 | 
				
			||||||
 | 
					  float C_warp[32];
 | 
				
			||||||
 | 
					  __shared__ half A_shared[16 * (32 + 8)];
 | 
				
			||||||
 | 
					  __shared__ half B_shared[32 * (64 + 8)];
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  __shared__ half scaling_factors_shared[64];
 | 
				
			||||||
 | 
					  __shared__ half zeros_shared[64];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int j_factors1 = ((OC + 64 - 1) / 64);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int blockIdx_x = 0;
 | 
				
			||||||
 | 
					  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
 | 
				
			||||||
 | 
					  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  half A_shared_warp[8];
 | 
				
			||||||
 | 
					  half B_shared_warp[16];
 | 
				
			||||||
 | 
					  for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
 | 
				
			||||||
 | 
					    for (int i = 0; i < 8; ++i) {
 | 
				
			||||||
 | 
					      C_warp[(j_0_4_init * 8) + i] = 0.0;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static constexpr int row_stride_warp = 32 * 8 / 32;
 | 
				
			||||||
 | 
					  static constexpr int row_stride = 2 * 32 * 8 / 64;
 | 
				
			||||||
 | 
					  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
 | 
				
			||||||
 | 
					  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
 | 
				
			||||||
 | 
					  bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
 | 
				
			||||||
 | 
					  // bool wb_C_flag = (threadIdx.x / 4) < M;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  half* A_ptr = A 
 | 
				
			||||||
 | 
					                + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
 | 
				
			||||||
 | 
					                + (((int)threadIdx.x) % (32 / 8)) * 8;
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  int* B_ptr = B
 | 
				
			||||||
 | 
					            + ((int)threadIdx.y) * (OC / 8) * 4
 | 
				
			||||||
 | 
					            + (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
 | 
				
			||||||
 | 
					            + (((int)blockIdx_y) % j_factors1) * (64 / 8)
 | 
				
			||||||
 | 
					            + (((int)threadIdx.x) % (64 / 8)) * 1;
 | 
				
			||||||
 | 
					// Why * 1 in the above line?
 | 
				
			||||||
 | 
					                        
 | 
				
			||||||
 | 
					  half* A_shared_ptr = A_shared 
 | 
				
			||||||
 | 
					                    + ((int)threadIdx.y) * row_stride_warp * (32 + 8) 
 | 
				
			||||||
 | 
					                    + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
 | 
				
			||||||
 | 
					                    + (((int)threadIdx.x) % (32 / 8) ) * 8;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  half* B_shared_ptr = B_shared
 | 
				
			||||||
 | 
					                    + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
 | 
				
			||||||
 | 
					                    + (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
 | 
				
			||||||
 | 
					                    + (((int)threadIdx.x) % (64 / 8)) * 8;
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  int* zeros_ptr = zeros
 | 
				
			||||||
 | 
					                + (((int)blockIdx_y) % j_factors1) * (64 / 8)
 | 
				
			||||||
 | 
					                + ((int)threadIdx.x) % (64 / 8);
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  half* scaling_factors_ptr = scaling_factors
 | 
				
			||||||
 | 
					                            + (((int)blockIdx_y) % j_factors1) * (64) 
 | 
				
			||||||
 | 
					                            + (((int)threadIdx.x) % (64 / 8)) * 8;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  half* C_ptr = C 
 | 
				
			||||||
 | 
					              + static_cast<long long>(blockIdx_z) * M * OC        // blockIdz.x -> split_k dim
 | 
				
			||||||
 | 
					              + (((int)blockIdx_y) % j_factors1) * 64
 | 
				
			||||||
 | 
					              + ((int)threadIdx.y) * 32
 | 
				
			||||||
 | 
					              + (((int)threadIdx.x) % 4) * 2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // preload s.f. and zeros
 | 
				
			||||||
 | 
					  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
 | 
				
			||||||
 | 
					  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
 | 
				
			||||||
 | 
					  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
 | 
				
			||||||
 | 
					    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
 | 
				
			||||||
 | 
					    __syncthreads();
 | 
				
			||||||
 | 
					    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
 | 
				
			||||||
 | 
					    if (ld_A_flag)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
 | 
				
			||||||
 | 
					    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
 | 
				
			||||||
 | 
					    uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
 | 
				
			||||||
 | 
					    uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
 | 
				
			||||||
 | 
					    /*
 | 
				
			||||||
 | 
					    if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
 | 
				
			||||||
 | 
					      printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    */
 | 
				
			||||||
 | 
					    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
 | 
				
			||||||
 | 
					    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // B: 32 x 136 (128+8) float16
 | 
				
			||||||
 | 
					      // each warp: 32 x 4
 | 
				
			||||||
 | 
					      // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
 | 
				
			||||||
 | 
					      // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
 | 
				
			||||||
 | 
					      // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) 
 | 
				
			||||||
 | 
					      uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
 | 
				
			||||||
 | 
					      uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
 | 
				
			||||||
 | 
					      //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
 | 
				
			||||||
 | 
					      // - zero and * scale
 | 
				
			||||||
 | 
					      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
 | 
				
			||||||
 | 
					      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
 | 
				
			||||||
 | 
					      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
 | 
				
			||||||
 | 
					      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
 | 
				
			||||||
 | 
					      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
 | 
				
			||||||
 | 
					      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
 | 
				
			||||||
 | 
					      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
 | 
				
			||||||
 | 
					      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
 | 
				
			||||||
 | 
					      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
 | 
				
			||||||
 | 
					      /*
 | 
				
			||||||
 | 
					      if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
 | 
				
			||||||
 | 
					        printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // write back
 | 
				
			||||||
 | 
					      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) 
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        unsigned int addr;
 | 
				
			||||||
 | 
					        __asm__ __volatile__(
 | 
				
			||||||
 | 
					          "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
 | 
				
			||||||
 | 
					          : "=r"(addr)
 | 
				
			||||||
 | 
					          : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
 | 
				
			||||||
 | 
					        );
 | 
				
			||||||
 | 
					        __asm__ __volatile__(
 | 
				
			||||||
 | 
					          "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
 | 
				
			||||||
 | 
					          "{%0, %1, %2, %3}, [%4];\n"
 | 
				
			||||||
 | 
					          : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
 | 
				
			||||||
 | 
					          : "r"(addr)
 | 
				
			||||||
 | 
					        );
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) 
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          unsigned int addr;
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
 | 
				
			||||||
 | 
					            : "=r"(addr)
 | 
				
			||||||
 | 
					            : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
 | 
				
			||||||
 | 
					          );
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, [%4];\n"
 | 
				
			||||||
 | 
					            : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
 | 
				
			||||||
 | 
					            : "r"(addr)
 | 
				
			||||||
 | 
					          );
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) 
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          __asm__ __volatile__(
 | 
				
			||||||
 | 
					            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
 | 
				
			||||||
 | 
					            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
 | 
				
			||||||
 | 
					            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
 | 
				
			||||||
 | 
					            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TODO: Shang: Hoist loop invariance.
 | 
				
			||||||
 | 
					  for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
 | 
				
			||||||
 | 
					    for (int local_id = 0; local_id < 8; ++local_id) {
 | 
				
			||||||
 | 
					      int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
 | 
				
			||||||
 | 
					      if (row_offset < M)
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					} // namespace awq
 | 
				
			||||||
 | 
					} // namespace vllm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// in_feats: M, IC [float16]
 | 
				
			||||||
 | 
					// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
 | 
				
			||||||
 | 
					// scaling_factors: IC // G, OC [float16]
 | 
				
			||||||
 | 
					// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
 | 
				
			||||||
 | 
					// assume that batch_size < 16 for now
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					torch::Tensor awq_gemm(
 | 
				
			||||||
 | 
					    torch::Tensor _in_feats,
 | 
				
			||||||
 | 
					    torch::Tensor _kernel,
 | 
				
			||||||
 | 
					    torch::Tensor _scaling_factors,
 | 
				
			||||||
 | 
					    torch::Tensor _zeros,
 | 
				
			||||||
 | 
					    int split_k_iters)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    int num_in_feats = _in_feats.size(0);
 | 
				
			||||||
 | 
					    int num_in_channels = _in_feats.size(1);
 | 
				
			||||||
 | 
					    const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
 | 
				
			||||||
 | 
					    at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
 | 
				
			||||||
 | 
					    int num_out_feats = _out_feats.size(-2);
 | 
				
			||||||
 | 
					    int num_out_channels = _out_feats.size(-1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
 | 
				
			||||||
 | 
					    auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
 | 
				
			||||||
 | 
					    auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
 | 
				
			||||||
 | 
					    auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
 | 
				
			||||||
 | 
					    auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
 | 
				
			||||||
 | 
					    int group_size = num_in_channels / _scaling_factors.size(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (num_out_channels % 64 != 0)
 | 
				
			||||||
 | 
					        throw std::invalid_argument("OC is not multiple of cta_N = 64");
 | 
				
			||||||
 | 
					    if (num_out_channels % 8 != 0)
 | 
				
			||||||
 | 
					        throw std::invalid_argument("OC is not multiple of pack_num = 8");
 | 
				
			||||||
 | 
					    if (group_size % 32 != 0)
 | 
				
			||||||
 | 
						      throw std::invalid_argument("Group size should be a multiple of 32");
 | 
				
			||||||
 | 
					    if (num_out_channels % group_size != 0)
 | 
				
			||||||
 | 
					        throw std::invalid_argument("OC is not multiple of Group size");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
				
			||||||
 | 
					    if (num_out_channels % 128 == 0)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        int j_factors1 = num_out_channels / 128 / 1;
 | 
				
			||||||
 | 
					        dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
 | 
				
			||||||
 | 
					        // threadIdx.x: 32
 | 
				
			||||||
 | 
					        // threadIdx.y: i_factors[2] * j_factors[2]
 | 
				
			||||||
 | 
					        dim3 threads_per_block(32, 2);
 | 
				
			||||||
 | 
					        vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
 | 
				
			||||||
 | 
					            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    else if (num_out_channels % 64 == 0)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        int j_factors1 = num_out_channels / 64 / 1;
 | 
				
			||||||
 | 
					        dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					        // threadIdx.x: 32
 | 
				
			||||||
 | 
					        // threadIdx.y: i_factors[2] * j_factors[2]
 | 
				
			||||||
 | 
					        dim3 threads_per_block(32, 2);
 | 
				
			||||||
 | 
					        vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
 | 
				
			||||||
 | 
					            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return _out_feats.sum(0);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
		 Before Width: | Height: | Size: 267 KiB  | 
| 
		 Before Width: | Height: | Size: 285 KiB  | 
| 
		 Before Width: | Height: | Size: 259 KiB  | 
| 
		 Before Width: | Height: | Size: 276 KiB  | 
| 
		 Before Width: | Height: | Size: 244 KiB  | 
| 
		 Before Width: | Height: | Size: 260 KiB  | 
| 
		 Before Width: | Height: | Size: 255 KiB  | 
| 
		 Before Width: | Height: | Size: 272 KiB  | 
@ -3,31 +3,15 @@
 | 
				
			|||||||
Installation
 | 
					Installation
 | 
				
			||||||
============
 | 
					============
 | 
				
			||||||
 | 
					
 | 
				
			||||||
vLLM is a Python library that also contains some C++ and CUDA code.
 | 
					vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
 | 
				
			||||||
This additional code requires compilation on the user's machine.
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
Requirements
 | 
					Requirements
 | 
				
			||||||
------------
 | 
					------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
* OS: Linux
 | 
					* OS: Linux
 | 
				
			||||||
* Python: 3.8 or higher
 | 
					* Python: 3.8 -- 3.11
 | 
				
			||||||
* CUDA: 11.0 -- 11.8
 | 
					 | 
				
			||||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
 | 
					* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.. note::
 | 
					 | 
				
			||||||
    As of now, vLLM does not support CUDA 12.
 | 
					 | 
				
			||||||
    If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
.. tip::
 | 
					 | 
				
			||||||
    If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    .. code-block:: console
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        $ # Pull the Docker image with CUDA 11.8.
 | 
					 | 
				
			||||||
        $ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Install with pip
 | 
					Install with pip
 | 
				
			||||||
----------------
 | 
					----------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -40,7 +24,7 @@ You can install vLLM using pip:
 | 
				
			|||||||
    $ conda activate myenv
 | 
					    $ conda activate myenv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    $ # Install vLLM.
 | 
					    $ # Install vLLM.
 | 
				
			||||||
    $ pip install vllm  # This may take 5-10 minutes.
 | 
					    $ pip install vllm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.. _build_from_source:
 | 
					.. _build_from_source:
 | 
				
			||||||
@ -55,3 +39,12 @@ You can also build and install vLLM from source:
 | 
				
			|||||||
    $ git clone https://github.com/vllm-project/vllm.git
 | 
					    $ git clone https://github.com/vllm-project/vllm.git
 | 
				
			||||||
    $ cd vllm
 | 
					    $ cd vllm
 | 
				
			||||||
    $ pip install -e .  # This may take 5-10 minutes.
 | 
					    $ pip install -e .  # This may take 5-10 minutes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.. tip::
 | 
				
			||||||
 | 
					    If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    .. code-block:: console
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        $ # Pull the Docker image with CUDA 11.8.
 | 
				
			||||||
 | 
					        $ # Use `--ipc=host` to make sure the shared memory is large enough.
 | 
				
			||||||
 | 
					        $ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:22.12-py3
 | 
				
			||||||
 | 
				
			|||||||
@ -128,4 +128,4 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
 | 
				
			|||||||
                                          prompt="San Francisco is a")
 | 
					                                          prompt="San Francisco is a")
 | 
				
			||||||
    print("Completion result:", completion)
 | 
					    print("Completion result:", completion)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
For a more detailed client example, refer to `examples/openai_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_client.py>`_.
 | 
					For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
 | 
				
			||||||
 | 
				
			|||||||
@ -43,6 +43,7 @@ vLLM is flexible and easy to use with:
 | 
				
			|||||||
For more information, check out the following:
 | 
					For more information, check out the following:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
 | 
					* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
 | 
				
			||||||
 | 
					* `vLLM paper <https://arxiv.org/abs/2309.06180>`_ (SOSP 2023)
 | 
				
			||||||
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
 | 
					* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -62,6 +63,8 @@ Documentation
 | 
				
			|||||||
   :caption: Serving
 | 
					   :caption: Serving
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   serving/distributed_serving
 | 
					   serving/distributed_serving
 | 
				
			||||||
 | 
					   serving/run_on_sky
 | 
				
			||||||
 | 
					   serving/deploying_with_triton
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.. toctree::
 | 
					.. toctree::
 | 
				
			||||||
   :maxdepth: 1
 | 
					   :maxdepth: 1
 | 
				
			||||||
 | 
				
			|||||||
@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
 | 
				
			|||||||
    +    kv_caches: List[KVCache],
 | 
					    +    kv_caches: List[KVCache],
 | 
				
			||||||
    +    input_metadata: InputMetadata,
 | 
					    +    input_metadata: InputMetadata,
 | 
				
			||||||
    +    cache_events: Optional[List[torch.cuda.Event]],
 | 
					    +    cache_events: Optional[List[torch.cuda.Event]],
 | 
				
			||||||
    +) -> Dict[int, SequenceOutputs]:
 | 
					    +) -> SamplerOutput:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
 | 
					3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
 | 
				
			||||||
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
 | 
					4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
 | 
				
			||||||
 | 
				
			|||||||
@ -14,27 +14,48 @@ Alongside each architecture, we include some popular models that use it.
 | 
				
			|||||||
  * - Architecture
 | 
					  * - Architecture
 | 
				
			||||||
    - Models
 | 
					    - Models
 | 
				
			||||||
    - Example HuggingFace Models
 | 
					    - Example HuggingFace Models
 | 
				
			||||||
 | 
					  * - :code:`AquilaForCausalLM`
 | 
				
			||||||
 | 
					    - Aquila
 | 
				
			||||||
 | 
					    - :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
 | 
				
			||||||
 | 
					  * - :code:`BaiChuanForCausalLM`
 | 
				
			||||||
 | 
					    - Baichuan
 | 
				
			||||||
 | 
					    - :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
 | 
				
			||||||
  * - :code:`BloomForCausalLM`
 | 
					  * - :code:`BloomForCausalLM`
 | 
				
			||||||
    - BLOOM, BLOOMZ, BLOOMChat
 | 
					    - BLOOM, BLOOMZ, BLOOMChat
 | 
				
			||||||
    - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
 | 
					    - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
 | 
				
			||||||
 | 
					  * - :code:`FalconForCausalLM`
 | 
				
			||||||
 | 
					    - Falcon
 | 
				
			||||||
 | 
					    - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
 | 
				
			||||||
  * - :code:`GPT2LMHeadModel`
 | 
					  * - :code:`GPT2LMHeadModel`
 | 
				
			||||||
    - GPT-2
 | 
					    - GPT-2
 | 
				
			||||||
    - :code:`gpt2`, :code:`gpt2-xl`, etc.
 | 
					    - :code:`gpt2`, :code:`gpt2-xl`, etc.
 | 
				
			||||||
  * - :code:`GPTBigCodeForCausalLM`
 | 
					  * - :code:`GPTBigCodeForCausalLM`
 | 
				
			||||||
    - StarCoder, SantaCoder, WizardCoder
 | 
					    - StarCoder, SantaCoder, WizardCoder
 | 
				
			||||||
    - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
 | 
					    - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
 | 
				
			||||||
 | 
					  * - :code:`GPTJForCausalLM`
 | 
				
			||||||
 | 
					    - GPT-J
 | 
				
			||||||
 | 
					    - :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
 | 
				
			||||||
  * - :code:`GPTNeoXForCausalLM`
 | 
					  * - :code:`GPTNeoXForCausalLM`
 | 
				
			||||||
    - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
 | 
					    - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
 | 
				
			||||||
    - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
 | 
					    - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
 | 
				
			||||||
 | 
					  * - :code:`InternLMForCausalLM`
 | 
				
			||||||
 | 
					    - InternLM
 | 
				
			||||||
 | 
					    - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
 | 
				
			||||||
  * - :code:`LlamaForCausalLM`
 | 
					  * - :code:`LlamaForCausalLM`
 | 
				
			||||||
    - LLaMA, Vicuna, Alpaca, Koala, Guanaco
 | 
					    - LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
 | 
				
			||||||
    - :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
 | 
					    - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
 | 
				
			||||||
 | 
					  * - :code:`MistralForCausalLM`
 | 
				
			||||||
 | 
					    - Mistral, Mistral-Instruct
 | 
				
			||||||
 | 
					    - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
 | 
				
			||||||
  * - :code:`MPTForCausalLM`
 | 
					  * - :code:`MPTForCausalLM`
 | 
				
			||||||
    - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
 | 
					    - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
 | 
				
			||||||
    - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
 | 
					    - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
 | 
				
			||||||
  * - :code:`OPTForCausalLM`
 | 
					  * - :code:`OPTForCausalLM`
 | 
				
			||||||
    - OPT, OPT-IML
 | 
					    - OPT, OPT-IML
 | 
				
			||||||
    - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
 | 
					    - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
 | 
				
			||||||
 | 
					  * - :code:`QWenLMHeadModel`
 | 
				
			||||||
 | 
					    - Qwen
 | 
				
			||||||
 | 
					    - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
 | 
					If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
 | 
				
			||||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
 | 
					Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										6
									
								
								docs/source/serving/deploying_with_triton.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					.. _deploying_with_triton:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Deploying with NVIDIA Triton
 | 
				
			||||||
 | 
					============================
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The `Triton Inference Server <https://github.com/triton-inference-server>`_ hosts a tutorial demonstrating how to quickly deploy a simple `facebook/opt-125m <https://huggingface.co/facebook/opt-125m>`_ model using vLLM. Please see `Deploying a vLLM model in Triton <https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/vLLM/README.md#deploying-a-vllm-model-in-triton>`_ for more details.
 | 
				
			||||||
							
								
								
									
										69
									
								
								docs/source/serving/run_on_sky.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,69 @@
 | 
				
			|||||||
 | 
					.. _on_cloud:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Running on clouds with SkyPilot
 | 
				
			||||||
 | 
					===============================
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.. raw:: html
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    <p align="center">
 | 
				
			||||||
 | 
					        <img src="https://imgur.com/yxtzPEu.png" alt="vLLM"/>
 | 
				
			||||||
 | 
					    </p>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot <https://github.com/skypilot-org/skypilot>`__, an open-source framework for running LLMs on any cloud.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					To install SkyPilot and setup your cloud credentials, run:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.. code-block:: console
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    $ pip install skypilot
 | 
				
			||||||
 | 
					    $ sky check
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					See the vLLM SkyPilot YAML for serving, `serving.yaml <https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml>`__.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.. code-block:: yaml
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    resources:
 | 
				
			||||||
 | 
					        accelerators: A100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    envs:
 | 
				
			||||||
 | 
					        MODEL_NAME: decapoda-research/llama-13b-hf
 | 
				
			||||||
 | 
					        TOKENIZER: hf-internal-testing/llama-tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    setup: |
 | 
				
			||||||
 | 
					        conda create -n vllm python=3.9 -y
 | 
				
			||||||
 | 
					        conda activate vllm
 | 
				
			||||||
 | 
					        git clone https://github.com/vllm-project/vllm.git
 | 
				
			||||||
 | 
					        cd vllm
 | 
				
			||||||
 | 
					        pip install .
 | 
				
			||||||
 | 
					        pip install gradio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    run: |
 | 
				
			||||||
 | 
					        conda activate vllm
 | 
				
			||||||
 | 
					        echo 'Starting vllm api server...'
 | 
				
			||||||
 | 
					        python -u -m vllm.entrypoints.api_server \
 | 
				
			||||||
 | 
					                        --model $MODEL_NAME \
 | 
				
			||||||
 | 
					                        --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
 | 
				
			||||||
 | 
					                        --tokenizer $TOKENIZER 2>&1 | tee api_server.log &
 | 
				
			||||||
 | 
					        echo 'Waiting for vllm api server to start...'
 | 
				
			||||||
 | 
					        while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
 | 
				
			||||||
 | 
					        echo 'Starting gradio server...'
 | 
				
			||||||
 | 
					        python vllm/examples/gradio_webserver.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Start the serving the LLaMA-13B model on an A100 GPU:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.. code-block:: console
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    $ sky launch serving.yaml
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.. code-block:: console
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    (task, pid=7431) Running on public URL: https://<gradio-hash>.gradio.live
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Optional**: Serve the 65B model instead of the default 13B and use more GPU:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.. code-block:: console
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -39,7 +39,7 @@ def build_demo():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser.add_argument("--host", type=str, default="localhost")
 | 
					    parser.add_argument("--host", type=str, default=None)
 | 
				
			||||||
    parser.add_argument("--port", type=int, default=8001)
 | 
					    parser.add_argument("--port", type=int, default=8001)
 | 
				
			||||||
    parser.add_argument("--model-url",
 | 
					    parser.add_argument("--model-url",
 | 
				
			||||||
                        type=str,
 | 
					                        type=str,
 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,8 @@ def main(args: argparse.Namespace):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Test the following prompts.
 | 
					    # Test the following prompts.
 | 
				
			||||||
    test_prompts = [
 | 
					    test_prompts = [
 | 
				
			||||||
        ("A robot may not injure a human being", SamplingParams()),
 | 
					        ("A robot may not injure a human being",
 | 
				
			||||||
 | 
					         SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
 | 
				
			||||||
        ("To be or not to be,",
 | 
					        ("To be or not to be,",
 | 
				
			||||||
         SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
 | 
					         SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
 | 
				
			||||||
        ("What is the meaning of life?",
 | 
					        ("What is the meaning of life?",
 | 
				
			||||||
@ -27,7 +28,7 @@ def main(args: argparse.Namespace):
 | 
				
			|||||||
    # Run the engine by calling `engine.step()` manually.
 | 
					    # Run the engine by calling `engine.step()` manually.
 | 
				
			||||||
    request_id = 0
 | 
					    request_id = 0
 | 
				
			||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        # To test iteration-level scheduling, we add one request at each step.
 | 
					        # To test continuous batching, we add one request at each step.
 | 
				
			||||||
        if test_prompts:
 | 
					        if test_prompts:
 | 
				
			||||||
            prompt, sampling_params = test_prompts.pop(0)
 | 
					            prompt, sampling_params = test_prompts.pop(0)
 | 
				
			||||||
            engine.add_request(str(request_id), prompt, sampling_params)
 | 
					            engine.add_request(str(request_id), prompt, sampling_params)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										33
									
								
								examples/openai_chatcompletion_client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,33 @@
 | 
				
			|||||||
 | 
					import openai
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Modify OpenAI's API key and API base to use vLLM's API server.
 | 
				
			||||||
 | 
					openai.api_key = "EMPTY"
 | 
				
			||||||
 | 
					openai.api_base = "http://localhost:8000/v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# List models API
 | 
				
			||||||
 | 
					models = openai.Model.list()
 | 
				
			||||||
 | 
					print("Models:", models)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					model = models["data"][0]["id"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Chat completion API
 | 
				
			||||||
 | 
					chat_completion = openai.ChatCompletion.create(
 | 
				
			||||||
 | 
					    model=model,
 | 
				
			||||||
 | 
					    messages=[{
 | 
				
			||||||
 | 
					        "role": "system",
 | 
				
			||||||
 | 
					        "content": "You are a helpful assistant."
 | 
				
			||||||
 | 
					    }, {
 | 
				
			||||||
 | 
					        "role": "user",
 | 
				
			||||||
 | 
					        "content": "Who won the world series in 2020?"
 | 
				
			||||||
 | 
					    }, {
 | 
				
			||||||
 | 
					        "role":
 | 
				
			||||||
 | 
					        "assistant",
 | 
				
			||||||
 | 
					        "content":
 | 
				
			||||||
 | 
					        "The Los Angeles Dodgers won the World Series in 2020."
 | 
				
			||||||
 | 
					    }, {
 | 
				
			||||||
 | 
					        "role": "user",
 | 
				
			||||||
 | 
					        "content": "Where was it played?"
 | 
				
			||||||
 | 
					    }])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print("Chat completion results:")
 | 
				
			||||||
 | 
					print(chat_completion)
 | 
				
			||||||
@ -3,26 +3,26 @@ import openai
 | 
				
			|||||||
# Modify OpenAI's API key and API base to use vLLM's API server.
 | 
					# Modify OpenAI's API key and API base to use vLLM's API server.
 | 
				
			||||||
openai.api_key = "EMPTY"
 | 
					openai.api_key = "EMPTY"
 | 
				
			||||||
openai.api_base = "http://localhost:8000/v1"
 | 
					openai.api_base = "http://localhost:8000/v1"
 | 
				
			||||||
model = "facebook/opt-125m"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Test list models API
 | 
					# List models API
 | 
				
			||||||
models = openai.Model.list()
 | 
					models = openai.Model.list()
 | 
				
			||||||
print("Models:", models)
 | 
					print("Models:", models)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Test completion API
 | 
					model = models["data"][0]["id"]
 | 
				
			||||||
stream = True
 | 
					
 | 
				
			||||||
 | 
					# Completion API
 | 
				
			||||||
 | 
					stream = False
 | 
				
			||||||
completion = openai.Completion.create(
 | 
					completion = openai.Completion.create(
 | 
				
			||||||
    model=model,
 | 
					    model=model,
 | 
				
			||||||
    prompt="A robot may not injure a human being",
 | 
					    prompt="A robot may not injure a human being",
 | 
				
			||||||
    echo=False,
 | 
					    echo=False,
 | 
				
			||||||
    n=2,
 | 
					    n=2,
 | 
				
			||||||
    best_of=3,
 | 
					 | 
				
			||||||
    stream=stream,
 | 
					    stream=stream,
 | 
				
			||||||
    logprobs=3)
 | 
					    logprobs=3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# print the completion
 | 
					print("Completion results:")
 | 
				
			||||||
if stream:
 | 
					if stream:
 | 
				
			||||||
    for c in completion:
 | 
					    for c in completion:
 | 
				
			||||||
        print(c)
 | 
					        print(c)
 | 
				
			||||||
else:
 | 
					else:
 | 
				
			||||||
    print("Completion result:", completion)
 | 
					    print(completion)
 | 
				
			||||||
@ -44,7 +44,6 @@ YAPF_FLAGS=(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
YAPF_EXCLUDES=(
 | 
					YAPF_EXCLUDES=(
 | 
				
			||||||
    '--exclude' 'build/**'
 | 
					    '--exclude' 'build/**'
 | 
				
			||||||
    '--exclude' 'vllm/model_executor/parallel_utils/**'
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Format specified files
 | 
					# Format specified files
 | 
				
			||||||
@ -72,7 +71,7 @@ format_changed() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Format all files
 | 
					# Format all files
 | 
				
			||||||
format_all() {
 | 
					format_all() {
 | 
				
			||||||
    yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
 | 
					    yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm tests
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## This flag formats individual files. --files *must* be the first command line
 | 
					## This flag formats individual files. --files *must* be the first command line
 | 
				
			||||||
@ -96,7 +95,7 @@ echo 'vLLM yapf: Done'
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Run Pylint
 | 
					# Run Pylint
 | 
				
			||||||
echo 'vLLM Pylint:'
 | 
					echo 'vLLM Pylint:'
 | 
				
			||||||
pylint vllm
 | 
					pylint vllm tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if ! git diff --quiet &>/dev/null; then
 | 
					if ! git diff --quiet &>/dev/null; then
 | 
				
			||||||
    echo 'Reformatted files. Please review and stage the changes.'
 | 
					    echo 'Reformatted files. Please review and stage the changes.'
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,7 @@ requires = [
 | 
				
			|||||||
    "ninja",
 | 
					    "ninja",
 | 
				
			||||||
    "packaging",
 | 
					    "packaging",
 | 
				
			||||||
    "setuptools",
 | 
					    "setuptools",
 | 
				
			||||||
    "torch >= 2.0.0",
 | 
					    "torch == 2.0.1",
 | 
				
			||||||
    "wheel",
 | 
					    "wheel",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
build-backend = "setuptools.build_meta"
 | 
					build-backend = "setuptools.build_meta"
 | 
				
			||||||
 | 
				
			|||||||
@ -10,3 +10,5 @@ types-setuptools
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# testing
 | 
					# testing
 | 
				
			||||||
pytest
 | 
					pytest
 | 
				
			||||||
 | 
					pytest-forked
 | 
				
			||||||
 | 
					pytest-asyncio
 | 
				
			||||||
 | 
				
			|||||||
@ -1,12 +1,13 @@
 | 
				
			|||||||
ninja  # For faster builds.
 | 
					ninja  # For faster builds.
 | 
				
			||||||
psutil
 | 
					psutil
 | 
				
			||||||
ray
 | 
					ray >= 2.5.1
 | 
				
			||||||
 | 
					pandas  # Required for Ray data.
 | 
				
			||||||
 | 
					pyarrow  # Required for Ray data.
 | 
				
			||||||
sentencepiece  # Required for LLaMA tokenizer.
 | 
					sentencepiece  # Required for LLaMA tokenizer.
 | 
				
			||||||
numpy
 | 
					numpy
 | 
				
			||||||
torch >= 2.0.0
 | 
					torch == 2.0.1
 | 
				
			||||||
transformers >= 4.28.0  # Required for LLaMA.
 | 
					transformers >= 4.34.0  # Required for Mistral.
 | 
				
			||||||
xformers >= 0.0.19
 | 
					xformers == 0.0.22  # Required for Mistral.
 | 
				
			||||||
fastapi
 | 
					fastapi
 | 
				
			||||||
uvicorn
 | 
					uvicorn[standard]
 | 
				
			||||||
pydantic  # Required for OpenAI server.
 | 
					pydantic < 2  # Required for OpenAI server.
 | 
				
			||||||
fschat  # Required for OpenAI ChatCompletion Endpoint.
 | 
					 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										171
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						@ -3,6 +3,7 @@ import os
 | 
				
			|||||||
import re
 | 
					import re
 | 
				
			||||||
import subprocess
 | 
					import subprocess
 | 
				
			||||||
from typing import List, Set
 | 
					from typing import List, Set
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from packaging.version import parse, Version
 | 
					from packaging.version import parse, Version
 | 
				
			||||||
import setuptools
 | 
					import setuptools
 | 
				
			||||||
@ -11,6 +12,9 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
ROOT_DIR = os.path.dirname(__file__)
 | 
					ROOT_DIR = os.path.dirname(__file__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Supported NVIDIA GPU architectures.
 | 
				
			||||||
 | 
					SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Compiler flags.
 | 
					# Compiler flags.
 | 
				
			||||||
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
 | 
					CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
 | 
				
			||||||
# TODO(woosuk): Should we use -O3?
 | 
					# TODO(woosuk): Should we use -O3?
 | 
				
			||||||
@ -22,7 +26,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
if CUDA_HOME is None:
 | 
					if CUDA_HOME is None:
 | 
				
			||||||
    raise RuntimeError(
 | 
					    raise RuntimeError(
 | 
				
			||||||
        f"Cannot find CUDA_HOME. CUDA must be available in order to build the package.")
 | 
					        "Cannot find CUDA_HOME. CUDA must be available to build the package.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
 | 
					def get_nvcc_cuda_version(cuda_dir: str) -> Version:
 | 
				
			||||||
@ -38,32 +42,95 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
 | 
				
			|||||||
    return nvcc_cuda_version
 | 
					    return nvcc_cuda_version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Collect the compute capabilities of all available GPUs.
 | 
					def get_torch_arch_list() -> Set[str]:
 | 
				
			||||||
device_count = torch.cuda.device_count()
 | 
					    # TORCH_CUDA_ARCH_LIST can have one or more architectures,
 | 
				
			||||||
compute_capabilities: Set[int] = set()
 | 
					    # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
 | 
				
			||||||
for i in range(device_count):
 | 
					    # compiler to additionally include PTX code that can be runtime-compiled
 | 
				
			||||||
    major, minor = torch.cuda.get_device_capability(i)
 | 
					    # and executed on the 8.6 or newer architectures. While the PTX code will
 | 
				
			||||||
    if major < 7:
 | 
					    # not give the best performance on the newer architectures, it provides
 | 
				
			||||||
 | 
					    # forward compatibility.
 | 
				
			||||||
 | 
					    env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
 | 
				
			||||||
 | 
					    if env_arch_list is None:
 | 
				
			||||||
 | 
					        return set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # List are separated by ; or space.
 | 
				
			||||||
 | 
					    torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
 | 
				
			||||||
 | 
					    if not torch_arch_list:
 | 
				
			||||||
 | 
					        return set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Filter out the invalid architectures and print a warning.
 | 
				
			||||||
 | 
					    valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
 | 
				
			||||||
 | 
					    arch_list = torch_arch_list.intersection(valid_archs)
 | 
				
			||||||
 | 
					    # If none of the specified architectures are valid, raise an error.
 | 
				
			||||||
 | 
					    if not arch_list:
 | 
				
			||||||
        raise RuntimeError(
 | 
					        raise RuntimeError(
 | 
				
			||||||
            "GPUs with compute capability less than 7.0 are not supported.")
 | 
					            "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
 | 
				
			||||||
    compute_capabilities.add(major * 10 + minor)
 | 
					            f"variable ({env_arch_list}) is supported. "
 | 
				
			||||||
# If no GPU is available, add all supported compute capabilities.
 | 
					            f"Supported CUDA architectures are: {valid_archs}.")
 | 
				
			||||||
 | 
					    invalid_arch_list = torch_arch_list - valid_archs
 | 
				
			||||||
 | 
					    if invalid_arch_list:
 | 
				
			||||||
 | 
					        warnings.warn(
 | 
				
			||||||
 | 
					            f"Unsupported CUDA architectures ({invalid_arch_list}) are "
 | 
				
			||||||
 | 
					            "excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
 | 
				
			||||||
 | 
					            f"({env_arch_list}). Supported CUDA architectures are: "
 | 
				
			||||||
 | 
					            f"{valid_archs}.")
 | 
				
			||||||
 | 
					    return arch_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# First, check the TORCH_CUDA_ARCH_LIST environment variable.
 | 
				
			||||||
 | 
					compute_capabilities = get_torch_arch_list()
 | 
				
			||||||
if not compute_capabilities:
 | 
					if not compute_capabilities:
 | 
				
			||||||
    compute_capabilities = {70, 75, 80, 86, 90}
 | 
					    # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
 | 
				
			||||||
# Add target compute capabilities to NVCC flags.
 | 
					    # GPUs on the current machine.
 | 
				
			||||||
for capability in compute_capabilities:
 | 
					    device_count = torch.cuda.device_count()
 | 
				
			||||||
    NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
 | 
					    for i in range(device_count):
 | 
				
			||||||
 | 
					        major, minor = torch.cuda.get_device_capability(i)
 | 
				
			||||||
 | 
					        if major < 7:
 | 
				
			||||||
 | 
					            raise RuntimeError(
 | 
				
			||||||
 | 
					                "GPUs with compute capability below 7.0 are not supported.")
 | 
				
			||||||
 | 
					        compute_capabilities.add(f"{major}.{minor}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
 | 
				
			||||||
 | 
					if not compute_capabilities:
 | 
				
			||||||
 | 
					    # If no GPU is specified nor available, add all supported architectures
 | 
				
			||||||
 | 
					    # based on the NVCC CUDA version.
 | 
				
			||||||
 | 
					    compute_capabilities = SUPPORTED_ARCHS.copy()
 | 
				
			||||||
 | 
					    if nvcc_cuda_version < Version("11.1"):
 | 
				
			||||||
 | 
					        compute_capabilities.remove("8.6")
 | 
				
			||||||
 | 
					    if nvcc_cuda_version < Version("11.8"):
 | 
				
			||||||
 | 
					        compute_capabilities.remove("8.9")
 | 
				
			||||||
 | 
					        compute_capabilities.remove("9.0")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Validate the NVCC CUDA version.
 | 
					# Validate the NVCC CUDA version.
 | 
				
			||||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
 | 
					 | 
				
			||||||
if nvcc_cuda_version < Version("11.0"):
 | 
					if nvcc_cuda_version < Version("11.0"):
 | 
				
			||||||
    raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
 | 
					    raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
 | 
				
			||||||
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
 | 
					if nvcc_cuda_version < Version("11.1"):
 | 
				
			||||||
    raise RuntimeError(
 | 
					    if any(cc.startswith("8.6") for cc in compute_capabilities):
 | 
				
			||||||
        "CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
 | 
					        raise RuntimeError(
 | 
				
			||||||
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
 | 
					            "CUDA 11.1 or higher is required for compute capability 8.6.")
 | 
				
			||||||
    raise RuntimeError(
 | 
					if nvcc_cuda_version < Version("11.8"):
 | 
				
			||||||
        "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
 | 
					    if any(cc.startswith("8.9") for cc in compute_capabilities):
 | 
				
			||||||
 | 
					        # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
 | 
				
			||||||
 | 
					        # However, GPUs with compute capability 8.9 can also run the code generated by
 | 
				
			||||||
 | 
					        # the previous versions of CUDA 11 and targeting compute capability 8.0.
 | 
				
			||||||
 | 
					        # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
 | 
				
			||||||
 | 
					        # instead of 8.9.
 | 
				
			||||||
 | 
					        warnings.warn(
 | 
				
			||||||
 | 
					            "CUDA 11.8 or higher is required for compute capability 8.9. "
 | 
				
			||||||
 | 
					            "Targeting compute capability 8.0 instead.")
 | 
				
			||||||
 | 
					        compute_capabilities = set(cc for cc in compute_capabilities
 | 
				
			||||||
 | 
					                                   if not cc.startswith("8.9"))
 | 
				
			||||||
 | 
					        compute_capabilities.add("8.0+PTX")
 | 
				
			||||||
 | 
					    if any(cc.startswith("9.0") for cc in compute_capabilities):
 | 
				
			||||||
 | 
					        raise RuntimeError(
 | 
				
			||||||
 | 
					            "CUDA 11.8 or higher is required for compute capability 9.0.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Add target compute capabilities to NVCC flags.
 | 
				
			||||||
 | 
					for capability in compute_capabilities:
 | 
				
			||||||
 | 
					    num = capability[0] + capability[2]
 | 
				
			||||||
 | 
					    NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
 | 
				
			||||||
 | 
					    if capability.endswith("+PTX"):
 | 
				
			||||||
 | 
					        NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Use NVCC threads to parallelize the build.
 | 
					# Use NVCC threads to parallelize the build.
 | 
				
			||||||
if nvcc_cuda_version >= Version("11.2"):
 | 
					if nvcc_cuda_version >= Version("11.2"):
 | 
				
			||||||
@ -76,7 +143,10 @@ ext_modules = []
 | 
				
			|||||||
cache_extension = CUDAExtension(
 | 
					cache_extension = CUDAExtension(
 | 
				
			||||||
    name="vllm.cache_ops",
 | 
					    name="vllm.cache_ops",
 | 
				
			||||||
    sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
 | 
					    sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
 | 
				
			||||||
    extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
 | 
					    extra_compile_args={
 | 
				
			||||||
 | 
					        "cxx": CXX_FLAGS,
 | 
				
			||||||
 | 
					        "nvcc": NVCC_FLAGS,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
ext_modules.append(cache_extension)
 | 
					ext_modules.append(cache_extension)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -84,7 +154,10 @@ ext_modules.append(cache_extension)
 | 
				
			|||||||
attention_extension = CUDAExtension(
 | 
					attention_extension = CUDAExtension(
 | 
				
			||||||
    name="vllm.attention_ops",
 | 
					    name="vllm.attention_ops",
 | 
				
			||||||
    sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
 | 
					    sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
 | 
				
			||||||
    extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
 | 
					    extra_compile_args={
 | 
				
			||||||
 | 
					        "cxx": CXX_FLAGS,
 | 
				
			||||||
 | 
					        "nvcc": NVCC_FLAGS,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
ext_modules.append(attention_extension)
 | 
					ext_modules.append(attention_extension)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -92,7 +165,10 @@ ext_modules.append(attention_extension)
 | 
				
			|||||||
positional_encoding_extension = CUDAExtension(
 | 
					positional_encoding_extension = CUDAExtension(
 | 
				
			||||||
    name="vllm.pos_encoding_ops",
 | 
					    name="vllm.pos_encoding_ops",
 | 
				
			||||||
    sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
 | 
					    sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
 | 
				
			||||||
    extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
 | 
					    extra_compile_args={
 | 
				
			||||||
 | 
					        "cxx": CXX_FLAGS,
 | 
				
			||||||
 | 
					        "nvcc": NVCC_FLAGS,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
ext_modules.append(positional_encoding_extension)
 | 
					ext_modules.append(positional_encoding_extension)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -100,7 +176,10 @@ ext_modules.append(positional_encoding_extension)
 | 
				
			|||||||
layernorm_extension = CUDAExtension(
 | 
					layernorm_extension = CUDAExtension(
 | 
				
			||||||
    name="vllm.layernorm_ops",
 | 
					    name="vllm.layernorm_ops",
 | 
				
			||||||
    sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
 | 
					    sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
 | 
				
			||||||
    extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
 | 
					    extra_compile_args={
 | 
				
			||||||
 | 
					        "cxx": CXX_FLAGS,
 | 
				
			||||||
 | 
					        "nvcc": NVCC_FLAGS,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
ext_modules.append(layernorm_extension)
 | 
					ext_modules.append(layernorm_extension)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -108,10 +187,38 @@ ext_modules.append(layernorm_extension)
 | 
				
			|||||||
activation_extension = CUDAExtension(
 | 
					activation_extension = CUDAExtension(
 | 
				
			||||||
    name="vllm.activation_ops",
 | 
					    name="vllm.activation_ops",
 | 
				
			||||||
    sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
 | 
					    sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
 | 
				
			||||||
    extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
 | 
					    extra_compile_args={
 | 
				
			||||||
 | 
					        "cxx": CXX_FLAGS,
 | 
				
			||||||
 | 
					        "nvcc": NVCC_FLAGS,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
ext_modules.append(activation_extension)
 | 
					ext_modules.append(activation_extension)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Quantization kernels.
 | 
				
			||||||
 | 
					quantization_extension = CUDAExtension(
 | 
				
			||||||
 | 
					    name="vllm.quantization_ops",
 | 
				
			||||||
 | 
					    sources=[
 | 
				
			||||||
 | 
					        "csrc/quantization.cpp",
 | 
				
			||||||
 | 
					        "csrc/quantization/awq/gemm_kernels.cu",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    extra_compile_args={
 | 
				
			||||||
 | 
					        "cxx": CXX_FLAGS,
 | 
				
			||||||
 | 
					        "nvcc": NVCC_FLAGS,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					ext_modules.append(quantization_extension)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Misc. CUDA utils.
 | 
				
			||||||
 | 
					cuda_utils_extension = CUDAExtension(
 | 
				
			||||||
 | 
					    name="vllm.cuda_utils",
 | 
				
			||||||
 | 
					    sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
 | 
				
			||||||
 | 
					    extra_compile_args={
 | 
				
			||||||
 | 
					        "cxx": CXX_FLAGS,
 | 
				
			||||||
 | 
					        "nvcc": NVCC_FLAGS,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					ext_modules.append(cuda_utils_extension)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_path(*filepath) -> str:
 | 
					def get_path(*filepath) -> str:
 | 
				
			||||||
    return os.path.join(ROOT_DIR, *filepath)
 | 
					    return os.path.join(ROOT_DIR, *filepath)
 | 
				
			||||||
@ -123,8 +230,8 @@ def find_version(filepath: str):
 | 
				
			|||||||
    Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
 | 
					    Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    with open(filepath) as fp:
 | 
					    with open(filepath) as fp:
 | 
				
			||||||
        version_match = re.search(
 | 
					        version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
 | 
				
			||||||
            r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
 | 
					                                  fp.read(), re.M)
 | 
				
			||||||
        if version_match:
 | 
					        if version_match:
 | 
				
			||||||
            return version_match.group(1)
 | 
					            return version_match.group(1)
 | 
				
			||||||
        raise RuntimeError("Unable to find version string.")
 | 
					        raise RuntimeError("Unable to find version string.")
 | 
				
			||||||
@ -147,7 +254,8 @@ setuptools.setup(
 | 
				
			|||||||
    version=find_version(get_path("vllm", "__init__.py")),
 | 
					    version=find_version(get_path("vllm", "__init__.py")),
 | 
				
			||||||
    author="vLLM Team",
 | 
					    author="vLLM Team",
 | 
				
			||||||
    license="Apache 2.0",
 | 
					    license="Apache 2.0",
 | 
				
			||||||
    description="A high-throughput and memory-efficient inference and serving engine for LLMs",
 | 
					    description=("A high-throughput and memory-efficient inference and "
 | 
				
			||||||
 | 
					                 "serving engine for LLMs"),
 | 
				
			||||||
    long_description=read_readme(),
 | 
					    long_description=read_readme(),
 | 
				
			||||||
    long_description_content_type="text/markdown",
 | 
					    long_description_content_type="text/markdown",
 | 
				
			||||||
    url="https://github.com/vllm-project/vllm",
 | 
					    url="https://github.com/vllm-project/vllm",
 | 
				
			||||||
@ -159,11 +267,12 @@ setuptools.setup(
 | 
				
			|||||||
        "Programming Language :: Python :: 3.8",
 | 
					        "Programming Language :: Python :: 3.8",
 | 
				
			||||||
        "Programming Language :: Python :: 3.9",
 | 
					        "Programming Language :: Python :: 3.9",
 | 
				
			||||||
        "Programming Language :: Python :: 3.10",
 | 
					        "Programming Language :: Python :: 3.10",
 | 
				
			||||||
 | 
					        "Programming Language :: Python :: 3.11",
 | 
				
			||||||
        "License :: OSI Approved :: Apache Software License",
 | 
					        "License :: OSI Approved :: Apache Software License",
 | 
				
			||||||
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
 | 
					        "Topic :: Scientific/Engineering :: Artificial Intelligence",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    packages=setuptools.find_packages(
 | 
					    packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
 | 
				
			||||||
        exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
 | 
					                                               "examples", "tests")),
 | 
				
			||||||
    python_requires=">=3.8",
 | 
					    python_requires=">=3.8",
 | 
				
			||||||
    install_requires=get_requirements(),
 | 
					    install_requires=get_requirements(),
 | 
				
			||||||
    ext_modules=ext_modules,
 | 
					    ext_modules=ext_modules,
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										51
									
								
								tests/async_engine/api_server_async_engine.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,51 @@
 | 
				
			|||||||
 | 
					"""vllm.entrypoints.api_server with some extra logging for testing."""
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					from typing import Any, Dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import uvicorn
 | 
				
			||||||
 | 
					from fastapi.responses import JSONResponse, Response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import vllm.entrypoints.api_server
 | 
				
			||||||
 | 
					from vllm.engine.arg_utils import AsyncEngineArgs
 | 
				
			||||||
 | 
					from vllm.engine.async_llm_engine import AsyncLLMEngine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					app = vllm.entrypoints.api_server.app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AsyncLLMEngineWithStats(AsyncLLMEngine):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # pylint: disable=redefined-outer-name
 | 
				
			||||||
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					        self._num_aborts = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def abort(self, request_id: str) -> None:
 | 
				
			||||||
 | 
					        await super().abort(request_id)
 | 
				
			||||||
 | 
					        self._num_aborts += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testing_stats(self) -> Dict[str, Any]:
 | 
				
			||||||
 | 
					        return {"num_aborted_requests": self._num_aborts}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.get("/stats")
 | 
				
			||||||
 | 
					def stats() -> Response:
 | 
				
			||||||
 | 
					    """Get the statistics of the engine."""
 | 
				
			||||||
 | 
					    return JSONResponse(engine.testing_stats())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser.add_argument("--host", type=str, default="localhost")
 | 
				
			||||||
 | 
					    parser.add_argument("--port", type=int, default=8000)
 | 
				
			||||||
 | 
					    parser = AsyncEngineArgs.add_cli_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    engine_args = AsyncEngineArgs.from_cli_args(args)
 | 
				
			||||||
 | 
					    engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
 | 
				
			||||||
 | 
					    vllm.entrypoints.api_server.engine = engine
 | 
				
			||||||
 | 
					    uvicorn.run(
 | 
				
			||||||
 | 
					        app,
 | 
				
			||||||
 | 
					        host=args.host,
 | 
				
			||||||
 | 
					        port=args.port,
 | 
				
			||||||
 | 
					        log_level="debug",
 | 
				
			||||||
 | 
					        timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE)
 | 
				
			||||||
							
								
								
									
										89
									
								
								tests/async_engine/test_api_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,89 @@
 | 
				
			|||||||
 | 
					import subprocess
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					from multiprocessing import Pool
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _query_server(prompt: str) -> dict:
 | 
				
			||||||
 | 
					    response = requests.post("http://localhost:8000/generate",
 | 
				
			||||||
 | 
					                             json={
 | 
				
			||||||
 | 
					                                 "prompt": prompt,
 | 
				
			||||||
 | 
					                                 "max_tokens": 100,
 | 
				
			||||||
 | 
					                                 "temperature": 0,
 | 
				
			||||||
 | 
					                                 "ignore_eos": True
 | 
				
			||||||
 | 
					                             })
 | 
				
			||||||
 | 
					    response.raise_for_status()
 | 
				
			||||||
 | 
					    return response.json()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def api_server():
 | 
				
			||||||
 | 
					    script_path = Path(__file__).parent.joinpath(
 | 
				
			||||||
 | 
					        "api_server_async_engine.py").absolute()
 | 
				
			||||||
 | 
					    # pylint: disable=consider-using-with
 | 
				
			||||||
 | 
					    uvicorn_process = subprocess.Popen([
 | 
				
			||||||
 | 
					        sys.executable, "-u",
 | 
				
			||||||
 | 
					        str(script_path), "--model", "facebook/opt-125m"
 | 
				
			||||||
 | 
					    ])
 | 
				
			||||||
 | 
					    yield
 | 
				
			||||||
 | 
					    uvicorn_process.terminate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# pylint: disable=redefined-outer-name, unused-argument
 | 
				
			||||||
 | 
					def test_api_server(api_server):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Run the API server and test it.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    We run both the server and requests in separate processes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    We test that the server can handle incoming requests, including
 | 
				
			||||||
 | 
					    multiple requests at the same time, and that it can handle requests
 | 
				
			||||||
 | 
					    being cancelled without crashing.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    with Pool(32) as pool:
 | 
				
			||||||
 | 
					        # Wait until the server is ready
 | 
				
			||||||
 | 
					        prompts = ["Hello world"] * 1
 | 
				
			||||||
 | 
					        result = None
 | 
				
			||||||
 | 
					        while not result:
 | 
				
			||||||
 | 
					            # pylint: disable=bare-except
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                for result in pool.map(_query_server, prompts):
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					            except:
 | 
				
			||||||
 | 
					                time.sleep(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Actual tests start here
 | 
				
			||||||
 | 
					        # Try with 1 prompt
 | 
				
			||||||
 | 
					        for result in pool.map(_query_server, prompts):
 | 
				
			||||||
 | 
					            assert result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        num_aborted_requests = requests.get(
 | 
				
			||||||
 | 
					            "http://localhost:8000/stats").json()["num_aborted_requests"]
 | 
				
			||||||
 | 
					        assert num_aborted_requests == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Try with 100 prompts
 | 
				
			||||||
 | 
					        prompts = ["Hello world"] * 100
 | 
				
			||||||
 | 
					        for result in pool.map(_query_server, prompts):
 | 
				
			||||||
 | 
					            assert result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Cancel requests
 | 
				
			||||||
 | 
					        pool.map_async(_query_server, prompts)
 | 
				
			||||||
 | 
					        time.sleep(0.01)
 | 
				
			||||||
 | 
					        pool.terminate()
 | 
				
			||||||
 | 
					        pool.join()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # check cancellation stats
 | 
				
			||||||
 | 
					        num_aborted_requests = requests.get(
 | 
				
			||||||
 | 
					            "http://localhost:8000/stats").json()["num_aborted_requests"]
 | 
				
			||||||
 | 
					        assert num_aborted_requests > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # check that server still runs after cancellations
 | 
				
			||||||
 | 
					    with Pool(32) as pool:
 | 
				
			||||||
 | 
					        # Try with 100 prompts
 | 
				
			||||||
 | 
					        prompts = ["Hello world"] * 100
 | 
				
			||||||
 | 
					        for result in pool.map(_query_server, prompts):
 | 
				
			||||||
 | 
					            assert result
 | 
				
			||||||
							
								
								
									
										80
									
								
								tests/async_engine/test_async_llm_engine.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,80 @@
 | 
				
			|||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm.engine.async_llm_engine import AsyncLLMEngine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class RequestOutput:
 | 
				
			||||||
 | 
					    request_id: int
 | 
				
			||||||
 | 
					    finished: bool = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MockEngine:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.step_calls = 0
 | 
				
			||||||
 | 
					        self.add_request_calls = 0
 | 
				
			||||||
 | 
					        self.abort_request_calls = 0
 | 
				
			||||||
 | 
					        self.request_id = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def step_async(self):
 | 
				
			||||||
 | 
					        self.step_calls += 1
 | 
				
			||||||
 | 
					        return [RequestOutput(
 | 
				
			||||||
 | 
					            request_id=self.request_id)] if self.request_id else []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, request_id):
 | 
				
			||||||
 | 
					        self.request_id = request_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def stop_generating(self):
 | 
				
			||||||
 | 
					        self.request_id = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_request(self, **kwargs):
 | 
				
			||||||
 | 
					        del kwargs  # Unused
 | 
				
			||||||
 | 
					        self.add_request_calls += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def abort_request(self, request_id):
 | 
				
			||||||
 | 
					        del request_id  # Unused
 | 
				
			||||||
 | 
					        self.abort_request_calls += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MockAsyncLLMEngine(AsyncLLMEngine):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _init_engine(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        return MockEngine()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_new_requests_event():
 | 
				
			||||||
 | 
					    engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
 | 
				
			||||||
 | 
					    engine.start_background_loop()
 | 
				
			||||||
 | 
					    await asyncio.sleep(0.01)
 | 
				
			||||||
 | 
					    assert engine.engine.step_calls == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await engine.add_request("1", "", None)
 | 
				
			||||||
 | 
					    await asyncio.sleep(0.01)
 | 
				
			||||||
 | 
					    assert engine.engine.add_request_calls == 1
 | 
				
			||||||
 | 
					    assert engine.engine.step_calls == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await engine.add_request("2", "", None)
 | 
				
			||||||
 | 
					    engine.engine.generate("2")
 | 
				
			||||||
 | 
					    await asyncio.sleep(0)
 | 
				
			||||||
 | 
					    assert engine.engine.add_request_calls == 2
 | 
				
			||||||
 | 
					    assert engine.engine.step_calls == 2
 | 
				
			||||||
 | 
					    await asyncio.sleep(0)
 | 
				
			||||||
 | 
					    assert engine.engine.step_calls == 3
 | 
				
			||||||
 | 
					    engine.engine.stop_generating()
 | 
				
			||||||
 | 
					    await asyncio.sleep(0)
 | 
				
			||||||
 | 
					    assert engine.engine.step_calls == 4
 | 
				
			||||||
 | 
					    await asyncio.sleep(0)
 | 
				
			||||||
 | 
					    assert engine.engine.step_calls == 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await engine.add_request("3", "", None)
 | 
				
			||||||
 | 
					    await asyncio.sleep(0.01)
 | 
				
			||||||
 | 
					    assert engine.engine.add_request_calls == 3
 | 
				
			||||||
 | 
					    assert engine.engine.step_calls == 5
 | 
				
			||||||
 | 
					    await asyncio.sleep(0.01)
 | 
				
			||||||
 | 
					    assert engine.engine.add_request_calls == 3
 | 
				
			||||||
 | 
					    assert engine.engine.step_calls == 5
 | 
				
			||||||
							
								
								
									
										75
									
								
								tests/async_engine/test_request_tracker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,75 @@
 | 
				
			|||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm.engine.async_llm_engine import RequestTracker
 | 
				
			||||||
 | 
					from vllm.outputs import RequestOutput
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DummyEvent:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.flag = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set(self):
 | 
				
			||||||
 | 
					        self.flag = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def clear(self):
 | 
				
			||||||
 | 
					        self.flag = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_request_tracker():
 | 
				
			||||||
 | 
					    tracker = RequestTracker()
 | 
				
			||||||
 | 
					    tracker.new_requests_event = DummyEvent()
 | 
				
			||||||
 | 
					    stream_1 = tracker.add_request("1")
 | 
				
			||||||
 | 
					    assert tracker.new_requests_event.flag
 | 
				
			||||||
 | 
					    new, finished = tracker.get_new_and_finished_requests()
 | 
				
			||||||
 | 
					    assert not tracker.new_requests_event.flag
 | 
				
			||||||
 | 
					    assert len(new) == 1
 | 
				
			||||||
 | 
					    assert new[0]["request_id"] == "1"
 | 
				
			||||||
 | 
					    assert not finished
 | 
				
			||||||
 | 
					    assert not stream_1.finished
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    stream_2 = tracker.add_request("2")
 | 
				
			||||||
 | 
					    stream_3 = tracker.add_request("3")
 | 
				
			||||||
 | 
					    assert tracker.new_requests_event.flag
 | 
				
			||||||
 | 
					    new, finished = tracker.get_new_and_finished_requests()
 | 
				
			||||||
 | 
					    assert not tracker.new_requests_event.flag
 | 
				
			||||||
 | 
					    assert len(new) == 2
 | 
				
			||||||
 | 
					    assert new[0]["request_id"] == "2"
 | 
				
			||||||
 | 
					    assert new[1]["request_id"] == "3"
 | 
				
			||||||
 | 
					    assert not finished
 | 
				
			||||||
 | 
					    assert not stream_2.finished
 | 
				
			||||||
 | 
					    assert not stream_3.finished
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # request_ids must be unique
 | 
				
			||||||
 | 
					    with pytest.raises(KeyError):
 | 
				
			||||||
 | 
					        tracker.add_request("1")
 | 
				
			||||||
 | 
					    assert not tracker.new_requests_event.flag
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tracker.abort_request("1")
 | 
				
			||||||
 | 
					    new, finished = tracker.get_new_and_finished_requests()
 | 
				
			||||||
 | 
					    assert len(finished) == 1
 | 
				
			||||||
 | 
					    assert "1" in finished
 | 
				
			||||||
 | 
					    assert not new
 | 
				
			||||||
 | 
					    assert stream_1.finished
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    stream_4 = tracker.add_request("4")
 | 
				
			||||||
 | 
					    tracker.abort_request("4")
 | 
				
			||||||
 | 
					    assert tracker.new_requests_event.flag
 | 
				
			||||||
 | 
					    new, finished = tracker.get_new_and_finished_requests()
 | 
				
			||||||
 | 
					    assert len(finished) == 1
 | 
				
			||||||
 | 
					    assert "4" in finished
 | 
				
			||||||
 | 
					    assert not new
 | 
				
			||||||
 | 
					    assert stream_4.finished
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    stream_5 = tracker.add_request("5")
 | 
				
			||||||
 | 
					    assert tracker.new_requests_event.flag
 | 
				
			||||||
 | 
					    tracker.process_request_output(
 | 
				
			||||||
 | 
					        RequestOutput("2", "output", [], [], [], finished=True))
 | 
				
			||||||
 | 
					    new, finished = tracker.get_new_and_finished_requests()
 | 
				
			||||||
 | 
					    assert not tracker.new_requests_event.flag
 | 
				
			||||||
 | 
					    assert len(finished) == 1
 | 
				
			||||||
 | 
					    assert "2" in finished
 | 
				
			||||||
 | 
					    assert len(new) == 1
 | 
				
			||||||
 | 
					    assert new[0]["request_id"] == "5"
 | 
				
			||||||
 | 
					    assert stream_2.finished
 | 
				
			||||||
 | 
					    assert not stream_5.finished
 | 
				
			||||||
							
								
								
									
										212
									
								
								tests/conftest.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,212 @@
 | 
				
			|||||||
 | 
					from typing import List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from transformers import AutoModelForCausalLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm import LLM, SamplingParams
 | 
				
			||||||
 | 
					from vllm.transformers_utils.tokenizer import get_tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_TEST_PROMPTS = [
 | 
				
			||||||
 | 
					    # pylint: disable=line-too-long
 | 
				
			||||||
 | 
					    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
 | 
				
			||||||
 | 
					    "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
 | 
				
			||||||
 | 
					    "Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
 | 
				
			||||||
 | 
					    "Describe the basic components of a neural network and how it can be trained.",
 | 
				
			||||||
 | 
					    "Write a short story about a robot that dreams for the first time.",
 | 
				
			||||||
 | 
					    "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
 | 
				
			||||||
 | 
					    "Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
 | 
				
			||||||
 | 
					    "Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def example_prompts() -> List[str]:
 | 
				
			||||||
 | 
					    return _TEST_PROMPTS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_STR_DTYPE_TO_TORCH_DTYPE = {
 | 
				
			||||||
 | 
					    "half": torch.half,
 | 
				
			||||||
 | 
					    "bfloat16": torch.bfloat16,
 | 
				
			||||||
 | 
					    "float": torch.float,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HfRunner:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        model_name: str,
 | 
				
			||||||
 | 
					        tokenizer_name: Optional[str] = None,
 | 
				
			||||||
 | 
					        dtype: str = "half",
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
 | 
				
			||||||
 | 
					        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
 | 
				
			||||||
 | 
					        self.model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
 | 
					            model_name,
 | 
				
			||||||
 | 
					            torch_dtype=torch_dtype,
 | 
				
			||||||
 | 
					            trust_remote_code=True,
 | 
				
			||||||
 | 
					        ).cuda()
 | 
				
			||||||
 | 
					        if tokenizer_name is None:
 | 
				
			||||||
 | 
					            tokenizer_name = model_name
 | 
				
			||||||
 | 
					        self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompts: List[str],
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> List[Tuple[List[int], str]]:
 | 
				
			||||||
 | 
					        outputs: List[Tuple[List[int], str]] = []
 | 
				
			||||||
 | 
					        for prompt in prompts:
 | 
				
			||||||
 | 
					            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
 | 
				
			||||||
 | 
					            output_ids = self.model.generate(
 | 
				
			||||||
 | 
					                input_ids.cuda(),
 | 
				
			||||||
 | 
					                use_cache=True,
 | 
				
			||||||
 | 
					                **kwargs,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            output_str = self.tokenizer.batch_decode(
 | 
				
			||||||
 | 
					                output_ids,
 | 
				
			||||||
 | 
					                skip_special_tokens=True,
 | 
				
			||||||
 | 
					                clean_up_tokenization_spaces=False,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            output_ids = output_ids.cpu().tolist()
 | 
				
			||||||
 | 
					            outputs.append((output_ids, output_str))
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_greedy(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompts: List[str],
 | 
				
			||||||
 | 
					        max_tokens: int,
 | 
				
			||||||
 | 
					    ) -> List[Tuple[List[int], str]]:
 | 
				
			||||||
 | 
					        outputs = self.generate(prompts,
 | 
				
			||||||
 | 
					                                do_sample=False,
 | 
				
			||||||
 | 
					                                max_new_tokens=max_tokens)
 | 
				
			||||||
 | 
					        for i in range(len(outputs)):
 | 
				
			||||||
 | 
					            output_ids, output_str = outputs[i]
 | 
				
			||||||
 | 
					            outputs[i] = (output_ids[0], output_str[0])
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_beam_search(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompts: List[str],
 | 
				
			||||||
 | 
					        beam_width: int,
 | 
				
			||||||
 | 
					        max_tokens: int,
 | 
				
			||||||
 | 
					    ) -> List[Tuple[List[int], str]]:
 | 
				
			||||||
 | 
					        outputs = self.generate(prompts,
 | 
				
			||||||
 | 
					                                do_sample=False,
 | 
				
			||||||
 | 
					                                max_new_tokens=max_tokens,
 | 
				
			||||||
 | 
					                                num_beams=beam_width,
 | 
				
			||||||
 | 
					                                num_return_sequences=beam_width)
 | 
				
			||||||
 | 
					        for i in range(len(outputs)):
 | 
				
			||||||
 | 
					            output_ids, output_str = outputs[i]
 | 
				
			||||||
 | 
					            for j in range(len(output_ids)):
 | 
				
			||||||
 | 
					                output_ids[j] = [
 | 
				
			||||||
 | 
					                    x for x in output_ids[j]
 | 
				
			||||||
 | 
					                    if x != self.tokenizer.pad_token_id
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
 | 
					            outputs[i] = (output_ids, output_str)
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_greedy_logprobs(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompts: List[str],
 | 
				
			||||||
 | 
					        max_tokens: int,
 | 
				
			||||||
 | 
					    ) -> List[List[torch.Tensor]]:
 | 
				
			||||||
 | 
					        all_logprobs = []
 | 
				
			||||||
 | 
					        for prompt in prompts:
 | 
				
			||||||
 | 
					            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
 | 
				
			||||||
 | 
					            output = self.model.generate(
 | 
				
			||||||
 | 
					                input_ids.cuda(),
 | 
				
			||||||
 | 
					                use_cache=True,
 | 
				
			||||||
 | 
					                do_sample=False,
 | 
				
			||||||
 | 
					                max_new_tokens=max_tokens,
 | 
				
			||||||
 | 
					                output_hidden_states=True,
 | 
				
			||||||
 | 
					                return_dict_in_generate=True,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            seq_logprobs = []
 | 
				
			||||||
 | 
					            for hidden_states in output.hidden_states:
 | 
				
			||||||
 | 
					                last_hidden_states = hidden_states[-1][0]
 | 
				
			||||||
 | 
					                logits = torch.matmul(
 | 
				
			||||||
 | 
					                    last_hidden_states,
 | 
				
			||||||
 | 
					                    self.model.get_output_embeddings().weight.t(),
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                if self.model.get_output_embeddings().bias is not None:
 | 
				
			||||||
 | 
					                    logits += self.model.get_output_embeddings(
 | 
				
			||||||
 | 
					                    ).bias.unsqueeze(0)
 | 
				
			||||||
 | 
					                logprobs = torch.nn.functional.log_softmax(logits,
 | 
				
			||||||
 | 
					                                                           dim=-1,
 | 
				
			||||||
 | 
					                                                           dtype=torch.float32)
 | 
				
			||||||
 | 
					                seq_logprobs.append(logprobs)
 | 
				
			||||||
 | 
					            all_logprobs.append(seq_logprobs)
 | 
				
			||||||
 | 
					        return all_logprobs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def hf_runner():
 | 
				
			||||||
 | 
					    return HfRunner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VllmRunner:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        model_name: str,
 | 
				
			||||||
 | 
					        tokenizer_name: Optional[str] = None,
 | 
				
			||||||
 | 
					        dtype: str = "half",
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        self.model = LLM(
 | 
				
			||||||
 | 
					            model=model_name,
 | 
				
			||||||
 | 
					            tokenizer=tokenizer_name,
 | 
				
			||||||
 | 
					            trust_remote_code=True,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            swap_space=0,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompts: List[str],
 | 
				
			||||||
 | 
					        sampling_params: SamplingParams,
 | 
				
			||||||
 | 
					    ) -> List[Tuple[List[int], str]]:
 | 
				
			||||||
 | 
					        req_outputs = self.model.generate(prompts,
 | 
				
			||||||
 | 
					                                          sampling_params=sampling_params)
 | 
				
			||||||
 | 
					        outputs = []
 | 
				
			||||||
 | 
					        for req_output in req_outputs:
 | 
				
			||||||
 | 
					            prompt_str = req_output.prompt
 | 
				
			||||||
 | 
					            prompt_ids = req_output.prompt_token_ids
 | 
				
			||||||
 | 
					            req_sample_output_ids = []
 | 
				
			||||||
 | 
					            req_sample_output_strs = []
 | 
				
			||||||
 | 
					            for sample in req_output.outputs:
 | 
				
			||||||
 | 
					                output_str = sample.text
 | 
				
			||||||
 | 
					                output_ids = sample.token_ids
 | 
				
			||||||
 | 
					                req_sample_output_ids.append(prompt_ids + output_ids)
 | 
				
			||||||
 | 
					                req_sample_output_strs.append(prompt_str + output_str)
 | 
				
			||||||
 | 
					            outputs.append((req_sample_output_ids, req_sample_output_strs))
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_greedy(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompts: List[str],
 | 
				
			||||||
 | 
					        max_tokens: int,
 | 
				
			||||||
 | 
					    ) -> List[Tuple[List[int], str]]:
 | 
				
			||||||
 | 
					        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
 | 
				
			||||||
 | 
					        outputs = self.generate(prompts, greedy_params)
 | 
				
			||||||
 | 
					        return [(output_ids[0], output_str[0])
 | 
				
			||||||
 | 
					                for output_ids, output_str in outputs]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_beam_search(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompts: List[str],
 | 
				
			||||||
 | 
					        beam_width: int,
 | 
				
			||||||
 | 
					        max_tokens: int,
 | 
				
			||||||
 | 
					    ) -> List[Tuple[List[int], str]]:
 | 
				
			||||||
 | 
					        beam_search_params = SamplingParams(n=beam_width,
 | 
				
			||||||
 | 
					                                            use_beam_search=True,
 | 
				
			||||||
 | 
					                                            temperature=0.0,
 | 
				
			||||||
 | 
					                                            max_tokens=max_tokens)
 | 
				
			||||||
 | 
					        outputs = self.generate(prompts, beam_search_params)
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def vllm_runner():
 | 
				
			||||||
 | 
					    return VllmRunner
 | 
				
			||||||
							
								
								
									
										82
									
								
								tests/distributed/test_comm_ops.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,82 @@
 | 
				
			|||||||
 | 
					"""Test the communication operators.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Run `pytest tests/distributed/test_comm_ops.py --forked`.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					from multiprocessing import Process
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm.config import ParallelConfig
 | 
				
			||||||
 | 
					from vllm.engine.ray_utils import get_open_port
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.communication_op import (
 | 
				
			||||||
 | 
					    tensor_model_parallel_all_reduce,
 | 
				
			||||||
 | 
					    tensor_model_parallel_all_gather,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from vllm.worker.worker import _init_distributed_environment
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def init_test_distributed_environment(pipeline_parallel_size: int,
 | 
				
			||||||
 | 
					                                      tensor_parallel_size: int, rank: int,
 | 
				
			||||||
 | 
					                                      distributed_init_port: str):
 | 
				
			||||||
 | 
					    parallel_config = ParallelConfig(pipeline_parallel_size,
 | 
				
			||||||
 | 
					                                     tensor_parallel_size,
 | 
				
			||||||
 | 
					                                     worker_use_ray=True)
 | 
				
			||||||
 | 
					    distributed_init_method = f"tcp://localhost:{distributed_init_port}"
 | 
				
			||||||
 | 
					    torch.cuda.set_device(rank)
 | 
				
			||||||
 | 
					    _init_distributed_environment(parallel_config, rank,
 | 
				
			||||||
 | 
					                                  distributed_init_method)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
 | 
				
			||||||
 | 
					                           distributed_init_port: str):
 | 
				
			||||||
 | 
					    init_test_distributed_environment(1, tensor_parallel_size, rank,
 | 
				
			||||||
 | 
					                                      distributed_init_port)
 | 
				
			||||||
 | 
					    num_elements = 8
 | 
				
			||||||
 | 
					    all_tensors = [
 | 
				
			||||||
 | 
					        torch.arange(num_elements, dtype=torch.float32, device="cuda") *
 | 
				
			||||||
 | 
					        (r + 1) for r in range(tensor_parallel_size)
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
 | 
				
			||||||
 | 
					    t = all_tensors[rank]
 | 
				
			||||||
 | 
					    t = tensor_model_parallel_all_reduce(t)
 | 
				
			||||||
 | 
					    assert torch.allclose(t, expected)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def all_gather_test_worker(tensor_parallel_size: int, rank: int,
 | 
				
			||||||
 | 
					                           distributed_init_port: str):
 | 
				
			||||||
 | 
					    init_test_distributed_environment(1, tensor_parallel_size, rank,
 | 
				
			||||||
 | 
					                                      distributed_init_port)
 | 
				
			||||||
 | 
					    num_dimensions = 3
 | 
				
			||||||
 | 
					    tensor_size = list(range(2, num_dimensions + 2))
 | 
				
			||||||
 | 
					    total_size = 1
 | 
				
			||||||
 | 
					    for s in tensor_size:
 | 
				
			||||||
 | 
					        total_size *= s
 | 
				
			||||||
 | 
					    for all_gather_dimension in range(num_dimensions):
 | 
				
			||||||
 | 
					        all_tensors = [
 | 
				
			||||||
 | 
					            torch.arange(total_size, dtype=torch.float32,
 | 
				
			||||||
 | 
					                         device="cuda").reshape(tensor_size) * (r + 1)
 | 
				
			||||||
 | 
					            for r in range(tensor_parallel_size)
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        expected = torch.cat(all_tensors, dim=all_gather_dimension)
 | 
				
			||||||
 | 
					        t = all_tensors[rank]
 | 
				
			||||||
 | 
					        t = tensor_model_parallel_all_gather(t, all_gather_dimension)
 | 
				
			||||||
 | 
					        assert torch.allclose(t, expected)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.skipif(torch.cuda.device_count() < 2,
 | 
				
			||||||
 | 
					                    reason="Need at least 2 GPUs to run the test.")
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("tensor_parallel_size", [2])
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("test_target",
 | 
				
			||||||
 | 
					                         [all_reduce_test_worker, all_gather_test_worker])
 | 
				
			||||||
 | 
					def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
 | 
				
			||||||
 | 
					    distributed_init_port = get_open_port()
 | 
				
			||||||
 | 
					    processes = []
 | 
				
			||||||
 | 
					    for rank in range(tensor_parallel_size):
 | 
				
			||||||
 | 
					        p = Process(target=test_target,
 | 
				
			||||||
 | 
					                    args=(tensor_parallel_size, rank, distributed_init_port))
 | 
				
			||||||
 | 
					        p.start()
 | 
				
			||||||
 | 
					        processes.append(p)
 | 
				
			||||||
 | 
					    for p in processes:
 | 
				
			||||||
 | 
					        p.join()
 | 
				
			||||||
 | 
					    assert all(p.exitcode == 0 for p in processes)
 | 
				
			||||||
							
								
								
									
										63
									
								
								tests/engine/test_detokenize.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,63 @@
 | 
				
			|||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from transformers import AutoTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm.transformers_utils.tokenizer import detokenize_incrementally
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TRUTH = [
 | 
				
			||||||
 | 
					    # pylint: disable=line-too-long
 | 
				
			||||||
 | 
					    "Hello here, this is a simple test",
 | 
				
			||||||
 | 
					    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
 | 
				
			||||||
 | 
					    "我很感谢你的热情"
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					TOKENIZERS = [
 | 
				
			||||||
 | 
					    "facebook/opt-125m",
 | 
				
			||||||
 | 
					    "gpt2",
 | 
				
			||||||
 | 
					    "bigcode/tiny_starcoder_py",
 | 
				
			||||||
 | 
					    "EleutherAI/gpt-j-6b",
 | 
				
			||||||
 | 
					    "EleutherAI/pythia-70m",
 | 
				
			||||||
 | 
					    "bigscience/bloom-560m",
 | 
				
			||||||
 | 
					    "mosaicml/mpt-7b",
 | 
				
			||||||
 | 
					    "tiiuae/falcon-7b",
 | 
				
			||||||
 | 
					    "meta-llama/Llama-2-7b-hf",
 | 
				
			||||||
 | 
					    "codellama/CodeLlama-7b-hf",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _run_incremental_decode(tokenizer, all_input_ids,
 | 
				
			||||||
 | 
					                            skip_special_tokens: bool):
 | 
				
			||||||
 | 
					    decoded_text = ""
 | 
				
			||||||
 | 
					    offset = 0
 | 
				
			||||||
 | 
					    token_offset = 0
 | 
				
			||||||
 | 
					    prev_tokens = None
 | 
				
			||||||
 | 
					    for i in range(len(all_input_ids)):
 | 
				
			||||||
 | 
					        new_tokens, text, offset, token_offset = detokenize_incrementally(
 | 
				
			||||||
 | 
					            tokenizer,
 | 
				
			||||||
 | 
					            all_input_ids[:i + 1],
 | 
				
			||||||
 | 
					            prev_tokens,
 | 
				
			||||||
 | 
					            offset,
 | 
				
			||||||
 | 
					            token_offset,
 | 
				
			||||||
 | 
					            skip_special_tokens=skip_special_tokens)
 | 
				
			||||||
 | 
					        decoded_text += text
 | 
				
			||||||
 | 
					        if prev_tokens is None:
 | 
				
			||||||
 | 
					            prev_tokens = new_tokens
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            prev_tokens += new_tokens
 | 
				
			||||||
 | 
					    return decoded_text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("truth", TRUTH)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("skip_special_tokens", (True, False))
 | 
				
			||||||
 | 
					def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
 | 
				
			||||||
 | 
					    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
 | 
				
			||||||
 | 
					    all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
 | 
				
			||||||
 | 
					    if skip_special_tokens:
 | 
				
			||||||
 | 
					        all_input_ids = ([tokenizer.bos_token_id]
 | 
				
			||||||
 | 
					                         if tokenizer.bos_token_id is not None else
 | 
				
			||||||
 | 
					                         []) + all_input_ids + [tokenizer.eos_token_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    decoded_text = _run_incremental_decode(
 | 
				
			||||||
 | 
					        tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert decoded_text == truth
 | 
				
			||||||
							
								
								
									
										43
									
								
								tests/kernels/conftest.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,43 @@
 | 
				
			|||||||
 | 
					from typing import List, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_kv_caches(
 | 
				
			||||||
 | 
					    num_blocks: int,
 | 
				
			||||||
 | 
					    block_size: int,
 | 
				
			||||||
 | 
					    num_layers: int,
 | 
				
			||||||
 | 
					    num_heads: int,
 | 
				
			||||||
 | 
					    head_size: int,
 | 
				
			||||||
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
 | 
					) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
 | 
				
			||||||
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scale = head_size**-0.5
 | 
				
			||||||
 | 
					    x = 16 // torch.tensor([], dtype=dtype).element_size()
 | 
				
			||||||
 | 
					    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
 | 
				
			||||||
 | 
					    key_caches = []
 | 
				
			||||||
 | 
					    for _ in range(num_layers):
 | 
				
			||||||
 | 
					        key_cache = torch.empty(size=key_cache_shape,
 | 
				
			||||||
 | 
					                                dtype=dtype,
 | 
				
			||||||
 | 
					                                device='cuda')
 | 
				
			||||||
 | 
					        key_cache.uniform_(-scale, scale)
 | 
				
			||||||
 | 
					        key_caches.append(key_cache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
 | 
				
			||||||
 | 
					    value_caches = []
 | 
				
			||||||
 | 
					    for _ in range(num_layers):
 | 
				
			||||||
 | 
					        value_cache = torch.empty(size=value_cache_shape,
 | 
				
			||||||
 | 
					                                  dtype=dtype,
 | 
				
			||||||
 | 
					                                  device='cuda')
 | 
				
			||||||
 | 
					        value_cache.uniform_(-scale, scale)
 | 
				
			||||||
 | 
					        value_caches.append(value_cache)
 | 
				
			||||||
 | 
					    return key_caches, value_caches
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture()
 | 
				
			||||||
 | 
					def kv_cache_factory():
 | 
				
			||||||
 | 
					    return create_kv_caches
 | 
				
			||||||
@ -1,30 +1,75 @@
 | 
				
			|||||||
 | 
					import pytest
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					from transformers.activations import get_activation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm import activation_ops
 | 
					from vllm import activation_ops
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DTYPES = [torch.half, torch.bfloat16, torch.float]
 | 
				
			||||||
 | 
					NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					D = [512, 4096, 5120, 13824]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					SEEDS = [0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
 | 
					def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
    x1, x2 = x.chunk(chunks=2, dim=1)
 | 
					    x1, x2 = x.chunk(chunks=2, dim=1)
 | 
				
			||||||
    return F.silu(x1) * x2
 | 
					    return F.silu(x1) * x2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("d", D)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", DTYPES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", SEEDS)
 | 
				
			||||||
@torch.inference_mode()
 | 
					@torch.inference_mode()
 | 
				
			||||||
def run_silu_and_mul(
 | 
					def test_silu_and_mul(
 | 
				
			||||||
    num_tokens: int,
 | 
					    num_tokens: int,
 | 
				
			||||||
    d: int,
 | 
					    d: int,
 | 
				
			||||||
    dtype: torch.dtype,
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
    out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					    x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
 | 
				
			||||||
 | 
					    out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
 | 
				
			||||||
    activation_ops.silu_and_mul(out, x)
 | 
					    activation_ops.silu_and_mul(out, x)
 | 
				
			||||||
    ref_out = ref_silu_and_mul(x)
 | 
					    ref_out = ref_silu_and_mul(x)
 | 
				
			||||||
    assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
 | 
					    assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_silu_and_mul() -> None:
 | 
					@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 | 
				
			||||||
    for dtype in [torch.half, torch.bfloat16, torch.float]:
 | 
					@pytest.mark.parametrize("d", D)
 | 
				
			||||||
        for num_tokens in [7, 83, 2048]:
 | 
					@pytest.mark.parametrize("dtype", DTYPES)
 | 
				
			||||||
            for d in [512, 4096, 5120, 13824]:
 | 
					@pytest.mark.parametrize("seed", SEEDS)
 | 
				
			||||||
                print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
 | 
					@torch.inference_mode()
 | 
				
			||||||
                run_silu_and_mul(num_tokens, d, dtype)
 | 
					def test_gelu_new(
 | 
				
			||||||
 | 
					    num_tokens: int,
 | 
				
			||||||
 | 
					    d: int,
 | 
				
			||||||
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					    x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
 | 
				
			||||||
 | 
					    out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
 | 
				
			||||||
 | 
					    activation_ops.gelu_new(out, x)
 | 
				
			||||||
 | 
					    ref_out = get_activation("gelu_new")(x)
 | 
				
			||||||
 | 
					    assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("d", D)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", DTYPES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", SEEDS)
 | 
				
			||||||
 | 
					def test_gelu_fast(
 | 
				
			||||||
 | 
					    num_tokens: int,
 | 
				
			||||||
 | 
					    d: int,
 | 
				
			||||||
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					    x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
 | 
				
			||||||
 | 
					    out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
 | 
				
			||||||
 | 
					    activation_ops.gelu_fast(out, x)
 | 
				
			||||||
 | 
					    ref_out = get_activation("gelu_fast")(x)
 | 
				
			||||||
 | 
					    assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,14 +1,29 @@
 | 
				
			|||||||
import random
 | 
					import random
 | 
				
			||||||
from typing import List, Optional
 | 
					from typing import List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from xformers import ops as xops
 | 
					from xformers import ops as xops
 | 
				
			||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
 | 
					from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm import attention_ops
 | 
					from vllm import attention_ops
 | 
				
			||||||
 | 
					from vllm.utils import get_max_shared_memory_bytes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
MAX_SEQ_LEN = 4096
 | 
					FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
 | 
				
			||||||
TEST_SEED = 0
 | 
					# This will change depending on the compute capability.
 | 
				
			||||||
 | 
					# - 512 as a buffer
 | 
				
			||||||
 | 
					MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
 | 
				
			||||||
 | 
					NUM_BLOCKS = 128  # Arbitrary values for testing
 | 
				
			||||||
 | 
					PARTITION_SIZE = 512
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DTYPES = [torch.half, torch.bfloat16, torch.float]
 | 
				
			||||||
 | 
					NUM_GEN_SEQS = [7]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					HEAD_SIZES = [64, 80, 96, 112, 128, 256]
 | 
				
			||||||
 | 
					BLOCK_SIZES = [16, 32]
 | 
				
			||||||
 | 
					USE_ALIBI = [False, True]
 | 
				
			||||||
 | 
					SEEDS = [0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ref_masked_attention(
 | 
					def ref_masked_attention(
 | 
				
			||||||
@ -18,29 +33,34 @@ def ref_masked_attention(
 | 
				
			|||||||
    scale: float,
 | 
					    scale: float,
 | 
				
			||||||
    attn_mask: Optional[torch.Tensor] = None,
 | 
					    attn_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
) -> torch.Tensor:
 | 
					) -> torch.Tensor:
 | 
				
			||||||
    query = query * scale
 | 
					    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
 | 
				
			||||||
    attn = torch.einsum('qhd,khd->hqk', query, key)
 | 
					 | 
				
			||||||
    if attn_mask is not None:
 | 
					    if attn_mask is not None:
 | 
				
			||||||
        attn = attn + attn_mask
 | 
					        attn_weights = attn_weights + attn_mask.float()
 | 
				
			||||||
    attn = torch.softmax(attn, dim=-1)
 | 
					    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
 | 
				
			||||||
    out = torch.einsum('hqk,khd->qhd', attn, value)
 | 
					    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
 | 
				
			||||||
    return out
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ref_single_query_cached_kv_attention(
 | 
					def ref_single_query_cached_kv_attention(
 | 
				
			||||||
    output: torch.Tensor,
 | 
					    output: torch.Tensor,
 | 
				
			||||||
    query: torch.Tensor,
 | 
					    query: torch.Tensor,
 | 
				
			||||||
 | 
					    num_queries_per_kv: int,
 | 
				
			||||||
    key_cache: torch.Tensor,
 | 
					    key_cache: torch.Tensor,
 | 
				
			||||||
    value_cache: torch.Tensor,
 | 
					    value_cache: torch.Tensor,
 | 
				
			||||||
    block_tables: torch.Tensor,
 | 
					    block_tables: torch.Tensor,
 | 
				
			||||||
    context_lens: torch.Tensor,
 | 
					    context_lens: torch.Tensor,
 | 
				
			||||||
 | 
					    scale: float,
 | 
				
			||||||
 | 
					    alibi_slopes: Optional[torch.Tensor],
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    num_heads = value_cache.shape[1]
 | 
					    num_query_heads = query.shape[1]
 | 
				
			||||||
 | 
					    num_kv_heads = value_cache.shape[1]
 | 
				
			||||||
    head_size = value_cache.shape[2]
 | 
					    head_size = value_cache.shape[2]
 | 
				
			||||||
    block_size = value_cache.shape[3]
 | 
					    block_size = value_cache.shape[3]
 | 
				
			||||||
 | 
					    num_seqs = query.shape[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    num_input_tokens = query.shape[0]
 | 
					    block_tables = block_tables.cpu().tolist()
 | 
				
			||||||
    for i in range(num_input_tokens):
 | 
					    context_lens = context_lens.cpu().tolist()
 | 
				
			||||||
 | 
					    for i in range(num_seqs):
 | 
				
			||||||
        q = query[i].unsqueeze(0)
 | 
					        q = query[i].unsqueeze(0)
 | 
				
			||||||
        block_table = block_tables[i]
 | 
					        block_table = block_tables[i]
 | 
				
			||||||
        context_len = int(context_lens[i])
 | 
					        context_len = int(context_lens[i])
 | 
				
			||||||
@ -52,30 +72,175 @@ def ref_single_query_cached_kv_attention(
 | 
				
			|||||||
            block_offset = j % block_size
 | 
					            block_offset = j % block_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            k = key_cache[block_number, :, :, block_offset, :]
 | 
					            k = key_cache[block_number, :, :, block_offset, :]
 | 
				
			||||||
            k = k.reshape(num_heads, head_size)
 | 
					            k = k.reshape(num_kv_heads, head_size)
 | 
				
			||||||
            keys.append(k)
 | 
					            keys.append(k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            v = value_cache[block_number, :, :, block_offset]
 | 
					            v = value_cache[block_number, :, :, block_offset]
 | 
				
			||||||
            values.append(v)
 | 
					            values.append(v)
 | 
				
			||||||
        keys = torch.stack(keys, dim=0)
 | 
					        keys = torch.stack(keys, dim=0)
 | 
				
			||||||
        values = torch.stack(values, dim=0)
 | 
					        values = torch.stack(values, dim=0)
 | 
				
			||||||
 | 
					        if num_queries_per_kv > 1:
 | 
				
			||||||
 | 
					            # Handle MQA and GQA
 | 
				
			||||||
 | 
					            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
 | 
				
			||||||
 | 
					            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        scale = 1.0 / (head_size**0.5)
 | 
					        alibi_bias = None
 | 
				
			||||||
        out = ref_masked_attention(q, keys, values, scale)
 | 
					        if alibi_slopes is not None:
 | 
				
			||||||
        out = out.view(num_heads, head_size)
 | 
					            # Create the ALiBi bias used in the paged attention kernel.
 | 
				
			||||||
 | 
					            position_ids = torch.arange(context_len, device="cuda").int()
 | 
				
			||||||
 | 
					            alibi_bias = (position_ids - context_len + 1).float()
 | 
				
			||||||
 | 
					            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
 | 
				
			||||||
 | 
					                1, 1, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
 | 
				
			||||||
 | 
					        out = out.view(num_query_heads, head_size)
 | 
				
			||||||
        output[i].copy_(out, non_blocking=True)
 | 
					        output[i].copy_(out, non_blocking=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("version", ["v1", "v2"])
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_heads", NUM_HEADS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("head_size", HEAD_SIZES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("use_alibi", USE_ALIBI)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("block_size", BLOCK_SIZES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", DTYPES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", SEEDS)
 | 
				
			||||||
 | 
					def test_paged_attention(
 | 
				
			||||||
 | 
					    kv_cache_factory,
 | 
				
			||||||
 | 
					    version: str,
 | 
				
			||||||
 | 
					    num_seqs: int,
 | 
				
			||||||
 | 
					    num_heads: Tuple[int, int],
 | 
				
			||||||
 | 
					    head_size: int,
 | 
				
			||||||
 | 
					    use_alibi: bool,
 | 
				
			||||||
 | 
					    block_size: int,
 | 
				
			||||||
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					    random.seed(seed)
 | 
				
			||||||
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scale = float(1.0 / (head_size**0.5))
 | 
				
			||||||
 | 
					    num_query_heads, num_kv_heads = num_heads
 | 
				
			||||||
 | 
					    query = torch.empty(num_seqs,
 | 
				
			||||||
 | 
					                        num_query_heads,
 | 
				
			||||||
 | 
					                        head_size,
 | 
				
			||||||
 | 
					                        dtype=dtype,
 | 
				
			||||||
 | 
					                        device="cuda")
 | 
				
			||||||
 | 
					    query.uniform_(-scale, scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert num_query_heads % num_kv_heads == 0
 | 
				
			||||||
 | 
					    num_queries_per_kv = num_query_heads // num_kv_heads
 | 
				
			||||||
 | 
					    head_mapping = torch.repeat_interleave(
 | 
				
			||||||
 | 
					        torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
 | 
				
			||||||
 | 
					        num_queries_per_kv)
 | 
				
			||||||
 | 
					    alibi_slopes = None
 | 
				
			||||||
 | 
					    if use_alibi:
 | 
				
			||||||
 | 
					        alibi_slopes = torch.randn(num_query_heads,
 | 
				
			||||||
 | 
					                                   dtype=torch.float,
 | 
				
			||||||
 | 
					                                   device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
 | 
				
			||||||
 | 
					    context_lens[-1] = MAX_SEQ_LEN
 | 
				
			||||||
 | 
					    max_context_len = max(context_lens)
 | 
				
			||||||
 | 
					    context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create the block tables.
 | 
				
			||||||
 | 
					    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
 | 
				
			||||||
 | 
					    block_tables = []
 | 
				
			||||||
 | 
					    for _ in range(num_seqs):
 | 
				
			||||||
 | 
					        block_table = [
 | 
				
			||||||
 | 
					            random.randint(0, NUM_BLOCKS - 1)
 | 
				
			||||||
 | 
					            for _ in range(max_num_blocks_per_seq)
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        block_tables.append(block_table)
 | 
				
			||||||
 | 
					    block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create the KV caches.
 | 
				
			||||||
 | 
					    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
 | 
				
			||||||
 | 
					                                                num_kv_heads, head_size, dtype,
 | 
				
			||||||
 | 
					                                                seed)
 | 
				
			||||||
 | 
					    key_cache, value_cache = key_caches[0], value_caches[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Call the paged attention kernel.
 | 
				
			||||||
 | 
					    output = torch.empty_like(query)
 | 
				
			||||||
 | 
					    if version == "v1":
 | 
				
			||||||
 | 
					        attention_ops.paged_attention_v1(
 | 
				
			||||||
 | 
					            output,
 | 
				
			||||||
 | 
					            query,
 | 
				
			||||||
 | 
					            key_cache,
 | 
				
			||||||
 | 
					            value_cache,
 | 
				
			||||||
 | 
					            head_mapping,
 | 
				
			||||||
 | 
					            scale,
 | 
				
			||||||
 | 
					            block_tables,
 | 
				
			||||||
 | 
					            context_lens,
 | 
				
			||||||
 | 
					            block_size,
 | 
				
			||||||
 | 
					            max_context_len,
 | 
				
			||||||
 | 
					            alibi_slopes,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    elif version == "v2":
 | 
				
			||||||
 | 
					        num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
 | 
				
			||||||
 | 
					                          PARTITION_SIZE)
 | 
				
			||||||
 | 
					        assert PARTITION_SIZE % block_size == 0
 | 
				
			||||||
 | 
					        num_seqs, num_heads, head_size = output.shape
 | 
				
			||||||
 | 
					        tmp_output = torch.empty(
 | 
				
			||||||
 | 
					            size=(num_seqs, num_heads, num_partitions, head_size),
 | 
				
			||||||
 | 
					            dtype=output.dtype,
 | 
				
			||||||
 | 
					            device=output.device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        exp_sums = torch.empty(
 | 
				
			||||||
 | 
					            size=(num_seqs, num_heads, num_partitions),
 | 
				
			||||||
 | 
					            dtype=torch.float32,
 | 
				
			||||||
 | 
					            device=output.device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        max_logits = torch.empty_like(exp_sums)
 | 
				
			||||||
 | 
					        attention_ops.paged_attention_v2(
 | 
				
			||||||
 | 
					            output,
 | 
				
			||||||
 | 
					            exp_sums,
 | 
				
			||||||
 | 
					            max_logits,
 | 
				
			||||||
 | 
					            tmp_output,
 | 
				
			||||||
 | 
					            query,
 | 
				
			||||||
 | 
					            key_cache,
 | 
				
			||||||
 | 
					            value_cache,
 | 
				
			||||||
 | 
					            head_mapping,
 | 
				
			||||||
 | 
					            scale,
 | 
				
			||||||
 | 
					            block_tables,
 | 
				
			||||||
 | 
					            context_lens,
 | 
				
			||||||
 | 
					            block_size,
 | 
				
			||||||
 | 
					            max_context_len,
 | 
				
			||||||
 | 
					            alibi_slopes,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        assert False, f"Unknown version: {version}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Run the reference implementation.
 | 
				
			||||||
 | 
					    ref_output = torch.empty_like(query)
 | 
				
			||||||
 | 
					    ref_single_query_cached_kv_attention(
 | 
				
			||||||
 | 
					        ref_output,
 | 
				
			||||||
 | 
					        query,
 | 
				
			||||||
 | 
					        num_queries_per_kv,
 | 
				
			||||||
 | 
					        key_cache,
 | 
				
			||||||
 | 
					        value_cache,
 | 
				
			||||||
 | 
					        block_tables,
 | 
				
			||||||
 | 
					        context_lens,
 | 
				
			||||||
 | 
					        scale,
 | 
				
			||||||
 | 
					        alibi_slopes,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # NOTE(woosuk): Due to the kernel-level differences in the two
 | 
				
			||||||
 | 
					    # implementations, there is a small numerical difference in the two
 | 
				
			||||||
 | 
					    # outputs. Thus, we use a relaxed tolerance for the test.
 | 
				
			||||||
 | 
					    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ref_multi_query_kv_attention(
 | 
					def ref_multi_query_kv_attention(
 | 
				
			||||||
    cu_seq_lens: List[int],
 | 
					    cu_seq_lens: List[int],
 | 
				
			||||||
    query: torch.Tensor,
 | 
					    query: torch.Tensor,
 | 
				
			||||||
    key: torch.Tensor,
 | 
					    key: torch.Tensor,
 | 
				
			||||||
    value: torch.Tensor,
 | 
					    value: torch.Tensor,
 | 
				
			||||||
 | 
					    scale: float,
 | 
				
			||||||
    dtype: torch.dtype,
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
) -> torch.Tensor:
 | 
					) -> torch.Tensor:
 | 
				
			||||||
    head_size = query.shape[-1]
 | 
					 | 
				
			||||||
    scale = 1.0 / (head_size**0.5)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    num_seqs = len(cu_seq_lens) - 1
 | 
					    num_seqs = len(cu_seq_lens) - 1
 | 
				
			||||||
    ref_outputs = []
 | 
					    ref_outputs = []
 | 
				
			||||||
    for i in range(num_seqs):
 | 
					    for i in range(num_seqs):
 | 
				
			||||||
@ -87,7 +252,7 @@ def ref_multi_query_kv_attention(
 | 
				
			|||||||
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
 | 
					        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
 | 
				
			||||||
                               diagonal=1)
 | 
					                               diagonal=1)
 | 
				
			||||||
        attn_mask = attn_mask * torch.finfo(dtype).min
 | 
					        attn_mask = attn_mask * torch.finfo(dtype).min
 | 
				
			||||||
        attn_mask = attn_mask.to(dtype=dtype, device='cuda')
 | 
					        attn_mask = attn_mask.to(dtype=dtype, device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ref_output = ref_masked_attention(
 | 
					        ref_output = ref_masked_attention(
 | 
				
			||||||
            query[start_idx:end_idx],
 | 
					            query[start_idx:end_idx],
 | 
				
			||||||
@ -101,161 +266,47 @@ def ref_multi_query_kv_attention(
 | 
				
			|||||||
    return ref_output
 | 
					    return ref_output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ref_multi_query_cached_kv_attention(
 | 
					# TODO(woosuk): Add tests for USE_ALIBI=True.
 | 
				
			||||||
    cu_query_lens: List[int],
 | 
					@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
 | 
				
			||||||
    query: torch.Tensor,
 | 
					@pytest.mark.parametrize("num_heads", NUM_HEADS)
 | 
				
			||||||
    key_cache: torch.Tensor,
 | 
					@pytest.mark.parametrize("head_size", HEAD_SIZES)
 | 
				
			||||||
    value_cache: torch.Tensor,
 | 
					@pytest.mark.parametrize("dtype", DTYPES)
 | 
				
			||||||
    block_tables: torch.Tensor,
 | 
					@pytest.mark.parametrize("seed", SEEDS)
 | 
				
			||||||
    context_lens: torch.Tensor,
 | 
					 | 
				
			||||||
    dtype: torch.dtype,
 | 
					 | 
				
			||||||
) -> torch.Tensor:
 | 
					 | 
				
			||||||
    num_heads = value_cache.shape[1]
 | 
					 | 
				
			||||||
    head_size = value_cache.shape[2]
 | 
					 | 
				
			||||||
    block_size = value_cache.shape[3]
 | 
					 | 
				
			||||||
    scale = 1.0 / (head_size**0.5)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    num_queries = len(cu_query_lens) - 1
 | 
					 | 
				
			||||||
    ref_outputs = []
 | 
					 | 
				
			||||||
    for i in range(num_queries):
 | 
					 | 
				
			||||||
        start_idx = cu_query_lens[i]
 | 
					 | 
				
			||||||
        end_idx = cu_query_lens[i + 1]
 | 
					 | 
				
			||||||
        query_len = end_idx - start_idx
 | 
					 | 
				
			||||||
        context_len = int(context_lens[i])
 | 
					 | 
				
			||||||
        block_table = block_tables[i]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Create attention mask
 | 
					 | 
				
			||||||
        attn_mask = torch.triu(torch.ones(query_len, context_len),
 | 
					 | 
				
			||||||
                               diagonal=context_len - query_len + 1) * -1e5
 | 
					 | 
				
			||||||
        attn_mask = attn_mask.to(dtype=dtype, device='cuda')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        keys = []
 | 
					 | 
				
			||||||
        values = []
 | 
					 | 
				
			||||||
        for j in range(context_len):
 | 
					 | 
				
			||||||
            block_number = int(block_table[j // block_size])
 | 
					 | 
				
			||||||
            block_offset = j % block_size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            k = key_cache[block_number, :, :, block_offset, :]
 | 
					 | 
				
			||||||
            k = k.reshape(num_heads, head_size)
 | 
					 | 
				
			||||||
            keys.append(k)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            v = value_cache[block_number, :, :, block_offset]
 | 
					 | 
				
			||||||
            values.append(v)
 | 
					 | 
				
			||||||
        keys = torch.stack(keys, dim=0)
 | 
					 | 
				
			||||||
        values = torch.stack(values, dim=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ref_output = ref_masked_attention(
 | 
					 | 
				
			||||||
            query[start_idx:end_idx],
 | 
					 | 
				
			||||||
            keys,
 | 
					 | 
				
			||||||
            values,
 | 
					 | 
				
			||||||
            scale,
 | 
					 | 
				
			||||||
            attn_mask=attn_mask,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        ref_outputs.append(ref_output)
 | 
					 | 
				
			||||||
    ref_output = torch.cat(ref_outputs, dim=0)
 | 
					 | 
				
			||||||
    return ref_output
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@torch.inference_mode()
 | 
					@torch.inference_mode()
 | 
				
			||||||
def run_single_query_cached_kv_attention(
 | 
					def test_multi_query_kv_attention(
 | 
				
			||||||
    num_tokens: int,
 | 
					 | 
				
			||||||
    num_heads: int,
 | 
					 | 
				
			||||||
    head_size: int,
 | 
					 | 
				
			||||||
    block_size: int,
 | 
					 | 
				
			||||||
    num_blocks: int,
 | 
					 | 
				
			||||||
    dtype: torch.dtype,
 | 
					 | 
				
			||||||
) -> None:
 | 
					 | 
				
			||||||
    qkv = torch.empty(num_tokens,
 | 
					 | 
				
			||||||
                      3,
 | 
					 | 
				
			||||||
                      num_heads,
 | 
					 | 
				
			||||||
                      head_size,
 | 
					 | 
				
			||||||
                      dtype=dtype,
 | 
					 | 
				
			||||||
                      device='cuda')
 | 
					 | 
				
			||||||
    qkv.uniform_(-1e-3, 1e-3)
 | 
					 | 
				
			||||||
    query, _, _ = qkv.unbind(dim=1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    x = 16 // torch.tensor([], dtype=dtype).element_size()
 | 
					 | 
				
			||||||
    key_block_shape = (num_heads, head_size // x, block_size, x)
 | 
					 | 
				
			||||||
    key_cache = torch.empty(size=(num_blocks, *key_block_shape),
 | 
					 | 
				
			||||||
                            dtype=dtype,
 | 
					 | 
				
			||||||
                            device='cuda')
 | 
					 | 
				
			||||||
    key_cache.uniform_(-1e-3, 1e-3)
 | 
					 | 
				
			||||||
    value_block_shape = (num_heads, head_size, block_size)
 | 
					 | 
				
			||||||
    value_cache = torch.empty(size=(num_blocks, *value_block_shape),
 | 
					 | 
				
			||||||
                              dtype=dtype,
 | 
					 | 
				
			||||||
                              device='cuda')
 | 
					 | 
				
			||||||
    value_cache.uniform_(-1e-3, 1e-3)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
 | 
					 | 
				
			||||||
    max_context_len = max(context_lens)
 | 
					 | 
				
			||||||
    context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
 | 
					 | 
				
			||||||
    block_tables = []
 | 
					 | 
				
			||||||
    for _ in range(num_tokens):
 | 
					 | 
				
			||||||
        block_table = [
 | 
					 | 
				
			||||||
            random.randint(0, num_blocks - 1)
 | 
					 | 
				
			||||||
            for _ in range(max_num_blocks_per_seq)
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
        block_tables.append(block_table)
 | 
					 | 
				
			||||||
    block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    scale = float(1.0 / (head_size**0.5))
 | 
					 | 
				
			||||||
    output = torch.empty(num_tokens,
 | 
					 | 
				
			||||||
                         num_heads,
 | 
					 | 
				
			||||||
                         head_size,
 | 
					 | 
				
			||||||
                         dtype=dtype,
 | 
					 | 
				
			||||||
                         device='cuda')
 | 
					 | 
				
			||||||
    attention_ops.single_query_cached_kv_attention(
 | 
					 | 
				
			||||||
        output,
 | 
					 | 
				
			||||||
        query,
 | 
					 | 
				
			||||||
        key_cache,
 | 
					 | 
				
			||||||
        value_cache,
 | 
					 | 
				
			||||||
        scale,
 | 
					 | 
				
			||||||
        block_tables,
 | 
					 | 
				
			||||||
        context_lens,
 | 
					 | 
				
			||||||
        block_size,
 | 
					 | 
				
			||||||
        max_context_len,
 | 
					 | 
				
			||||||
        None,  # ALiBi slopes.
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ref_output = torch.empty_like(query)
 | 
					 | 
				
			||||||
    ref_single_query_cached_kv_attention(
 | 
					 | 
				
			||||||
        ref_output,
 | 
					 | 
				
			||||||
        query,
 | 
					 | 
				
			||||||
        key_cache,
 | 
					 | 
				
			||||||
        value_cache,
 | 
					 | 
				
			||||||
        block_tables,
 | 
					 | 
				
			||||||
        context_lens,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    # NOTE(woosuk): Due to the difference in the data types the two
 | 
					 | 
				
			||||||
    # implementations use for attention softmax logits and accumulation,
 | 
					 | 
				
			||||||
    # there is a small difference in the final outputs.
 | 
					 | 
				
			||||||
    # We should use a relaxed tolerance for the test.
 | 
					 | 
				
			||||||
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@torch.inference_mode()
 | 
					 | 
				
			||||||
def run_multi_query_kv_attention(
 | 
					 | 
				
			||||||
    num_seqs: int,
 | 
					    num_seqs: int,
 | 
				
			||||||
    num_heads: int,
 | 
					    num_heads: Tuple[int, int],
 | 
				
			||||||
    head_size: int,
 | 
					    head_size: int,
 | 
				
			||||||
    dtype: torch.dtype,
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
 | 
					    random.seed(seed)
 | 
				
			||||||
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
 | 
				
			||||||
 | 
					    # As the xformers library is already tested with its own tests, we can use
 | 
				
			||||||
 | 
					    # a smaller MAX_SEQ_LEN here.
 | 
				
			||||||
 | 
					    max_len = min(MAX_SEQ_LEN, 4096)
 | 
				
			||||||
 | 
					    seq_lens = random.sample(range(1, max_len), num_seqs)
 | 
				
			||||||
    num_tokens = sum(seq_lens)
 | 
					    num_tokens = sum(seq_lens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    scale = float(1.0 / (head_size**0.5))
 | 
					    scale = float(1.0 / (head_size**0.5))
 | 
				
			||||||
 | 
					    num_query_heads, num_kv_heads = num_heads
 | 
				
			||||||
    qkv = torch.empty(num_tokens,
 | 
					    qkv = torch.empty(num_tokens,
 | 
				
			||||||
                      3,
 | 
					                      num_query_heads + 2 * num_kv_heads,
 | 
				
			||||||
                      num_heads,
 | 
					 | 
				
			||||||
                      head_size,
 | 
					                      head_size,
 | 
				
			||||||
                      dtype=dtype,
 | 
					                      dtype=dtype,
 | 
				
			||||||
                      device='cuda')
 | 
					                      device="cuda")
 | 
				
			||||||
    qkv.uniform_(-1e-3, 1e-3)
 | 
					    qkv.uniform_(-scale, scale)
 | 
				
			||||||
    query, key, value = qkv.unbind(dim=1)
 | 
					    query, key, value = qkv.split(
 | 
				
			||||||
 | 
					        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_op = xops.fmha.cutlass.FwOp()
 | 
					    num_queries_per_kv = num_query_heads // num_kv_heads
 | 
				
			||||||
 | 
					    if num_queries_per_kv > 1:
 | 
				
			||||||
 | 
					        # Handle MQA and GQA
 | 
				
			||||||
 | 
					        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
 | 
				
			||||||
 | 
					        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
 | 
				
			||||||
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
 | 
					    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
 | 
				
			||||||
    output = xops.memory_efficient_attention_forward(
 | 
					    output = xops.memory_efficient_attention_forward(
 | 
				
			||||||
        query.unsqueeze(0),
 | 
					        query.unsqueeze(0),
 | 
				
			||||||
@ -264,7 +315,6 @@ def run_multi_query_kv_attention(
 | 
				
			|||||||
        attn_bias=attn_bias,
 | 
					        attn_bias=attn_bias,
 | 
				
			||||||
        p=0.0,
 | 
					        p=0.0,
 | 
				
			||||||
        scale=scale,
 | 
					        scale=scale,
 | 
				
			||||||
        op=attn_op,
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    output = output.squeeze(0)
 | 
					    output = output.squeeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -276,40 +326,7 @@ def run_multi_query_kv_attention(
 | 
				
			|||||||
        query,
 | 
					        query,
 | 
				
			||||||
        key,
 | 
					        key,
 | 
				
			||||||
        value,
 | 
					        value,
 | 
				
			||||||
 | 
					        scale,
 | 
				
			||||||
        dtype,
 | 
					        dtype,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
 | 
					    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_single_query_cached_kv_attention() -> None:
 | 
					 | 
				
			||||||
    torch.random.manual_seed(TEST_SEED)
 | 
					 | 
				
			||||||
    torch.cuda.manual_seed(TEST_SEED)
 | 
					 | 
				
			||||||
    for dtype in [torch.half, torch.bfloat16, torch.float]:
 | 
					 | 
				
			||||||
        for block_size in [8, 16, 32]:
 | 
					 | 
				
			||||||
            for head_size in [64, 80, 96, 128]:
 | 
					 | 
				
			||||||
                print(f'Testing single_query_cached_kv_attention with '
 | 
					 | 
				
			||||||
                      f'dtype={dtype}, block_size={block_size}, '
 | 
					 | 
				
			||||||
                      f'head_size={head_size}')
 | 
					 | 
				
			||||||
                run_single_query_cached_kv_attention(
 | 
					 | 
				
			||||||
                    num_tokens=37,
 | 
					 | 
				
			||||||
                    num_heads=3,
 | 
					 | 
				
			||||||
                    head_size=head_size,
 | 
					 | 
				
			||||||
                    block_size=block_size,
 | 
					 | 
				
			||||||
                    num_blocks=1024,
 | 
					 | 
				
			||||||
                    dtype=dtype,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_multi_query_kv_attention() -> None:
 | 
					 | 
				
			||||||
    torch.random.manual_seed(TEST_SEED)
 | 
					 | 
				
			||||||
    torch.cuda.manual_seed(TEST_SEED)
 | 
					 | 
				
			||||||
    for dtype in [torch.half, torch.bfloat16, torch.float]:
 | 
					 | 
				
			||||||
        for head_size in [64, 80, 96, 128]:
 | 
					 | 
				
			||||||
            print(f'Testing multi_query_kv_attention with dtype={dtype}, '
 | 
					 | 
				
			||||||
                  f'head_size={head_size}')
 | 
					 | 
				
			||||||
            run_multi_query_kv_attention(
 | 
					 | 
				
			||||||
                num_seqs=5,
 | 
					 | 
				
			||||||
                num_heads=3,
 | 
					 | 
				
			||||||
                head_size=head_size,
 | 
					 | 
				
			||||||
                dtype=dtype,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -1,12 +1,32 @@
 | 
				
			|||||||
import random
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm import cache_ops
 | 
					from vllm import cache_ops
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DTYPES = [torch.half, torch.bfloat16, torch.float]
 | 
				
			||||||
 | 
					NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					NUM_LAYERS = [5]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					NUM_HEADS = [8]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					HEAD_SIZES = [64, 80, 96, 112, 128, 256]
 | 
				
			||||||
 | 
					BLOCK_SIZES = [8, 16, 32]
 | 
				
			||||||
 | 
					NUM_BLOCKS = [1024]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					NUM_MAPPINGS = [32, 256]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					SEEDS = [0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_layers", NUM_LAYERS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_heads", NUM_HEADS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("head_size", HEAD_SIZES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("block_size", BLOCK_SIZES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", DTYPES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", SEEDS)
 | 
				
			||||||
@torch.inference_mode()
 | 
					@torch.inference_mode()
 | 
				
			||||||
def run_copy_blocks(
 | 
					def test_copy_blocks(
 | 
				
			||||||
 | 
					    kv_cache_factory,
 | 
				
			||||||
    num_mappings: int,
 | 
					    num_mappings: int,
 | 
				
			||||||
    num_layers: int,
 | 
					    num_layers: int,
 | 
				
			||||||
    num_heads: int,
 | 
					    num_heads: int,
 | 
				
			||||||
@ -14,48 +34,43 @@ def run_copy_blocks(
 | 
				
			|||||||
    block_size: int,
 | 
					    block_size: int,
 | 
				
			||||||
    num_blocks: int,
 | 
					    num_blocks: int,
 | 
				
			||||||
    dtype: torch.dtype,
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    # Generate random block mappings.
 | 
					    random.seed(seed)
 | 
				
			||||||
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Generate random block mappings where each source block is mapped to two
 | 
				
			||||||
 | 
					    # destination blocks.
 | 
				
			||||||
 | 
					    assert 2 * num_mappings <= num_blocks
 | 
				
			||||||
    src_blocks = random.sample(range(num_blocks), num_mappings)
 | 
					    src_blocks = random.sample(range(num_blocks), num_mappings)
 | 
				
			||||||
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
 | 
					    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
 | 
				
			||||||
    dst_blocks = random.sample(remainig_blocks, num_mappings)
 | 
					    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
 | 
				
			||||||
    block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}
 | 
					    block_mapping = {}
 | 
				
			||||||
 | 
					    for i in range(num_mappings):
 | 
				
			||||||
 | 
					        src = src_blocks[i]
 | 
				
			||||||
 | 
					        dst1 = dst_blocks[2 * i]
 | 
				
			||||||
 | 
					        dst2 = dst_blocks[2 * i + 1]
 | 
				
			||||||
 | 
					        block_mapping[src] = [dst1, dst2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Create the KV cache.
 | 
					    # Create the KV caches.
 | 
				
			||||||
    x = 16 // torch.tensor([], dtype=dtype).element_size()
 | 
					    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
 | 
				
			||||||
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
 | 
					                                                num_layers, num_heads,
 | 
				
			||||||
    key_caches = []
 | 
					                                                head_size, dtype, seed)
 | 
				
			||||||
    for _ in range(num_layers):
 | 
					 | 
				
			||||||
        key_cache = torch.randn(size=key_cache_shape,
 | 
					 | 
				
			||||||
                                dtype=dtype,
 | 
					 | 
				
			||||||
                                device='cuda')
 | 
					 | 
				
			||||||
        key_caches.append(key_cache)
 | 
					 | 
				
			||||||
    cloned_key_caches = []
 | 
					 | 
				
			||||||
    for key_cache in key_caches:
 | 
					 | 
				
			||||||
        cloned_key_caches.append(key_cache.clone())
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
 | 
					    # Clone the KV caches.
 | 
				
			||||||
    value_caches = []
 | 
					    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
 | 
				
			||||||
    for _ in range(num_layers):
 | 
					    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
 | 
				
			||||||
        value_cache = torch.randn(size=value_cache_shape,
 | 
					 | 
				
			||||||
                                  dtype=dtype,
 | 
					 | 
				
			||||||
                                  device='cuda')
 | 
					 | 
				
			||||||
        value_caches.append(value_cache)
 | 
					 | 
				
			||||||
    cloned_value_caches = []
 | 
					 | 
				
			||||||
    for value_cache in value_caches:
 | 
					 | 
				
			||||||
        cloned_value_caches.append(value_cache.clone())
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Call the copy blocks kernel.
 | 
					    # Call the copy blocks kernel.
 | 
				
			||||||
    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
 | 
					    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Reference implementation.
 | 
					    # Run the reference implementation.
 | 
				
			||||||
    for src, dsts in block_mapping.items():
 | 
					    for src, dsts in block_mapping.items():
 | 
				
			||||||
        for dst in dsts:
 | 
					        for dst in dsts:
 | 
				
			||||||
            for key_cache, cloned_key_cache in zip(key_caches,
 | 
					            for cloned_key_cache in cloned_key_caches:
 | 
				
			||||||
                                                   cloned_key_caches):
 | 
					 | 
				
			||||||
                cloned_key_cache[dst] = cloned_key_cache[src]
 | 
					                cloned_key_cache[dst] = cloned_key_cache[src]
 | 
				
			||||||
            for value_cache, cloned_value_cache in zip(value_caches,
 | 
					            for cloned_value_cache in cloned_value_caches:
 | 
				
			||||||
                                                       cloned_value_caches):
 | 
					 | 
				
			||||||
                cloned_value_cache[dst] = cloned_value_cache[src]
 | 
					                cloned_value_cache[dst] = cloned_value_cache[src]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Compare the results.
 | 
					    # Compare the results.
 | 
				
			||||||
@ -66,131 +81,66 @@ def run_copy_blocks(
 | 
				
			|||||||
        assert torch.allclose(value_cache, cloned_value_cache)
 | 
					        assert torch.allclose(value_cache, cloned_value_cache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_heads", NUM_HEADS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("head_size", HEAD_SIZES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("block_size", BLOCK_SIZES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", DTYPES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", SEEDS)
 | 
				
			||||||
@torch.inference_mode()
 | 
					@torch.inference_mode()
 | 
				
			||||||
def run_reshape_and_cache(
 | 
					def test_reshape_and_cache(
 | 
				
			||||||
 | 
					    kv_cache_factory,
 | 
				
			||||||
    num_tokens: int,
 | 
					    num_tokens: int,
 | 
				
			||||||
    num_heads: int,
 | 
					    num_heads: int,
 | 
				
			||||||
    head_size: int,
 | 
					    head_size: int,
 | 
				
			||||||
    block_size: int,
 | 
					    block_size: int,
 | 
				
			||||||
    num_blocks: int,
 | 
					    num_blocks: int,
 | 
				
			||||||
    dtype: torch.dtype,
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
 | 
					    random.seed(seed)
 | 
				
			||||||
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create a random slot mapping.
 | 
				
			||||||
    num_slots = block_size * num_blocks
 | 
					    num_slots = block_size * num_blocks
 | 
				
			||||||
    slot_mapping = random.sample(range(num_slots), num_tokens)
 | 
					    slot_mapping = random.sample(range(num_slots), num_tokens)
 | 
				
			||||||
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
 | 
					    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    qkv = torch.randn(num_tokens,
 | 
					    qkv = torch.randn(num_tokens,
 | 
				
			||||||
                      3,
 | 
					                      3,
 | 
				
			||||||
                      num_heads,
 | 
					                      num_heads,
 | 
				
			||||||
                      head_size,
 | 
					                      head_size,
 | 
				
			||||||
                      dtype=dtype,
 | 
					                      dtype=dtype,
 | 
				
			||||||
                      device='cuda')
 | 
					                      device="cuda")
 | 
				
			||||||
    _, key, value = qkv.unbind(dim=1)
 | 
					    _, key, value = qkv.unbind(dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    x = 16 // torch.tensor([], dtype=dtype).element_size()
 | 
					    # Create the KV caches.
 | 
				
			||||||
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
 | 
					    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
 | 
				
			||||||
    key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
 | 
					                                                num_heads, head_size, dtype,
 | 
				
			||||||
    cloned_key_cache = key_cache.clone()
 | 
					                                                seed)
 | 
				
			||||||
 | 
					    key_cache, value_cache = key_caches[0], value_caches[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
 | 
					    # Clone the KV caches.
 | 
				
			||||||
    value_cache = torch.randn(size=value_cache_shape,
 | 
					    cloned_key_cache = key_cache.clone()
 | 
				
			||||||
                              dtype=dtype,
 | 
					 | 
				
			||||||
                              device='cuda')
 | 
					 | 
				
			||||||
    cloned_value_cache = value_cache.clone()
 | 
					    cloned_value_cache = value_cache.clone()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Call the reshape_and_cache kernel.
 | 
				
			||||||
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
 | 
					    cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
 | 
				
			||||||
                                slot_mapping)
 | 
					                                slot_mapping)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Run the reference implementation.
 | 
				
			||||||
 | 
					    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
 | 
				
			||||||
 | 
					    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
 | 
				
			||||||
 | 
					    block_indicies = block_indicies.cpu().tolist()
 | 
				
			||||||
 | 
					    block_offsets = slot_mapping % block_size
 | 
				
			||||||
 | 
					    block_offsets = block_offsets.cpu().tolist()
 | 
				
			||||||
    for i in range(num_tokens):
 | 
					    for i in range(num_tokens):
 | 
				
			||||||
        reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
 | 
					        block_idx = block_indicies[i]
 | 
				
			||||||
        block_idx = torch.div(slot_mapping[i],
 | 
					        block_offset = block_offsets[i]
 | 
				
			||||||
                              block_size,
 | 
					 | 
				
			||||||
                              rounding_mode='floor')
 | 
					 | 
				
			||||||
        block_offset = slot_mapping[i] % block_size
 | 
					 | 
				
			||||||
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
 | 
					        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
 | 
				
			||||||
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
 | 
					        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert torch.allclose(key_cache, cloned_key_cache)
 | 
					    assert torch.allclose(key_cache, cloned_key_cache)
 | 
				
			||||||
    assert torch.allclose(value_cache, cloned_value_cache)
 | 
					    assert torch.allclose(value_cache, cloned_value_cache)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@torch.inference_mode()
 | 
					 | 
				
			||||||
def run_gather_cached_kv(
 | 
					 | 
				
			||||||
    num_tokens: int,
 | 
					 | 
				
			||||||
    num_heads: int,
 | 
					 | 
				
			||||||
    head_size: int,
 | 
					 | 
				
			||||||
    block_size: int,
 | 
					 | 
				
			||||||
    num_blocks: int,
 | 
					 | 
				
			||||||
    dtype: torch.dtype,
 | 
					 | 
				
			||||||
) -> None:
 | 
					 | 
				
			||||||
    num_slots = block_size * num_blocks
 | 
					 | 
				
			||||||
    slot_mapping = random.sample(range(num_slots), num_tokens)
 | 
					 | 
				
			||||||
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    qkv = torch.randn(num_tokens,
 | 
					 | 
				
			||||||
                      3,
 | 
					 | 
				
			||||||
                      num_heads,
 | 
					 | 
				
			||||||
                      head_size,
 | 
					 | 
				
			||||||
                      dtype=dtype,
 | 
					 | 
				
			||||||
                      device='cuda')
 | 
					 | 
				
			||||||
    _, key, value = qkv.unbind(dim=1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    qkv_clone = qkv.clone()
 | 
					 | 
				
			||||||
    _, cloned_key, cloned_value = qkv_clone.unbind(dim=1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    x = 16 // torch.tensor([], dtype=dtype).element_size()
 | 
					 | 
				
			||||||
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
 | 
					 | 
				
			||||||
    key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
 | 
					 | 
				
			||||||
    value_cache = torch.randn(size=value_cache_shape,
 | 
					 | 
				
			||||||
                              dtype=dtype,
 | 
					 | 
				
			||||||
                              device='cuda')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
 | 
					 | 
				
			||||||
                               slot_mapping)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Reference implementation.
 | 
					 | 
				
			||||||
    for i in range(num_tokens):
 | 
					 | 
				
			||||||
        reshaped_key = cloned_key.reshape(num_tokens, num_heads,
 | 
					 | 
				
			||||||
                                          head_size // x, x)
 | 
					 | 
				
			||||||
        block_idx = torch.div(slot_mapping[i],
 | 
					 | 
				
			||||||
                              block_size,
 | 
					 | 
				
			||||||
                              rounding_mode='floor')
 | 
					 | 
				
			||||||
        block_offset = slot_mapping[i] % block_size
 | 
					 | 
				
			||||||
        reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
 | 
					 | 
				
			||||||
        cloned_value[i] = value_cache[block_idx, :, :, block_offset]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    assert torch.allclose(key, cloned_key)
 | 
					 | 
				
			||||||
    assert torch.allclose(value, cloned_value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_copy_blocks() -> None:
 | 
					 | 
				
			||||||
    for dtype in [torch.half, torch.bfloat16, torch.float]:
 | 
					 | 
				
			||||||
        run_copy_blocks(num_mappings=23,
 | 
					 | 
				
			||||||
                        num_layers=7,
 | 
					 | 
				
			||||||
                        num_heads=17,
 | 
					 | 
				
			||||||
                        head_size=16,
 | 
					 | 
				
			||||||
                        block_size=8,
 | 
					 | 
				
			||||||
                        num_blocks=1024,
 | 
					 | 
				
			||||||
                        dtype=dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_reshape_and_cache() -> None:
 | 
					 | 
				
			||||||
    for dtype in [torch.half, torch.bfloat16, torch.float]:
 | 
					 | 
				
			||||||
        run_reshape_and_cache(num_tokens=3,
 | 
					 | 
				
			||||||
                              num_heads=2,
 | 
					 | 
				
			||||||
                              head_size=16,
 | 
					 | 
				
			||||||
                              block_size=8,
 | 
					 | 
				
			||||||
                              num_blocks=2,
 | 
					 | 
				
			||||||
                              dtype=dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_gather_cached_kv() -> None:
 | 
					 | 
				
			||||||
    for dtype in [torch.half, torch.bfloat16, torch.float]:
 | 
					 | 
				
			||||||
        run_gather_cached_kv(num_tokens=3,
 | 
					 | 
				
			||||||
                             num_heads=2,
 | 
					 | 
				
			||||||
                             head_size=16,
 | 
					 | 
				
			||||||
                             block_size=8,
 | 
					 | 
				
			||||||
                             num_blocks=2,
 | 
					 | 
				
			||||||
                             dtype=dtype)
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -1,35 +1,50 @@
 | 
				
			|||||||
 | 
					import pytest
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm import layernorm_ops
 | 
					from vllm import layernorm_ops
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DTYPES = [torch.half, torch.bfloat16, torch.float]
 | 
				
			||||||
 | 
					HIDDEN_SIZES = [67, 768, 2048, 5120, 8192]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					NUM_TOKENS = [7, 83, 4096]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					SEEDS = [0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RefRMSNorm(nn.Module):
 | 
					class RefRMSNorm(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, hidden_size, eps=1e-6):
 | 
					    def __init__(self, hidden_size, eps=1e-6):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        weight = torch.empty(hidden_size)
 | 
					        weight = torch.empty(hidden_size)
 | 
				
			||||||
        weight.uniform_(-1e-3, 1e-3)
 | 
					        weight.normal_(mean=1.0, std=0.1)
 | 
				
			||||||
        self.weight = nn.Parameter(weight)
 | 
					        self.weight = nn.Parameter(weight)
 | 
				
			||||||
        self.variance_epsilon = eps
 | 
					        self.variance_epsilon = eps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, hidden_states):
 | 
					    def forward(self, hidden_states):
 | 
				
			||||||
        variance = hidden_states.to(torch.float32).pow(2).mean(-1,
 | 
					        input_dtype = hidden_states.dtype
 | 
				
			||||||
                                                               keepdim=True)
 | 
					        hidden_states = hidden_states.to(torch.float32)
 | 
				
			||||||
 | 
					        variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
				
			||||||
        hidden_states = hidden_states * torch.rsqrt(variance +
 | 
					        hidden_states = hidden_states * torch.rsqrt(variance +
 | 
				
			||||||
                                                    self.variance_epsilon)
 | 
					                                                    self.variance_epsilon)
 | 
				
			||||||
        if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
 | 
					        return self.weight * hidden_states.to(input_dtype)
 | 
				
			||||||
            hidden_states = hidden_states.to(self.weight.dtype)
 | 
					 | 
				
			||||||
        return self.weight * hidden_states
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", DTYPES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", SEEDS)
 | 
				
			||||||
@torch.inference_mode()
 | 
					@torch.inference_mode()
 | 
				
			||||||
def run_rms_norm(
 | 
					def test_rms_norm(
 | 
				
			||||||
    num_tokens: int,
 | 
					    num_tokens: int,
 | 
				
			||||||
    hidden_size: int,
 | 
					    hidden_size: int,
 | 
				
			||||||
    dtype: torch.dtype,
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scale = float(hidden_size**-0.5)
 | 
				
			||||||
 | 
					    x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
 | 
				
			||||||
 | 
					    x.uniform_(-scale, scale)
 | 
				
			||||||
    ref = RefRMSNorm(hidden_size).to(dtype).cuda()
 | 
					    ref = RefRMSNorm(hidden_size).to(dtype).cuda()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    out = torch.empty_like(x)
 | 
					    out = torch.empty_like(x)
 | 
				
			||||||
@ -40,17 +55,4 @@ def run_rms_norm(
 | 
				
			|||||||
        ref.variance_epsilon,
 | 
					        ref.variance_epsilon,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    ref_out = ref(x)
 | 
					    ref_out = ref(x)
 | 
				
			||||||
    assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
 | 
					    assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_rms_norm() -> None:
 | 
					 | 
				
			||||||
    for dtype in [torch.half, torch.bfloat16, torch.float]:
 | 
					 | 
				
			||||||
        for num_tokens in [7, 128, 2048]:
 | 
					 | 
				
			||||||
            for hidden_size in [13, 64, 1024, 5120]:
 | 
					 | 
				
			||||||
                print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
 | 
					 | 
				
			||||||
                      f'{num_tokens}, hidden_size={hidden_size}')
 | 
					 | 
				
			||||||
                run_rms_norm(
 | 
					 | 
				
			||||||
                    num_tokens=num_tokens,
 | 
					 | 
				
			||||||
                    hidden_size=hidden_size,
 | 
					 | 
				
			||||||
                    dtype=dtype,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -1,47 +1,70 @@
 | 
				
			|||||||
from typing import Tuple
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm import pos_encoding_ops
 | 
					from vllm import pos_encoding_ops
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					IS_NEOX_STYLE = [True, False]
 | 
				
			||||||
 | 
					DTYPES = [torch.half, torch.bfloat16, torch.float]
 | 
				
			||||||
 | 
					HEAD_SIZES = [64, 80, 96, 112, 128, 256]
 | 
				
			||||||
 | 
					ROTARY_DIMS = [None, 32]  # None means rotary dim == head size
 | 
				
			||||||
 | 
					NUM_HEADS = [7, 12, 40, 52]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					NUM_TOKENS = [11, 83, 2048]  # Arbitrary values for testing
 | 
				
			||||||
 | 
					SEEDS = [0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
 | 
					
 | 
				
			||||||
 | 
					def rotate_neox(x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
    x1 = x[..., :x.shape[-1] // 2]
 | 
					    x1 = x[..., :x.shape[-1] // 2]
 | 
				
			||||||
    x2 = x[..., x.shape[-1] // 2:]
 | 
					    x2 = x[..., x.shape[-1] // 2:]
 | 
				
			||||||
    return torch.cat((-x2, x1), dim=-1)
 | 
					    return torch.cat((-x2, x1), dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def apply_rotary_pos_emb(
 | 
					def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					    x1 = x[..., ::2]
 | 
				
			||||||
 | 
					    x2 = x[..., 1::2]
 | 
				
			||||||
 | 
					    x = torch.stack((-x2, x1), dim=-1)
 | 
				
			||||||
 | 
					    return x.flatten(-2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def apply_rope(
 | 
				
			||||||
    q: torch.Tensor,
 | 
					    q: torch.Tensor,
 | 
				
			||||||
    k: torch.Tensor,
 | 
					    k: torch.Tensor,
 | 
				
			||||||
    cos: torch.Tensor,
 | 
					    cos: torch.Tensor,
 | 
				
			||||||
    sin: torch.Tensor,
 | 
					    sin: torch.Tensor,
 | 
				
			||||||
 | 
					    is_neox_style: bool,
 | 
				
			||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
					) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
    q_embed = (q * cos) + (rotate_half(q) * sin)
 | 
					    rotate_fn = rotate_neox if is_neox_style else rotate_gptj
 | 
				
			||||||
    k_embed = (k * cos) + (rotate_half(k) * sin)
 | 
					    q_embed = (q * cos) + (rotate_fn(q) * sin)
 | 
				
			||||||
 | 
					    k_embed = (k * cos) + (rotate_fn(k) * sin)
 | 
				
			||||||
    return q_embed, k_embed
 | 
					    return q_embed, k_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RefRotaryEmbeddingNeox(nn.Module):
 | 
					class RefRotaryEmbedding(nn.Module):
 | 
				
			||||||
    """Reference implementation of the GPT-NeoX style rotary embedding."""
 | 
					    """Reference implementation of rotary embedding."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        dim: int,
 | 
					        dim: int,
 | 
				
			||||||
        max_position_embeddings: int = 2048,
 | 
					        is_neox_style: bool,
 | 
				
			||||||
 | 
					        max_position_embeddings: int = 8192,
 | 
				
			||||||
        base: int = 10000,
 | 
					        base: int = 10000,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.rotary_dim = dim
 | 
					        self.rotary_dim = dim
 | 
				
			||||||
 | 
					        self.is_neox_style = is_neox_style
 | 
				
			||||||
        self.max_position_embeddings = max_position_embeddings
 | 
					        self.max_position_embeddings = max_position_embeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create cos and sin embeddings.
 | 
					        # Create cos and sin embeddings.
 | 
				
			||||||
        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
 | 
					        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
 | 
				
			||||||
        t = torch.arange(max_position_embeddings).float()
 | 
					        t = torch.arange(max_position_embeddings).float()
 | 
				
			||||||
        freqs = torch.einsum("i,j->ij", t, inv_freq.float())
 | 
					        freqs = torch.einsum("i,j->ij", t, inv_freq.float())
 | 
				
			||||||
        emb = torch.cat((freqs, freqs), dim=-1)
 | 
					        if is_neox_style:
 | 
				
			||||||
 | 
					            emb = torch.cat((freqs, freqs), dim=-1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            emb = torch.repeat_interleave(freqs, 2, -1)
 | 
				
			||||||
        cos = emb.cos().to(dtype=inv_freq.dtype)
 | 
					        cos = emb.cos().to(dtype=inv_freq.dtype)
 | 
				
			||||||
        sin = emb.sin().to(dtype=inv_freq.dtype)
 | 
					        sin = emb.sin().to(dtype=inv_freq.dtype)
 | 
				
			||||||
        self.register_buffer("cos_cached", cos, persistent=False)
 | 
					        self.register_buffer("cos_cached", cos, persistent=False)
 | 
				
			||||||
@ -53,7 +76,6 @@ class RefRotaryEmbeddingNeox(nn.Module):
 | 
				
			|||||||
        query: torch.Tensor,  # [num_tokens, num_heads, head_size]
 | 
					        query: torch.Tensor,  # [num_tokens, num_heads, head_size]
 | 
				
			||||||
        key: torch.Tensor,  # [num_tokens, num_heads, head_size]
 | 
					        key: torch.Tensor,  # [num_tokens, num_heads, head_size]
 | 
				
			||||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					 | 
				
			||||||
        query_rot = query[..., :self.rotary_dim]
 | 
					        query_rot = query[..., :self.rotary_dim]
 | 
				
			||||||
        query_pass = query[..., self.rotary_dim:]
 | 
					        query_pass = query[..., self.rotary_dim:]
 | 
				
			||||||
        key_rot = key[..., :self.rotary_dim]
 | 
					        key_rot = key[..., :self.rotary_dim]
 | 
				
			||||||
@ -63,7 +85,9 @@ class RefRotaryEmbeddingNeox(nn.Module):
 | 
				
			|||||||
        key_rot = key_rot.transpose(0, 1)
 | 
					        key_rot = key_rot.transpose(0, 1)
 | 
				
			||||||
        cos = F.embedding(positions, self.cos_cached)
 | 
					        cos = F.embedding(positions, self.cos_cached)
 | 
				
			||||||
        sin = F.embedding(positions, self.sin_cached)
 | 
					        sin = F.embedding(positions, self.sin_cached)
 | 
				
			||||||
        query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
 | 
					
 | 
				
			||||||
 | 
					        query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
 | 
				
			||||||
 | 
					                                        self.is_neox_style)
 | 
				
			||||||
        query_rot = query_rot.transpose(0, 1).contiguous()
 | 
					        query_rot = query_rot.transpose(0, 1).contiguous()
 | 
				
			||||||
        key_rot = key_rot.transpose(0, 1).contiguous()
 | 
					        key_rot = key_rot.transpose(0, 1).contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -74,52 +98,69 @@ class RefRotaryEmbeddingNeox(nn.Module):
 | 
				
			|||||||
        return query, key
 | 
					        return query, key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_heads", NUM_HEADS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("head_size", HEAD_SIZES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", DTYPES)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", SEEDS)
 | 
				
			||||||
@torch.inference_mode()
 | 
					@torch.inference_mode()
 | 
				
			||||||
def run_rotary_embedding_neox(
 | 
					def test_rotary_embedding(
 | 
				
			||||||
 | 
					    is_neox_style: bool,
 | 
				
			||||||
    num_tokens: int,
 | 
					    num_tokens: int,
 | 
				
			||||||
    num_heads: int,
 | 
					    num_heads: int,
 | 
				
			||||||
    head_size: int,
 | 
					    head_size: int,
 | 
				
			||||||
    max_position: int,
 | 
					    rotary_dim: Optional[int],
 | 
				
			||||||
    rotary_dim: int,
 | 
					 | 
				
			||||||
    dtype: torch.dtype,
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
 | 
					    seed: int,
 | 
				
			||||||
 | 
					    max_position: int = 8192,
 | 
				
			||||||
    base: int = 10000,
 | 
					    base: int = 10000,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
 | 
					    if rotary_dim is None:
 | 
				
			||||||
 | 
					        rotary_dim = head_size
 | 
				
			||||||
 | 
					    torch.random.manual_seed(seed)
 | 
				
			||||||
 | 
					    torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
 | 
				
			||||||
    query = torch.randn(num_tokens,
 | 
					    query = torch.randn(num_tokens,
 | 
				
			||||||
                        num_heads * head_size,
 | 
					                        num_heads * head_size,
 | 
				
			||||||
                        dtype=dtype,
 | 
					                        dtype=dtype,
 | 
				
			||||||
                        device='cuda')
 | 
					                        device="cuda")
 | 
				
			||||||
    key = torch.randn(num_tokens,
 | 
					    key = torch.randn(num_tokens,
 | 
				
			||||||
                      num_heads * head_size,
 | 
					                      num_heads * head_size,
 | 
				
			||||||
                      dtype=dtype,
 | 
					                      dtype=dtype,
 | 
				
			||||||
                      device='cuda')
 | 
					                      device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Create the rotary embedding.
 | 
					    # Create the rotary embedding.
 | 
				
			||||||
    inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
 | 
					    inv_freq = 1.0 / (base**(
 | 
				
			||||||
 | 
					        torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
 | 
				
			||||||
    t = torch.arange(max_position).float()
 | 
					    t = torch.arange(max_position).float()
 | 
				
			||||||
    freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
 | 
					    freqs = torch.einsum("i,j -> ij", t, inv_freq)
 | 
				
			||||||
    cos = freqs.cos()
 | 
					    cos = freqs.cos()
 | 
				
			||||||
    sin = freqs.sin()
 | 
					    sin = freqs.sin()
 | 
				
			||||||
    cos_sin_cache = torch.cat((cos, sin), dim=-1)
 | 
					    cos_sin_cache = torch.cat((cos, sin), dim=-1)
 | 
				
			||||||
    cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
 | 
					    cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Run the kernel. The kernel is in-place, so we need to clone the inputs.
 | 
					    # Run the kernel. The kernel is in-place, so we need to clone the inputs.
 | 
				
			||||||
    out_query = query.clone()
 | 
					    out_query = query.clone()
 | 
				
			||||||
    out_key = key.clone()
 | 
					    out_key = key.clone()
 | 
				
			||||||
    pos_encoding_ops.rotary_embedding_neox(
 | 
					    pos_encoding_ops.rotary_embedding(
 | 
				
			||||||
        positions,
 | 
					        positions,
 | 
				
			||||||
        out_query,
 | 
					        out_query,
 | 
				
			||||||
        out_key,
 | 
					        out_key,
 | 
				
			||||||
        head_size,
 | 
					        head_size,
 | 
				
			||||||
        cos_sin_cache,
 | 
					        cos_sin_cache,
 | 
				
			||||||
 | 
					        is_neox_style,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Run the reference implementation.
 | 
					    # Run the reference implementation.
 | 
				
			||||||
    ref_rotary_embedding = RefRotaryEmbeddingNeox(
 | 
					    ref_rotary_embedding = RefRotaryEmbedding(
 | 
				
			||||||
        dim=rotary_dim,
 | 
					        dim=rotary_dim,
 | 
				
			||||||
 | 
					        is_neox_style=is_neox_style,
 | 
				
			||||||
        max_position_embeddings=max_position,
 | 
					        max_position_embeddings=max_position,
 | 
				
			||||||
        base=base,
 | 
					        base=base,
 | 
				
			||||||
    ).to(dtype=dtype, device='cuda')
 | 
					    ).to(dtype=dtype, device="cuda")
 | 
				
			||||||
    ref_query, ref_key = ref_rotary_embedding(
 | 
					    ref_query, ref_key = ref_rotary_embedding(
 | 
				
			||||||
        positions,
 | 
					        positions,
 | 
				
			||||||
        query.view(num_tokens, num_heads, head_size),
 | 
					        query.view(num_tokens, num_heads, head_size),
 | 
				
			||||||
@ -129,19 +170,5 @@ def run_rotary_embedding_neox(
 | 
				
			|||||||
    ref_key = ref_key.view(num_tokens, num_heads * head_size)
 | 
					    ref_key = ref_key.view(num_tokens, num_heads * head_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Compare the results.
 | 
					    # Compare the results.
 | 
				
			||||||
    assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
 | 
					    assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
 | 
				
			||||||
    assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
 | 
					    assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_rotary_embedding_neox() -> None:
 | 
					 | 
				
			||||||
    for dtype in [torch.half, torch.bfloat16, torch.float]:
 | 
					 | 
				
			||||||
        for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
 | 
					 | 
				
			||||||
            print(f'Running tests for head_size={head_size} and dtype={dtype}')
 | 
					 | 
				
			||||||
            run_rotary_embedding_neox(
 | 
					 | 
				
			||||||
                num_tokens=2145,
 | 
					 | 
				
			||||||
                num_heads=5,
 | 
					 | 
				
			||||||
                head_size=head_size,
 | 
					 | 
				
			||||||
                max_position=8192,
 | 
					 | 
				
			||||||
                rotary_dim=head_size,
 | 
					 | 
				
			||||||
                dtype=dtype,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										45
									
								
								tests/models/test_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,45 @@
 | 
				
			|||||||
 | 
					"""Compare the outputs of HF and vLLM when using greedy sampling.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Run `pytest tests/models/test_models.py --forked`.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MODELS = [
 | 
				
			||||||
 | 
					    "facebook/opt-125m",
 | 
				
			||||||
 | 
					    "gpt2",
 | 
				
			||||||
 | 
					    "bigcode/tiny_starcoder_py",
 | 
				
			||||||
 | 
					    "EleutherAI/gpt-j-6b",
 | 
				
			||||||
 | 
					    "EleutherAI/pythia-70m",
 | 
				
			||||||
 | 
					    "bigscience/bloom-560m",
 | 
				
			||||||
 | 
					    "mosaicml/mpt-7b",
 | 
				
			||||||
 | 
					    "tiiuae/falcon-7b",
 | 
				
			||||||
 | 
					    "meta-llama/Llama-2-7b-hf",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("model", MODELS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", ["half"])
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("max_tokens", [128])
 | 
				
			||||||
 | 
					def test_models(
 | 
				
			||||||
 | 
					    hf_runner,
 | 
				
			||||||
 | 
					    vllm_runner,
 | 
				
			||||||
 | 
					    example_prompts,
 | 
				
			||||||
 | 
					    model: str,
 | 
				
			||||||
 | 
					    dtype: str,
 | 
				
			||||||
 | 
					    max_tokens: int,
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					    hf_model = hf_runner(model, dtype=dtype)
 | 
				
			||||||
 | 
					    hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
 | 
				
			||||||
 | 
					    del hf_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vllm_model = vllm_runner(model, dtype=dtype)
 | 
				
			||||||
 | 
					    vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
 | 
				
			||||||
 | 
					    del vllm_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for i in range(len(example_prompts)):
 | 
				
			||||||
 | 
					        hf_output_ids, hf_output_str = hf_outputs[i]
 | 
				
			||||||
 | 
					        vllm_output_ids, vllm_output_str = vllm_outputs[i]
 | 
				
			||||||
 | 
					        assert hf_output_str == vllm_output_str, (
 | 
				
			||||||
 | 
					            f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
 | 
				
			||||||
 | 
					        assert hf_output_ids == vllm_output_ids, (
 | 
				
			||||||
 | 
					            f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
 | 
				
			||||||
							
								
								
									
										46
									
								
								tests/samplers/test_beam_search.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,46 @@
 | 
				
			|||||||
 | 
					"""Compare the outputs of HF and vLLM when using beam search.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Run `pytest tests/samplers/test_beam_search.py --forked`.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# FIXME(zhuohan): The test can not pass if we:
 | 
				
			||||||
 | 
					#   1. Increase max_tokens to 256.
 | 
				
			||||||
 | 
					#   2. Increase beam_width to 8.
 | 
				
			||||||
 | 
					#   3. Use the model "huggyllama/llama-7b".
 | 
				
			||||||
 | 
					MAX_TOKENS = [128]
 | 
				
			||||||
 | 
					BEAM_WIDTHS = [4]
 | 
				
			||||||
 | 
					MODELS = ["facebook/opt-125m"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("model", MODELS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", ["half"])
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
 | 
				
			||||||
 | 
					def test_beam_search_single_input(
 | 
				
			||||||
 | 
					    hf_runner,
 | 
				
			||||||
 | 
					    vllm_runner,
 | 
				
			||||||
 | 
					    example_prompts,
 | 
				
			||||||
 | 
					    model: str,
 | 
				
			||||||
 | 
					    dtype: str,
 | 
				
			||||||
 | 
					    max_tokens: int,
 | 
				
			||||||
 | 
					    beam_width: int,
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					    hf_model = hf_runner(model, dtype=dtype)
 | 
				
			||||||
 | 
					    hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
 | 
				
			||||||
 | 
					                                               max_tokens)
 | 
				
			||||||
 | 
					    del hf_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vllm_model = vllm_runner(model, dtype=dtype)
 | 
				
			||||||
 | 
					    vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
 | 
				
			||||||
 | 
					                                                   max_tokens)
 | 
				
			||||||
 | 
					    del vllm_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for i in range(len(example_prompts)):
 | 
				
			||||||
 | 
					        hf_output_ids, _ = hf_outputs[i]
 | 
				
			||||||
 | 
					        vllm_output_ids, _ = vllm_outputs[i]
 | 
				
			||||||
 | 
					        assert len(hf_output_ids) == len(vllm_output_ids)
 | 
				
			||||||
 | 
					        for j in range(len(hf_output_ids)):
 | 
				
			||||||
 | 
					            assert hf_output_ids[j] == vllm_output_ids[j], (
 | 
				
			||||||
 | 
					                f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
 | 
				
			||||||
 | 
					                f"vLLM: {vllm_output_ids}")
 | 
				
			||||||
							
								
								
									
										55
									
								
								tests/samplers/test_logprobs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,55 @@
 | 
				
			|||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm import SamplingParams
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MODELS = ["facebook/opt-125m"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("model", MODELS)
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", ["half"])
 | 
				
			||||||
 | 
					def test_get_prompt_logprobs(
 | 
				
			||||||
 | 
					    hf_runner,
 | 
				
			||||||
 | 
					    vllm_runner,
 | 
				
			||||||
 | 
					    model,
 | 
				
			||||||
 | 
					    dtype,
 | 
				
			||||||
 | 
					    example_prompts,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    max_tokens = 5
 | 
				
			||||||
 | 
					    hf_model = hf_runner(model, dtype=dtype)
 | 
				
			||||||
 | 
					    hf_logprobs = hf_model.generate_greedy_logprobs(
 | 
				
			||||||
 | 
					        example_prompts,
 | 
				
			||||||
 | 
					        max_tokens=max_tokens,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    del hf_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vllm_model = vllm_runner(model, dtype=dtype)
 | 
				
			||||||
 | 
					    vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
 | 
				
			||||||
 | 
					                                          logprobs=5,
 | 
				
			||||||
 | 
					                                          prompt_logprobs=5,
 | 
				
			||||||
 | 
					                                          temperature=0.0)
 | 
				
			||||||
 | 
					    vllm_results = vllm_model.model.generate(
 | 
				
			||||||
 | 
					        example_prompts, sampling_params=vllm_sampling_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Test whether logprobs are included in the results.
 | 
				
			||||||
 | 
					    for result in vllm_results:
 | 
				
			||||||
 | 
					        assert result.prompt_logprobs is not None
 | 
				
			||||||
 | 
					        assert result.outputs[0].logprobs is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Test whether prompt logprobs are consistent with HF
 | 
				
			||||||
 | 
					    for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
 | 
				
			||||||
 | 
					        # Check prompt logprobs
 | 
				
			||||||
 | 
					        vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
 | 
				
			||||||
 | 
					        for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
 | 
				
			||||||
 | 
					            for token_id, logprob in vllm_prompt_logprob_dict.items():
 | 
				
			||||||
 | 
					                torch.testing.assert_close(logprob,
 | 
				
			||||||
 | 
					                                           hf_logprob[0][i][token_id].item(),
 | 
				
			||||||
 | 
					                                           atol=1e-2,
 | 
				
			||||||
 | 
					                                           rtol=1e-2)
 | 
				
			||||||
 | 
					        vllm_sample_logprobs = vllm_result.outputs[0].logprobs
 | 
				
			||||||
 | 
					        for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs):
 | 
				
			||||||
 | 
					            for token_id, logprob in vllm_sample_logprob_dict.items():
 | 
				
			||||||
 | 
					                torch.testing.assert_close(logprob,
 | 
				
			||||||
 | 
					                                           hf_logprob[i][-1][token_id].item(),
 | 
				
			||||||
 | 
					                                           atol=1e-2,
 | 
				
			||||||
 | 
					                                           rtol=1e-2)
 | 
				
			||||||
							
								
								
									
										185
									
								
								tests/samplers/test_sampler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,185 @@
 | 
				
			|||||||
 | 
					# pylint: disable=protected-access
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					from typing import Tuple
 | 
				
			||||||
 | 
					from unittest.mock import patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.sampler import Sampler
 | 
				
			||||||
 | 
					from vllm.model_executor.utils import set_random_seed
 | 
				
			||||||
 | 
					from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
 | 
				
			||||||
 | 
					from vllm.worker.worker import Worker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MockLogitsSampler(Sampler):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
 | 
				
			||||||
 | 
					        super().__init__(vocab_size=vocab_size)
 | 
				
			||||||
 | 
					        self.fake_logits = fake_logits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
 | 
				
			||||||
 | 
					                   lambda x, y: x):
 | 
				
			||||||
 | 
					            with patch("vllm.model_executor.layers.sampler._get_logits",
 | 
				
			||||||
 | 
					                       lambda *args, **kwargs: self.fake_logits):
 | 
				
			||||||
 | 
					                return super().forward(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _prepare_test(
 | 
				
			||||||
 | 
					    batch_size: int
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
 | 
				
			||||||
 | 
					    vocab_size = 32000
 | 
				
			||||||
 | 
					    input_tensor = torch.rand((batch_size, 1024),
 | 
				
			||||||
 | 
					                              device="cuda",
 | 
				
			||||||
 | 
					                              dtype=torch.float16)
 | 
				
			||||||
 | 
					    fake_logits = torch.full((batch_size, vocab_size),
 | 
				
			||||||
 | 
					                             1e-2,
 | 
				
			||||||
 | 
					                             device=input_tensor.device,
 | 
				
			||||||
 | 
					                             dtype=input_tensor.dtype)
 | 
				
			||||||
 | 
					    sampler = MockLogitsSampler(32000, fake_logits)
 | 
				
			||||||
 | 
					    worker = Worker(None, None, None)
 | 
				
			||||||
 | 
					    worker.block_size = 16
 | 
				
			||||||
 | 
					    return input_tensor, fake_logits, sampler, worker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					RANDOM_SEEDS = list(range(128))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", RANDOM_SEEDS)
 | 
				
			||||||
 | 
					def test_sampler_all_greedy(seed: int):
 | 
				
			||||||
 | 
					    set_random_seed(seed)
 | 
				
			||||||
 | 
					    batch_size = random.randint(1, 256)
 | 
				
			||||||
 | 
					    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    seq_group_metadata_list = []
 | 
				
			||||||
 | 
					    for i in range(batch_size):
 | 
				
			||||||
 | 
					        seq_group_metadata_list.append(
 | 
				
			||||||
 | 
					            SequenceGroupMetadata(
 | 
				
			||||||
 | 
					                request_id=f"test_{i}",
 | 
				
			||||||
 | 
					                is_prompt=True,
 | 
				
			||||||
 | 
					                seq_data={0: SequenceData([1, 2, 3])},
 | 
				
			||||||
 | 
					                sampling_params=SamplingParams(temperature=0, ),
 | 
				
			||||||
 | 
					                block_tables={0: [1]},
 | 
				
			||||||
 | 
					            ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
 | 
				
			||||||
 | 
					    sampler_output = sampler(embedding=None,
 | 
				
			||||||
 | 
					                             hidden_states=input_tensor,
 | 
				
			||||||
 | 
					                             input_metadata=input_metadata)
 | 
				
			||||||
 | 
					    expected = torch.argmax(fake_logits, dim=-1)
 | 
				
			||||||
 | 
					    for i, sequence_output in enumerate(sampler_output):
 | 
				
			||||||
 | 
					        for nth_output in sequence_output.samples:
 | 
				
			||||||
 | 
					            assert nth_output.output_token == expected[i].item()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", RANDOM_SEEDS)
 | 
				
			||||||
 | 
					def test_sampler_all_random(seed: int):
 | 
				
			||||||
 | 
					    set_random_seed(seed)
 | 
				
			||||||
 | 
					    batch_size = random.randint(1, 256)
 | 
				
			||||||
 | 
					    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for i in range(batch_size):
 | 
				
			||||||
 | 
					        fake_logits[i, i] = 1e2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    seq_group_metadata_list = []
 | 
				
			||||||
 | 
					    for i in range(batch_size):
 | 
				
			||||||
 | 
					        seq_group_metadata_list.append(
 | 
				
			||||||
 | 
					            SequenceGroupMetadata(
 | 
				
			||||||
 | 
					                request_id=f"test_{i}",
 | 
				
			||||||
 | 
					                is_prompt=True,
 | 
				
			||||||
 | 
					                seq_data={0: SequenceData([1, 2, 3])},
 | 
				
			||||||
 | 
					                sampling_params=SamplingParams(
 | 
				
			||||||
 | 
					                    temperature=1.0,
 | 
				
			||||||
 | 
					                    n=random.randint(1, 10),
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                block_tables={0: [1]},
 | 
				
			||||||
 | 
					            ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
 | 
				
			||||||
 | 
					    sampler_output = sampler(embedding=None,
 | 
				
			||||||
 | 
					                             hidden_states=input_tensor,
 | 
				
			||||||
 | 
					                             input_metadata=input_metadata)
 | 
				
			||||||
 | 
					    for i, sequence_output in enumerate(sampler_output):
 | 
				
			||||||
 | 
					        for nth_output in sequence_output.samples:
 | 
				
			||||||
 | 
					            assert nth_output.output_token == i
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", RANDOM_SEEDS)
 | 
				
			||||||
 | 
					def test_sampler_all_beam(seed: int):
 | 
				
			||||||
 | 
					    set_random_seed(seed)
 | 
				
			||||||
 | 
					    batch_size = random.randint(1, 256)
 | 
				
			||||||
 | 
					    input_tensor, _, sampler, worker = _prepare_test(batch_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    seq_group_metadata_list = []
 | 
				
			||||||
 | 
					    for i in range(batch_size):
 | 
				
			||||||
 | 
					        seq_group_metadata_list.append(
 | 
				
			||||||
 | 
					            SequenceGroupMetadata(
 | 
				
			||||||
 | 
					                request_id=f"test_{i}",
 | 
				
			||||||
 | 
					                is_prompt=True,
 | 
				
			||||||
 | 
					                seq_data={0: SequenceData([1, 2, 3])},
 | 
				
			||||||
 | 
					                sampling_params=SamplingParams(
 | 
				
			||||||
 | 
					                    temperature=0,
 | 
				
			||||||
 | 
					                    best_of=2,
 | 
				
			||||||
 | 
					                    use_beam_search=True,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                block_tables={0: [1]},
 | 
				
			||||||
 | 
					            ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
 | 
				
			||||||
 | 
					    sampler(embedding=None,
 | 
				
			||||||
 | 
					            hidden_states=input_tensor,
 | 
				
			||||||
 | 
					            input_metadata=input_metadata)
 | 
				
			||||||
 | 
					    # no assertion here as I am not sure how to determine whether
 | 
				
			||||||
 | 
					    # the outputs are expected - in other words, this just tests
 | 
				
			||||||
 | 
					    # whether there are no exceptions in the sampler
 | 
				
			||||||
 | 
					    # when handling an all-beam search case.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("seed", RANDOM_SEEDS)
 | 
				
			||||||
 | 
					def test_sampler_mixed(seed: int):
 | 
				
			||||||
 | 
					    set_random_seed(seed)
 | 
				
			||||||
 | 
					    batch_size = random.randint(1, 256)
 | 
				
			||||||
 | 
					    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    seq_group_metadata_list = []
 | 
				
			||||||
 | 
					    expected_tokens = []
 | 
				
			||||||
 | 
					    for i in range(batch_size):
 | 
				
			||||||
 | 
					        n = 1
 | 
				
			||||||
 | 
					        sampling_type = random.randint(0, 2)
 | 
				
			||||||
 | 
					        if sampling_type == 0:
 | 
				
			||||||
 | 
					            sampling_params = SamplingParams(temperature=0)
 | 
				
			||||||
 | 
					        elif sampling_type == 1:
 | 
				
			||||||
 | 
					            n = random.randint(1, 10)
 | 
				
			||||||
 | 
					            sampling_params = SamplingParams(
 | 
				
			||||||
 | 
					                temperature=random.random() + 0.1,
 | 
				
			||||||
 | 
					                top_p=min(random.random() + 0.1, 1),
 | 
				
			||||||
 | 
					                top_k=random.randint(0, 10) or -1,
 | 
				
			||||||
 | 
					                n=n,
 | 
				
			||||||
 | 
					                presence_penalty=random.randint(0, 1),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            sampling_params = SamplingParams(temperature=0,
 | 
				
			||||||
 | 
					                                             use_beam_search=True,
 | 
				
			||||||
 | 
					                                             best_of=2)
 | 
				
			||||||
 | 
					        for idx in range(n):
 | 
				
			||||||
 | 
					            fake_logits[i, i + idx] = 1e2
 | 
				
			||||||
 | 
					            expected_tokens.append(i + idx)
 | 
				
			||||||
 | 
					        seq_group_metadata_list.append(
 | 
				
			||||||
 | 
					            SequenceGroupMetadata(
 | 
				
			||||||
 | 
					                request_id=f"test_{i}",
 | 
				
			||||||
 | 
					                is_prompt=True,
 | 
				
			||||||
 | 
					                seq_data={0: SequenceData([1, 2, 3])},
 | 
				
			||||||
 | 
					                sampling_params=sampling_params,
 | 
				
			||||||
 | 
					                block_tables={0: [1]},
 | 
				
			||||||
 | 
					            ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
 | 
				
			||||||
 | 
					    sampler_output = sampler(embedding=None,
 | 
				
			||||||
 | 
					                             hidden_states=input_tensor,
 | 
				
			||||||
 | 
					                             input_metadata=input_metadata)
 | 
				
			||||||
 | 
					    for i, sequence_output in enumerate(sampler_output):
 | 
				
			||||||
 | 
					        if seq_group_metadata_list[i].sampling_params.use_beam_search:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					        for nth_output in sequence_output.samples:
 | 
				
			||||||
 | 
					            assert nth_output.output_token in expected_tokens
 | 
				
			||||||
@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
 | 
				
			|||||||
from vllm.outputs import CompletionOutput, RequestOutput
 | 
					from vllm.outputs import CompletionOutput, RequestOutput
 | 
				
			||||||
from vllm.sampling_params import SamplingParams
 | 
					from vllm.sampling_params import SamplingParams
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__version__ = "0.1.2"
 | 
					__version__ = "0.2.1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = [
 | 
					__all__ = [
 | 
				
			||||||
    "LLM",
 | 
					    "LLM",
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										193
									
								
								vllm/config.py
									
									
									
									
									
								
							
							
						
						@ -20,15 +20,34 @@ class ModelConfig:
 | 
				
			|||||||
        tokenizer: Name or path of the huggingface tokenizer to use.
 | 
					        tokenizer: Name or path of the huggingface tokenizer to use.
 | 
				
			||||||
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
 | 
					        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
 | 
				
			||||||
            available, and "slow" will always use the slow tokenizer.
 | 
					            available, and "slow" will always use the slow tokenizer.
 | 
				
			||||||
 | 
					        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
 | 
				
			||||||
 | 
					            downloading the model and tokenizer.
 | 
				
			||||||
        download_dir: Directory to download and load the weights, default to the
 | 
					        download_dir: Directory to download and load the weights, default to the
 | 
				
			||||||
            default cache directory of huggingface.
 | 
					            default cache directory of huggingface.
 | 
				
			||||||
        use_np_weights: Save a numpy copy of model weights for faster loading.
 | 
					        load_format: The format of the model weights to load:
 | 
				
			||||||
            This can increase the disk usage by up to 2x.
 | 
					            "auto" will try to load the weights in the safetensors format and
 | 
				
			||||||
        use_dummy_weights: Use dummy values for model weights (for profiling).
 | 
					                fall back to the pytorch bin format if safetensors format is
 | 
				
			||||||
 | 
					                not available.
 | 
				
			||||||
 | 
					            "pt" will load the weights in the pytorch bin format.
 | 
				
			||||||
 | 
					            "safetensors" will load the weights in the safetensors format.
 | 
				
			||||||
 | 
					            "npcache" will load the weights in pytorch format and store
 | 
				
			||||||
 | 
					                a numpy cache to speed up the loading.
 | 
				
			||||||
 | 
					            "dummy" will initialize the weights with random values, which is
 | 
				
			||||||
 | 
					                mainly for profiling.
 | 
				
			||||||
        dtype: Data type for model weights and activations. The "auto" option
 | 
					        dtype: Data type for model weights and activations. The "auto" option
 | 
				
			||||||
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
 | 
					            will use FP16 precision for FP32 and FP16 models, and BF16 precision
 | 
				
			||||||
            for BF16 models.
 | 
					            for BF16 models.
 | 
				
			||||||
        seed: Random seed for reproducibility.
 | 
					        seed: Random seed for reproducibility.
 | 
				
			||||||
 | 
					        revision: The specific model version to use. It can be a branch name,
 | 
				
			||||||
 | 
					            a tag name, or a commit id. If unspecified, will use the default
 | 
				
			||||||
 | 
					            version.
 | 
				
			||||||
 | 
					        tokenizer_revision: The specific tokenizer version to use. It can be a
 | 
				
			||||||
 | 
					            branch name, a tag name, or a commit id. If unspecified, will use
 | 
				
			||||||
 | 
					            the default version.
 | 
				
			||||||
 | 
					        max_model_len: Maximum length of a sequence (including prompt and
 | 
				
			||||||
 | 
					            output). If None, will be derived from the model.
 | 
				
			||||||
 | 
					        quantization: Quantization method that was used to quantize the model
 | 
				
			||||||
 | 
					            weights. If None, we assume the model weights are not quantized.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
@ -36,23 +55,44 @@ class ModelConfig:
 | 
				
			|||||||
        model: str,
 | 
					        model: str,
 | 
				
			||||||
        tokenizer: str,
 | 
					        tokenizer: str,
 | 
				
			||||||
        tokenizer_mode: str,
 | 
					        tokenizer_mode: str,
 | 
				
			||||||
 | 
					        trust_remote_code: bool,
 | 
				
			||||||
        download_dir: Optional[str],
 | 
					        download_dir: Optional[str],
 | 
				
			||||||
        use_np_weights: bool,
 | 
					        load_format: str,
 | 
				
			||||||
        use_dummy_weights: bool,
 | 
					 | 
				
			||||||
        dtype: str,
 | 
					        dtype: str,
 | 
				
			||||||
        seed: int,
 | 
					        seed: int,
 | 
				
			||||||
 | 
					        revision: Optional[str] = None,
 | 
				
			||||||
 | 
					        tokenizer_revision: Optional[str] = None,
 | 
				
			||||||
 | 
					        max_model_len: Optional[int] = None,
 | 
				
			||||||
 | 
					        quantization: Optional[str] = None,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
        self.tokenizer = tokenizer
 | 
					        self.tokenizer = tokenizer
 | 
				
			||||||
        self.tokenizer_mode = tokenizer_mode
 | 
					        self.tokenizer_mode = tokenizer_mode
 | 
				
			||||||
 | 
					        self.trust_remote_code = trust_remote_code
 | 
				
			||||||
        self.download_dir = download_dir
 | 
					        self.download_dir = download_dir
 | 
				
			||||||
        self.use_np_weights = use_np_weights
 | 
					        self.load_format = load_format
 | 
				
			||||||
        self.use_dummy_weights = use_dummy_weights
 | 
					 | 
				
			||||||
        self.seed = seed
 | 
					        self.seed = seed
 | 
				
			||||||
 | 
					        self.revision = revision
 | 
				
			||||||
 | 
					        self.tokenizer_revision = tokenizer_revision
 | 
				
			||||||
 | 
					        self.quantization = quantization
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.hf_config = get_config(model)
 | 
					        self.hf_config = get_config(model, trust_remote_code, revision)
 | 
				
			||||||
        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
 | 
					        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
 | 
				
			||||||
 | 
					        self.max_model_len = _get_and_verify_max_len(self.hf_config,
 | 
				
			||||||
 | 
					                                                     max_model_len)
 | 
				
			||||||
 | 
					        self._verify_load_format()
 | 
				
			||||||
        self._verify_tokenizer_mode()
 | 
					        self._verify_tokenizer_mode()
 | 
				
			||||||
 | 
					        self._verify_quantization()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _verify_load_format(self) -> None:
 | 
				
			||||||
 | 
					        load_format = self.load_format.lower()
 | 
				
			||||||
 | 
					        if load_format not in [
 | 
				
			||||||
 | 
					                "auto", "pt", "safetensors", "npcache", "dummy"
 | 
				
			||||||
 | 
					        ]:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"Unknown load format: {self.load_format}. Must be one of "
 | 
				
			||||||
 | 
					                "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
 | 
				
			||||||
 | 
					        self.load_format = load_format
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _verify_tokenizer_mode(self) -> None:
 | 
					    def _verify_tokenizer_mode(self) -> None:
 | 
				
			||||||
        tokenizer_mode = self.tokenizer_mode.lower()
 | 
					        tokenizer_mode = self.tokenizer_mode.lower()
 | 
				
			||||||
@ -62,6 +102,17 @@ class ModelConfig:
 | 
				
			|||||||
                "either 'auto' or 'slow'.")
 | 
					                "either 'auto' or 'slow'.")
 | 
				
			||||||
        self.tokenizer_mode = tokenizer_mode
 | 
					        self.tokenizer_mode = tokenizer_mode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _verify_quantization(self) -> None:
 | 
				
			||||||
 | 
					        supported_quantization = ["awq"]
 | 
				
			||||||
 | 
					        if self.quantization is None:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        quantization = self.quantization.lower()
 | 
				
			||||||
 | 
					        if quantization not in supported_quantization:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"Unknown quantization: {self.quantization}. Must be one of "
 | 
				
			||||||
 | 
					                f"{supported_quantization}.")
 | 
				
			||||||
 | 
					        self.quantization = quantization
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def verify_with_parallel_config(
 | 
					    def verify_with_parallel_config(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        parallel_config: "ParallelConfig",
 | 
					        parallel_config: "ParallelConfig",
 | 
				
			||||||
@ -89,7 +140,32 @@ class ModelConfig:
 | 
				
			|||||||
        # FIXME(woosuk): This may not be true for all models.
 | 
					        # FIXME(woosuk): This may not be true for all models.
 | 
				
			||||||
        return self.hf_config.hidden_size // self.hf_config.num_attention_heads
 | 
					        return self.hf_config.hidden_size // self.hf_config.num_attention_heads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
 | 
					    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
 | 
				
			||||||
 | 
					        """Returns the number of KV heads per GPU worker."""
 | 
				
			||||||
 | 
					        # For GPTBigCode & Falcon:
 | 
				
			||||||
 | 
					        # NOTE: for falcon, when new_decoder_architecture is True, the
 | 
				
			||||||
 | 
					        # multi_query flag is ignored and we use n_head_kv for the number of
 | 
				
			||||||
 | 
					        # KV heads.
 | 
				
			||||||
 | 
					        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
 | 
				
			||||||
 | 
					        new_decoder_arch_falcon = (
 | 
				
			||||||
 | 
					            self.hf_config.model_type in falcon_model_types
 | 
				
			||||||
 | 
					            and getattr(self.hf_config, "new_decoder_architecture", False))
 | 
				
			||||||
 | 
					        if not new_decoder_arch_falcon and getattr(self.hf_config,
 | 
				
			||||||
 | 
					                                                   "multi_query", False):
 | 
				
			||||||
 | 
					            # Multi-query attention, only one KV head.
 | 
				
			||||||
 | 
					            # Currently, tensor parallelism is not supported in this case.
 | 
				
			||||||
 | 
					            return 1
 | 
				
			||||||
 | 
					        # For Falcon:
 | 
				
			||||||
 | 
					        if getattr(self.hf_config, "n_head_kv", None) is not None:
 | 
				
			||||||
 | 
					            return (self.hf_config.n_head_kv //
 | 
				
			||||||
 | 
					                    parallel_config.tensor_parallel_size)
 | 
				
			||||||
 | 
					        if getattr(self.hf_config, "num_kv_heads", None) is not None:
 | 
				
			||||||
 | 
					            return (self.hf_config.num_kv_heads //
 | 
				
			||||||
 | 
					                    parallel_config.tensor_parallel_size)
 | 
				
			||||||
 | 
					        # For LLaMA-2:
 | 
				
			||||||
 | 
					        if getattr(self.hf_config, "num_key_value_heads", None) is not None:
 | 
				
			||||||
 | 
					            return (self.hf_config.num_key_value_heads //
 | 
				
			||||||
 | 
					                    parallel_config.tensor_parallel_size)
 | 
				
			||||||
        total_num_attention_heads = self.hf_config.num_attention_heads
 | 
					        total_num_attention_heads = self.hf_config.num_attention_heads
 | 
				
			||||||
        return total_num_attention_heads // parallel_config.tensor_parallel_size
 | 
					        return total_num_attention_heads // parallel_config.tensor_parallel_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -113,10 +189,12 @@ class CacheConfig:
 | 
				
			|||||||
        block_size: int,
 | 
					        block_size: int,
 | 
				
			||||||
        gpu_memory_utilization: float,
 | 
					        gpu_memory_utilization: float,
 | 
				
			||||||
        swap_space: int,
 | 
					        swap_space: int,
 | 
				
			||||||
 | 
					        sliding_window: Optional[int] = None,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        self.block_size = block_size
 | 
					        self.block_size = block_size
 | 
				
			||||||
        self.gpu_memory_utilization = gpu_memory_utilization
 | 
					        self.gpu_memory_utilization = gpu_memory_utilization
 | 
				
			||||||
        self.swap_space_bytes = swap_space * _GB
 | 
					        self.swap_space_bytes = swap_space * _GB
 | 
				
			||||||
 | 
					        self.sliding_window = sliding_window
 | 
				
			||||||
        self._verify_args()
 | 
					        self._verify_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Will be set after profiling.
 | 
					        # Will be set after profiling.
 | 
				
			||||||
@ -188,15 +266,40 @@ class SchedulerConfig:
 | 
				
			|||||||
            a single iteration.
 | 
					            a single iteration.
 | 
				
			||||||
        max_num_seqs: Maximum number of sequences to be processed in a single
 | 
					        max_num_seqs: Maximum number of sequences to be processed in a single
 | 
				
			||||||
            iteration.
 | 
					            iteration.
 | 
				
			||||||
        max_seq_len: Maximum length of a sequence (including prompt
 | 
					        max_model_len: Maximum length of a sequence (including prompt
 | 
				
			||||||
            and generated text).
 | 
					            and generated text).
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
 | 
					    def __init__(
 | 
				
			||||||
                 max_seq_len: int) -> None:
 | 
					        self,
 | 
				
			||||||
        self.max_num_batched_tokens = max_num_batched_tokens
 | 
					        max_num_batched_tokens: Optional[int],
 | 
				
			||||||
 | 
					        max_num_seqs: int,
 | 
				
			||||||
 | 
					        max_model_len: int,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        if max_num_batched_tokens is not None:
 | 
				
			||||||
 | 
					            self.max_num_batched_tokens = max_num_batched_tokens
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # If max_model_len is too short, use 2048 as the default value for
 | 
				
			||||||
 | 
					            # higher throughput.
 | 
				
			||||||
 | 
					            self.max_num_batched_tokens = max(max_model_len, 2048)
 | 
				
			||||||
        self.max_num_seqs = max_num_seqs
 | 
					        self.max_num_seqs = max_num_seqs
 | 
				
			||||||
        self.max_seq_len = max_seq_len
 | 
					        self.max_model_len = max_model_len
 | 
				
			||||||
 | 
					        self._verify_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _verify_args(self) -> None:
 | 
				
			||||||
 | 
					        if self.max_num_batched_tokens < self.max_model_len:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
 | 
				
			||||||
 | 
					                f"smaller than max_model_len ({self.max_model_len}). "
 | 
				
			||||||
 | 
					                "This effectively limits the maximum sequence length to "
 | 
				
			||||||
 | 
					                "max_num_batched_tokens and makes vLLM reject longer "
 | 
				
			||||||
 | 
					                "sequences. Please increase max_num_batched_tokens or "
 | 
				
			||||||
 | 
					                "decrease max_model_len.")
 | 
				
			||||||
 | 
					        if self.max_num_batched_tokens < self.max_num_seqs:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
 | 
				
			||||||
 | 
					                "be greater than or equal to max_num_seqs "
 | 
				
			||||||
 | 
					                f"({self.max_num_seqs}).")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
 | 
					_STR_DTYPE_TO_TORCH_DTYPE = {
 | 
				
			||||||
@ -242,13 +345,57 @@ def _get_and_verify_dtype(
 | 
				
			|||||||
            # Casting between float16 and bfloat16 is allowed with a warning.
 | 
					            # Casting between float16 and bfloat16 is allowed with a warning.
 | 
				
			||||||
            logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
 | 
					            logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Check if the GPU supports the dtype.
 | 
					 | 
				
			||||||
    if torch_dtype == torch.bfloat16:
 | 
					 | 
				
			||||||
        compute_capability = torch.cuda.get_device_capability()
 | 
					 | 
				
			||||||
        if compute_capability[0] < 8:
 | 
					 | 
				
			||||||
            gpu_name = torch.cuda.get_device_name()
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                "Bfloat16 is only supported on GPUs with compute capability "
 | 
					 | 
				
			||||||
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
 | 
					 | 
				
			||||||
                f"{compute_capability[0]}.{compute_capability[1]}.")
 | 
					 | 
				
			||||||
    return torch_dtype
 | 
					    return torch_dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_and_verify_max_len(
 | 
				
			||||||
 | 
					    hf_config: PretrainedConfig,
 | 
				
			||||||
 | 
					    max_model_len: Optional[int],
 | 
				
			||||||
 | 
					) -> int:
 | 
				
			||||||
 | 
					    """Get and verify the model's maximum length."""
 | 
				
			||||||
 | 
					    derived_max_model_len = float("inf")
 | 
				
			||||||
 | 
					    possible_keys = [
 | 
				
			||||||
 | 
					        # OPT
 | 
				
			||||||
 | 
					        "max_position_embeddings",
 | 
				
			||||||
 | 
					        # GPT-2
 | 
				
			||||||
 | 
					        "n_positions",
 | 
				
			||||||
 | 
					        # MPT
 | 
				
			||||||
 | 
					        "max_seq_len",
 | 
				
			||||||
 | 
					        # Others
 | 
				
			||||||
 | 
					        "max_sequence_length",
 | 
				
			||||||
 | 
					        "max_seq_length",
 | 
				
			||||||
 | 
					        "seq_len",
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    for key in possible_keys:
 | 
				
			||||||
 | 
					        max_len_key = getattr(hf_config, key, None)
 | 
				
			||||||
 | 
					        if max_len_key is not None:
 | 
				
			||||||
 | 
					            derived_max_model_len = min(derived_max_model_len, max_len_key)
 | 
				
			||||||
 | 
					    if derived_max_model_len == float("inf"):
 | 
				
			||||||
 | 
					        if max_model_len is not None:
 | 
				
			||||||
 | 
					            # If max_model_len is specified, we use it.
 | 
				
			||||||
 | 
					            return max_model_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        default_max_len = 2048
 | 
				
			||||||
 | 
					        logger.warning(
 | 
				
			||||||
 | 
					            "The model's config.json does not contain any of the following "
 | 
				
			||||||
 | 
					            "keys to determine the original maximum length of the model: "
 | 
				
			||||||
 | 
					            f"{possible_keys}. Assuming the model's maximum length is "
 | 
				
			||||||
 | 
					            f"{default_max_len}.")
 | 
				
			||||||
 | 
					        derived_max_model_len = default_max_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rope_scaling = getattr(hf_config, "rope_scaling", None)
 | 
				
			||||||
 | 
					    if rope_scaling is not None:
 | 
				
			||||||
 | 
					        assert "factor" in rope_scaling
 | 
				
			||||||
 | 
					        scaling_factor = rope_scaling["factor"]
 | 
				
			||||||
 | 
					        derived_max_model_len *= scaling_factor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if max_model_len is None:
 | 
				
			||||||
 | 
					        max_model_len = derived_max_model_len
 | 
				
			||||||
 | 
					    elif max_model_len > derived_max_model_len:
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            f"User-specified max_model_len ({max_model_len}) is greater than "
 | 
				
			||||||
 | 
					            f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
 | 
				
			||||||
 | 
					            " in model's config.json). This may lead to incorrect model "
 | 
				
			||||||
 | 
					            "outputs or CUDA errors. Make sure the value is correct and "
 | 
				
			||||||
 | 
					            "within the model context size.")
 | 
				
			||||||
 | 
					    return int(max_model_len)
 | 
				
			||||||
 | 
				
			|||||||
@ -63,10 +63,18 @@ class BlockSpaceManager:
 | 
				
			|||||||
        num_gpu_blocks: int,
 | 
					        num_gpu_blocks: int,
 | 
				
			||||||
        num_cpu_blocks: int,
 | 
					        num_cpu_blocks: int,
 | 
				
			||||||
        watermark: float = 0.01,
 | 
					        watermark: float = 0.01,
 | 
				
			||||||
 | 
					        sliding_window: Optional[int] = None,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        self.block_size = block_size
 | 
					        self.block_size = block_size
 | 
				
			||||||
        self.num_total_gpu_blocks = num_gpu_blocks
 | 
					        self.num_total_gpu_blocks = num_gpu_blocks
 | 
				
			||||||
        self.num_total_cpu_blocks = num_cpu_blocks
 | 
					        self.num_total_cpu_blocks = num_cpu_blocks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.block_sliding_window = None
 | 
				
			||||||
 | 
					        if sliding_window is not None:
 | 
				
			||||||
 | 
					            assert sliding_window % block_size == 0, (sliding_window,
 | 
				
			||||||
 | 
					                                                      block_size)
 | 
				
			||||||
 | 
					            self.block_sliding_window = sliding_window // block_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.watermark = watermark
 | 
					        self.watermark = watermark
 | 
				
			||||||
        assert watermark >= 0.0
 | 
					        assert watermark >= 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -83,6 +91,9 @@ class BlockSpaceManager:
 | 
				
			|||||||
        # the same prompt. This may not be true for preempted sequences.
 | 
					        # the same prompt. This may not be true for preempted sequences.
 | 
				
			||||||
        seq = seq_group.get_seqs()[0]
 | 
					        seq = seq_group.get_seqs()[0]
 | 
				
			||||||
        num_required_blocks = len(seq.logical_token_blocks)
 | 
					        num_required_blocks = len(seq.logical_token_blocks)
 | 
				
			||||||
 | 
					        if self.block_sliding_window is not None:
 | 
				
			||||||
 | 
					            num_required_blocks = min(num_required_blocks,
 | 
				
			||||||
 | 
					                                      self.block_sliding_window)
 | 
				
			||||||
        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
 | 
					        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
 | 
				
			||||||
        # Use watermark to avoid frequent cache eviction.
 | 
					        # Use watermark to avoid frequent cache eviction.
 | 
				
			||||||
        return (num_free_gpu_blocks - num_required_blocks >=
 | 
					        return (num_free_gpu_blocks - num_required_blocks >=
 | 
				
			||||||
@ -95,8 +106,12 @@ class BlockSpaceManager:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Allocate new physical token blocks that will store the prompt tokens.
 | 
					        # Allocate new physical token blocks that will store the prompt tokens.
 | 
				
			||||||
        block_table: BlockTable = []
 | 
					        block_table: BlockTable = []
 | 
				
			||||||
        for _ in range(len(seq.logical_token_blocks)):
 | 
					        for logical_idx in range(len(seq.logical_token_blocks)):
 | 
				
			||||||
            block = self.gpu_allocator.allocate()
 | 
					            if (self.block_sliding_window is not None
 | 
				
			||||||
 | 
					                    and logical_idx >= self.block_sliding_window):
 | 
				
			||||||
 | 
					                block = block_table[logical_idx % self.block_sliding_window]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                block = self.gpu_allocator.allocate()
 | 
				
			||||||
            # Set the reference counts of the token blocks.
 | 
					            # Set the reference counts of the token blocks.
 | 
				
			||||||
            block.ref_count = seq_group.num_seqs()
 | 
					            block.ref_count = seq_group.num_seqs()
 | 
				
			||||||
            block_table.append(block)
 | 
					            block_table.append(block)
 | 
				
			||||||
@ -118,11 +133,17 @@ class BlockSpaceManager:
 | 
				
			|||||||
        block_table = self.block_tables[seq.seq_id]
 | 
					        block_table = self.block_tables[seq.seq_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if len(block_table) < len(logical_blocks):
 | 
					        if len(block_table) < len(logical_blocks):
 | 
				
			||||||
            # The sequence has a new logical block.
 | 
					            if (self.block_sliding_window
 | 
				
			||||||
            # Allocate a new physical block.
 | 
					                    and len(block_table) >= self.block_sliding_window):
 | 
				
			||||||
            block = self.gpu_allocator.allocate()
 | 
					                # re-use a block
 | 
				
			||||||
            block_table.append(block)
 | 
					                block_table.append(block_table[len(block_table) %
 | 
				
			||||||
            return None
 | 
					                                               self.block_sliding_window])
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # The sequence has a new logical block.
 | 
				
			||||||
 | 
					                # Allocate a new physical block.
 | 
				
			||||||
 | 
					                block = self.gpu_allocator.allocate()
 | 
				
			||||||
 | 
					                block_table.append(block)
 | 
				
			||||||
 | 
					                return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # We want to append the token to the last physical block.
 | 
					        # We want to append the token to the last physical block.
 | 
				
			||||||
        last_block = block_table[-1]
 | 
					        last_block = block_table[-1]
 | 
				
			||||||
@ -154,9 +175,7 @@ class BlockSpaceManager:
 | 
				
			|||||||
        for seq in seq_group.get_seqs():
 | 
					        for seq in seq_group.get_seqs():
 | 
				
			||||||
            if seq.is_finished():
 | 
					            if seq.is_finished():
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
            block_table = self.block_tables[seq.seq_id]
 | 
					            blocks.update(self.block_tables[seq.seq_id])
 | 
				
			||||||
            for block in block_table:
 | 
					 | 
				
			||||||
                blocks.add(block)
 | 
					 | 
				
			||||||
        return list(blocks)
 | 
					        return list(blocks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def can_swap_in(self, seq_group: SequenceGroup) -> bool:
 | 
					    def can_swap_in(self, seq_group: SequenceGroup) -> bool:
 | 
				
			||||||
@ -172,9 +191,7 @@ class BlockSpaceManager:
 | 
				
			|||||||
    def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
 | 
					    def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
 | 
				
			||||||
        # CPU block -> GPU block.
 | 
					        # CPU block -> GPU block.
 | 
				
			||||||
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
 | 
					        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
 | 
				
			||||||
        for seq in seq_group.get_seqs():
 | 
					        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
 | 
				
			||||||
            if seq.is_finished():
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
            new_block_table: BlockTable = []
 | 
					            new_block_table: BlockTable = []
 | 
				
			||||||
            block_table = self.block_tables[seq.seq_id]
 | 
					            block_table = self.block_tables[seq.seq_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -203,9 +220,7 @@ class BlockSpaceManager:
 | 
				
			|||||||
    def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
 | 
					    def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
 | 
				
			||||||
        # GPU block -> CPU block.
 | 
					        # GPU block -> CPU block.
 | 
				
			||||||
        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
 | 
					        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
 | 
				
			||||||
        for seq in seq_group.get_seqs():
 | 
					        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
 | 
				
			||||||
            if seq.is_finished():
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
            new_block_table: BlockTable = []
 | 
					            new_block_table: BlockTable = []
 | 
				
			||||||
            block_table = self.block_tables[seq.seq_id]
 | 
					            block_table = self.block_tables[seq.seq_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -228,7 +243,7 @@ class BlockSpaceManager:
 | 
				
			|||||||
        return block_number_mapping
 | 
					        return block_number_mapping
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _free_block_table(self, block_table: BlockTable) -> None:
 | 
					    def _free_block_table(self, block_table: BlockTable) -> None:
 | 
				
			||||||
        for block in block_table:
 | 
					        for block in set(block_table):
 | 
				
			||||||
            if block.device == Device.GPU:
 | 
					            if block.device == Device.GPU:
 | 
				
			||||||
                self.gpu_allocator.free(block)
 | 
					                self.gpu_allocator.free(block)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
 | 
				
			|||||||
@ -1,19 +1,16 @@
 | 
				
			|||||||
import enum
 | 
					import enum
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from typing import Dict, List, Optional, Tuple
 | 
					from typing import Dict, Iterable, List, Optional, Tuple, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm.config import CacheConfig, SchedulerConfig
 | 
					from vllm.config import CacheConfig, SchedulerConfig
 | 
				
			||||||
from vllm.core.block_manager import BlockSpaceManager
 | 
					from vllm.core.block_manager import BlockSpaceManager
 | 
				
			||||||
from vllm.core.policy import PolicyFactory
 | 
					from vllm.core.policy import PolicyFactory
 | 
				
			||||||
from vllm.logger import init_logger
 | 
					from vllm.logger import init_logger
 | 
				
			||||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
 | 
					from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
 | 
				
			||||||
                           SequenceGroupMetadata, SequenceOutputs,
 | 
					                           SequenceGroupMetadata, SequenceStatus)
 | 
				
			||||||
                           SequenceStatus)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = init_logger(__name__)
 | 
					logger = init_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_LOGGING_INTERVAL_SEC = 5
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PreemptionMode(enum.Enum):
 | 
					class PreemptionMode(enum.Enum):
 | 
				
			||||||
    """Preemption modes.
 | 
					    """Preemption modes.
 | 
				
			||||||
@ -32,19 +29,28 @@ class SchedulerOutputs:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
 | 
					        scheduled_seq_groups: List[SequenceGroup],
 | 
				
			||||||
 | 
					        prompt_run: bool,
 | 
				
			||||||
 | 
					        num_batched_tokens: int,
 | 
				
			||||||
        blocks_to_swap_in: Dict[int, int],
 | 
					        blocks_to_swap_in: Dict[int, int],
 | 
				
			||||||
        blocks_to_swap_out: Dict[int, int],
 | 
					        blocks_to_swap_out: Dict[int, int],
 | 
				
			||||||
        blocks_to_copy: Dict[int, List[int]],
 | 
					        blocks_to_copy: Dict[int, List[int]],
 | 
				
			||||||
 | 
					        ignored_seq_groups: List[SequenceGroup],
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        self.scheduled_seq_groups = scheduled_seq_groups
 | 
				
			||||||
 | 
					        self.prompt_run = prompt_run
 | 
				
			||||||
 | 
					        self.num_batched_tokens = num_batched_tokens
 | 
				
			||||||
        self.blocks_to_swap_in = blocks_to_swap_in
 | 
					        self.blocks_to_swap_in = blocks_to_swap_in
 | 
				
			||||||
        self.blocks_to_swap_out = blocks_to_swap_out
 | 
					        self.blocks_to_swap_out = blocks_to_swap_out
 | 
				
			||||||
        self.blocks_to_copy = blocks_to_copy
 | 
					        self.blocks_to_copy = blocks_to_copy
 | 
				
			||||||
        # Swap in and swap out should never happen at the same time.
 | 
					        # Swap in and swap out should never happen at the same time.
 | 
				
			||||||
        assert not (blocks_to_swap_in and blocks_to_swap_out)
 | 
					        assert not (blocks_to_swap_in and blocks_to_swap_out)
 | 
				
			||||||
 | 
					        self.ignored_seq_groups = ignored_seq_groups
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def is_empty(self) -> bool:
 | 
					    def is_empty(self) -> bool:
 | 
				
			||||||
        return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
 | 
					        # NOTE: We do not consider the ignored sequence groups.
 | 
				
			||||||
                and not self.blocks_to_copy)
 | 
					        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
 | 
				
			||||||
 | 
					                and not self.blocks_to_swap_out and not self.blocks_to_copy)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Scheduler:
 | 
					class Scheduler:
 | 
				
			||||||
@ -53,11 +59,12 @@ class Scheduler:
 | 
				
			|||||||
        self,
 | 
					        self,
 | 
				
			||||||
        scheduler_config: SchedulerConfig,
 | 
					        scheduler_config: SchedulerConfig,
 | 
				
			||||||
        cache_config: CacheConfig,
 | 
					        cache_config: CacheConfig,
 | 
				
			||||||
        log_stats: bool,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        self.scheduler_config = scheduler_config
 | 
					        self.scheduler_config = scheduler_config
 | 
				
			||||||
        self.cache_config = cache_config
 | 
					        self.cache_config = cache_config
 | 
				
			||||||
        self.log_stats = log_stats
 | 
					
 | 
				
			||||||
 | 
					        self.prompt_limit = min(self.scheduler_config.max_model_len,
 | 
				
			||||||
 | 
					                                self.scheduler_config.max_num_batched_tokens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Instantiate the scheduling policy.
 | 
					        # Instantiate the scheduling policy.
 | 
				
			||||||
        self.policy = PolicyFactory.get_policy(policy_name="fcfs")
 | 
					        self.policy = PolicyFactory.get_policy(policy_name="fcfs")
 | 
				
			||||||
@ -66,8 +73,9 @@ class Scheduler:
 | 
				
			|||||||
            block_size=self.cache_config.block_size,
 | 
					            block_size=self.cache_config.block_size,
 | 
				
			||||||
            num_gpu_blocks=self.cache_config.num_gpu_blocks,
 | 
					            num_gpu_blocks=self.cache_config.num_gpu_blocks,
 | 
				
			||||||
            num_cpu_blocks=self.cache_config.num_cpu_blocks,
 | 
					            num_cpu_blocks=self.cache_config.num_cpu_blocks,
 | 
				
			||||||
        )
 | 
					            sliding_window=self.cache_config.sliding_window)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # TODO(zhuohan): Use deque instead of list for better performance.
 | 
				
			||||||
        # Sequence groups in the WAITING state.
 | 
					        # Sequence groups in the WAITING state.
 | 
				
			||||||
        self.waiting: List[SequenceGroup] = []
 | 
					        self.waiting: List[SequenceGroup] = []
 | 
				
			||||||
        # Sequence groups in the RUNNING state.
 | 
					        # Sequence groups in the RUNNING state.
 | 
				
			||||||
@ -75,25 +83,30 @@ class Scheduler:
 | 
				
			|||||||
        # Sequence groups in the SWAPPED state.
 | 
					        # Sequence groups in the SWAPPED state.
 | 
				
			||||||
        self.swapped: List[SequenceGroup] = []
 | 
					        self.swapped: List[SequenceGroup] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.last_logging_time: float = 0.0
 | 
					 | 
				
			||||||
        # List[timestamp, num_tokens]
 | 
					 | 
				
			||||||
        self.num_input_tokens: List[Tuple[float, int]] = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
 | 
					    def add_seq_group(self, seq_group: SequenceGroup) -> None:
 | 
				
			||||||
        # Add sequence groups to the waiting queue.
 | 
					        # Add sequence groups to the waiting queue.
 | 
				
			||||||
        self.waiting.append(seq_group)
 | 
					        self.waiting.append(seq_group)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def abort_seq_group(self, request_id: str) -> None:
 | 
					    def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
 | 
				
			||||||
 | 
					        if isinstance(request_id, str):
 | 
				
			||||||
 | 
					            request_id = (request_id, )
 | 
				
			||||||
 | 
					        request_ids = set(request_id)
 | 
				
			||||||
        for state_queue in [self.waiting, self.running, self.swapped]:
 | 
					        for state_queue in [self.waiting, self.running, self.swapped]:
 | 
				
			||||||
            for seq_group in state_queue:
 | 
					            # We need to reverse the list as we are removing elements
 | 
				
			||||||
                if seq_group.request_id == request_id:
 | 
					            # from it as we iterate over it. If we don't do it,
 | 
				
			||||||
 | 
					            # indices will get messed up and we will skip over elements.
 | 
				
			||||||
 | 
					            for seq_group in reversed(state_queue):
 | 
				
			||||||
 | 
					                if seq_group.request_id in request_ids:
 | 
				
			||||||
                    # Remove the sequence group from the state queue.
 | 
					                    # Remove the sequence group from the state queue.
 | 
				
			||||||
                    state_queue.remove(seq_group)
 | 
					                    state_queue.remove(seq_group)
 | 
				
			||||||
                    for seq in seq_group.seqs:
 | 
					                    for seq in seq_group.get_seqs():
 | 
				
			||||||
                        if seq.is_finished():
 | 
					                        if seq.is_finished():
 | 
				
			||||||
                            continue
 | 
					                            continue
 | 
				
			||||||
                        self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
 | 
					                        seq.status = SequenceStatus.FINISHED_ABORTED
 | 
				
			||||||
                    return
 | 
					                        self.free_seq(seq)
 | 
				
			||||||
 | 
					                    request_ids.remove(seq_group.request_id)
 | 
				
			||||||
 | 
					                    if not request_ids:
 | 
				
			||||||
 | 
					                        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def has_unfinished_seqs(self) -> bool:
 | 
					    def has_unfinished_seqs(self) -> bool:
 | 
				
			||||||
        return self.waiting or self.running or self.swapped
 | 
					        return self.waiting or self.running or self.swapped
 | 
				
			||||||
@ -101,21 +114,81 @@ class Scheduler:
 | 
				
			|||||||
    def get_num_unfinished_seq_groups(self) -> int:
 | 
					    def get_num_unfinished_seq_groups(self) -> int:
 | 
				
			||||||
        return len(self.waiting) + len(self.running) + len(self.swapped)
 | 
					        return len(self.waiting) + len(self.running) + len(self.swapped)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _schedule(
 | 
					    def _schedule(self) -> SchedulerOutputs:
 | 
				
			||||||
            self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
 | 
					 | 
				
			||||||
        # Blocks that need to be swaped or copied before model execution.
 | 
					        # Blocks that need to be swaped or copied before model execution.
 | 
				
			||||||
        blocks_to_swap_in: Dict[int, int] = {}
 | 
					        blocks_to_swap_in: Dict[int, int] = {}
 | 
				
			||||||
        blocks_to_swap_out: Dict[int, int] = {}
 | 
					        blocks_to_swap_out: Dict[int, int] = {}
 | 
				
			||||||
        blocks_to_copy: Dict[int, List[int]] = {}
 | 
					        blocks_to_copy: Dict[int, List[int]] = {}
 | 
				
			||||||
        ignored_seq_groups: List[SequenceGroup] = []
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Fix the current time.
 | 
					        # Fix the current time.
 | 
				
			||||||
        now = time.time()
 | 
					        now = time.monotonic()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
 | 
					        # Join waiting sequences if possible.
 | 
				
			||||||
        # in order to minimize the preemption overheads.
 | 
					        if not self.swapped:
 | 
				
			||||||
        # Preemption happens only when there is no available slot to keep all
 | 
					            ignored_seq_groups: List[SequenceGroup] = []
 | 
				
			||||||
        # the sequence groups in the RUNNING state.
 | 
					            scheduled: List[SequenceGroup] = []
 | 
				
			||||||
 | 
					            # The total number of sequences on the fly, including the
 | 
				
			||||||
 | 
					            # requests in the generation phase.
 | 
				
			||||||
 | 
					            num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
 | 
				
			||||||
 | 
					                                for seq_group in self.running)
 | 
				
			||||||
 | 
					            num_batched_tokens = 0
 | 
				
			||||||
 | 
					            # Optimization: We do not sort the waiting queue since the preempted
 | 
				
			||||||
 | 
					            # sequence groups are added to the front and the new sequence groups
 | 
				
			||||||
 | 
					            # are added to the back.
 | 
				
			||||||
 | 
					            while self.waiting:
 | 
				
			||||||
 | 
					                seq_group = self.waiting[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                assert seq_group.num_seqs() == 1, (
 | 
				
			||||||
 | 
					                    "Waiting sequence group should have only one prompt "
 | 
				
			||||||
 | 
					                    "sequence.")
 | 
				
			||||||
 | 
					                num_prompt_tokens = seq_group.get_seqs()[0].get_len()
 | 
				
			||||||
 | 
					                if num_prompt_tokens > self.prompt_limit:
 | 
				
			||||||
 | 
					                    logger.warning(
 | 
				
			||||||
 | 
					                        f"Input prompt ({num_prompt_tokens} tokens) is too long"
 | 
				
			||||||
 | 
					                        f" and exceeds limit of {self.prompt_limit}")
 | 
				
			||||||
 | 
					                    for seq in seq_group.get_seqs():
 | 
				
			||||||
 | 
					                        seq.status = SequenceStatus.FINISHED_IGNORED
 | 
				
			||||||
 | 
					                    ignored_seq_groups.append(seq_group)
 | 
				
			||||||
 | 
					                    self.waiting.pop(0)
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # If the sequence group cannot be allocated, stop.
 | 
				
			||||||
 | 
					                if not self.block_manager.can_allocate(seq_group):
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # If the number of batched tokens exceeds the limit, stop.
 | 
				
			||||||
 | 
					                if (num_batched_tokens + num_prompt_tokens >
 | 
				
			||||||
 | 
					                        self.scheduler_config.max_num_batched_tokens):
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # The total number of sequences in the RUNNING state should not
 | 
				
			||||||
 | 
					                # exceed the maximum number of sequences.
 | 
				
			||||||
 | 
					                num_new_seqs = seq_group.get_max_num_running_seqs()
 | 
				
			||||||
 | 
					                if (num_curr_seqs + num_new_seqs >
 | 
				
			||||||
 | 
					                        self.scheduler_config.max_num_seqs):
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                seq_group = self.waiting.pop(0)
 | 
				
			||||||
 | 
					                self._allocate(seq_group)
 | 
				
			||||||
 | 
					                self.running.append(seq_group)
 | 
				
			||||||
 | 
					                num_batched_tokens += num_prompt_tokens
 | 
				
			||||||
 | 
					                num_curr_seqs += num_new_seqs
 | 
				
			||||||
 | 
					                scheduled.append(seq_group)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if scheduled or ignored_seq_groups:
 | 
				
			||||||
 | 
					                scheduler_outputs = SchedulerOutputs(
 | 
				
			||||||
 | 
					                    scheduled_seq_groups=scheduled,
 | 
				
			||||||
 | 
					                    prompt_run=True,
 | 
				
			||||||
 | 
					                    num_batched_tokens=num_batched_tokens,
 | 
				
			||||||
 | 
					                    blocks_to_swap_in=blocks_to_swap_in,
 | 
				
			||||||
 | 
					                    blocks_to_swap_out=blocks_to_swap_out,
 | 
				
			||||||
 | 
					                    blocks_to_copy=blocks_to_copy,
 | 
				
			||||||
 | 
					                    ignored_seq_groups=ignored_seq_groups,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                return scheduler_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # NOTE(woosuk): Preemption happens only when there is no available slot
 | 
				
			||||||
 | 
					        # to keep all the sequence groups in the RUNNING state.
 | 
				
			||||||
        # In this case, the policy is responsible for deciding which sequence
 | 
					        # In this case, the policy is responsible for deciding which sequence
 | 
				
			||||||
        # groups to preempt.
 | 
					        # groups to preempt.
 | 
				
			||||||
        self.running = self.policy.sort_by_priority(now, self.running)
 | 
					        self.running = self.policy.sort_by_priority(now, self.running)
 | 
				
			||||||
@ -145,150 +218,56 @@ class Scheduler:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Swap in the sequence groups in the SWAPPED state if possible.
 | 
					        # Swap in the sequence groups in the SWAPPED state if possible.
 | 
				
			||||||
        self.swapped = self.policy.sort_by_priority(now, self.swapped)
 | 
					        self.swapped = self.policy.sort_by_priority(now, self.swapped)
 | 
				
			||||||
        while self.swapped and not blocks_to_swap_out:
 | 
					        if not preempted:
 | 
				
			||||||
            seq_group = self.swapped[0]
 | 
					            num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
 | 
				
			||||||
            # If the sequence group has been preempted in this step, stop.
 | 
					                                for seq_group in self.running)
 | 
				
			||||||
            if seq_group in preempted:
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
            # If the sequence group cannot be swapped in, stop.
 | 
					 | 
				
			||||||
            if not self.block_manager.can_swap_in(seq_group):
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # The total number of sequences in the RUNNING state should not
 | 
					            while self.swapped:
 | 
				
			||||||
            # exceed the maximum number of sequences.
 | 
					                seq_group = self.swapped[0]
 | 
				
			||||||
            num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
 | 
					                # If the sequence group cannot be swapped in, stop.
 | 
				
			||||||
            num_curr_seqs = sum(
 | 
					                if not self.block_manager.can_swap_in(seq_group):
 | 
				
			||||||
                seq_group.num_seqs(status=SequenceStatus.RUNNING)
 | 
					 | 
				
			||||||
                for seq_group in self.running)
 | 
					 | 
				
			||||||
            if (num_curr_seqs + num_new_seqs >
 | 
					 | 
				
			||||||
                    self.scheduler_config.max_num_seqs):
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            seq_group = self.swapped.pop(0)
 | 
					 | 
				
			||||||
            self._swap_in(seq_group, blocks_to_swap_in)
 | 
					 | 
				
			||||||
            self._append_slot(seq_group, blocks_to_copy)
 | 
					 | 
				
			||||||
            self.running.append(seq_group)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        num_batched_tokens = sum(
 | 
					 | 
				
			||||||
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
 | 
					 | 
				
			||||||
            for seq_group in self.running)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Join waiting sequences if possible.
 | 
					 | 
				
			||||||
        prompt_group_ids: List[str] = []
 | 
					 | 
				
			||||||
        # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
 | 
					 | 
				
			||||||
        # prioritized over the sequence groups in the WAITING state.
 | 
					 | 
				
			||||||
        # This is because we want to bound the amount of CPU memory taken by
 | 
					 | 
				
			||||||
        # the swapped sequence groups.
 | 
					 | 
				
			||||||
        if not self.swapped:
 | 
					 | 
				
			||||||
            # Optimization: We do not sort the waiting queue since the preempted
 | 
					 | 
				
			||||||
            # sequence groups are added to the front and the new sequence groups
 | 
					 | 
				
			||||||
            # are added to the back.
 | 
					 | 
				
			||||||
            while self.waiting:
 | 
					 | 
				
			||||||
                seq_group = self.waiting[0]
 | 
					 | 
				
			||||||
                # If the sequence group has been preempted in this step, stop.
 | 
					 | 
				
			||||||
                if seq_group in preempted:
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                num_prompt_tokens = seq_group.get_seqs()[0].get_len()
 | 
					 | 
				
			||||||
                if num_prompt_tokens >= self.scheduler_config.max_seq_len:
 | 
					 | 
				
			||||||
                    logger.warning(
 | 
					 | 
				
			||||||
                        f"Input prompt ({num_prompt_tokens} tokens) is too long"
 | 
					 | 
				
			||||||
                        " and exceeds limit of "
 | 
					 | 
				
			||||||
                        f"{self.scheduler_config.max_seq_len}")
 | 
					 | 
				
			||||||
                    for seq in seq_group.get_seqs():
 | 
					 | 
				
			||||||
                        seq.status = SequenceStatus.FINISHED_IGNORED
 | 
					 | 
				
			||||||
                    ignored_seq_groups.append(seq_group)
 | 
					 | 
				
			||||||
                    self.waiting.pop(0)
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # If the sequence group cannot be allocated, stop.
 | 
					 | 
				
			||||||
                if not self.block_manager.can_allocate(seq_group):
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # If the number of batched tokens exceeds the limit, stop.
 | 
					 | 
				
			||||||
                if (num_batched_tokens + num_prompt_tokens >
 | 
					 | 
				
			||||||
                        self.scheduler_config.max_num_batched_tokens):
 | 
					 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # The total number of sequences in the RUNNING state should not
 | 
					                # The total number of sequences in the RUNNING state should not
 | 
				
			||||||
                # exceed the maximum number of sequences.
 | 
					                # exceed the maximum number of sequences.
 | 
				
			||||||
                num_new_seqs = seq_group.num_seqs(
 | 
					                num_new_seqs = seq_group.get_max_num_running_seqs()
 | 
				
			||||||
                    status=SequenceStatus.WAITING)
 | 
					 | 
				
			||||||
                num_curr_seqs = sum(
 | 
					 | 
				
			||||||
                    seq_group.num_seqs(status=SequenceStatus.RUNNING)
 | 
					 | 
				
			||||||
                    for seq_group in self.running)
 | 
					 | 
				
			||||||
                if (num_curr_seqs + num_new_seqs >
 | 
					                if (num_curr_seqs + num_new_seqs >
 | 
				
			||||||
                        self.scheduler_config.max_num_seqs):
 | 
					                        self.scheduler_config.max_num_seqs):
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                seq_group = self.waiting.pop(0)
 | 
					                seq_group = self.swapped.pop(0)
 | 
				
			||||||
                self._allocate(seq_group)
 | 
					                self._swap_in(seq_group, blocks_to_swap_in)
 | 
				
			||||||
 | 
					                self._append_slot(seq_group, blocks_to_copy)
 | 
				
			||||||
 | 
					                num_curr_seqs += num_new_seqs
 | 
				
			||||||
                self.running.append(seq_group)
 | 
					                self.running.append(seq_group)
 | 
				
			||||||
                num_batched_tokens += num_prompt_tokens
 | 
					
 | 
				
			||||||
                prompt_group_ids.append(seq_group.request_id)
 | 
					        # Each sequence in the generation phase only takes one token slot.
 | 
				
			||||||
 | 
					        # Therefore, the number of batched tokens is equal to the number of
 | 
				
			||||||
 | 
					        # sequences in the RUNNING state.
 | 
				
			||||||
 | 
					        num_batched_tokens = sum(
 | 
				
			||||||
 | 
					            seq_group.num_seqs(status=SequenceStatus.RUNNING)
 | 
				
			||||||
 | 
					            for seq_group in self.running)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        scheduler_outputs = SchedulerOutputs(
 | 
					        scheduler_outputs = SchedulerOutputs(
 | 
				
			||||||
 | 
					            scheduled_seq_groups=self.running,
 | 
				
			||||||
 | 
					            prompt_run=False,
 | 
				
			||||||
 | 
					            num_batched_tokens=num_batched_tokens,
 | 
				
			||||||
            blocks_to_swap_in=blocks_to_swap_in,
 | 
					            blocks_to_swap_in=blocks_to_swap_in,
 | 
				
			||||||
            blocks_to_swap_out=blocks_to_swap_out,
 | 
					            blocks_to_swap_out=blocks_to_swap_out,
 | 
				
			||||||
            blocks_to_copy=blocks_to_copy,
 | 
					            blocks_to_copy=blocks_to_copy,
 | 
				
			||||||
 | 
					            ignored_seq_groups=[],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        if not self.log_stats:
 | 
					        return scheduler_outputs
 | 
				
			||||||
            return scheduler_outputs, prompt_group_ids, ignored_seq_groups
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO(woosuk): Move the below code to the engine.
 | 
					    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
 | 
				
			||||||
        now = time.time()
 | 
					 | 
				
			||||||
        if num_batched_tokens > 0:
 | 
					 | 
				
			||||||
            self.num_input_tokens.append((now, num_batched_tokens))
 | 
					 | 
				
			||||||
        elapsed_time = now - self.last_logging_time
 | 
					 | 
				
			||||||
        if elapsed_time > _LOGGING_INTERVAL_SEC:
 | 
					 | 
				
			||||||
            self.last_logging_time = now
 | 
					 | 
				
			||||||
            self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
 | 
					 | 
				
			||||||
                                     if now - t < _LOGGING_INTERVAL_SEC]
 | 
					 | 
				
			||||||
            if len(self.num_input_tokens) > 1:
 | 
					 | 
				
			||||||
                total_num_tokens = sum(n
 | 
					 | 
				
			||||||
                                       for _, n in self.num_input_tokens[:-1])
 | 
					 | 
				
			||||||
                window = now - self.num_input_tokens[0][0]
 | 
					 | 
				
			||||||
                avg_throughput = total_num_tokens / window
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                avg_throughput = 0.0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            total_num_gpu_blocks = self.cache_config.num_gpu_blocks
 | 
					 | 
				
			||||||
            num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
 | 
					 | 
				
			||||||
            num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
 | 
					 | 
				
			||||||
            gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            total_num_cpu_blocks = self.cache_config.num_cpu_blocks
 | 
					 | 
				
			||||||
            if total_num_cpu_blocks > 0:
 | 
					 | 
				
			||||||
                num_free_cpu_blocks = (
 | 
					 | 
				
			||||||
                    self.block_manager.get_num_free_cpu_blocks())
 | 
					 | 
				
			||||||
                num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
 | 
					 | 
				
			||||||
                cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                cpu_cache_usage = 0.0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
 | 
					 | 
				
			||||||
                        f"Running: {len(self.running)} reqs, "
 | 
					 | 
				
			||||||
                        f"Swapped: {len(self.swapped)} reqs, "
 | 
					 | 
				
			||||||
                        f"Pending: {len(self.waiting)} reqs, "
 | 
					 | 
				
			||||||
                        f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
 | 
					 | 
				
			||||||
                        f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
 | 
					 | 
				
			||||||
        return scheduler_outputs, prompt_group_ids, ignored_seq_groups
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def schedule(
 | 
					 | 
				
			||||||
        self
 | 
					 | 
				
			||||||
    ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
 | 
					 | 
				
			||||||
               List[SequenceGroup]]:
 | 
					 | 
				
			||||||
        # Schedule sequence groups.
 | 
					        # Schedule sequence groups.
 | 
				
			||||||
        # This function call changes the internal states of the scheduler
 | 
					        # This function call changes the internal states of the scheduler
 | 
				
			||||||
        # such as self.running, self.swapped, and self.waiting.
 | 
					        # such as self.running, self.swapped, and self.waiting.
 | 
				
			||||||
        (scheduler_outputs, prompt_group_ids,
 | 
					        scheduler_outputs = self._schedule()
 | 
				
			||||||
         ignored_seq_groups) = self._schedule()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create input data structures.
 | 
					        # Create input data structures.
 | 
				
			||||||
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
 | 
					        seq_group_metadata_list: List[SequenceGroupMetadata] = []
 | 
				
			||||||
        for seq_group in self.running:
 | 
					        for seq_group in scheduler_outputs.scheduled_seq_groups:
 | 
				
			||||||
            is_prompt = seq_group.request_id in prompt_group_ids
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            seq_data: Dict[int, List[SequenceData]] = {}
 | 
					            seq_data: Dict[int, List[SequenceData]] = {}
 | 
				
			||||||
            block_tables: Dict[int, List[int]] = {}
 | 
					            block_tables: Dict[int, List[int]] = {}
 | 
				
			||||||
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
 | 
					            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
 | 
				
			||||||
@ -298,43 +277,18 @@ class Scheduler:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            seq_group_metadata = SequenceGroupMetadata(
 | 
					            seq_group_metadata = SequenceGroupMetadata(
 | 
				
			||||||
                request_id=seq_group.request_id,
 | 
					                request_id=seq_group.request_id,
 | 
				
			||||||
                is_prompt=is_prompt,
 | 
					                is_prompt=scheduler_outputs.prompt_run,
 | 
				
			||||||
                seq_data=seq_data,
 | 
					                seq_data=seq_data,
 | 
				
			||||||
                sampling_params=seq_group.sampling_params,
 | 
					                sampling_params=seq_group.sampling_params,
 | 
				
			||||||
                block_tables=block_tables,
 | 
					                block_tables=block_tables,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            seq_group_metadata_list.append(seq_group_metadata)
 | 
					            seq_group_metadata_list.append(seq_group_metadata)
 | 
				
			||||||
        return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
 | 
					        return seq_group_metadata_list, scheduler_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(
 | 
					    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
 | 
				
			||||||
        self,
 | 
					        self.block_manager.fork(parent_seq, child_seq)
 | 
				
			||||||
        seq_outputs: Dict[int, SequenceOutputs],
 | 
					 | 
				
			||||||
    ) -> List[SequenceGroup]:
 | 
					 | 
				
			||||||
        # Update the running sequences and free blocks.
 | 
					 | 
				
			||||||
        for seq_group in self.running:
 | 
					 | 
				
			||||||
            # Process beam search results before processing the new tokens.
 | 
					 | 
				
			||||||
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
 | 
					 | 
				
			||||||
                output = seq_outputs[seq.seq_id]
 | 
					 | 
				
			||||||
                if seq.seq_id != output.parent_seq_id:
 | 
					 | 
				
			||||||
                    # The sequence is a fork of the parent sequence (beam
 | 
					 | 
				
			||||||
                    # search). Free the current sequence.
 | 
					 | 
				
			||||||
                    self.block_manager.free(seq)
 | 
					 | 
				
			||||||
                    # Fork the parent sequence.
 | 
					 | 
				
			||||||
                    parent_seq = seq_group.find(output.parent_seq_id)
 | 
					 | 
				
			||||||
                    parent_seq.fork(seq)
 | 
					 | 
				
			||||||
                    self.block_manager.fork(parent_seq, seq)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Process the new tokens.
 | 
					    def free_seq(self, seq: Sequence) -> None:
 | 
				
			||||||
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
 | 
					 | 
				
			||||||
                # Append a new token to the sequence.
 | 
					 | 
				
			||||||
                output = seq_outputs[seq.seq_id]
 | 
					 | 
				
			||||||
                seq.append_token_id(output.output_token, output.logprobs)
 | 
					 | 
				
			||||||
        # Return a shallow copy of the running queue to prevent the queue
 | 
					 | 
				
			||||||
        # from being modified by the caller.
 | 
					 | 
				
			||||||
        return self.running.copy()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
 | 
					 | 
				
			||||||
        seq.status = finish_status
 | 
					 | 
				
			||||||
        self.block_manager.free(seq)
 | 
					        self.block_manager.free(seq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def free_finished_seq_groups(self) -> None:
 | 
					    def free_finished_seq_groups(self) -> None:
 | 
				
			||||||
@ -371,8 +325,8 @@ class Scheduler:
 | 
				
			|||||||
        # If preemption mode is not specified, we determine the mode as follows:
 | 
					        # If preemption mode is not specified, we determine the mode as follows:
 | 
				
			||||||
        # We use recomputation by default since it incurs lower overhead than
 | 
					        # We use recomputation by default since it incurs lower overhead than
 | 
				
			||||||
        # swapping. However, when the sequence group has multiple sequences
 | 
					        # swapping. However, when the sequence group has multiple sequences
 | 
				
			||||||
        # (e.g., beam search), recomputation is not supported. In such a case,
 | 
					        # (e.g., beam search), recomputation is not currently supported. In
 | 
				
			||||||
        # we use swapping instead.
 | 
					        # such a case, we use swapping instead.
 | 
				
			||||||
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
 | 
					        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
 | 
				
			||||||
        # As swapped sequences are prioritized over waiting sequences,
 | 
					        # As swapped sequences are prioritized over waiting sequences,
 | 
				
			||||||
        # sequence groups with multiple sequences are implicitly prioritized
 | 
					        # sequence groups with multiple sequences are implicitly prioritized
 | 
				
			||||||
@ -380,8 +334,7 @@ class Scheduler:
 | 
				
			|||||||
        # TODO(woosuk): Support recomputation for sequence groups with multiple
 | 
					        # TODO(woosuk): Support recomputation for sequence groups with multiple
 | 
				
			||||||
        # sequences. This may require a more sophisticated CUDA kernel.
 | 
					        # sequences. This may require a more sophisticated CUDA kernel.
 | 
				
			||||||
        if preemption_mode is None:
 | 
					        if preemption_mode is None:
 | 
				
			||||||
            seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
 | 
					            if seq_group.get_max_num_running_seqs() == 1:
 | 
				
			||||||
            if len(seqs) == 1:
 | 
					 | 
				
			||||||
                preemption_mode = PreemptionMode.RECOMPUTE
 | 
					                preemption_mode = PreemptionMode.RECOMPUTE
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                preemption_mode = PreemptionMode.SWAP
 | 
					                preemption_mode = PreemptionMode.SWAP
 | 
				
			||||||
@ -410,9 +363,6 @@ class Scheduler:
 | 
				
			|||||||
        seq_group: SequenceGroup,
 | 
					        seq_group: SequenceGroup,
 | 
				
			||||||
        blocks_to_swap_out: Dict[int, int],
 | 
					        blocks_to_swap_out: Dict[int, int],
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
 | 
					 | 
				
			||||||
        for seq in seqs:
 | 
					 | 
				
			||||||
            seq.status = SequenceStatus.SWAPPED
 | 
					 | 
				
			||||||
        self._swap_out(seq_group, blocks_to_swap_out)
 | 
					        self._swap_out(seq_group, blocks_to_swap_out)
 | 
				
			||||||
        self.swapped.append(seq_group)
 | 
					        self.swapped.append(seq_group)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -13,25 +13,28 @@ class EngineArgs:
 | 
				
			|||||||
    model: str
 | 
					    model: str
 | 
				
			||||||
    tokenizer: Optional[str] = None
 | 
					    tokenizer: Optional[str] = None
 | 
				
			||||||
    tokenizer_mode: str = 'auto'
 | 
					    tokenizer_mode: str = 'auto'
 | 
				
			||||||
 | 
					    trust_remote_code: bool = False
 | 
				
			||||||
    download_dir: Optional[str] = None
 | 
					    download_dir: Optional[str] = None
 | 
				
			||||||
    use_np_weights: bool = False
 | 
					    load_format: str = 'auto'
 | 
				
			||||||
    use_dummy_weights: bool = False
 | 
					 | 
				
			||||||
    dtype: str = 'auto'
 | 
					    dtype: str = 'auto'
 | 
				
			||||||
    seed: int = 0
 | 
					    seed: int = 0
 | 
				
			||||||
 | 
					    max_model_len: Optional[int] = None
 | 
				
			||||||
    worker_use_ray: bool = False
 | 
					    worker_use_ray: bool = False
 | 
				
			||||||
    pipeline_parallel_size: int = 1
 | 
					    pipeline_parallel_size: int = 1
 | 
				
			||||||
    tensor_parallel_size: int = 1
 | 
					    tensor_parallel_size: int = 1
 | 
				
			||||||
    block_size: int = 16
 | 
					    block_size: int = 16
 | 
				
			||||||
    swap_space: int = 4  # GiB
 | 
					    swap_space: int = 4  # GiB
 | 
				
			||||||
    gpu_memory_utilization: float = 0.90
 | 
					    gpu_memory_utilization: float = 0.90
 | 
				
			||||||
    max_num_batched_tokens: int = 2560
 | 
					    max_num_batched_tokens: Optional[int] = None
 | 
				
			||||||
    max_num_seqs: int = 256
 | 
					    max_num_seqs: int = 256
 | 
				
			||||||
    disable_log_stats: bool = False
 | 
					    disable_log_stats: bool = False
 | 
				
			||||||
 | 
					    revision: Optional[str] = None
 | 
				
			||||||
 | 
					    tokenizer_revision: Optional[str] = None
 | 
				
			||||||
 | 
					    quantization: Optional[str] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __post_init__(self):
 | 
					    def __post_init__(self):
 | 
				
			||||||
        if self.tokenizer is None:
 | 
					        if self.tokenizer is None:
 | 
				
			||||||
            self.tokenizer = self.model
 | 
					            self.tokenizer = self.model
 | 
				
			||||||
        self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def add_cli_args(
 | 
					    def add_cli_args(
 | 
				
			||||||
@ -48,6 +51,20 @@ class EngineArgs:
 | 
				
			|||||||
            type=str,
 | 
					            type=str,
 | 
				
			||||||
            default=EngineArgs.tokenizer,
 | 
					            default=EngineArgs.tokenizer,
 | 
				
			||||||
            help='name or path of the huggingface tokenizer to use')
 | 
					            help='name or path of the huggingface tokenizer to use')
 | 
				
			||||||
 | 
					        parser.add_argument(
 | 
				
			||||||
 | 
					            '--revision',
 | 
				
			||||||
 | 
					            type=str,
 | 
				
			||||||
 | 
					            default=None,
 | 
				
			||||||
 | 
					            help='the specific model version to use. It can be a branch '
 | 
				
			||||||
 | 
					            'name, a tag name, or a commit id. If unspecified, will use '
 | 
				
			||||||
 | 
					            'the default version.')
 | 
				
			||||||
 | 
					        parser.add_argument(
 | 
				
			||||||
 | 
					            '--tokenizer-revision',
 | 
				
			||||||
 | 
					            type=str,
 | 
				
			||||||
 | 
					            default=None,
 | 
				
			||||||
 | 
					            help='the specific tokenizer version to use. It can be a branch '
 | 
				
			||||||
 | 
					            'name, a tag name, or a commit id. If unspecified, will use '
 | 
				
			||||||
 | 
					            'the default version.')
 | 
				
			||||||
        parser.add_argument('--tokenizer-mode',
 | 
					        parser.add_argument('--tokenizer-mode',
 | 
				
			||||||
                            type=str,
 | 
					                            type=str,
 | 
				
			||||||
                            default=EngineArgs.tokenizer_mode,
 | 
					                            default=EngineArgs.tokenizer_mode,
 | 
				
			||||||
@ -55,30 +72,46 @@ class EngineArgs:
 | 
				
			|||||||
                            help='tokenizer mode. "auto" will use the fast '
 | 
					                            help='tokenizer mode. "auto" will use the fast '
 | 
				
			||||||
                            'tokenizer if available, and "slow" will '
 | 
					                            'tokenizer if available, and "slow" will '
 | 
				
			||||||
                            'always use the slow tokenizer.')
 | 
					                            'always use the slow tokenizer.')
 | 
				
			||||||
 | 
					        parser.add_argument('--trust-remote-code',
 | 
				
			||||||
 | 
					                            action='store_true',
 | 
				
			||||||
 | 
					                            help='trust remote code from huggingface')
 | 
				
			||||||
        parser.add_argument('--download-dir',
 | 
					        parser.add_argument('--download-dir',
 | 
				
			||||||
                            type=str,
 | 
					                            type=str,
 | 
				
			||||||
                            default=EngineArgs.download_dir,
 | 
					                            default=EngineArgs.download_dir,
 | 
				
			||||||
                            help='directory to download and load the weights, '
 | 
					                            help='directory to download and load the weights, '
 | 
				
			||||||
                            'default to the default cache dir of '
 | 
					                            'default to the default cache dir of '
 | 
				
			||||||
                            'huggingface')
 | 
					                            'huggingface')
 | 
				
			||||||
        parser.add_argument('--use-np-weights',
 | 
					        parser.add_argument(
 | 
				
			||||||
                            action='store_true',
 | 
					            '--load-format',
 | 
				
			||||||
                            help='save a numpy copy of model weights for '
 | 
					            type=str,
 | 
				
			||||||
                            'faster loading. This can increase the disk '
 | 
					            default=EngineArgs.load_format,
 | 
				
			||||||
                            'usage by up to 2x.')
 | 
					            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
 | 
				
			||||||
        parser.add_argument('--use-dummy-weights',
 | 
					            help='The format of the model weights to load. '
 | 
				
			||||||
                            action='store_true',
 | 
					            '"auto" will try to load the weights in the safetensors format '
 | 
				
			||||||
                            help='use dummy values for model weights')
 | 
					            'and fall back to the pytorch bin format if safetensors format '
 | 
				
			||||||
        # TODO(woosuk): Support FP32.
 | 
					            'is not available. '
 | 
				
			||||||
 | 
					            '"pt" will load the weights in the pytorch bin format. '
 | 
				
			||||||
 | 
					            '"safetensors" will load the weights in the safetensors format. '
 | 
				
			||||||
 | 
					            '"npcache" will load the weights in pytorch format and store '
 | 
				
			||||||
 | 
					            'a numpy cache to speed up the loading. '
 | 
				
			||||||
 | 
					            '"dummy" will initialize the weights with random values, '
 | 
				
			||||||
 | 
					            'which is mainly for profiling.')
 | 
				
			||||||
        parser.add_argument(
 | 
					        parser.add_argument(
 | 
				
			||||||
            '--dtype',
 | 
					            '--dtype',
 | 
				
			||||||
            type=str,
 | 
					            type=str,
 | 
				
			||||||
            default=EngineArgs.dtype,
 | 
					            default=EngineArgs.dtype,
 | 
				
			||||||
            choices=['auto', 'half', 'bfloat16', 'float'],
 | 
					            choices=[
 | 
				
			||||||
 | 
					                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
            help='data type for model weights and activations. '
 | 
					            help='data type for model weights and activations. '
 | 
				
			||||||
            'The "auto" option will use FP16 precision '
 | 
					            'The "auto" option will use FP16 precision '
 | 
				
			||||||
            'for FP32 and FP16 models, and BF16 precision '
 | 
					            'for FP32 and FP16 models, and BF16 precision '
 | 
				
			||||||
            'for BF16 models.')
 | 
					            'for BF16 models.')
 | 
				
			||||||
 | 
					        parser.add_argument('--max-model-len',
 | 
				
			||||||
 | 
					                            type=int,
 | 
				
			||||||
 | 
					                            default=None,
 | 
				
			||||||
 | 
					                            help='model context length. If unspecified, '
 | 
				
			||||||
 | 
					                            'will be automatically derived from the model.')
 | 
				
			||||||
        # Parallel arguments
 | 
					        # Parallel arguments
 | 
				
			||||||
        parser.add_argument('--worker-use-ray',
 | 
					        parser.add_argument('--worker-use-ray',
 | 
				
			||||||
                            action='store_true',
 | 
					                            action='store_true',
 | 
				
			||||||
@ -126,6 +159,13 @@ class EngineArgs:
 | 
				
			|||||||
        parser.add_argument('--disable-log-stats',
 | 
					        parser.add_argument('--disable-log-stats',
 | 
				
			||||||
                            action='store_true',
 | 
					                            action='store_true',
 | 
				
			||||||
                            help='disable logging statistics')
 | 
					                            help='disable logging statistics')
 | 
				
			||||||
 | 
					        # Quantization settings.
 | 
				
			||||||
 | 
					        parser.add_argument('--quantization',
 | 
				
			||||||
 | 
					                            '-q',
 | 
				
			||||||
 | 
					                            type=str,
 | 
				
			||||||
 | 
					                            choices=['awq', None],
 | 
				
			||||||
 | 
					                            default=None,
 | 
				
			||||||
 | 
					                            help='Method used to quantize the weights')
 | 
				
			||||||
        return parser
 | 
					        return parser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
@ -139,22 +179,21 @@ class EngineArgs:
 | 
				
			|||||||
    def create_engine_configs(
 | 
					    def create_engine_configs(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
 | 
					    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
 | 
				
			||||||
        # Initialize the configs.
 | 
					 | 
				
			||||||
        model_config = ModelConfig(self.model, self.tokenizer,
 | 
					        model_config = ModelConfig(self.model, self.tokenizer,
 | 
				
			||||||
                                   self.tokenizer_mode, self.download_dir,
 | 
					                                   self.tokenizer_mode, self.trust_remote_code,
 | 
				
			||||||
                                   self.use_np_weights, self.use_dummy_weights,
 | 
					                                   self.download_dir, self.load_format,
 | 
				
			||||||
                                   self.dtype, self.seed)
 | 
					                                   self.dtype, self.seed, self.revision,
 | 
				
			||||||
        cache_config = CacheConfig(self.block_size,
 | 
					                                   self.tokenizer_revision, self.max_model_len,
 | 
				
			||||||
                                   self.gpu_memory_utilization,
 | 
					                                   self.quantization)
 | 
				
			||||||
                                   self.swap_space)
 | 
					        cache_config = CacheConfig(
 | 
				
			||||||
 | 
					            self.block_size, self.gpu_memory_utilization, self.swap_space,
 | 
				
			||||||
 | 
					            getattr(model_config.hf_config, 'sliding_window', None))
 | 
				
			||||||
        parallel_config = ParallelConfig(self.pipeline_parallel_size,
 | 
					        parallel_config = ParallelConfig(self.pipeline_parallel_size,
 | 
				
			||||||
                                         self.tensor_parallel_size,
 | 
					                                         self.tensor_parallel_size,
 | 
				
			||||||
                                         self.worker_use_ray)
 | 
					                                         self.worker_use_ray)
 | 
				
			||||||
        model_max_len = getattr(model_config.hf_config,
 | 
					 | 
				
			||||||
                                'max_position_embeddings', float('inf'))
 | 
					 | 
				
			||||||
        max_seq_len = min(self.max_num_batched_tokens, model_max_len)
 | 
					 | 
				
			||||||
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
 | 
					        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
 | 
				
			||||||
                                           self.max_num_seqs, max_seq_len)
 | 
					                                           self.max_num_seqs,
 | 
				
			||||||
 | 
					                                           model_config.max_model_len)
 | 
				
			||||||
        return model_config, cache_config, parallel_config, scheduler_config
 | 
					        return model_config, cache_config, parallel_config, scheduler_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -163,6 +202,7 @@ class AsyncEngineArgs(EngineArgs):
 | 
				
			|||||||
    """Arguments for asynchronous vLLM engine."""
 | 
					    """Arguments for asynchronous vLLM engine."""
 | 
				
			||||||
    engine_use_ray: bool = False
 | 
					    engine_use_ray: bool = False
 | 
				
			||||||
    disable_log_requests: bool = False
 | 
					    disable_log_requests: bool = False
 | 
				
			||||||
 | 
					    max_log_len: Optional[int] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def add_cli_args(
 | 
					    def add_cli_args(
 | 
				
			||||||
@ -175,4 +215,10 @@ class AsyncEngineArgs(EngineArgs):
 | 
				
			|||||||
        parser.add_argument('--disable-log-requests',
 | 
					        parser.add_argument('--disable-log-requests',
 | 
				
			||||||
                            action='store_true',
 | 
					                            action='store_true',
 | 
				
			||||||
                            help='disable logging requests')
 | 
					                            help='disable logging requests')
 | 
				
			||||||
 | 
					        parser.add_argument('--max-log-len',
 | 
				
			||||||
 | 
					                            type=int,
 | 
				
			||||||
 | 
					                            default=None,
 | 
				
			||||||
 | 
					                            help='max number of prompt characters or prompt '
 | 
				
			||||||
 | 
					                            'ID numbers being printed in log. '
 | 
				
			||||||
 | 
					                            'Default: unlimited.')
 | 
				
			||||||
        return parser
 | 
					        return parser
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,8 @@
 | 
				
			|||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from typing import Dict, List, Optional
 | 
					from functools import partial
 | 
				
			||||||
 | 
					from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
 | 
				
			||||||
 | 
					                    Union)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm.config import ModelConfig
 | 
					from vllm.config import ModelConfig
 | 
				
			||||||
from vllm.engine.arg_utils import AsyncEngineArgs
 | 
					from vllm.engine.arg_utils import AsyncEngineArgs
 | 
				
			||||||
@ -12,7 +14,219 @@ from vllm.sampling_params import SamplingParams
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
logger = init_logger(__name__)
 | 
					logger = init_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1  # seconds
 | 
					
 | 
				
			||||||
 | 
					class AsyncEngineDeadError(RuntimeError):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _raise_exception_on_finish(task: asyncio.Task,
 | 
				
			||||||
 | 
					                               request_tracker: "RequestTracker") -> None:
 | 
				
			||||||
 | 
					    msg = ("Task finished unexpectedly. This should never happen! "
 | 
				
			||||||
 | 
					           "Please open an issue on Github.")
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            task.result()
 | 
				
			||||||
 | 
					        except asyncio.CancelledError:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        except Exception as exc:
 | 
				
			||||||
 | 
					            raise AsyncEngineDeadError(
 | 
				
			||||||
 | 
					                msg + " See stack trace above for the actual cause.") from exc
 | 
				
			||||||
 | 
					        raise AsyncEngineDeadError(msg)
 | 
				
			||||||
 | 
					    except Exception as exc:
 | 
				
			||||||
 | 
					        request_tracker.propagate_exception(exc)
 | 
				
			||||||
 | 
					        raise exc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AsyncStream:
 | 
				
			||||||
 | 
					    """A stream of RequestOutputs for a request that can be
 | 
				
			||||||
 | 
					    iterated over asynchronously."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, request_id: str) -> None:
 | 
				
			||||||
 | 
					        self.request_id = request_id
 | 
				
			||||||
 | 
					        self._queue = asyncio.Queue()
 | 
				
			||||||
 | 
					        self._finished = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def put(self, item: RequestOutput) -> None:
 | 
				
			||||||
 | 
					        if self._finished:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        self._queue.put_nowait(item)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def finish(self) -> None:
 | 
				
			||||||
 | 
					        self._queue.put_nowait(StopIteration)
 | 
				
			||||||
 | 
					        self._finished = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def finished(self) -> bool:
 | 
				
			||||||
 | 
					        return self._finished
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __aiter__(self):
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def __anext__(self) -> RequestOutput:
 | 
				
			||||||
 | 
					        result = await self._queue.get()
 | 
				
			||||||
 | 
					        if result is StopIteration:
 | 
				
			||||||
 | 
					            raise StopAsyncIteration
 | 
				
			||||||
 | 
					        elif isinstance(result, Exception):
 | 
				
			||||||
 | 
					            raise result
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RequestTracker:
 | 
				
			||||||
 | 
					    """Synchronous abstraction for tracking requests."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self) -> None:
 | 
				
			||||||
 | 
					        self._request_streams: Dict[str, AsyncStream] = {}
 | 
				
			||||||
 | 
					        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
 | 
				
			||||||
 | 
					        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
 | 
				
			||||||
 | 
					                                                dict]] = asyncio.Queue()
 | 
				
			||||||
 | 
					        self.new_requests_event = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __contains__(self, item):
 | 
				
			||||||
 | 
					        return item in self._request_streams
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_event(self):
 | 
				
			||||||
 | 
					        self.new_requests_event = asyncio.Event()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def propagate_exception(self,
 | 
				
			||||||
 | 
					                            exc: Exception,
 | 
				
			||||||
 | 
					                            request_id: Optional[str] = None) -> None:
 | 
				
			||||||
 | 
					        """Propagate an exception to request streams
 | 
				
			||||||
 | 
					        (all if request_id is None)."""
 | 
				
			||||||
 | 
					        if request_id is not None:
 | 
				
			||||||
 | 
					            self._request_streams[request_id].put(exc)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            for stream in self._request_streams.values():
 | 
				
			||||||
 | 
					                stream.put(exc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_request_output(self,
 | 
				
			||||||
 | 
					                               request_output: RequestOutput,
 | 
				
			||||||
 | 
					                               *,
 | 
				
			||||||
 | 
					                               verbose: bool = False) -> None:
 | 
				
			||||||
 | 
					        """Process a request output from the engine."""
 | 
				
			||||||
 | 
					        request_id = request_output.request_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._request_streams[request_id].put(request_output)
 | 
				
			||||||
 | 
					        if request_output.finished:
 | 
				
			||||||
 | 
					            if verbose:
 | 
				
			||||||
 | 
					                logger.info(f"Finished request {request_id}.")
 | 
				
			||||||
 | 
					            self.abort_request(request_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_request(self, request_id: str,
 | 
				
			||||||
 | 
					                    **engine_add_request_kwargs) -> AsyncStream:
 | 
				
			||||||
 | 
					        """Add a request to be sent to the engine on the next background
 | 
				
			||||||
 | 
					        loop iteration."""
 | 
				
			||||||
 | 
					        if request_id in self._request_streams:
 | 
				
			||||||
 | 
					            raise KeyError(f"Request {request_id} already exists.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        stream = AsyncStream(request_id)
 | 
				
			||||||
 | 
					        self._new_requests.put_nowait((stream, {
 | 
				
			||||||
 | 
					            "request_id": request_id,
 | 
				
			||||||
 | 
					            **engine_add_request_kwargs
 | 
				
			||||||
 | 
					        }))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.new_requests_event.set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return stream
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
 | 
				
			||||||
 | 
					        """Abort a request during next background loop iteration."""
 | 
				
			||||||
 | 
					        if verbose:
 | 
				
			||||||
 | 
					            logger.info(f"Aborted request {request_id}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._finished_requests.put_nowait(request_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if request_id not in self._request_streams or self._request_streams[
 | 
				
			||||||
 | 
					                request_id].finished:
 | 
				
			||||||
 | 
					            # The request has already finished or been aborted.
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._request_streams[request_id].finish()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
 | 
				
			||||||
 | 
					        """Get the new requests and finished requests to be
 | 
				
			||||||
 | 
					        sent to the engine."""
 | 
				
			||||||
 | 
					        new_requests: List[dict] = []
 | 
				
			||||||
 | 
					        finished_requests: Set[str] = set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        while not self._finished_requests.empty():
 | 
				
			||||||
 | 
					            request_id = self._finished_requests.get_nowait()
 | 
				
			||||||
 | 
					            finished_requests.add(request_id)
 | 
				
			||||||
 | 
					            self._request_streams.pop(request_id, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        while not self._new_requests.empty():
 | 
				
			||||||
 | 
					            stream, new_request = self._new_requests.get_nowait()
 | 
				
			||||||
 | 
					            if stream.request_id in finished_requests:
 | 
				
			||||||
 | 
					                # The request has already been aborted.
 | 
				
			||||||
 | 
					                stream.finish()
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            self._request_streams[stream.request_id] = stream
 | 
				
			||||||
 | 
					            new_requests.append(new_request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.new_requests_event.clear()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return new_requests, finished_requests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def wait_for_new_requests(self):
 | 
				
			||||||
 | 
					        await self.new_requests_event.wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _AsyncLLMEngine(LLMEngine):
 | 
				
			||||||
 | 
					    """Extension of LLMEngine to add async methods."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def step_async(self) -> List[RequestOutput]:
 | 
				
			||||||
 | 
					        """Performs one decoding iteration and returns newly generated results.
 | 
				
			||||||
 | 
					        The workers are ran asynchronously if possible.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This function performs one decoding iteration of the engine. It first
 | 
				
			||||||
 | 
					        schedules the sequences to be executed in the next iteration and the
 | 
				
			||||||
 | 
					        token blocks to be swapped in/out/copy. Then, it executes the model
 | 
				
			||||||
 | 
					        and updates the scheduler with the model outputs. Finally, it decodes
 | 
				
			||||||
 | 
					        the sequences and returns the newly generated results.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
 | 
				
			||||||
 | 
					        if scheduler_outputs.is_empty():
 | 
				
			||||||
 | 
					            return ignored
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Execute the model.
 | 
				
			||||||
 | 
					        output = await self._run_workers_async(
 | 
				
			||||||
 | 
					            "execute_model",
 | 
				
			||||||
 | 
					            seq_group_metadata_list=seq_group_metadata_list,
 | 
				
			||||||
 | 
					            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
 | 
				
			||||||
 | 
					            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
 | 
				
			||||||
 | 
					            blocks_to_copy=scheduler_outputs.blocks_to_copy,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self._process_model_outputs(output, scheduler_outputs) + ignored
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _run_workers_async(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        method: str,
 | 
				
			||||||
 | 
					        *args,
 | 
				
			||||||
 | 
					        get_all_outputs: bool = False,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> Any:
 | 
				
			||||||
 | 
					        """Runs the given method on all workers."""
 | 
				
			||||||
 | 
					        all_outputs = []
 | 
				
			||||||
 | 
					        for worker in self.workers:
 | 
				
			||||||
 | 
					            if self.parallel_config.worker_use_ray:
 | 
				
			||||||
 | 
					                executor = partial(worker.execute_method.remote, method)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                executor = getattr(worker, method)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            output = executor(*args, **kwargs)
 | 
				
			||||||
 | 
					            all_outputs.append(output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.parallel_config.worker_use_ray:
 | 
				
			||||||
 | 
					            all_outputs = await asyncio.gather(*all_outputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if get_all_outputs:
 | 
				
			||||||
 | 
					            return all_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Make sure all workers have the same results.
 | 
				
			||||||
 | 
					        output = all_outputs[0]
 | 
				
			||||||
 | 
					        for other_output in all_outputs[1:]:
 | 
				
			||||||
 | 
					            assert output == other_output
 | 
				
			||||||
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AsyncLLMEngine:
 | 
					class AsyncLLMEngine:
 | 
				
			||||||
@ -34,52 +248,149 @@ class AsyncLLMEngine:
 | 
				
			|||||||
            async frontend will be executed in a separate process as the
 | 
					            async frontend will be executed in a separate process as the
 | 
				
			||||||
            model workers.
 | 
					            model workers.
 | 
				
			||||||
        log_requests: Whether to log the requests.
 | 
					        log_requests: Whether to log the requests.
 | 
				
			||||||
 | 
					        start_engine_loop: If True, the background task to run the engine
 | 
				
			||||||
 | 
					            will be automatically started in the generate call.
 | 
				
			||||||
        *args, *kwargs: Arguments for LLMEngine.
 | 
					        *args, *kwargs: Arguments for LLMEngine.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 worker_use_ray: bool,
 | 
					                 worker_use_ray: bool,
 | 
				
			||||||
                 engine_use_ray: bool,
 | 
					                 engine_use_ray: bool,
 | 
				
			||||||
                 *args,
 | 
					                 *args,
 | 
				
			||||||
                 log_requests: bool = True,
 | 
					                 log_requests: bool = True,
 | 
				
			||||||
 | 
					                 max_log_len: Optional[int] = None,
 | 
				
			||||||
 | 
					                 start_engine_loop: bool = True,
 | 
				
			||||||
                 **kwargs) -> None:
 | 
					                 **kwargs) -> None:
 | 
				
			||||||
        self.worker_use_ray = worker_use_ray
 | 
					        self.worker_use_ray = worker_use_ray
 | 
				
			||||||
        self.engine_use_ray = engine_use_ray
 | 
					        self.engine_use_ray = engine_use_ray
 | 
				
			||||||
        self.log_requests = log_requests
 | 
					        self.log_requests = log_requests
 | 
				
			||||||
        if not self.engine_use_ray:
 | 
					        self.max_log_len = max_log_len
 | 
				
			||||||
            engine_class = LLMEngine
 | 
					        self.engine = self._init_engine(*args, **kwargs)
 | 
				
			||||||
        elif self.worker_use_ray:
 | 
					
 | 
				
			||||||
            engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
 | 
					        self.background_loop = None
 | 
				
			||||||
        else:
 | 
					        # We need to keep a reference to unshielded
 | 
				
			||||||
            engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
 | 
					        # task as well to prevent it from being garbage
 | 
				
			||||||
        self.engine = engine_class(*args, **kwargs)
 | 
					        # collected
 | 
				
			||||||
        # Request id -> request output.
 | 
					        self._background_loop_unshielded = None
 | 
				
			||||||
        self.request_outputs: Dict[str, RequestOutput] = {}
 | 
					        self.start_engine_loop = start_engine_loop
 | 
				
			||||||
        # Request id -> event to notify that there is new output.
 | 
					        self._request_tracker = RequestTracker()
 | 
				
			||||||
        self.request_events: Dict[str, asyncio.Event] = {}
 | 
					
 | 
				
			||||||
        self.is_engine_running = False
 | 
					    @property
 | 
				
			||||||
        self.kicking_request_id: Optional[str] = None
 | 
					    def is_running(self) -> bool:
 | 
				
			||||||
 | 
					        return (self.background_loop is not None
 | 
				
			||||||
 | 
					                and not self.background_loop.done())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def start_background_loop(self) -> None:
 | 
				
			||||||
 | 
					        """Start the background loop."""
 | 
				
			||||||
 | 
					        if self.is_running:
 | 
				
			||||||
 | 
					            raise RuntimeError("Background loop is already running.")
 | 
				
			||||||
 | 
					        self._request_tracker.init_event()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._background_loop_unshielded = asyncio.get_event_loop(
 | 
				
			||||||
 | 
					        ).create_task(self.run_engine_loop())
 | 
				
			||||||
 | 
					        self._background_loop_unshielded.add_done_callback(
 | 
				
			||||||
 | 
					            partial(_raise_exception_on_finish,
 | 
				
			||||||
 | 
					                    request_tracker=self._request_tracker))
 | 
				
			||||||
 | 
					        self.background_loop = asyncio.shield(self._background_loop_unshielded)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _init_engine(self, *args,
 | 
				
			||||||
 | 
					                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
 | 
				
			||||||
 | 
					        if not self.engine_use_ray:
 | 
				
			||||||
 | 
					            engine_class = self._engine_class
 | 
				
			||||||
 | 
					        elif self.worker_use_ray:
 | 
				
			||||||
 | 
					            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
 | 
				
			||||||
 | 
					        return engine_class(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def engine_step(self) -> bool:
 | 
				
			||||||
 | 
					        """Kick the engine to process the waiting requests.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns True if there are in-progress requests."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_requests, finished_requests = (
 | 
				
			||||||
 | 
					            self._request_tracker.get_new_and_finished_requests())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for new_request in new_requests:
 | 
				
			||||||
 | 
					            # Add the request into the vLLM engine's waiting queue.
 | 
				
			||||||
 | 
					            # TODO: Maybe add add_request_batch to reduce Ray overhead
 | 
				
			||||||
 | 
					            if self.engine_use_ray:
 | 
				
			||||||
 | 
					                await self.engine.add_request.remote(**new_request)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.engine.add_request(**new_request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if finished_requests:
 | 
				
			||||||
 | 
					            await self._engine_abort(finished_requests)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def engine_step(self, kicking_request_id: Optional[str] = None):
 | 
					 | 
				
			||||||
        """Kick the engine to process the waiting requests."""
 | 
					 | 
				
			||||||
        self.is_engine_running = True
 | 
					 | 
				
			||||||
        self.kicking_request_id = kicking_request_id
 | 
					 | 
				
			||||||
        if self.engine_use_ray:
 | 
					        if self.engine_use_ray:
 | 
				
			||||||
            request_outputs = await self.engine.step.remote()
 | 
					            request_outputs = await self.engine.step.remote()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # Yield to the event loop to allow other coroutines to run
 | 
					            request_outputs = await self.engine.step_async()
 | 
				
			||||||
            # while is_engine_running is True. This let the engine to add new
 | 
					 | 
				
			||||||
            # requests into the queue.
 | 
					 | 
				
			||||||
            await asyncio.sleep(0)
 | 
					 | 
				
			||||||
            request_outputs = self.engine.step()
 | 
					 | 
				
			||||||
        self.is_engine_running = False
 | 
					 | 
				
			||||||
        self.kicking_request_id = None
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Notify the waiting coroutines that there are new outputs ready.
 | 
					        # Put the outputs into the corresponding streams.
 | 
				
			||||||
        for request_output in request_outputs:
 | 
					        for request_output in request_outputs:
 | 
				
			||||||
            request_id = request_output.request_id
 | 
					            self._request_tracker.process_request_output(
 | 
				
			||||||
            self.request_outputs[request_id] = request_output
 | 
					                request_output, verbose=self.log_requests)
 | 
				
			||||||
            self.request_events[request_id].set()
 | 
					
 | 
				
			||||||
 | 
					        return len(request_outputs) > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _engine_abort(self, request_ids: Iterable[str]):
 | 
				
			||||||
 | 
					        if self.engine_use_ray:
 | 
				
			||||||
 | 
					            await self.engine.abort_request.remote(request_ids)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.engine.abort_request(request_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def run_engine_loop(self):
 | 
				
			||||||
 | 
					        # Initialize the RequestTracker here so it uses the right event loop.
 | 
				
			||||||
 | 
					        has_requests_in_progress = False
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            if not has_requests_in_progress:
 | 
				
			||||||
 | 
					                await self._request_tracker.wait_for_new_requests()
 | 
				
			||||||
 | 
					            has_requests_in_progress = await self.engine_step()
 | 
				
			||||||
 | 
					            await asyncio.sleep(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def add_request(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        request_id: str,
 | 
				
			||||||
 | 
					        prompt: Optional[str],
 | 
				
			||||||
 | 
					        sampling_params: SamplingParams,
 | 
				
			||||||
 | 
					        prompt_token_ids: Optional[List[int]] = None,
 | 
				
			||||||
 | 
					        arrival_time: Optional[float] = None,
 | 
				
			||||||
 | 
					    ) -> AsyncStream:
 | 
				
			||||||
 | 
					        if self.log_requests:
 | 
				
			||||||
 | 
					            shortened_prompt = prompt
 | 
				
			||||||
 | 
					            shortened_token_ids = prompt_token_ids
 | 
				
			||||||
 | 
					            if self.max_log_len is not None:
 | 
				
			||||||
 | 
					                if shortened_prompt is not None:
 | 
				
			||||||
 | 
					                    shortened_prompt = shortened_prompt[:self.max_log_len]
 | 
				
			||||||
 | 
					                if shortened_token_ids is not None:
 | 
				
			||||||
 | 
					                    shortened_token_ids = shortened_token_ids[:self.
 | 
				
			||||||
 | 
					                                                              max_log_len]
 | 
				
			||||||
 | 
					            logger.info(f"Received request {request_id}: "
 | 
				
			||||||
 | 
					                        f"prompt: {shortened_prompt!r}, "
 | 
				
			||||||
 | 
					                        f"sampling params: {sampling_params}, "
 | 
				
			||||||
 | 
					                        f"prompt token ids: {shortened_token_ids}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.is_running:
 | 
				
			||||||
 | 
					            if self.start_engine_loop:
 | 
				
			||||||
 | 
					                self.start_background_loop()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise AsyncEngineDeadError(
 | 
				
			||||||
 | 
					                    "Background loop is not running. If it was running, "
 | 
				
			||||||
 | 
					                    "inspect the output to find the stacktrace of the "
 | 
				
			||||||
 | 
					                    "error that caused the background loop to stop "
 | 
				
			||||||
 | 
					                    "(AsyncEngineDeadError).")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        stream = self._request_tracker.add_request(
 | 
				
			||||||
 | 
					            request_id,
 | 
				
			||||||
 | 
					            prompt=prompt,
 | 
				
			||||||
 | 
					            sampling_params=sampling_params,
 | 
				
			||||||
 | 
					            prompt_token_ids=prompt_token_ids,
 | 
				
			||||||
 | 
					            arrival_time=arrival_time)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return stream
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def generate(
 | 
					    async def generate(
 | 
				
			||||||
            self,
 | 
					            self,
 | 
				
			||||||
@ -106,78 +417,23 @@ class AsyncLLMEngine:
 | 
				
			|||||||
            request.
 | 
					            request.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        # Preprocess the request.
 | 
					        # Preprocess the request.
 | 
				
			||||||
        arrival_time = time.time()
 | 
					        # This should not be used for logging, as it is monotonic time.
 | 
				
			||||||
 | 
					        arrival_time = time.monotonic()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create an event to notify us that there is new output from the
 | 
					        try:
 | 
				
			||||||
        # vLLM engine.
 | 
					            stream = await self.add_request(request_id,
 | 
				
			||||||
        request_event = asyncio.Event()
 | 
					                                            prompt,
 | 
				
			||||||
        self.request_events[request_id] = request_event
 | 
					                                            sampling_params,
 | 
				
			||||||
 | 
					                                            prompt_token_ids=prompt_token_ids,
 | 
				
			||||||
 | 
					                                            arrival_time=arrival_time)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.log_requests:
 | 
					            async for request_output in stream:
 | 
				
			||||||
            logger.info(f"Received request {request_id}: "
 | 
					                yield request_output
 | 
				
			||||||
                        f"prompt: {prompt!r}, "
 | 
					        except (Exception, asyncio.CancelledError) as e:
 | 
				
			||||||
                        f"sampling params: {sampling_params}, "
 | 
					            # If there is an exception or coroutine is cancelled, abort the
 | 
				
			||||||
                        f"prompt token ids: {prompt_token_ids}.")
 | 
					            # request.
 | 
				
			||||||
 | 
					            self._abort(request_id)
 | 
				
			||||||
        # Add the request into the vLLM engine's waiting queue.
 | 
					            raise e
 | 
				
			||||||
        if self.engine_use_ray:
 | 
					 | 
				
			||||||
            await self.engine.add_request.remote(
 | 
					 | 
				
			||||||
                request_id,
 | 
					 | 
				
			||||||
                prompt,
 | 
					 | 
				
			||||||
                sampling_params,
 | 
					 | 
				
			||||||
                prompt_token_ids=prompt_token_ids,
 | 
					 | 
				
			||||||
                arrival_time=arrival_time)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.engine.add_request(request_id,
 | 
					 | 
				
			||||||
                                    prompt,
 | 
					 | 
				
			||||||
                                    sampling_params,
 | 
					 | 
				
			||||||
                                    prompt_token_ids=prompt_token_ids,
 | 
					 | 
				
			||||||
                                    arrival_time=arrival_time)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # The vLLM engine does not have a background loop that keeps
 | 
					 | 
				
			||||||
        # processing incoming requests. Therefore, we need to keep kicking
 | 
					 | 
				
			||||||
        # the engine to process the requests.
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            if request_id not in self.request_events:
 | 
					 | 
				
			||||||
                # The request has been aborted.
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Kick the engine if the engine is not running.
 | 
					 | 
				
			||||||
            if not self.is_engine_running:
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    await self.engine_step(request_id)
 | 
					 | 
				
			||||||
                except RuntimeError as e:
 | 
					 | 
				
			||||||
                    await self.abort(request_id)
 | 
					 | 
				
			||||||
                    raise e
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Wait for new output. The group_event will be set in engine_step
 | 
					 | 
				
			||||||
            # when there is new output available for the sequence group.
 | 
					 | 
				
			||||||
            # Added a timeout to prevent deadlock.
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                await asyncio.wait_for(request_event.wait(),
 | 
					 | 
				
			||||||
                                       timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
 | 
					 | 
				
			||||||
            except asyncio.TimeoutError:
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
            # Reset the event to wait for the next output.
 | 
					 | 
				
			||||||
            request_event.clear()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Decode and return new outputs.
 | 
					 | 
				
			||||||
            request_output = self.request_outputs[request_id]
 | 
					 | 
				
			||||||
            yield request_output
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Once finished, release the resources of the sequence group.
 | 
					 | 
				
			||||||
            if request_output.finished:
 | 
					 | 
				
			||||||
                if self.log_requests:
 | 
					 | 
				
			||||||
                    logger.info(f"Finished request {request_id}.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                del self.request_outputs[request_id]
 | 
					 | 
				
			||||||
                del self.request_events[request_id]
 | 
					 | 
				
			||||||
                # Kick the engine if the engine is not running. This is to
 | 
					 | 
				
			||||||
                # prevent that there are still requests in engine's waiting
 | 
					 | 
				
			||||||
                # queue to be executed.
 | 
					 | 
				
			||||||
                if not self.is_engine_running:
 | 
					 | 
				
			||||||
                    await self.engine_step()
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def abort(self, request_id: str) -> None:
 | 
					    async def abort(self, request_id: str) -> None:
 | 
				
			||||||
        """Abort a request.
 | 
					        """Abort a request.
 | 
				
			||||||
@ -188,28 +444,26 @@ class AsyncLLMEngine:
 | 
				
			|||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            request_id: The unique id of the request.
 | 
					            request_id: The unique id of the request.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if request_id not in self.request_events:
 | 
					        if not self.is_running:
 | 
				
			||||||
            # The request has already finished or been aborted.
 | 
					            raise AsyncEngineDeadError(
 | 
				
			||||||
            return
 | 
					                "Background loop is not running. If it was running, "
 | 
				
			||||||
 | 
					                "inspect the output to find the stacktrace of the "
 | 
				
			||||||
 | 
					                "error that caused the background loop to stop "
 | 
				
			||||||
 | 
					                "(AsyncEngineDeadError).")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.log_requests:
 | 
					        return self._abort(request_id)
 | 
				
			||||||
            logger.info(f"Aborted request {request_id}.")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.engine_use_ray:
 | 
					    def _abort(self, request_id: str) -> None:
 | 
				
			||||||
            await self.engine.abort_request.remote(request_id)
 | 
					        """Abort a request.
 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.engine.abort_request(request_id)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if request_id in self.request_events:
 | 
					        Abort a submitted request. If the request is finished or not found,
 | 
				
			||||||
            del self.request_events[request_id]
 | 
					        this method will be a no-op.
 | 
				
			||||||
        if request_id in self.request_outputs:
 | 
					 | 
				
			||||||
            del self.request_outputs[request_id]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # To prevent deadlock when a request is aborted while the engine is
 | 
					        Args:
 | 
				
			||||||
        # running.
 | 
					            request_id: The unique id of the request.
 | 
				
			||||||
        if self.kicking_request_id == request_id:
 | 
					        """
 | 
				
			||||||
            self.is_engine_running = False
 | 
					        self._request_tracker.abort_request(request_id,
 | 
				
			||||||
            self.kicking_request_id = None
 | 
					                                            verbose=self.log_requests)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_model_config(self) -> ModelConfig:
 | 
					    async def get_model_config(self) -> ModelConfig:
 | 
				
			||||||
        """Get the model configuration of the vLLM engine."""
 | 
					        """Get the model configuration of the vLLM engine."""
 | 
				
			||||||
@ -220,20 +474,23 @@ class AsyncLLMEngine:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def from_engine_args(cls,
 | 
					    def from_engine_args(cls,
 | 
				
			||||||
                         engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
 | 
					                         engine_args: AsyncEngineArgs,
 | 
				
			||||||
 | 
					                         start_engine_loop: bool = True) -> "AsyncLLMEngine":
 | 
				
			||||||
        """Creates an async LLM engine from the engine arguments."""
 | 
					        """Creates an async LLM engine from the engine arguments."""
 | 
				
			||||||
        # Create the engine configs.
 | 
					        # Create the engine configs.
 | 
				
			||||||
        engine_configs = engine_args.create_engine_configs()
 | 
					        engine_configs = engine_args.create_engine_configs()
 | 
				
			||||||
        parallel_config = engine_configs[2]
 | 
					        parallel_config = engine_configs[2]
 | 
				
			||||||
        # Initialize the cluster.
 | 
					        # Initialize the cluster.
 | 
				
			||||||
        distributed_init_method, devices = initialize_cluster(
 | 
					        distributed_init_method, placement_group = initialize_cluster(
 | 
				
			||||||
            parallel_config, engine_args.engine_use_ray)
 | 
					            parallel_config, engine_args.engine_use_ray)
 | 
				
			||||||
        # Create the async LLM engine.
 | 
					        # Create the async LLM engine.
 | 
				
			||||||
        engine = cls(engine_args.worker_use_ray,
 | 
					        engine = cls(engine_args.worker_use_ray,
 | 
				
			||||||
                     engine_args.engine_use_ray,
 | 
					                     engine_args.engine_use_ray,
 | 
				
			||||||
                     *engine_configs,
 | 
					                     *engine_configs,
 | 
				
			||||||
                     distributed_init_method,
 | 
					                     distributed_init_method,
 | 
				
			||||||
                     devices,
 | 
					                     placement_group,
 | 
				
			||||||
                     log_requests=not engine_args.disable_log_requests,
 | 
					                     log_requests=not engine_args.disable_log_requests,
 | 
				
			||||||
                     log_stats=not engine_args.disable_log_stats)
 | 
					                     log_stats=not engine_args.disable_log_stats,
 | 
				
			||||||
 | 
					                     max_log_len=engine_args.max_log_len,
 | 
				
			||||||
 | 
					                     start_engine_loop=start_engine_loop)
 | 
				
			||||||
        return engine
 | 
					        return engine
 | 
				
			||||||
 | 
				
			|||||||
@ -1,22 +1,34 @@
 | 
				
			|||||||
 | 
					import copy
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from typing import Any, List, Optional
 | 
					from functools import partial
 | 
				
			||||||
 | 
					from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
 | 
					from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
 | 
				
			||||||
                         SchedulerConfig)
 | 
					                         SchedulerConfig)
 | 
				
			||||||
from vllm.core.scheduler import Scheduler
 | 
					from vllm.core.scheduler import Scheduler, SchedulerOutputs
 | 
				
			||||||
from vllm.engine.arg_utils import EngineArgs
 | 
					from vllm.engine.arg_utils import EngineArgs
 | 
				
			||||||
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
 | 
					from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
 | 
				
			||||||
from vllm.logger import init_logger
 | 
					from vllm.logger import init_logger
 | 
				
			||||||
from vllm.outputs import RequestOutput
 | 
					from vllm.outputs import RequestOutput
 | 
				
			||||||
from vllm.sampling_params import SamplingParams
 | 
					from vllm.sampling_params import SamplingParams
 | 
				
			||||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
 | 
					from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
 | 
				
			||||||
 | 
					                           SequenceGroupMetadata, SequenceGroupOutputs,
 | 
				
			||||||
 | 
					                           SequenceOutputs, SequenceStatus)
 | 
				
			||||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
 | 
					from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
 | 
				
			||||||
                                               get_tokenizer)
 | 
					                                               get_tokenizer)
 | 
				
			||||||
from vllm.utils import Counter
 | 
					from vllm.utils import Counter
 | 
				
			||||||
from vllm.worker.worker import Worker
 | 
					
 | 
				
			||||||
 | 
					if ray:
 | 
				
			||||||
 | 
					    from ray.air.util.torch_dist import init_torch_dist_process_group
 | 
				
			||||||
 | 
					    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
 | 
					    from ray.util.placement_group import PlacementGroup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = init_logger(__name__)
 | 
					logger = init_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_LOGGING_INTERVAL_SEC = 5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LLMEngine:
 | 
					class LLMEngine:
 | 
				
			||||||
    """An LLM engine that receives requests and generates texts.
 | 
					    """An LLM engine that receives requests and generates texts.
 | 
				
			||||||
@ -42,8 +54,8 @@ class LLMEngine:
 | 
				
			|||||||
        scheduler_config: The configuration related to the request scheduler.
 | 
					        scheduler_config: The configuration related to the request scheduler.
 | 
				
			||||||
        distributed_init_method: The initialization method for distributed
 | 
					        distributed_init_method: The initialization method for distributed
 | 
				
			||||||
            execution. See `torch.distributed.init_process_group` for details.
 | 
					            execution. See `torch.distributed.init_process_group` for details.
 | 
				
			||||||
        stage_devices: The list of devices for each stage. Each stage is a list
 | 
					        placement_group: Ray placement group for distributed execution.
 | 
				
			||||||
            of (rank, node_resource, device) tuples.
 | 
					            Required for distributed execution.
 | 
				
			||||||
        log_stats: Whether to log statistics.
 | 
					        log_stats: Whether to log statistics.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -54,7 +66,7 @@ class LLMEngine:
 | 
				
			|||||||
        parallel_config: ParallelConfig,
 | 
					        parallel_config: ParallelConfig,
 | 
				
			||||||
        scheduler_config: SchedulerConfig,
 | 
					        scheduler_config: SchedulerConfig,
 | 
				
			||||||
        distributed_init_method: str,
 | 
					        distributed_init_method: str,
 | 
				
			||||||
        stage_devices: List[List[DeviceID]],
 | 
					        placement_group: Optional["PlacementGroup"],
 | 
				
			||||||
        log_stats: bool,
 | 
					        log_stats: bool,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        logger.info(
 | 
					        logger.info(
 | 
				
			||||||
@ -62,50 +74,114 @@ class LLMEngine:
 | 
				
			|||||||
            f"model={model_config.model!r}, "
 | 
					            f"model={model_config.model!r}, "
 | 
				
			||||||
            f"tokenizer={model_config.tokenizer!r}, "
 | 
					            f"tokenizer={model_config.tokenizer!r}, "
 | 
				
			||||||
            f"tokenizer_mode={model_config.tokenizer_mode}, "
 | 
					            f"tokenizer_mode={model_config.tokenizer_mode}, "
 | 
				
			||||||
 | 
					            f"revision={model_config.revision}, "
 | 
				
			||||||
 | 
					            f"tokenizer_revision={model_config.tokenizer_revision}, "
 | 
				
			||||||
 | 
					            f"trust_remote_code={model_config.trust_remote_code}, "
 | 
				
			||||||
            f"dtype={model_config.dtype}, "
 | 
					            f"dtype={model_config.dtype}, "
 | 
				
			||||||
            f"use_dummy_weights={model_config.use_dummy_weights}, "
 | 
					            f"max_seq_len={model_config.max_model_len}, "
 | 
				
			||||||
            f"download_dir={model_config.download_dir!r}, "
 | 
					            f"download_dir={model_config.download_dir!r}, "
 | 
				
			||||||
            f"use_np_weights={model_config.use_np_weights}, "
 | 
					            f"load_format={model_config.load_format}, "
 | 
				
			||||||
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
 | 
					            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
 | 
				
			||||||
 | 
					            f"quantization={model_config.quantization}, "
 | 
				
			||||||
            f"seed={model_config.seed})")
 | 
					            f"seed={model_config.seed})")
 | 
				
			||||||
        # TODO(woosuk): Print more configs in debug mode.
 | 
					        # TODO(woosuk): Print more configs in debug mode.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.model_config = model_config
 | 
					        self.model_config = model_config
 | 
				
			||||||
        self.cache_config = cache_config
 | 
					        self.cache_config = cache_config
 | 
				
			||||||
 | 
					        assert self.cache_config.sliding_window == getattr(
 | 
				
			||||||
 | 
					            self.model_config.hf_config, "sliding_window", None)
 | 
				
			||||||
        self.parallel_config = parallel_config
 | 
					        self.parallel_config = parallel_config
 | 
				
			||||||
        self.scheduler_config = scheduler_config
 | 
					        self.scheduler_config = scheduler_config
 | 
				
			||||||
        self.log_stats = log_stats
 | 
					        self.log_stats = log_stats
 | 
				
			||||||
        self._verify_args()
 | 
					        self._verify_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.tokenizer = get_tokenizer(
 | 
					        self.tokenizer = get_tokenizer(
 | 
				
			||||||
            model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
 | 
					            model_config.tokenizer,
 | 
				
			||||||
 | 
					            tokenizer_mode=model_config.tokenizer_mode,
 | 
				
			||||||
 | 
					            trust_remote_code=model_config.trust_remote_code,
 | 
				
			||||||
 | 
					            tokenizer_revision=model_config.tokenizer_revision,
 | 
				
			||||||
 | 
					            revision=model_config.revision)
 | 
				
			||||||
        self.seq_counter = Counter()
 | 
					        self.seq_counter = Counter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create the parallel GPU workers.
 | 
					        # Create the parallel GPU workers.
 | 
				
			||||||
        self.workers: List[Worker] = []
 | 
					        if self.parallel_config.worker_use_ray:
 | 
				
			||||||
        assert len(stage_devices) == 1, "Only support one stage for now."
 | 
					            self._init_workers_ray(placement_group)
 | 
				
			||||||
        for rank, node_resource, _ in stage_devices[0]:
 | 
					        else:
 | 
				
			||||||
            worker_cls = Worker
 | 
					            self._init_workers(distributed_init_method)
 | 
				
			||||||
            if self.parallel_config.worker_use_ray:
 | 
					 | 
				
			||||||
                worker_cls = ray.remote(
 | 
					 | 
				
			||||||
                    num_cpus=0,
 | 
					 | 
				
			||||||
                    num_gpus=1,
 | 
					 | 
				
			||||||
                    resources={node_resource: 1e-3},
 | 
					 | 
				
			||||||
                )(worker_cls).remote
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            worker = worker_cls(
 | 
					 | 
				
			||||||
                model_config,
 | 
					 | 
				
			||||||
                parallel_config,
 | 
					 | 
				
			||||||
                scheduler_config,
 | 
					 | 
				
			||||||
                rank,
 | 
					 | 
				
			||||||
                distributed_init_method,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            self.workers.append(worker)
 | 
					 | 
				
			||||||
        # Profile the memory usage and initialize the cache.
 | 
					        # Profile the memory usage and initialize the cache.
 | 
				
			||||||
        self._init_cache()
 | 
					        self._init_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create the scheduler.
 | 
					        # Create the scheduler.
 | 
				
			||||||
        self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
 | 
					        self.scheduler = Scheduler(scheduler_config, cache_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Logging.
 | 
				
			||||||
 | 
					        self.last_logging_time = 0.0
 | 
				
			||||||
 | 
					        # List of (timestamp, num_tokens)
 | 
				
			||||||
 | 
					        self.num_prompt_tokens: List[Tuple[float, int]] = []
 | 
				
			||||||
 | 
					        # List of (timestamp, num_tokens)
 | 
				
			||||||
 | 
					        self.num_generation_tokens: List[Tuple[float, int]] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _init_workers(self, distributed_init_method: str):
 | 
				
			||||||
 | 
					        # Lazy import the Worker to avoid importing torch.cuda/xformers
 | 
				
			||||||
 | 
					        # before CUDA_VISIBLE_DEVICES is set in the Worker
 | 
				
			||||||
 | 
					        from vllm.worker.worker import Worker  # pylint: disable=import-outside-toplevel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert self.parallel_config.world_size == 1, (
 | 
				
			||||||
 | 
					            "Ray is required if parallel_config.world_size > 1.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.workers: List[Worker] = []
 | 
				
			||||||
 | 
					        worker = Worker(
 | 
				
			||||||
 | 
					            self.model_config,
 | 
				
			||||||
 | 
					            self.parallel_config,
 | 
				
			||||||
 | 
					            self.scheduler_config,
 | 
				
			||||||
 | 
					            0,
 | 
				
			||||||
 | 
					            distributed_init_method,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.workers.append(worker)
 | 
				
			||||||
 | 
					        self._run_workers(
 | 
				
			||||||
 | 
					            "init_model",
 | 
				
			||||||
 | 
					            get_all_outputs=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _init_workers_ray(self, placement_group: "PlacementGroup",
 | 
				
			||||||
 | 
					                          **ray_remote_kwargs):
 | 
				
			||||||
 | 
					        # Lazy import the Worker to avoid importing torch.cuda/xformers
 | 
				
			||||||
 | 
					        # before CUDA_VISIBLE_DEVICES is set in the Worker
 | 
				
			||||||
 | 
					        from vllm.worker.worker import Worker  # pylint: disable=import-outside-toplevel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.workers: List[Worker] = []
 | 
				
			||||||
 | 
					        for bundle in placement_group.bundle_specs:
 | 
				
			||||||
 | 
					            if not bundle.get("GPU", 0):
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            worker = ray.remote(
 | 
				
			||||||
 | 
					                num_cpus=0,
 | 
				
			||||||
 | 
					                num_gpus=1,
 | 
				
			||||||
 | 
					                scheduling_strategy=PlacementGroupSchedulingStrategy(
 | 
				
			||||||
 | 
					                    placement_group=placement_group,
 | 
				
			||||||
 | 
					                    placement_group_capture_child_tasks=True),
 | 
				
			||||||
 | 
					                **ray_remote_kwargs,
 | 
				
			||||||
 | 
					            )(RayWorker).remote(self.model_config.trust_remote_code)
 | 
				
			||||||
 | 
					            self.workers.append(worker)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Initialize torch distributed process group for the workers.
 | 
				
			||||||
 | 
					        init_torch_dist_process_group(self.workers, backend="nccl")
 | 
				
			||||||
 | 
					        model_config = copy.deepcopy(self.model_config)
 | 
				
			||||||
 | 
					        parallel_config = copy.deepcopy(self.parallel_config)
 | 
				
			||||||
 | 
					        scheduler_config = copy.deepcopy(self.scheduler_config)
 | 
				
			||||||
 | 
					        self._run_workers("init_worker",
 | 
				
			||||||
 | 
					                          get_all_outputs=True,
 | 
				
			||||||
 | 
					                          worker_init_fn=lambda: Worker(
 | 
				
			||||||
 | 
					                              model_config,
 | 
				
			||||||
 | 
					                              parallel_config,
 | 
				
			||||||
 | 
					                              scheduler_config,
 | 
				
			||||||
 | 
					                              None,
 | 
				
			||||||
 | 
					                              None,
 | 
				
			||||||
 | 
					                          ))
 | 
				
			||||||
 | 
					        self._run_workers(
 | 
				
			||||||
 | 
					            "init_model",
 | 
				
			||||||
 | 
					            get_all_outputs=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _verify_args(self) -> None:
 | 
					    def _verify_args(self) -> None:
 | 
				
			||||||
        self.model_config.verify_with_parallel_config(self.parallel_config)
 | 
					        self.model_config.verify_with_parallel_config(self.parallel_config)
 | 
				
			||||||
@ -149,11 +225,12 @@ class LLMEngine:
 | 
				
			|||||||
        engine_configs = engine_args.create_engine_configs()
 | 
					        engine_configs = engine_args.create_engine_configs()
 | 
				
			||||||
        parallel_config = engine_configs[2]
 | 
					        parallel_config = engine_configs[2]
 | 
				
			||||||
        # Initialize the cluster.
 | 
					        # Initialize the cluster.
 | 
				
			||||||
        distributed_init_method, devices = initialize_cluster(parallel_config)
 | 
					        distributed_init_method, placement_group = initialize_cluster(
 | 
				
			||||||
 | 
					            parallel_config)
 | 
				
			||||||
        # Create the LLM engine.
 | 
					        # Create the LLM engine.
 | 
				
			||||||
        engine = cls(*engine_configs,
 | 
					        engine = cls(*engine_configs,
 | 
				
			||||||
                     distributed_init_method,
 | 
					                     distributed_init_method,
 | 
				
			||||||
                     devices,
 | 
					                     placement_group,
 | 
				
			||||||
                     log_stats=not engine_args.disable_log_stats)
 | 
					                     log_stats=not engine_args.disable_log_stats)
 | 
				
			||||||
        return engine
 | 
					        return engine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -179,34 +256,31 @@ class LLMEngine:
 | 
				
			|||||||
            prompt_token_ids: The token IDs of the prompt. If None, we
 | 
					            prompt_token_ids: The token IDs of the prompt. If None, we
 | 
				
			||||||
                use the tokenizer to convert the prompts to token IDs.
 | 
					                use the tokenizer to convert the prompts to token IDs.
 | 
				
			||||||
            arrival_time: The arrival time of the request. If None, we use
 | 
					            arrival_time: The arrival time of the request. If None, we use
 | 
				
			||||||
                the current time.
 | 
					                the current monotonic time.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if arrival_time is None:
 | 
					        if arrival_time is None:
 | 
				
			||||||
            arrival_time = time.time()
 | 
					            arrival_time = time.monotonic()
 | 
				
			||||||
        if prompt_token_ids is None:
 | 
					        if prompt_token_ids is None:
 | 
				
			||||||
            assert prompt is not None
 | 
					            assert prompt is not None
 | 
				
			||||||
            prompt_token_ids = self.tokenizer.encode(prompt)
 | 
					            prompt_token_ids = self.tokenizer.encode(prompt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create the sequences.
 | 
					        # Create the sequences.
 | 
				
			||||||
        block_size = self.cache_config.block_size
 | 
					        block_size = self.cache_config.block_size
 | 
				
			||||||
        seqs: List[Sequence] = []
 | 
					        seq_id = next(self.seq_counter)
 | 
				
			||||||
        for _ in range(sampling_params.best_of):
 | 
					        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
 | 
				
			||||||
            seq_id = next(self.seq_counter)
 | 
					 | 
				
			||||||
            seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
 | 
					 | 
				
			||||||
            seqs.append(seq)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create the sequence group.
 | 
					        # Create the sequence group.
 | 
				
			||||||
        seq_group = SequenceGroup(request_id, seqs, sampling_params,
 | 
					        seq_group = SequenceGroup(request_id, [seq], sampling_params,
 | 
				
			||||||
                                  arrival_time)
 | 
					                                  arrival_time)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Add the sequence group to the scheduler.
 | 
					        # Add the sequence group to the scheduler.
 | 
				
			||||||
        self.scheduler.add_seq_group(seq_group)
 | 
					        self.scheduler.add_seq_group(seq_group)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def abort_request(self, request_id: str) -> None:
 | 
					    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
 | 
				
			||||||
        """Aborts a request with the given ID.
 | 
					        """Aborts a request(s) with the given ID.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            request_id: The ID of the request to abort.
 | 
					            request_id: The ID(s) of the request to abort.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.scheduler.abort_seq_group(request_id)
 | 
					        self.scheduler.abort_seq_group(request_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -222,6 +296,255 @@ class LLMEngine:
 | 
				
			|||||||
        """Returns True if there are unfinished requests."""
 | 
					        """Returns True if there are unfinished requests."""
 | 
				
			||||||
        return self.scheduler.has_unfinished_seqs()
 | 
					        return self.scheduler.has_unfinished_seqs()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _schedule(
 | 
				
			||||||
 | 
					        self
 | 
				
			||||||
 | 
					    ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
 | 
				
			||||||
 | 
					               List[RequestOutput]]:
 | 
				
			||||||
 | 
					        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
 | 
				
			||||||
 | 
					        return seq_group_metadata_list, scheduler_outputs, [
 | 
				
			||||||
 | 
					            RequestOutput.from_seq_group(seq_group)
 | 
				
			||||||
 | 
					            for seq_group in scheduler_outputs.ignored_seq_groups
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _check_beam_search_early_stopping(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        early_stopping: Union[bool, str],
 | 
				
			||||||
 | 
					        sampling_params: SamplingParams,
 | 
				
			||||||
 | 
					        best_running_seq: Sequence,
 | 
				
			||||||
 | 
					        current_worst_seq: Sequence,
 | 
				
			||||||
 | 
					    ) -> bool:
 | 
				
			||||||
 | 
					        assert sampling_params.use_beam_search
 | 
				
			||||||
 | 
					        length_penalty = sampling_params.length_penalty
 | 
				
			||||||
 | 
					        if early_stopping is True:
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        current_worst_score = (current_worst_seq.get_beam_search_score(
 | 
				
			||||||
 | 
					            length_penalty=length_penalty,
 | 
				
			||||||
 | 
					            eos_token_id=self.tokenizer.eos_token_id))
 | 
				
			||||||
 | 
					        if early_stopping is False:
 | 
				
			||||||
 | 
					            highest_attainable_score = (best_running_seq.get_beam_search_score(
 | 
				
			||||||
 | 
					                length_penalty=length_penalty,
 | 
				
			||||||
 | 
					                eos_token_id=self.tokenizer.eos_token_id))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            assert early_stopping == "never"
 | 
				
			||||||
 | 
					            if length_penalty > 0.0:
 | 
				
			||||||
 | 
					                # If length_penalty > 0.0, beam search will prefer longer
 | 
				
			||||||
 | 
					                # sequences. The highest attainable score calculation is
 | 
				
			||||||
 | 
					                # based on the longest possible sequence length in this case.
 | 
				
			||||||
 | 
					                max_possible_length = max(
 | 
				
			||||||
 | 
					                    best_running_seq.get_prompt_len() +
 | 
				
			||||||
 | 
					                    sampling_params.max_tokens,
 | 
				
			||||||
 | 
					                    self.scheduler_config.max_model_len)
 | 
				
			||||||
 | 
					                highest_attainable_score = (
 | 
				
			||||||
 | 
					                    best_running_seq.get_beam_search_score(
 | 
				
			||||||
 | 
					                        length_penalty=length_penalty,
 | 
				
			||||||
 | 
					                        eos_token_id=self.tokenizer.eos_token_id,
 | 
				
			||||||
 | 
					                        seq_len=max_possible_length))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # Otherwise, beam search will prefer shorter sequences. The
 | 
				
			||||||
 | 
					                # highest attainable score calculation is based on the current
 | 
				
			||||||
 | 
					                # sequence length.
 | 
				
			||||||
 | 
					                highest_attainable_score = (
 | 
				
			||||||
 | 
					                    best_running_seq.get_beam_search_score(
 | 
				
			||||||
 | 
					                        length_penalty=length_penalty,
 | 
				
			||||||
 | 
					                        eos_token_id=self.tokenizer.eos_token_id))
 | 
				
			||||||
 | 
					        return current_worst_score >= highest_attainable_score
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
 | 
				
			||||||
 | 
					                                        outputs: SequenceGroupOutputs) -> None:
 | 
				
			||||||
 | 
					        # Process prompt logprobs
 | 
				
			||||||
 | 
					        prompt_logprobs = outputs.prompt_logprobs
 | 
				
			||||||
 | 
					        if prompt_logprobs is not None:
 | 
				
			||||||
 | 
					            seq_group.prompt_logprobs = prompt_logprobs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Process samples
 | 
				
			||||||
 | 
					        samples = outputs.samples
 | 
				
			||||||
 | 
					        parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
 | 
				
			||||||
 | 
					        existing_finished_seqs = seq_group.get_finished_seqs()
 | 
				
			||||||
 | 
					        parent_child_dict = {
 | 
				
			||||||
 | 
					            parent_seq.seq_id: []
 | 
				
			||||||
 | 
					            for parent_seq in parent_seqs
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        for sample in samples:
 | 
				
			||||||
 | 
					            parent_child_dict[sample.parent_seq_id].append(sample)
 | 
				
			||||||
 | 
					        # List of (child, parent)
 | 
				
			||||||
 | 
					        child_seqs: List[Tuple[Sequence, Sequence]] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Process the child samples for each parent sequence
 | 
				
			||||||
 | 
					        for parent in parent_seqs:
 | 
				
			||||||
 | 
					            child_samples: List[SequenceOutputs] = parent_child_dict[
 | 
				
			||||||
 | 
					                parent.seq_id]
 | 
				
			||||||
 | 
					            if len(child_samples) == 0:
 | 
				
			||||||
 | 
					                # This parent sequence has no children samples. Remove
 | 
				
			||||||
 | 
					                # the parent sequence from the sequence group since it will
 | 
				
			||||||
 | 
					                # not be used in the future iterations.
 | 
				
			||||||
 | 
					                parent.status = SequenceStatus.FINISHED_ABORTED
 | 
				
			||||||
 | 
					                seq_group.remove(parent.seq_id)
 | 
				
			||||||
 | 
					                self.scheduler.free_seq(parent)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            # Fork the parent sequence if there are multiple child samples.
 | 
				
			||||||
 | 
					            for child_sample in child_samples[:-1]:
 | 
				
			||||||
 | 
					                new_child_seq_id = next(self.seq_counter)
 | 
				
			||||||
 | 
					                child = parent.fork(new_child_seq_id)
 | 
				
			||||||
 | 
					                child.append_token_id(child_sample.output_token,
 | 
				
			||||||
 | 
					                                      child_sample.logprobs)
 | 
				
			||||||
 | 
					                child_seqs.append((child, parent))
 | 
				
			||||||
 | 
					            # Continue the parent sequence for the last child sample.
 | 
				
			||||||
 | 
					            # We reuse the parent sequence here to reduce redundant memory
 | 
				
			||||||
 | 
					            # copies, especially when using non-beam search sampling methods.
 | 
				
			||||||
 | 
					            last_child_sample = child_samples[-1]
 | 
				
			||||||
 | 
					            parent.append_token_id(last_child_sample.output_token,
 | 
				
			||||||
 | 
					                                   last_child_sample.logprobs)
 | 
				
			||||||
 | 
					            child_seqs.append((parent, parent))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for seq, _ in child_seqs:
 | 
				
			||||||
 | 
					            self._decode_sequence(seq, seq_group.sampling_params)
 | 
				
			||||||
 | 
					            self._check_stop(seq, seq_group.sampling_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Non-beam search case
 | 
				
			||||||
 | 
					        if not seq_group.sampling_params.use_beam_search:
 | 
				
			||||||
 | 
					            # For newly created child sequences, add them to the sequence group
 | 
				
			||||||
 | 
					            # and fork them in block manager if they are not finished.
 | 
				
			||||||
 | 
					            for seq, parent in child_seqs:
 | 
				
			||||||
 | 
					                if seq is not parent:
 | 
				
			||||||
 | 
					                    seq_group.add(seq)
 | 
				
			||||||
 | 
					                    if not seq.is_finished():
 | 
				
			||||||
 | 
					                        self.scheduler.fork_seq(parent, seq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Free the finished and selected parent sequences' memory in block
 | 
				
			||||||
 | 
					            # manager. Keep them in the sequence group as candidate output.
 | 
				
			||||||
 | 
					            # NOTE: we need to fork the new sequences before freeing the
 | 
				
			||||||
 | 
					            # old sequences.
 | 
				
			||||||
 | 
					            for seq, parent in child_seqs:
 | 
				
			||||||
 | 
					                if seq is parent and seq.is_finished():
 | 
				
			||||||
 | 
					                    self.scheduler.free_seq(seq)
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Beam search case
 | 
				
			||||||
 | 
					        # Select the child sequences to keep in the sequence group.
 | 
				
			||||||
 | 
					        selected_child_seqs = []
 | 
				
			||||||
 | 
					        unselected_child_seqs = []
 | 
				
			||||||
 | 
					        beam_width = seq_group.sampling_params.best_of
 | 
				
			||||||
 | 
					        length_penalty = seq_group.sampling_params.length_penalty
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Select the newly finished sequences with the highest scores
 | 
				
			||||||
 | 
					        # to replace existing finished sequences.
 | 
				
			||||||
 | 
					        # Tuple of (seq, parent, is_new)
 | 
				
			||||||
 | 
					        existing_finished_seqs = [(seq, None, False)
 | 
				
			||||||
 | 
					                                  for seq in existing_finished_seqs]
 | 
				
			||||||
 | 
					        new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
 | 
				
			||||||
 | 
					                             if seq.is_finished()]
 | 
				
			||||||
 | 
					        all_finished_seqs = existing_finished_seqs + new_finished_seqs
 | 
				
			||||||
 | 
					        # Sort the finished sequences by their scores.
 | 
				
			||||||
 | 
					        all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
 | 
				
			||||||
 | 
					            length_penalty=length_penalty,
 | 
				
			||||||
 | 
					            eos_token_id=self.tokenizer.eos_token_id),
 | 
				
			||||||
 | 
					                               reverse=True)
 | 
				
			||||||
 | 
					        for seq, parent, is_new in all_finished_seqs[:beam_width]:
 | 
				
			||||||
 | 
					            if is_new:
 | 
				
			||||||
 | 
					                # A newly generated child sequence finishes and has a high
 | 
				
			||||||
 | 
					                # score, so we will add it into the sequence group.
 | 
				
			||||||
 | 
					                selected_child_seqs.append((seq, parent))
 | 
				
			||||||
 | 
					        for seq, parent, is_new in all_finished_seqs[beam_width:]:
 | 
				
			||||||
 | 
					            if is_new:
 | 
				
			||||||
 | 
					                # A newly generated child sequence finishes but has a low
 | 
				
			||||||
 | 
					                # score, so we will not add it into the sequence group.
 | 
				
			||||||
 | 
					                # Additionally, if this sequence is a continuation of a
 | 
				
			||||||
 | 
					                # parent sequence, we will need remove the parent sequence
 | 
				
			||||||
 | 
					                # from the sequence group.
 | 
				
			||||||
 | 
					                unselected_child_seqs.append((seq, parent))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # An existing finished sequence has a low score, so we will
 | 
				
			||||||
 | 
					                # remove it from the sequence group.
 | 
				
			||||||
 | 
					                seq_group.remove(seq.seq_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # select the top beam_width sequences from the running
 | 
				
			||||||
 | 
					        # sequences for the next iteration to continue the beam
 | 
				
			||||||
 | 
					        # search.
 | 
				
			||||||
 | 
					        running_child_seqs = [(seq, parent) for seq, parent in child_seqs
 | 
				
			||||||
 | 
					                              if not seq.is_finished()]
 | 
				
			||||||
 | 
					        # Sort the running sequences by their scores.
 | 
				
			||||||
 | 
					        running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
 | 
				
			||||||
 | 
					            length_penalty=length_penalty,
 | 
				
			||||||
 | 
					            eos_token_id=self.tokenizer.eos_token_id),
 | 
				
			||||||
 | 
					                                reverse=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Check if we can stop the beam search.
 | 
				
			||||||
 | 
					        if len(running_child_seqs) == 0:
 | 
				
			||||||
 | 
					            # No running sequences, stop the beam search.
 | 
				
			||||||
 | 
					            stop_beam_search = True
 | 
				
			||||||
 | 
					        elif len(all_finished_seqs) < beam_width:
 | 
				
			||||||
 | 
					            # Not enough finished sequences, continue the beam search.
 | 
				
			||||||
 | 
					            stop_beam_search = False
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Check the early stopping criteria
 | 
				
			||||||
 | 
					            best_running_seq = running_child_seqs[0][0]
 | 
				
			||||||
 | 
					            current_worst_seq = all_finished_seqs[beam_width - 1][0]
 | 
				
			||||||
 | 
					            stop_beam_search = self._check_beam_search_early_stopping(
 | 
				
			||||||
 | 
					                seq_group.sampling_params.early_stopping,
 | 
				
			||||||
 | 
					                seq_group.sampling_params, best_running_seq, current_worst_seq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if stop_beam_search:
 | 
				
			||||||
 | 
					            # Stop the beam search and remove all the running sequences from
 | 
				
			||||||
 | 
					            # the sequence group.
 | 
				
			||||||
 | 
					            unselected_child_seqs.extend(running_child_seqs)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Continue the beam search and select the top beam_width sequences
 | 
				
			||||||
 | 
					            # to continue the beam search.
 | 
				
			||||||
 | 
					            selected_child_seqs.extend(running_child_seqs[:beam_width])
 | 
				
			||||||
 | 
					            # The remaining running sequences will not be used in the next
 | 
				
			||||||
 | 
					            # iteration. Again, if these sequences are continuations of
 | 
				
			||||||
 | 
					            # parent sequences, we will need to remove the parent sequences
 | 
				
			||||||
 | 
					            # from the sequence group.
 | 
				
			||||||
 | 
					            unselected_child_seqs.extend(running_child_seqs[beam_width:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # For newly created child sequences, add them to the sequence group
 | 
				
			||||||
 | 
					        # and fork them in block manager if they are not finished.
 | 
				
			||||||
 | 
					        for seq, parent in selected_child_seqs:
 | 
				
			||||||
 | 
					            if seq is not parent:
 | 
				
			||||||
 | 
					                seq_group.add(seq)
 | 
				
			||||||
 | 
					                if not seq.is_finished():
 | 
				
			||||||
 | 
					                    self.scheduler.fork_seq(parent, seq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Free the finished and selected parent sequences' memory in block
 | 
				
			||||||
 | 
					        # manager. Keep them in the sequence group as candidate output.
 | 
				
			||||||
 | 
					        for seq, parent in selected_child_seqs:
 | 
				
			||||||
 | 
					            if seq is parent and seq.is_finished():
 | 
				
			||||||
 | 
					                self.scheduler.free_seq(seq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Remove the unselected parent sequences from the sequence group and
 | 
				
			||||||
 | 
					        # free their memory in block manager.
 | 
				
			||||||
 | 
					        for seq, parent in unselected_child_seqs:
 | 
				
			||||||
 | 
					            if seq is parent:
 | 
				
			||||||
 | 
					                # Remove the parent sequence if it is not selected for next
 | 
				
			||||||
 | 
					                # iteration
 | 
				
			||||||
 | 
					                seq_group.remove(seq.seq_id)
 | 
				
			||||||
 | 
					                self.scheduler.free_seq(seq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _process_model_outputs(
 | 
				
			||||||
 | 
					            self, output: SamplerOutput,
 | 
				
			||||||
 | 
					            scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
 | 
				
			||||||
 | 
					        # Update the scheduled sequence groups with the model outputs.
 | 
				
			||||||
 | 
					        scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
 | 
				
			||||||
 | 
					        for seq_group, outputs in zip(scheduled_seq_groups, output):
 | 
				
			||||||
 | 
					            self._process_sequence_group_outputs(seq_group, outputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Free the finished sequence groups.
 | 
				
			||||||
 | 
					        self.scheduler.free_finished_seq_groups()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Create the outputs.
 | 
				
			||||||
 | 
					        request_outputs: List[RequestOutput] = []
 | 
				
			||||||
 | 
					        for seq_group in (scheduled_seq_groups +
 | 
				
			||||||
 | 
					                          scheduler_outputs.ignored_seq_groups):
 | 
				
			||||||
 | 
					            request_output = RequestOutput.from_seq_group(seq_group)
 | 
				
			||||||
 | 
					            request_outputs.append(request_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.log_stats:
 | 
				
			||||||
 | 
					            # Log the system stats.
 | 
				
			||||||
 | 
					            self._log_system_stats(scheduler_outputs.prompt_run,
 | 
				
			||||||
 | 
					                                   scheduler_outputs.num_batched_tokens)
 | 
				
			||||||
 | 
					        return request_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def step(self) -> List[RequestOutput]:
 | 
					    def step(self) -> List[RequestOutput]:
 | 
				
			||||||
        """Performs one decoding iteration and returns newly generated results.
 | 
					        """Performs one decoding iteration and returns newly generated results.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -231,12 +554,9 @@ class LLMEngine:
 | 
				
			|||||||
        and updates the scheduler with the model outputs. Finally, it decodes
 | 
					        and updates the scheduler with the model outputs. Finally, it decodes
 | 
				
			||||||
        the sequences and returns the newly generated results.
 | 
					        the sequences and returns the newly generated results.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        (seq_group_metadata_list, scheduler_outputs,
 | 
					        seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
 | 
				
			||||||
         ignored_seq_groups) = self.scheduler.schedule()
 | 
					        if scheduler_outputs.is_empty():
 | 
				
			||||||
        if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
 | 
					            return ignored
 | 
				
			||||||
                and (not ignored_seq_groups)):
 | 
					 | 
				
			||||||
            # Nothing to do.
 | 
					 | 
				
			||||||
            return []
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Execute the model.
 | 
					        # Execute the model.
 | 
				
			||||||
        output = self._run_workers(
 | 
					        output = self._run_workers(
 | 
				
			||||||
@ -246,72 +566,121 @@ class LLMEngine:
 | 
				
			|||||||
            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
 | 
					            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
 | 
				
			||||||
            blocks_to_copy=scheduler_outputs.blocks_to_copy,
 | 
					            blocks_to_copy=scheduler_outputs.blocks_to_copy,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        # Update the scheduler with the model outputs.
 | 
					 | 
				
			||||||
        seq_groups = self.scheduler.update(output)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Decode the sequences.
 | 
					        return self._process_model_outputs(output, scheduler_outputs) + ignored
 | 
				
			||||||
        self._decode_sequences(seq_groups)
 | 
					 | 
				
			||||||
        # Stop the sequences that meet the stopping criteria.
 | 
					 | 
				
			||||||
        self._stop_sequences(seq_groups)
 | 
					 | 
				
			||||||
        # Free the finished sequence groups.
 | 
					 | 
				
			||||||
        self.scheduler.free_finished_seq_groups()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create the outputs.
 | 
					    def _log_system_stats(
 | 
				
			||||||
        request_outputs: List[RequestOutput] = []
 | 
					        self,
 | 
				
			||||||
        for seq_group in seq_groups + ignored_seq_groups:
 | 
					        prompt_run: bool,
 | 
				
			||||||
            request_output = RequestOutput.from_seq_group(seq_group)
 | 
					        num_batched_tokens: int,
 | 
				
			||||||
            request_outputs.append(request_output)
 | 
					    ) -> None:
 | 
				
			||||||
        return request_outputs
 | 
					        now = time.monotonic()
 | 
				
			||||||
 | 
					        # Log the number of batched input tokens.
 | 
				
			||||||
 | 
					        if prompt_run:
 | 
				
			||||||
 | 
					            self.num_prompt_tokens.append((now, num_batched_tokens))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.num_generation_tokens.append((now, num_batched_tokens))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
 | 
					        elapsed_time = now - self.last_logging_time
 | 
				
			||||||
        """Decodes the sequence outputs."""
 | 
					        if elapsed_time < _LOGGING_INTERVAL_SEC:
 | 
				
			||||||
        for seq_group in seq_groups:
 | 
					            return
 | 
				
			||||||
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
 | 
					 | 
				
			||||||
                new_token, new_output_text = detokenize_incrementally(
 | 
					 | 
				
			||||||
                    self.tokenizer,
 | 
					 | 
				
			||||||
                    seq.output_tokens,
 | 
					 | 
				
			||||||
                    seq.get_last_token_id(),
 | 
					 | 
				
			||||||
                    skip_special_tokens=True,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                seq.output_tokens.append(new_token)
 | 
					 | 
				
			||||||
                seq.output_text = new_output_text
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
 | 
					        # Discard the old stats.
 | 
				
			||||||
 | 
					        self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
 | 
				
			||||||
 | 
					                                  if now - t < _LOGGING_INTERVAL_SEC]
 | 
				
			||||||
 | 
					        self.num_generation_tokens = [(t, n)
 | 
				
			||||||
 | 
					                                      for t, n in self.num_generation_tokens
 | 
				
			||||||
 | 
					                                      if now - t < _LOGGING_INTERVAL_SEC]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(self.num_prompt_tokens) > 1:
 | 
				
			||||||
 | 
					            total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
 | 
				
			||||||
 | 
					            window = now - self.num_prompt_tokens[0][0]
 | 
				
			||||||
 | 
					            avg_prompt_throughput = total_num_tokens / window
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            avg_prompt_throughput = 0.0
 | 
				
			||||||
 | 
					        if len(self.num_generation_tokens) > 1:
 | 
				
			||||||
 | 
					            total_num_tokens = sum(n
 | 
				
			||||||
 | 
					                                   for _, n in self.num_generation_tokens[:-1])
 | 
				
			||||||
 | 
					            window = now - self.num_generation_tokens[0][0]
 | 
				
			||||||
 | 
					            avg_generation_throughput = total_num_tokens / window
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            avg_generation_throughput = 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        total_num_gpu_blocks = self.cache_config.num_gpu_blocks
 | 
				
			||||||
 | 
					        num_free_gpu_blocks = (
 | 
				
			||||||
 | 
					            self.scheduler.block_manager.get_num_free_gpu_blocks())
 | 
				
			||||||
 | 
					        num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
 | 
				
			||||||
 | 
					        gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        total_num_cpu_blocks = self.cache_config.num_cpu_blocks
 | 
				
			||||||
 | 
					        if total_num_cpu_blocks > 0:
 | 
				
			||||||
 | 
					            num_free_cpu_blocks = (
 | 
				
			||||||
 | 
					                self.scheduler.block_manager.get_num_free_cpu_blocks())
 | 
				
			||||||
 | 
					            num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
 | 
				
			||||||
 | 
					            cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            cpu_cache_usage = 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logger.info("Avg prompt throughput: "
 | 
				
			||||||
 | 
					                    f"{avg_prompt_throughput:.1f} tokens/s, "
 | 
				
			||||||
 | 
					                    "Avg generation throughput: "
 | 
				
			||||||
 | 
					                    f"{avg_generation_throughput:.1f} tokens/s, "
 | 
				
			||||||
 | 
					                    f"Running: {len(self.scheduler.running)} reqs, "
 | 
				
			||||||
 | 
					                    f"Swapped: {len(self.scheduler.swapped)} reqs, "
 | 
				
			||||||
 | 
					                    f"Pending: {len(self.scheduler.waiting)} reqs, "
 | 
				
			||||||
 | 
					                    f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
 | 
				
			||||||
 | 
					                    f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
 | 
				
			||||||
 | 
					        self.last_logging_time = now
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _decode_sequence(self, seq: Sequence,
 | 
				
			||||||
 | 
					                         sampling_params: SamplingParams) -> None:
 | 
				
			||||||
 | 
					        """Decodes the new token for a sequence."""
 | 
				
			||||||
 | 
					        (new_tokens, new_output_text, prefix_offset,
 | 
				
			||||||
 | 
					         read_offset) = detokenize_incrementally(
 | 
				
			||||||
 | 
					             self.tokenizer,
 | 
				
			||||||
 | 
					             all_input_ids=seq.get_token_ids(),
 | 
				
			||||||
 | 
					             prev_tokens=seq.tokens,
 | 
				
			||||||
 | 
					             prefix_offset=seq.prefix_offset,
 | 
				
			||||||
 | 
					             read_offset=seq.read_offset,
 | 
				
			||||||
 | 
					             skip_special_tokens=sampling_params.skip_special_tokens,
 | 
				
			||||||
 | 
					         )
 | 
				
			||||||
 | 
					        if seq.tokens is None:
 | 
				
			||||||
 | 
					            seq.tokens = new_tokens
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            seq.tokens.extend(new_tokens)
 | 
				
			||||||
 | 
					        seq.prefix_offset = prefix_offset
 | 
				
			||||||
 | 
					        seq.read_offset = read_offset
 | 
				
			||||||
 | 
					        seq.output_text += new_output_text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _check_stop(self, seq: Sequence,
 | 
				
			||||||
 | 
					                    sampling_params: SamplingParams) -> None:
 | 
				
			||||||
        """Stop the finished sequences."""
 | 
					        """Stop the finished sequences."""
 | 
				
			||||||
        for seq_group in seq_groups:
 | 
					        for stop_str in sampling_params.stop:
 | 
				
			||||||
            sampling_params = seq_group.sampling_params
 | 
					            if seq.output_text.endswith(stop_str):
 | 
				
			||||||
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
 | 
					                # Truncate the output text so that the stop string is
 | 
				
			||||||
                # Check if the sequence has generated a stop string.
 | 
					                # not included in the output.
 | 
				
			||||||
                stopped = False
 | 
					                seq.output_text = seq.output_text[:-len(stop_str)]
 | 
				
			||||||
                for stop_str in sampling_params.stop:
 | 
					                seq.status = SequenceStatus.FINISHED_STOPPED
 | 
				
			||||||
                    if seq.output_text.endswith(stop_str):
 | 
					                return
 | 
				
			||||||
                        # Truncate the output text so that the stop string is
 | 
					        if seq.get_last_token_id() in sampling_params.stop_token_ids:
 | 
				
			||||||
                        # not included in the output.
 | 
					            seq.status = SequenceStatus.FINISHED_STOPPED
 | 
				
			||||||
                        seq.output_text = seq.output_text[:-len(stop_str)]
 | 
					            return
 | 
				
			||||||
                        self.scheduler.free_seq(
 | 
					 | 
				
			||||||
                            seq, SequenceStatus.FINISHED_STOPPED)
 | 
					 | 
				
			||||||
                        stopped = True
 | 
					 | 
				
			||||||
                        break
 | 
					 | 
				
			||||||
                if stopped:
 | 
					 | 
				
			||||||
                    continue
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # Check if the sequence has reached max_seq_len.
 | 
					        # Check if the sequence has reached max_model_len.
 | 
				
			||||||
                if (seq.get_len() >=
 | 
					        if seq.get_len() > self.scheduler_config.max_model_len:
 | 
				
			||||||
                        self.scheduler.scheduler_config.max_seq_len):
 | 
					            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
 | 
				
			||||||
                    self.scheduler.free_seq(
 | 
					            return
 | 
				
			||||||
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
 | 
					
 | 
				
			||||||
                    continue
 | 
					        # Check if the sequence has reached max_tokens.
 | 
				
			||||||
                # Check if the sequence has reached max_tokens.
 | 
					        if seq.get_output_len() == sampling_params.max_tokens:
 | 
				
			||||||
                if seq.get_output_len() == sampling_params.max_tokens:
 | 
					            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
 | 
				
			||||||
                    self.scheduler.free_seq(
 | 
					            return
 | 
				
			||||||
                        seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
 | 
					
 | 
				
			||||||
                    continue
 | 
					        # Check if the sequence has generated the EOS token.
 | 
				
			||||||
                # Check if the sequence has generated the EOS token.
 | 
					        if ((not sampling_params.ignore_eos)
 | 
				
			||||||
                if not sampling_params.ignore_eos:
 | 
					                and seq.get_last_token_id() == self.tokenizer.eos_token_id):
 | 
				
			||||||
                    if seq.get_last_token_id() == self.tokenizer.eos_token_id:
 | 
					            seq.status = SequenceStatus.FINISHED_STOPPED
 | 
				
			||||||
                        self.scheduler.free_seq(
 | 
					            return
 | 
				
			||||||
                            seq, SequenceStatus.FINISHED_STOPPED)
 | 
					 | 
				
			||||||
                        continue
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _run_workers(
 | 
					    def _run_workers(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
@ -323,9 +692,10 @@ class LLMEngine:
 | 
				
			|||||||
        """Runs the given method on all workers."""
 | 
					        """Runs the given method on all workers."""
 | 
				
			||||||
        all_outputs = []
 | 
					        all_outputs = []
 | 
				
			||||||
        for worker in self.workers:
 | 
					        for worker in self.workers:
 | 
				
			||||||
            executor = getattr(worker, method)
 | 
					 | 
				
			||||||
            if self.parallel_config.worker_use_ray:
 | 
					            if self.parallel_config.worker_use_ray:
 | 
				
			||||||
                executor = executor.remote
 | 
					                executor = partial(worker.execute_method.remote, method)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                executor = getattr(worker, method)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            output = executor(*args, **kwargs)
 | 
					            output = executor(*args, **kwargs)
 | 
				
			||||||
            all_outputs.append(output)
 | 
					            all_outputs.append(output)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,22 +1,59 @@
 | 
				
			|||||||
import random
 | 
					import socket
 | 
				
			||||||
from typing import List, Optional, Tuple
 | 
					from typing import Optional, Tuple, TYPE_CHECKING
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm.config import ParallelConfig
 | 
				
			||||||
 | 
					from vllm.logger import init_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = init_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    import ray
 | 
					    import ray
 | 
				
			||||||
except ImportError:
 | 
					    from ray.air.util.torch_dist import TorchDistributedWorker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class RayWorker(TorchDistributedWorker):
 | 
				
			||||||
 | 
					        """Ray wrapper for vllm.worker.Worker, allowing Worker to be
 | 
				
			||||||
 | 
					        lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def __init__(self, init_cached_hf_modules=False) -> None:
 | 
				
			||||||
 | 
					            if init_cached_hf_modules:
 | 
				
			||||||
 | 
					                # pylint: disable=import-outside-toplevel
 | 
				
			||||||
 | 
					                from transformers.dynamic_module_utils import init_hf_modules
 | 
				
			||||||
 | 
					                init_hf_modules()
 | 
				
			||||||
 | 
					            self.worker = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def init_worker(self, worker_init_fn):
 | 
				
			||||||
 | 
					            self.worker = worker_init_fn()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def __getattr__(self, name):
 | 
				
			||||||
 | 
					            return getattr(self.worker, name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def execute_method(self, method, *args, **kwargs):
 | 
				
			||||||
 | 
					            executor = getattr(self, method)
 | 
				
			||||||
 | 
					            return executor(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					except ImportError as e:
 | 
				
			||||||
 | 
					    logger.warning(f"Failed to import Ray with {e!r}. "
 | 
				
			||||||
 | 
					                   "For distributed inference, please install Ray with "
 | 
				
			||||||
 | 
					                   "`pip install ray pandas pyarrow`.")
 | 
				
			||||||
    ray = None
 | 
					    ray = None
 | 
				
			||||||
 | 
					    TorchDistributedWorker = None
 | 
				
			||||||
 | 
					    RayWorker = None  # pylint: disable=invalid-name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm.config import ParallelConfig
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
 | 
					    from ray.util.placement_group import PlacementGroup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# rank, node resource (node IP), device id
 | 
					
 | 
				
			||||||
DeviceID = Tuple[int, Optional[str], int]
 | 
					def get_open_port():
 | 
				
			||||||
 | 
					    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
 | 
				
			||||||
 | 
					        s.bind(("", 0))
 | 
				
			||||||
 | 
					        return s.getsockname()[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def initialize_cluster(
 | 
					def initialize_cluster(
 | 
				
			||||||
    parallel_config: ParallelConfig,
 | 
					    parallel_config: ParallelConfig,
 | 
				
			||||||
    engine_use_ray: bool = False,
 | 
					    engine_use_ray: bool = False,
 | 
				
			||||||
    ray_address: Optional[str] = None,
 | 
					    ray_address: Optional[str] = None,
 | 
				
			||||||
) -> Tuple[str, List[List[DeviceID]]]:
 | 
					) -> Tuple[str, Optional["PlacementGroup"]]:
 | 
				
			||||||
    """Initialize the distributed cluster probably with Ray.
 | 
					    """Initialize the distributed cluster probably with Ray.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
@ -26,11 +63,10 @@ def initialize_cluster(
 | 
				
			|||||||
            the default Ray cluster address.
 | 
					            the default Ray cluster address.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Returns:
 | 
					    Returns:
 | 
				
			||||||
        A tuple of (`distributed_init_method`, `all_stage_devices`). The
 | 
					        A tuple of (`distributed_init_method`, `placement_group`). The
 | 
				
			||||||
        `distributed_init_method` is the address for initializing the
 | 
					        `distributed_init_method` is the address for initializing the
 | 
				
			||||||
        distributed backend. `all_stage_devices` includes device IDs for
 | 
					        distributed backend. `placement_group` includes the specification
 | 
				
			||||||
        each worker in each pipeline stage. Each device ID is a tuple of
 | 
					        of the resources for each distributed worker.
 | 
				
			||||||
        (rank, node resource, device id).
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if parallel_config.worker_use_ray or engine_use_ray:
 | 
					    if parallel_config.worker_use_ray or engine_use_ray:
 | 
				
			||||||
        if ray is None:
 | 
					        if ray is None:
 | 
				
			||||||
@ -38,71 +74,46 @@ def initialize_cluster(
 | 
				
			|||||||
                "Ray is not installed. Please install Ray to use distributed "
 | 
					                "Ray is not installed. Please install Ray to use distributed "
 | 
				
			||||||
                "serving.")
 | 
					                "serving.")
 | 
				
			||||||
        # Connect to a ray cluster.
 | 
					        # Connect to a ray cluster.
 | 
				
			||||||
        ray.init(address=ray_address)
 | 
					        ray.init(address=ray_address, ignore_reinit_error=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not parallel_config.worker_use_ray:
 | 
					    if not parallel_config.worker_use_ray:
 | 
				
			||||||
        # Initialize cluster locally.
 | 
					        # Initialize cluster locally.
 | 
				
			||||||
        port = random.randint(10000, 20000)
 | 
					        port = get_open_port()
 | 
				
			||||||
        # We need to setup the distributed init method to make sure
 | 
					        # We need to setup the distributed init method to make sure
 | 
				
			||||||
        # the distributed megatron code (e.g., get world size) works correctly.
 | 
					        # the distributed megatron code (e.g., get world size) works correctly.
 | 
				
			||||||
        distributed_init_method = f"tcp://localhost:{port}"
 | 
					        distributed_init_method = f"tcp://localhost:{port}"
 | 
				
			||||||
        all_stage_devices = [[(0, None, 0)]]
 | 
					        return distributed_init_method, None
 | 
				
			||||||
        return distributed_init_method, all_stage_devices
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Assume we have a uniform cluster that each node has the same number of
 | 
					    current_placement_group = ray.util.get_current_placement_group()
 | 
				
			||||||
    # GPUs for now.
 | 
					    if current_placement_group:
 | 
				
			||||||
    valid_node_resources = []
 | 
					        # We are in a placement group
 | 
				
			||||||
    num_devices_per_node = None
 | 
					        bundles = current_placement_group.bundle_specs
 | 
				
			||||||
    for node in ray.nodes():
 | 
					        # Verify that we can use the placement group.
 | 
				
			||||||
        if (not node["Alive"]) or node["Resources"]["GPU"] <= 0:
 | 
					        gpu_bundles = 0
 | 
				
			||||||
            continue
 | 
					        for bundle in bundles:
 | 
				
			||||||
        if num_devices_per_node is None:
 | 
					            bundle_gpus = bundle.get("GPU", 0)
 | 
				
			||||||
            num_devices_per_node = node["Resources"]["GPU"]
 | 
					            if bundle_gpus > 1:
 | 
				
			||||||
        else:
 | 
					                raise ValueError(
 | 
				
			||||||
            assert num_devices_per_node == node["Resources"]["GPU"], (
 | 
					                    "Placement group bundle cannot have more than 1 GPU.")
 | 
				
			||||||
                "The number of GPUs per node is not uniform.")
 | 
					            if bundle_gpus:
 | 
				
			||||||
        for key in node["Resources"]:
 | 
					                gpu_bundles += 1
 | 
				
			||||||
            if key.startswith("node:"):
 | 
					        if parallel_config.world_size > gpu_bundles:
 | 
				
			||||||
                valid_node_resources.append(key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Verify the parallel config.
 | 
					 | 
				
			||||||
    num_nodes = len(valid_node_resources)
 | 
					 | 
				
			||||||
    if parallel_config.world_size > num_nodes * num_devices_per_node:
 | 
					 | 
				
			||||||
        raise ValueError(
 | 
					 | 
				
			||||||
            "The number of required GPUs exceeds the total number of "
 | 
					 | 
				
			||||||
            "available GPUs.")
 | 
					 | 
				
			||||||
    if parallel_config.tensor_parallel_size >= num_devices_per_node:
 | 
					 | 
				
			||||||
        if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					            raise ValueError(
 | 
				
			||||||
                "The number of tensor parallelism is not divisible by the "
 | 
					                "The number of required GPUs exceeds the total number of "
 | 
				
			||||||
                "number of GPUs per node.")
 | 
					                "available GPUs in the placement group.")
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        if num_devices_per_node % parallel_config.tensor_parallel_size != 0:
 | 
					        num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
 | 
				
			||||||
 | 
					        if parallel_config.world_size > num_gpus_in_cluster:
 | 
				
			||||||
            raise ValueError(
 | 
					            raise ValueError(
 | 
				
			||||||
                "The number of GPUs per node is not divisible by the number "
 | 
					                "The number of required GPUs exceeds the total number of "
 | 
				
			||||||
                "of tensor parallelism.")
 | 
					                "available GPUs in the cluster.")
 | 
				
			||||||
 | 
					        # Create a new placement group
 | 
				
			||||||
 | 
					        current_placement_group = ray.util.placement_group([{
 | 
				
			||||||
 | 
					            "GPU": 1
 | 
				
			||||||
 | 
					        }] * parallel_config.world_size)
 | 
				
			||||||
 | 
					        # Wait until PG is ready - this will block until all
 | 
				
			||||||
 | 
					        # requested resources are available, and will timeout
 | 
				
			||||||
 | 
					        # if they cannot be provisioned.
 | 
				
			||||||
 | 
					        ray.get(current_placement_group.ready(), timeout=1800)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Assign GPUs to pipeline stages.
 | 
					    return None, current_placement_group
 | 
				
			||||||
    rank = 0
 | 
					 | 
				
			||||||
    current_node_id = 0
 | 
					 | 
				
			||||||
    current_device_id = 0
 | 
					 | 
				
			||||||
    distributed_init_method = None
 | 
					 | 
				
			||||||
    all_stage_devices = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for _ in range(parallel_config.pipeline_parallel_size):
 | 
					 | 
				
			||||||
        stage_devices = []
 | 
					 | 
				
			||||||
        for _ in range(parallel_config.tensor_parallel_size):
 | 
					 | 
				
			||||||
            node_resource = valid_node_resources[current_node_id]
 | 
					 | 
				
			||||||
            stage_devices.append((rank, node_resource, current_device_id))
 | 
					 | 
				
			||||||
            if distributed_init_method is None:
 | 
					 | 
				
			||||||
                ip = node_resource.split("node:")[-1]
 | 
					 | 
				
			||||||
                port = random.randint(10000, 20000)
 | 
					 | 
				
			||||||
                distributed_init_method = f"tcp://{ip}:{port}"
 | 
					 | 
				
			||||||
            rank += 1
 | 
					 | 
				
			||||||
            current_device_id += 1
 | 
					 | 
				
			||||||
            if current_device_id >= num_devices_per_node:
 | 
					 | 
				
			||||||
                current_node_id += 1
 | 
					 | 
				
			||||||
                current_device_id = 0
 | 
					 | 
				
			||||||
        all_stage_devices.append(stage_devices)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return distributed_init_method, all_stage_devices
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -2,8 +2,8 @@ import argparse
 | 
				
			|||||||
import json
 | 
					import json
 | 
				
			||||||
from typing import AsyncGenerator
 | 
					from typing import AsyncGenerator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from fastapi import BackgroundTasks, FastAPI, Request
 | 
					from fastapi import FastAPI, Request
 | 
				
			||||||
from fastapi.responses import Response, StreamingResponse
 | 
					from fastapi.responses import JSONResponse, Response, StreamingResponse
 | 
				
			||||||
import uvicorn
 | 
					import uvicorn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm.engine.arg_utils import AsyncEngineArgs
 | 
					from vllm.engine.arg_utils import AsyncEngineArgs
 | 
				
			||||||
@ -14,6 +14,7 @@ from vllm.utils import random_uuid
 | 
				
			|||||||
TIMEOUT_KEEP_ALIVE = 5  # seconds.
 | 
					TIMEOUT_KEEP_ALIVE = 5  # seconds.
 | 
				
			||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1  # seconds.
 | 
					TIMEOUT_TO_PREVENT_DEADLOCK = 1  # seconds.
 | 
				
			||||||
app = FastAPI()
 | 
					app = FastAPI()
 | 
				
			||||||
 | 
					engine = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.post("/generate")
 | 
					@app.post("/generate")
 | 
				
			||||||
@ -30,6 +31,7 @@ async def generate(request: Request) -> Response:
 | 
				
			|||||||
    stream = request_dict.pop("stream", False)
 | 
					    stream = request_dict.pop("stream", False)
 | 
				
			||||||
    sampling_params = SamplingParams(**request_dict)
 | 
					    sampling_params = SamplingParams(**request_dict)
 | 
				
			||||||
    request_id = random_uuid()
 | 
					    request_id = random_uuid()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    results_generator = engine.generate(prompt, sampling_params, request_id)
 | 
					    results_generator = engine.generate(prompt, sampling_params, request_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Streaming case
 | 
					    # Streaming case
 | 
				
			||||||
@ -42,14 +44,8 @@ async def generate(request: Request) -> Response:
 | 
				
			|||||||
            ret = {"text": text_outputs}
 | 
					            ret = {"text": text_outputs}
 | 
				
			||||||
            yield (json.dumps(ret) + "\0").encode("utf-8")
 | 
					            yield (json.dumps(ret) + "\0").encode("utf-8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def abort_request() -> None:
 | 
					 | 
				
			||||||
        await engine.abort(request_id)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if stream:
 | 
					    if stream:
 | 
				
			||||||
        background_tasks = BackgroundTasks()
 | 
					        return StreamingResponse(stream_results())
 | 
				
			||||||
        # Abort the request if the client disconnects.
 | 
					 | 
				
			||||||
        background_tasks.add_task(abort_request)
 | 
					 | 
				
			||||||
        return StreamingResponse(stream_results(), background=background_tasks)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Non-streaming case
 | 
					    # Non-streaming case
 | 
				
			||||||
    final_output = None
 | 
					    final_output = None
 | 
				
			||||||
@ -64,12 +60,12 @@ async def generate(request: Request) -> Response:
 | 
				
			|||||||
    prompt = final_output.prompt
 | 
					    prompt = final_output.prompt
 | 
				
			||||||
    text_outputs = [prompt + output.text for output in final_output.outputs]
 | 
					    text_outputs = [prompt + output.text for output in final_output.outputs]
 | 
				
			||||||
    ret = {"text": text_outputs}
 | 
					    ret = {"text": text_outputs}
 | 
				
			||||||
    return Response(content=json.dumps(ret))
 | 
					    return JSONResponse(ret)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser.add_argument("--host", type=str, default="localhost")
 | 
					    parser.add_argument("--host", type=str, default=None)
 | 
				
			||||||
    parser.add_argument("--port", type=int, default=8000)
 | 
					    parser.add_argument("--port", type=int, default=8000)
 | 
				
			||||||
    parser = AsyncEngineArgs.add_cli_args(parser)
 | 
					    parser = AsyncEngineArgs.add_cli_args(parser)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
				
			|||||||
@ -28,6 +28,8 @@ class LLM:
 | 
				
			|||||||
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
 | 
					        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
 | 
				
			||||||
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
 | 
					        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
 | 
				
			||||||
            if available, and "slow" will always use the slow tokenizer.
 | 
					            if available, and "slow" will always use the slow tokenizer.
 | 
				
			||||||
 | 
					        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
 | 
				
			||||||
 | 
					            downloading the model and tokenizer.
 | 
				
			||||||
        tensor_parallel_size: The number of GPUs to use for distributed
 | 
					        tensor_parallel_size: The number of GPUs to use for distributed
 | 
				
			||||||
            execution with tensor parallelism.
 | 
					            execution with tensor parallelism.
 | 
				
			||||||
        dtype: The data type for the model weights and activations. Currently,
 | 
					        dtype: The data type for the model weights and activations. Currently,
 | 
				
			||||||
@ -35,7 +37,24 @@ class LLM:
 | 
				
			|||||||
            the `torch_dtype` attribute specified in the model config file.
 | 
					            the `torch_dtype` attribute specified in the model config file.
 | 
				
			||||||
            However, if the `torch_dtype` in the config is `float32`, we will
 | 
					            However, if the `torch_dtype` in the config is `float32`, we will
 | 
				
			||||||
            use `float16` instead.
 | 
					            use `float16` instead.
 | 
				
			||||||
 | 
					        quantization: The method used to quantize the model weights. Currently,
 | 
				
			||||||
 | 
					            we support "awq". If None, we assume the model weights are not
 | 
				
			||||||
 | 
					            quantized and use `dtype` to determine the data type of the weights.
 | 
				
			||||||
 | 
					        revision: The specific model version to use. It can be a branch name,
 | 
				
			||||||
 | 
					            a tag name, or a commit id.
 | 
				
			||||||
 | 
					        tokenizer_revision: The specific tokenizer version to use. It can be a
 | 
				
			||||||
 | 
					            branch name, a tag name, or a commit id.
 | 
				
			||||||
        seed: The seed to initialize the random number generator for sampling.
 | 
					        seed: The seed to initialize the random number generator for sampling.
 | 
				
			||||||
 | 
					        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
 | 
				
			||||||
 | 
					            reserve for the model weights, activations, and KV cache. Higher
 | 
				
			||||||
 | 
					            values will increase the KV cache size and thus improve the model's
 | 
				
			||||||
 | 
					            throughput. However, if the value is too high, it may cause out-of-
 | 
				
			||||||
 | 
					            memory (OOM) errors.
 | 
				
			||||||
 | 
					        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
 | 
				
			||||||
 | 
					            This can be used for temporarily storing the states of the requests
 | 
				
			||||||
 | 
					            when their `best_of` sampling parameters are larger than 1. If all
 | 
				
			||||||
 | 
					            requests will have `best_of=1`, you can safely set this to 0.
 | 
				
			||||||
 | 
					            Otherwise, too small values may cause out-of-memory (OOM) errors.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
@ -43,9 +62,15 @@ class LLM:
 | 
				
			|||||||
        model: str,
 | 
					        model: str,
 | 
				
			||||||
        tokenizer: Optional[str] = None,
 | 
					        tokenizer: Optional[str] = None,
 | 
				
			||||||
        tokenizer_mode: str = "auto",
 | 
					        tokenizer_mode: str = "auto",
 | 
				
			||||||
 | 
					        trust_remote_code: bool = False,
 | 
				
			||||||
        tensor_parallel_size: int = 1,
 | 
					        tensor_parallel_size: int = 1,
 | 
				
			||||||
        dtype: str = "auto",
 | 
					        dtype: str = "auto",
 | 
				
			||||||
 | 
					        quantization: Optional[str] = None,
 | 
				
			||||||
 | 
					        revision: Optional[str] = None,
 | 
				
			||||||
 | 
					        tokenizer_revision: Optional[str] = None,
 | 
				
			||||||
        seed: int = 0,
 | 
					        seed: int = 0,
 | 
				
			||||||
 | 
					        gpu_memory_utilization: float = 0.9,
 | 
				
			||||||
 | 
					        swap_space: int = 4,
 | 
				
			||||||
        **kwargs,
 | 
					        **kwargs,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        if "disable_log_stats" not in kwargs:
 | 
					        if "disable_log_stats" not in kwargs:
 | 
				
			||||||
@ -54,9 +79,15 @@ class LLM:
 | 
				
			|||||||
            model=model,
 | 
					            model=model,
 | 
				
			||||||
            tokenizer=tokenizer,
 | 
					            tokenizer=tokenizer,
 | 
				
			||||||
            tokenizer_mode=tokenizer_mode,
 | 
					            tokenizer_mode=tokenizer_mode,
 | 
				
			||||||
 | 
					            trust_remote_code=trust_remote_code,
 | 
				
			||||||
            tensor_parallel_size=tensor_parallel_size,
 | 
					            tensor_parallel_size=tensor_parallel_size,
 | 
				
			||||||
            dtype=dtype,
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            quantization=quantization,
 | 
				
			||||||
 | 
					            revision=revision,
 | 
				
			||||||
 | 
					            tokenizer_revision=tokenizer_revision,
 | 
				
			||||||
            seed=seed,
 | 
					            seed=seed,
 | 
				
			||||||
 | 
					            gpu_memory_utilization=gpu_memory_utilization,
 | 
				
			||||||
 | 
					            swap_space=swap_space,
 | 
				
			||||||
            **kwargs,
 | 
					            **kwargs,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.llm_engine = LLMEngine.from_engine_args(engine_args)
 | 
					        self.llm_engine = LLMEngine.from_engine_args(engine_args)
 | 
				
			||||||
@ -151,4 +182,8 @@ class LLM:
 | 
				
			|||||||
                        pbar.update(1)
 | 
					                        pbar.update(1)
 | 
				
			||||||
        if use_tqdm:
 | 
					        if use_tqdm:
 | 
				
			||||||
            pbar.close()
 | 
					            pbar.close()
 | 
				
			||||||
 | 
					        # Sort the outputs by request ID.
 | 
				
			||||||
 | 
					        # This is necessary because some requests may be finished earlier than
 | 
				
			||||||
 | 
					        # its previous requests.
 | 
				
			||||||
 | 
					        outputs = sorted(outputs, key=lambda x: int(x.request_id))
 | 
				
			||||||
        return outputs
 | 
					        return outputs
 | 
				
			||||||
 | 
				
			|||||||
@ -3,20 +3,18 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
from http import HTTPStatus
 | 
					 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from typing import AsyncGenerator, Dict, List, Optional
 | 
					from http import HTTPStatus
 | 
				
			||||||
 | 
					from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import fastapi
 | 
					import fastapi
 | 
				
			||||||
from fastapi import BackgroundTasks, Request
 | 
					import uvicorn
 | 
				
			||||||
 | 
					from fastapi import Request
 | 
				
			||||||
from fastapi.exceptions import RequestValidationError
 | 
					from fastapi.exceptions import RequestValidationError
 | 
				
			||||||
from fastapi.middleware.cors import CORSMiddleware
 | 
					from fastapi.middleware.cors import CORSMiddleware
 | 
				
			||||||
from fastapi.responses import JSONResponse, StreamingResponse
 | 
					from fastapi.responses import JSONResponse, StreamingResponse
 | 
				
			||||||
from fastchat.conversation import Conversation, SeparatorStyle
 | 
					from packaging import version
 | 
				
			||||||
from fastchat.model.model_adapter import get_conversation_template
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import uvicorn
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm.engine.arg_utils import AsyncEngineArgs
 | 
					from vllm.engine.arg_utils import AsyncEngineArgs
 | 
				
			||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
 | 
					from vllm.engine.async_llm_engine import AsyncLLMEngine
 | 
				
			||||||
@ -33,11 +31,20 @@ from vllm.sampling_params import SamplingParams
 | 
				
			|||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
 | 
					from vllm.transformers_utils.tokenizer import get_tokenizer
 | 
				
			||||||
from vllm.utils import random_uuid
 | 
					from vllm.utils import random_uuid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    import fastchat
 | 
				
			||||||
 | 
					    from fastchat.conversation import Conversation, SeparatorStyle
 | 
				
			||||||
 | 
					    from fastchat.model.model_adapter import get_conversation_template
 | 
				
			||||||
 | 
					    _fastchat_available = True
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    _fastchat_available = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TIMEOUT_KEEP_ALIVE = 5  # seconds
 | 
					TIMEOUT_KEEP_ALIVE = 5  # seconds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = init_logger(__name__)
 | 
					logger = init_logger(__name__)
 | 
				
			||||||
served_model = None
 | 
					served_model = None
 | 
				
			||||||
app = fastapi.FastAPI()
 | 
					app = fastapi.FastAPI()
 | 
				
			||||||
 | 
					engine = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_error_response(status_code: HTTPStatus,
 | 
					def create_error_response(status_code: HTTPStatus,
 | 
				
			||||||
@ -63,10 +70,21 @@ async def check_model(request) -> Optional[JSONResponse]:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_gen_prompt(request) -> str:
 | 
					async def get_gen_prompt(request) -> str:
 | 
				
			||||||
 | 
					    if not _fastchat_available:
 | 
				
			||||||
 | 
					        raise ModuleNotFoundError(
 | 
				
			||||||
 | 
					            "fastchat is not installed. Please install fastchat to use "
 | 
				
			||||||
 | 
					            "the chat completion and conversation APIs: `$ pip install fschat`"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    if version.parse(fastchat.__version__) < version.parse("0.2.23"):
 | 
				
			||||||
 | 
					        raise ImportError(
 | 
				
			||||||
 | 
					            f"fastchat version is low. Current version: {fastchat.__version__} "
 | 
				
			||||||
 | 
					            "Please upgrade fastchat to use: `$ pip install -U fschat`")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    conv = get_conversation_template(request.model)
 | 
					    conv = get_conversation_template(request.model)
 | 
				
			||||||
    conv = Conversation(
 | 
					    conv = Conversation(
 | 
				
			||||||
        name=conv.name,
 | 
					        name=conv.name,
 | 
				
			||||||
        system=conv.system,
 | 
					        system_template=conv.system_template,
 | 
				
			||||||
 | 
					        system_message=conv.system_message,
 | 
				
			||||||
        roles=conv.roles,
 | 
					        roles=conv.roles,
 | 
				
			||||||
        messages=list(conv.messages),  # prevent in-place modification
 | 
					        messages=list(conv.messages),  # prevent in-place modification
 | 
				
			||||||
        offset=conv.offset,
 | 
					        offset=conv.offset,
 | 
				
			||||||
@ -83,7 +101,7 @@ async def get_gen_prompt(request) -> str:
 | 
				
			|||||||
        for message in request.messages:
 | 
					        for message in request.messages:
 | 
				
			||||||
            msg_role = message["role"]
 | 
					            msg_role = message["role"]
 | 
				
			||||||
            if msg_role == "system":
 | 
					            if msg_role == "system":
 | 
				
			||||||
                conv.system = message["content"]
 | 
					                conv.system_message = message["content"]
 | 
				
			||||||
            elif msg_role == "user":
 | 
					            elif msg_role == "user":
 | 
				
			||||||
                conv.append_message(conv.roles[0], message["content"])
 | 
					                conv.append_message(conv.roles[0], message["content"])
 | 
				
			||||||
            elif msg_role == "assistant":
 | 
					            elif msg_role == "assistant":
 | 
				
			||||||
@ -98,32 +116,33 @@ async def get_gen_prompt(request) -> str:
 | 
				
			|||||||
    return prompt
 | 
					    return prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def check_length(request, prompt, model_config):
 | 
					async def check_length(
 | 
				
			||||||
    if hasattr(model_config.hf_config, "max_sequence_length"):
 | 
					    request: Union[ChatCompletionRequest, CompletionRequest],
 | 
				
			||||||
        context_len = model_config.hf_config.max_sequence_length
 | 
					    prompt: Optional[str] = None,
 | 
				
			||||||
    elif hasattr(model_config.hf_config, "seq_length"):
 | 
					    prompt_ids: Optional[List[int]] = None
 | 
				
			||||||
        context_len = model_config.hf_config.seq_length
 | 
					) -> Tuple[List[int], Optional[JSONResponse]]:
 | 
				
			||||||
    elif hasattr(model_config.hf_config, "max_position_embeddings"):
 | 
					    assert (not (prompt is None and prompt_ids is None)
 | 
				
			||||||
        context_len = model_config.hf_config.max_position_embeddings
 | 
					            and not (prompt is not None and prompt_ids is not None)
 | 
				
			||||||
    elif hasattr(model_config.hf_config, "seq_length"):
 | 
					            ), "Either prompt or prompt_ids should be provided."
 | 
				
			||||||
        context_len = model_config.hf_config.seq_length
 | 
					    if prompt_ids is not None:
 | 
				
			||||||
 | 
					        input_ids = prompt_ids
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        context_len = 2048
 | 
					        input_ids = tokenizer(prompt).input_ids
 | 
				
			||||||
 | 
					 | 
				
			||||||
    input_ids = tokenizer(prompt).input_ids
 | 
					 | 
				
			||||||
    token_num = len(input_ids)
 | 
					    token_num = len(input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if token_num + request.max_tokens > context_len:
 | 
					    if request.max_tokens is None:
 | 
				
			||||||
        return create_error_response(
 | 
					        request.max_tokens = max_model_len - token_num
 | 
				
			||||||
 | 
					    if token_num + request.max_tokens > max_model_len:
 | 
				
			||||||
 | 
					        return input_ids, create_error_response(
 | 
				
			||||||
            HTTPStatus.BAD_REQUEST,
 | 
					            HTTPStatus.BAD_REQUEST,
 | 
				
			||||||
            f"This model's maximum context length is {context_len} tokens. "
 | 
					            f"This model's maximum context length is {max_model_len} tokens. "
 | 
				
			||||||
            f"However, you requested {request.max_tokens + token_num} tokens "
 | 
					            f"However, you requested {request.max_tokens + token_num} tokens "
 | 
				
			||||||
            f"({token_num} in the messages, "
 | 
					            f"({token_num} in the messages, "
 | 
				
			||||||
            f"{request.max_tokens} in the completion). "
 | 
					            f"{request.max_tokens} in the completion). "
 | 
				
			||||||
            f"Please reduce the length of the messages or completion.",
 | 
					            f"Please reduce the length of the messages or completion.",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        return None
 | 
					        return input_ids, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.get("/v1/models")
 | 
					@app.get("/v1/models")
 | 
				
			||||||
@ -162,7 +181,8 @@ def create_logprobs(token_ids: List[int],
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.post("/v1/chat/completions")
 | 
					@app.post("/v1/chat/completions")
 | 
				
			||||||
async def create_chat_completion(raw_request: Request):
 | 
					async def create_chat_completion(request: ChatCompletionRequest,
 | 
				
			||||||
 | 
					                                 raw_request: Request):
 | 
				
			||||||
    """Completion API similar to OpenAI's API.
 | 
					    """Completion API similar to OpenAI's API.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    See  https://platform.openai.com/docs/api-reference/chat/create
 | 
					    See  https://platform.openai.com/docs/api-reference/chat/create
 | 
				
			||||||
@ -172,26 +192,25 @@ async def create_chat_completion(raw_request: Request):
 | 
				
			|||||||
        - function_call (Users should implement this by themselves)
 | 
					        - function_call (Users should implement this by themselves)
 | 
				
			||||||
        - logit_bias (to be supported by vLLM engine)
 | 
					        - logit_bias (to be supported by vLLM engine)
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    request = ChatCompletionRequest(**await raw_request.json())
 | 
					 | 
				
			||||||
    logger.info(f"Received chat completion request: {request}")
 | 
					    logger.info(f"Received chat completion request: {request}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    error_check_ret = await check_model(request)
 | 
					    error_check_ret = await check_model(request)
 | 
				
			||||||
    if error_check_ret is not None:
 | 
					    if error_check_ret is not None:
 | 
				
			||||||
        return error_check_ret
 | 
					        return error_check_ret
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if request.logit_bias is not None:
 | 
					    if request.logit_bias is not None and len(request.logit_bias) > 0:
 | 
				
			||||||
        # TODO: support logit_bias in vLLM engine.
 | 
					        # TODO: support logit_bias in vLLM engine.
 | 
				
			||||||
        return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
					        return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
				
			||||||
                                     "logit_bias is not currently supported")
 | 
					                                     "logit_bias is not currently supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    prompt = await get_gen_prompt(request)
 | 
					    prompt = await get_gen_prompt(request)
 | 
				
			||||||
    error_check_ret = await check_length(request, prompt, engine_model_config)
 | 
					    token_ids, error_check_ret = await check_length(request, prompt=prompt)
 | 
				
			||||||
    if error_check_ret is not None:
 | 
					    if error_check_ret is not None:
 | 
				
			||||||
        return error_check_ret
 | 
					        return error_check_ret
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_name = request.model
 | 
					    model_name = request.model
 | 
				
			||||||
    request_id = f"cmpl-{random_uuid()}"
 | 
					    request_id = f"cmpl-{random_uuid()}"
 | 
				
			||||||
    created_time = int(time.time())
 | 
					    created_time = int(time.monotonic())
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        sampling_params = SamplingParams(
 | 
					        sampling_params = SamplingParams(
 | 
				
			||||||
            n=request.n,
 | 
					            n=request.n,
 | 
				
			||||||
@ -200,19 +219,19 @@ async def create_chat_completion(raw_request: Request):
 | 
				
			|||||||
            temperature=request.temperature,
 | 
					            temperature=request.temperature,
 | 
				
			||||||
            top_p=request.top_p,
 | 
					            top_p=request.top_p,
 | 
				
			||||||
            stop=request.stop,
 | 
					            stop=request.stop,
 | 
				
			||||||
 | 
					            stop_token_ids=request.stop_token_ids,
 | 
				
			||||||
            max_tokens=request.max_tokens,
 | 
					            max_tokens=request.max_tokens,
 | 
				
			||||||
            best_of=request.best_of,
 | 
					            best_of=request.best_of,
 | 
				
			||||||
            top_k=request.top_k,
 | 
					            top_k=request.top_k,
 | 
				
			||||||
            ignore_eos=request.ignore_eos,
 | 
					            ignore_eos=request.ignore_eos,
 | 
				
			||||||
            use_beam_search=request.use_beam_search,
 | 
					            use_beam_search=request.use_beam_search,
 | 
				
			||||||
 | 
					            skip_special_tokens=request.skip_special_tokens,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    except ValueError as e:
 | 
					    except ValueError as e:
 | 
				
			||||||
        return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
 | 
					        return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result_generator = engine.generate(prompt, sampling_params, request_id)
 | 
					    result_generator = engine.generate(prompt, sampling_params, request_id,
 | 
				
			||||||
 | 
					                                       token_ids)
 | 
				
			||||||
    async def abort_request() -> None:
 | 
					 | 
				
			||||||
        await engine.abort(request_id)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def create_stream_response_json(
 | 
					    def create_stream_response_json(
 | 
				
			||||||
        index: int,
 | 
					        index: int,
 | 
				
			||||||
@ -269,23 +288,19 @@ async def create_chat_completion(raw_request: Request):
 | 
				
			|||||||
                        finish_reason=output.finish_reason,
 | 
					                        finish_reason=output.finish_reason,
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                    yield f"data: {response_json}\n\n"
 | 
					                    yield f"data: {response_json}\n\n"
 | 
				
			||||||
            yield "data: [DONE]\n\n"
 | 
					        yield "data: [DONE]\n\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Streaming response
 | 
					    # Streaming response
 | 
				
			||||||
    if request.stream:
 | 
					    if request.stream:
 | 
				
			||||||
        background_tasks = BackgroundTasks()
 | 
					 | 
				
			||||||
        # Abort the request if the client disconnects.
 | 
					 | 
				
			||||||
        background_tasks.add_task(abort_request)
 | 
					 | 
				
			||||||
        return StreamingResponse(completion_stream_generator(),
 | 
					        return StreamingResponse(completion_stream_generator(),
 | 
				
			||||||
                                 media_type="text/event-stream",
 | 
					                                 media_type="text/event-stream")
 | 
				
			||||||
                                 background=background_tasks)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Non-streaming response
 | 
					    # Non-streaming response
 | 
				
			||||||
    final_res: RequestOutput = None
 | 
					    final_res: RequestOutput = None
 | 
				
			||||||
    async for res in result_generator:
 | 
					    async for res in result_generator:
 | 
				
			||||||
        if await raw_request.is_disconnected():
 | 
					        if await raw_request.is_disconnected():
 | 
				
			||||||
            # Abort the request if the client disconnects.
 | 
					            # Abort the request if the client disconnects.
 | 
				
			||||||
            await abort_request()
 | 
					            await engine.abort(request_id)
 | 
				
			||||||
            return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
					            return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
				
			||||||
                                         "Client disconnected")
 | 
					                                         "Client disconnected")
 | 
				
			||||||
        final_res = res
 | 
					        final_res = res
 | 
				
			||||||
@ -331,7 +346,7 @@ async def create_chat_completion(raw_request: Request):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.post("/v1/completions")
 | 
					@app.post("/v1/completions")
 | 
				
			||||||
async def create_completion(raw_request: Request):
 | 
					async def create_completion(request: CompletionRequest, raw_request: Request):
 | 
				
			||||||
    """Completion API similar to OpenAI's API.
 | 
					    """Completion API similar to OpenAI's API.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    See https://platform.openai.com/docs/api-reference/completions/create
 | 
					    See https://platform.openai.com/docs/api-reference/completions/create
 | 
				
			||||||
@ -344,7 +359,6 @@ async def create_completion(raw_request: Request):
 | 
				
			|||||||
          suffix)
 | 
					          suffix)
 | 
				
			||||||
        - logit_bias (to be supported by vLLM engine)
 | 
					        - logit_bias (to be supported by vLLM engine)
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    request = CompletionRequest(**await raw_request.json())
 | 
					 | 
				
			||||||
    logger.info(f"Received completion request: {request}")
 | 
					    logger.info(f"Received completion request: {request}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    error_check_ret = await check_model(request)
 | 
					    error_check_ret = await check_model(request)
 | 
				
			||||||
@ -362,25 +376,42 @@ async def create_completion(raw_request: Request):
 | 
				
			|||||||
        return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
					        return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
				
			||||||
                                     "suffix is not currently supported")
 | 
					                                     "suffix is not currently supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if request.logit_bias is not None:
 | 
					    if request.logit_bias is not None and len(request.logit_bias) > 0:
 | 
				
			||||||
        # TODO: support logit_bias in vLLM engine.
 | 
					        # TODO: support logit_bias in vLLM engine.
 | 
				
			||||||
        return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
					        return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
				
			||||||
                                     "logit_bias is not currently supported")
 | 
					                                     "logit_bias is not currently supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_name = request.model
 | 
					    model_name = request.model
 | 
				
			||||||
    request_id = f"cmpl-{random_uuid()}"
 | 
					    request_id = f"cmpl-{random_uuid()}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    use_token_ids = False
 | 
				
			||||||
    if isinstance(request.prompt, list):
 | 
					    if isinstance(request.prompt, list):
 | 
				
			||||||
        if len(request.prompt) == 0:
 | 
					        if len(request.prompt) == 0:
 | 
				
			||||||
            return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
					            return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
				
			||||||
                                         "please provide at least one prompt")
 | 
					                                         "please provide at least one prompt")
 | 
				
			||||||
        if len(request.prompt) > 1:
 | 
					        first_element = request.prompt[0]
 | 
				
			||||||
            return create_error_response(
 | 
					        if isinstance(first_element, int):
 | 
				
			||||||
                HTTPStatus.BAD_REQUEST,
 | 
					            use_token_ids = True
 | 
				
			||||||
                "multiple prompts in a batch is not currently supported")
 | 
					            prompt = request.prompt
 | 
				
			||||||
        prompt = request.prompt[0]
 | 
					        elif isinstance(first_element, (str, list)):
 | 
				
			||||||
 | 
					            # TODO: handles multiple prompt case in list[list[int]]
 | 
				
			||||||
 | 
					            if len(request.prompt) > 1:
 | 
				
			||||||
 | 
					                return create_error_response(
 | 
				
			||||||
 | 
					                    HTTPStatus.BAD_REQUEST,
 | 
				
			||||||
 | 
					                    "multiple prompts in a batch is not currently supported")
 | 
				
			||||||
 | 
					            use_token_ids = not isinstance(first_element, str)
 | 
				
			||||||
 | 
					            prompt = request.prompt[0]
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        prompt = request.prompt
 | 
					        prompt = request.prompt
 | 
				
			||||||
    created_time = int(time.time())
 | 
					
 | 
				
			||||||
 | 
					    if use_token_ids:
 | 
				
			||||||
 | 
					        _, error_check_ret = await check_length(request, prompt_ids=prompt)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        token_ids, error_check_ret = await check_length(request, prompt=prompt)
 | 
				
			||||||
 | 
					    if error_check_ret is not None:
 | 
				
			||||||
 | 
					        return error_check_ret
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    created_time = int(time.monotonic())
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        sampling_params = SamplingParams(
 | 
					        sampling_params = SamplingParams(
 | 
				
			||||||
            n=request.n,
 | 
					            n=request.n,
 | 
				
			||||||
@ -391,15 +422,24 @@ async def create_completion(raw_request: Request):
 | 
				
			|||||||
            top_p=request.top_p,
 | 
					            top_p=request.top_p,
 | 
				
			||||||
            top_k=request.top_k,
 | 
					            top_k=request.top_k,
 | 
				
			||||||
            stop=request.stop,
 | 
					            stop=request.stop,
 | 
				
			||||||
 | 
					            stop_token_ids=request.stop_token_ids,
 | 
				
			||||||
            ignore_eos=request.ignore_eos,
 | 
					            ignore_eos=request.ignore_eos,
 | 
				
			||||||
            max_tokens=request.max_tokens,
 | 
					            max_tokens=request.max_tokens,
 | 
				
			||||||
            logprobs=request.logprobs,
 | 
					            logprobs=request.logprobs,
 | 
				
			||||||
            use_beam_search=request.use_beam_search,
 | 
					            use_beam_search=request.use_beam_search,
 | 
				
			||||||
 | 
					            skip_special_tokens=request.skip_special_tokens,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    except ValueError as e:
 | 
					    except ValueError as e:
 | 
				
			||||||
        return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
 | 
					        return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result_generator = engine.generate(prompt, sampling_params, request_id)
 | 
					    if use_token_ids:
 | 
				
			||||||
 | 
					        result_generator = engine.generate(None,
 | 
				
			||||||
 | 
					                                           sampling_params,
 | 
				
			||||||
 | 
					                                           request_id,
 | 
				
			||||||
 | 
					                                           prompt_token_ids=prompt)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        result_generator = engine.generate(prompt, sampling_params, request_id,
 | 
				
			||||||
 | 
					                                           token_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Similar to the OpenAI API, when n != best_of, we do not stream the
 | 
					    # Similar to the OpenAI API, when n != best_of, we do not stream the
 | 
				
			||||||
    # results. In addition, we do not stream the results when use beam search.
 | 
					    # results. In addition, we do not stream the results when use beam search.
 | 
				
			||||||
@ -407,9 +447,6 @@ async def create_completion(raw_request: Request):
 | 
				
			|||||||
              and (request.best_of is None or request.n == request.best_of)
 | 
					              and (request.best_of is None or request.n == request.best_of)
 | 
				
			||||||
              and not request.use_beam_search)
 | 
					              and not request.use_beam_search)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def abort_request() -> None:
 | 
					 | 
				
			||||||
        await engine.abort(request_id)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def create_stream_response_json(
 | 
					    def create_stream_response_json(
 | 
				
			||||||
        index: int,
 | 
					        index: int,
 | 
				
			||||||
        text: str,
 | 
					        text: str,
 | 
				
			||||||
@ -465,23 +502,19 @@ async def create_completion(raw_request: Request):
 | 
				
			|||||||
                        finish_reason=output.finish_reason,
 | 
					                        finish_reason=output.finish_reason,
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                    yield f"data: {response_json}\n\n"
 | 
					                    yield f"data: {response_json}\n\n"
 | 
				
			||||||
            yield "data: [DONE]\n\n"
 | 
					        yield "data: [DONE]\n\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Streaming response
 | 
					    # Streaming response
 | 
				
			||||||
    if stream:
 | 
					    if stream:
 | 
				
			||||||
        background_tasks = BackgroundTasks()
 | 
					 | 
				
			||||||
        # Abort the request if the client disconnects.
 | 
					 | 
				
			||||||
        background_tasks.add_task(abort_request)
 | 
					 | 
				
			||||||
        return StreamingResponse(completion_stream_generator(),
 | 
					        return StreamingResponse(completion_stream_generator(),
 | 
				
			||||||
                                 media_type="text/event-stream",
 | 
					                                 media_type="text/event-stream")
 | 
				
			||||||
                                 background=background_tasks)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Non-streaming response
 | 
					    # Non-streaming response
 | 
				
			||||||
    final_res: RequestOutput = None
 | 
					    final_res: RequestOutput = None
 | 
				
			||||||
    async for res in result_generator:
 | 
					    async for res in result_generator:
 | 
				
			||||||
        if await raw_request.is_disconnected():
 | 
					        if await raw_request.is_disconnected():
 | 
				
			||||||
            # Abort the request if the client disconnects.
 | 
					            # Abort the request if the client disconnects.
 | 
				
			||||||
            await abort_request()
 | 
					            await engine.abort(request_id)
 | 
				
			||||||
            return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
					            return create_error_response(HTTPStatus.BAD_REQUEST,
 | 
				
			||||||
                                         "Client disconnected")
 | 
					                                         "Client disconnected")
 | 
				
			||||||
        final_res = res
 | 
					        final_res = res
 | 
				
			||||||
@ -534,10 +567,7 @@ async def create_completion(raw_request: Request):
 | 
				
			|||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    parser = argparse.ArgumentParser(
 | 
					    parser = argparse.ArgumentParser(
 | 
				
			||||||
        description="vLLM OpenAI-Compatible RESTful API server.")
 | 
					        description="vLLM OpenAI-Compatible RESTful API server.")
 | 
				
			||||||
    parser.add_argument("--host",
 | 
					    parser.add_argument("--host", type=str, default=None, help="host name")
 | 
				
			||||||
                        type=str,
 | 
					 | 
				
			||||||
                        default="localhost",
 | 
					 | 
				
			||||||
                        help="host name")
 | 
					 | 
				
			||||||
    parser.add_argument("--port", type=int, default=8000, help="port number")
 | 
					    parser.add_argument("--port", type=int, default=8000, help="port number")
 | 
				
			||||||
    parser.add_argument("--allow-credentials",
 | 
					    parser.add_argument("--allow-credentials",
 | 
				
			||||||
                        action="store_true",
 | 
					                        action="store_true",
 | 
				
			||||||
@ -582,10 +612,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
    engine_args = AsyncEngineArgs.from_cli_args(args)
 | 
					    engine_args = AsyncEngineArgs.from_cli_args(args)
 | 
				
			||||||
    engine = AsyncLLMEngine.from_engine_args(engine_args)
 | 
					    engine = AsyncLLMEngine.from_engine_args(engine_args)
 | 
				
			||||||
    engine_model_config = asyncio.run(engine.get_model_config())
 | 
					    engine_model_config = asyncio.run(engine.get_model_config())
 | 
				
			||||||
 | 
					    max_model_len = engine_model_config.max_model_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # A separate tokenizer to map token IDs to strings.
 | 
					    # A separate tokenizer to map token IDs to strings.
 | 
				
			||||||
    tokenizer = get_tokenizer(engine_args.tokenizer,
 | 
					    tokenizer = get_tokenizer(engine_args.tokenizer,
 | 
				
			||||||
                              tokenizer_mode=engine_args.tokenizer_mode)
 | 
					                              tokenizer_mode=engine_args.tokenizer_mode,
 | 
				
			||||||
 | 
					                              trust_remote_code=engine_args.trust_remote_code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    uvicorn.run(app,
 | 
					    uvicorn.run(app,
 | 
				
			||||||
                host=args.host,
 | 
					                host=args.host,
 | 
				
			||||||
 | 
				
			|||||||
@ -58,7 +58,7 @@ class ChatCompletionRequest(BaseModel):
 | 
				
			|||||||
    temperature: Optional[float] = 0.7
 | 
					    temperature: Optional[float] = 0.7
 | 
				
			||||||
    top_p: Optional[float] = 1.0
 | 
					    top_p: Optional[float] = 1.0
 | 
				
			||||||
    n: Optional[int] = 1
 | 
					    n: Optional[int] = 1
 | 
				
			||||||
    max_tokens: Optional[int] = 16
 | 
					    max_tokens: Optional[int] = None
 | 
				
			||||||
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
 | 
					    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
 | 
				
			||||||
    stream: Optional[bool] = False
 | 
					    stream: Optional[bool] = False
 | 
				
			||||||
    presence_penalty: Optional[float] = 0.0
 | 
					    presence_penalty: Optional[float] = 0.0
 | 
				
			||||||
@ -70,11 +70,14 @@ class ChatCompletionRequest(BaseModel):
 | 
				
			|||||||
    top_k: Optional[int] = -1
 | 
					    top_k: Optional[int] = -1
 | 
				
			||||||
    ignore_eos: Optional[bool] = False
 | 
					    ignore_eos: Optional[bool] = False
 | 
				
			||||||
    use_beam_search: Optional[bool] = False
 | 
					    use_beam_search: Optional[bool] = False
 | 
				
			||||||
 | 
					    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
 | 
				
			||||||
 | 
					    skip_special_tokens: Optional[bool] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CompletionRequest(BaseModel):
 | 
					class CompletionRequest(BaseModel):
 | 
				
			||||||
    model: str
 | 
					    model: str
 | 
				
			||||||
    prompt: Union[str, List[str]]
 | 
					    # a string, array of strings, array of tokens, or array of token arrays
 | 
				
			||||||
 | 
					    prompt: Union[List[int], List[List[int]], str, List[str]]
 | 
				
			||||||
    suffix: Optional[str] = None
 | 
					    suffix: Optional[str] = None
 | 
				
			||||||
    max_tokens: Optional[int] = 16
 | 
					    max_tokens: Optional[int] = 16
 | 
				
			||||||
    temperature: Optional[float] = 1.0
 | 
					    temperature: Optional[float] = 1.0
 | 
				
			||||||
@ -93,6 +96,8 @@ class CompletionRequest(BaseModel):
 | 
				
			|||||||
    top_k: Optional[int] = -1
 | 
					    top_k: Optional[int] = -1
 | 
				
			||||||
    ignore_eos: Optional[bool] = False
 | 
					    ignore_eos: Optional[bool] = False
 | 
				
			||||||
    use_beam_search: Optional[bool] = False
 | 
					    use_beam_search: Optional[bool] = False
 | 
				
			||||||
 | 
					    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
 | 
				
			||||||
 | 
					    skip_special_tokens: Optional[bool] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LogProbs(BaseModel):
 | 
					class LogProbs(BaseModel):
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
from typing import Dict, List, Tuple
 | 
					from typing import Dict, List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from xformers.ops import AttentionBias
 | 
					from xformers.ops import AttentionBias
 | 
				
			||||||
@ -29,6 +29,7 @@ class InputMetadata:
 | 
				
			|||||||
        context_lens: torch.Tensor,
 | 
					        context_lens: torch.Tensor,
 | 
				
			||||||
        max_context_len: int,
 | 
					        max_context_len: int,
 | 
				
			||||||
        block_tables: torch.Tensor,
 | 
					        block_tables: torch.Tensor,
 | 
				
			||||||
 | 
					        sliding_window: Optional[int] = None,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        self.seq_groups = seq_groups
 | 
					        self.seq_groups = seq_groups
 | 
				
			||||||
        self.seq_data = seq_data
 | 
					        self.seq_data = seq_data
 | 
				
			||||||
@ -38,6 +39,24 @@ class InputMetadata:
 | 
				
			|||||||
        self.max_context_len = max_context_len
 | 
					        self.max_context_len = max_context_len
 | 
				
			||||||
        self.block_tables = block_tables
 | 
					        self.block_tables = block_tables
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.to_cache = None
 | 
				
			||||||
 | 
					        if sliding_window is not None:
 | 
				
			||||||
 | 
					            # We need to keep the positions of sliding windows within
 | 
				
			||||||
 | 
					            # the key / value tables, this is helpful to know which
 | 
				
			||||||
 | 
					            # elements we need to cache and where
 | 
				
			||||||
 | 
					            to_cache, start_idx = [], 0
 | 
				
			||||||
 | 
					            for prompt_len in self.prompt_lens:
 | 
				
			||||||
 | 
					                to_cache.extend(
 | 
				
			||||||
 | 
					                    range(
 | 
				
			||||||
 | 
					                        start_idx + max(0, prompt_len - sliding_window),
 | 
				
			||||||
 | 
					                        start_idx + prompt_len,
 | 
				
			||||||
 | 
					                    ))
 | 
				
			||||||
 | 
					                start_idx += prompt_len
 | 
				
			||||||
 | 
					            to_cache.extend(range(start_idx, slot_mapping.shape[0]))
 | 
				
			||||||
 | 
					            self.to_cache = torch.tensor(to_cache,
 | 
				
			||||||
 | 
					                                         dtype=torch.int32,
 | 
				
			||||||
 | 
					                                         device=self.slot_mapping.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.num_prompts = len(prompt_lens)
 | 
					        self.num_prompts = len(prompt_lens)
 | 
				
			||||||
        self.num_prompt_tokens = sum(prompt_lens)
 | 
					        self.num_prompt_tokens = sum(prompt_lens)
 | 
				
			||||||
        self.num_generation_tokens = context_lens.shape[0]
 | 
					        self.num_generation_tokens = context_lens.shape[0]
 | 
				
			||||||
 | 
				
			|||||||
@ -4,23 +4,6 @@ import torch.nn as nn
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from vllm import activation_ops
 | 
					from vllm import activation_ops
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_ACTIVATION_REGISTRY = {
 | 
					 | 
				
			||||||
    "gelu": nn.GELU(),
 | 
					 | 
				
			||||||
    # NOTE: The following GELU functions may introduce small rounding errors.
 | 
					 | 
				
			||||||
    "gelu_new": nn.GELU(approximate="tanh"),
 | 
					 | 
				
			||||||
    "gelu_fast": nn.GELU(approximate="tanh"),
 | 
					 | 
				
			||||||
    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
 | 
					 | 
				
			||||||
    "relu": nn.ReLU(),
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_act_fn(act_fn: str) -> nn.Module:
 | 
					 | 
				
			||||||
    """Get an activation function by name."""
 | 
					 | 
				
			||||||
    act_fn = act_fn.lower()
 | 
					 | 
				
			||||||
    if act_fn in _ACTIVATION_REGISTRY:
 | 
					 | 
				
			||||||
        return _ACTIVATION_REGISTRY[act_fn]
 | 
					 | 
				
			||||||
    raise ValueError(f"Activation function {act_fn!r} is not supported.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SiluAndMul(nn.Module):
 | 
					class SiluAndMul(nn.Module):
 | 
				
			||||||
    """An activation function for SwiGLU.
 | 
					    """An activation function for SwiGLU.
 | 
				
			||||||
@ -38,3 +21,40 @@ class SiluAndMul(nn.Module):
 | 
				
			|||||||
        out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
 | 
					        out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
 | 
				
			||||||
        activation_ops.silu_and_mul(out, x)
 | 
					        activation_ops.silu_and_mul(out, x)
 | 
				
			||||||
        return out
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NewGELU(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        num_tokens = x.shape[0]
 | 
				
			||||||
 | 
					        d = x.shape[1]
 | 
				
			||||||
 | 
					        out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
 | 
				
			||||||
 | 
					        activation_ops.gelu_new(out, x)
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FastGELU(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        num_tokens = x.shape[0]
 | 
				
			||||||
 | 
					        d = x.shape[1]
 | 
				
			||||||
 | 
					        out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
 | 
				
			||||||
 | 
					        activation_ops.gelu_fast(out, x)
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_ACTIVATION_REGISTRY = {
 | 
				
			||||||
 | 
					    "gelu": nn.GELU(),
 | 
				
			||||||
 | 
					    "gelu_fast": FastGELU(),
 | 
				
			||||||
 | 
					    "gelu_new": NewGELU(),
 | 
				
			||||||
 | 
					    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
 | 
				
			||||||
 | 
					    "relu": nn.ReLU(),
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_act_fn(act_fn: str) -> nn.Module:
 | 
				
			||||||
 | 
					    """Get an activation function by name."""
 | 
				
			||||||
 | 
					    act_fn = act_fn.lower()
 | 
				
			||||||
 | 
					    if act_fn in _ACTIVATION_REGISTRY:
 | 
				
			||||||
 | 
					        return _ACTIVATION_REGISTRY[act_fn]
 | 
				
			||||||
 | 
					    raise ValueError(f"Activation function {act_fn!r} is not supported.")
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,5 @@
 | 
				
			|||||||
"""Multi-head attention."""
 | 
					"""Multi-head attention."""
 | 
				
			||||||
from typing import List, Optional
 | 
					from typing import Any, Dict, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
@ -9,10 +9,14 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from vllm import attention_ops
 | 
					from vllm import attention_ops
 | 
				
			||||||
from vllm import cache_ops
 | 
					from vllm import cache_ops
 | 
				
			||||||
from vllm import pos_encoding_ops
 | 
					 | 
				
			||||||
from vllm.model_executor.input_metadata import InputMetadata
 | 
					from vllm.model_executor.input_metadata import InputMetadata
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.rotary_embedding import (
 | 
				
			||||||
 | 
					    DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
 | 
				
			||||||
 | 
					    RotaryEmbedding)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128]
 | 
					_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
 | 
				
			||||||
 | 
					# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
 | 
				
			||||||
 | 
					_PARTITION_SIZE = 512
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PagedAttention(nn.Module):
 | 
					class PagedAttention(nn.Module):
 | 
				
			||||||
@ -20,12 +24,20 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
    """GPT-style multi-head PagedAttention.
 | 
					    """GPT-style multi-head PagedAttention.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    This class takes flattened 1D query, key, and value tensors as input. The
 | 
					    This class takes flattened 1D query, key, and value tensors as input. The
 | 
				
			||||||
    input 1D tensors can be split into three parts: the prompt tokens, the
 | 
					    input 1D tensors can either contain prompt tokens or generation tokens, in
 | 
				
			||||||
    generation tokens, and the paddings.
 | 
					    addition to paddings.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    |<------------------------------------- num_valid_tokens ------------------------------------->|
 | 
					    If the input tensors contain prompt tokens, the layout is as follows:
 | 
				
			||||||
    |<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
 | 
					
 | 
				
			||||||
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
 | 
					    |<---------------------- num_valid_tokens ---------------------->|
 | 
				
			||||||
 | 
					    |<--------------- num_prompt_tokens -------------->|
 | 
				
			||||||
 | 
					    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Otherwise, the layout is as follows:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    |<------------------ num_valid_tokens ------------------->|
 | 
				
			||||||
 | 
					    |<------- num_generation_tokens (M) ------->|
 | 
				
			||||||
 | 
					    |<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    The prompts might have different lengths, while the generation tokens always
 | 
					    The prompts might have different lengths, while the generation tokens always
 | 
				
			||||||
    have length 1. The paddings are appended to make the input length a multiple
 | 
					    have length 1. The paddings are appended to make the input length a multiple
 | 
				
			||||||
@ -44,23 +56,42 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
    5. Output a flattened 1D tensor.
 | 
					    5. Output a flattened 1D tensor.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 num_heads: int,
 | 
				
			||||||
 | 
					                 head_size: int,
 | 
				
			||||||
 | 
					                 scale: float,
 | 
				
			||||||
 | 
					                 num_kv_heads: Optional[int] = None,
 | 
				
			||||||
 | 
					                 sliding_window: Optional[int] = None) -> None:
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.num_heads = num_heads
 | 
					        self.num_heads = num_heads
 | 
				
			||||||
        self.head_size = head_size
 | 
					        self.head_size = head_size
 | 
				
			||||||
        self.scale = float(scale)
 | 
					        self.scale = float(scale)
 | 
				
			||||||
        self.attn_op = xops.fmha.cutlass.FwOp()
 | 
					        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
 | 
				
			||||||
 | 
					        self.sliding_window = sliding_window
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert self.num_heads % self.num_kv_heads == 0
 | 
				
			||||||
 | 
					        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
 | 
				
			||||||
 | 
					        self.head_mapping = torch.repeat_interleave(
 | 
				
			||||||
 | 
					            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
 | 
				
			||||||
 | 
					            self.num_queries_per_kv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
 | 
					        if self.head_size not in _SUPPORTED_HEAD_SIZES:
 | 
				
			||||||
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
 | 
					            raise ValueError(f"head_size ({self.head_size}) is not supported. "
 | 
				
			||||||
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
 | 
					                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_attn_bias(self, input_metadata: InputMetadata) -> None:
 | 
					    def set_attn_bias(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        dtype: torch.dtype,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        del dtype  # Unused.
 | 
				
			||||||
        if input_metadata.attn_bias:
 | 
					        if input_metadata.attn_bias:
 | 
				
			||||||
            # Already set by a previous layer.
 | 
					            # Already set by a previous layer.
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        prompt_lens = input_metadata.prompt_lens
 | 
					        prompt_lens = input_metadata.prompt_lens
 | 
				
			||||||
        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
 | 
					        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
 | 
				
			||||||
 | 
					        if self.sliding_window is not None:
 | 
				
			||||||
 | 
					            attn_bias = attn_bias.make_local_attention(self.sliding_window)
 | 
				
			||||||
        input_metadata.attn_bias.append(attn_bias)
 | 
					        input_metadata.attn_bias.append(attn_bias)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def multi_query_kv_attention(
 | 
					    def multi_query_kv_attention(
 | 
				
			||||||
@ -76,10 +107,18 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            output: shape = [num_prompt_tokens, num_heads, head_size]
 | 
					            output: shape = [num_prompt_tokens, num_heads, head_size]
 | 
				
			||||||
            query: shape = [num_prompt_tokens, num_heads, head_size]
 | 
					            query: shape = [num_prompt_tokens, num_heads, head_size]
 | 
				
			||||||
            key: shape = [num_prompt_tokens, num_heads, head_size]
 | 
					            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
 | 
				
			||||||
            value: shape = [num_prompt_tokens, num_heads, head_size]
 | 
					            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
 | 
				
			||||||
            input_metadata: metadata for paged attention.
 | 
					            input_metadata: metadata for paged attention.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.num_kv_heads != self.num_heads:
 | 
				
			||||||
 | 
					            # Project the key and value tensors to the desired number of heads.
 | 
				
			||||||
 | 
					            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
 | 
				
			||||||
 | 
					            value = torch.repeat_interleave(value,
 | 
				
			||||||
 | 
					                                            self.num_queries_per_kv,
 | 
				
			||||||
 | 
					                                            dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
 | 
					        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
 | 
				
			||||||
        out = xops.memory_efficient_attention_forward(
 | 
					        out = xops.memory_efficient_attention_forward(
 | 
				
			||||||
            query.unsqueeze(0),
 | 
					            query.unsqueeze(0),
 | 
				
			||||||
@ -88,12 +127,19 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
            attn_bias=input_metadata.attn_bias[0],
 | 
					            attn_bias=input_metadata.attn_bias[0],
 | 
				
			||||||
            p=0.0,
 | 
					            p=0.0,
 | 
				
			||||||
            scale=self.scale,
 | 
					            scale=self.scale,
 | 
				
			||||||
            op=self.attn_op,
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        # TODO(woosuk): Unnecessary copy. Optimize.
 | 
					        # TODO(woosuk): Unnecessary copy. Optimize.
 | 
				
			||||||
        output.copy_(out.squeeze(0))
 | 
					        output.copy_(out.squeeze(0))
 | 
				
			||||||
        return output
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_alibi_slopes(self) -> Optional[torch.Tensor]:
 | 
				
			||||||
 | 
					        """Returns the slopes for the alibi attention bias.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            slopes: shape = [num_heads]
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def single_query_cached_kv_attention(
 | 
					    def single_query_cached_kv_attention(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        output: torch.Tensor,
 | 
					        output: torch.Tensor,
 | 
				
			||||||
@ -101,30 +147,77 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
        key_cache: torch.Tensor,
 | 
					        key_cache: torch.Tensor,
 | 
				
			||||||
        value_cache: torch.Tensor,
 | 
					        value_cache: torch.Tensor,
 | 
				
			||||||
        input_metadata: InputMetadata,
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        alibi_slopes: Optional[torch.Tensor],
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        """PagedAttention for the generation tokens.
 | 
					        """PagedAttention for the generation tokens.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            output: shape = [num_generation_tokens, num_heads, head_size]
 | 
					            output: shape = [num_generation_tokens, num_heads, head_size]
 | 
				
			||||||
            query: shape = [num_generation_tokens, num_heads, head_size]
 | 
					            query: shape = [num_generation_tokens, num_heads, head_size]
 | 
				
			||||||
            key_cache: shape = [num_blocks, num_heads, head_size/x,
 | 
					            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
 | 
				
			||||||
                block_size, x]
 | 
					                block_size, x]
 | 
				
			||||||
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
 | 
					            value_cache: shape = [num_blocks, num_kv_heads, head_size,
 | 
				
			||||||
 | 
					                block_size]
 | 
				
			||||||
            input_metadata: metadata for paged attention.
 | 
					            input_metadata: metadata for paged attention.
 | 
				
			||||||
 | 
					            alibi_slopes: shape = [num_heads]
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        block_size = value_cache.shape[3]
 | 
					        block_size = value_cache.shape[3]
 | 
				
			||||||
        attention_ops.single_query_cached_kv_attention(
 | 
					        num_seqs, num_heads, head_size = query.shape
 | 
				
			||||||
            output,
 | 
					        max_num_partitions = (
 | 
				
			||||||
            query,
 | 
					            (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
 | 
				
			||||||
            key_cache,
 | 
					            _PARTITION_SIZE)
 | 
				
			||||||
            value_cache,
 | 
					        # NOTE(woosuk): We use a simple heuristic to decide whether to use
 | 
				
			||||||
            self.scale,
 | 
					        # PagedAttention V1 or V2. If the number of partitions is 1, we use
 | 
				
			||||||
            input_metadata.block_tables,
 | 
					        # V1 to avoid the overhead of reduction. Also, if the number of
 | 
				
			||||||
            input_metadata.context_lens,
 | 
					        # sequences or heads is large, we use V1 since there is enough work
 | 
				
			||||||
            block_size,
 | 
					        # to parallelize.
 | 
				
			||||||
            input_metadata.max_context_len,
 | 
					        # TODO(woosuk): Tune this heuristic.
 | 
				
			||||||
            None,  # alibi_slopes
 | 
					        use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
 | 
				
			||||||
        )
 | 
					        if use_v1:
 | 
				
			||||||
 | 
					            # Run PagedAttention V1.
 | 
				
			||||||
 | 
					            attention_ops.paged_attention_v1(
 | 
				
			||||||
 | 
					                output,
 | 
				
			||||||
 | 
					                query,
 | 
				
			||||||
 | 
					                key_cache,
 | 
				
			||||||
 | 
					                value_cache,
 | 
				
			||||||
 | 
					                self.head_mapping,
 | 
				
			||||||
 | 
					                self.scale,
 | 
				
			||||||
 | 
					                input_metadata.block_tables,
 | 
				
			||||||
 | 
					                input_metadata.context_lens,
 | 
				
			||||||
 | 
					                block_size,
 | 
				
			||||||
 | 
					                input_metadata.max_context_len,
 | 
				
			||||||
 | 
					                alibi_slopes,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Run PagedAttention V2.
 | 
				
			||||||
 | 
					            assert _PARTITION_SIZE % block_size == 0
 | 
				
			||||||
 | 
					            tmp_output = torch.empty(
 | 
				
			||||||
 | 
					                size=(num_seqs, num_heads, max_num_partitions, head_size),
 | 
				
			||||||
 | 
					                dtype=output.dtype,
 | 
				
			||||||
 | 
					                device=output.device,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            exp_sums = torch.empty(
 | 
				
			||||||
 | 
					                size=(num_seqs, num_heads, max_num_partitions),
 | 
				
			||||||
 | 
					                dtype=torch.float32,
 | 
				
			||||||
 | 
					                device=output.device,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            max_logits = torch.empty_like(exp_sums)
 | 
				
			||||||
 | 
					            attention_ops.paged_attention_v2(
 | 
				
			||||||
 | 
					                output,
 | 
				
			||||||
 | 
					                exp_sums,
 | 
				
			||||||
 | 
					                max_logits,
 | 
				
			||||||
 | 
					                tmp_output,
 | 
				
			||||||
 | 
					                query,
 | 
				
			||||||
 | 
					                key_cache,
 | 
				
			||||||
 | 
					                value_cache,
 | 
				
			||||||
 | 
					                self.head_mapping,
 | 
				
			||||||
 | 
					                self.scale,
 | 
				
			||||||
 | 
					                input_metadata.block_tables,
 | 
				
			||||||
 | 
					                input_metadata.context_lens,
 | 
				
			||||||
 | 
					                block_size,
 | 
				
			||||||
 | 
					                input_metadata.max_context_len,
 | 
				
			||||||
 | 
					                alibi_slopes,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
@ -143,11 +236,12 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            query: shape = [num_tokens, num_heads * head_size]
 | 
					            query: shape = [num_tokens, num_heads * head_size]
 | 
				
			||||||
            key: shape = [num_tokens, num_heads * head_size]
 | 
					            key: shape = [num_tokens, num_kv_heads * head_size]
 | 
				
			||||||
            value: shape = [num_tokens, num_heads * head_size]
 | 
					            value: shape = [num_tokens, num_kv_heads * head_size]
 | 
				
			||||||
            key_cache: shape = [num_blocks, num_heads, head_size/x,
 | 
					            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
 | 
				
			||||||
                block_size, x]
 | 
					                block_size, x]
 | 
				
			||||||
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
 | 
					            value_cache: shape = [num_blocks, num_kv_heads, head_size,
 | 
				
			||||||
 | 
					                block_size]
 | 
				
			||||||
            input_metadata: metadata for paged attention.
 | 
					            input_metadata: metadata for paged attention.
 | 
				
			||||||
            cache_event: event to wait for the cache operations to finish.
 | 
					            cache_event: event to wait for the cache operations to finish.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -157,8 +251,8 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Reshape the query, key, and value tensors.
 | 
					        # Reshape the query, key, and value tensors.
 | 
				
			||||||
        query = query.view(-1, self.num_heads, self.head_size)
 | 
					        query = query.view(-1, self.num_heads, self.head_size)
 | 
				
			||||||
        key = key.view(-1, self.num_heads, self.head_size)
 | 
					        key = key.view(-1, self.num_kv_heads, self.head_size)
 | 
				
			||||||
        value = value.view(-1, self.num_heads, self.head_size)
 | 
					        value = value.view(-1, self.num_kv_heads, self.head_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Pre-allocate the output tensor.
 | 
					        # Pre-allocate the output tensor.
 | 
				
			||||||
        output = torch.empty_like(query)
 | 
					        output = torch.empty_like(query)
 | 
				
			||||||
@ -166,7 +260,9 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
        # Compute the attention op for prompts.
 | 
					        # Compute the attention op for prompts.
 | 
				
			||||||
        num_prompt_tokens = input_metadata.num_prompt_tokens
 | 
					        num_prompt_tokens = input_metadata.num_prompt_tokens
 | 
				
			||||||
        if num_prompt_tokens > 0:
 | 
					        if num_prompt_tokens > 0:
 | 
				
			||||||
            self.set_attn_bias(input_metadata)
 | 
					            # Prompt run.
 | 
				
			||||||
 | 
					            assert input_metadata.num_generation_tokens == 0
 | 
				
			||||||
 | 
					            self.set_attn_bias(input_metadata, dtype=query.dtype)
 | 
				
			||||||
            self.multi_query_kv_attention(
 | 
					            self.multi_query_kv_attention(
 | 
				
			||||||
                output[:num_prompt_tokens],
 | 
					                output[:num_prompt_tokens],
 | 
				
			||||||
                query[:num_prompt_tokens],
 | 
					                query[:num_prompt_tokens],
 | 
				
			||||||
@ -186,15 +282,25 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
        if (num_valid_tokens > 0 and key_cache is not None
 | 
					        if (num_valid_tokens > 0 and key_cache is not None
 | 
				
			||||||
                and value_cache is not None):
 | 
					                and value_cache is not None):
 | 
				
			||||||
            # The stride is 3 because the key and value are sliced from qkv.
 | 
					            # The stride is 3 because the key and value are sliced from qkv.
 | 
				
			||||||
 | 
					            key_to_cache = key[:num_valid_tokens]
 | 
				
			||||||
 | 
					            value_to_cache = value[:num_valid_tokens]
 | 
				
			||||||
 | 
					            slot_mapping = input_metadata.slot_mapping
 | 
				
			||||||
 | 
					            if input_metadata.to_cache is not None:
 | 
				
			||||||
 | 
					                key_to_cache = key_to_cache[input_metadata.to_cache]
 | 
				
			||||||
 | 
					                value_to_cache = value_to_cache[input_metadata.to_cache]
 | 
				
			||||||
 | 
					                slot_mapping = slot_mapping[input_metadata.to_cache]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cache_ops.reshape_and_cache(
 | 
					            cache_ops.reshape_and_cache(
 | 
				
			||||||
                key[:num_valid_tokens],
 | 
					                key_to_cache,
 | 
				
			||||||
                value[:num_valid_tokens],
 | 
					                value_to_cache,
 | 
				
			||||||
                key_cache,
 | 
					                key_cache,
 | 
				
			||||||
                value_cache,
 | 
					                value_cache,
 | 
				
			||||||
                input_metadata.slot_mapping,
 | 
					                slot_mapping,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if input_metadata.num_generation_tokens > 0:
 | 
					        if input_metadata.num_generation_tokens > 0:
 | 
				
			||||||
 | 
					            # Decoding run.
 | 
				
			||||||
 | 
					            assert input_metadata.num_prompt_tokens == 0
 | 
				
			||||||
            assert key_cache is not None and value_cache is not None, (
 | 
					            assert key_cache is not None and value_cache is not None, (
 | 
				
			||||||
                "key_cache and value_cache must be provided when "
 | 
					                "key_cache and value_cache must be provided when "
 | 
				
			||||||
                "generating tokens.")
 | 
					                "generating tokens.")
 | 
				
			||||||
@ -202,7 +308,7 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
            self.single_query_cached_kv_attention(
 | 
					            self.single_query_cached_kv_attention(
 | 
				
			||||||
                output[num_prompt_tokens:num_valid_tokens],
 | 
					                output[num_prompt_tokens:num_valid_tokens],
 | 
				
			||||||
                query[num_prompt_tokens:num_valid_tokens], key_cache,
 | 
					                query[num_prompt_tokens:num_valid_tokens], key_cache,
 | 
				
			||||||
                value_cache, input_metadata)
 | 
					                value_cache, input_metadata, self.get_alibi_slopes())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Reshape the output tensor.
 | 
					        # Reshape the output tensor.
 | 
				
			||||||
        # NOTE(woosuk): The output tensor may include paddings.
 | 
					        # NOTE(woosuk): The output tensor may include paddings.
 | 
				
			||||||
@ -210,7 +316,7 @@ class PagedAttention(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PagedAttentionWithRoPE(PagedAttention):
 | 
					class PagedAttentionWithRoPE(PagedAttention):
 | 
				
			||||||
    """PagedAttention with GPT-NeoX style rotary embedding."""
 | 
					    """PagedAttention with rotary positional embedding."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
@ -220,24 +326,33 @@ class PagedAttentionWithRoPE(PagedAttention):
 | 
				
			|||||||
        rotary_dim: int,
 | 
					        rotary_dim: int,
 | 
				
			||||||
        max_position: int = 8192,
 | 
					        max_position: int = 8192,
 | 
				
			||||||
        base: int = 10000,
 | 
					        base: int = 10000,
 | 
				
			||||||
 | 
					        num_kv_heads: Optional[int] = None,
 | 
				
			||||||
 | 
					        is_neox_style: bool = True,
 | 
				
			||||||
 | 
					        rope_scaling: Optional[Dict[str, Any]] = None,
 | 
				
			||||||
 | 
					        sliding_window: Optional[int] = None,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        super().__init__(num_heads, head_size, scale)
 | 
					        super().__init__(num_heads,
 | 
				
			||||||
 | 
					                         head_size,
 | 
				
			||||||
        # Create the cos and sin cache.
 | 
					                         scale,
 | 
				
			||||||
        inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
 | 
					                         num_kv_heads,
 | 
				
			||||||
        t = torch.arange(max_position).float()
 | 
					                         sliding_window=sliding_window)
 | 
				
			||||||
        freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
 | 
					        if rope_scaling is None:
 | 
				
			||||||
        cos = freqs.cos()
 | 
					            self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
 | 
				
			||||||
        sin = freqs.sin()
 | 
					                                              max_position, base,
 | 
				
			||||||
        cache = torch.cat((cos, sin), dim=-1)
 | 
					                                              is_neox_style)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
        # FIXME(woosuk): This assumes that we configure the default dtype when
 | 
					            scaling_type = rope_scaling["type"]
 | 
				
			||||||
        # initializing the model.
 | 
					            scaling_factor = rope_scaling["factor"]
 | 
				
			||||||
        # TODO(woosuk): Make it more robust.
 | 
					            if scaling_type == "linear":
 | 
				
			||||||
        torch_dtype = torch.get_default_dtype()
 | 
					                self.rotary_emb = LinearScalingRotaryEmbedding(
 | 
				
			||||||
        cache = cache.to(torch_dtype)
 | 
					                    head_size, rotary_dim, max_position, base, is_neox_style,
 | 
				
			||||||
        # Embedding size: [max_position, rotary_dim]
 | 
					                    scaling_factor)
 | 
				
			||||||
        self.register_buffer("cos_sin_cache", cache, persistent=False)
 | 
					            elif scaling_type == "dynamic":
 | 
				
			||||||
 | 
					                self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
 | 
				
			||||||
 | 
					                    head_size, rotary_dim, max_position, base, is_neox_style,
 | 
				
			||||||
 | 
					                    scaling_factor)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
@ -254,12 +369,13 @@ class PagedAttentionWithRoPE(PagedAttention):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            positions: shape = [num_tokens]
 | 
					            positions: shape = [num_tokens]
 | 
				
			||||||
                        query: shape = [num_tokens, num_heads * head_size]
 | 
					            query: shape = [num_tokens, num_heads * head_size]
 | 
				
			||||||
            key: shape = [num_tokens, num_heads * head_size]
 | 
					            key: shape = [num_tokens, num_kv_heads * head_size]
 | 
				
			||||||
            value: shape = [num_tokens, num_heads * head_size]
 | 
					            value: shape = [num_tokens, num_kv_heads * head_size]
 | 
				
			||||||
            key_cache: shape = [num_blocks, num_heads, head_size/x,
 | 
					            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
 | 
				
			||||||
                block_size, x]
 | 
					                block_size, x]
 | 
				
			||||||
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
 | 
					            value_cache: shape = [num_blocks, num_kv_heads, head_size,
 | 
				
			||||||
 | 
					                block_size]
 | 
				
			||||||
            input_metadata: metadata for paged attention.
 | 
					            input_metadata: metadata for paged attention.
 | 
				
			||||||
            cache_event: event to wait for the cache operations to finish.
 | 
					            cache_event: event to wait for the cache operations to finish.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -269,13 +385,7 @@ class PagedAttentionWithRoPE(PagedAttention):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Apply rotary embedding to the query and key before passing them
 | 
					        # Apply rotary embedding to the query and key before passing them
 | 
				
			||||||
        # to the attention op.
 | 
					        # to the attention op.
 | 
				
			||||||
        pos_encoding_ops.rotary_embedding_neox(
 | 
					        query, key = self.rotary_emb(positions, query, key)
 | 
				
			||||||
            positions,
 | 
					 | 
				
			||||||
            query,
 | 
					 | 
				
			||||||
            key,
 | 
					 | 
				
			||||||
            self.head_size,
 | 
					 | 
				
			||||||
            self.cos_sin_cache,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        return super().forward(
 | 
					        return super().forward(
 | 
				
			||||||
            query,
 | 
					            query,
 | 
				
			||||||
            key,
 | 
					            key,
 | 
				
			||||||
@ -290,26 +400,31 @@ class PagedAttentionWithRoPE(PagedAttention):
 | 
				
			|||||||
class PagedAttentionWithALiBi(PagedAttention):
 | 
					class PagedAttentionWithALiBi(PagedAttention):
 | 
				
			||||||
    """PagedAttention with ALiBi attention bias."""
 | 
					    """PagedAttention with ALiBi attention bias."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(self,
 | 
				
			||||||
        self,
 | 
					                 num_heads: int,
 | 
				
			||||||
        num_heads: int,
 | 
					                 head_size: int,
 | 
				
			||||||
        head_size: int,
 | 
					                 scale: float,
 | 
				
			||||||
        scale: float,
 | 
					                 slopes: List[float],
 | 
				
			||||||
        slopes: List[float],
 | 
					                 num_kv_heads: Optional[int] = None) -> None:
 | 
				
			||||||
    ) -> None:
 | 
					        super().__init__(num_heads, head_size, scale, num_kv_heads)
 | 
				
			||||||
        super().__init__(num_heads, head_size, scale)
 | 
					 | 
				
			||||||
        assert len(slopes) == num_heads
 | 
					        assert len(slopes) == num_heads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        slopes = torch.tensor(slopes, dtype=torch.float32)
 | 
					        slopes = torch.tensor(slopes, dtype=torch.float32)
 | 
				
			||||||
        self.register_buffer("alibi_slopes", slopes, persistent=False)
 | 
					        self.register_buffer("alibi_slopes", slopes, persistent=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_attn_bias(self, input_metadata: InputMetadata) -> None:
 | 
					    def set_attn_bias(self, input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					                      dtype: torch.dtype) -> None:
 | 
				
			||||||
        if input_metadata.attn_bias:
 | 
					        if input_metadata.attn_bias:
 | 
				
			||||||
            # Already set by a previous layer.
 | 
					            # Already set by a previous layer.
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        # Generates ALiBi mask for each prompt.
 | 
					        # Generates ALiBi mask for each prompt.
 | 
				
			||||||
        for prompt_len in input_metadata.prompt_lens:
 | 
					        for prompt_len in input_metadata.prompt_lens:
 | 
				
			||||||
            bias = torch.arange(prompt_len)
 | 
					            bias = torch.arange(prompt_len, dtype=dtype)
 | 
				
			||||||
 | 
					            # NOTE(zhuohan): HF uses
 | 
				
			||||||
 | 
					            #     `bias = bias[None, :].repeat(prompt_len, 1)`
 | 
				
			||||||
 | 
					            # here. We find that both biases give the same results, but
 | 
				
			||||||
 | 
					            # the bias below more accurately follows the original ALiBi
 | 
				
			||||||
 | 
					            # paper.
 | 
				
			||||||
            bias = bias[None, :] - bias[:, None]
 | 
					            bias = bias[None, :] - bias[:, None]
 | 
				
			||||||
            bias = bias.to(self.alibi_slopes.device)
 | 
					            bias = bias.to(self.alibi_slopes.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -317,11 +432,13 @@ class PagedAttentionWithALiBi(PagedAttention):
 | 
				
			|||||||
            # be sliced from a tensor whose length is a multiple of 8.
 | 
					            # be sliced from a tensor whose length is a multiple of 8.
 | 
				
			||||||
            padded_len = (prompt_len + 7) // 8 * 8
 | 
					            padded_len = (prompt_len + 7) // 8 * 8
 | 
				
			||||||
            bias = torch.empty(
 | 
					            bias = torch.empty(
 | 
				
			||||||
 | 
					                1,  # batch_size
 | 
				
			||||||
                self.num_heads,
 | 
					                self.num_heads,
 | 
				
			||||||
                padded_len,
 | 
					                prompt_len,
 | 
				
			||||||
                padded_len,
 | 
					                padded_len,
 | 
				
			||||||
                device=self.alibi_slopes.device,
 | 
					                device=self.alibi_slopes.device,
 | 
				
			||||||
            )[:, :prompt_len, :prompt_len].copy_(bias)
 | 
					                dtype=dtype,
 | 
				
			||||||
 | 
					            )[:, :, :, :prompt_len].copy_(bias)
 | 
				
			||||||
            bias.mul_(self.alibi_slopes[:, None, None])
 | 
					            bias.mul_(self.alibi_slopes[:, None, None])
 | 
				
			||||||
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
 | 
					            attn_bias = LowerTriangularMaskWithTensorBias(bias)
 | 
				
			||||||
            input_metadata.attn_bias.append(attn_bias)
 | 
					            input_metadata.attn_bias.append(attn_bias)
 | 
				
			||||||
@ -339,10 +456,17 @@ class PagedAttentionWithALiBi(PagedAttention):
 | 
				
			|||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            output: shape = [num_prompt_tokens, num_heads, head_size]
 | 
					            output: shape = [num_prompt_tokens, num_heads, head_size]
 | 
				
			||||||
            query: shape = [num_prompt_tokens, num_heads, head_size]
 | 
					            query: shape = [num_prompt_tokens, num_heads, head_size]
 | 
				
			||||||
            key: shape = [num_prompt_tokens, num_heads, head_size]
 | 
					            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
 | 
				
			||||||
            value: shape = [num_prompt_tokens, num_heads, head_size]
 | 
					            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
 | 
				
			||||||
            input_metadata: metadata for paged attention.
 | 
					            input_metadata: metadata for paged attention.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        if self.num_kv_heads != self.num_heads:
 | 
				
			||||||
 | 
					            # Project the key and value tensors to the desired number of heads.
 | 
				
			||||||
 | 
					            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
 | 
				
			||||||
 | 
					            value = torch.repeat_interleave(value,
 | 
				
			||||||
 | 
					                                            self.num_queries_per_kv,
 | 
				
			||||||
 | 
					                                            dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # FIXME(woosuk): Because xformers does not support dynamic sequence
 | 
					        # FIXME(woosuk): Because xformers does not support dynamic sequence
 | 
				
			||||||
        # lengths with custom attention bias, we process each prompt one by
 | 
					        # lengths with custom attention bias, we process each prompt one by
 | 
				
			||||||
        # one. This is inefficient, especially when we have many short prompts.
 | 
					        # one. This is inefficient, especially when we have many short prompts.
 | 
				
			||||||
@ -356,41 +480,11 @@ class PagedAttentionWithALiBi(PagedAttention):
 | 
				
			|||||||
                attn_bias=input_metadata.attn_bias[i],
 | 
					                attn_bias=input_metadata.attn_bias[i],
 | 
				
			||||||
                p=0.0,
 | 
					                p=0.0,
 | 
				
			||||||
                scale=self.scale,
 | 
					                scale=self.scale,
 | 
				
			||||||
                op=self.attn_op,
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            # TODO(woosuk): Unnecessary copy. Optimize.
 | 
					            # TODO(woosuk): Unnecessary copy. Optimize.
 | 
				
			||||||
            output[start:end].copy_(out.squeeze(0))
 | 
					            output[start:end].copy_(out.squeeze(0))
 | 
				
			||||||
            start += prompt_len
 | 
					            start += prompt_len
 | 
				
			||||||
        return output
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def single_query_cached_kv_attention(
 | 
					    def get_alibi_slopes(self) -> Optional[torch.Tensor]:
 | 
				
			||||||
        self,
 | 
					        return self.alibi_slopes
 | 
				
			||||||
        output: torch.Tensor,
 | 
					 | 
				
			||||||
        query: torch.Tensor,
 | 
					 | 
				
			||||||
        key_cache: torch.Tensor,
 | 
					 | 
				
			||||||
        value_cache: torch.Tensor,
 | 
					 | 
				
			||||||
        input_metadata: InputMetadata,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        """PagedAttention with ALiBi bias for the generation tokens.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            output: shape = [num_generation_tokens, num_heads, head_size]
 | 
					 | 
				
			||||||
            query: shape = [num_generation_tokens, num_heads, head_size]
 | 
					 | 
				
			||||||
            key_cache: shape = [num_blocks, num_heads, head_size/x,
 | 
					 | 
				
			||||||
                block_size, x]
 | 
					 | 
				
			||||||
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
 | 
					 | 
				
			||||||
            input_metadata: metadata for paged attention.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        block_size = value_cache.shape[3]
 | 
					 | 
				
			||||||
        attention_ops.single_query_cached_kv_attention(
 | 
					 | 
				
			||||||
            output,
 | 
					 | 
				
			||||||
            query,
 | 
					 | 
				
			||||||
            key_cache,
 | 
					 | 
				
			||||||
            value_cache,
 | 
					 | 
				
			||||||
            self.scale,
 | 
					 | 
				
			||||||
            input_metadata.block_tables,
 | 
					 | 
				
			||||||
            input_metadata.context_lens,
 | 
					 | 
				
			||||||
            block_size,
 | 
					 | 
				
			||||||
            input_metadata.max_context_len,
 | 
					 | 
				
			||||||
            self.alibi_slopes,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										37
									
								
								vllm/model_executor/layers/quantized_linear/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,37 @@
 | 
				
			|||||||
 | 
					from vllm.model_executor.layers.quantized_linear.awq import (
 | 
				
			||||||
 | 
					    AWQColumnParallelLinear, AWQRowParallelLinear)
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
 | 
				
			||||||
 | 
					                                                       RowParallelLinear)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_QUANTIZED_LINEAR_REGISTRY = {
 | 
				
			||||||
 | 
					    "awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ParallelLinear:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def column(cls, *args, **kwargs) -> ColumnParallelLinear:
 | 
				
			||||||
 | 
					        quant_config = kwargs.get("quant_config", None)
 | 
				
			||||||
 | 
					        if quant_config is None:
 | 
				
			||||||
 | 
					            return ColumnParallelLinear(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        name = quant_config.get_name()
 | 
				
			||||||
 | 
					        if name not in _QUANTIZED_LINEAR_REGISTRY:
 | 
				
			||||||
 | 
					            raise ValueError(f"No quantized linear is found for {name}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
 | 
				
			||||||
 | 
					        return quant_linear_cls(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def row(cls, *args, **kwargs) -> RowParallelLinear:
 | 
				
			||||||
 | 
					        quant_config = kwargs.get("quant_config", None)
 | 
				
			||||||
 | 
					        if quant_config is None:
 | 
				
			||||||
 | 
					            return RowParallelLinear(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        name = quant_config.get_name()
 | 
				
			||||||
 | 
					        if name not in _QUANTIZED_LINEAR_REGISTRY:
 | 
				
			||||||
 | 
					            raise ValueError(f"No quantized linear is found for {name}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
 | 
				
			||||||
 | 
					        return quant_linear_cls(*args, **kwargs)
 | 
				
			||||||
							
								
								
									
										102
									
								
								vllm/model_executor/layers/quantized_linear/awq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,102 @@
 | 
				
			|||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm import quantization_ops
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
 | 
				
			||||||
 | 
					                                                       RowParallelLinear)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AWQColumnParallelLinear(ColumnParallelLinear):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def create_weights(self, dtype: torch.dtype) -> None:
 | 
				
			||||||
 | 
					        assert self.input_size % self.quant_config.weight_bits == 0
 | 
				
			||||||
 | 
					        assert (self.output_size_per_partition %
 | 
				
			||||||
 | 
					                self.quant_config.pack_factor == 0)
 | 
				
			||||||
 | 
					        self.qweight = Parameter(
 | 
				
			||||||
 | 
					            torch.empty(
 | 
				
			||||||
 | 
					                self.input_size,
 | 
				
			||||||
 | 
					                self.output_size_per_partition //
 | 
				
			||||||
 | 
					                self.quant_config.pack_factor,
 | 
				
			||||||
 | 
					                device="cuda",
 | 
				
			||||||
 | 
					                dtype=torch.int32,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            requires_grad=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.qzeros = Parameter(
 | 
				
			||||||
 | 
					            torch.empty(
 | 
				
			||||||
 | 
					                self.input_size // self.quant_config.group_size,
 | 
				
			||||||
 | 
					                self.output_size_per_partition //
 | 
				
			||||||
 | 
					                self.quant_config.pack_factor,
 | 
				
			||||||
 | 
					                device="cuda",
 | 
				
			||||||
 | 
					                dtype=torch.int32,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            requires_grad=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.scales = Parameter(
 | 
				
			||||||
 | 
					            torch.empty(
 | 
				
			||||||
 | 
					                self.input_size // self.quant_config.group_size,
 | 
				
			||||||
 | 
					                self.output_size_per_partition,
 | 
				
			||||||
 | 
					                device="cuda",
 | 
				
			||||||
 | 
					                dtype=dtype,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            requires_grad=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def apply_weights(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        x: torch.Tensor,
 | 
				
			||||||
 | 
					        bias: Optional[torch.Tensor],
 | 
				
			||||||
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
 | 
					        pack_factor = self.quant_config.pack_factor
 | 
				
			||||||
 | 
					        out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
 | 
				
			||||||
 | 
					        reshaped_x = x.reshape(-1, x.shape[-1])
 | 
				
			||||||
 | 
					        out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
 | 
				
			||||||
 | 
					                                        self.qzeros, pack_factor)
 | 
				
			||||||
 | 
					        if bias is not None:
 | 
				
			||||||
 | 
					            out = out + bias
 | 
				
			||||||
 | 
					        return out.reshape(out_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AWQRowParallelLinear(RowParallelLinear):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def create_weights(self, dtype: torch.dtype) -> None:
 | 
				
			||||||
 | 
					        assert (self.input_size_per_partition %
 | 
				
			||||||
 | 
					                self.quant_config.weight_bits == 0)
 | 
				
			||||||
 | 
					        assert self.output_size % self.quant_config.pack_factor == 0
 | 
				
			||||||
 | 
					        self.qweight = Parameter(
 | 
				
			||||||
 | 
					            torch.empty(
 | 
				
			||||||
 | 
					                self.input_size_per_partition,
 | 
				
			||||||
 | 
					                self.output_size // self.quant_config.pack_factor,
 | 
				
			||||||
 | 
					                device="cuda",
 | 
				
			||||||
 | 
					                dtype=torch.int32,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            requires_grad=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.qzeros = Parameter(
 | 
				
			||||||
 | 
					            torch.empty(
 | 
				
			||||||
 | 
					                self.input_size_per_partition // self.quant_config.group_size,
 | 
				
			||||||
 | 
					                self.output_size // self.quant_config.pack_factor,
 | 
				
			||||||
 | 
					                device="cuda",
 | 
				
			||||||
 | 
					                dtype=torch.int32,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            requires_grad=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.scales = Parameter(
 | 
				
			||||||
 | 
					            torch.empty(
 | 
				
			||||||
 | 
					                self.input_size_per_partition // self.quant_config.group_size,
 | 
				
			||||||
 | 
					                self.output_size,
 | 
				
			||||||
 | 
					                device="cuda",
 | 
				
			||||||
 | 
					                dtype=dtype,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            requires_grad=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        pack_factor = self.quant_config.pack_factor
 | 
				
			||||||
 | 
					        out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
 | 
				
			||||||
 | 
					        reshaped_x = x.reshape(-1, x.shape[-1])
 | 
				
			||||||
 | 
					        out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
 | 
				
			||||||
 | 
					                                        self.qzeros, pack_factor)
 | 
				
			||||||
 | 
					        return out.reshape(out_shape)
 | 
				
			||||||
							
								
								
									
										169
									
								
								vllm/model_executor/layers/rotary_embedding.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,169 @@
 | 
				
			|||||||
 | 
					# coding=utf-8
 | 
				
			||||||
 | 
					# Adapted from
 | 
				
			||||||
 | 
					# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
 | 
				
			||||||
 | 
					# Copyright 2023 The vLLM team.
 | 
				
			||||||
 | 
					# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
 | 
				
			||||||
 | 
					# and OPT implementations in this library. It has been modified from its
 | 
				
			||||||
 | 
					# original forms to accommodate minor architectural differences compared
 | 
				
			||||||
 | 
					# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Rotary Positional Embeddings."""
 | 
				
			||||||
 | 
					from typing import Tuple, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm import pos_encoding_ops
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RotaryEmbedding(nn.Module):
 | 
				
			||||||
 | 
					    """Original rotary positional embedding."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        head_size: int,
 | 
				
			||||||
 | 
					        rotary_dim: int,
 | 
				
			||||||
 | 
					        max_position_embeddings: int,
 | 
				
			||||||
 | 
					        base: int,
 | 
				
			||||||
 | 
					        is_neox_style: bool,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.head_size = head_size
 | 
				
			||||||
 | 
					        self.rotary_dim = rotary_dim
 | 
				
			||||||
 | 
					        self.max_position_embeddings = max_position_embeddings
 | 
				
			||||||
 | 
					        self.base = base
 | 
				
			||||||
 | 
					        self.is_neox_style = is_neox_style
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cache = self._compute_cos_sin_cache()
 | 
				
			||||||
 | 
					        cache = cache.to(torch.get_default_dtype())
 | 
				
			||||||
 | 
					        self.register_buffer("cos_sin_cache", cache, persistent=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
 | 
				
			||||||
 | 
					        """Compute the inverse frequency."""
 | 
				
			||||||
 | 
					        # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
 | 
				
			||||||
 | 
					        # However, we use `torch.arange(..., dtype=torch.float)` instead to
 | 
				
			||||||
 | 
					        # avoid numerical issues with large base values (e.g., 10000000).
 | 
				
			||||||
 | 
					        # This may cause a slight numerical difference between the HF
 | 
				
			||||||
 | 
					        # implementation and ours.
 | 
				
			||||||
 | 
					        # NOTE(woosuk): To exactly match the HF implementation, we need to
 | 
				
			||||||
 | 
					        # use CPU to compute the cache and then move it to GPU. However, we
 | 
				
			||||||
 | 
					        # create the cache on GPU for faster initialization. This may cause
 | 
				
			||||||
 | 
					        # a slight numerical difference between the HF implementation and ours.
 | 
				
			||||||
 | 
					        inv_freq = 1.0 / (base**(torch.arange(
 | 
				
			||||||
 | 
					            0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
 | 
				
			||||||
 | 
					                                 self.rotary_dim))
 | 
				
			||||||
 | 
					        return inv_freq
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _compute_cos_sin_cache(self) -> torch.Tensor:
 | 
				
			||||||
 | 
					        """Compute the cos and sin cache."""
 | 
				
			||||||
 | 
					        inv_freq = self._compute_inv_freq(self.base)
 | 
				
			||||||
 | 
					        t = torch.arange(self.max_position_embeddings,
 | 
				
			||||||
 | 
					                         dtype=torch.float,
 | 
				
			||||||
 | 
					                         device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        freqs = torch.einsum("i,j -> ij", t, inv_freq)
 | 
				
			||||||
 | 
					        cos = freqs.cos()
 | 
				
			||||||
 | 
					        sin = freqs.sin()
 | 
				
			||||||
 | 
					        cache = torch.cat((cos, sin), dim=-1)
 | 
				
			||||||
 | 
					        return cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        query: torch.Tensor,
 | 
				
			||||||
 | 
					        key: torch.Tensor,
 | 
				
			||||||
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					        # pos_encoding_ops.rotary_embedding() is an in-place operation that
 | 
				
			||||||
 | 
					        # updates the query and key tensors.
 | 
				
			||||||
 | 
					        pos_encoding_ops.rotary_embedding(positions, query, key,
 | 
				
			||||||
 | 
					                                          self.head_size, self.cos_sin_cache,
 | 
				
			||||||
 | 
					                                          self.is_neox_style)
 | 
				
			||||||
 | 
					        return query, key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LinearScalingRotaryEmbedding(RotaryEmbedding):
 | 
				
			||||||
 | 
					    """RotaryEmbedding extended with linear scaling.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Credits to the Reddit user /u/kaiokendev
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        head_size: int,
 | 
				
			||||||
 | 
					        rotary_dim: int,
 | 
				
			||||||
 | 
					        max_position_embeddings: int,
 | 
				
			||||||
 | 
					        base: int,
 | 
				
			||||||
 | 
					        is_neox_style: bool,
 | 
				
			||||||
 | 
					        scaling_factor: float,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        self.scaling_factor = scaling_factor
 | 
				
			||||||
 | 
					        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
 | 
				
			||||||
 | 
					                         is_neox_style)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _compute_cos_sin_cache(self) -> torch.Tensor:
 | 
				
			||||||
 | 
					        inv_freq = self._compute_inv_freq(self.base)
 | 
				
			||||||
 | 
					        # NOTE(woosuk): self.max_position_embeddings is the original
 | 
				
			||||||
 | 
					        # maximum length before applying the rope scaling.
 | 
				
			||||||
 | 
					        # Thus, the maximum length after applying the rope scaling is
 | 
				
			||||||
 | 
					        # self.max_position_embeddings * self.scaling_factor.
 | 
				
			||||||
 | 
					        max_len = self.max_position_embeddings * self.scaling_factor
 | 
				
			||||||
 | 
					        t = torch.arange(max_len, dtype=torch.float, device="cuda")
 | 
				
			||||||
 | 
					        t = t / self.scaling_factor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        freqs = torch.einsum("i,j -> ij", t, inv_freq)
 | 
				
			||||||
 | 
					        cos = freqs.cos()
 | 
				
			||||||
 | 
					        sin = freqs.sin()
 | 
				
			||||||
 | 
					        cache = torch.cat((cos, sin), dim=-1)
 | 
				
			||||||
 | 
					        return cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
 | 
				
			||||||
 | 
					    """RotaryEmbedding extended with Dynamic NTK scaling.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Credits to the Reddit users /u/bloc97 and /u/emozilla
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        head_size: int,
 | 
				
			||||||
 | 
					        rotary_dim: int,
 | 
				
			||||||
 | 
					        max_position_embeddings: int,
 | 
				
			||||||
 | 
					        base: int,
 | 
				
			||||||
 | 
					        is_neox_style: bool,
 | 
				
			||||||
 | 
					        scaling_factor: float,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        self.scaling_factor = scaling_factor
 | 
				
			||||||
 | 
					        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
 | 
				
			||||||
 | 
					                         is_neox_style)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _compute_cos_sin_cache(self) -> torch.Tensor:
 | 
				
			||||||
 | 
					        # NOTE(woosuk): self.max_position_embeddings is the original
 | 
				
			||||||
 | 
					        # maximum length before applying the rope scaling.
 | 
				
			||||||
 | 
					        # Thus, the maximum length after applying the rope scaling is
 | 
				
			||||||
 | 
					        # self.max_position_embeddings * self.scaling_factor.
 | 
				
			||||||
 | 
					        max_len = self.max_position_embeddings * self.scaling_factor
 | 
				
			||||||
 | 
					        base = self.base * (
 | 
				
			||||||
 | 
					            (self.scaling_factor * max_len / self.max_position_embeddings) -
 | 
				
			||||||
 | 
					            (self.scaling_factor - 1))**(self.rotary_dim /
 | 
				
			||||||
 | 
					                                         (self.rotary_dim - 2))
 | 
				
			||||||
 | 
					        inv_freq = self._compute_inv_freq(base)
 | 
				
			||||||
 | 
					        t = torch.arange(max_len, dtype=torch.float, device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        freqs = torch.einsum("i,j -> ij", t, inv_freq)
 | 
				
			||||||
 | 
					        cos = freqs.cos()
 | 
				
			||||||
 | 
					        sin = freqs.sin()
 | 
				
			||||||
 | 
					        cache = torch.cat((cos, sin), dim=-1)
 | 
				
			||||||
 | 
					        return cache
 | 
				
			||||||
@ -1,15 +1,15 @@
 | 
				
			|||||||
"""A layer that samples the next tokens from the model's outputs."""
 | 
					"""A layer that samples the next tokens from the model's outputs."""
 | 
				
			||||||
from typing import Dict, List, Tuple, Optional
 | 
					from typing import Dict, List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm.model_executor.input_metadata import InputMetadata
 | 
					from vllm.model_executor.input_metadata import InputMetadata
 | 
				
			||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
 | 
					from vllm.model_executor.parallel_utils.communication_op import (
 | 
				
			||||||
    gather_from_tensor_model_parallel_region)
 | 
					    tensor_model_parallel_all_gather)
 | 
				
			||||||
from vllm.sampling_params import SamplingParams
 | 
					from vllm.sampling_params import SamplingParams, SamplingType
 | 
				
			||||||
from vllm.sequence import SequenceOutputs
 | 
					from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
 | 
				
			||||||
 | 
					                           SequenceData, SequenceGroupOutputs, SequenceOutputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_SAMPLING_EPS = 1e-5
 | 
					_SAMPLING_EPS = 1e-5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -38,15 +38,14 @@ class Sampler(nn.Module):
 | 
				
			|||||||
        embedding: torch.Tensor,
 | 
					        embedding: torch.Tensor,
 | 
				
			||||||
        hidden_states: torch.Tensor,
 | 
					        hidden_states: torch.Tensor,
 | 
				
			||||||
        input_metadata: InputMetadata,
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
    ) -> Dict[int, SequenceOutputs]:
 | 
					        embedding_bias: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    ) -> SamplerOutput:
 | 
				
			||||||
        # Get the hidden states that we use for sampling.
 | 
					        # Get the hidden states that we use for sampling.
 | 
				
			||||||
        hidden_states = _prune_hidden_states(hidden_states, input_metadata)
 | 
					        hidden_states = _prune_hidden_states(hidden_states, input_metadata)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Get the logits for the next tokens.
 | 
					        # Get the logits for the next tokens.
 | 
				
			||||||
        logits = torch.matmul(hidden_states, embedding.t())
 | 
					        logits = _get_logits(hidden_states, embedding, embedding_bias,
 | 
				
			||||||
        logits = gather_from_tensor_model_parallel_region(logits)
 | 
					                             self.vocab_size)
 | 
				
			||||||
        # Remove paddings in vocab (if any).
 | 
					 | 
				
			||||||
        logits = logits[:, :self.vocab_size]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Apply presence and frequency penalties.
 | 
					        # Apply presence and frequency penalties.
 | 
				
			||||||
        output_tokens = _get_output_tokens(input_metadata)
 | 
					        output_tokens = _get_output_tokens(input_metadata)
 | 
				
			||||||
@ -56,7 +55,7 @@ class Sampler(nn.Module):
 | 
				
			|||||||
        assert len(presence_penalties) == logits.shape[0]
 | 
					        assert len(presence_penalties) == logits.shape[0]
 | 
				
			||||||
        assert len(frequency_penalties) == logits.shape[0]
 | 
					        assert len(frequency_penalties) == logits.shape[0]
 | 
				
			||||||
        logits = _apply_penalties(logits, output_tokens, presence_penalties,
 | 
					        logits = _apply_penalties(logits, output_tokens, presence_penalties,
 | 
				
			||||||
                                  frequency_penalties, self.vocab_size)
 | 
					                                  frequency_penalties)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Apply temperature scaling.
 | 
					        # Apply temperature scaling.
 | 
				
			||||||
        temperatures = _get_temperatures(input_metadata)
 | 
					        temperatures = _get_temperatures(input_metadata)
 | 
				
			||||||
@ -68,36 +67,69 @@ class Sampler(nn.Module):
 | 
				
			|||||||
            # Use in-place division to avoid creating a new tensor.
 | 
					            # Use in-place division to avoid creating a new tensor.
 | 
				
			||||||
            logits.div_(t.unsqueeze(dim=1))
 | 
					            logits.div_(t.unsqueeze(dim=1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # We use float32 for probabilities and log probabilities.
 | 
					 | 
				
			||||||
        # Compute the probabilities.
 | 
					 | 
				
			||||||
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
 | 
					 | 
				
			||||||
        # Compute the log probabilities (before applying top-p and top-k).
 | 
					 | 
				
			||||||
        logprobs = torch.log(probs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Apply top-p and top-k truncation.
 | 
					        # Apply top-p and top-k truncation.
 | 
				
			||||||
        top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
 | 
					        top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
 | 
				
			||||||
        assert len(top_ps) == len(top_ks) == probs.shape[0]
 | 
					        assert len(top_ps) == len(top_ks) == logits.shape[0]
 | 
				
			||||||
        do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
 | 
					        do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
 | 
				
			||||||
        do_top_k = any(k != self.vocab_size for k in top_ks)
 | 
					        do_top_k = any(k != self.vocab_size for k in top_ks)
 | 
				
			||||||
        if do_top_p or do_top_k:
 | 
					        if do_top_p or do_top_k:
 | 
				
			||||||
            probs = _apply_top_p_top_k(probs, top_ps, top_ks)
 | 
					            logits = _apply_top_p_top_k(logits, top_ps, top_ks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # We use float32 for probabilities and log probabilities.
 | 
				
			||||||
 | 
					        # Compute the probabilities.
 | 
				
			||||||
 | 
					        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
 | 
				
			||||||
 | 
					        # Compute the log probabilities.
 | 
				
			||||||
 | 
					        # Use log_softmax to ensure numerical stability.
 | 
				
			||||||
 | 
					        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Sample the next tokens.
 | 
					        # Sample the next tokens.
 | 
				
			||||||
        return _sample(probs, logprobs, input_metadata)
 | 
					        sample_results = _sample(probs, logprobs, input_metadata)
 | 
				
			||||||
 | 
					        # Get the logprobs query results.
 | 
				
			||||||
 | 
					        prompt_logprobs, sample_logprobs = _get_logprobs(
 | 
				
			||||||
 | 
					            logprobs, input_metadata, sample_results)
 | 
				
			||||||
 | 
					        return _build_sampler_output(sample_results, input_metadata,
 | 
				
			||||||
 | 
					                                     prompt_logprobs, sample_logprobs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
 | 
				
			||||||
 | 
					                embedding_bias: Optional[torch.Tensor],
 | 
				
			||||||
 | 
					                vocab_size: int) -> torch.Tensor:
 | 
				
			||||||
 | 
					    # Get the logits for the next tokens.
 | 
				
			||||||
 | 
					    logits = torch.matmul(hidden_states, embedding.t())
 | 
				
			||||||
 | 
					    if embedding_bias is not None:
 | 
				
			||||||
 | 
					        logits += embedding_bias
 | 
				
			||||||
 | 
					    logits = tensor_model_parallel_all_gather(logits)
 | 
				
			||||||
 | 
					    # Remove paddings in vocab (if any).
 | 
				
			||||||
 | 
					    logits = logits[:, :vocab_size]
 | 
				
			||||||
 | 
					    return logits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _prune_hidden_states(
 | 
					def _prune_hidden_states(
 | 
				
			||||||
    hidden_states: torch.Tensor,
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
    input_metadata: InputMetadata,
 | 
					    input_metadata: InputMetadata,
 | 
				
			||||||
) -> torch.Tensor:
 | 
					) -> torch.Tensor:
 | 
				
			||||||
 | 
					    selected_token_indices: List[int] = []
 | 
				
			||||||
    start_idx = 0
 | 
					    start_idx = 0
 | 
				
			||||||
    last_token_indicies: List[int] = []
 | 
					    for i, seq_group in enumerate(input_metadata.seq_groups):
 | 
				
			||||||
    for prompt_len in input_metadata.prompt_lens:
 | 
					        seq_ids, sampling_params = seq_group
 | 
				
			||||||
        last_token_indicies.append(start_idx + prompt_len - 1)
 | 
					        if i < input_metadata.num_prompts:
 | 
				
			||||||
        start_idx += prompt_len
 | 
					            assert len(seq_ids) == 1, "Prompt input should have only one seq."
 | 
				
			||||||
    last_token_indicies.extend(
 | 
					            prompt_len = input_metadata.prompt_lens[i]
 | 
				
			||||||
        range(start_idx, start_idx + input_metadata.num_generation_tokens))
 | 
					            if sampling_params.prompt_logprobs is not None:
 | 
				
			||||||
    return hidden_states[last_token_indicies]
 | 
					                selected_token_indices.extend(
 | 
				
			||||||
 | 
					                    range(start_idx, start_idx + prompt_len - 1))
 | 
				
			||||||
 | 
					            selected_token_indices.append(start_idx + prompt_len - 1)
 | 
				
			||||||
 | 
					            start_idx += prompt_len
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            num_seqs = len(seq_ids)
 | 
				
			||||||
 | 
					            selected_token_indices.extend(
 | 
				
			||||||
 | 
					                range(start_idx, start_idx + num_seqs))
 | 
				
			||||||
 | 
					            start_idx += num_seqs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    selected_token_indices = torch.tensor(selected_token_indices,
 | 
				
			||||||
 | 
					                                          dtype=torch.long,
 | 
				
			||||||
 | 
					                                          device=hidden_states.device)
 | 
				
			||||||
 | 
					    return hidden_states.index_select(0, selected_token_indices)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_penalties(
 | 
					def _get_penalties(
 | 
				
			||||||
@ -109,33 +141,31 @@ def _get_penalties(
 | 
				
			|||||||
        seq_ids, sampling_params = seq_group
 | 
					        seq_ids, sampling_params = seq_group
 | 
				
			||||||
        p = sampling_params.presence_penalty
 | 
					        p = sampling_params.presence_penalty
 | 
				
			||||||
        f = sampling_params.frequency_penalty
 | 
					        f = sampling_params.frequency_penalty
 | 
				
			||||||
        if i < input_metadata.num_prompts:
 | 
					        if (i < input_metadata.num_prompts
 | 
				
			||||||
            # A prompt input.
 | 
					                and sampling_params.prompt_logprobs is not None):
 | 
				
			||||||
            presence_penalties.append(p)
 | 
					            # NOTE: We do not apply presence and frequency penalties for the
 | 
				
			||||||
            frequency_penalties.append(f)
 | 
					            # prompt token positions where we don't sample new tokens.
 | 
				
			||||||
        else:
 | 
					            prompt_len = input_metadata.prompt_lens[i]
 | 
				
			||||||
            # A generation token.
 | 
					            presence_penalties += [0] * (prompt_len - 1)
 | 
				
			||||||
            presence_penalties += [p] * len(seq_ids)
 | 
					            frequency_penalties += [0] * (prompt_len - 1)
 | 
				
			||||||
            frequency_penalties += [f] * len(seq_ids)
 | 
					        presence_penalties += [p] * len(seq_ids)
 | 
				
			||||||
 | 
					        frequency_penalties += [f] * len(seq_ids)
 | 
				
			||||||
    return presence_penalties, frequency_penalties
 | 
					    return presence_penalties, frequency_penalties
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
 | 
					def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
 | 
				
			||||||
    output_tokens: List[List[int]] = []
 | 
					    output_tokens: List[List[int]] = []
 | 
				
			||||||
    for i, seq_group in enumerate(input_metadata.seq_groups):
 | 
					    for i, seq_group in enumerate(input_metadata.seq_groups):
 | 
				
			||||||
        seq_ids, _ = seq_group
 | 
					        seq_ids, sampling_params = seq_group
 | 
				
			||||||
        if i < input_metadata.num_prompts:
 | 
					        if (i < input_metadata.num_prompts
 | 
				
			||||||
            # A prompt input.
 | 
					                and sampling_params.prompt_logprobs is not None):
 | 
				
			||||||
            # NOTE: While the prompt input usually has no output tokens,
 | 
					            # NOTE: prompt token positions do not need output tokens to
 | 
				
			||||||
            # it may have output tokens in the case of recomputation.
 | 
					            # compute penalties.
 | 
				
			||||||
            seq_id = seq_ids[0]
 | 
					            prompt_len = input_metadata.prompt_lens[i]
 | 
				
			||||||
 | 
					            output_tokens.extend([] for _ in range(prompt_len - 1))
 | 
				
			||||||
 | 
					        for seq_id in seq_ids:
 | 
				
			||||||
            seq_data = input_metadata.seq_data[seq_id]
 | 
					            seq_data = input_metadata.seq_data[seq_id]
 | 
				
			||||||
            output_tokens.append(seq_data.output_token_ids)
 | 
					            output_tokens.append(seq_data.output_token_ids)
 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # A generation token.
 | 
					 | 
				
			||||||
            for seq_id in seq_ids:
 | 
					 | 
				
			||||||
                seq_data = input_metadata.seq_data[seq_id]
 | 
					 | 
				
			||||||
                output_tokens.append(seq_data.output_token_ids)
 | 
					 | 
				
			||||||
    return output_tokens
 | 
					    return output_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -144,45 +174,49 @@ def _apply_penalties(
 | 
				
			|||||||
    output_tokens: List[List[int]],
 | 
					    output_tokens: List[List[int]],
 | 
				
			||||||
    presence_penalties: List[float],
 | 
					    presence_penalties: List[float],
 | 
				
			||||||
    frequency_penalties: List[float],
 | 
					    frequency_penalties: List[float],
 | 
				
			||||||
    vocab_size: int,
 | 
					 | 
				
			||||||
) -> torch.Tensor:
 | 
					) -> torch.Tensor:
 | 
				
			||||||
    num_seqs = logits.shape[0]
 | 
					    num_seqs, vocab_size = logits.shape
 | 
				
			||||||
    # Collect the indices of sequences that have non-zero penalties.
 | 
					 | 
				
			||||||
    indices = []
 | 
					 | 
				
			||||||
    for i in range(num_seqs):
 | 
					    for i in range(num_seqs):
 | 
				
			||||||
        if not output_tokens[i]:
 | 
					        if not output_tokens[i]:
 | 
				
			||||||
            continue
 | 
					            continue
 | 
				
			||||||
        p = presence_penalties[i]
 | 
					        p = presence_penalties[i]
 | 
				
			||||||
        f = frequency_penalties[i]
 | 
					        f = frequency_penalties[i]
 | 
				
			||||||
        if p < _SAMPLING_EPS and f < _SAMPLING_EPS:
 | 
					        if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
 | 
				
			||||||
            continue
 | 
					            continue
 | 
				
			||||||
        indices.append(i)
 | 
					        break
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
    # Return early if all sequences have zero penalties.
 | 
					        # Return early if all sequences have zero penalties.
 | 
				
			||||||
    if not indices:
 | 
					 | 
				
			||||||
        return logits
 | 
					        return logits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bin_counts = []
 | 
					    max_output_len = max(len(tokens) for tokens in output_tokens)
 | 
				
			||||||
    for i in indices:
 | 
					    padded_output_tokens = [
 | 
				
			||||||
        bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
 | 
					        tokens + [vocab_size] * (max_output_len - len(tokens))
 | 
				
			||||||
    bin_counts = np.stack(bin_counts, axis=0)
 | 
					        for tokens in output_tokens
 | 
				
			||||||
    bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
 | 
					    ]
 | 
				
			||||||
                                                 device=logits.device)
 | 
					    output_tokens_tensor = torch.tensor(padded_output_tokens,
 | 
				
			||||||
 | 
					                                        dtype=torch.long,
 | 
				
			||||||
 | 
					                                        device=logits.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute the bin counts for the output tokens.
 | 
				
			||||||
 | 
					    # vocab_size + 1 for padding.
 | 
				
			||||||
 | 
					    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
 | 
				
			||||||
 | 
					                             dtype=torch.long,
 | 
				
			||||||
 | 
					                             device=logits.device)
 | 
				
			||||||
 | 
					    bin_counts.scatter_add_(1, output_tokens_tensor,
 | 
				
			||||||
 | 
					                            torch.ones_like(output_tokens_tensor))
 | 
				
			||||||
 | 
					    bin_counts = bin_counts[:, :vocab_size]  # Remove the padding bin.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    frequency_penalties = [frequency_penalties[i] for i in indices]
 | 
					 | 
				
			||||||
    frequency_penalties = torch.tensor(frequency_penalties,
 | 
					    frequency_penalties = torch.tensor(frequency_penalties,
 | 
				
			||||||
                                       dtype=logits.dtype,
 | 
					                                       dtype=logits.dtype,
 | 
				
			||||||
                                       device=logits.device)
 | 
					                                       device=logits.device)
 | 
				
			||||||
    presence_penalties = [presence_penalties[i] for i in indices]
 | 
					 | 
				
			||||||
    presence_penalties = torch.tensor(presence_penalties,
 | 
					    presence_penalties = torch.tensor(presence_penalties,
 | 
				
			||||||
                                      dtype=logits.dtype,
 | 
					                                      dtype=logits.dtype,
 | 
				
			||||||
                                      device=logits.device)
 | 
					                                      device=logits.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # We follow the definition in OpenAI API.
 | 
					    # We follow the definition in OpenAI API.
 | 
				
			||||||
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
 | 
					    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
 | 
				
			||||||
    logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
 | 
					    logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
 | 
				
			||||||
    presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
 | 
					    logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
 | 
				
			||||||
    logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
 | 
					 | 
				
			||||||
    return logits
 | 
					    return logits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -197,13 +231,11 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
 | 
				
			|||||||
            # (i.e., greedy sampling or beam search).
 | 
					            # (i.e., greedy sampling or beam search).
 | 
				
			||||||
            # Set the temperature to 1 to avoid division by zero.
 | 
					            # Set the temperature to 1 to avoid division by zero.
 | 
				
			||||||
            temperature = 1.0
 | 
					            temperature = 1.0
 | 
				
			||||||
 | 
					        if (i < input_metadata.num_prompts
 | 
				
			||||||
        if i < input_metadata.num_prompts:
 | 
					                and sampling_params.prompt_logprobs is not None):
 | 
				
			||||||
            # A prompt input.
 | 
					            prompt_len = input_metadata.prompt_lens[i]
 | 
				
			||||||
            temperatures.append(temperature)
 | 
					            temperatures += [temperature] * (prompt_len - 1)
 | 
				
			||||||
        else:
 | 
					        temperatures += [temperature] * len(seq_ids)
 | 
				
			||||||
            # A generation token.
 | 
					 | 
				
			||||||
            temperatures += [temperature] * len(seq_ids)
 | 
					 | 
				
			||||||
    return temperatures
 | 
					    return temperatures
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -220,214 +252,339 @@ def _get_top_p_top_k(
 | 
				
			|||||||
        top_k = min(sampling_params.top_k, vocab_size)
 | 
					        top_k = min(sampling_params.top_k, vocab_size)
 | 
				
			||||||
        # k=-1 means no truncation.
 | 
					        # k=-1 means no truncation.
 | 
				
			||||||
        top_k = vocab_size if top_k == -1 else top_k
 | 
					        top_k = vocab_size if top_k == -1 else top_k
 | 
				
			||||||
        if i < input_metadata.num_prompts:
 | 
					        if (i < input_metadata.num_prompts
 | 
				
			||||||
            # A prompt input.
 | 
					                and sampling_params.prompt_logprobs is not None):
 | 
				
			||||||
            top_ps.append(top_p)
 | 
					            prompt_len = input_metadata.prompt_lens[i]
 | 
				
			||||||
            top_ks.append(top_k)
 | 
					            top_ps += [top_p] * (prompt_len - 1)
 | 
				
			||||||
        else:
 | 
					            top_ks += [top_k] * (prompt_len - 1)
 | 
				
			||||||
            # A generation token.
 | 
					        top_ps += [top_p] * len(seq_ids)
 | 
				
			||||||
            top_ps += [top_p] * len(seq_ids)
 | 
					        top_ks += [top_k] * len(seq_ids)
 | 
				
			||||||
            top_ks += [top_k] * len(seq_ids)
 | 
					 | 
				
			||||||
    return top_ps, top_ks
 | 
					    return top_ps, top_ks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _apply_top_p_top_k(
 | 
					def _apply_top_p_top_k(
 | 
				
			||||||
    probs: torch.Tensor,
 | 
					    logits: torch.Tensor,
 | 
				
			||||||
    top_ps: List[float],
 | 
					    top_ps: List[float],
 | 
				
			||||||
    top_ks: List[int],
 | 
					    top_ks: List[int],
 | 
				
			||||||
) -> torch.Tensor:
 | 
					) -> torch.Tensor:
 | 
				
			||||||
    p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
 | 
					    p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
 | 
				
			||||||
    k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
 | 
					    k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
 | 
				
			||||||
    probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
 | 
					    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Apply top-p.
 | 
					    # Apply top-p.
 | 
				
			||||||
    probs_sum = torch.cumsum(probs_sort, dim=-1)
 | 
					    probs_sort = logits_sort.softmax(dim=-1)
 | 
				
			||||||
 | 
					    probs_sum = probs_sort.cumsum(dim=-1)
 | 
				
			||||||
    top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
 | 
					    top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
 | 
				
			||||||
    probs_sort[top_p_mask] = 0.0
 | 
					    logits_sort[top_p_mask] = -float("inf")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Apply top-k.
 | 
					    # Apply top-k.
 | 
				
			||||||
    # Create a mask for the top-k elements.
 | 
					    # Create a mask for the top-k elements.
 | 
				
			||||||
    top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
 | 
					    top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
 | 
				
			||||||
    top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1)
 | 
					    top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
 | 
				
			||||||
    top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
 | 
					    top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
 | 
				
			||||||
    probs_sort[top_k_mask] = 0.0
 | 
					    logits_sort[top_k_mask] = -float("inf")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Re-sort the probabilities.
 | 
					    # Re-sort the probabilities.
 | 
				
			||||||
    probs = torch.gather(probs_sort,
 | 
					    logits = torch.gather(logits_sort,
 | 
				
			||||||
                         dim=-1,
 | 
					                          dim=-1,
 | 
				
			||||||
                         index=torch.argsort(probs_idx, dim=-1))
 | 
					                          index=torch.argsort(logits_idx, dim=-1))
 | 
				
			||||||
    return probs
 | 
					    return logits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_topk_logprobs(
 | 
					def _greedy_sample(
 | 
				
			||||||
 | 
					    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
 | 
				
			||||||
    logprobs: torch.Tensor,
 | 
					    logprobs: torch.Tensor,
 | 
				
			||||||
    num_logprobs: Optional[int],
 | 
					) -> List[Tuple[List[int], List[int]]]:
 | 
				
			||||||
) -> Dict[int, float]:
 | 
					    samples = torch.argmax(logprobs, dim=-1).cpu()
 | 
				
			||||||
    if num_logprobs is None or num_logprobs == 0:
 | 
					    sample_idx = 0
 | 
				
			||||||
        return {}
 | 
					    results = []
 | 
				
			||||||
 | 
					    for seq_group in selected_seq_groups:
 | 
				
			||||||
    topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
 | 
					        seq_ids, _ = seq_group
 | 
				
			||||||
    if num_logprobs == 1:
 | 
					        num_parent_seqs = len(seq_ids)
 | 
				
			||||||
        topk_logprobs = [topk_logprobs.item()]
 | 
					        assert num_parent_seqs == 1, (
 | 
				
			||||||
        topk_ids = [topk_ids.item()]
 | 
					            "Greedy sampling should have only one seq.")
 | 
				
			||||||
    else:
 | 
					        parent_ids = list(range(num_parent_seqs))
 | 
				
			||||||
        topk_logprobs = topk_logprobs.tolist()
 | 
					        next_token_ids = [samples[sample_idx].item()]
 | 
				
			||||||
        topk_ids = topk_ids.tolist()
 | 
					        results.append((next_token_ids, parent_ids))
 | 
				
			||||||
 | 
					        sample_idx += num_parent_seqs
 | 
				
			||||||
    token_to_logprob: Dict[int, float] = {}
 | 
					    assert sample_idx == logprobs.size(0)
 | 
				
			||||||
    for token_id, logprob in zip(topk_ids, topk_logprobs):
 | 
					    return results
 | 
				
			||||||
        token_to_logprob[token_id] = logprob
 | 
					 | 
				
			||||||
    return token_to_logprob
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _sample_from_prompt(
 | 
					def _random_sample(
 | 
				
			||||||
    prob: torch.Tensor,
 | 
					    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
 | 
				
			||||||
    sampling_params: SamplingParams,
 | 
					    is_prompts: List[bool],
 | 
				
			||||||
) -> List[int]:
 | 
					 | 
				
			||||||
    if sampling_params.use_beam_search:
 | 
					 | 
				
			||||||
        # Beam search.
 | 
					 | 
				
			||||||
        beam_width = sampling_params.best_of
 | 
					 | 
				
			||||||
        _, next_token_ids = torch.topk(prob, beam_width)
 | 
					 | 
				
			||||||
        next_token_ids = next_token_ids.tolist()
 | 
					 | 
				
			||||||
    elif sampling_params.temperature < _SAMPLING_EPS:
 | 
					 | 
				
			||||||
        # Greedy sampling.
 | 
					 | 
				
			||||||
        assert sampling_params.best_of == 1
 | 
					 | 
				
			||||||
        next_token_id = torch.argmax(prob)
 | 
					 | 
				
			||||||
        next_token_ids = [next_token_id.item()]
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        # Random sampling.
 | 
					 | 
				
			||||||
        # Sample `best_of` tokens for the prompt.
 | 
					 | 
				
			||||||
        num_seqs = sampling_params.best_of
 | 
					 | 
				
			||||||
        next_token_ids = torch.multinomial(prob,
 | 
					 | 
				
			||||||
                                           num_samples=num_seqs,
 | 
					 | 
				
			||||||
                                           replacement=True)
 | 
					 | 
				
			||||||
        next_token_ids = next_token_ids.tolist()
 | 
					 | 
				
			||||||
    return next_token_ids
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _sample_from_generation_tokens(
 | 
					 | 
				
			||||||
    seq_ids: List[int],
 | 
					 | 
				
			||||||
    probs: torch.Tensor,
 | 
					    probs: torch.Tensor,
 | 
				
			||||||
 | 
					) -> List[Tuple[List[int], List[int]]]:
 | 
				
			||||||
 | 
					    # Find the maximum best_of value of the prompt phase requests.
 | 
				
			||||||
 | 
					    max_best_of = 1
 | 
				
			||||||
 | 
					    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
 | 
				
			||||||
 | 
					        if is_prompt:
 | 
				
			||||||
 | 
					            seq_ids, sampling_params = seq_group
 | 
				
			||||||
 | 
					            max_best_of = max(max_best_of, sampling_params.best_of)
 | 
				
			||||||
 | 
					    random_samples = torch.multinomial(probs,
 | 
				
			||||||
 | 
					                                       num_samples=max_best_of,
 | 
				
			||||||
 | 
					                                       replacement=True).cpu()
 | 
				
			||||||
 | 
					    sample_idx = 0
 | 
				
			||||||
 | 
					    results = []
 | 
				
			||||||
 | 
					    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
 | 
				
			||||||
 | 
					        seq_ids, sampling_params = seq_group
 | 
				
			||||||
 | 
					        num_parent_seqs = len(seq_ids)
 | 
				
			||||||
 | 
					        if is_prompt:
 | 
				
			||||||
 | 
					            # Prompt phase.
 | 
				
			||||||
 | 
					            assert num_parent_seqs == 1, (
 | 
				
			||||||
 | 
					                "Prompt input should have only one seq.")
 | 
				
			||||||
 | 
					            parent_ids = [0] * sampling_params.best_of
 | 
				
			||||||
 | 
					            next_token_ids = random_samples[
 | 
				
			||||||
 | 
					                sample_idx, :sampling_params.best_of].tolist()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Generation phase.
 | 
				
			||||||
 | 
					            parent_ids = list(range(num_parent_seqs))
 | 
				
			||||||
 | 
					            next_token_ids = random_samples[sample_idx:sample_idx +
 | 
				
			||||||
 | 
					                                            num_parent_seqs, 0].tolist()
 | 
				
			||||||
 | 
					        results.append((next_token_ids, parent_ids))
 | 
				
			||||||
 | 
					        sample_idx += num_parent_seqs
 | 
				
			||||||
 | 
					    assert sample_idx == probs.size(0)
 | 
				
			||||||
 | 
					    return results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _beam_search_sample(
 | 
				
			||||||
 | 
					    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
 | 
				
			||||||
 | 
					    is_prompts: List[bool],
 | 
				
			||||||
 | 
					    seq_data: Dict[int, SequenceData],
 | 
				
			||||||
    logprobs: torch.Tensor,
 | 
					    logprobs: torch.Tensor,
 | 
				
			||||||
    seq_logprobs: List[float],
 | 
					) -> List[Tuple[List[int], List[int]]]:
 | 
				
			||||||
    sampling_params: SamplingParams,
 | 
					    # We sample 2 * beam_width candidates to make sure that with high
 | 
				
			||||||
) -> Tuple[List[int], List[int]]:
 | 
					    # probability we can get `beam_width` candidates in addition to
 | 
				
			||||||
    # NOTE(woosuk): sampling_params.best_of can be greater than
 | 
					    # the finished sequences for the next iteration. See
 | 
				
			||||||
    # len(seq_ids) because some sequences in the group might have
 | 
					    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
 | 
				
			||||||
    # been already terminated.
 | 
					    # for details. See also HF reference:
 | 
				
			||||||
    if sampling_params.use_beam_search:
 | 
					    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
 | 
				
			||||||
        # Beam search.
 | 
					    #
 | 
				
			||||||
        # Add cumulative logprobs for the sequences in the group.
 | 
					    # NOTE: Beam search is not vectorized, so its speed can be slower than
 | 
				
			||||||
        seq_logprobs = torch.tensor(seq_logprobs,
 | 
					    # other sampling methods.
 | 
				
			||||||
                                    dtype=torch.float,
 | 
					    sample_idx = 0
 | 
				
			||||||
                                    device=logprobs.device)
 | 
					    results = []
 | 
				
			||||||
        logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
 | 
					    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
 | 
				
			||||||
 | 
					        seq_ids, sampling_params = seq_group
 | 
				
			||||||
        vocab_size = logprobs.size(-1)
 | 
					        num_parent_seqs = len(seq_ids)
 | 
				
			||||||
        beam_width = len(seq_ids)
 | 
					        beam_width = sampling_params.best_of
 | 
				
			||||||
        _, topk_ids = torch.topk(logprobs.flatten(), beam_width)
 | 
					        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
 | 
				
			||||||
        topk_ids = topk_ids.tolist()
 | 
					        if is_prompt:
 | 
				
			||||||
        seq_idx = [i // vocab_size for i in topk_ids]
 | 
					            # Prompt phase.
 | 
				
			||||||
        beam_seq_ids = [seq_ids[i] for i in seq_idx]
 | 
					            assert num_parent_seqs == 1, (
 | 
				
			||||||
        token_ids = [i % vocab_size for i in topk_ids]
 | 
					                "Prompt input should have only one seq.")
 | 
				
			||||||
 | 
					            parent_ids = [0] * (2 * beam_width)
 | 
				
			||||||
        beam_outputs: Dict[int, Tuple[int, int]] = {}
 | 
					            _, next_token_ids = torch.topk(seq_group_logprobs[0],
 | 
				
			||||||
        outstanding_beams: List[Tuple[int, int]] = []
 | 
					                                           2 * beam_width)
 | 
				
			||||||
        # If a beam survives, continue with it.
 | 
					            next_token_ids = next_token_ids.tolist()
 | 
				
			||||||
        for seq_id, token_id in zip(beam_seq_ids, token_ids):
 | 
					        else:
 | 
				
			||||||
            if seq_id not in beam_outputs:
 | 
					            # Generation phase.
 | 
				
			||||||
                beam_outputs[seq_id] = (seq_id, token_id)
 | 
					            cumulative_logprobs = [
 | 
				
			||||||
            else:
 | 
					                seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
 | 
				
			||||||
                outstanding_beams.append((seq_id, token_id))
 | 
					            ]
 | 
				
			||||||
 | 
					            cumulative_logprobs = torch.tensor(
 | 
				
			||||||
        # If a beam is discarded, fork another beam.
 | 
					                cumulative_logprobs,
 | 
				
			||||||
        for seq_id in seq_ids:
 | 
					                dtype=torch.float,
 | 
				
			||||||
            if seq_id not in beam_outputs:
 | 
					                device=seq_group_logprobs.device)
 | 
				
			||||||
                beam_outputs[seq_id] = outstanding_beams.pop()
 | 
					            seq_group_logprobs = (seq_group_logprobs +
 | 
				
			||||||
        assert not outstanding_beams
 | 
					                                  cumulative_logprobs.unsqueeze(dim=1))
 | 
				
			||||||
 | 
					            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
 | 
				
			||||||
        parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
 | 
					                                     2 * beam_width)
 | 
				
			||||||
        next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
 | 
					            topk_ids = topk_ids.tolist()
 | 
				
			||||||
    elif sampling_params.temperature < _SAMPLING_EPS:
 | 
					            vocab_size = seq_group_logprobs.size(-1)
 | 
				
			||||||
        # Greedy sampling.
 | 
					            parent_ids = [i // vocab_size for i in topk_ids]
 | 
				
			||||||
        assert len(seq_ids) == 1
 | 
					            next_token_ids = [i % vocab_size for i in topk_ids]
 | 
				
			||||||
        next_token_id = torch.argmax(probs, dim=-1)
 | 
					        results.append((next_token_ids, parent_ids))
 | 
				
			||||||
        next_token_ids = [int(next_token_id.item())]
 | 
					        sample_idx += num_parent_seqs
 | 
				
			||||||
        parent_seq_ids = seq_ids
 | 
					    assert sample_idx == logprobs.size(0)
 | 
				
			||||||
    else:
 | 
					    return results
 | 
				
			||||||
        # Random sampling.
 | 
					 | 
				
			||||||
        # Sample 1 token for each sequence in the group.
 | 
					 | 
				
			||||||
        next_token_ids = torch.multinomial(probs,
 | 
					 | 
				
			||||||
                                           num_samples=1,
 | 
					 | 
				
			||||||
                                           replacement=True)
 | 
					 | 
				
			||||||
        next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
 | 
					 | 
				
			||||||
        parent_seq_ids = seq_ids
 | 
					 | 
				
			||||||
    return parent_seq_ids, next_token_ids
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _sample(
 | 
					def _sample(
 | 
				
			||||||
    probs: torch.Tensor,
 | 
					    probs: torch.Tensor,
 | 
				
			||||||
    logprobs: torch.Tensor,
 | 
					    logprobs: torch.Tensor,
 | 
				
			||||||
    input_metadata: InputMetadata,
 | 
					    input_metadata: InputMetadata,
 | 
				
			||||||
) -> Dict[int, SequenceOutputs]:
 | 
					) -> List[Tuple[List[int], List[int]]]:
 | 
				
			||||||
    seq_outputs: Dict[int, SequenceOutputs] = {}
 | 
					    categorized_seq_group_ids = {t: [] for t in SamplingType}
 | 
				
			||||||
 | 
					    categorized_sample_indices = {t: [] for t in SamplingType}
 | 
				
			||||||
    # TODO(woosuk): Optimize.
 | 
					    start_idx = 0
 | 
				
			||||||
    idx = 0
 | 
					 | 
				
			||||||
    for i, seq_group in enumerate(input_metadata.seq_groups):
 | 
					    for i, seq_group in enumerate(input_metadata.seq_groups):
 | 
				
			||||||
        seq_ids, sampling_params = seq_group
 | 
					        seq_ids, sampling_params = seq_group
 | 
				
			||||||
        if i < input_metadata.num_prompts:
 | 
					        sampling_type = sampling_params.sampling_type
 | 
				
			||||||
            # Generate the next tokens for a prompt input.
 | 
					        if (i < input_metadata.num_prompts
 | 
				
			||||||
            assert len(seq_ids) == sampling_params.best_of
 | 
					                and sampling_params.prompt_logprobs is not None):
 | 
				
			||||||
            prob = probs[idx]
 | 
					            # NOTE: prompt token positions do not need sample, skip
 | 
				
			||||||
            logprob = logprobs[idx]
 | 
					            prompt_len = input_metadata.prompt_lens[i]
 | 
				
			||||||
            idx += 1
 | 
					            start_idx += prompt_len - 1
 | 
				
			||||||
 | 
					        categorized_seq_group_ids[sampling_type].append(i)
 | 
				
			||||||
 | 
					        num_seqs = len(seq_ids)
 | 
				
			||||||
 | 
					        categorized_sample_indices[sampling_type].extend(
 | 
				
			||||||
 | 
					            range(start_idx, start_idx + num_seqs))
 | 
				
			||||||
 | 
					        start_idx += num_seqs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Sample the next tokens.
 | 
					    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
 | 
				
			||||||
            next_token_ids = _sample_from_prompt(prob, sampling_params)
 | 
					    for sampling_type in SamplingType:
 | 
				
			||||||
            # Get top-k log probabilities for the next tokens.
 | 
					        seq_group_ids = categorized_seq_group_ids[sampling_type]
 | 
				
			||||||
            next_logprobs = _get_topk_logprobs(logprob,
 | 
					        seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
 | 
				
			||||||
                                               sampling_params.logprobs)
 | 
					        is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
 | 
				
			||||||
 | 
					        sample_indices = categorized_sample_indices[sampling_type]
 | 
				
			||||||
            # Build the output.
 | 
					        num_tokens = len(sample_indices)
 | 
				
			||||||
            for seq_id, next_token_id in zip(seq_ids, next_token_ids):
 | 
					        if num_tokens == 0:
 | 
				
			||||||
                output_logprobs = next_logprobs.copy()
 | 
					            continue
 | 
				
			||||||
                output_logprobs[next_token_id] = logprob[next_token_id].item()
 | 
					        if sampling_type == SamplingType.GREEDY:
 | 
				
			||||||
                seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
 | 
					            category_logprobs = logprobs[sample_indices]
 | 
				
			||||||
                                                      next_token_id,
 | 
					            sample_results = _greedy_sample(seq_groups, category_logprobs)
 | 
				
			||||||
                                                      output_logprobs)
 | 
					        elif sampling_type == SamplingType.RANDOM:
 | 
				
			||||||
 | 
					            category_probs = probs[sample_indices]
 | 
				
			||||||
 | 
					            sample_results = _random_sample(seq_groups, is_prompts,
 | 
				
			||||||
 | 
					                                            category_probs)
 | 
				
			||||||
 | 
					        elif sampling_type == SamplingType.BEAM:
 | 
				
			||||||
 | 
					            category_logprobs = logprobs[sample_indices]
 | 
				
			||||||
 | 
					            sample_results = _beam_search_sample(seq_groups, is_prompts,
 | 
				
			||||||
 | 
					                                                 input_metadata.seq_data,
 | 
				
			||||||
 | 
					                                                 category_logprobs)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # Generate the next tokens for generation tokens.
 | 
					            raise ValueError(f"Unsupported sampling type: {sampling_type}")
 | 
				
			||||||
            prob = probs[idx:idx + len(seq_ids)]
 | 
					        sample_results_dict.update(zip(seq_group_ids, sample_results))
 | 
				
			||||||
            logprob = logprobs[idx:idx + len(seq_ids)]
 | 
					 | 
				
			||||||
            idx += len(seq_ids)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Sample the next tokens.
 | 
					    sample_results = [
 | 
				
			||||||
            seq_logprobs = [
 | 
					        sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
 | 
				
			||||||
                input_metadata.seq_data[seq_id].cumulative_logprob
 | 
					    ]
 | 
				
			||||||
                for seq_id in seq_ids
 | 
					    return sample_results
 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
 | 
					 | 
				
			||||||
                seq_ids, prob, logprob, seq_logprobs, sampling_params)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Get top-k log probabilities for the next tokens.
 | 
					 | 
				
			||||||
            next_logprobs: Dict[int, Dict[int, float]] = {}
 | 
					 | 
				
			||||||
            for j, seq_id in enumerate(seq_ids):
 | 
					 | 
				
			||||||
                next_logprobs[seq_id] = _get_topk_logprobs(
 | 
					 | 
				
			||||||
                    logprob[j], sampling_params.logprobs)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Build the output.
 | 
					def _get_logprobs(
 | 
				
			||||||
            for seq_id, parent_seq_id, next_token_id in zip(
 | 
					    logprobs: torch.Tensor,
 | 
				
			||||||
                    seq_ids, parent_seq_ids, next_token_ids):
 | 
					    input_metadata: InputMetadata,
 | 
				
			||||||
                j = seq_ids.index(parent_seq_id)
 | 
					    sample_results: List[Tuple[List[int], List[int]]],
 | 
				
			||||||
                output_logprobs = next_logprobs[parent_seq_id].copy()
 | 
					) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
 | 
				
			||||||
                output_logprobs[next_token_id] = logprob[j,
 | 
					        int, float]]]]:
 | 
				
			||||||
                                                         next_token_id].item()
 | 
					    # Prepare query indices
 | 
				
			||||||
                seq_outputs[seq_id] = SequenceOutputs(
 | 
					    batched_logprobs_query_seq_indices: List[int] = []
 | 
				
			||||||
                    seq_id,
 | 
					    batched_logprobs_query_token_indices: List[int] = []
 | 
				
			||||||
                    parent_seq_id,
 | 
					    largest_num_logprobs = 0
 | 
				
			||||||
                    next_token_id,
 | 
					    sample_idx = 0
 | 
				
			||||||
                    output_logprobs,
 | 
					    for i, (seq_group, sample_result) in enumerate(
 | 
				
			||||||
                )
 | 
					            zip(input_metadata.seq_groups, sample_results)):
 | 
				
			||||||
 | 
					        seq_ids, sampling_params = seq_group
 | 
				
			||||||
 | 
					        next_token_ids, parent_ids = sample_result
 | 
				
			||||||
 | 
					        num_parent_seqs = len(seq_ids)
 | 
				
			||||||
 | 
					        if (i < input_metadata.num_prompts
 | 
				
			||||||
 | 
					                and sampling_params.prompt_logprobs is not None):
 | 
				
			||||||
 | 
					            largest_num_logprobs = max(largest_num_logprobs,
 | 
				
			||||||
 | 
					                                       sampling_params.prompt_logprobs)
 | 
				
			||||||
 | 
					            prompt_len = input_metadata.prompt_lens[i]
 | 
				
			||||||
 | 
					            prompt_tokens = input_metadata.seq_data[
 | 
				
			||||||
 | 
					                seq_ids[0]].prompt_token_ids
 | 
				
			||||||
 | 
					            batched_logprobs_query_seq_indices.extend(
 | 
				
			||||||
 | 
					                sample_idx + j for j in range(prompt_len - 1))
 | 
				
			||||||
 | 
					            batched_logprobs_query_token_indices.extend(
 | 
				
			||||||
 | 
					                token_id for token_id in prompt_tokens[1:])
 | 
				
			||||||
 | 
					            sample_idx += prompt_len - 1
 | 
				
			||||||
 | 
					        batched_logprobs_query_seq_indices.extend(
 | 
				
			||||||
 | 
					            [sample_idx + parent_id for parent_id in parent_ids])
 | 
				
			||||||
 | 
					        batched_logprobs_query_token_indices.extend(next_token_ids)
 | 
				
			||||||
 | 
					        if sampling_params.logprobs is not None:
 | 
				
			||||||
 | 
					            largest_num_logprobs = max(largest_num_logprobs,
 | 
				
			||||||
 | 
					                                       sampling_params.logprobs)
 | 
				
			||||||
 | 
					        sample_idx += num_parent_seqs
 | 
				
			||||||
 | 
					    assert sample_idx == logprobs.size(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return seq_outputs
 | 
					    # Batched query for logprobs of selected token
 | 
				
			||||||
 | 
					    batched_logprobs_query_result = logprobs[[
 | 
				
			||||||
 | 
					        batched_logprobs_query_seq_indices,
 | 
				
			||||||
 | 
					        batched_logprobs_query_token_indices
 | 
				
			||||||
 | 
					    ]].cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Batched query for logprobs of topk tokens
 | 
				
			||||||
 | 
					    if largest_num_logprobs > 0:
 | 
				
			||||||
 | 
					        top_logprobs, top_token_ids = torch.topk(logprobs,
 | 
				
			||||||
 | 
					                                                 largest_num_logprobs,
 | 
				
			||||||
 | 
					                                                 dim=-1)
 | 
				
			||||||
 | 
					        top_logprobs = top_logprobs.cpu()
 | 
				
			||||||
 | 
					        top_token_ids = top_token_ids.cpu()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        top_logprobs, top_token_ids = None, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Gather results
 | 
				
			||||||
 | 
					    result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
 | 
				
			||||||
 | 
					    result_sample_logprobs: List[SampleLogprobs] = []
 | 
				
			||||||
 | 
					    sample_idx = 0
 | 
				
			||||||
 | 
					    query_result_idx = 0
 | 
				
			||||||
 | 
					    for i, (seq_group, sample_result) in enumerate(
 | 
				
			||||||
 | 
					            zip(input_metadata.seq_groups, sample_results)):
 | 
				
			||||||
 | 
					        seq_ids, sampling_params = seq_group
 | 
				
			||||||
 | 
					        next_token_ids, parent_ids = sample_result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Prompt logprobs
 | 
				
			||||||
 | 
					        if (i < input_metadata.num_prompts
 | 
				
			||||||
 | 
					                and sampling_params.prompt_logprobs is not None):
 | 
				
			||||||
 | 
					            num_logprobs = sampling_params.prompt_logprobs
 | 
				
			||||||
 | 
					            prompt_len = input_metadata.prompt_lens[i]
 | 
				
			||||||
 | 
					            prompt_tokens = input_metadata.seq_data[
 | 
				
			||||||
 | 
					                seq_ids[0]].prompt_token_ids
 | 
				
			||||||
 | 
					            group_prompt_logprobs: PromptLogprobs = [None]
 | 
				
			||||||
 | 
					            for token_id in prompt_tokens[1:]:
 | 
				
			||||||
 | 
					                prompt_logprobs_dict = {
 | 
				
			||||||
 | 
					                    token_id:
 | 
				
			||||||
 | 
					                    batched_logprobs_query_result[query_result_idx].item()
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                if num_logprobs > 0:
 | 
				
			||||||
 | 
					                    prompt_logprobs_dict.update(
 | 
				
			||||||
 | 
					                        zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
 | 
				
			||||||
 | 
					                            top_logprobs[sample_idx, :num_logprobs].tolist()))
 | 
				
			||||||
 | 
					                group_prompt_logprobs.append(prompt_logprobs_dict)
 | 
				
			||||||
 | 
					                sample_idx += 1
 | 
				
			||||||
 | 
					                query_result_idx += 1
 | 
				
			||||||
 | 
					            result_prompt_logprobs.append(group_prompt_logprobs)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            result_prompt_logprobs.append(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Sample logprobs
 | 
				
			||||||
 | 
					        num_logprobs = sampling_params.logprobs
 | 
				
			||||||
 | 
					        if num_logprobs is None:
 | 
				
			||||||
 | 
					            num_logprobs = 0
 | 
				
			||||||
 | 
					        group_sample_logprobs: SampleLogprobs = []
 | 
				
			||||||
 | 
					        for next_token_id, parent_id in zip(next_token_ids, parent_ids):
 | 
				
			||||||
 | 
					            sample_logprobs_dict = {
 | 
				
			||||||
 | 
					                next_token_id:
 | 
				
			||||||
 | 
					                batched_logprobs_query_result[query_result_idx].item()
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            query_result_idx += 1
 | 
				
			||||||
 | 
					            if num_logprobs > 0:
 | 
				
			||||||
 | 
					                sample_logprobs_dict.update(
 | 
				
			||||||
 | 
					                    zip(
 | 
				
			||||||
 | 
					                        top_token_ids[sample_idx +
 | 
				
			||||||
 | 
					                                      parent_id, :num_logprobs].tolist(),
 | 
				
			||||||
 | 
					                        top_logprobs[sample_idx +
 | 
				
			||||||
 | 
					                                     parent_id, :num_logprobs].tolist()))
 | 
				
			||||||
 | 
					            group_sample_logprobs.append(sample_logprobs_dict)
 | 
				
			||||||
 | 
					        result_sample_logprobs.append(group_sample_logprobs)
 | 
				
			||||||
 | 
					        sample_idx += len(seq_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return result_prompt_logprobs, result_sample_logprobs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _build_sampler_output(
 | 
				
			||||||
 | 
					    sample_results: List[Tuple[List[int], List[int]]],
 | 
				
			||||||
 | 
					    input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					    prompt_logprobs: List[Optional[PromptLogprobs]],
 | 
				
			||||||
 | 
					    sample_logprobs: List[SampleLogprobs],
 | 
				
			||||||
 | 
					) -> SamplerOutput:
 | 
				
			||||||
 | 
					    sampler_output = []
 | 
				
			||||||
 | 
					    for (seq_group, sample_result, group_prompt_logprobs,
 | 
				
			||||||
 | 
					         group_sample_logprobs) in zip(input_metadata.seq_groups,
 | 
				
			||||||
 | 
					                                       sample_results, prompt_logprobs,
 | 
				
			||||||
 | 
					                                       sample_logprobs):
 | 
				
			||||||
 | 
					        seq_ids, _ = seq_group
 | 
				
			||||||
 | 
					        next_token_ids, parent_ids = sample_result
 | 
				
			||||||
 | 
					        seq_outputs = []
 | 
				
			||||||
 | 
					        for parent_id, next_token_id, logprobs in zip(parent_ids,
 | 
				
			||||||
 | 
					                                                      next_token_ids,
 | 
				
			||||||
 | 
					                                                      group_sample_logprobs):
 | 
				
			||||||
 | 
					            seq_outputs.append(
 | 
				
			||||||
 | 
					                SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
 | 
				
			||||||
 | 
					        sampler_output.append(
 | 
				
			||||||
 | 
					            SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
 | 
				
			||||||
 | 
					    return sampler_output
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,5 @@
 | 
				
			|||||||
"""Utilities for selecting and loading models."""
 | 
					"""Utilities for selecting and loading models."""
 | 
				
			||||||
 | 
					import contextlib
 | 
				
			||||||
from typing import Type
 | 
					from typing import Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
@ -7,20 +8,46 @@ from transformers import PretrainedConfig
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from vllm.config import ModelConfig
 | 
					from vllm.config import ModelConfig
 | 
				
			||||||
from vllm.model_executor.models import *  # pylint: disable=wildcard-import
 | 
					from vllm.model_executor.models import *  # pylint: disable=wildcard-import
 | 
				
			||||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
 | 
					from vllm.model_executor.weight_utils import (get_quant_config,
 | 
				
			||||||
 | 
					                                              initialize_dummy_weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO(woosuk): Lazy-load the model classes.
 | 
					# TODO(woosuk): Lazy-load the model classes.
 | 
				
			||||||
_MODEL_REGISTRY = {
 | 
					_MODEL_REGISTRY = {
 | 
				
			||||||
 | 
					    "AquilaModel": AquilaForCausalLM,
 | 
				
			||||||
 | 
					    "AquilaForCausalLM": AquilaForCausalLM,  # AquilaChat2
 | 
				
			||||||
 | 
					    "BaiChuanForCausalLM": BaiChuanForCausalLM,  # baichuan-7b
 | 
				
			||||||
 | 
					    "BaichuanForCausalLM": BaichuanForCausalLM,  # baichuan-13b
 | 
				
			||||||
    "BloomForCausalLM": BloomForCausalLM,
 | 
					    "BloomForCausalLM": BloomForCausalLM,
 | 
				
			||||||
 | 
					    "FalconForCausalLM": FalconForCausalLM,
 | 
				
			||||||
    "GPT2LMHeadModel": GPT2LMHeadModel,
 | 
					    "GPT2LMHeadModel": GPT2LMHeadModel,
 | 
				
			||||||
    "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
 | 
					    "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
 | 
				
			||||||
 | 
					    "GPTJForCausalLM": GPTJForCausalLM,
 | 
				
			||||||
    "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
 | 
					    "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
 | 
				
			||||||
 | 
					    "InternLMForCausalLM": InternLMForCausalLM,
 | 
				
			||||||
    "LlamaForCausalLM": LlamaForCausalLM,
 | 
					    "LlamaForCausalLM": LlamaForCausalLM,
 | 
				
			||||||
    "LLaMAForCausalLM": LlamaForCausalLM,  # For decapoda-research/llama-*
 | 
					    "LLaMAForCausalLM": LlamaForCausalLM,  # For decapoda-research/llama-*
 | 
				
			||||||
 | 
					    "MistralForCausalLM": MistralForCausalLM,
 | 
				
			||||||
    "MPTForCausalLM": MPTForCausalLM,
 | 
					    "MPTForCausalLM": MPTForCausalLM,
 | 
				
			||||||
    "OPTForCausalLM": OPTForCausalLM,
 | 
					    "OPTForCausalLM": OPTForCausalLM,
 | 
				
			||||||
 | 
					    "QWenLMHeadModel": QWenLMHeadModel,
 | 
				
			||||||
 | 
					    "RWForCausalLM": FalconForCausalLM,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# FIXME(woosuk): Remove this once all models support quantization.
 | 
				
			||||||
 | 
					_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
 | 
				
			||||||
 | 
					    LlamaForCausalLM,
 | 
				
			||||||
 | 
					    MistralForCausalLM,
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@contextlib.contextmanager
 | 
				
			||||||
 | 
					def _set_default_torch_dtype(dtype: torch.dtype):
 | 
				
			||||||
 | 
					    """Sets the default torch dtype to the given dtype."""
 | 
				
			||||||
 | 
					    old_dtype = torch.get_default_dtype()
 | 
				
			||||||
 | 
					    torch.set_default_dtype(dtype)
 | 
				
			||||||
 | 
					    yield
 | 
				
			||||||
 | 
					    torch.set_default_dtype(old_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
 | 
					def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
 | 
				
			||||||
    architectures = getattr(config, "architectures", [])
 | 
					    architectures = getattr(config, "architectures", [])
 | 
				
			||||||
@ -34,19 +61,46 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def get_model(model_config: ModelConfig) -> nn.Module:
 | 
					def get_model(model_config: ModelConfig) -> nn.Module:
 | 
				
			||||||
    model_class = _get_model_architecture(model_config.hf_config)
 | 
					    model_class = _get_model_architecture(model_config.hf_config)
 | 
				
			||||||
    torch.set_default_dtype(model_config.dtype)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Create a model instance.
 | 
					    # Get the quantization config.
 | 
				
			||||||
    # The weights will be initialized as empty tensors.
 | 
					    quant_config = None
 | 
				
			||||||
    model = model_class(model_config.hf_config)
 | 
					    if model_config.quantization is not None:
 | 
				
			||||||
    if model_config.use_dummy_weights:
 | 
					        if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
 | 
				
			||||||
        model = model.cuda()
 | 
					            raise ValueError(
 | 
				
			||||||
        # NOTE(woosuk): For accurate performance evaluation, we assign
 | 
					                f"Quantization is not supported for {model_class}.")
 | 
				
			||||||
        # random values to the weights.
 | 
					        quant_config = get_quant_config(model_config.quantization,
 | 
				
			||||||
        initialize_dummy_weights(model)
 | 
					                                        model_config.model,
 | 
				
			||||||
    else:
 | 
					                                        model_config.download_dir)
 | 
				
			||||||
        # Load the weights from the cached or downloaded files.
 | 
					        capability = torch.cuda.get_device_capability()
 | 
				
			||||||
        model.load_weights(model_config.model, model_config.download_dir,
 | 
					        capability = capability[0] * 10 + capability[1]
 | 
				
			||||||
                           model_config.use_np_weights)
 | 
					        if capability < quant_config.get_min_capability():
 | 
				
			||||||
        model = model.cuda()
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"The quantization method {model_config.quantization} is not "
 | 
				
			||||||
 | 
					                "supported for the current GPU. "
 | 
				
			||||||
 | 
					                f"Minimum capability: {quant_config.get_min_capability()}. "
 | 
				
			||||||
 | 
					                f"Current capability: {capability}.")
 | 
				
			||||||
 | 
					        supported_dtypes = quant_config.get_supported_act_dtypes()
 | 
				
			||||||
 | 
					        if model_config.dtype not in supported_dtypes:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"{model_config.dtype} is not supported for quantization "
 | 
				
			||||||
 | 
					                f"method {model_config.quantization}. Supported dtypes: "
 | 
				
			||||||
 | 
					                f"{supported_dtypes}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with _set_default_torch_dtype(model_config.dtype):
 | 
				
			||||||
 | 
					        # Create a model instance.
 | 
				
			||||||
 | 
					        # The weights will be initialized as empty tensors.
 | 
				
			||||||
 | 
					        if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
 | 
				
			||||||
 | 
					            model = model_class(model_config.hf_config, quant_config)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            model = model_class(model_config.hf_config)
 | 
				
			||||||
 | 
					        if model_config.load_format == "dummy":
 | 
				
			||||||
 | 
					            model = model.cuda()
 | 
				
			||||||
 | 
					            # NOTE(woosuk): For accurate performance evaluation, we assign
 | 
				
			||||||
 | 
					            # random values to the weights.
 | 
				
			||||||
 | 
					            initialize_dummy_weights(model)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Load the weights from the cached or downloaded files.
 | 
				
			||||||
 | 
					            model.load_weights(model_config.model, model_config.download_dir,
 | 
				
			||||||
 | 
					                               model_config.load_format, model_config.revision)
 | 
				
			||||||
 | 
					            model = model.cuda()
 | 
				
			||||||
    return model.eval()
 | 
					    return model.eval()
 | 
				
			||||||
 | 
				
			|||||||
@ -1,17 +1,33 @@
 | 
				
			|||||||
 | 
					from vllm.model_executor.models.aquila import AquilaForCausalLM
 | 
				
			||||||
 | 
					from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
 | 
				
			||||||
 | 
					                                                 BaichuanForCausalLM)
 | 
				
			||||||
from vllm.model_executor.models.bloom import BloomForCausalLM
 | 
					from vllm.model_executor.models.bloom import BloomForCausalLM
 | 
				
			||||||
 | 
					from vllm.model_executor.models.falcon import FalconForCausalLM
 | 
				
			||||||
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
 | 
					from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
 | 
				
			||||||
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
 | 
					from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
 | 
				
			||||||
 | 
					from vllm.model_executor.models.gpt_j import GPTJForCausalLM
 | 
				
			||||||
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
 | 
					from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
 | 
				
			||||||
 | 
					from vllm.model_executor.models.internlm import InternLMForCausalLM
 | 
				
			||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
 | 
					from vllm.model_executor.models.llama import LlamaForCausalLM
 | 
				
			||||||
 | 
					from vllm.model_executor.models.mistral import MistralForCausalLM
 | 
				
			||||||
from vllm.model_executor.models.mpt import MPTForCausalLM
 | 
					from vllm.model_executor.models.mpt import MPTForCausalLM
 | 
				
			||||||
from vllm.model_executor.models.opt import OPTForCausalLM
 | 
					from vllm.model_executor.models.opt import OPTForCausalLM
 | 
				
			||||||
 | 
					from vllm.model_executor.models.qwen import QWenLMHeadModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = [
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    "AquilaForCausalLM",
 | 
				
			||||||
 | 
					    "BaiChuanForCausalLM",
 | 
				
			||||||
 | 
					    "BaichuanForCausalLM",
 | 
				
			||||||
    "BloomForCausalLM",
 | 
					    "BloomForCausalLM",
 | 
				
			||||||
 | 
					    "FalconForCausalLM",
 | 
				
			||||||
    "GPT2LMHeadModel",
 | 
					    "GPT2LMHeadModel",
 | 
				
			||||||
    "GPTBigCodeForCausalLM",
 | 
					    "GPTBigCodeForCausalLM",
 | 
				
			||||||
 | 
					    "GPTJForCausalLM",
 | 
				
			||||||
    "GPTNeoXForCausalLM",
 | 
					    "GPTNeoXForCausalLM",
 | 
				
			||||||
 | 
					    "InternLMForCausalLM",
 | 
				
			||||||
    "LlamaForCausalLM",
 | 
					    "LlamaForCausalLM",
 | 
				
			||||||
    "MPTForCausalLM",
 | 
					    "MPTForCausalLM",
 | 
				
			||||||
    "OPTForCausalLM",
 | 
					    "OPTForCausalLM",
 | 
				
			||||||
 | 
					    "QWenLMHeadModel",
 | 
				
			||||||
 | 
					    "MistralForCausalLM",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										372
									
								
								vllm/model_executor/models/aquila.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,372 @@
 | 
				
			|||||||
 | 
					# coding=utf-8
 | 
				
			||||||
 | 
					# Adapted from
 | 
				
			||||||
 | 
					# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
 | 
				
			||||||
 | 
					# Copyright 2023 The vLLM team.
 | 
				
			||||||
 | 
					# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
 | 
				
			||||||
 | 
					# and OPT implementations in this library. It has been modified from its
 | 
				
			||||||
 | 
					# original forms to accommodate minor architectural differences compared
 | 
				
			||||||
 | 
					# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Inference-only LLaMA model compatible with HuggingFace weights.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The input of the model is flattened to a 1D tensor of tokens. The model uses
 | 
				
			||||||
 | 
					InputMetadata to extract the original 2D shape of the input.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					from typing import List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm.model_executor.input_metadata import InputMetadata
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.activation import SiluAndMul
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.sampler import Sampler
 | 
				
			||||||
 | 
					from vllm.model_executor.weight_utils import (
 | 
				
			||||||
 | 
					    hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
 | 
				
			||||||
 | 
					    load_tensor_parallel_weights)
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.parallel_state import (
 | 
				
			||||||
 | 
					    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
 | 
				
			||||||
 | 
					                                                       ColumnParallelLinear,
 | 
				
			||||||
 | 
					                                                       RowParallelLinear)
 | 
				
			||||||
 | 
					from vllm.sequence import SamplerOutput
 | 
				
			||||||
 | 
					from vllm.transformers_utils.configs.aquila import AquilaConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					KVCache = Tuple[torch.Tensor, torch.Tensor]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AquilaMLP(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        hidden_size: int,
 | 
				
			||||||
 | 
					        intermediate_size: int,
 | 
				
			||||||
 | 
					        hidden_act: str,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.gate_up_proj = ColumnParallelLinear(
 | 
				
			||||||
 | 
					            hidden_size,
 | 
				
			||||||
 | 
					            2 * intermediate_size,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            gather_output=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.down_proj = RowParallelLinear(
 | 
				
			||||||
 | 
					            intermediate_size,
 | 
				
			||||||
 | 
					            hidden_size,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if hidden_act != "silu":
 | 
				
			||||||
 | 
					            raise ValueError(f"Unsupported activation: {hidden_act}. "
 | 
				
			||||||
 | 
					                             "Only silu is supported for now.")
 | 
				
			||||||
 | 
					        self.act_fn = SiluAndMul()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        gate_up, _ = self.gate_up_proj(x)
 | 
				
			||||||
 | 
					        x = self.act_fn(gate_up)
 | 
				
			||||||
 | 
					        x, _ = self.down_proj(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AquilaRMSNorm(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, hidden_size, eps=1e-6):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        AquilaRMSNorm is equivalent to T5LayerNorm
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.weight = nn.Parameter(torch.ones(hidden_size))
 | 
				
			||||||
 | 
					        self.variance_epsilon = eps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, hidden_states):
 | 
				
			||||||
 | 
					        input_dtype = hidden_states.dtype
 | 
				
			||||||
 | 
					        variance = hidden_states.to(torch.float32).pow(2).mean(-1,
 | 
				
			||||||
 | 
					                                                               keepdim=True)
 | 
				
			||||||
 | 
					        hidden_states = hidden_states * torch.rsqrt(variance +
 | 
				
			||||||
 | 
					                                                    self.variance_epsilon)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return (self.weight * hidden_states).to(input_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AquilaAttention(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        hidden_size: int,
 | 
				
			||||||
 | 
					        num_heads: int,
 | 
				
			||||||
 | 
					        num_kv_heads: int,
 | 
				
			||||||
 | 
					        rope_theta: float = 10000,
 | 
				
			||||||
 | 
					        max_position_embeddings: int = 8192,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.hidden_size = hidden_size
 | 
				
			||||||
 | 
					        tp_size = get_tensor_model_parallel_world_size()
 | 
				
			||||||
 | 
					        self.total_num_heads = num_heads
 | 
				
			||||||
 | 
					        assert self.total_num_heads % tp_size == 0
 | 
				
			||||||
 | 
					        self.num_heads = self.total_num_heads // tp_size
 | 
				
			||||||
 | 
					        self.total_num_kv_heads = num_kv_heads
 | 
				
			||||||
 | 
					        assert self.total_num_kv_heads % tp_size == 0
 | 
				
			||||||
 | 
					        self.num_kv_heads = self.total_num_kv_heads // tp_size
 | 
				
			||||||
 | 
					        self.head_dim = hidden_size // self.total_num_heads
 | 
				
			||||||
 | 
					        self.q_size = self.num_heads * self.head_dim
 | 
				
			||||||
 | 
					        self.kv_size = self.num_kv_heads * self.head_dim
 | 
				
			||||||
 | 
					        self.scaling = self.head_dim**-0.5
 | 
				
			||||||
 | 
					        self.rope_theta = rope_theta
 | 
				
			||||||
 | 
					        self.max_position_embeddings = max_position_embeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.qkv_proj = ColumnParallelLinear(
 | 
				
			||||||
 | 
					            hidden_size,
 | 
				
			||||||
 | 
					            (self.total_num_heads + 2 * self.total_num_kv_heads) *
 | 
				
			||||||
 | 
					            self.head_dim,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            gather_output=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.o_proj = RowParallelLinear(
 | 
				
			||||||
 | 
					            self.total_num_heads * self.head_dim,
 | 
				
			||||||
 | 
					            hidden_size,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.attn = PagedAttentionWithRoPE(
 | 
				
			||||||
 | 
					            self.num_heads,
 | 
				
			||||||
 | 
					            self.head_dim,
 | 
				
			||||||
 | 
					            self.scaling,
 | 
				
			||||||
 | 
					            rotary_dim=self.head_dim,
 | 
				
			||||||
 | 
					            base=self.rope_theta,
 | 
				
			||||||
 | 
					            max_position=self.max_position_embeddings,
 | 
				
			||||||
 | 
					            num_kv_heads=self.num_kv_heads,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_cache: KVCache,
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_event: Optional[torch.cuda.Event],
 | 
				
			||||||
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
 | 
					        qkv, _ = self.qkv_proj(hidden_states)
 | 
				
			||||||
 | 
					        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
 | 
				
			||||||
 | 
					        k_cache, v_cache = kv_cache
 | 
				
			||||||
 | 
					        attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
 | 
				
			||||||
 | 
					                                input_metadata, cache_event)
 | 
				
			||||||
 | 
					        output, _ = self.o_proj(attn_output)
 | 
				
			||||||
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AquilaDecoderLayer(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config: AquilaConfig):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.hidden_size = config.hidden_size
 | 
				
			||||||
 | 
					        rope_theta = getattr(config, "rope_theta", 10000)
 | 
				
			||||||
 | 
					        max_position_embeddings = getattr(config, "max_position_embeddings",
 | 
				
			||||||
 | 
					                                          8192)
 | 
				
			||||||
 | 
					        self.self_attn = AquilaAttention(
 | 
				
			||||||
 | 
					            hidden_size=self.hidden_size,
 | 
				
			||||||
 | 
					            num_heads=config.num_attention_heads,
 | 
				
			||||||
 | 
					            num_kv_heads=config.num_key_value_heads,
 | 
				
			||||||
 | 
					            rope_theta=rope_theta,
 | 
				
			||||||
 | 
					            max_position_embeddings=max_position_embeddings,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.mlp = AquilaMLP(
 | 
				
			||||||
 | 
					            hidden_size=self.hidden_size,
 | 
				
			||||||
 | 
					            intermediate_size=config.intermediate_size,
 | 
				
			||||||
 | 
					            hidden_act=config.hidden_act,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.input_layernorm = AquilaRMSNorm(config.hidden_size,
 | 
				
			||||||
 | 
					                                             eps=config.rms_norm_eps)
 | 
				
			||||||
 | 
					        self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size,
 | 
				
			||||||
 | 
					                                                      eps=config.rms_norm_eps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_cache: KVCache,
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_event: Optional[torch.cuda.Event],
 | 
				
			||||||
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
 | 
					        # Self Attention
 | 
				
			||||||
 | 
					        residual = hidden_states
 | 
				
			||||||
 | 
					        hidden_states = self.input_layernorm(hidden_states)
 | 
				
			||||||
 | 
					        hidden_states = self.self_attn(
 | 
				
			||||||
 | 
					            positions=positions,
 | 
				
			||||||
 | 
					            hidden_states=hidden_states,
 | 
				
			||||||
 | 
					            kv_cache=kv_cache,
 | 
				
			||||||
 | 
					            input_metadata=input_metadata,
 | 
				
			||||||
 | 
					            cache_event=cache_event,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        hidden_states = residual + hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Fully Connected
 | 
				
			||||||
 | 
					        residual = hidden_states
 | 
				
			||||||
 | 
					        hidden_states = self.post_attention_layernorm(hidden_states)
 | 
				
			||||||
 | 
					        hidden_states = self.mlp(hidden_states)
 | 
				
			||||||
 | 
					        hidden_states = residual + hidden_states
 | 
				
			||||||
 | 
					        return hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AquilaModel(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config: AquilaConfig):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.padding_idx = config.pad_token_id
 | 
				
			||||||
 | 
					        self.vocab_size = config.vocab_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #vocab_size = ((config.vocab_size + 63) // 64) * 64
 | 
				
			||||||
 | 
					        self.embed_tokens = VocabParallelEmbedding(
 | 
				
			||||||
 | 
					            config.vocab_size,
 | 
				
			||||||
 | 
					            config.hidden_size,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.layers = nn.ModuleList([
 | 
				
			||||||
 | 
					            AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					        self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        input_ids: torch.Tensor,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_caches: List[KVCache],
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_events: Optional[List[torch.cuda.Event]],
 | 
				
			||||||
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
 | 
					        hidden_states = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					        for i in range(len(self.layers)):
 | 
				
			||||||
 | 
					            if cache_events is None:
 | 
				
			||||||
 | 
					                cache_event = None
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                cache_event = cache_events[i]
 | 
				
			||||||
 | 
					            layer = self.layers[i]
 | 
				
			||||||
 | 
					            hidden_states = layer(
 | 
				
			||||||
 | 
					                positions,
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                kv_caches[i],
 | 
				
			||||||
 | 
					                input_metadata,
 | 
				
			||||||
 | 
					                cache_event,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AquilaForCausalLM(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.model = AquilaModel(config)
 | 
				
			||||||
 | 
					        vocab_size = ((config.vocab_size + 63) // 64) * 64
 | 
				
			||||||
 | 
					        self.lm_head = ColumnParallelLinear(
 | 
				
			||||||
 | 
					            config.hidden_size,
 | 
				
			||||||
 | 
					            vocab_size,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            gather_output=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.sampler = Sampler(config.vocab_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        input_ids: torch.Tensor,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_caches: List[KVCache],
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_events: Optional[List[torch.cuda.Event]],
 | 
				
			||||||
 | 
					    ) -> SamplerOutput:
 | 
				
			||||||
 | 
					        hidden_states = self.model(input_ids, positions, kv_caches,
 | 
				
			||||||
 | 
					                                   input_metadata, cache_events)
 | 
				
			||||||
 | 
					        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
 | 
				
			||||||
 | 
					                                   input_metadata)
 | 
				
			||||||
 | 
					        return next_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _column_parallel_weights = [
 | 
				
			||||||
 | 
					        "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_weights(self,
 | 
				
			||||||
 | 
					                     model_name_or_path: str,
 | 
				
			||||||
 | 
					                     cache_dir: Optional[str] = None,
 | 
				
			||||||
 | 
					                     load_format: str = "auto",
 | 
				
			||||||
 | 
					                     revision: Optional[str] = None):
 | 
				
			||||||
 | 
					        tp_size = get_tensor_model_parallel_world_size()
 | 
				
			||||||
 | 
					        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
 | 
				
			||||||
 | 
					        q_proj_shard_size = (self.config.hidden_size // tp_size)
 | 
				
			||||||
 | 
					        kv_proj_shard_size = (self.config.hidden_size //
 | 
				
			||||||
 | 
					                              self.config.num_attention_heads *
 | 
				
			||||||
 | 
					                              self.config.num_key_value_heads // tp_size)
 | 
				
			||||||
 | 
					        attention_weight_specs = [
 | 
				
			||||||
 | 
					            # (weight_name, shard_size, offset)
 | 
				
			||||||
 | 
					            ("q_proj", q_proj_shard_size, 0),
 | 
				
			||||||
 | 
					            ("k_proj", kv_proj_shard_size, q_proj_shard_size),
 | 
				
			||||||
 | 
					            ("v_proj", kv_proj_shard_size,
 | 
				
			||||||
 | 
					             q_proj_shard_size + kv_proj_shard_size),
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        state_dict = self.state_dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for name, loaded_weight in hf_model_weights_iterator(
 | 
				
			||||||
 | 
					                model_name_or_path, cache_dir, load_format, revision):
 | 
				
			||||||
 | 
					            if "rotary_emb.inv_freq" in name:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            is_attention_weight = False
 | 
				
			||||||
 | 
					            for weight_name, shard_size, offset in attention_weight_specs:
 | 
				
			||||||
 | 
					                if weight_name not in name:
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                param = state_dict[name.replace(weight_name, "qkv_proj")]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                loaded_weight = loaded_weight[
 | 
				
			||||||
 | 
					                    shard_size * tensor_model_parallel_rank:shard_size *
 | 
				
			||||||
 | 
					                    (tensor_model_parallel_rank + 1)]
 | 
				
			||||||
 | 
					                param_slice = param.data[offset:offset + shard_size]
 | 
				
			||||||
 | 
					                assert param_slice.shape == loaded_weight.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                param_slice.copy_(loaded_weight)
 | 
				
			||||||
 | 
					                is_attention_weight = True
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            if is_attention_weight:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            is_gate_up_weight = False
 | 
				
			||||||
 | 
					            for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
 | 
				
			||||||
 | 
					                if weight_name not in name:
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                param = state_dict[name.replace(weight_name, "gate_up_proj")]
 | 
				
			||||||
 | 
					                shard_size = param.shape[0] // 2
 | 
				
			||||||
 | 
					                loaded_weight = loaded_weight[
 | 
				
			||||||
 | 
					                    shard_size * tensor_model_parallel_rank:shard_size *
 | 
				
			||||||
 | 
					                    (tensor_model_parallel_rank + 1)]
 | 
				
			||||||
 | 
					                param_slice = param.data[shard_size * stride_id:shard_size *
 | 
				
			||||||
 | 
					                                         (stride_id + 1)]
 | 
				
			||||||
 | 
					                assert param_slice.shape == loaded_weight.shape
 | 
				
			||||||
 | 
					                param_slice.copy_(loaded_weight)
 | 
				
			||||||
 | 
					                is_gate_up_weight = True
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            if is_gate_up_weight:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            param = state_dict[name]
 | 
				
			||||||
 | 
					            if "embed_tokens" in name or "lm_head" in name:
 | 
				
			||||||
 | 
					                load_padded_tensor_parallel_vocab(param, loaded_weight,
 | 
				
			||||||
 | 
					                                                  tensor_model_parallel_rank)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            load_tensor_parallel_weights(param, loaded_weight, name,
 | 
				
			||||||
 | 
					                                         self._column_parallel_weights,
 | 
				
			||||||
 | 
					                                         self._row_parallel_weights,
 | 
				
			||||||
 | 
					                                         tensor_model_parallel_rank)
 | 
				
			||||||
							
								
								
									
										389
									
								
								vllm/model_executor/models/baichuan.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,389 @@
 | 
				
			|||||||
 | 
					# coding=utf-8
 | 
				
			||||||
 | 
					# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
 | 
				
			||||||
 | 
					# and OPT implementations in this library. It has been modified from its
 | 
				
			||||||
 | 
					# original forms to accommodate minor architectural differences compared
 | 
				
			||||||
 | 
					# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Inference-only BaiChuan model compatible with HuggingFace weights.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The input of the model is flattened to a 1D tensor of tokens. The model uses
 | 
				
			||||||
 | 
					InputMetadata to extract the original 2D shape of the input.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					from typing import List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm.model_executor.input_metadata import InputMetadata
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.activation import SiluAndMul
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.layernorm import RMSNorm
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
 | 
				
			||||||
 | 
					                                                  PagedAttentionWithALiBi)
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.sampler import Sampler
 | 
				
			||||||
 | 
					from vllm.model_executor.weight_utils import (
 | 
				
			||||||
 | 
					    convert_pyslice_to_tensor, hf_model_weights_iterator,
 | 
				
			||||||
 | 
					    load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.parallel_state import (
 | 
				
			||||||
 | 
					    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
 | 
				
			||||||
 | 
					                                                       ColumnParallelLinear,
 | 
				
			||||||
 | 
					                                                       RowParallelLinear)
 | 
				
			||||||
 | 
					from vllm.sequence import SamplerOutput
 | 
				
			||||||
 | 
					from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					KVCache = Tuple[torch.Tensor, torch.Tensor]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
 | 
				
			||||||
 | 
					    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
 | 
				
			||||||
 | 
					    base = torch.tensor(
 | 
				
			||||||
 | 
					        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
 | 
				
			||||||
 | 
					        dtype=torch.float32,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
 | 
				
			||||||
 | 
					    slopes = torch.pow(base, powers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if closest_power_of_2 != total_num_heads:
 | 
				
			||||||
 | 
					        extra_base = torch.tensor(
 | 
				
			||||||
 | 
					            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
 | 
				
			||||||
 | 
					            dtype=torch.float32,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        num_remaining_heads = min(closest_power_of_2,
 | 
				
			||||||
 | 
					                                  total_num_heads - closest_power_of_2)
 | 
				
			||||||
 | 
					        extra_powers = torch.arange(start=1,
 | 
				
			||||||
 | 
					                                    end=1 + 2 * num_remaining_heads,
 | 
				
			||||||
 | 
					                                    step=2,
 | 
				
			||||||
 | 
					                                    dtype=torch.int32)
 | 
				
			||||||
 | 
					        slopes = torch.cat(
 | 
				
			||||||
 | 
					            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
 | 
				
			||||||
 | 
					    return slopes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaiChuanMLP(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        hidden_size: int,
 | 
				
			||||||
 | 
					        intermediate_size: int,
 | 
				
			||||||
 | 
					        hidden_act: str,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.gate_up_proj = ColumnParallelLinear(
 | 
				
			||||||
 | 
					            hidden_size,
 | 
				
			||||||
 | 
					            2 * intermediate_size,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            gather_output=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.down_proj = RowParallelLinear(
 | 
				
			||||||
 | 
					            intermediate_size,
 | 
				
			||||||
 | 
					            hidden_size,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if hidden_act != "silu":
 | 
				
			||||||
 | 
					            raise ValueError(f"Unsupported activation: {hidden_act}. "
 | 
				
			||||||
 | 
					                             "Only silu is supported for now.")
 | 
				
			||||||
 | 
					        self.act_fn = SiluAndMul()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        gate_up, _ = self.gate_up_proj(x)
 | 
				
			||||||
 | 
					        x = self.act_fn(gate_up)
 | 
				
			||||||
 | 
					        x, _ = self.down_proj(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaiChuanAttention(nn.Module):
 | 
				
			||||||
 | 
					    """Multi-headed attention from 'Attention Is All You Need' paper"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        hidden_size: int,
 | 
				
			||||||
 | 
					        num_heads: int,
 | 
				
			||||||
 | 
					        position_embedding: str,
 | 
				
			||||||
 | 
					        rope_theta: float = 10000,
 | 
				
			||||||
 | 
					        max_position_embeddings: int = 8192,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.hidden_size = hidden_size
 | 
				
			||||||
 | 
					        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.total_num_heads = num_heads
 | 
				
			||||||
 | 
					        assert self.total_num_heads % tensor_model_parallel_world_size == 0
 | 
				
			||||||
 | 
					        self.num_heads = (self.total_num_heads //
 | 
				
			||||||
 | 
					                          tensor_model_parallel_world_size)
 | 
				
			||||||
 | 
					        self.head_dim = hidden_size // self.total_num_heads
 | 
				
			||||||
 | 
					        self.postion_embedding = position_embedding
 | 
				
			||||||
 | 
					        self.rope_theta = rope_theta
 | 
				
			||||||
 | 
					        self.max_position_embeddings = max_position_embeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # pylint: disable=invalid-name
 | 
				
			||||||
 | 
					        self.W_pack = ColumnParallelLinear(
 | 
				
			||||||
 | 
					            hidden_size,
 | 
				
			||||||
 | 
					            3 * hidden_size,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            gather_output=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.o_proj = RowParallelLinear(
 | 
				
			||||||
 | 
					            self.total_num_heads * self.head_dim,
 | 
				
			||||||
 | 
					            hidden_size,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # Create the alibi slopes and slice them.
 | 
				
			||||||
 | 
					        if self.postion_embedding == "ALIBI":
 | 
				
			||||||
 | 
					            tp_rank = get_tensor_model_parallel_rank()
 | 
				
			||||||
 | 
					            head_start = tp_rank * self.num_heads
 | 
				
			||||||
 | 
					            head_end = (tp_rank + 1) * self.num_heads
 | 
				
			||||||
 | 
					            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
 | 
				
			||||||
 | 
					            alibi_slopes = alibi_slopes[head_start:head_end].tolist()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            scaling = self.head_dim**-0.5
 | 
				
			||||||
 | 
					            self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
 | 
				
			||||||
 | 
					                                                scaling, alibi_slopes)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.scaling = self.head_dim**-0.5
 | 
				
			||||||
 | 
					            self.attn = PagedAttentionWithRoPE(
 | 
				
			||||||
 | 
					                self.num_heads,
 | 
				
			||||||
 | 
					                self.head_dim,
 | 
				
			||||||
 | 
					                self.scaling,
 | 
				
			||||||
 | 
					                rotary_dim=self.head_dim,
 | 
				
			||||||
 | 
					                base=self.rope_theta,
 | 
				
			||||||
 | 
					                max_position=self.max_position_embeddings)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_cache: KVCache,
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_event: Optional[torch.cuda.Event],
 | 
				
			||||||
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
 | 
					        qkv, _ = self.W_pack(hidden_states)
 | 
				
			||||||
 | 
					        q, k, v = qkv.chunk(chunks=3, dim=-1)
 | 
				
			||||||
 | 
					        k_cache, v_cache = kv_cache
 | 
				
			||||||
 | 
					        if self.postion_embedding == "ALIBI":
 | 
				
			||||||
 | 
					            attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
 | 
				
			||||||
 | 
					                                    cache_event)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
 | 
				
			||||||
 | 
					                                    input_metadata, cache_event)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        output, _ = self.o_proj(attn_output)
 | 
				
			||||||
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaiChuanDecoderLayer(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config: BaiChuanConfig, position_embedding: str):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.hidden_size = config.hidden_size
 | 
				
			||||||
 | 
					        rope_theta = getattr(config, "rope_theta", 10000)
 | 
				
			||||||
 | 
					        max_position_embeddings = getattr(config, "max_position_embeddings",
 | 
				
			||||||
 | 
					                                          8192)
 | 
				
			||||||
 | 
					        self.self_attn = BaiChuanAttention(
 | 
				
			||||||
 | 
					            hidden_size=self.hidden_size,
 | 
				
			||||||
 | 
					            num_heads=config.num_attention_heads,
 | 
				
			||||||
 | 
					            position_embedding=position_embedding,
 | 
				
			||||||
 | 
					            rope_theta=rope_theta,
 | 
				
			||||||
 | 
					            max_position_embeddings=max_position_embeddings,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.mlp = BaiChuanMLP(
 | 
				
			||||||
 | 
					            hidden_size=self.hidden_size,
 | 
				
			||||||
 | 
					            intermediate_size=config.intermediate_size,
 | 
				
			||||||
 | 
					            hidden_act=config.hidden_act,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.input_layernorm = RMSNorm(config.hidden_size,
 | 
				
			||||||
 | 
					                                       eps=config.rms_norm_eps)
 | 
				
			||||||
 | 
					        self.post_attention_layernorm = RMSNorm(config.hidden_size,
 | 
				
			||||||
 | 
					                                                eps=config.rms_norm_eps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_cache: KVCache,
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_event: Optional[torch.cuda.Event],
 | 
				
			||||||
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
 | 
					        # Self Attention
 | 
				
			||||||
 | 
					        residual = hidden_states
 | 
				
			||||||
 | 
					        hidden_states = self.input_layernorm(hidden_states)
 | 
				
			||||||
 | 
					        hidden_states = self.self_attn(
 | 
				
			||||||
 | 
					            positions=positions,
 | 
				
			||||||
 | 
					            hidden_states=hidden_states,
 | 
				
			||||||
 | 
					            kv_cache=kv_cache,
 | 
				
			||||||
 | 
					            input_metadata=input_metadata,
 | 
				
			||||||
 | 
					            cache_event=cache_event,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        hidden_states = residual + hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Fully Connected
 | 
				
			||||||
 | 
					        residual = hidden_states
 | 
				
			||||||
 | 
					        hidden_states = self.post_attention_layernorm(hidden_states)
 | 
				
			||||||
 | 
					        hidden_states = self.mlp(hidden_states)
 | 
				
			||||||
 | 
					        hidden_states = residual + hidden_states
 | 
				
			||||||
 | 
					        return hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaiChuanModel(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config: BaiChuanConfig, position_embedding: str):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.padding_idx = config.pad_token_id
 | 
				
			||||||
 | 
					        self.vocab_size = config.vocab_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.embed_tokens = VocabParallelEmbedding(
 | 
				
			||||||
 | 
					            config.vocab_size,
 | 
				
			||||||
 | 
					            config.hidden_size,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.layers = nn.ModuleList([
 | 
				
			||||||
 | 
					            BaiChuanDecoderLayer(config, position_embedding)
 | 
				
			||||||
 | 
					            for _ in range(config.num_hidden_layers)
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        input_ids: torch.Tensor,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_caches: List[KVCache],
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_events: Optional[List[torch.cuda.Event]],
 | 
				
			||||||
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
 | 
					        hidden_states = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					        for i in range(len(self.layers)):
 | 
				
			||||||
 | 
					            if cache_events is None:
 | 
				
			||||||
 | 
					                cache_event = None
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                cache_event = cache_events[i]
 | 
				
			||||||
 | 
					            layer = self.layers[i]
 | 
				
			||||||
 | 
					            hidden_states = layer(
 | 
				
			||||||
 | 
					                positions,
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                kv_caches[i],
 | 
				
			||||||
 | 
					                input_metadata,
 | 
				
			||||||
 | 
					                cache_event,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					        return hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaiChuanBaseForCausalLM(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config, position_embedding: str):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.model = BaiChuanModel(config, position_embedding)
 | 
				
			||||||
 | 
					        self.lm_head = ColumnParallelLinear(
 | 
				
			||||||
 | 
					            config.hidden_size,
 | 
				
			||||||
 | 
					            config.vocab_size,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            gather_output=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.sampler = Sampler(config.vocab_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        input_ids: torch.Tensor,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_caches: List[KVCache],
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_events: Optional[List[torch.cuda.Event]],
 | 
				
			||||||
 | 
					    ) -> SamplerOutput:
 | 
				
			||||||
 | 
					        hidden_states = self.model(input_ids, positions, kv_caches,
 | 
				
			||||||
 | 
					                                   input_metadata, cache_events)
 | 
				
			||||||
 | 
					        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
 | 
				
			||||||
 | 
					                                   input_metadata)
 | 
				
			||||||
 | 
					        return next_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _column_parallel_weights = []
 | 
				
			||||||
 | 
					    _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_weights(self,
 | 
				
			||||||
 | 
					                     model_name_or_path: str,
 | 
				
			||||||
 | 
					                     cache_dir: Optional[str] = None,
 | 
				
			||||||
 | 
					                     load_format: str = "auto",
 | 
				
			||||||
 | 
					                     revision: Optional[str] = None):
 | 
				
			||||||
 | 
					        tp_world_size = get_tensor_model_parallel_world_size()
 | 
				
			||||||
 | 
					        tp_rank = get_tensor_model_parallel_rank()
 | 
				
			||||||
 | 
					        state_dict = self.state_dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for name, loaded_weight in hf_model_weights_iterator(
 | 
				
			||||||
 | 
					                model_name_or_path, cache_dir, load_format, revision):
 | 
				
			||||||
 | 
					            if "rotary_emb.inv_freq" in name:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            loaded_weight = convert_pyslice_to_tensor(loaded_weight)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if "W_pack" in name:
 | 
				
			||||||
 | 
					                total_num_heads = self.config.num_attention_heads
 | 
				
			||||||
 | 
					                hidden_size = self.config.hidden_size
 | 
				
			||||||
 | 
					                head_size = hidden_size // total_num_heads
 | 
				
			||||||
 | 
					                num_heads = total_num_heads // tp_world_size
 | 
				
			||||||
 | 
					                head_start = tp_rank * num_heads
 | 
				
			||||||
 | 
					                head_end = (tp_rank + 1) * num_heads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                loaded_weight = loaded_weight.view(3, total_num_heads,
 | 
				
			||||||
 | 
					                                                   head_size, hidden_size)
 | 
				
			||||||
 | 
					                loaded_weight = loaded_weight[:, head_start:head_end, :, :]
 | 
				
			||||||
 | 
					                loaded_weight = loaded_weight.reshape(-1, hidden_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            is_gate_up_weight = False
 | 
				
			||||||
 | 
					            for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
 | 
				
			||||||
 | 
					                if weight_name not in name:
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                param = state_dict[name.replace(weight_name, "gate_up_proj")]
 | 
				
			||||||
 | 
					                shard_size = param.shape[0] // 2
 | 
				
			||||||
 | 
					                loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
 | 
				
			||||||
 | 
					                                              (tp_rank + 1)]
 | 
				
			||||||
 | 
					                param_slice = param.data[shard_size * stride_id:shard_size *
 | 
				
			||||||
 | 
					                                         (stride_id + 1)]
 | 
				
			||||||
 | 
					                assert param_slice.shape == loaded_weight.shape
 | 
				
			||||||
 | 
					                param_slice.copy_(loaded_weight)
 | 
				
			||||||
 | 
					                is_gate_up_weight = True
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            if is_gate_up_weight:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            param = state_dict[name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if "embed_tokens" in name or "lm_head" in name:
 | 
				
			||||||
 | 
					                load_padded_tensor_parallel_vocab(param, loaded_weight,
 | 
				
			||||||
 | 
					                                                  tp_rank)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            load_tensor_parallel_weights(
 | 
				
			||||||
 | 
					                param,
 | 
				
			||||||
 | 
					                loaded_weight,
 | 
				
			||||||
 | 
					                name,
 | 
				
			||||||
 | 
					                self._column_parallel_weights,
 | 
				
			||||||
 | 
					                self._row_parallel_weights,
 | 
				
			||||||
 | 
					                tp_rank,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaichuanForCausalLM(BaiChuanBaseForCausalLM):  # baichuan 13b
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config):
 | 
				
			||||||
 | 
					        super().__init__(config, "ALIBI")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):  # baichuan 7b
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config):
 | 
				
			||||||
 | 
					        super().__init__(config, "ROPE")
 | 
				
			||||||
@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
 | 
				
			|||||||
InputMetadata to extract the original 2D shape of the input.
 | 
					InputMetadata to extract the original 2D shape of the input.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
from typing import Dict, List, Optional, Tuple
 | 
					from typing import List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
@ -35,9 +35,10 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
 | 
				
			|||||||
                                              load_tensor_parallel_weights)
 | 
					                                              load_tensor_parallel_weights)
 | 
				
			||||||
from vllm.model_executor.parallel_utils.parallel_state import (
 | 
					from vllm.model_executor.parallel_utils.parallel_state import (
 | 
				
			||||||
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 | 
					    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 | 
				
			||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
 | 
					from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
 | 
				
			||||||
    VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
 | 
					                                                       ColumnParallelLinear,
 | 
				
			||||||
from vllm.sequence import SequenceOutputs
 | 
					                                                       RowParallelLinear)
 | 
				
			||||||
 | 
					from vllm.sequence import SamplerOutput
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
 | 
					KVCache = Tuple[torch.Tensor, torch.Tensor]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -85,14 +86,12 @@ class BloomAttention(nn.Module):
 | 
				
			|||||||
            3 * self.hidden_size,
 | 
					            3 * self.hidden_size,
 | 
				
			||||||
            bias=True,
 | 
					            bias=True,
 | 
				
			||||||
            gather_output=False,
 | 
					            gather_output=False,
 | 
				
			||||||
            perform_initialization=False,
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.dense = RowParallelLinear(
 | 
					        self.dense = RowParallelLinear(
 | 
				
			||||||
            self.hidden_size,
 | 
					            self.hidden_size,
 | 
				
			||||||
            self.hidden_size,
 | 
					            self.hidden_size,
 | 
				
			||||||
            bias=True,
 | 
					            bias=True,
 | 
				
			||||||
            input_is_parallel=True,
 | 
					            input_is_parallel=True,
 | 
				
			||||||
            perform_initialization=False,
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create the alibi slopes and slice them.
 | 
					        # Create the alibi slopes and slice them.
 | 
				
			||||||
@ -129,15 +128,17 @@ class BloomMLP(nn.Module):
 | 
				
			|||||||
    def __init__(self, config: BloomConfig):
 | 
					    def __init__(self, config: BloomConfig):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        hidden_size = config.hidden_size
 | 
					        hidden_size = config.hidden_size
 | 
				
			||||||
        self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
 | 
					        self.dense_h_to_4h = ColumnParallelLinear(
 | 
				
			||||||
                                                  4 * hidden_size,
 | 
					            hidden_size,
 | 
				
			||||||
                                                  gather_output=False,
 | 
					            4 * hidden_size,
 | 
				
			||||||
                                                  perform_initialization=False)
 | 
					            gather_output=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.act = get_act_fn("gelu")
 | 
					        self.act = get_act_fn("gelu")
 | 
				
			||||||
        self.dense_4h_to_h = RowParallelLinear(4 * hidden_size,
 | 
					        self.dense_4h_to_h = RowParallelLinear(
 | 
				
			||||||
                                               hidden_size,
 | 
					            4 * hidden_size,
 | 
				
			||||||
                                               input_is_parallel=True,
 | 
					            hidden_size,
 | 
				
			||||||
                                               perform_initialization=False)
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
					    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
        x, _ = self.dense_h_to_4h(x)
 | 
					        x, _ = self.dense_h_to_4h(x)
 | 
				
			||||||
@ -208,7 +209,9 @@ class BloomModel(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Embedding + LN Embedding
 | 
					        # Embedding + LN Embedding
 | 
				
			||||||
        self.word_embeddings = VocabParallelEmbedding(
 | 
					        self.word_embeddings = VocabParallelEmbedding(
 | 
				
			||||||
            config.vocab_size, self.embed_dim, perform_initialization=False)
 | 
					            config.vocab_size,
 | 
				
			||||||
 | 
					            self.embed_dim,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.word_embeddings_layernorm = nn.LayerNorm(
 | 
					        self.word_embeddings_layernorm = nn.LayerNorm(
 | 
				
			||||||
            self.embed_dim, eps=config.layer_norm_epsilon)
 | 
					            self.embed_dim, eps=config.layer_norm_epsilon)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -264,7 +267,7 @@ class BloomForCausalLM(nn.Module):
 | 
				
			|||||||
        kv_caches: List[KVCache],
 | 
					        kv_caches: List[KVCache],
 | 
				
			||||||
        input_metadata: InputMetadata,
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
        cache_events: Optional[List[torch.cuda.Event]],
 | 
					        cache_events: Optional[List[torch.cuda.Event]],
 | 
				
			||||||
    ) -> Dict[int, SequenceOutputs]:
 | 
					    ) -> SamplerOutput:
 | 
				
			||||||
        hidden_states = self.transformer(input_ids, positions, kv_caches,
 | 
					        hidden_states = self.transformer(input_ids, positions, kv_caches,
 | 
				
			||||||
                                         input_metadata, cache_events)
 | 
					                                         input_metadata, cache_events)
 | 
				
			||||||
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
 | 
					        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
 | 
				
			||||||
@ -279,15 +282,23 @@ class BloomForCausalLM(nn.Module):
 | 
				
			|||||||
    def load_weights(self,
 | 
					    def load_weights(self,
 | 
				
			||||||
                     model_name_or_path: str,
 | 
					                     model_name_or_path: str,
 | 
				
			||||||
                     cache_dir: Optional[str] = None,
 | 
					                     cache_dir: Optional[str] = None,
 | 
				
			||||||
                     use_np_cache: bool = False):
 | 
					                     load_format: str = "auto",
 | 
				
			||||||
 | 
					                     revision: Optional[str] = None):
 | 
				
			||||||
        tp_rank = get_tensor_model_parallel_rank()
 | 
					        tp_rank = get_tensor_model_parallel_rank()
 | 
				
			||||||
        state_dict = self.state_dict()
 | 
					        state_dict = self.state_dict()
 | 
				
			||||||
        for name, loaded_weight in hf_model_weights_iterator(
 | 
					        for name, loaded_weight in hf_model_weights_iterator(
 | 
				
			||||||
                model_name_or_path, cache_dir, use_np_cache):
 | 
					                model_name_or_path, cache_dir, load_format, revision):
 | 
				
			||||||
            if not name.startswith("transformer."):
 | 
					            if name == "lm_head.weight":
 | 
				
			||||||
                name = "transformer." + name
 | 
					                # Since hidden_states are parallelized, we need to
 | 
				
			||||||
 | 
					                # load lm_head.weight in parallel.
 | 
				
			||||||
 | 
					                self._column_parallel_weights.append(name)
 | 
				
			||||||
 | 
					                # If lm_head is provided, use it instead.
 | 
				
			||||||
 | 
					                param = self.lm_head_weight
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                if not name.startswith("transformer."):
 | 
				
			||||||
 | 
					                    name = "transformer." + name
 | 
				
			||||||
 | 
					                param = state_dict[name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            param = state_dict[name]
 | 
					 | 
				
			||||||
            if "query_key_value" in name:
 | 
					            if "query_key_value" in name:
 | 
				
			||||||
                # NOTE(woosuk): BLOOM's fused QKV has the shape of
 | 
					                # NOTE(woosuk): BLOOM's fused QKV has the shape of
 | 
				
			||||||
                # [num_heads * 3 * head_size, hidden_size], while the
 | 
					                # [num_heads * 3 * head_size, hidden_size], while the
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										503
									
								
								vllm/model_executor/models/falcon.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,503 @@
 | 
				
			|||||||
 | 
					# coding=utf-8
 | 
				
			||||||
 | 
					# Adapted from
 | 
				
			||||||
 | 
					# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
 | 
				
			||||||
 | 
					# Copyright 2023 The vLLM team.
 | 
				
			||||||
 | 
					# Copyright 2023 the Falcon authors and HuggingFace Inc. team.  All rights
 | 
				
			||||||
 | 
					# reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""PyTorch Falcon model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					from typing import List, Optional, Tuple, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					from torch.nn import LayerNorm
 | 
				
			||||||
 | 
					from transformers import FalconConfig as HF_FalconConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm.model_executor.input_metadata import InputMetadata
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.attention import (PagedAttention,
 | 
				
			||||||
 | 
					                                                  PagedAttentionWithALiBi,
 | 
				
			||||||
 | 
					                                                  PagedAttentionWithRoPE)
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.sampler import Sampler
 | 
				
			||||||
 | 
					from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
 | 
				
			||||||
 | 
					                                              hf_model_weights_iterator,
 | 
				
			||||||
 | 
					                                              load_tensor_parallel_weights)
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.parallel_state import (
 | 
				
			||||||
 | 
					    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
 | 
				
			||||||
 | 
					                                                       ColumnParallelLinear,
 | 
				
			||||||
 | 
					                                                       RowParallelLinear)
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.communication_op import (
 | 
				
			||||||
 | 
					    tensor_model_parallel_all_reduce)
 | 
				
			||||||
 | 
					from vllm.sequence import SamplerOutput
 | 
				
			||||||
 | 
					from vllm.transformers_utils.configs import RWConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					KVCache = Tuple[torch.Tensor, torch.Tensor]
 | 
				
			||||||
 | 
					FalconConfig = Union[HF_FalconConfig, RWConfig]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
 | 
				
			||||||
 | 
					# training, this means that there's one additional quantization to bfloat16
 | 
				
			||||||
 | 
					# between the operations. In order not to degrade the quality of our HF-port,
 | 
				
			||||||
 | 
					# we keep these characteristics in the final model.
 | 
				
			||||||
 | 
					class FalconLinear(nn.Linear):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        hidden_states = x @ self.weight.T
 | 
				
			||||||
 | 
					        if self.bias is None:
 | 
				
			||||||
 | 
					            return hidden_states
 | 
				
			||||||
 | 
					        return hidden_states + self.bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
 | 
				
			||||||
 | 
					    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
 | 
				
			||||||
 | 
					    base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
 | 
				
			||||||
 | 
					                        dtype=torch.float32)
 | 
				
			||||||
 | 
					    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
 | 
				
			||||||
 | 
					    slopes = torch.pow(base, powers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if closest_power_of_2 != total_num_heads:
 | 
				
			||||||
 | 
					        extra_base = torch.tensor(
 | 
				
			||||||
 | 
					            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
 | 
				
			||||||
 | 
					            dtype=torch.float32)
 | 
				
			||||||
 | 
					        num_remaining_heads = min(closest_power_of_2,
 | 
				
			||||||
 | 
					                                  total_num_heads - closest_power_of_2)
 | 
				
			||||||
 | 
					        extra_powers = torch.arange(1,
 | 
				
			||||||
 | 
					                                    1 + 2 * num_remaining_heads,
 | 
				
			||||||
 | 
					                                    2,
 | 
				
			||||||
 | 
					                                    dtype=torch.int32)
 | 
				
			||||||
 | 
					        slopes = torch.cat(
 | 
				
			||||||
 | 
					            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return slopes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FalconAttention(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config: FalconConfig):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.hidden_size = config.hidden_size
 | 
				
			||||||
 | 
					        tp_size = get_tensor_model_parallel_world_size()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.total_num_heads = config.num_attention_heads
 | 
				
			||||||
 | 
					        assert self.total_num_heads % tp_size == 0
 | 
				
			||||||
 | 
					        self.num_heads = self.total_num_heads // tp_size
 | 
				
			||||||
 | 
					        self.head_dim = self.hidden_size // self.total_num_heads
 | 
				
			||||||
 | 
					        assert self.head_dim * self.total_num_heads == self.hidden_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.new_decoder_architecture = config.new_decoder_architecture
 | 
				
			||||||
 | 
					        self.multi_query = config.multi_query
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.new_decoder_architecture:
 | 
				
			||||||
 | 
					            self.total_num_kv_heads = config.num_kv_heads
 | 
				
			||||||
 | 
					            assert self.total_num_heads % tp_size == 0
 | 
				
			||||||
 | 
					            self.num_kv_heads = self.total_num_kv_heads // tp_size
 | 
				
			||||||
 | 
					            self.query_key_value = ColumnParallelLinear(
 | 
				
			||||||
 | 
					                self.hidden_size,
 | 
				
			||||||
 | 
					                (self.total_num_heads + 2 * self.total_num_kv_heads) *
 | 
				
			||||||
 | 
					                self.head_dim,
 | 
				
			||||||
 | 
					                bias=config.bias,
 | 
				
			||||||
 | 
					                gather_output=False,
 | 
				
			||||||
 | 
					                skip_bias_add=True,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        elif self.multi_query:
 | 
				
			||||||
 | 
					            self.total_num_kv_heads = 1
 | 
				
			||||||
 | 
					            self.num_kv_heads = 1
 | 
				
			||||||
 | 
					            self.query = ColumnParallelLinear(
 | 
				
			||||||
 | 
					                self.hidden_size,
 | 
				
			||||||
 | 
					                self.total_num_heads * self.head_dim,
 | 
				
			||||||
 | 
					                bias=config.bias,
 | 
				
			||||||
 | 
					                gather_output=False,
 | 
				
			||||||
 | 
					                skip_bias_add=True,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.key_value = FalconLinear(self.hidden_size,
 | 
				
			||||||
 | 
					                                          2 * self.head_dim,
 | 
				
			||||||
 | 
					                                          bias=config.bias)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.total_num_kv_heads = self.total_num_heads
 | 
				
			||||||
 | 
					            self.num_kv_heads = self.num_heads
 | 
				
			||||||
 | 
					            self.query_key_value = ColumnParallelLinear(
 | 
				
			||||||
 | 
					                self.hidden_size,
 | 
				
			||||||
 | 
					                (self.total_num_heads + 2 * self.total_num_kv_heads) *
 | 
				
			||||||
 | 
					                self.head_dim,
 | 
				
			||||||
 | 
					                bias=config.bias,
 | 
				
			||||||
 | 
					                gather_output=False,
 | 
				
			||||||
 | 
					                skip_bias_add=True,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.q_size = self.num_heads * self.head_dim
 | 
				
			||||||
 | 
					        self.kv_size = self.num_kv_heads * self.head_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Layer-wise attention scaling
 | 
				
			||||||
 | 
					        self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
 | 
				
			||||||
 | 
					        self.reduce_row_parallel_results = not (config.new_decoder_architecture
 | 
				
			||||||
 | 
					                                                or config.parallel_attn)
 | 
				
			||||||
 | 
					        self.dense = RowParallelLinear(
 | 
				
			||||||
 | 
					            self.hidden_size,
 | 
				
			||||||
 | 
					            self.hidden_size,
 | 
				
			||||||
 | 
					            bias=config.bias,
 | 
				
			||||||
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					            skip_bias_add=True,
 | 
				
			||||||
 | 
					            reduce_results=self.reduce_row_parallel_results)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.use_rotary = config.rotary
 | 
				
			||||||
 | 
					        self.use_alibi = config.alibi
 | 
				
			||||||
 | 
					        assert not (self.use_rotary and self.use_alibi), (
 | 
				
			||||||
 | 
					            "Rotary and alibi are mutually exclusive.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.use_rotary:
 | 
				
			||||||
 | 
					            rope_theta = getattr(config, "rope_theta", 10000)
 | 
				
			||||||
 | 
					            max_position_embeddings = getattr(config,
 | 
				
			||||||
 | 
					                                              "max_position_embeddings", 8192)
 | 
				
			||||||
 | 
					            self.attn = PagedAttentionWithRoPE(
 | 
				
			||||||
 | 
					                self.num_heads,
 | 
				
			||||||
 | 
					                self.head_dim,
 | 
				
			||||||
 | 
					                self.inv_norm_factor,
 | 
				
			||||||
 | 
					                base=rope_theta,
 | 
				
			||||||
 | 
					                max_position=max_position_embeddings,
 | 
				
			||||||
 | 
					                rotary_dim=self.head_dim,
 | 
				
			||||||
 | 
					                num_kv_heads=self.num_kv_heads)
 | 
				
			||||||
 | 
					        elif self.use_alibi:
 | 
				
			||||||
 | 
					            tp_rank = get_tensor_model_parallel_rank()
 | 
				
			||||||
 | 
					            head_start = tp_rank * self.num_heads
 | 
				
			||||||
 | 
					            head_end = (tp_rank + 1) * self.num_heads
 | 
				
			||||||
 | 
					            alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
 | 
				
			||||||
 | 
					                            self.inv_norm_factor)
 | 
				
			||||||
 | 
					            alibi_slopes = alibi_slopes[head_start:head_end].tolist()
 | 
				
			||||||
 | 
					            self.attn = PagedAttentionWithALiBi(self.num_heads,
 | 
				
			||||||
 | 
					                                                self.head_dim,
 | 
				
			||||||
 | 
					                                                self.inv_norm_factor,
 | 
				
			||||||
 | 
					                                                alibi_slopes,
 | 
				
			||||||
 | 
					                                                num_kv_heads=self.num_kv_heads)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.attn = PagedAttention(self.num_heads,
 | 
				
			||||||
 | 
					                                       self.head_dim,
 | 
				
			||||||
 | 
					                                       scale=self.inv_norm_factor,
 | 
				
			||||||
 | 
					                                       num_kv_heads=self.num_kv_heads)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_cache: KVCache,
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_event: Optional[torch.cuda.Event],
 | 
				
			||||||
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
 | 
					        if not self.new_decoder_architecture and self.multi_query:
 | 
				
			||||||
 | 
					            q, bias = self.query(hidden_states)
 | 
				
			||||||
 | 
					            if bias is not None:
 | 
				
			||||||
 | 
					                q += bias
 | 
				
			||||||
 | 
					            kv = self.key_value(hidden_states)
 | 
				
			||||||
 | 
					            k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            qkv, bias = self.query_key_value(hidden_states)
 | 
				
			||||||
 | 
					            if bias is not None:
 | 
				
			||||||
 | 
					                qkv += bias
 | 
				
			||||||
 | 
					            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
 | 
				
			||||||
 | 
					                                dim=-1)
 | 
				
			||||||
 | 
					        k_cache, v_cache = kv_cache
 | 
				
			||||||
 | 
					        if self.use_rotary:
 | 
				
			||||||
 | 
					            attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
 | 
				
			||||||
 | 
					                                    input_metadata, cache_event)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
 | 
				
			||||||
 | 
					                                    cache_event)
 | 
				
			||||||
 | 
					        attn_output, bias = self.dense(attn_output)
 | 
				
			||||||
 | 
					        return attn_output, bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FalconMLP(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config: FalconConfig):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        hidden_size = config.hidden_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
 | 
				
			||||||
 | 
					                                                  4 * hidden_size,
 | 
				
			||||||
 | 
					                                                  bias=config.bias,
 | 
				
			||||||
 | 
					                                                  gather_output=False,
 | 
				
			||||||
 | 
					                                                  skip_bias_add=True)
 | 
				
			||||||
 | 
					        self.act = nn.GELU()
 | 
				
			||||||
 | 
					        self.reduce_row_parallel_results = not (config.new_decoder_architecture
 | 
				
			||||||
 | 
					                                                or config.parallel_attn)
 | 
				
			||||||
 | 
					        self.dense_4h_to_h = RowParallelLinear(
 | 
				
			||||||
 | 
					            4 * hidden_size,
 | 
				
			||||||
 | 
					            hidden_size,
 | 
				
			||||||
 | 
					            bias=config.bias,
 | 
				
			||||||
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					            skip_bias_add=True,
 | 
				
			||||||
 | 
					            reduce_results=self.reduce_row_parallel_results)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
 | 
				
			||||||
 | 
					        x, bias = self.dense_h_to_4h(x)
 | 
				
			||||||
 | 
					        if bias is not None:
 | 
				
			||||||
 | 
					            x += bias
 | 
				
			||||||
 | 
					        x = self.act(x)
 | 
				
			||||||
 | 
					        x, bias = self.dense_4h_to_h(x)
 | 
				
			||||||
 | 
					        return x, bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FalconDecoderLayer(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config: FalconConfig):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        hidden_size = config.hidden_size
 | 
				
			||||||
 | 
					        self.num_heads = config.num_attention_heads
 | 
				
			||||||
 | 
					        self.self_attention = FalconAttention(config)
 | 
				
			||||||
 | 
					        self.mlp = FalconMLP(config)
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if config.new_decoder_architecture:
 | 
				
			||||||
 | 
					            # The layer norm before self-attention
 | 
				
			||||||
 | 
					            self.ln_attn = LayerNorm(hidden_size,
 | 
				
			||||||
 | 
					                                     eps=config.layer_norm_epsilon)
 | 
				
			||||||
 | 
					            # The layer norm before the MLP
 | 
				
			||||||
 | 
					            self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.input_layernorm = LayerNorm(hidden_size,
 | 
				
			||||||
 | 
					                                             eps=config.layer_norm_epsilon)
 | 
				
			||||||
 | 
					            if not config.parallel_attn:
 | 
				
			||||||
 | 
					                self.post_attention_layernorm = LayerNorm(
 | 
				
			||||||
 | 
					                    hidden_size, eps=config.layer_norm_epsilon)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.reduce_row_parallel_results = not (config.new_decoder_architecture
 | 
				
			||||||
 | 
					                                                or config.parallel_attn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_cache: KVCache,
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_event: Optional[torch.cuda.Event],
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        residual = hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.config.new_decoder_architecture:
 | 
				
			||||||
 | 
					            attention_layernorm_out = self.ln_attn(hidden_states)
 | 
				
			||||||
 | 
					            mlp_layernorm_out = self.ln_mlp(hidden_states)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            attention_layernorm_out = self.input_layernorm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Self attention.
 | 
				
			||||||
 | 
					        attention_output, attention_bias = self.self_attention(
 | 
				
			||||||
 | 
					            positions=positions,
 | 
				
			||||||
 | 
					            hidden_states=attention_layernorm_out,
 | 
				
			||||||
 | 
					            kv_cache=kv_cache,
 | 
				
			||||||
 | 
					            input_metadata=input_metadata,
 | 
				
			||||||
 | 
					            cache_event=cache_event,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if self.reduce_row_parallel_results and attention_bias is not None:
 | 
				
			||||||
 | 
					            attention_output += attention_bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.config.new_decoder_architecture:
 | 
				
			||||||
 | 
					            if self.config.parallel_attn:
 | 
				
			||||||
 | 
					                mlp_layernorm_out = attention_layernorm_out
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                residual += attention_output
 | 
				
			||||||
 | 
					                mlp_layernorm_out = self.post_attention_layernorm(residual)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # MLP.
 | 
				
			||||||
 | 
					        mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
 | 
				
			||||||
 | 
					        if self.reduce_row_parallel_results and mlp_bias is not None:
 | 
				
			||||||
 | 
					            mlp_output += mlp_bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.reduce_row_parallel_results:
 | 
				
			||||||
 | 
					            # When MLP and Attention layers are parallel, we can use
 | 
				
			||||||
 | 
					            # only one all-reduce operator to reduce the results from
 | 
				
			||||||
 | 
					            # both MLP and Attention layers.
 | 
				
			||||||
 | 
					            mlp_output += attention_output
 | 
				
			||||||
 | 
					            mlp_output = tensor_model_parallel_all_reduce(mlp_output)
 | 
				
			||||||
 | 
					            if attention_bias is not None:
 | 
				
			||||||
 | 
					                mlp_output += attention_bias
 | 
				
			||||||
 | 
					            if mlp_bias is not None:
 | 
				
			||||||
 | 
					                mlp_output += mlp_bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        output = mlp_output + residual
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FalconModel(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config: FalconConfig):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.embed_dim = config.hidden_size
 | 
				
			||||||
 | 
					        self.num_heads = config.num_attention_heads
 | 
				
			||||||
 | 
					        self.use_alibi = config.alibi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Embedding + LN Embedding
 | 
				
			||||||
 | 
					        self.word_embeddings = VocabParallelEmbedding(
 | 
				
			||||||
 | 
					            config.vocab_size,
 | 
				
			||||||
 | 
					            self.embed_dim,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Transformer blocks
 | 
				
			||||||
 | 
					        self.h = nn.ModuleList([
 | 
				
			||||||
 | 
					            FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Final Layer Norm
 | 
				
			||||||
 | 
					        self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        input_ids: torch.LongTensor,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_caches: List[KVCache],
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_events: Optional[List[torch.cuda.Event]],
 | 
				
			||||||
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
 | 
					        hidden_states = self.word_embeddings(input_ids)
 | 
				
			||||||
 | 
					        for i in range(len(self.h)):
 | 
				
			||||||
 | 
					            if cache_events is None:
 | 
				
			||||||
 | 
					                cache_event = None
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                cache_event = cache_events[i]
 | 
				
			||||||
 | 
					            layer = self.h[i]
 | 
				
			||||||
 | 
					            hidden_states = layer(
 | 
				
			||||||
 | 
					                positions,
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                kv_caches[i],
 | 
				
			||||||
 | 
					                input_metadata,
 | 
				
			||||||
 | 
					                cache_event,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        hidden_states = self.ln_f(hidden_states)
 | 
				
			||||||
 | 
					        return hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FalconForCausalLM(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config: FalconConfig):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.transformer = FalconModel(config)
 | 
				
			||||||
 | 
					        self.lm_head = ColumnParallelLinear(
 | 
				
			||||||
 | 
					            config.hidden_size,
 | 
				
			||||||
 | 
					            config.vocab_size,
 | 
				
			||||||
 | 
					            bias=False,
 | 
				
			||||||
 | 
					            gather_output=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.sampler = Sampler(config.vocab_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        input_ids: torch.LongTensor,
 | 
				
			||||||
 | 
					        positions: torch.Tensor,
 | 
				
			||||||
 | 
					        kv_caches: List[KVCache],
 | 
				
			||||||
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
 | 
					        cache_events: Optional[List[torch.cuda.Event]],
 | 
				
			||||||
 | 
					    ) -> SamplerOutput:
 | 
				
			||||||
 | 
					        hidden_states = self.transformer(
 | 
				
			||||||
 | 
					            input_ids,
 | 
				
			||||||
 | 
					            positions,
 | 
				
			||||||
 | 
					            kv_caches,
 | 
				
			||||||
 | 
					            input_metadata,
 | 
				
			||||||
 | 
					            cache_events,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
 | 
				
			||||||
 | 
					                                   input_metadata)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return next_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _column_parallel_weights = [
 | 
				
			||||||
 | 
					        "word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
 | 
				
			||||||
 | 
					        "dense_h_to_4h.bias"
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_weights(self,
 | 
				
			||||||
 | 
					                     model_name_or_path: str,
 | 
				
			||||||
 | 
					                     cache_dir: Optional[str] = None,
 | 
				
			||||||
 | 
					                     load_format: str = "auto",
 | 
				
			||||||
 | 
					                     revision: Optional[str] = None):
 | 
				
			||||||
 | 
					        tp_size = (get_tensor_model_parallel_world_size())
 | 
				
			||||||
 | 
					        tp_rank = get_tensor_model_parallel_rank()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_size = self.config.hidden_size
 | 
				
			||||||
 | 
					        total_num_heads = self.config.num_attention_heads
 | 
				
			||||||
 | 
					        num_heads = total_num_heads // tp_size
 | 
				
			||||||
 | 
					        head_size = hidden_size // total_num_heads
 | 
				
			||||||
 | 
					        head_start = tp_rank * num_heads
 | 
				
			||||||
 | 
					        head_end = (tp_rank + 1) * num_heads
 | 
				
			||||||
 | 
					        if self.config.new_decoder_architecture:
 | 
				
			||||||
 | 
					            total_num_kv_heads = self.config.num_kv_heads
 | 
				
			||||||
 | 
					            num_kv_heads = total_num_kv_heads // tp_size
 | 
				
			||||||
 | 
					            separated_q_kv = False
 | 
				
			||||||
 | 
					            kv_head_start = tp_rank * num_kv_heads
 | 
				
			||||||
 | 
					            kv_head_end = (tp_rank + 1) * num_kv_heads
 | 
				
			||||||
 | 
					        elif self.config.multi_query:
 | 
				
			||||||
 | 
					            total_num_kv_heads = 1
 | 
				
			||||||
 | 
					            num_kv_heads = 1
 | 
				
			||||||
 | 
					            separated_q_kv = True
 | 
				
			||||||
 | 
					            kv_head_start = 0
 | 
				
			||||||
 | 
					            kv_head_end = 1
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            total_num_kv_heads = total_num_heads
 | 
				
			||||||
 | 
					            num_kv_heads = total_num_kv_heads // tp_size
 | 
				
			||||||
 | 
					            separated_q_kv = False
 | 
				
			||||||
 | 
					            kv_head_start = tp_rank * num_kv_heads
 | 
				
			||||||
 | 
					            kv_head_end = (tp_rank + 1) * num_kv_heads
 | 
				
			||||||
 | 
					        num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
 | 
				
			||||||
 | 
					        state_dict = self.state_dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for name, loaded_weight in hf_model_weights_iterator(
 | 
				
			||||||
 | 
					                model_name_or_path, cache_dir, load_format, revision):
 | 
				
			||||||
 | 
					            if "query_key_value" in name:
 | 
				
			||||||
 | 
					                loaded_weight = convert_pyslice_to_tensor(loaded_weight)
 | 
				
			||||||
 | 
					                loaded_weight_size = loaded_weight.size()
 | 
				
			||||||
 | 
					                loaded_weight = loaded_weight.view(
 | 
				
			||||||
 | 
					                    total_num_kv_heads, num_query_heads_per_kv_head + 2,
 | 
				
			||||||
 | 
					                    head_size, *loaded_weight_size[1:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:])
 | 
				
			||||||
 | 
					                wk = loaded_weight[:, [-2]].reshape(-1,
 | 
				
			||||||
 | 
					                                                    *loaded_weight_size[1:])
 | 
				
			||||||
 | 
					                wv = loaded_weight[:, [-1]].reshape(-1,
 | 
				
			||||||
 | 
					                                                    *loaded_weight_size[1:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                wq = wq[head_size * head_start:head_size * head_end]
 | 
				
			||||||
 | 
					                wk = wk[head_size * kv_head_start:head_size * kv_head_end]
 | 
				
			||||||
 | 
					                wv = wv[head_size * kv_head_start:head_size * kv_head_end]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if separated_q_kv:
 | 
				
			||||||
 | 
					                    loaded_weight_q = wq
 | 
				
			||||||
 | 
					                    loaded_weight_kv = torch.cat([wk, wv], dim=0)
 | 
				
			||||||
 | 
					                    q_weight_name = name.replace("query_key_value", "query")
 | 
				
			||||||
 | 
					                    kv_weight_name = name.replace("query_key_value",
 | 
				
			||||||
 | 
					                                                  "key_value")
 | 
				
			||||||
 | 
					                    load_tensor_parallel_weights(state_dict[q_weight_name],
 | 
				
			||||||
 | 
					                                                 loaded_weight_q,
 | 
				
			||||||
 | 
					                                                 q_weight_name,
 | 
				
			||||||
 | 
					                                                 self._column_parallel_weights,
 | 
				
			||||||
 | 
					                                                 self._row_parallel_weights,
 | 
				
			||||||
 | 
					                                                 tp_rank)
 | 
				
			||||||
 | 
					                    load_tensor_parallel_weights(state_dict[kv_weight_name],
 | 
				
			||||||
 | 
					                                                 loaded_weight_kv,
 | 
				
			||||||
 | 
					                                                 kv_weight_name,
 | 
				
			||||||
 | 
					                                                 self._column_parallel_weights,
 | 
				
			||||||
 | 
					                                                 self._row_parallel_weights,
 | 
				
			||||||
 | 
					                                                 tp_rank)
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    loaded_weight = torch.cat([wq, wk, wv], dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            param = state_dict[name]
 | 
				
			||||||
 | 
					            load_tensor_parallel_weights(param, loaded_weight, name,
 | 
				
			||||||
 | 
					                                         self._column_parallel_weights,
 | 
				
			||||||
 | 
					                                         self._row_parallel_weights, tp_rank)
 | 
				
			||||||
@ -21,7 +21,7 @@
 | 
				
			|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
 | 
					The input of the model is flattened to a 1D tensor of tokens. The model uses
 | 
				
			||||||
InputMetadata to extract the original 2D shape of the input.
 | 
					InputMetadata to extract the original 2D shape of the input.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
from typing import Dict, List, Optional, Tuple
 | 
					from typing import List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
@ -31,13 +31,15 @@ from vllm.model_executor.input_metadata import InputMetadata
 | 
				
			|||||||
from vllm.model_executor.layers.activation import get_act_fn
 | 
					from vllm.model_executor.layers.activation import get_act_fn
 | 
				
			||||||
from vllm.model_executor.layers.attention import PagedAttention
 | 
					from vllm.model_executor.layers.attention import PagedAttention
 | 
				
			||||||
from vllm.model_executor.layers.sampler import Sampler
 | 
					from vllm.model_executor.layers.sampler import Sampler
 | 
				
			||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
 | 
					from vllm.model_executor.weight_utils import (
 | 
				
			||||||
                                              load_tensor_parallel_weights)
 | 
					    convert_pyslice_to_tensor, hf_model_weights_iterator,
 | 
				
			||||||
 | 
					    load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
 | 
				
			||||||
from vllm.model_executor.parallel_utils.parallel_state import (
 | 
					from vllm.model_executor.parallel_utils.parallel_state import (
 | 
				
			||||||
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 | 
					    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 | 
				
			||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
 | 
					from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
 | 
				
			||||||
    VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
 | 
					                                                       ColumnParallelLinear,
 | 
				
			||||||
from vllm.sequence import SequenceOutputs
 | 
					                                                       RowParallelLinear)
 | 
				
			||||||
 | 
					from vllm.sequence import SamplerOutput
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
 | 
					KVCache = Tuple[torch.Tensor, torch.Tensor]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -55,16 +57,18 @@ class GPT2Attention(nn.Module):
 | 
				
			|||||||
        self.head_dim = self.hidden_size // total_num_heads
 | 
					        self.head_dim = self.hidden_size // total_num_heads
 | 
				
			||||||
        self.scale = self.head_dim**-0.5
 | 
					        self.scale = self.head_dim**-0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.c_attn = ColumnParallelLinear(self.hidden_size,
 | 
					        self.c_attn = ColumnParallelLinear(
 | 
				
			||||||
                                           3 * self.hidden_size,
 | 
					            self.hidden_size,
 | 
				
			||||||
                                           bias=True,
 | 
					            3 * self.hidden_size,
 | 
				
			||||||
                                           gather_output=False,
 | 
					            bias=True,
 | 
				
			||||||
                                           perform_initialization=False)
 | 
					            gather_output=False,
 | 
				
			||||||
        self.c_proj = RowParallelLinear(self.hidden_size,
 | 
					        )
 | 
				
			||||||
                                        self.hidden_size,
 | 
					        self.c_proj = RowParallelLinear(
 | 
				
			||||||
                                        bias=True,
 | 
					            self.hidden_size,
 | 
				
			||||||
                                        input_is_parallel=True,
 | 
					            self.hidden_size,
 | 
				
			||||||
                                        perform_initialization=False)
 | 
					            bias=True,
 | 
				
			||||||
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.attn = PagedAttention(self.num_heads,
 | 
					        self.attn = PagedAttention(self.num_heads,
 | 
				
			||||||
                                   self.head_dim,
 | 
					                                   self.head_dim,
 | 
				
			||||||
                                   scale=self.scale)
 | 
					                                   scale=self.scale)
 | 
				
			||||||
@ -94,16 +98,18 @@ class GPT2MLP(nn.Module):
 | 
				
			|||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        hidden_size = config.hidden_size
 | 
					        hidden_size = config.hidden_size
 | 
				
			||||||
        self.c_fc = ColumnParallelLinear(hidden_size,
 | 
					        self.c_fc = ColumnParallelLinear(
 | 
				
			||||||
                                         intermediate_size,
 | 
					            hidden_size,
 | 
				
			||||||
                                         bias=True,
 | 
					            intermediate_size,
 | 
				
			||||||
                                         gather_output=False,
 | 
					            bias=True,
 | 
				
			||||||
                                         perform_initialization=False)
 | 
					            gather_output=False,
 | 
				
			||||||
        self.c_proj = RowParallelLinear(intermediate_size,
 | 
					        )
 | 
				
			||||||
                                        hidden_size,
 | 
					        self.c_proj = RowParallelLinear(
 | 
				
			||||||
                                        bias=True,
 | 
					            intermediate_size,
 | 
				
			||||||
                                        input_is_parallel=True,
 | 
					            hidden_size,
 | 
				
			||||||
                                        perform_initialization=False)
 | 
					            bias=True,
 | 
				
			||||||
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.act = get_act_fn(config.activation_function)
 | 
					        self.act = get_act_fn(config.activation_function)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
					    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
@ -217,27 +223,28 @@ class GPT2LMHeadModel(nn.Module):
 | 
				
			|||||||
        kv_caches: List[KVCache],
 | 
					        kv_caches: List[KVCache],
 | 
				
			||||||
        input_metadata: InputMetadata,
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
        cache_events: Optional[List[torch.cuda.Event]],
 | 
					        cache_events: Optional[List[torch.cuda.Event]],
 | 
				
			||||||
    ) -> Dict[int, SequenceOutputs]:
 | 
					    ) -> SamplerOutput:
 | 
				
			||||||
        hidden_states = self.transformer(input_ids, positions, kv_caches,
 | 
					        hidden_states = self.transformer(input_ids, positions, kv_caches,
 | 
				
			||||||
                                         input_metadata, cache_events)
 | 
					                                         input_metadata, cache_events)
 | 
				
			||||||
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
 | 
					        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
 | 
				
			||||||
                                   input_metadata)
 | 
					                                   input_metadata)
 | 
				
			||||||
        return next_tokens
 | 
					        return next_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
 | 
					    _column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
 | 
				
			||||||
    _row_parallel_weights = ["c_proj.weight"]
 | 
					    _row_parallel_weights = ["c_proj.weight"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load_weights(self,
 | 
					    def load_weights(self,
 | 
				
			||||||
                     model_name_or_path: str,
 | 
					                     model_name_or_path: str,
 | 
				
			||||||
                     cache_dir: Optional[str] = None,
 | 
					                     cache_dir: Optional[str] = None,
 | 
				
			||||||
                     use_np_cache: bool = False):
 | 
					                     load_format: str = "auto",
 | 
				
			||||||
 | 
					                     revision: Optional[str] = None):
 | 
				
			||||||
        tensor_model_parallel_world_size = (
 | 
					        tensor_model_parallel_world_size = (
 | 
				
			||||||
            get_tensor_model_parallel_world_size())
 | 
					            get_tensor_model_parallel_world_size())
 | 
				
			||||||
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
 | 
					        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
 | 
				
			||||||
        state_dict = self.state_dict()
 | 
					        state_dict = self.state_dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for name, loaded_weight in hf_model_weights_iterator(
 | 
					        for name, loaded_weight in hf_model_weights_iterator(
 | 
				
			||||||
                model_name_or_path, cache_dir, use_np_cache):
 | 
					                model_name_or_path, cache_dir, load_format, revision):
 | 
				
			||||||
            if "lm_head.weight" in name:
 | 
					            if "lm_head.weight" in name:
 | 
				
			||||||
                # GPT-2 ties the weights of the embedding layer and the final
 | 
					                # GPT-2 ties the weights of the embedding layer and the final
 | 
				
			||||||
                # linear layer.
 | 
					                # linear layer.
 | 
				
			||||||
@ -250,6 +257,8 @@ class GPT2LMHeadModel(nn.Module):
 | 
				
			|||||||
            if not name.startswith("transformer."):
 | 
					            if not name.startswith("transformer."):
 | 
				
			||||||
                name = "transformer." + name
 | 
					                name = "transformer." + name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            loaded_weight = convert_pyslice_to_tensor(loaded_weight)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
 | 
					            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
 | 
				
			||||||
            # Because of this, we need to transpose the weights.
 | 
					            # Because of this, we need to transpose the weights.
 | 
				
			||||||
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
 | 
					            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
 | 
				
			||||||
@ -261,14 +270,9 @@ class GPT2LMHeadModel(nn.Module):
 | 
				
			|||||||
            param = state_dict[name]
 | 
					            param = state_dict[name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if name == "transformer.wte.weight":
 | 
					            if name == "transformer.wte.weight":
 | 
				
			||||||
                # Consider padding in the vocab size.
 | 
					                load_padded_tensor_parallel_vocab(param, loaded_weight,
 | 
				
			||||||
                padded_vocab_size = (param.shape[0] *
 | 
					                                                  tensor_model_parallel_rank)
 | 
				
			||||||
                                     tensor_model_parallel_world_size)
 | 
					                continue
 | 
				
			||||||
                num_extra_rows = padded_vocab_size - self.config.vocab_size
 | 
					 | 
				
			||||||
                extra_rows = torch.empty(num_extra_rows,
 | 
					 | 
				
			||||||
                                         loaded_weight.shape[1])
 | 
					 | 
				
			||||||
                extra_rows = extra_rows.to(loaded_weight)
 | 
					 | 
				
			||||||
                loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # For the fused QKV linear layer, manually shard the weights.
 | 
					            # For the fused QKV linear layer, manually shard the weights.
 | 
				
			||||||
            if "c_attn" in name:
 | 
					            if "c_attn" in name:
 | 
				
			||||||
 | 
				
			|||||||
@ -22,24 +22,25 @@
 | 
				
			|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
 | 
					The input of the model is flattened to a 1D tensor of tokens. The model uses
 | 
				
			||||||
InputMetadata to extract the original 2D shape of the input.
 | 
					InputMetadata to extract the original 2D shape of the input.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
from typing import Dict, List, Optional, Tuple
 | 
					from typing import List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
from transformers import GPTBigCodeConfig
 | 
					from transformers import GPTBigCodeConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm.model_executor.input_metadata import InputMetadata
 | 
					from vllm.model_executor.input_metadata import InputMetadata
 | 
				
			||||||
from vllm.model_executor.layers.activation import get_act_fn
 | 
					from vllm.model_executor.layers.activation import get_act_fn
 | 
				
			||||||
from vllm.model_executor.layers.attention import PagedAttention
 | 
					from vllm.model_executor.layers.attention import PagedAttention
 | 
				
			||||||
from vllm.model_executor.layers.sampler import Sampler
 | 
					from vllm.model_executor.layers.sampler import Sampler
 | 
				
			||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
 | 
					from vllm.model_executor.weight_utils import (
 | 
				
			||||||
                                              load_tensor_parallel_weights)
 | 
					    convert_pyslice_to_tensor, hf_model_weights_iterator,
 | 
				
			||||||
 | 
					    load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
 | 
				
			||||||
from vllm.model_executor.parallel_utils.parallel_state import (
 | 
					from vllm.model_executor.parallel_utils.parallel_state import (
 | 
				
			||||||
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 | 
					    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 | 
				
			||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
 | 
					from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
 | 
				
			||||||
    VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
 | 
					                                                       ColumnParallelLinear,
 | 
				
			||||||
from vllm.sequence import SequenceOutputs
 | 
					                                                       RowParallelLinear)
 | 
				
			||||||
 | 
					from vllm.sequence import SamplerOutput
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
 | 
					KVCache = Tuple[torch.Tensor, torch.Tensor]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -50,26 +51,47 @@ class GPTBigCodeAttention(nn.Module):
 | 
				
			|||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.hidden_size = config.hidden_size
 | 
					        self.hidden_size = config.hidden_size
 | 
				
			||||||
        total_num_heads = config.num_attention_heads
 | 
					        total_num_heads = config.num_attention_heads
 | 
				
			||||||
        tensor_model_parallel_world_size = (
 | 
					        self.tensor_model_parallel_world_size = (
 | 
				
			||||||
            get_tensor_model_parallel_world_size())
 | 
					            get_tensor_model_parallel_world_size())
 | 
				
			||||||
        assert total_num_heads % tensor_model_parallel_world_size == 0
 | 
					        assert total_num_heads % self.tensor_model_parallel_world_size == 0
 | 
				
			||||||
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
 | 
					        self.num_heads = (total_num_heads //
 | 
				
			||||||
 | 
					                          self.tensor_model_parallel_world_size)
 | 
				
			||||||
        self.head_dim = self.hidden_size // total_num_heads
 | 
					        self.head_dim = self.hidden_size // total_num_heads
 | 
				
			||||||
        self.scale = self.head_dim**-0.5
 | 
					        self.scale = self.head_dim**-0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.c_attn = ColumnParallelLinear(self.hidden_size,
 | 
					        self.multi_query = config.multi_query
 | 
				
			||||||
                                           3 * self.hidden_size,
 | 
					        if self.multi_query:
 | 
				
			||||||
                                           bias=True,
 | 
					            self.num_kv_heads = 1
 | 
				
			||||||
                                           gather_output=False,
 | 
					            self.kv_dim = self.head_dim
 | 
				
			||||||
                                           perform_initialization=False)
 | 
					            self.c_attn_q = ColumnParallelLinear(
 | 
				
			||||||
        self.c_proj = RowParallelLinear(self.hidden_size,
 | 
					                self.hidden_size,
 | 
				
			||||||
                                        self.hidden_size,
 | 
					                self.hidden_size,
 | 
				
			||||||
                                        bias=True,
 | 
					                bias=True,
 | 
				
			||||||
                                        input_is_parallel=True,
 | 
					                gather_output=False,
 | 
				
			||||||
                                        perform_initialization=False)
 | 
					            )
 | 
				
			||||||
 | 
					            self.c_attn_kv = nn.Linear(self.hidden_size,
 | 
				
			||||||
 | 
					                                       2 * self.kv_dim,
 | 
				
			||||||
 | 
					                                       bias=True)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.num_kv_heads = self.num_heads
 | 
				
			||||||
 | 
					            self.kv_dim = self.num_kv_heads * self.head_dim
 | 
				
			||||||
 | 
					            self.c_attn = ColumnParallelLinear(
 | 
				
			||||||
 | 
					                self.hidden_size,
 | 
				
			||||||
 | 
					                self.hidden_size + 2 * self.kv_dim,
 | 
				
			||||||
 | 
					                bias=True,
 | 
				
			||||||
 | 
					                gather_output=False,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.c_proj = RowParallelLinear(
 | 
				
			||||||
 | 
					            self.hidden_size,
 | 
				
			||||||
 | 
					            self.hidden_size,
 | 
				
			||||||
 | 
					            bias=True,
 | 
				
			||||||
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.attn = PagedAttention(self.num_heads,
 | 
					        self.attn = PagedAttention(self.num_heads,
 | 
				
			||||||
                                   self.head_dim,
 | 
					                                   self.head_dim,
 | 
				
			||||||
                                   scale=self.scale)
 | 
					                                   scale=self.scale,
 | 
				
			||||||
 | 
					                                   num_kv_heads=self.num_kv_heads)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
@ -78,8 +100,17 @@ class GPTBigCodeAttention(nn.Module):
 | 
				
			|||||||
        input_metadata: InputMetadata,
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
        cache_event: Optional[torch.cuda.Event],
 | 
					        cache_event: Optional[torch.cuda.Event],
 | 
				
			||||||
    ) -> torch.Tensor:
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
        qkv, _ = self.c_attn(hidden_states)
 | 
					        if self.multi_query:
 | 
				
			||||||
        q, k, v = qkv.chunk(chunks=3, dim=-1)
 | 
					            q, _ = self.c_attn_q(hidden_states)
 | 
				
			||||||
 | 
					            kv = self.c_attn_kv(hidden_states)
 | 
				
			||||||
 | 
					            k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            qkv, _ = self.c_attn(hidden_states)
 | 
				
			||||||
 | 
					            q, k, v = qkv.split([
 | 
				
			||||||
 | 
					                self.hidden_size // self.tensor_model_parallel_world_size,
 | 
				
			||||||
 | 
					                self.kv_dim, self.kv_dim
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					                                dim=-1)
 | 
				
			||||||
        key_cache, value_cache = kv_cache
 | 
					        key_cache, value_cache = kv_cache
 | 
				
			||||||
        attn_output = self.attn(q, k, v, key_cache, value_cache,
 | 
					        attn_output = self.attn(q, k, v, key_cache, value_cache,
 | 
				
			||||||
                                input_metadata, cache_event)
 | 
					                                input_metadata, cache_event)
 | 
				
			||||||
@ -96,16 +127,18 @@ class GPTBigMLP(nn.Module):
 | 
				
			|||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        hidden_size = config.hidden_size
 | 
					        hidden_size = config.hidden_size
 | 
				
			||||||
        self.c_fc = ColumnParallelLinear(hidden_size,
 | 
					        self.c_fc = ColumnParallelLinear(
 | 
				
			||||||
                                         intermediate_size,
 | 
					            hidden_size,
 | 
				
			||||||
                                         bias=True,
 | 
					            intermediate_size,
 | 
				
			||||||
                                         gather_output=False,
 | 
					            bias=True,
 | 
				
			||||||
                                         perform_initialization=False)
 | 
					            gather_output=False,
 | 
				
			||||||
        self.c_proj = RowParallelLinear(intermediate_size,
 | 
					        )
 | 
				
			||||||
                                        hidden_size,
 | 
					        self.c_proj = RowParallelLinear(
 | 
				
			||||||
                                        bias=True,
 | 
					            intermediate_size,
 | 
				
			||||||
                                        input_is_parallel=True,
 | 
					            hidden_size,
 | 
				
			||||||
                                        perform_initialization=False)
 | 
					            bias=True,
 | 
				
			||||||
 | 
					            input_is_parallel=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.act = get_act_fn(config.activation_function)
 | 
					        self.act = get_act_fn(config.activation_function)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
					    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
@ -218,27 +251,28 @@ class GPTBigCodeForCausalLM(nn.Module):
 | 
				
			|||||||
        kv_caches: List[KVCache],
 | 
					        kv_caches: List[KVCache],
 | 
				
			||||||
        input_metadata: InputMetadata,
 | 
					        input_metadata: InputMetadata,
 | 
				
			||||||
        cache_events: Optional[List[torch.cuda.Event]],
 | 
					        cache_events: Optional[List[torch.cuda.Event]],
 | 
				
			||||||
    ) -> Dict[int, SequenceOutputs]:
 | 
					    ) -> SamplerOutput:
 | 
				
			||||||
        hidden_states = self.transformer(input_ids, positions, kv_caches,
 | 
					        hidden_states = self.transformer(input_ids, positions, kv_caches,
 | 
				
			||||||
                                         input_metadata, cache_events)
 | 
					                                         input_metadata, cache_events)
 | 
				
			||||||
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
 | 
					        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
 | 
				
			||||||
                                   input_metadata)
 | 
					                                   input_metadata)
 | 
				
			||||||
        return next_tokens
 | 
					        return next_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
 | 
					    _column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
 | 
				
			||||||
    _row_parallel_weights = ["c_proj.weight"]
 | 
					    _row_parallel_weights = ["c_proj.weight"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load_weights(self,
 | 
					    def load_weights(self,
 | 
				
			||||||
                     model_name_or_path: str,
 | 
					                     model_name_or_path: str,
 | 
				
			||||||
                     cache_dir: Optional[str] = None,
 | 
					                     cache_dir: Optional[str] = None,
 | 
				
			||||||
                     use_np_cache: bool = False):
 | 
					                     load_format: str = "auto",
 | 
				
			||||||
 | 
					                     revision: Optional[str] = None):
 | 
				
			||||||
        tensor_model_parallel_world_size = (
 | 
					        tensor_model_parallel_world_size = (
 | 
				
			||||||
            get_tensor_model_parallel_world_size())
 | 
					            get_tensor_model_parallel_world_size())
 | 
				
			||||||
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
 | 
					        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
 | 
				
			||||||
        state_dict = self.state_dict()
 | 
					        state_dict = self.state_dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for name, loaded_weight in hf_model_weights_iterator(
 | 
					        for name, loaded_weight in hf_model_weights_iterator(
 | 
				
			||||||
                model_name_or_path, cache_dir, use_np_cache):
 | 
					                model_name_or_path, cache_dir, load_format, revision):
 | 
				
			||||||
            if "lm_head.weight" in name:
 | 
					            if "lm_head.weight" in name:
 | 
				
			||||||
                # GPT-2 ties the weights of the embedding layer and the final
 | 
					                # GPT-2 ties the weights of the embedding layer and the final
 | 
				
			||||||
                # linear layer.
 | 
					                # linear layer.
 | 
				
			||||||
@ -248,51 +282,9 @@ class GPTBigCodeForCausalLM(nn.Module):
 | 
				
			|||||||
                # NOTE: "c_attn.bias" should not be skipped.
 | 
					                # NOTE: "c_attn.bias" should not be skipped.
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            param = state_dict[name]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if not name.startswith("transformer."):
 | 
					            if not name.startswith("transformer."):
 | 
				
			||||||
                name = "transformer." + name
 | 
					                name = "transformer." + name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if name == "transformer.wte.weight":
 | 
					 | 
				
			||||||
                # Consider padding in the vocab size.
 | 
					 | 
				
			||||||
                padded_vocab_size = param.shape[
 | 
					 | 
				
			||||||
                    0] * tensor_model_parallel_world_size
 | 
					 | 
				
			||||||
                num_extra_rows = padded_vocab_size - self.config.vocab_size
 | 
					 | 
				
			||||||
                extra_rows = torch.empty(num_extra_rows,
 | 
					 | 
				
			||||||
                                         loaded_weight.shape[1])
 | 
					 | 
				
			||||||
                extra_rows = extra_rows.to(loaded_weight)
 | 
					 | 
				
			||||||
                loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            def _expand_mqa_mha(qkv_array, n_head, head_dim):
 | 
					 | 
				
			||||||
                """manipulates along axis=0 from MQA to MHA
 | 
					 | 
				
			||||||
                inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
 | 
					 | 
				
			||||||
                    with n_heads for q, then 1 for k, 1 for 1 v, times head dim
 | 
					 | 
				
			||||||
                return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                TODO: this function is no longer needed once vllm supports MQA.
 | 
					 | 
				
			||||||
                """
 | 
					 | 
				
			||||||
                qkv_array = qkv_array.numpy()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                dims_q = n_head * head_dim
 | 
					 | 
				
			||||||
                # pylint: disable=unbalanced-tuple-unpacking
 | 
					 | 
				
			||||||
                q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
 | 
					 | 
				
			||||||
                                   axis=0)
 | 
					 | 
				
			||||||
                # q is fine, but k & v have not replicated shape along the first
 | 
					 | 
				
			||||||
                # axis as long as MQA is not nativly supported, increase memory
 | 
					 | 
				
			||||||
                # and replicated (head_dim, hidden_dim) to
 | 
					 | 
				
			||||||
                # (n_heads * head_dim, hidden_dim)
 | 
					 | 
				
			||||||
                if k.ndim == 2 and v.ndim == 2:
 | 
					 | 
				
			||||||
                    replication = (n_head, 1)  # weights
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    replication = n_head  # biases
 | 
					 | 
				
			||||||
                # replicate n_head times for q, v
 | 
					 | 
				
			||||||
                k, v = np.tile(k, replication), np.tile(v, replication)
 | 
					 | 
				
			||||||
                # concat q, k, v along the first axis
 | 
					 | 
				
			||||||
                # (n_heads * head_dim, hidden_dim)
 | 
					 | 
				
			||||||
                # to (3 * n_heads * head_dim, hidden_dim)
 | 
					 | 
				
			||||||
                qkv_array = np.concatenate((q, k, v), axis=0)
 | 
					 | 
				
			||||||
                return torch.from_numpy(qkv_array)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # For the fused QKV linear layer, manually shard the weights.
 | 
					            # For the fused QKV linear layer, manually shard the weights.
 | 
				
			||||||
            if "c_attn" in name:
 | 
					            if "c_attn" in name:
 | 
				
			||||||
                # GPT-2's fused QKV has the shape of
 | 
					                # GPT-2's fused QKV has the shape of
 | 
				
			||||||
@ -300,30 +292,53 @@ class GPTBigCodeForCausalLM(nn.Module):
 | 
				
			|||||||
                # When tensor parallelism is used, we shard the weights along
 | 
					                # When tensor parallelism is used, we shard the weights along
 | 
				
			||||||
                # the head dimension.
 | 
					                # the head dimension.
 | 
				
			||||||
                total_num_heads = self.config.num_attention_heads
 | 
					                total_num_heads = self.config.num_attention_heads
 | 
				
			||||||
 | 
					                total_num_kv_heads = (1 if self.config.multi_query else
 | 
				
			||||||
 | 
					                                      total_num_heads)
 | 
				
			||||||
                hidden_size = self.config.hidden_size
 | 
					                hidden_size = self.config.hidden_size
 | 
				
			||||||
                head_size = hidden_size // total_num_heads
 | 
					                head_size = hidden_size // total_num_heads
 | 
				
			||||||
 | 
					                total_kv_size = head_size * total_num_kv_heads
 | 
				
			||||||
                num_heads = total_num_heads // tensor_model_parallel_world_size
 | 
					                num_heads = total_num_heads // tensor_model_parallel_world_size
 | 
				
			||||||
                head_start = tensor_model_parallel_rank * num_heads
 | 
					                head_start = tensor_model_parallel_rank * num_heads
 | 
				
			||||||
                head_end = (tensor_model_parallel_rank + 1) * num_heads
 | 
					                head_end = (tensor_model_parallel_rank + 1) * num_heads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if name.endswith(".weight"):
 | 
					                loaded_weight = convert_pyslice_to_tensor(loaded_weight)
 | 
				
			||||||
                    loaded_weight = _expand_mqa_mha(loaded_weight,
 | 
					                wq, wk, wv = torch.split(
 | 
				
			||||||
                                                    n_head=total_num_heads,
 | 
					                    loaded_weight, [hidden_size, total_kv_size, total_kv_size],
 | 
				
			||||||
                                                    head_dim=head_size)
 | 
					                    dim=0)
 | 
				
			||||||
                    loaded_weight = loaded_weight.view(3, total_num_heads,
 | 
					
 | 
				
			||||||
                                                       head_size, hidden_size)
 | 
					                wq = wq[head_size * head_start:head_size * head_end]
 | 
				
			||||||
                    loaded_weight = loaded_weight[:, head_start:head_end, :, :]
 | 
					                if not self.config.multi_query:
 | 
				
			||||||
                    loaded_weight = loaded_weight.reshape(-1, hidden_size)
 | 
					                    # Split the heads when using normal multi-head attention
 | 
				
			||||||
                elif name.endswith(".bias"):
 | 
					                    wk = wk[head_size * head_start:head_size * head_end]
 | 
				
			||||||
                    loaded_weight = _expand_mqa_mha(loaded_weight,
 | 
					                    wv = wv[head_size * head_start:head_size * head_end]
 | 
				
			||||||
                                                    n_head=total_num_heads,
 | 
					                    loaded_weight = torch.cat([wq, wk, wv], dim=0)
 | 
				
			||||||
                                                    head_dim=head_size)
 | 
					 | 
				
			||||||
                    loaded_weight = loaded_weight.view(3, total_num_heads,
 | 
					 | 
				
			||||||
                                                       head_size)
 | 
					 | 
				
			||||||
                    loaded_weight = loaded_weight[:, head_start:head_end, :]
 | 
					 | 
				
			||||||
                    loaded_weight = loaded_weight.reshape(-1)
 | 
					 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    raise ValueError(f"Unexpected parameter name {name}")
 | 
					                    # For multi-query attention, we split the query
 | 
				
			||||||
 | 
					                    # but replicate the key and value.
 | 
				
			||||||
 | 
					                    loaded_weight_q = wq
 | 
				
			||||||
 | 
					                    loaded_weight_kv = torch.cat([wk, wv], dim=0)
 | 
				
			||||||
 | 
					                    q_weight_name = name.replace("c_attn", "c_attn_q")
 | 
				
			||||||
 | 
					                    kv_weight_name = name.replace("c_attn", "c_attn_kv")
 | 
				
			||||||
 | 
					                    load_tensor_parallel_weights(state_dict[q_weight_name],
 | 
				
			||||||
 | 
					                                                 loaded_weight_q,
 | 
				
			||||||
 | 
					                                                 q_weight_name,
 | 
				
			||||||
 | 
					                                                 self._column_parallel_weights,
 | 
				
			||||||
 | 
					                                                 self._row_parallel_weights,
 | 
				
			||||||
 | 
					                                                 tensor_model_parallel_rank)
 | 
				
			||||||
 | 
					                    load_tensor_parallel_weights(state_dict[kv_weight_name],
 | 
				
			||||||
 | 
					                                                 loaded_weight_kv,
 | 
				
			||||||
 | 
					                                                 kv_weight_name,
 | 
				
			||||||
 | 
					                                                 self._column_parallel_weights,
 | 
				
			||||||
 | 
					                                                 self._row_parallel_weights,
 | 
				
			||||||
 | 
					                                                 tensor_model_parallel_rank)
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            param = state_dict[name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if name == "transformer.wte.weight":
 | 
				
			||||||
 | 
					                load_padded_tensor_parallel_vocab(param, loaded_weight,
 | 
				
			||||||
 | 
					                                                  tensor_model_parallel_rank)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            load_tensor_parallel_weights(param, loaded_weight, name,
 | 
					            load_tensor_parallel_weights(param, loaded_weight, name,
 | 
				
			||||||
                                         self._column_parallel_weights,
 | 
					                                         self._column_parallel_weights,
 | 
				
			||||||
 | 
				
			|||||||