mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	Compare commits
	
		
			6 Commits
		
	
	
		
			v0.11.1rc1
			...
			v0.10.2rc3
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| da3fa78dc9 | |||
| bbb70036cb | |||
| 89da8d9d09 | |||
| 01085b134d | |||
| 66160a9943 | |||
| eaca762c18 | 
							
								
								
									
										114
									
								
								tests/models/language/pooling/test_mm_classifier_conversion.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								tests/models/language/pooling/test_mm_classifier_conversion.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,114 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| from vllm.platforms import current_platform | ||||
|  | ||||
|  | ||||
| def test_idefics_multimodal( | ||||
|     vllm_runner, | ||||
|     monkeypatch, | ||||
| ) -> None: | ||||
|     if current_platform.is_rocm(): | ||||
|         # ROCm Triton FA does not currently support sliding window attention | ||||
|         # switch to use ROCm CK FA backend | ||||
|         monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") | ||||
|  | ||||
|     prompts = [ | ||||
|         "Hello, my name is", | ||||
|         "The president of the United States is", | ||||
|         "The capital of France is", | ||||
|         "The future of AI is", | ||||
|     ] | ||||
|  | ||||
|     with vllm_runner(model_name="HuggingFaceM4/Idefics3-8B-Llama3", | ||||
|                      runner="pooling", | ||||
|                      task="classify", | ||||
|                      convert="classify", | ||||
|                      load_format="dummy", | ||||
|                      max_model_len=512, | ||||
|                      enforce_eager=True, | ||||
|                      tensor_parallel_size=1, | ||||
|                      disable_log_stats=True, | ||||
|                      dtype="bfloat16") as vllm_model: | ||||
|         llm = vllm_model.get_llm() | ||||
|         outputs = llm.classify(prompts) | ||||
|         for output in outputs: | ||||
|             assert len(output.outputs.probs) == 2 | ||||
|  | ||||
|  | ||||
| def update_config(config): | ||||
|     config.text_config.update({ | ||||
|         "architectures": ["Gemma3ForSequenceClassification"], | ||||
|         "classifier_from_token": ["A", "B", "C", "D", "E"], | ||||
|         "method": | ||||
|         "no_post_processing", | ||||
|         "id2label": { | ||||
|             "A": "Chair", | ||||
|             "B": "Couch", | ||||
|             "C": "Table", | ||||
|             "D": "Bed", | ||||
|             "E": "Cupboard" | ||||
|         }, | ||||
|     }) | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def test_gemma_multimodal( | ||||
|     vllm_runner, | ||||
|     monkeypatch, | ||||
| ) -> None: | ||||
|     if current_platform.is_rocm(): | ||||
|         # ROCm Triton FA does not currently support sliding window attention | ||||
|         # switch to use ROCm CK FA backend | ||||
|         monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") | ||||
|  | ||||
|     messages = [{ | ||||
|         "role": | ||||
|         "system", | ||||
|         "content": | ||||
|         """ | ||||
|     You are a helpful assistant. You will be given a product description | ||||
|     which may also include an image. Classify the following product into | ||||
|     one of the categories: | ||||
|  | ||||
|     A = chair | ||||
|     B = couch | ||||
|     C = table | ||||
|     D = bed | ||||
|     E = cupboard | ||||
|  | ||||
|     You'll answer with exactly one letter (A, B, C, D, or E).""" | ||||
|     }, { | ||||
|         "role": | ||||
|         "user", | ||||
|         "content": [{ | ||||
|             "type": "image_url", | ||||
|             "image_url": { | ||||
|                 "url": | ||||
|                 "https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg" | ||||
|             } | ||||
|         }, { | ||||
|             "type": "text", | ||||
|             "text": "A fine 19th century piece of furniture." | ||||
|         }] | ||||
|     }] | ||||
|  | ||||
|     with vllm_runner(model_name="google/gemma-3-4b-it", | ||||
|                      runner="pooling", | ||||
|                      task="classify", | ||||
|                      convert="classify", | ||||
|                      load_format="auto", | ||||
|                      hf_overrides=update_config, | ||||
|                      override_pooler_config={"pooling_type": "LAST"}, | ||||
|                      max_model_len=512, | ||||
|                      enforce_eager=True, | ||||
|                      tensor_parallel_size=1, | ||||
|                      disable_log_stats=True, | ||||
|                      dtype="bfloat16") as vllm_model: | ||||
|  | ||||
|         llm = vllm_model.get_llm() | ||||
|         prompts = llm.preprocess_chat(messages) | ||||
|  | ||||
|         result = llm.classify(prompts) | ||||
|         assert result[0].outputs.probs[0] > 0.95 | ||||
|         assert all(c < 0.05 for c in result[0].outputs.probs[1:]) | ||||
| @ -703,6 +703,106 @@ class LLM: | ||||
|  | ||||
|         return outputs | ||||
|  | ||||
|     def preprocess_chat( | ||||
|         self, | ||||
|         messages: Union[list[ChatCompletionMessageParam], | ||||
|                         list[list[ChatCompletionMessageParam]]], | ||||
|         lora_request: Optional[LoRARequest] = None, | ||||
|         chat_template: Optional[str] = None, | ||||
|         chat_template_content_format: ChatTemplateContentFormatOption = "auto", | ||||
|         add_generation_prompt: bool = True, | ||||
|         continue_final_message: bool = False, | ||||
|         tools: Optional[list[dict[str, Any]]] = None, | ||||
|         chat_template_kwargs: Optional[dict[str, Any]] = None, | ||||
|         mm_processor_kwargs: Optional[dict[str, Any]] = None, | ||||
|     ) -> list[TokensPrompt]: | ||||
|         """ | ||||
|         Generate prompt for a chat conversation. The pre-processed | ||||
|         prompt can then be used as input for the other LLM methods. | ||||
|  | ||||
|         Refer to `chat` for a complete description of the arguments. | ||||
|         Returns: | ||||
|             A list of `TokensPrompts` objects containing the tokenized | ||||
|             prompt after chat template interpolation, and the | ||||
|             pre-processed multi-modal inputs. | ||||
|         """ | ||||
|         list_of_messages: list[list[ChatCompletionMessageParam]] | ||||
|  | ||||
|         # Handle multi and single conversations | ||||
|         if is_list_of(messages, list): | ||||
|             # messages is list[list[...]] | ||||
|             list_of_messages = cast(list[list[ChatCompletionMessageParam]], | ||||
|                                     messages) | ||||
|         else: | ||||
|             # messages is list[...] | ||||
|             list_of_messages = [ | ||||
|                 cast(list[ChatCompletionMessageParam], messages) | ||||
|             ] | ||||
|  | ||||
|         tokenizer = self.get_tokenizer(lora_request) | ||||
|         model_config = self.llm_engine.get_model_config() | ||||
|         resolved_content_format = resolve_chat_template_content_format( | ||||
|             chat_template, | ||||
|             tools, | ||||
|             chat_template_content_format, | ||||
|             tokenizer, | ||||
|             model_config=model_config, | ||||
|         ) | ||||
|  | ||||
|         _chat_template_kwargs: dict[str, Any] = dict( | ||||
|             chat_template=chat_template, | ||||
|             add_generation_prompt=add_generation_prompt, | ||||
|             continue_final_message=continue_final_message, | ||||
|             tools=tools, | ||||
|         ) | ||||
|         _chat_template_kwargs.update(chat_template_kwargs or {}) | ||||
|  | ||||
|         prompts: list[TokensPrompt] = [] | ||||
|  | ||||
|         for msgs in list_of_messages: | ||||
|             # NOTE: _parse_chat_message_content_parts() currently doesn't | ||||
|             # handle mm_processor_kwargs, since there is no implementation in | ||||
|             # the chat message parsing for it. | ||||
|             conversation, mm_data, mm_uuids = parse_chat_messages( | ||||
|                 msgs, | ||||
|                 model_config, | ||||
|                 tokenizer, | ||||
|                 content_format=resolved_content_format, | ||||
|             ) | ||||
|  | ||||
|             if isinstance(tokenizer, MistralTokenizer): | ||||
|                 prompt_token_ids = apply_mistral_chat_template( | ||||
|                     tokenizer, | ||||
|                     messages=msgs, | ||||
|                     **_chat_template_kwargs, | ||||
|                 ) | ||||
|             else: | ||||
|                 prompt_str = apply_hf_chat_template( | ||||
|                     tokenizer=tokenizer, | ||||
|                     conversation=conversation, | ||||
|                     model_config=model_config, | ||||
|                     **_chat_template_kwargs, | ||||
|                 ) | ||||
|                 # Special tokens are already included in chat templates so | ||||
|                 # should not be added by the tokenizer in this case. | ||||
|                 prompt_token_ids = tokenizer.encode(prompt_str, | ||||
|                                                     add_special_tokens=False) | ||||
|  | ||||
|             prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) | ||||
|  | ||||
|             if mm_data is not None: | ||||
|                 prompt["multi_modal_data"] = mm_data | ||||
|  | ||||
|             if mm_uuids is not None: | ||||
|                 prompt["multi_modal_uuids"] = mm_uuids | ||||
|  | ||||
|             if mm_processor_kwargs is not None: | ||||
|                 prompt["mm_processor_kwargs"] = mm_processor_kwargs | ||||
|  | ||||
|             prompts.append(prompt) | ||||
|  | ||||
|         return prompts | ||||
|  | ||||
|     def chat( | ||||
|         self, | ||||
|         messages: Union[list[ChatCompletionMessageParam], | ||||
| @ -769,80 +869,18 @@ class LLM: | ||||
|             A list of `RequestOutput` objects containing the generated | ||||
|             responses in the same order as the input messages. | ||||
|         """ | ||||
|         list_of_messages: list[list[ChatCompletionMessageParam]] | ||||
|  | ||||
|         # Handle multi and single conversations | ||||
|         if is_list_of(messages, list): | ||||
|             # messages is list[list[...]] | ||||
|             list_of_messages = cast(list[list[ChatCompletionMessageParam]], | ||||
|                                     messages) | ||||
|         else: | ||||
|             # messages is list[...] | ||||
|             list_of_messages = [ | ||||
|                 cast(list[ChatCompletionMessageParam], messages) | ||||
|             ] | ||||
|  | ||||
|         tokenizer = self.get_tokenizer(lora_request) | ||||
|         model_config = self.llm_engine.get_model_config() | ||||
|         resolved_content_format = resolve_chat_template_content_format( | ||||
|             chat_template, | ||||
|             tools, | ||||
|             chat_template_content_format, | ||||
|             tokenizer, | ||||
|             model_config=model_config, | ||||
|         ) | ||||
|  | ||||
|         _chat_template_kwargs: dict[str, Any] = dict( | ||||
|         prompts = self.preprocess_chat( | ||||
|             messages=messages, | ||||
|             lora_request=lora_request, | ||||
|             chat_template=chat_template, | ||||
|             chat_template_content_format=chat_template_content_format, | ||||
|             add_generation_prompt=add_generation_prompt, | ||||
|             continue_final_message=continue_final_message, | ||||
|             tools=tools, | ||||
|             chat_template_kwargs=chat_template_kwargs, | ||||
|             mm_processor_kwargs=mm_processor_kwargs, | ||||
|         ) | ||||
|         _chat_template_kwargs.update(chat_template_kwargs or {}) | ||||
|  | ||||
|         prompts: list[Union[TokensPrompt, TextPrompt]] = [] | ||||
|  | ||||
|         for msgs in list_of_messages: | ||||
|             # NOTE: _parse_chat_message_content_parts() currently doesn't | ||||
|             # handle mm_processor_kwargs, since there is no implementation in | ||||
|             # the chat message parsing for it. | ||||
|             conversation, mm_data, mm_uuids = parse_chat_messages( | ||||
|                 msgs, | ||||
|                 model_config, | ||||
|                 tokenizer, | ||||
|                 content_format=resolved_content_format, | ||||
|             ) | ||||
|  | ||||
|             if isinstance(tokenizer, MistralTokenizer): | ||||
|                 prompt_token_ids = apply_mistral_chat_template( | ||||
|                     tokenizer, | ||||
|                     messages=msgs, | ||||
|                     **_chat_template_kwargs, | ||||
|                 ) | ||||
|             else: | ||||
|                 prompt_str = apply_hf_chat_template( | ||||
|                     tokenizer=tokenizer, | ||||
|                     conversation=conversation, | ||||
|                     model_config=model_config, | ||||
|                     **_chat_template_kwargs, | ||||
|                 ) | ||||
|                 # Special tokens are already included in chat templates so | ||||
|                 # should not be added by the tokenizer in this case. | ||||
|                 prompt_token_ids = tokenizer.encode(prompt_str, | ||||
|                                                     add_special_tokens=False) | ||||
|  | ||||
|             prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) | ||||
|  | ||||
|             if mm_data is not None: | ||||
|                 prompt["multi_modal_data"] = mm_data | ||||
|  | ||||
|             if mm_uuids is not None: | ||||
|                 prompt["multi_modal_uuids"] = mm_uuids | ||||
|  | ||||
|             if mm_processor_kwargs is not None: | ||||
|                 prompt["mm_processor_kwargs"] = mm_processor_kwargs | ||||
|  | ||||
|             prompts.append(prompt) | ||||
|  | ||||
|         return self.generate( | ||||
|             prompts, | ||||
|  | ||||
| @ -0,0 +1,146 @@ | ||||
| { | ||||
|     "1": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "2": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "4": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "8": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "16": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 64, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 5 | ||||
|     }, | ||||
|     "24": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "32": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "48": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 64, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "64": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "96": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "128": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "256": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "512": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1024": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1536": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "2048": { | ||||
|         "BLOCK_SIZE_M": 128, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "3072": { | ||||
|         "BLOCK_SIZE_M": 128, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "4096": { | ||||
|         "BLOCK_SIZE_M": 128, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 4 | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,146 @@ | ||||
| { | ||||
|     "1": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "2": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "4": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "8": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "16": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "24": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "32": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "48": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 32, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "64": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "96": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "128": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "256": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "512": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1024": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1536": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "2048": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "3072": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "4096": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,146 @@ | ||||
| { | ||||
|     "1": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "2": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "4": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "8": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "16": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 64, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "24": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "32": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 64, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "48": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "64": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 64, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "96": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 32, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "128": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "256": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "512": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "1024": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1536": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "2048": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 32, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "3072": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "4096": { | ||||
|         "BLOCK_SIZE_M": 128, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 32, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,146 @@ | ||||
| { | ||||
|     "1": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "2": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "4": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "8": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "16": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "24": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "32": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "48": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "64": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "96": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "128": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "256": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "512": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1024": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1536": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "2048": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "3072": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "4096": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,146 @@ | ||||
| { | ||||
|     "1": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "2": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 32, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "4": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 32, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "8": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "16": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 64, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "24": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 64, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "32": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "48": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "64": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "96": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "128": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "256": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "512": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "1024": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 128, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1536": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "2048": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "3072": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "4096": { | ||||
|         "BLOCK_SIZE_M": 128, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 8, | ||||
|         "num_stages": 3 | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,146 @@ | ||||
| { | ||||
|     "1": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "2": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "4": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "8": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "16": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "24": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 32, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "32": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "48": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "64": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "96": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "128": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "256": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "512": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1024": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 256, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "1536": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "2048": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "3072": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 16, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "4096": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 64, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,146 @@ | ||||
| { | ||||
|     "1": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "2": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 32, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 4 | ||||
|     }, | ||||
|     "4": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "8": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "16": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "24": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "32": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "48": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "64": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "96": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "128": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "256": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "512": { | ||||
|         "BLOCK_SIZE_M": 16, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1024": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "1536": { | ||||
|         "BLOCK_SIZE_M": 32, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 3 | ||||
|     }, | ||||
|     "2048": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 1, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "3072": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 64, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 32, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     }, | ||||
|     "4096": { | ||||
|         "BLOCK_SIZE_M": 64, | ||||
|         "BLOCK_SIZE_N": 128, | ||||
|         "BLOCK_SIZE_K": 64, | ||||
|         "GROUP_SIZE_M": 32, | ||||
|         "num_warps": 4, | ||||
|         "num_stages": 2 | ||||
|     } | ||||
| } | ||||
| @ -19,10 +19,11 @@ from vllm.logger import init_logger | ||||
| from vllm.model_executor.layers.linear import QKVCrossParallelLinear | ||||
| from vllm.model_executor.layers.quantization.base_config import ( | ||||
|     QuantizationConfig, QuantizeMethodBase) | ||||
| from vllm.model_executor.models.adapters import (as_embedding_model, | ||||
|                                                  as_reward_model, | ||||
|                                                  as_seq_cls_model) | ||||
| from vllm.model_executor.models.interfaces import SupportsQuant | ||||
| from vllm.model_executor.models.adapters import ( | ||||
|     as_embedding_model, as_reward_model, as_seq_cls_model, | ||||
|     try_create_mm_pooling_model_cls) | ||||
| from vllm.model_executor.models.interfaces import (SupportsQuant, | ||||
|                                                    supports_multimodal) | ||||
| from vllm.utils import is_pin_memory_available | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
| @ -183,6 +184,15 @@ def get_model_architecture( | ||||
|                 "performance may not be optimal.", arch) | ||||
|  | ||||
|     convert_type = model_config.convert_type | ||||
|     if convert_type != "none" and supports_multimodal(model_cls): | ||||
|         logger.debug_once("Detected conversion of Multi Modal model.") | ||||
|         converted = try_create_mm_pooling_model_cls(model_cls) | ||||
|         if converted is not None: | ||||
|             logger.debug_once("Creating wrapper class to forward pooler.") | ||||
|             return converted, arch | ||||
|         else: | ||||
|             logger.debug_once("Attempting direct conversion.") | ||||
|  | ||||
|     if convert_type == "none": | ||||
|         pass | ||||
|     elif convert_type == "embed": | ||||
|  | ||||
| @ -1,12 +1,15 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| import ast | ||||
| import inspect | ||||
| from collections.abc import Iterable | ||||
| from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from vllm.config import VllmConfig | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor.layers.activation import get_act_fn | ||||
| from vllm.model_executor.models.config import VerifyAndUpdateConfig | ||||
| @ -129,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: | ||||
|     return model_name + pooling_suffix | ||||
|  | ||||
|  | ||||
| def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T: | ||||
|  | ||||
|     class CallVisitor(ast.NodeVisitor): | ||||
|  | ||||
|         def __init__(self): | ||||
|             self.calls = [] | ||||
|  | ||||
|         def visit_Call(self, node): | ||||
|             if isinstance(node.func, ast.Name): | ||||
|                 self.calls.append(node.func.id) | ||||
|             self.generic_visit(node) | ||||
|  | ||||
|     visitor = CallVisitor() | ||||
|     visitor.visit(ast.parse(inspect.getsource(orig_cls))) | ||||
|     if "init_vllm_registered_model" not in visitor.calls: | ||||
|         return None | ||||
|  | ||||
|     class ModelForPooling(orig_cls, VllmModelForPooling): | ||||
|  | ||||
|         is_pooling_model = True | ||||
|  | ||||
|         def __init__( | ||||
|             self, | ||||
|             *, | ||||
|             vllm_config: "VllmConfig", | ||||
|             prefix: str = "", | ||||
|             **kwargs: Any, | ||||
|         ) -> None: | ||||
|             super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) | ||||
|  | ||||
|             self.pooler = self.get_language_model().pooler | ||||
|  | ||||
|     return ModelForPooling  # type: ignore | ||||
|  | ||||
|  | ||||
| def _create_pooling_model_cls(orig_cls: _T) -> _T: | ||||
|     # Lazy import | ||||
|     from .utils import AutoWeightsLoader, WeightsMapper | ||||
| @ -399,6 +437,7 @@ def load_weights_using_from_2_way_softmax( | ||||
|     from vllm.model_executor.models.utils import AutoWeightsLoader | ||||
|  | ||||
|     model_config = model.vllm_config.model_config | ||||
|  | ||||
|     tokens = getattr(model.config, "classifier_from_token", []) | ||||
|     tokens = cast(list[int], tokens) | ||||
|     assert len(tokens) == 2 | ||||
| @ -406,9 +445,10 @@ def load_weights_using_from_2_way_softmax( | ||||
|     if model.config.tie_word_embeddings: | ||||
|         model.lm_head = model.model.embed_tokens | ||||
|     else: | ||||
|         quant_config = model.vllm_config.quant_config | ||||
|         model.lm_head = ParallelLMHead(model.config.vocab_size, | ||||
|                                        model.config.hidden_size, | ||||
|                                        quant_config=model.quant_config) | ||||
|                                        quant_config=quant_config) | ||||
|  | ||||
|     loader = AutoWeightsLoader(model) | ||||
|     loaded_weights = loader.load_weights(weights) | ||||
| @ -452,9 +492,10 @@ def load_weights_no_post_processing(model, | ||||
|     if model.config.tie_word_embeddings: | ||||
|         model.lm_head = model.model.embed_tokens | ||||
|     else: | ||||
|         quant_config = model.vllm_config.quant_config | ||||
|         model.lm_head = ParallelLMHead(model.config.vocab_size, | ||||
|                                        model.config.hidden_size, | ||||
|                                        quant_config=model.quant_config) | ||||
|                                        quant_config=quant_config) | ||||
|  | ||||
|     loader = AutoWeightsLoader(model) | ||||
|     loaded_weights = loader.load_weights(weights) | ||||
|  | ||||
| @ -512,7 +512,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, | ||||
|             architectures=["Gemma3ForCausalLM"], | ||||
|         ) | ||||
|         logit_scale = getattr(config, "logit_scale", 1.0) | ||||
|         self.language_model.logits_processor.scale *= logit_scale | ||||
|  | ||||
|         if hasattr(self.language_model, "logits_processor"): | ||||
|             # The logits processor can be unset if we're using | ||||
|             # automatic conversion to pooling model. | ||||
|             self.language_model.logits_processor.scale *= logit_scale | ||||
|  | ||||
|         self.make_empty_intermediate_tensors = ( | ||||
|             self.language_model.make_empty_intermediate_tensors) | ||||
|  | ||||
| @ -170,8 +170,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): | ||||
|         return quant_config | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||
|         # NOTE: hidden_states can have either 1D or 2D shape. | ||||
|         orig_shape = hidden_states.shape | ||||
|         assert hidden_states.dim( | ||||
|         ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" | ||||
|         is_input_1d = hidden_states.dim() == 1 | ||||
|         hidden_dim = hidden_states.shape[-1] | ||||
|         hidden_states = hidden_states.view(-1, hidden_dim) | ||||
|  | ||||
| @ -180,7 +181,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): | ||||
|         final_hidden_states = self.experts(hidden_states=hidden_states, | ||||
|                                            router_logits=router_logits) | ||||
|  | ||||
|         return final_hidden_states.view(orig_shape) | ||||
|         # return to 1d if input is 1d | ||||
|         return final_hidden_states.squeeze(0) if is_input_1d else \ | ||||
|             final_hidden_states | ||||
|  | ||||
|  | ||||
| class Qwen3MoeAttention(nn.Module): | ||||
|  | ||||
| @ -2,6 +2,7 @@ | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| """Inference-only Qwen3Next model.""" | ||||
| from collections.abc import Iterable | ||||
| from itertools import islice | ||||
| from typing import Optional | ||||
|  | ||||
| import torch | ||||
| @ -917,8 +918,11 @@ class Qwen3NextModel(nn.Module): | ||||
|             make_empty_intermediate_tensors_factory( | ||||
|                 ["hidden_states", "residual"], config.hidden_size)) | ||||
|  | ||||
|         self.norm = Qwen3NextRMSNorm(config.hidden_size, | ||||
|                                      eps=config.rms_norm_eps) | ||||
|         if get_pp_group().is_last_rank: | ||||
|             self.norm = Qwen3NextRMSNorm(config.hidden_size, | ||||
|                                          eps=config.rms_norm_eps) | ||||
|         else: | ||||
|             self.norm = PPMissingLayer() | ||||
|  | ||||
|     def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||||
|         return self.embed_tokens(input_ids) | ||||
| @ -941,7 +945,7 @@ class Qwen3NextModel(nn.Module): | ||||
|             hidden_states = intermediate_tensors["hidden_states"] | ||||
|             residual = intermediate_tensors["residual"] | ||||
|  | ||||
|         for layer in self.layers: | ||||
|         for layer in islice(self.layers, self.start_layer, self.end_layer): | ||||
|             hidden_states, residual = layer( | ||||
|                 positions=positions, | ||||
|                 hidden_states=hidden_states, | ||||
|  | ||||
| @ -209,7 +209,8 @@ class GDNAttentionMetadataBuilder( | ||||
|  | ||||
|         # prepare tensors for cudagraph | ||||
|         if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0 | ||||
|                 and num_spec_decodes <= self.decode_cudagraph_max_bs): | ||||
|                 and num_spec_decodes <= self.decode_cudagraph_max_bs | ||||
|                 and m.num_actual_tokens <= self.decode_cudagraph_max_bs): | ||||
|             num_total_tokens = self.vllm_config.pad_for_cudagraph( | ||||
|                 m.num_actual_tokens) | ||||
|             batch_size = num_total_tokens // (self.num_spec + 1) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	