mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-26 00:24:53 +08:00 
			
		
		
		
	This PR adds the intrinsics based micro-gemm for BF16 using Advanced Matrix eXtension (AMX) instructions available in Intel 4th and 5th Xeon processors. A compilation check is added to `codecache.py` to check the validity of the compiler support. Also, since AMX requires an initialization in the Linux kernel to extra register states, an initialization function is added to do that and triggered via `codecache.py`. Performance speedups with >=10% on BF16 AMP, max_autotune vs. no autotune, measured on Intel(R) Xeon(R) Platinum 8488C: Static shapes Single-threaded | Model Family | Model Name | Speedup | |--------------|------------|---------| | timm_models | mixer_b16_224 | 1.54 | | timm_models | convit_base | 1.53 | | huggingface | MobileBertForQuestionAnswering | 1.52 | | torchbench | fastNLP_Bert | 1.44 | | torchbench | llama | 1.33 | | timm_models | swin_base_patch4_window7_224 | 1.31 | | torchbench | dlrm | 1.28 | | torchbench | timm_vision_transformer_large | 1.28 | | huggingface | MobileBertForMaskedLM | 1.27 | | timm_models | vit_base_patch16_224 | 1.26 | | timm_models | beit_base_patch16_224 | 1.23 | | timm_models | jx_nest_base | 1.21 | | torchbench | pyhpc_equation_of_state | 1.18 | | huggingface | Speech2Text2ForCausalLM | 1.15 | | timm_models | pit_b_224 | 1.14 | | timm_models | twins_pcpvt_base | 1.14 | | torchbench | maml_omniglot | 1.1 | | timm_models | botnet26t_256 | 1.1 | Multi-threaded | Model Family | Model Name | Speedup | |--------------|------------|---------| | torchbench | BERT_pytorch | 1.35 | | torchbench | lennard_jones | 2.43 | | torchbench | hf_Albert | 1.35 | | torchbench | hf_T5 | 1.34 | | torchbench | soft_actor_critic | 1.34 | | torchbench | fastNLP_Bert | 1.28 | | huggingface | LayoutLMForSequenceClassification | 1.26 | | torchbench | llama | 1.24 | | huggingface | GPT2ForSequenceClassification | 1.19 | | torchbench | hf_Bart | 1.17 | | torchbench | hf_Bert_large | 1.16 | | torchbench | hf_GPT2 | 1.16 | | timm_models | gmixer_24_224 | 1.16 | | torchbench | hf_GPT2_large | 1.15 | | torchbench | maml_omniglot | 1.14 | | torchbench | hf_Bert | 1.13 | | torchbench | hf_DistilBert | 1.13 | | torchbench | hf_T5_large | 1.12 | | huggingface | MT5ForConditionalGeneration | 1.11 | Dynamic shapes Single-threaded | Model Family | Model Name | Speedup | |--------------|------------|-------| | timm_models | mixer_b16_224 | 1.52 | | timm_models | convit_base | 1.5 | | huggingface | MobileBertForQuestionAnswering | 1.49 | | torchbench | fastNLP_Bert | 1.42 | | torchbench | timm_vision_transformer_large | 1.28 | | timm_models | swin_base_patch4_window7_224 | 1.27 | | torchbench | llama | 1.26 | | huggingface | MobileBertForMaskedLM | 1.25 | | timm_models | vit_base_patch16_224 | 1.25 | | timm_models | beit_base_patch16_224 | 1.24 | | timm_models | jx_nest_base | 1.2 | | torchbench | dlrm | 1.19 | | timm_models | pit_b_224 | 1.13 | | timm_models | twins_pcpvt_base | 1.13 | | torchbench | hf_Bert_large | 1.12 | | torchbench | hf_BigBird | 1.11 | | huggingface | Speech2Text2ForCausalLM | 1.11 | | timm_models | eca_botnext26ts_256 | 1.11 | | timm_models | botnet26t_256 | 1.1 | Multi-threaded | Model Family | Model Name | Speedup | |--------------|------------|-------| | torchbench | BERT_pytorch | 1.18 | | torchbench | lennard_jones | 2.18 | | torchbench | hf_Albert | 1.37 | | torchbench | soft_actor_critic | 1.31 | | huggingface | GPT2ForSequenceClassification | 1.29 | | torchbench | hf_T5 | 1.28 | | torchbench | fastNLP_Bert | 1.27 | | torchbench | hf_Bart | 1.21 | | torchbench | hf_Bert_large | 1.19 | | torchbench | hf_T5_large | 1.19 | | torchbench | hf_Bert | 1.16 | | torchbench | hf_GPT2 | 1.16 | | huggingface | CamemBert | 1.16 | | torchbench | hf_GPT2_large | 1.13 | | torchbench | functorch_maml_omniglot | 1.12 | | huggingface | BertForMaskedLM | 1.12 | | huggingface | MT5ForConditionalGeneration | 1.12 | | torchbench | hf_DistilBert | 1.11 | | timm_models | mixnet_l | 1.11 | | timm_models | tf_mixnet_l | 1.11 | No perf regressions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127195 Approved by: https://github.com/jansel
		
			
				
	
	
		
			180 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			180 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: allow-untyped-defs
 | |
| r"""
 | |
| This package implements abstractions found in ``torch.cuda``
 | |
| to facilitate writing device-agnostic code.
 | |
| """
 | |
| 
 | |
| from contextlib import AbstractContextManager
 | |
| from typing import Any, Optional, Union
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from .. import device as _device
 | |
| from . import amp
 | |
| 
 | |
| 
 | |
| __all__ = [
 | |
|     "is_available",
 | |
|     "synchronize",
 | |
|     "current_device",
 | |
|     "current_stream",
 | |
|     "stream",
 | |
|     "set_device",
 | |
|     "device_count",
 | |
|     "Stream",
 | |
|     "StreamContext",
 | |
|     "Event",
 | |
| ]
 | |
| 
 | |
| _device_t = Union[_device, str, int, None]
 | |
| 
 | |
| 
 | |
| def _is_cpu_support_avx2() -> bool:
 | |
|     r"""Returns a bool indicating if CPU supports AVX2."""
 | |
|     return torch._C._cpu._is_cpu_support_avx2()
 | |
| 
 | |
| 
 | |
| def _is_cpu_support_avx512() -> bool:
 | |
|     r"""Returns a bool indicating if CPU supports AVX512."""
 | |
|     return torch._C._cpu._is_cpu_support_avx512()
 | |
| 
 | |
| 
 | |
| def _is_cpu_support_vnni() -> bool:
 | |
|     r"""Returns a bool indicating if CPU supports VNNI."""
 | |
|     # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later.
 | |
|     return torch._C._cpu._is_cpu_support_avx512_vnni()
 | |
| 
 | |
| 
 | |
| def _is_cpu_support_amx_tile() -> bool:
 | |
|     r"""Returns a bool indicating if CPU supports AMX_TILE."""
 | |
|     return torch._C._cpu._is_cpu_support_amx_tile()
 | |
| 
 | |
| 
 | |
| def _init_amx() -> bool:
 | |
|     r"""Initializes AMX instructions."""
 | |
|     return torch._C._cpu._init_amx()
 | |
| 
 | |
| 
 | |
| def is_available() -> bool:
 | |
|     r"""Returns a bool indicating if CPU is currently available.
 | |
| 
 | |
|     N.B. This function only exists to facilitate device-agnostic code
 | |
| 
 | |
|     """
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def synchronize(device: _device_t = None) -> None:
 | |
|     r"""Waits for all kernels in all streams on the CPU device to complete.
 | |
| 
 | |
|     Args:
 | |
|         device (torch.device or int, optional): ignored, there's only one CPU device.
 | |
| 
 | |
|     N.B. This function only exists to facilitate device-agnostic code.
 | |
|     """
 | |
| 
 | |
| 
 | |
| class Stream:
 | |
|     """
 | |
|     N.B. This class only exists to facilitate device-agnostic code
 | |
|     """
 | |
| 
 | |
|     def __init__(self, priority: int = -1) -> None:
 | |
|         pass
 | |
| 
 | |
|     def wait_stream(self, stream) -> None:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class Event:
 | |
|     def query(self) -> bool:
 | |
|         return True
 | |
| 
 | |
|     def record(self, stream=None) -> None:
 | |
|         pass
 | |
| 
 | |
|     def synchronize(self) -> None:
 | |
|         pass
 | |
| 
 | |
|     def wait(self, stream=None) -> None:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| _default_cpu_stream = Stream()
 | |
| _current_stream = _default_cpu_stream
 | |
| 
 | |
| 
 | |
| def current_stream(device: _device_t = None) -> Stream:
 | |
|     r"""Returns the currently selected :class:`Stream` for a given device.
 | |
| 
 | |
|     Args:
 | |
|         device (torch.device or int, optional): Ignored.
 | |
| 
 | |
|     N.B. This function only exists to facilitate device-agnostic code
 | |
| 
 | |
|     """
 | |
|     return _current_stream
 | |
| 
 | |
| 
 | |
| class StreamContext(AbstractContextManager):
 | |
|     r"""Context-manager that selects a given stream.
 | |
| 
 | |
|     N.B. This class only exists to facilitate device-agnostic code
 | |
| 
 | |
|     """
 | |
| 
 | |
|     cur_stream: Optional[Stream]
 | |
| 
 | |
|     def __init__(self, stream):
 | |
|         self.stream = stream
 | |
|         self.prev_stream = _default_cpu_stream
 | |
| 
 | |
|     def __enter__(self):
 | |
|         cur_stream = self.stream
 | |
|         if cur_stream is None:
 | |
|             return
 | |
| 
 | |
|         global _current_stream
 | |
|         self.prev_stream = _current_stream
 | |
|         _current_stream = cur_stream
 | |
| 
 | |
|     def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
 | |
|         cur_stream = self.stream
 | |
|         if cur_stream is None:
 | |
|             return
 | |
| 
 | |
|         global _current_stream
 | |
|         _current_stream = self.prev_stream
 | |
| 
 | |
| 
 | |
| def stream(stream: Stream) -> AbstractContextManager:
 | |
|     r"""Wrapper around the Context-manager StreamContext that
 | |
|     selects a given stream.
 | |
| 
 | |
|     N.B. This function only exists to facilitate device-agnostic code
 | |
|     """
 | |
|     return StreamContext(stream)
 | |
| 
 | |
| 
 | |
| def device_count() -> int:
 | |
|     r"""Returns number of CPU devices (not cores). Always 1.
 | |
| 
 | |
|     N.B. This function only exists to facilitate device-agnostic code
 | |
|     """
 | |
|     return 1
 | |
| 
 | |
| 
 | |
| def set_device(device: _device_t) -> None:
 | |
|     r"""Sets the current device, in CPU we do nothing.
 | |
| 
 | |
|     N.B. This function only exists to facilitate device-agnostic code
 | |
|     """
 | |
| 
 | |
| 
 | |
| def current_device() -> str:
 | |
|     r"""Returns current device for cpu. Always 'cpu'.
 | |
| 
 | |
|     N.B. This function only exists to facilitate device-agnostic code
 | |
|     """
 | |
|     return "cpu"
 |