mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Compare commits
	
		
			9 Commits
		
	
	
		
			main
			...
			ciflow/tru
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 95ad31ab20 | |||
| ac6c125329 | |||
| 0b41d9a6e0 | |||
| 39f9c13974 | |||
| fd90e9d082 | |||
| 30c8f0f3ad | |||
| 03bd16b815 | |||
| b366fc40ce | |||
| e5bf90ba39 | 
							
								
								
									
										1414
									
								
								test/inductor/test_lookup_table.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1414
									
								
								test/inductor/test_lookup_table.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -2073,6 +2073,39 @@ write_are_deterministic_algorithms_enabled = ( | ||||
| ) | ||||
|  | ||||
|  | ||||
| class lookup_table: | ||||
|     # Enable/disable lookup table choices system (defaults to True) | ||||
|     active: bool = True | ||||
|  | ||||
|     # Lookup table for template config overrides | ||||
|     table: Optional[dict[str, list[dict[str, Any]]]] = None | ||||
|  | ||||
|     # Enable template src_hash checking in lookup table to prevent using stale configs. | ||||
|     # If True, configs with 'template_hash' field will be compared against the template's | ||||
|     # src_hash at runtime and filtered out if they don't match. If False, no | ||||
|     # hash checking is performed. | ||||
|     check_src_hash: bool = True | ||||
|  | ||||
|     # Recorder configuration for capturing autotuning results | ||||
|     # Master switch for recording functionality - must be True to enable any recording | ||||
|     recording_active: bool = False | ||||
|     # Enable emitting entries to registered emit backends (e.g., logging) | ||||
|     recorder_emit: bool = True | ||||
|     # Directory to record lookup tables to. If set, enables DirectoryRecordBackend | ||||
|     recorder_record_dir: Optional[str] = None | ||||
|     # Number of top choices to record. If None or negative, record all choices. | ||||
|     # If 0, record nothing. If positive, record only the top k choices. | ||||
|     recorder_topk: Optional[int] = None | ||||
|     # Whether to record template hashes for templates that provide them. | ||||
|     # Template hashes make the lookup table more robust but potentially less portable | ||||
|     record_template_hash: bool = True | ||||
|     # Whether to include device key when recording lookup table entries. | ||||
|     # When True, entries are recorded with device-specific keys (more precise but less portable). | ||||
|     # When False, entries are recorded without device keys (more portable across devices). | ||||
|     # During lookup, both key formats are always tried regardless of this setting. | ||||
|     record_with_device_key: bool = True | ||||
|  | ||||
|  | ||||
| class test_configs: | ||||
|     force_extern_kernel_in_multi_template: bool = False | ||||
|  | ||||
|  | ||||
							
								
								
									
										540
									
								
								torch/_inductor/lookup_table/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										540
									
								
								torch/_inductor/lookup_table/README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,540 @@ | ||||
| # Template Lookup Table System | ||||
|  | ||||
| The template lookup table system provides a way to pre-configure kernel template parameters for specific operations and | ||||
| input configurations, bypassing the default choice generation and autotuning process. | ||||
|  | ||||
| ## Overview | ||||
|  | ||||
| The lookup table system replaces default choice generation with pre-configured template parameters for specific | ||||
| operations and input configurations. It sits orthogonal to `max-autotune(-gemm)` in the following way | ||||
|  | ||||
| If a lookup table is provided and there is a match | ||||
|  | ||||
| - We check whether the template(s) in the match are currently in use | ||||
| - If so, we use the pre-configured template(s) and config and bypass choice generation | ||||
|   - If more than one choice is provided, we run autotune among the pre-configured choices | ||||
| - If not, we fall back to the default choice generation process, including max-autotune(-gemm) logic | ||||
|  | ||||
| If there is no match, we fall back to the default choice generation process, including max-autotune(-gemm) logic | ||||
|  | ||||
| ## Configuration | ||||
|  | ||||
| Enable the system by setting both: | ||||
|  | ||||
| ```python | ||||
| from torch._inductor import config | ||||
| config.lookup_table.table = your_table_dict | ||||
| # You also need to set it as the default choice handler | ||||
| from torch._inductor.lookup_table import LookupTableChoices | ||||
| torch._inductor.V.set_choices_handler(LookupTableChoices()) | ||||
| ``` | ||||
|  | ||||
| ### Device Key Handling | ||||
|  | ||||
| The key schema format is described in detail in the [Key Schemas](#key-schemas) section below. | ||||
|  | ||||
| Configure device key behavior: | ||||
|  | ||||
| ```python | ||||
| # Control whether entries include device-specific keys for lookups | ||||
| # Device-agnostic entries work across different GPU models | ||||
| ``` | ||||
|  | ||||
| **Lookup Behavior**: During lookup, the system automatically tries both key formats: | ||||
|  | ||||
| 1. **Device-specific key** (e.g., `"NVIDIA H100+input_data+mm"`) - tried first | ||||
| 1. **Device-agnostic key** (e.g., `"input_data+mm"`) - tried if device-specific fails | ||||
|  | ||||
| **Priority**: If both device-specific and device-agnostic entries exist for the same inputs, the device-specific entry | ||||
| takes priority. | ||||
|  | ||||
| **NOTE**: Device-based keys simplify hardware-specific optimization without complex build rules. Currently limited to | ||||
| device name only. If you need additional conditional key attributes (e.g., CUDA version filtering), please file an issue | ||||
| or submit a patch. | ||||
|  | ||||
| ## Behavior | ||||
|  | ||||
| When the table is active, the following behavior occurs for all supported operations: | ||||
|  | ||||
| ### Match Found | ||||
|  | ||||
| - Uses pre-configured choices from the table instead of generating default choices | ||||
| - Bypasses autotuning if only a single choice is provided | ||||
| - If multiple choices are provided, autotuning occurs among those choices only | ||||
|  | ||||
| ### No Match Found | ||||
|  | ||||
| - Standard default behavior - generates choices using heuristics and max-autotune settings | ||||
|  | ||||
| ### Table Not Set or Inactive | ||||
|  | ||||
| - Standard default behavior - generates choices using heuristics and max-autotune settings | ||||
|  | ||||
| ## Supported Operations | ||||
|  | ||||
| Currently supports: `mm`, `addmm`, `bmm`, `mm_plus_mm`, `scaled_mm` operations with | ||||
|  | ||||
| - Triton | ||||
| - ATEN | ||||
| - DecomposeK | ||||
|  | ||||
| ## Table Format | ||||
|  | ||||
| The table is a dictionary with keys in the format: | ||||
|  | ||||
| ``` | ||||
| "input_key+op_name" | ||||
| ``` | ||||
|  | ||||
| Where: | ||||
|  | ||||
| - `input_key`: Generated from `KernelInputs.key` property, represents tensor shapes/dtypes/strides | ||||
| - `op_name`: Operation name (`"mm"`, `"addmm"`, etc.) | ||||
|  | ||||
| Each value is a list of configuration dictionaries containing: | ||||
|  | ||||
| - `template_id`: Template identifier (`"triton:mm"`, `"triton::mm_persistent_tma"`, `"decompose_k"`, etc.) | ||||
| - Template-specific parameters (`BLOCK_M`, `BLOCK_N`, `BLOCK_K`, `num_warps`, etc.) | ||||
|  | ||||
| ## Key Schemas | ||||
|  | ||||
| **NOTE**: The key schema format is subject to change as the system evolves. | ||||
|  | ||||
| The lookup table uses composite keys to match kernel configurations. See | ||||
| [Implementation Details](#implementation-details) below for more technical information about key generation. This | ||||
| section describes the structure of these keys. | ||||
|  | ||||
| ### Key Format Structure | ||||
|  | ||||
| Keys follow the pattern: | ||||
|  | ||||
| ``` | ||||
| [device_name+]input_key+[additional_params+]op_name | ||||
| ``` | ||||
|  | ||||
| Components: | ||||
|  | ||||
| - **device_name** (optional): GPU device identifier (e.g., `"NVIDIA H100"`) | ||||
|  | ||||
|   - Obtained from `torch.cuda.get_device_properties().gcnArchName` | ||||
|   - Enables device-specific optimizations | ||||
|   - When omitted, creates device-agnostic entries that work across hardware | ||||
|  | ||||
| - **input_key**: Tensor configuration representation from `KernelInputs.key` | ||||
|  | ||||
|   - Format: `((dtype, shape, stride), (dtype, shape, stride), ...)` | ||||
|   - Each tuple represents one input tensor's properties | ||||
|   - Example: `((torch.float16, [128, 256], [0, 1]), (torch.float16, [64, 256], [256, 1]))` | ||||
|   - Order matches the operation's input argument order | ||||
|  | ||||
| - **additional_params** (optional): Operation-specific parameters | ||||
|  | ||||
|   - Format: `key1=value1&key2=value2` | ||||
|   - Example: `alpha=1&beta=1` for addmm operations | ||||
|  | ||||
| - **op_name**: Operation identifier | ||||
|  | ||||
|   - Examples: `"mm"`, `"addmm"`, `"bmm"`, `"mm_plus_mm"`, `"scaled_mm"` | ||||
|  | ||||
| ### Key Examples | ||||
|  | ||||
| **Device-specific key for addmm:** | ||||
|  | ||||
| ``` | ||||
| "NVIDIA H100+((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm" | ||||
| ``` | ||||
|  | ||||
| **Device-agnostic key for mm:** | ||||
|  | ||||
| ``` | ||||
| "((torch.float16, [64, 128], [128, 1]), (torch.float16, [128, 256], [256, 1]))+mm" | ||||
| ``` | ||||
|  | ||||
| **Key with no additional parameters:** | ||||
|  | ||||
| ``` | ||||
| "((torch.float32, [512, 512], [512, 1]), (torch.float32, [512, 512], [512, 1]))+bmm" | ||||
| ``` | ||||
|  | ||||
| ### Lookup Strategy | ||||
|  | ||||
| During lookup, the system tries keys in priority order: | ||||
|  | ||||
| 1. **Device-specific key** - checked first if device information is available | ||||
| 1. **Device-agnostic key** - fallback if device-specific lookup fails | ||||
|  | ||||
| This allows tables to contain: | ||||
|  | ||||
| - Device-optimized configurations (higher priority) | ||||
| - Portable configurations that work across devices | ||||
| - Mix of both for flexible deployment | ||||
|  | ||||
| ## Example Table | ||||
|  | ||||
| This is an example table for a single input showing two configurations | ||||
|  | ||||
| ```python | ||||
| table = { | ||||
|   "((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm": [ | ||||
|     { | ||||
|       "template_id": "triton::mm", | ||||
|       "EVEN_K": true, | ||||
|       "USE_FAST_ACCUM": false, | ||||
|       "ACC_TYPE": "tl.float32", | ||||
|       "num_stages": 2, | ||||
|       "num_warps": 4, | ||||
|       "BLOCK_M": 32, | ||||
|       "BLOCK_N": 32, | ||||
|       "BLOCK_K": 64, | ||||
|       "hint_override": null, | ||||
|       "GROUP_M": 8, | ||||
|       "template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896" | ||||
|     }, | ||||
|     { | ||||
|       "template_id": "aten::bias_addmm" | ||||
|     }, | ||||
|   ] | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## Source Hashing Safety | ||||
|  | ||||
| The lookup table system includes source hashing to prevent using stale configurations when template code changes. | ||||
|  | ||||
| ### Configuration | ||||
|  | ||||
| - **Enabled by default**: `torch._inductor.config.lookup_table.check_src_hash = True` | ||||
| - **Optional field**: Add `"template_hash"` to table entries for enhanced safety | ||||
|  | ||||
| ### Behavior | ||||
|  | ||||
| When source hash checking is enabled: | ||||
|  | ||||
| - Template configurations with `"template_hash"` fields are validated against current template source hashes | ||||
| - Mismatched hashes indicate the template code has changed since the configuration was created | ||||
| - Stale configurations are automatically filtered out with a warning message | ||||
| - Configurations without hash fields are preserved for backward compatibility or if the user wants to fly looser | ||||
|  | ||||
| ### Example with Template Hash | ||||
|  | ||||
| ```python | ||||
| { | ||||
|   "template_id": "triton::mm", | ||||
|   "BLOCK_M": 32, | ||||
|   "BLOCK_N": 32, | ||||
|   "BLOCK_K": 16, | ||||
|   "template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896" | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## Performance Impact | ||||
|  | ||||
| - **Lookup Hit**: Eliminates heuristic choice generation and autotuning overhead (if a single choice) | ||||
| - **Lookup Miss**: Default behavior, including heuristic choice generation and autotuning | ||||
| - **Memory**: Table stored in memory, minimal overhead for key generation and lookup | ||||
|  | ||||
| ## Implementation Details | ||||
|  | ||||
| ### Key Generation | ||||
|  | ||||
| - Device key: Uses `torch.cuda.get_device_properties().gcnArchName` (e.g., "NVIDIA H100") | ||||
| - Input key: Generated from `KernelInputs.key` containing tensor properties | ||||
|  | ||||
| ### Entry Points | ||||
|  | ||||
| The system is accessed through: | ||||
|  | ||||
| - `lookup_template_configs(kernel_inputs, op_name, template_uids)` - Main lookup function | ||||
| - `LookupTableChoices._finalize_template_configs()` - Integration point with existing choice system | ||||
|  | ||||
| ### Error Handling | ||||
|  | ||||
| - Validates config dictionaries contain required `template_id` field | ||||
| - Gracefully handles non-CUDA devices by returning empty results | ||||
| <<<<<<< HEAD | ||||
| ======= | ||||
|  | ||||
| ## Recording Lookup Tables | ||||
|  | ||||
| The system can record autotuning results to automatically build lookup tables for future use. This eliminates the need to manually create table entries and ensures optimal configurations are captured from real workloads. | ||||
|  | ||||
| ### Quick Start: Recording | ||||
|  | ||||
| 1. **Enable recording:** | ||||
|    ```python | ||||
|    from torch._inductor import config | ||||
|  | ||||
|    # Master switch - must be True to enable recording | ||||
|    config.lookup_table.recording_active = True | ||||
|  | ||||
|    # Configure recording behavior | ||||
|    config.lookup_table.recorder_topk = 5  # Record top 5 fastest choices per operation | ||||
|    config.lookup_table.recorder_record_dir = "/path/to/output"  # Save to files | ||||
|    ``` | ||||
|  | ||||
| 2. **Run your model with autotuning:** | ||||
|    ```python | ||||
|    import torch | ||||
|  | ||||
|    model = YourModel() | ||||
|    compiled_model = torch.compile(model, mode="max-autotune") | ||||
|  | ||||
|    # Recording happens automatically during compilation and execution | ||||
|    result = compiled_model(inputs) | ||||
|    ``` | ||||
|  | ||||
| 3. **Files are automatically saved** to the specified directory with timestamped filenames like `inductor_lut_20241205_143052_123.json`. | ||||
|  | ||||
| ### Recording Configuration | ||||
|  | ||||
| All recording options are available under `torch._inductor.config.lookup_table`: | ||||
|  | ||||
| ```python | ||||
| from torch._inductor import config | ||||
|  | ||||
| # Master recording switch - must be True for any recording to happen | ||||
| config.lookup_table.recording_active = True  # Default: False | ||||
|  | ||||
| # Logging and immediate emission | ||||
| config.lookup_table.recorder_emit = True  # Default: True (logs entries) | ||||
|  | ||||
| # File recording - set directory to enable file output | ||||
| config.lookup_table.recorder_record_dir = "/path/to/save/tables"  # Default: None | ||||
|  | ||||
| # Number of top choices to record per operation key | ||||
| config.lookup_table.recorder_topk = 10  # Default: None (record all) | ||||
| config.lookup_table.recorder_topk = 0   # Special case: disable recording | ||||
|  | ||||
| # Template safety and portability options | ||||
| config.lookup_table.record_template_hash = True   # Default: True (include hashes) | ||||
| config.lookup_table.record_with_device_key = True # Default: True (device-specific keys) | ||||
| ``` | ||||
|  | ||||
| ### Understanding TopK: Determinism vs. Flexibility | ||||
|  | ||||
| The `recorder_topk` setting is crucial for controlling the behavior of your recorded lookup tables: | ||||
|  | ||||
| #### TopK = 1: Maximum Performance and Determinism | ||||
| ```python | ||||
| config.lookup_table.recorder_topk = 1  # Record only the fastest choice | ||||
| ``` | ||||
|  | ||||
| **Benefits:** | ||||
| - **No autotuning overhead**: When using the recorded table, exactly one choice is available, so no autotuning occurs | ||||
| - **Perfect determinism**: Always uses the same kernel for identical inputs across runs | ||||
| - **Fastest compilation**: Minimal overhead during `torch.compile()` with the lookup table | ||||
| - **Production-ready**: Ideal for deployment where consistency and speed matter most | ||||
|  | ||||
| **Use case:** Production environments where you want maximum performance and deterministic behavior. | ||||
|  | ||||
| #### TopK > 1: Balanced Performance with Options | ||||
| ```python | ||||
| config.lookup_table.recorder_topk = 5  # Record top 5 fastest choices | ||||
| ``` | ||||
|  | ||||
| **Benefits:** | ||||
| - **Some autotuning**: When using the recorded table, autotuning occurs among the recorded choices (faster than full autotuning) | ||||
| - **Flexibility**: Multiple good options available if hardware characteristics change slightly | ||||
| - **Robustness**: Backup choices if the fastest choice becomes unavailable | ||||
|  | ||||
| **Trade-offs:** | ||||
| - **Slight overhead**: Autotuning still occurs among the recorded choices | ||||
| - **Less determinism**: May pick different choices between runs based on timing variations | ||||
|  | ||||
| **Use case:** Development/staging environments where you want good performance but retain some flexibility. | ||||
|  | ||||
| #### TopK = None: Maximum Visibility for Analysis | ||||
| ```python | ||||
| config.lookup_table.recorder_topk = None  # Record ALL choices that were tested | ||||
| ``` | ||||
|  | ||||
| **Benefits:** | ||||
| - **Complete picture**: See every template choice that was considered during autotuning | ||||
| - **Manual optimization**: Analyze all options and manually edit the table to select specific choices | ||||
| - **Debugging**: Understand what choices were available and their relative performance | ||||
|  | ||||
| **Trade-offs:** | ||||
| - **Large tables**: More storage space and memory usage | ||||
| - **Full autotuning**: When using the table, autotuning occurs among all recorded choices (no speed benefit) | ||||
|  | ||||
| **Use case:** Analysis, debugging, or when you want to manually curate the final lookup table. | ||||
|  | ||||
| #### Recommended Strategy | ||||
|  | ||||
| 1. **Start with TopK = None** for analysis: | ||||
|    ```python | ||||
|    config.lookup_table.recorder_topk = None  # See all options | ||||
|    ``` | ||||
|  | ||||
| 2. **Analyze the results** to understand choice distribution and performance gaps | ||||
|  | ||||
| 3. **Switch to TopK = 1** for production: | ||||
|    ```python | ||||
|    config.lookup_table.recorder_topk = 1  # Lock in the fastest choice | ||||
|    ``` | ||||
|  | ||||
| 4. **Validate determinism** by running the same workload multiple times and confirming identical kernels | ||||
|  | ||||
| ### Device Key Configuration for Recording | ||||
|  | ||||
| The `record_with_device_key` setting controls whether recorded entries are device-specific or portable: | ||||
|  | ||||
| ```python | ||||
| # Device-specific recording (more precise but less portable) | ||||
| config.lookup_table.record_with_device_key = True | ||||
| # Key format: "NVIDIA H100+input_shapes+operation" | ||||
| # Best for: Production environments with known hardware | ||||
|  | ||||
| # Device-agnostic recording (more portable across GPU types) | ||||
| config.lookup_table.record_with_device_key = False | ||||
| # Key format: "input_shapes+operation" | ||||
| # Best for: Development, CI/CD, mixed GPU environments | ||||
| ``` | ||||
|  | ||||
| **Note**: During lookup, both key formats are always tried regardless of this setting, with device-specific keys taking priority if both exist. | ||||
|  | ||||
| ### How Recording Works | ||||
|  | ||||
| The recording system automatically: | ||||
|  | ||||
| 1. **Captures autotuning results**: Monitors all kernel selection decisions during `torch.compile()` execution | ||||
| 2. **Filters by performance**: Records only the fastest choices (configurable via `recorder_topk`) | ||||
| 3. **Generates lookup keys**: Uses the same key format as the lookup system for consistency | ||||
| 4. **Saves incrementally**: Each autotuning session appends to timestamped JSON files | ||||
| 5. **Maintains safety**: Includes template hashes to prevent using stale configurations | ||||
|  | ||||
| ### Example: Complete Recording Workflow | ||||
|  | ||||
| ```python | ||||
| import torch | ||||
| from torch._inductor import config | ||||
| import tempfile | ||||
| import json | ||||
|  | ||||
| # Enable recording with configuration | ||||
| config.lookup_table.recording_active = True | ||||
| config.lookup_table.recorder_topk = 3 | ||||
| config.lookup_table.record_template_hash = True | ||||
|  | ||||
| # Use temporary directory for this example | ||||
| with tempfile.TemporaryDirectory() as temp_dir: | ||||
|     config.lookup_table.recorder_record_dir = temp_dir | ||||
|  | ||||
|     # Your model | ||||
|     def matmul_model(a, b): | ||||
|         return torch.mm(a, b) | ||||
|  | ||||
|     # Compile with autotuning (triggers recording) | ||||
|     compiled_model = torch.compile(matmul_model, mode="max-autotune") | ||||
|  | ||||
|     # Run the model (autotuning results are recorded automatically) | ||||
|     a = torch.randn(512, 512, device='cuda', dtype=torch.float16) | ||||
|     b = torch.randn(512, 512, device='cuda', dtype=torch.float16) | ||||
|     result = compiled_model(a, b) | ||||
|  | ||||
|     # Check recorded files | ||||
|     import os | ||||
|     files = [f for f in os.listdir(temp_dir) if f.startswith('inductor_lut_')] | ||||
|     print(f"Recorded files: {files}") | ||||
|  | ||||
|     # Load and inspect the recorded lookup table | ||||
|     if files: | ||||
|         with open(os.path.join(temp_dir, files[0])) as f: | ||||
|             recorded_table = json.load(f) | ||||
|  | ||||
|         print("Recorded entries:") | ||||
|         for key, configs in recorded_table.items(): | ||||
|             print(f"  Key: {key}") | ||||
|             for i, config in enumerate(configs): | ||||
|                 print(f"    Config {i+1}: template_id={config['template_id']}") | ||||
| ``` | ||||
|  | ||||
| ### Using Recorded Tables | ||||
|  | ||||
| Once you have recorded tables, use them for faster compilation: | ||||
|  | ||||
| ```python | ||||
| from torch._inductor import config | ||||
| from torch._inductor.lookup_table import LookupTableChoices | ||||
| from torch._inductor.virtualized import V | ||||
| import json | ||||
|  | ||||
| # Load your recorded table | ||||
| with open('inductor_lut_20241205_143052_123.json') as f: | ||||
|     lookup_table = json.load(f) | ||||
|  | ||||
| # Configure the system to use the lookup table | ||||
| config.lookup_table.table = lookup_table | ||||
| V.set_choices_handler(LookupTableChoices()) | ||||
|  | ||||
| # Now compilation will use your recorded configurations | ||||
| model = torch.compile(your_model, mode="max-autotune") | ||||
| result = model(inputs)  # Uses lookup table, skips autotuning | ||||
| ``` | ||||
|  | ||||
| ### Advanced: Custom Recording Backends | ||||
|  | ||||
| Extend the recording system with custom backends for specialized workflows: | ||||
|  | ||||
| ```python | ||||
| from torch._inductor.lookup_table import recorder | ||||
|  | ||||
| # Custom emit backend for immediate processing | ||||
| class CustomLogBackend(recorder.EmitBackend): | ||||
|     def emit(self, entry): | ||||
|         # Process each entry immediately as it's recorded | ||||
|         print(f"Recorded {entry.key} -> {entry.value['template_id']} (runtime: {entry.runtime:.4f}ms)") | ||||
|  | ||||
| # Custom record backend for batch processing | ||||
| class DatabaseRecordBackend(recorder.RecordBackend): | ||||
|     def __init__(self, connection_string): | ||||
|         self.conn = connection_string | ||||
|  | ||||
|     def dump(self, data): | ||||
|         # Save all entries to database when dump() is called | ||||
|         for key, entries in data.items(): | ||||
|             for entry in entries: | ||||
|                 self.save_to_db(key, entry.value, entry.runtime) | ||||
|  | ||||
| # Register custom backends | ||||
| recorder.add_backend(CustomLogBackend()) | ||||
| recorder.add_backend(DatabaseRecordBackend("postgresql://...")) | ||||
| ``` | ||||
|  | ||||
| ### Performance and Overhead | ||||
|  | ||||
| **Recording Performance**: | ||||
| - **Minimal overhead**: Recording adds ~1-5μs per kernel selection | ||||
| - **Fast bail**: When `recording_active=False`, overhead is ~100ns (single boolean check) | ||||
| - **Memory efficient**: Only keeps configured `topk` entries per operation in memory | ||||
|  | ||||
| **Storage**: | ||||
| - **Typical size**: 1-10KB per recorded table file | ||||
| - **Compression**: JSON format is human-readable and compresses well | ||||
| - **Incremental**: Each compilation session creates a separate timestamped file | ||||
|  | ||||
| ### Troubleshooting Recording | ||||
|  | ||||
| **No files created?** | ||||
| ```python | ||||
| # Check if recording is properly enabled | ||||
| from torch._inductor import config | ||||
| print(f"Recording active: {config.lookup_table.recording_active}") | ||||
| print(f"Record directory: {config.lookup_table.recorder_record_dir}") | ||||
| print(f"TopK setting: {config.lookup_table.recorder_topk}") | ||||
|  | ||||
| # Ensure max-autotune is enabled to trigger template selection | ||||
| compiled_model = torch.compile(model, mode="max-autotune") | ||||
| ``` | ||||
|  | ||||
| **Empty tables?** | ||||
| - Recording only captures results from operations that undergo autotuning | ||||
| - Ensure your model has matrix operations (`mm`, `addmm`, `bmm`, etc.) | ||||
| - Check that input sizes are large enough to trigger template-based kernels | ||||
| - Verify GPU kernels are being used (CPU operations aren't recorded) | ||||
|  | ||||
| **Files but no entries?** | ||||
| - Check `recorder_topk` isn't set to 0 (which disables recording) | ||||
| - Ensure autotuning found valid template choices (not just ATEN fallbacks) | ||||
| - Verify templates have the required `template_id` and parameters | ||||
| >>>>>>> 8e9a06853941 ([inductor][lookup table] add recorder 2/3) | ||||
							
								
								
									
										34
									
								
								torch/_inductor/lookup_table/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								torch/_inductor/lookup_table/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,34 @@ | ||||
| """ | ||||
| Template lookup table system for PyTorch Inductor. | ||||
|  | ||||
| This package provides functionality for: | ||||
| - Loading pre-configured template choices from lookup tables | ||||
| - Managing template configurations and choices | ||||
|  | ||||
| All functionality is contained within the LookupTableChoices class. | ||||
| You can customize any aspect by subclassing LookupTableChoices and overriding methods. | ||||
|  | ||||
| Usage: | ||||
|     # Basic usage | ||||
|     choices = LookupTableChoices() | ||||
|     V.set_choices_handler(choices) | ||||
|  | ||||
|     # Custom usage | ||||
|     class MyCustomChoices(LookupTableChoices): | ||||
|         def _get_lookup_table(self): | ||||
|             return my_custom_table | ||||
|  | ||||
|         def make_lookup_key(self, kernel_inputs, op_name, include_device=False): | ||||
|             return f"custom_{op_name}_{hash(str(kernel_inputs))}" | ||||
|  | ||||
|     V.set_choices_handler(MyCustomChoices()) | ||||
| """ | ||||
|  | ||||
| from . import recorder | ||||
| from .choices import LookupTableChoices | ||||
|  | ||||
|  | ||||
| __all__ = [ | ||||
|     "LookupTableChoices", | ||||
|     "recorder", | ||||
| ] | ||||
							
								
								
									
										424
									
								
								torch/_inductor/lookup_table/choices.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										424
									
								
								torch/_inductor/lookup_table/choices.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,424 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import copy | ||||
| import logging | ||||
| from functools import lru_cache | ||||
| from typing import Any, Optional, TYPE_CHECKING, Union | ||||
|  | ||||
| import torch | ||||
| from torch._inductor import config | ||||
| from torch._inductor.choices import InductorChoices | ||||
| from torch._inductor.kernel_template_choice import KernelTemplateChoice | ||||
| from torch._inductor.template_heuristics.params import DictKernelTemplateParams | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from collections.abc import Generator | ||||
|  | ||||
|     from torch._inductor.codegen.common import KernelTemplate | ||||
|     from torch._inductor.kernel_inputs import KernelInputs | ||||
|     from torch._inductor.select_algorithm import ExternKernelChoice | ||||
|  | ||||
|  | ||||
| class LookupTableChoices(InductorChoices): | ||||
|     """ | ||||
|     InductorChoices subclass that uses lookup table when available, otherwise falls back to parent. | ||||
|     All lookup functionality is contained within this class and can be customized by overriding methods. | ||||
|     """ | ||||
|  | ||||
|     def _get_lookup_table(self) -> dict[str, list[dict[str, Any]]]: | ||||
|         """ | ||||
|         Get the template lookup table from config. | ||||
|         Override this method to use custom lookup table sources (database, API, etc.). | ||||
|         """ | ||||
|         if not torch.cuda.is_available() or config.lookup_table.table is None: | ||||
|             return {} | ||||
|         return config.lookup_table.table | ||||
|  | ||||
|     @staticmethod | ||||
|     @lru_cache | ||||
|     def _get_device_key(device: torch.device) -> Optional[str]: | ||||
|         """ | ||||
|         Generate a device key for lookup table indexing. | ||||
|         For CPU devices, returns None. | ||||
|         For CUDA devices, returns the props.gcnArchName string. | ||||
|         """ | ||||
|         if device.type != "cuda": | ||||
|             # only cuda devices are supported, this indicates that the system is not in use | ||||
|             # for this device | ||||
|             return None | ||||
|  | ||||
|         # Get CUDA device properties | ||||
|         props = torch.cuda.get_device_properties(device.index) | ||||
|         return props.gcnArchName | ||||
|  | ||||
|     @staticmethod | ||||
|     def _generate_kernel_inputs_key(kernel_inputs: KernelInputs) -> str: | ||||
|         """ | ||||
|         Generate a key based on input node properties and scalars. | ||||
|         The key includes dtype, size, and stride information for each input node, | ||||
|         plus scalar values as key=value pairs separated by & signs. | ||||
|         """ | ||||
|         # Get node information using existing methods | ||||
|         dtypes = kernel_inputs.dtypes() | ||||
|         shapes = kernel_inputs.shapes_hinted() | ||||
|         strides = kernel_inputs.strides_hinted() | ||||
|  | ||||
|         # Create tuple of (dtype, shape_list, stride_list) for each node | ||||
|         node_info = tuple( | ||||
|             (dtype, list(shape), list(stride)) | ||||
|             for dtype, shape, stride in zip(dtypes, shapes, strides) | ||||
|         ) | ||||
|  | ||||
|         # Create base key from node information | ||||
|         fmt_key = str(node_info) | ||||
|         # Add scalar information if present | ||||
|         if kernel_inputs._scalars: | ||||
|             # Sort scalars for consistent key generation and join with & | ||||
|             scalar_parts = [ | ||||
|                 f"{key}={value}" | ||||
|                 for key, value in sorted(kernel_inputs._scalars.items()) | ||||
|             ] | ||||
|             scalars_key = "&".join(scalar_parts) | ||||
|             fmt_key = f"{fmt_key}+{scalars_key}" | ||||
|  | ||||
|         return f"{fmt_key}" | ||||
|  | ||||
|     def make_lookup_key( | ||||
|         self, kernel_inputs: KernelInputs, op_name: str, include_device: bool = False | ||||
|     ) -> Optional[str]: | ||||
|         """ | ||||
|         Create a flattened lookup key from kernel inputs and operation name. | ||||
|         Override this method to customize key generation. | ||||
|  | ||||
|         Args: | ||||
|             kernel_inputs: KernelInputs object containing input nodes and scalars | ||||
|             op_name: Operation name (e.g., "mm", "addmm") | ||||
|             include_device: Whether to include device key in the generated key | ||||
|  | ||||
|         Returns: | ||||
|             A string key combining device (optional), operation, and input information | ||||
|         """ | ||||
|         device = kernel_inputs.device() | ||||
|         dev_key = self._get_device_key(device) | ||||
|         if dev_key is None: | ||||
|             # The system does not run when dev_key is None, regardless of | ||||
|             # whether include_device is True or False | ||||
|             return None | ||||
|         if not include_device: | ||||
|             dev_key = None | ||||
|  | ||||
|         # Generate input key using our staticmethod | ||||
|         input_key = self._generate_kernel_inputs_key(kernel_inputs) | ||||
|  | ||||
|         # Create the flattened lookup key | ||||
|         if dev_key is not None: | ||||
|             key_parts = [dev_key, input_key, op_name] | ||||
|         else: | ||||
|             key_parts = [input_key, op_name] | ||||
|  | ||||
|         return "+".join(key_parts) | ||||
|  | ||||
|     def make_lookup_key_variants( | ||||
|         self, kernel_inputs: KernelInputs, op_name: str | ||||
|     ) -> tuple[Optional[str], Optional[str]]: | ||||
|         """ | ||||
|         Generate both device-specific and device-agnostic lookup keys. | ||||
|         Override this method to customize key variant generation. | ||||
|  | ||||
|         Args: | ||||
|             kernel_inputs: KernelInputs object containing input nodes and scalars | ||||
|             op_name: Operation name (e.g., "mm", "addmm") | ||||
|  | ||||
|         Returns: | ||||
|             Tuple of (device_key, device_agnostic_key). Either may be None if generation fails. | ||||
|         """ | ||||
|         device_key = self.make_lookup_key(kernel_inputs, op_name, include_device=True) | ||||
|         device_agnostic_key = self.make_lookup_key( | ||||
|             kernel_inputs, op_name, include_device=False | ||||
|         ) | ||||
|  | ||||
|         return device_key, device_agnostic_key | ||||
|  | ||||
|     @staticmethod | ||||
|     def _entry_is_valid( | ||||
|         cfg: dict[str, Any], | ||||
|         template_id: str, | ||||
|         template_hash_map: Optional[dict[str, Optional[str]]], | ||||
|     ) -> bool: | ||||
|         """ | ||||
|         Check if a config entry is valid based on template hash validation. | ||||
|  | ||||
|         Args: | ||||
|             cfg: Configuration dictionary that may contain a template_hash field | ||||
|             template_id: The template identifier | ||||
|             template_hash_map: Optional mapping from template_uid to src_hash for validation | ||||
|  | ||||
|         Returns: | ||||
|             True if the config is valid and should be kept, False if it should be filtered out | ||||
|         """ | ||||
|         # If hash checking is disabled or no hash map provided, keep the config | ||||
|         if not config.lookup_table.check_src_hash or not template_hash_map: | ||||
|             return True | ||||
|  | ||||
|         template_hash = template_hash_map.get(template_id) | ||||
|         config_hash = cfg.get("template_hash") | ||||
|  | ||||
|         # Both hashes present - validate they match | ||||
|         if template_hash is not None and config_hash is not None: | ||||
|             if config_hash != template_hash: | ||||
|                 log.warning( | ||||
|                     "Hash validation failed for template '%s': config_hash='%s' != template_hash='%s'. " | ||||
|                     "Template code may have changed. Filtering out config: %s", | ||||
|                     template_id, | ||||
|                     config_hash, | ||||
|                     template_hash, | ||||
|                     {k: v for k, v in cfg.items() if k != "template_hash"}, | ||||
|                 ) | ||||
|                 return False | ||||
|             else: | ||||
|                 log.debug( | ||||
|                     "Hash validation passed for template '%s': hash='%s'", | ||||
|                     template_id, | ||||
|                     template_hash, | ||||
|                 ) | ||||
|                 return True | ||||
|         # Config has no hash - keep it | ||||
|         elif config_hash is None: | ||||
|             log.debug( | ||||
|                 "Config for template '%s' has no hash - keeping it (template_hash='%s')", | ||||
|                 template_id, | ||||
|                 template_hash, | ||||
|             ) | ||||
|             return True | ||||
|         # Template has no hash - keep config | ||||
|         else: | ||||
|             log.debug( | ||||
|                 "Template '%s' has no src_hash - keeping config with hash '%s'", | ||||
|                 template_id, | ||||
|                 config_hash, | ||||
|             ) | ||||
|             return True | ||||
|  | ||||
|     def lookup_template_configs( | ||||
|         self, | ||||
|         kernel_inputs: KernelInputs, | ||||
|         op_name: str, | ||||
|         template_uids: list[str], | ||||
|         template_hash_map: Optional[dict[str, Optional[str]]] = None, | ||||
|     ) -> dict[str, list[dict[str, Any]]]: | ||||
|         """ | ||||
|         Unified function to look up template configurations for multiple templates. | ||||
|         Override this method to customize lookup logic. | ||||
|  | ||||
|         Args: | ||||
|             kernel_inputs: KernelInputs object containing input nodes and scalars | ||||
|             op_name: Operation name (e.g., "mm", "addmm") | ||||
|             template_uids: List of template identifiers (e.g., ["mm", "tma", "decompose_k"]) | ||||
|             template_hash_map: Optional mapping from template_uid to src_hash for validation | ||||
|  | ||||
|         Returns: | ||||
|             {}: No lookup table in use, or no matches found for any template | ||||
|             {"template_uid1": [config1, config2], ...}: Matches found, filtered configurations | ||||
|         """ | ||||
|         lookup_table = self._get_lookup_table() | ||||
|         if not lookup_table: | ||||
|             log.debug("Lookup table: no table configured or CUDA unavailable") | ||||
|             return {} | ||||
|  | ||||
|         # Try both key variants: device-specific first, then device-agnostic | ||||
|         # If both exist, device-specific takes priority | ||||
|         device_key, device_agnostic_key = self.make_lookup_key_variants( | ||||
|             kernel_inputs, op_name | ||||
|         ) | ||||
|  | ||||
|         config_list = [] | ||||
|  | ||||
|         for key_type, key in [ | ||||
|             ("device-specific", device_key), | ||||
|             ("device-agnostic", device_agnostic_key), | ||||
|         ]: | ||||
|             if key is not None: | ||||
|                 config_list = lookup_table.get(key, []) | ||||
|                 if config_list: | ||||
|                     log.debug( | ||||
|                         "Lookup table: found %d configs using %s key '%s' for %s", | ||||
|                         len(config_list), | ||||
|                         key_type, | ||||
|                         key, | ||||
|                         op_name, | ||||
|                     ) | ||||
|                     break | ||||
|         else: | ||||
|             log.debug( | ||||
|                 "Lookup table: no match for %s (tried keys: %s, %s) (table has %d keys)", | ||||
|                 op_name, | ||||
|                 device_key, | ||||
|                 device_agnostic_key, | ||||
|                 len(lookup_table), | ||||
|             ) | ||||
|             return {} | ||||
|  | ||||
|         log.debug( | ||||
|             "Lookup table: found %d configs for %s templates %s", | ||||
|             len(config_list), | ||||
|             op_name, | ||||
|             template_uids, | ||||
|         ) | ||||
|         # Group configs by template_id | ||||
|         configs_by_template: dict[str, list[dict[str, Any]]] = {} | ||||
|         for cfg in config_list: | ||||
|             if not isinstance(cfg, dict): | ||||
|                 raise ValueError( | ||||
|                     f"Config for {op_name} operation is not a dictionary: {cfg}" | ||||
|                 ) | ||||
|             if "template_id" not in cfg: | ||||
|                 raise ValueError( | ||||
|                     f"Config for {op_name} operation missing required 'template_id' field: {cfg}" | ||||
|                 ) | ||||
|  | ||||
|             template_id = cfg["template_id"] | ||||
|             if template_id in template_uids: | ||||
|                 if template_id not in configs_by_template: | ||||
|                     configs_by_template[template_id] = [] | ||||
|                 configs_by_template[template_id].append(cfg) | ||||
|  | ||||
|         # Check template hashes and clean up template_id field | ||||
|         result = {} | ||||
|         for template_id, matching_configs in configs_by_template.items(): | ||||
|             filtered_configs = [] | ||||
|             for cfg in matching_configs: | ||||
|                 # Check template hash using helper function | ||||
|                 if not self._entry_is_valid(cfg, template_id, template_hash_map): | ||||
|                     continue | ||||
|  | ||||
|                 # Return a copy of the config, as we don't want to modify the original | ||||
|                 cconfig = copy.deepcopy(cfg) | ||||
|                 # Lastly, we have to throw out the template_id, as it's not a valid kwarg | ||||
|                 # and just used to identify which template the entry belongs to | ||||
|                 del cconfig["template_id"] | ||||
|                 # Similarly, the template_hash is not a valid kwarg | ||||
|                 cconfig.pop("template_hash", None) | ||||
|                 filtered_configs.append(cconfig) | ||||
|  | ||||
|             if filtered_configs: | ||||
|                 result[template_id] = filtered_configs | ||||
|  | ||||
|         return result | ||||
|  | ||||
|     def _finalize_template_configs( | ||||
|         self, | ||||
|         template_choices: dict[str, Generator[KernelTemplateChoice, None, None]], | ||||
|         kernel_inputs: KernelInputs, | ||||
|         templates: list[Union[KernelTemplate, ExternKernelChoice]], | ||||
|         op_name: str, | ||||
|         kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, | ||||
|     ) -> list[KernelTemplateChoice]: | ||||
|         """Check lookup table for hits, use those if found, otherwise fall back to parent.""" | ||||
|         # Early exit if lookup table system is disabled - call super directly | ||||
|         if not config.lookup_table.active: | ||||
|             return super()._finalize_template_configs( | ||||
|                 template_choices, kernel_inputs, templates, op_name, kwarg_overrides | ||||
|             ) | ||||
|  | ||||
|         # 1. Collect template src_hashes for validation | ||||
|         template_uids = [template.uid for template in templates] | ||||
|         template_hash_map = {} | ||||
|         for template in templates: | ||||
|             src_hash = getattr(template, "src_hash", None) | ||||
|             template_hash_map[template.uid] = src_hash | ||||
|  | ||||
|         log.debug( | ||||
|             "Choices: attempting lookup for %s with %d templates", | ||||
|             op_name, | ||||
|             len(template_uids), | ||||
|         ) | ||||
|  | ||||
|         # 2. Single batch lookup for all templates | ||||
|         lookup_results = self.lookup_template_configs( | ||||
|             kernel_inputs, op_name, template_uids, template_hash_map | ||||
|         ) | ||||
|  | ||||
|         # 3. Early exit if no lookup table or no matches | ||||
|         if not lookup_results:  # Empty dict | ||||
|             log.info("LookupChoices: lookup miss for %s, using fallback", op_name) | ||||
|             return self._fallback( | ||||
|                 template_choices, | ||||
|                 kernel_inputs, | ||||
|                 templates, | ||||
|                 op_name, | ||||
|                 kwarg_overrides, | ||||
|             ) | ||||
|  | ||||
|         log.info( | ||||
|             "LookupChoices: lookup hit for %s - found %d/%d templates: %s", | ||||
|             op_name, | ||||
|             len(lookup_results), | ||||
|             len(template_uids), | ||||
|             list(lookup_results.keys()), | ||||
|         ) | ||||
|  | ||||
|         # 4. Create KTCs only for templates with lookup entries | ||||
|         return self._create_lookup_choices( | ||||
|             lookup_results, templates, kernel_inputs, op_name | ||||
|         ) | ||||
|  | ||||
|     def _fallback( | ||||
|         self, | ||||
|         template_choices: dict[str, Generator[KernelTemplateChoice, None, None]], | ||||
|         kernel_inputs: KernelInputs, | ||||
|         templates: list[Union[KernelTemplate, ExternKernelChoice]], | ||||
|         op_name: str, | ||||
|         kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, | ||||
|     ) -> list[KernelTemplateChoice]: | ||||
|         """Fallback to parent if no lookup table or no matches.""" | ||||
|         # NOTE: this is broken out, so that subclasses are able to override this | ||||
|         # to handle explicitly the situations where the lookup take had a miss vs | ||||
|         # overriding the entire logic | ||||
|         return super()._finalize_template_configs( | ||||
|             template_choices, | ||||
|             kernel_inputs, | ||||
|             templates, | ||||
|             op_name, | ||||
|             kwarg_overrides, | ||||
|         ) | ||||
|  | ||||
|     def _create_lookup_choices( | ||||
|         self, | ||||
|         lookup_results: dict[str, list[dict[str, Any]]], | ||||
|         templates: list[Union[KernelTemplate, ExternKernelChoice]], | ||||
|         kernel_inputs: KernelInputs, | ||||
|         op_name: str, | ||||
|     ) -> list[KernelTemplateChoice]: | ||||
|         """Create KernelTemplateChoice objects from lookup results using parent's get_ktc method.""" | ||||
|         templates_by_uid = {template.uid: template for template in templates} | ||||
|         lookup_choices: list[KernelTemplateChoice] = [] | ||||
|  | ||||
|         for template_uid, configs in lookup_results.items(): | ||||
|             template = templates_by_uid[template_uid] | ||||
|  | ||||
|             # Use parent's get_ktc method to get a generator, then get the first base KTC | ||||
|             ktc_generator = self.get_ktc(kernel_inputs, template, op_name) | ||||
|  | ||||
|             try: | ||||
|                 base_ktc = next(ktc_generator) | ||||
|             except StopIteration: | ||||
|                 # No configs from heuristic, skip this template | ||||
|                 continue | ||||
|  | ||||
|             # For each lookup config, create a KTC with the override kwargs | ||||
|             for c in configs: | ||||
|                 lookup_ktc = KernelTemplateChoice( | ||||
|                     template=base_ktc.template, | ||||
|                     # use the ones from the lookup table | ||||
|                     params=DictKernelTemplateParams(c), | ||||
|                     extra_kwargs=base_ktc.extra_kwargs, | ||||
|                     layout=base_ktc.layout, | ||||
|                     inputs=base_ktc.inputs, | ||||
|                 ) | ||||
|                 lookup_choices.append(lookup_ktc) | ||||
|  | ||||
|         return lookup_choices | ||||
							
								
								
									
										415
									
								
								torch/_inductor/lookup_table/recorder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										415
									
								
								torch/_inductor/lookup_table/recorder.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,415 @@ | ||||
| """ | ||||
| Lookup table recorder system for capturing autotuning results. | ||||
|  | ||||
| This module provides a system to record and emit autotuning results from kernel selection. | ||||
| It supports both immediate emission (logging) and table recording (building lookup tables). | ||||
| """ | ||||
|  | ||||
| import json | ||||
| import logging | ||||
| import os | ||||
| from abc import ABC, abstractmethod | ||||
| from dataclasses import dataclass | ||||
| from datetime import datetime, timezone | ||||
| from typing import Any, Callable, Optional | ||||
|  | ||||
| from torch._inductor import config | ||||
| from torch._inductor.ir import ChoiceCaller | ||||
| from torch._inductor.kernel_inputs import KernelInputs | ||||
| from torch._inductor.kernel_template_choice import KernelTemplateChoice | ||||
| from torch.utils._ordered_set import OrderedSet | ||||
|  | ||||
| from .choices import LookupTableChoices | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class LookupTableEntry: | ||||
|     """Single entry representing one autotuning result""" | ||||
|  | ||||
|     key: str  # device_key+op_name+input_key | ||||
|     value: dict[str, Any]  # Contains template_id and all kwargs | ||||
|     metadata: dict[str, Any]  # Contains timing, rank, and other recording metadata | ||||
|     runtime: ( | ||||
|         float  # Contains the unique timing information. This is used to record topk | ||||
|     ) | ||||
|  | ||||
|     @classmethod | ||||
|     def from_ktc_and_timing( | ||||
|         cls, | ||||
|         ktc: KernelTemplateChoice, | ||||
|         timing: float, | ||||
|         rank: int, | ||||
|         op_name: str, | ||||
|     ) -> Optional["LookupTableEntry"]: | ||||
|         """Create entry from a KTC and its timing""" | ||||
|         # KTC must have a template - this is a requirement | ||||
|         assert ktc.template is not None, "KernelTemplateChoice must have a template" | ||||
|  | ||||
|         # Use V.choices_handler to make lookup key if it's a LookupTableChoices instance | ||||
|         key = _make_lookup_key( | ||||
|             ktc.inputs, op_name, config.lookup_table.record_with_device_key | ||||
|         ) | ||||
|         if key is None: | ||||
|             return None | ||||
|  | ||||
|         # Build value dict from KTC kwargs | ||||
|         value = dict(template_id=ktc.template.uid, **ktc.params.to_serializeable_dict()) | ||||
|  | ||||
|         # Add template hash if available and configured | ||||
|         if config.lookup_table.record_template_hash: | ||||
|             # Use src_hash directly from the template | ||||
|             template_hash = getattr(ktc.template, "src_hash", None) | ||||
|             if template_hash is not None: | ||||
|                 value["template_hash"] = template_hash | ||||
|  | ||||
|         # Create metadata dict with timing and rank info | ||||
|         metadata = { | ||||
|             "timing": timing, | ||||
|             "rank": rank, | ||||
|         } | ||||
|  | ||||
|         return cls(key=key, value=value, metadata=metadata, runtime=timing) | ||||
|  | ||||
|  | ||||
| def _make_lookup_key( | ||||
|     kernel_inputs: KernelInputs, op_name: str, include_device: bool = False | ||||
| ) -> Optional[str]: | ||||
|     """Make lookup key using V.choices_handler if available, otherwise use LookupTableChoices static methods""" | ||||
|     from torch._inductor.virtualized import V | ||||
|  | ||||
|     if hasattr(V, "choices_handler") and isinstance( | ||||
|         V.choices_handler, LookupTableChoices | ||||
|     ): | ||||
|         return V.choices_handler.make_lookup_key(kernel_inputs, op_name, include_device) | ||||
|     else: | ||||
|         # Fallback: create a temporary LookupTableChoices instance to use its methods | ||||
|         choices_handler = LookupTableChoices() | ||||
|         return choices_handler.make_lookup_key(kernel_inputs, op_name, include_device) | ||||
|  | ||||
|  | ||||
| class Backend: | ||||
|     """Base class for backends""" | ||||
|  | ||||
|     def clear(self) -> None: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class EmitBackend(ABC, Backend): | ||||
|     """Backend for immediate emission of single entries""" | ||||
|  | ||||
|     @abstractmethod | ||||
|     def emit(self, entry: LookupTableEntry) -> None: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class RecordBackend(ABC, Backend): | ||||
|     """Backend for dumping recorded table""" | ||||
|  | ||||
|     @abstractmethod | ||||
|     def dump(self, data: dict[str, list[LookupTableEntry]]) -> None: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| # Track registered backends to avoid double registration | ||||
| _registered_backends: OrderedSet[Any] = OrderedSet() | ||||
|  | ||||
|  | ||||
| def _backend_key( | ||||
|     backend_class: type, kwargs: dict[str, Any] | ||||
| ) -> tuple[type, frozenset[Any]]: | ||||
|     """Create a unique key for backend class + kwargs""" | ||||
|     return (backend_class, frozenset(kwargs.items()) if kwargs else frozenset()) | ||||
|  | ||||
|  | ||||
| class LogEmitBackend(EmitBackend): | ||||
|     """Default emit backend that logs entries""" | ||||
|  | ||||
|     def emit(self, entry: LookupTableEntry) -> None: | ||||
|         log.debug("LookupTable: %r -> %r", entry.key, entry.value) | ||||
|  | ||||
|  | ||||
| class DirectoryRecordBackend(RecordBackend): | ||||
|     """Default record backend that saves to timestamped files in a directory""" | ||||
|  | ||||
|     def __init__(self, directory: str): | ||||
|         self.directory = directory | ||||
|         # Generate timestamped filename with 3-digit millisecond precision | ||||
|         self.setup() | ||||
|  | ||||
|     def setup(self) -> None: | ||||
|         now = datetime.now(tz=timezone.utc) | ||||
|         timestamp = now.strftime("%Y%m%d_%H%M%S") + f"_{now.microsecond // 1000:03d}" | ||||
|         filename = f"inductor_lut_{timestamp}.json" | ||||
|         self.filepath = os.path.join(self.directory, filename) | ||||
|  | ||||
|     def dump(self, data: dict[str, list[LookupTableEntry]]) -> None: | ||||
|         # Create directory if it doesn't exist | ||||
|         os.makedirs(self.directory, exist_ok=True) | ||||
|  | ||||
|         # extract only the value from the entries and dump those | ||||
|         data_values = {} | ||||
|         for k, entries in data.items(): | ||||
|             data_values[k] = [e.value for e in entries] | ||||
|         # just override it again | ||||
|         with open(self.filepath, "w") as f: | ||||
|             json.dump(data_values, f, indent=2) | ||||
|  | ||||
|     def clear(self) -> None: | ||||
|         # generate a new path | ||||
|         self.setup() | ||||
|  | ||||
|  | ||||
| class LookupTableRecorder: | ||||
|     """Main recorder that manages both emit and record backends""" | ||||
|  | ||||
|     def __init__(self, topk: Optional[int] = None) -> None: | ||||
|         self.data: dict[str, list[LookupTableEntry]] = {} | ||||
|         self.emit_backends: list[EmitBackend] = [] | ||||
|         self.record_backends: list[RecordBackend] = [] | ||||
|         # Use provided topk or fall back to config | ||||
|         self.topk = topk if topk is not None else config.lookup_table.recorder_topk | ||||
|  | ||||
|     @property | ||||
|     def input_entries(self) -> int: | ||||
|         """how many unique input entries have been recorded""" | ||||
|         return len(self.data) | ||||
|  | ||||
|     def add_backend(self, backend: Backend) -> None: | ||||
|         """Add a backend to the appropriate list based on its type""" | ||||
|         if isinstance(backend, EmitBackend): | ||||
|             self.emit_backends.append(backend) | ||||
|         elif isinstance(backend, RecordBackend): | ||||
|             self.record_backends.append(backend) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 f"Backend must be an instance of EmitBackend or RecordBackend, " | ||||
|                 f"got {type(backend).__name__}" | ||||
|             ) | ||||
|  | ||||
|     def emit(self, entry: LookupTableEntry) -> None: | ||||
|         """Emit a single entry immediately""" | ||||
|         for backend in self.emit_backends: | ||||
|             backend.emit(entry) | ||||
|  | ||||
|     def record(self, entry: LookupTableEntry) -> None: | ||||
|         """Record entry to table and emit it""" | ||||
|         # Always emit when recording | ||||
|         self.emit(entry) | ||||
|         # Initialize key if not exists | ||||
|         if entry.key not in self.data: | ||||
|             self.data[entry.key] = [] | ||||
|  | ||||
|         # just insert and sort, it's a small topk usually | ||||
|         # not worth doing bisection | ||||
|         entries_for_key = self.data[entry.key] | ||||
|         entries_for_key.append(entry) | ||||
|         entries_for_key.sort( | ||||
|             key=lambda x: x.runtime if x.runtime is not None else float("inf") | ||||
|         ) | ||||
|  | ||||
|         # Trim to topk if necessary (only if topk is positive) | ||||
|         if self.topk is not None and self.topk > 0 and len(entries_for_key) > self.topk: | ||||
|             topk: int = self.topk | ||||
|             # Log which entries we're replacing | ||||
|             replaced_entries = entries_for_key[topk:] | ||||
|             log.info( | ||||
|                 "Replacing %d entries with new entry (value: %r, runtime: %f) due to topk=%d. " | ||||
|                 "Replaced entries: %s", | ||||
|                 len(replaced_entries), | ||||
|                 entry.value, | ||||
|                 entry.runtime, | ||||
|                 self.topk, | ||||
|                 [{"value": e.value, "runtime": e.runtime} for e in replaced_entries], | ||||
|             ) | ||||
|             del entries_for_key[topk:] | ||||
|  | ||||
|     def dump(self) -> None: | ||||
|         """Dump via all record backends""" | ||||
|         for backend in self.record_backends: | ||||
|             backend.dump(self.data) | ||||
|  | ||||
|     def clear(self) -> None: | ||||
|         self.data.clear() | ||||
|         for backend in self.emit_backends + self.record_backends: | ||||
|             backend.clear() | ||||
|  | ||||
|  | ||||
| # Module-wide instance | ||||
| _lookup_table_recorder: Optional[LookupTableRecorder] = None | ||||
|  | ||||
|  | ||||
| def get_lookup_table_recorder() -> LookupTableRecorder: | ||||
|     """Get the global lookup table recorder""" | ||||
|     global _lookup_table_recorder | ||||
|     if _lookup_table_recorder is None: | ||||
|         _lookup_table_recorder = LookupTableRecorder() | ||||
|     # Always register any pending backends | ||||
|     _register_pending_backends(_lookup_table_recorder) | ||||
|     assert _lookup_table_recorder is not None | ||||
|     return _lookup_table_recorder | ||||
|  | ||||
|  | ||||
| def add_backend(backend: Backend) -> None: | ||||
|     """Add a backend to the global lookup table recorder""" | ||||
|     recorder = get_lookup_table_recorder() | ||||
|     if recorder is not None: | ||||
|         recorder.add_backend(backend) | ||||
|  | ||||
|  | ||||
| def _register_pending_backends(recorder: LookupTableRecorder) -> None: | ||||
|     """Register built-in backends based on current config""" | ||||
|     global _registered_backends | ||||
|  | ||||
|     # Add built-in LogEmitBackend if enabled and not already registered | ||||
|     if config.lookup_table.recorder_emit: | ||||
|         emit_key = _backend_key(LogEmitBackend, {}) | ||||
|         if emit_key not in _registered_backends: | ||||
|             try: | ||||
|                 recorder.add_backend(LogEmitBackend()) | ||||
|                 _registered_backends.add(emit_key) | ||||
|                 log.debug("Registered LogEmitBackend") | ||||
|             except Exception as e: | ||||
|                 log.warning("Failed to register LogEmitBackend: %r", e) | ||||
|  | ||||
|     # Add built-in DirectoryRecordBackend if enabled and not already registered | ||||
|     record_dir = config.lookup_table.recorder_record_dir | ||||
|     if record_dir: | ||||
|         record_key = _backend_key(DirectoryRecordBackend, {"directory": record_dir}) | ||||
|         if record_key not in _registered_backends: | ||||
|             try: | ||||
|                 recorder.add_backend(DirectoryRecordBackend(record_dir)) | ||||
|                 _registered_backends.add(record_key) | ||||
|                 log.debug("Registered DirectoryRecordBackend: %s", record_dir) | ||||
|             except Exception as e: | ||||
|                 log.warning( | ||||
|                     "Failed to register DirectoryRecordBackend %s: %r", record_dir, e | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| def record_topk_choices( | ||||
|     timings: dict[ChoiceCaller, float], | ||||
|     op_name: str, | ||||
|     input_nodes: list[Any], | ||||
|     choices: list[ChoiceCaller], | ||||
|     profiled_time_fn: Callable[[], dict[Any, Any]], | ||||
| ) -> None: | ||||
|     """ | ||||
|     Feedback function to record topk choices based on timing results. | ||||
|  | ||||
|     Args: | ||||
|         timings: Mapping from choices to benchmark times | ||||
|         op_name: Name of the operation (e.g. "mm", "addmm") | ||||
|         input_nodes: List of input ir.Nodes | ||||
|         choices: List of ChoiceCaller objects | ||||
|         profiled_time_fn: Function to get profiled times (unused in this implementation) | ||||
|     """ | ||||
|     # Fast bail if recording not active | ||||
|     if not config.lookup_table.recording_active: | ||||
|         log.debug( | ||||
|             "Recording disabled (recording_active=False) for operation %s", op_name | ||||
|         ) | ||||
|         return | ||||
|  | ||||
|     # If topk is 0, don't record anything | ||||
|     if config.lookup_table.recorder_topk == 0: | ||||
|         log.debug("Recording disabled (topk=0) for operation %s", op_name) | ||||
|         return | ||||
|  | ||||
|     # Get recorder | ||||
|     recorder = get_lookup_table_recorder() | ||||
|     if recorder is None: | ||||
|         log.warning("Failed to get lookup table recorder for operation %s", op_name) | ||||
|         return | ||||
|     # adjust the recorder topk if necessary | ||||
|     recorder.topk = config.lookup_table.recorder_topk | ||||
|  | ||||
|     # Filter choices that have valid timings and KTC references | ||||
|     valid_choices = [] | ||||
|     filtered_count = 0 | ||||
|  | ||||
|     for choice in choices: | ||||
|         if ( | ||||
|             choice not in timings | ||||
|             or not hasattr(choice, "annotations") | ||||
|             or "ktc" not in choice.annotations | ||||
|             or timings[choice] == float("inf") | ||||
|         ): | ||||
|             filtered_count += 1 | ||||
|         else: | ||||
|             valid_choices.append(choice) | ||||
|  | ||||
|     if filtered_count > 0: | ||||
|         log.debug( | ||||
|             "Recording %s: filtered %d/%d invalid choices", | ||||
|             op_name, | ||||
|             filtered_count, | ||||
|             len(choices), | ||||
|         ) | ||||
|  | ||||
|     if not valid_choices: | ||||
|         log.debug("Recording %s: no valid choices", op_name) | ||||
|         return | ||||
|  | ||||
|     # Sort and trim to topk | ||||
|     sorted_choices = sorted(valid_choices, key=lambda c: timings[c]) | ||||
|     if recorder.topk and recorder.topk > 0 and len(sorted_choices) > recorder.topk: | ||||
|         sorted_choices = sorted_choices[: recorder.topk] | ||||
|  | ||||
|     # Record each choice | ||||
|     recorded_count = 0 | ||||
|     for rank, choice in enumerate(sorted_choices): | ||||
|         ktc = choice.annotations["ktc"] | ||||
|         if ktc is None: | ||||
|             log.warning( | ||||
|                 "Recording %s: KTC is None for choice %s", | ||||
|                 op_name, | ||||
|                 getattr(choice, "name", choice), | ||||
|             ) | ||||
|             continue | ||||
|  | ||||
|         entry = LookupTableEntry.from_ktc_and_timing( | ||||
|             ktc=ktc, timing=timings[choice], rank=rank, op_name=op_name | ||||
|         ) | ||||
|         if entry is not None: | ||||
|             recorder.record(entry) | ||||
|             recorded_count += 1 | ||||
|  | ||||
|     log.info( | ||||
|         "Recording %s: saved %d/%d entries", | ||||
|         op_name, | ||||
|         recorded_count, | ||||
|         len(sorted_choices), | ||||
|     ) | ||||
|     # Any time we record the table, we continue to dump | ||||
|     # The backends need to be implemented so that they can | ||||
|     # accommodate progressive dumping | ||||
|     recorder.dump() | ||||
|  | ||||
|  | ||||
| def dump() -> None: | ||||
|     """Dump the global lookup table recorder""" | ||||
|     recorder = get_lookup_table_recorder() | ||||
|     if recorder is not None: | ||||
|         recorder.dump() | ||||
|  | ||||
|  | ||||
| def clear() -> None: | ||||
|     """Clear the global lookup table recorder""" | ||||
|     recorder = get_lookup_table_recorder() | ||||
|     if recorder is not None: | ||||
|         recorder.clear() | ||||
|  | ||||
|  | ||||
| # Auto-register the feedback function when the module is imported | ||||
| try: | ||||
|     from torch._inductor.select_algorithm import add_feedback_saver | ||||
|  | ||||
|     add_feedback_saver(record_topk_choices) | ||||
|     log.debug("Registered lookup table recorder feedback function") | ||||
| except ImportError: | ||||
|     log.warning( | ||||
|         "Failed to register lookup table recorder feedback function - select_algorithm not available" | ||||
|     ) | ||||
| @ -207,9 +207,9 @@ def _choices_default(): | ||||
|  | ||||
|     We virtualize InductorChoices to allow changing inductor heuristics from out of tree. | ||||
|     """ | ||||
|     from torch._inductor.choices import InductorChoices | ||||
|     from torch._inductor.lookup_table.choices import LookupTableChoices | ||||
|  | ||||
|     rv = InductorChoices() | ||||
|     rv = LookupTableChoices() | ||||
|     setattr(threadlocal, _choices._key, rv) | ||||
|     return rv | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	