Compare commits

...

9 Commits

Author SHA1 Message Date
95ad31ab20 Update on "[inductor][choices] use LookupTableChoices by default 3/3"
\# why

- an empty table is a noop
- reduce friction for users to use the table

\# what

- virtualized uses LookupTableChoices as the default InductorChoices
- lookup table behavior is gated by an explicit bool that defaults to
  True, in addition to having any entries in the table at all

\# testing

- existing unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
2025-10-27 12:33:02 -07:00
ac6c125329 Update base for Update on "[inductor][choices] use LookupTableChoices by default 3/3"
\# why

- an empty table is a noop
- reduce friction for users to use the table

\# what

- virtualized uses LookupTableChoices as the default InductorChoices
- lookup table behavior is gated by an explicit bool that defaults to
  True, in addition to having any entries in the table at all

\# testing

- existing unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
2025-10-27 12:33:02 -07:00
0b41d9a6e0 Update on "[inductor][choices] use LookupTableChoices by default 3/3"
\# why

- an empty table is a noop
- reduce friction for users to use the table

\# what

- virtualized uses LookupTableChoices as the default InductorChoices
- lookup table behavior is gated by an explicit bool that defaults to
  True, in addition to having any entries in the table at all

\# testing

- existing unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
2025-10-27 10:22:03 -07:00
39f9c13974 Update base for Update on "[inductor][choices] use LookupTableChoices by default 3/3"
\# why

- an empty table is a noop
- reduce friction for users to use the table

\# what

- virtualized uses LookupTableChoices as the default InductorChoices
- lookup table behavior is gated by an explicit bool that defaults to
  True, in addition to having any entries in the table at all

\# testing

- existing unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
2025-10-27 10:22:03 -07:00
fd90e9d082 Update on "[inductor][choices] use LookupTableChoices by default 3/3"
\# why

- an empty table is a noop
- reduce friction for users to use the table

\# what

- virtualized uses LookupTableChoices as the default InductorChoices
- lookup table behavior is gated by an explicit bool that defaults to
  True, in addition to having any entries in the table at all

\# testing

- existing unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
2025-10-21 15:12:48 -07:00
30c8f0f3ad Update base for Update on "[inductor][choices] use LookupTableChoices by default 3/3"
\# why

- an empty table is a noop
- reduce friction for users to use the table

\# what

- virtualized uses LookupTableChoices as the default InductorChoices
- lookup table behavior is gated by an explicit bool that defaults to
  True, in addition to having any entries in the table at all

\# testing

- existing unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
2025-10-21 15:12:48 -07:00
03bd16b815 [inductor][choices] use LookupTableChoices by default 3/3
\# why

- an empty table is a noop
- reduce friction for users to use the table

\# what

- virtualized uses LookupTableChoices as the default InductorChoices
- lookup table behavior is gated by an explicit bool that defaults to
  True, in addition to having any entries in the table at all

\# testing

- existing unit tests

[ghstack-poisoned]
2025-10-08 12:13:06 -07:00
b366fc40ce [inductor][lookup table] add recorder 2/3
\# why

- enable users to record full tables to plug into the system

\# what

- recorder system
- default recorders that
  - write individual entries into debug logs
  - dump table progressively into output directory by setting
    `TORCH_INDUCTOR_LOOKUP_TABLE_RECORD_DIR`

\# testing

- added new unit tests

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -v
```

[ghstack-poisoned]
2025-10-08 12:13:01 -07:00
e5bf90ba39 [inductor][choices] lookup table choices 1/3
\# why

- enable users to control which choices get used on which inputs
- reduce lowering time, and pin kernel selection, by selecting
  them for the inputs

\# what

- a new InductorChoices subclass that implements a lookup table
- a README explaining the usage
- corresponding testing

- currently only supports templates that go through
  `V.choices.get_template_configs`

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -v
```

[ghstack-poisoned]
2025-10-08 12:12:56 -07:00
7 changed files with 2862 additions and 2 deletions

File diff suppressed because it is too large Load Diff

View File

@ -2073,6 +2073,39 @@ write_are_deterministic_algorithms_enabled = (
)
class lookup_table:
# Enable/disable lookup table choices system (defaults to True)
active: bool = True
# Lookup table for template config overrides
table: Optional[dict[str, list[dict[str, Any]]]] = None
# Enable template src_hash checking in lookup table to prevent using stale configs.
# If True, configs with 'template_hash' field will be compared against the template's
# src_hash at runtime and filtered out if they don't match. If False, no
# hash checking is performed.
check_src_hash: bool = True
# Recorder configuration for capturing autotuning results
# Master switch for recording functionality - must be True to enable any recording
recording_active: bool = False
# Enable emitting entries to registered emit backends (e.g., logging)
recorder_emit: bool = True
# Directory to record lookup tables to. If set, enables DirectoryRecordBackend
recorder_record_dir: Optional[str] = None
# Number of top choices to record. If None or negative, record all choices.
# If 0, record nothing. If positive, record only the top k choices.
recorder_topk: Optional[int] = None
# Whether to record template hashes for templates that provide them.
# Template hashes make the lookup table more robust but potentially less portable
record_template_hash: bool = True
# Whether to include device key when recording lookup table entries.
# When True, entries are recorded with device-specific keys (more precise but less portable).
# When False, entries are recorded without device keys (more portable across devices).
# During lookup, both key formats are always tried regardless of this setting.
record_with_device_key: bool = True
class test_configs:
force_extern_kernel_in_multi_template: bool = False

View File

@ -0,0 +1,540 @@
# Template Lookup Table System
The template lookup table system provides a way to pre-configure kernel template parameters for specific operations and
input configurations, bypassing the default choice generation and autotuning process.
## Overview
The lookup table system replaces default choice generation with pre-configured template parameters for specific
operations and input configurations. It sits orthogonal to `max-autotune(-gemm)` in the following way
If a lookup table is provided and there is a match
- We check whether the template(s) in the match are currently in use
- If so, we use the pre-configured template(s) and config and bypass choice generation
- If more than one choice is provided, we run autotune among the pre-configured choices
- If not, we fall back to the default choice generation process, including max-autotune(-gemm) logic
If there is no match, we fall back to the default choice generation process, including max-autotune(-gemm) logic
## Configuration
Enable the system by setting both:
```python
from torch._inductor import config
config.lookup_table.table = your_table_dict
# You also need to set it as the default choice handler
from torch._inductor.lookup_table import LookupTableChoices
torch._inductor.V.set_choices_handler(LookupTableChoices())
```
### Device Key Handling
The key schema format is described in detail in the [Key Schemas](#key-schemas) section below.
Configure device key behavior:
```python
# Control whether entries include device-specific keys for lookups
# Device-agnostic entries work across different GPU models
```
**Lookup Behavior**: During lookup, the system automatically tries both key formats:
1. **Device-specific key** (e.g., `"NVIDIA H100+input_data+mm"`) - tried first
1. **Device-agnostic key** (e.g., `"input_data+mm"`) - tried if device-specific fails
**Priority**: If both device-specific and device-agnostic entries exist for the same inputs, the device-specific entry
takes priority.
**NOTE**: Device-based keys simplify hardware-specific optimization without complex build rules. Currently limited to
device name only. If you need additional conditional key attributes (e.g., CUDA version filtering), please file an issue
or submit a patch.
## Behavior
When the table is active, the following behavior occurs for all supported operations:
### Match Found
- Uses pre-configured choices from the table instead of generating default choices
- Bypasses autotuning if only a single choice is provided
- If multiple choices are provided, autotuning occurs among those choices only
### No Match Found
- Standard default behavior - generates choices using heuristics and max-autotune settings
### Table Not Set or Inactive
- Standard default behavior - generates choices using heuristics and max-autotune settings
## Supported Operations
Currently supports: `mm`, `addmm`, `bmm`, `mm_plus_mm`, `scaled_mm` operations with
- Triton
- ATEN
- DecomposeK
## Table Format
The table is a dictionary with keys in the format:
```
"input_key+op_name"
```
Where:
- `input_key`: Generated from `KernelInputs.key` property, represents tensor shapes/dtypes/strides
- `op_name`: Operation name (`"mm"`, `"addmm"`, etc.)
Each value is a list of configuration dictionaries containing:
- `template_id`: Template identifier (`"triton:mm"`, `"triton::mm_persistent_tma"`, `"decompose_k"`, etc.)
- Template-specific parameters (`BLOCK_M`, `BLOCK_N`, `BLOCK_K`, `num_warps`, etc.)
## Key Schemas
**NOTE**: The key schema format is subject to change as the system evolves.
The lookup table uses composite keys to match kernel configurations. See
[Implementation Details](#implementation-details) below for more technical information about key generation. This
section describes the structure of these keys.
### Key Format Structure
Keys follow the pattern:
```
[device_name+]input_key+[additional_params+]op_name
```
Components:
- **device_name** (optional): GPU device identifier (e.g., `"NVIDIA H100"`)
- Obtained from `torch.cuda.get_device_properties().gcnArchName`
- Enables device-specific optimizations
- When omitted, creates device-agnostic entries that work across hardware
- **input_key**: Tensor configuration representation from `KernelInputs.key`
- Format: `((dtype, shape, stride), (dtype, shape, stride), ...)`
- Each tuple represents one input tensor's properties
- Example: `((torch.float16, [128, 256], [0, 1]), (torch.float16, [64, 256], [256, 1]))`
- Order matches the operation's input argument order
- **additional_params** (optional): Operation-specific parameters
- Format: `key1=value1&key2=value2`
- Example: `alpha=1&beta=1` for addmm operations
- **op_name**: Operation identifier
- Examples: `"mm"`, `"addmm"`, `"bmm"`, `"mm_plus_mm"`, `"scaled_mm"`
### Key Examples
**Device-specific key for addmm:**
```
"NVIDIA H100+((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm"
```
**Device-agnostic key for mm:**
```
"((torch.float16, [64, 128], [128, 1]), (torch.float16, [128, 256], [256, 1]))+mm"
```
**Key with no additional parameters:**
```
"((torch.float32, [512, 512], [512, 1]), (torch.float32, [512, 512], [512, 1]))+bmm"
```
### Lookup Strategy
During lookup, the system tries keys in priority order:
1. **Device-specific key** - checked first if device information is available
1. **Device-agnostic key** - fallback if device-specific lookup fails
This allows tables to contain:
- Device-optimized configurations (higher priority)
- Portable configurations that work across devices
- Mix of both for flexible deployment
## Example Table
This is an example table for a single input showing two configurations
```python
table = {
"((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm": [
{
"template_id": "triton::mm",
"EVEN_K": true,
"USE_FAST_ACCUM": false,
"ACC_TYPE": "tl.float32",
"num_stages": 2,
"num_warps": 4,
"BLOCK_M": 32,
"BLOCK_N": 32,
"BLOCK_K": 64,
"hint_override": null,
"GROUP_M": 8,
"template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896"
},
{
"template_id": "aten::bias_addmm"
},
]
}
```
## Source Hashing Safety
The lookup table system includes source hashing to prevent using stale configurations when template code changes.
### Configuration
- **Enabled by default**: `torch._inductor.config.lookup_table.check_src_hash = True`
- **Optional field**: Add `"template_hash"` to table entries for enhanced safety
### Behavior
When source hash checking is enabled:
- Template configurations with `"template_hash"` fields are validated against current template source hashes
- Mismatched hashes indicate the template code has changed since the configuration was created
- Stale configurations are automatically filtered out with a warning message
- Configurations without hash fields are preserved for backward compatibility or if the user wants to fly looser
### Example with Template Hash
```python
{
"template_id": "triton::mm",
"BLOCK_M": 32,
"BLOCK_N": 32,
"BLOCK_K": 16,
"template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896"
}
```
## Performance Impact
- **Lookup Hit**: Eliminates heuristic choice generation and autotuning overhead (if a single choice)
- **Lookup Miss**: Default behavior, including heuristic choice generation and autotuning
- **Memory**: Table stored in memory, minimal overhead for key generation and lookup
## Implementation Details
### Key Generation
- Device key: Uses `torch.cuda.get_device_properties().gcnArchName` (e.g., "NVIDIA H100")
- Input key: Generated from `KernelInputs.key` containing tensor properties
### Entry Points
The system is accessed through:
- `lookup_template_configs(kernel_inputs, op_name, template_uids)` - Main lookup function
- `LookupTableChoices._finalize_template_configs()` - Integration point with existing choice system
### Error Handling
- Validates config dictionaries contain required `template_id` field
- Gracefully handles non-CUDA devices by returning empty results
<<<<<<< HEAD
=======
## Recording Lookup Tables
The system can record autotuning results to automatically build lookup tables for future use. This eliminates the need to manually create table entries and ensures optimal configurations are captured from real workloads.
### Quick Start: Recording
1. **Enable recording:**
```python
from torch._inductor import config
# Master switch - must be True to enable recording
config.lookup_table.recording_active = True
# Configure recording behavior
config.lookup_table.recorder_topk = 5 # Record top 5 fastest choices per operation
config.lookup_table.recorder_record_dir = "/path/to/output" # Save to files
```
2. **Run your model with autotuning:**
```python
import torch
model = YourModel()
compiled_model = torch.compile(model, mode="max-autotune")
# Recording happens automatically during compilation and execution
result = compiled_model(inputs)
```
3. **Files are automatically saved** to the specified directory with timestamped filenames like `inductor_lut_20241205_143052_123.json`.
### Recording Configuration
All recording options are available under `torch._inductor.config.lookup_table`:
```python
from torch._inductor import config
# Master recording switch - must be True for any recording to happen
config.lookup_table.recording_active = True # Default: False
# Logging and immediate emission
config.lookup_table.recorder_emit = True # Default: True (logs entries)
# File recording - set directory to enable file output
config.lookup_table.recorder_record_dir = "/path/to/save/tables" # Default: None
# Number of top choices to record per operation key
config.lookup_table.recorder_topk = 10 # Default: None (record all)
config.lookup_table.recorder_topk = 0 # Special case: disable recording
# Template safety and portability options
config.lookup_table.record_template_hash = True # Default: True (include hashes)
config.lookup_table.record_with_device_key = True # Default: True (device-specific keys)
```
### Understanding TopK: Determinism vs. Flexibility
The `recorder_topk` setting is crucial for controlling the behavior of your recorded lookup tables:
#### TopK = 1: Maximum Performance and Determinism
```python
config.lookup_table.recorder_topk = 1 # Record only the fastest choice
```
**Benefits:**
- **No autotuning overhead**: When using the recorded table, exactly one choice is available, so no autotuning occurs
- **Perfect determinism**: Always uses the same kernel for identical inputs across runs
- **Fastest compilation**: Minimal overhead during `torch.compile()` with the lookup table
- **Production-ready**: Ideal for deployment where consistency and speed matter most
**Use case:** Production environments where you want maximum performance and deterministic behavior.
#### TopK > 1: Balanced Performance with Options
```python
config.lookup_table.recorder_topk = 5 # Record top 5 fastest choices
```
**Benefits:**
- **Some autotuning**: When using the recorded table, autotuning occurs among the recorded choices (faster than full autotuning)
- **Flexibility**: Multiple good options available if hardware characteristics change slightly
- **Robustness**: Backup choices if the fastest choice becomes unavailable
**Trade-offs:**
- **Slight overhead**: Autotuning still occurs among the recorded choices
- **Less determinism**: May pick different choices between runs based on timing variations
**Use case:** Development/staging environments where you want good performance but retain some flexibility.
#### TopK = None: Maximum Visibility for Analysis
```python
config.lookup_table.recorder_topk = None # Record ALL choices that were tested
```
**Benefits:**
- **Complete picture**: See every template choice that was considered during autotuning
- **Manual optimization**: Analyze all options and manually edit the table to select specific choices
- **Debugging**: Understand what choices were available and their relative performance
**Trade-offs:**
- **Large tables**: More storage space and memory usage
- **Full autotuning**: When using the table, autotuning occurs among all recorded choices (no speed benefit)
**Use case:** Analysis, debugging, or when you want to manually curate the final lookup table.
#### Recommended Strategy
1. **Start with TopK = None** for analysis:
```python
config.lookup_table.recorder_topk = None # See all options
```
2. **Analyze the results** to understand choice distribution and performance gaps
3. **Switch to TopK = 1** for production:
```python
config.lookup_table.recorder_topk = 1 # Lock in the fastest choice
```
4. **Validate determinism** by running the same workload multiple times and confirming identical kernels
### Device Key Configuration for Recording
The `record_with_device_key` setting controls whether recorded entries are device-specific or portable:
```python
# Device-specific recording (more precise but less portable)
config.lookup_table.record_with_device_key = True
# Key format: "NVIDIA H100+input_shapes+operation"
# Best for: Production environments with known hardware
# Device-agnostic recording (more portable across GPU types)
config.lookup_table.record_with_device_key = False
# Key format: "input_shapes+operation"
# Best for: Development, CI/CD, mixed GPU environments
```
**Note**: During lookup, both key formats are always tried regardless of this setting, with device-specific keys taking priority if both exist.
### How Recording Works
The recording system automatically:
1. **Captures autotuning results**: Monitors all kernel selection decisions during `torch.compile()` execution
2. **Filters by performance**: Records only the fastest choices (configurable via `recorder_topk`)
3. **Generates lookup keys**: Uses the same key format as the lookup system for consistency
4. **Saves incrementally**: Each autotuning session appends to timestamped JSON files
5. **Maintains safety**: Includes template hashes to prevent using stale configurations
### Example: Complete Recording Workflow
```python
import torch
from torch._inductor import config
import tempfile
import json
# Enable recording with configuration
config.lookup_table.recording_active = True
config.lookup_table.recorder_topk = 3
config.lookup_table.record_template_hash = True
# Use temporary directory for this example
with tempfile.TemporaryDirectory() as temp_dir:
config.lookup_table.recorder_record_dir = temp_dir
# Your model
def matmul_model(a, b):
return torch.mm(a, b)
# Compile with autotuning (triggers recording)
compiled_model = torch.compile(matmul_model, mode="max-autotune")
# Run the model (autotuning results are recorded automatically)
a = torch.randn(512, 512, device='cuda', dtype=torch.float16)
b = torch.randn(512, 512, device='cuda', dtype=torch.float16)
result = compiled_model(a, b)
# Check recorded files
import os
files = [f for f in os.listdir(temp_dir) if f.startswith('inductor_lut_')]
print(f"Recorded files: {files}")
# Load and inspect the recorded lookup table
if files:
with open(os.path.join(temp_dir, files[0])) as f:
recorded_table = json.load(f)
print("Recorded entries:")
for key, configs in recorded_table.items():
print(f" Key: {key}")
for i, config in enumerate(configs):
print(f" Config {i+1}: template_id={config['template_id']}")
```
### Using Recorded Tables
Once you have recorded tables, use them for faster compilation:
```python
from torch._inductor import config
from torch._inductor.lookup_table import LookupTableChoices
from torch._inductor.virtualized import V
import json
# Load your recorded table
with open('inductor_lut_20241205_143052_123.json') as f:
lookup_table = json.load(f)
# Configure the system to use the lookup table
config.lookup_table.table = lookup_table
V.set_choices_handler(LookupTableChoices())
# Now compilation will use your recorded configurations
model = torch.compile(your_model, mode="max-autotune")
result = model(inputs) # Uses lookup table, skips autotuning
```
### Advanced: Custom Recording Backends
Extend the recording system with custom backends for specialized workflows:
```python
from torch._inductor.lookup_table import recorder
# Custom emit backend for immediate processing
class CustomLogBackend(recorder.EmitBackend):
def emit(self, entry):
# Process each entry immediately as it's recorded
print(f"Recorded {entry.key} -> {entry.value['template_id']} (runtime: {entry.runtime:.4f}ms)")
# Custom record backend for batch processing
class DatabaseRecordBackend(recorder.RecordBackend):
def __init__(self, connection_string):
self.conn = connection_string
def dump(self, data):
# Save all entries to database when dump() is called
for key, entries in data.items():
for entry in entries:
self.save_to_db(key, entry.value, entry.runtime)
# Register custom backends
recorder.add_backend(CustomLogBackend())
recorder.add_backend(DatabaseRecordBackend("postgresql://..."))
```
### Performance and Overhead
**Recording Performance**:
- **Minimal overhead**: Recording adds ~1-5μs per kernel selection
- **Fast bail**: When `recording_active=False`, overhead is ~100ns (single boolean check)
- **Memory efficient**: Only keeps configured `topk` entries per operation in memory
**Storage**:
- **Typical size**: 1-10KB per recorded table file
- **Compression**: JSON format is human-readable and compresses well
- **Incremental**: Each compilation session creates a separate timestamped file
### Troubleshooting Recording
**No files created?**
```python
# Check if recording is properly enabled
from torch._inductor import config
print(f"Recording active: {config.lookup_table.recording_active}")
print(f"Record directory: {config.lookup_table.recorder_record_dir}")
print(f"TopK setting: {config.lookup_table.recorder_topk}")
# Ensure max-autotune is enabled to trigger template selection
compiled_model = torch.compile(model, mode="max-autotune")
```
**Empty tables?**
- Recording only captures results from operations that undergo autotuning
- Ensure your model has matrix operations (`mm`, `addmm`, `bmm`, etc.)
- Check that input sizes are large enough to trigger template-based kernels
- Verify GPU kernels are being used (CPU operations aren't recorded)
**Files but no entries?**
- Check `recorder_topk` isn't set to 0 (which disables recording)
- Ensure autotuning found valid template choices (not just ATEN fallbacks)
- Verify templates have the required `template_id` and parameters
>>>>>>> 8e9a06853941 ([inductor][lookup table] add recorder 2/3)

View File

@ -0,0 +1,34 @@
"""
Template lookup table system for PyTorch Inductor.
This package provides functionality for:
- Loading pre-configured template choices from lookup tables
- Managing template configurations and choices
All functionality is contained within the LookupTableChoices class.
You can customize any aspect by subclassing LookupTableChoices and overriding methods.
Usage:
# Basic usage
choices = LookupTableChoices()
V.set_choices_handler(choices)
# Custom usage
class MyCustomChoices(LookupTableChoices):
def _get_lookup_table(self):
return my_custom_table
def make_lookup_key(self, kernel_inputs, op_name, include_device=False):
return f"custom_{op_name}_{hash(str(kernel_inputs))}"
V.set_choices_handler(MyCustomChoices())
"""
from . import recorder
from .choices import LookupTableChoices
__all__ = [
"LookupTableChoices",
"recorder",
]

View File

@ -0,0 +1,424 @@
from __future__ import annotations
import copy
import logging
from functools import lru_cache
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
from torch._inductor import config
from torch._inductor.choices import InductorChoices
from torch._inductor.kernel_template_choice import KernelTemplateChoice
from torch._inductor.template_heuristics.params import DictKernelTemplateParams
log = logging.getLogger(__name__)
if TYPE_CHECKING:
from collections.abc import Generator
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.kernel_inputs import KernelInputs
from torch._inductor.select_algorithm import ExternKernelChoice
class LookupTableChoices(InductorChoices):
"""
InductorChoices subclass that uses lookup table when available, otherwise falls back to parent.
All lookup functionality is contained within this class and can be customized by overriding methods.
"""
def _get_lookup_table(self) -> dict[str, list[dict[str, Any]]]:
"""
Get the template lookup table from config.
Override this method to use custom lookup table sources (database, API, etc.).
"""
if not torch.cuda.is_available() or config.lookup_table.table is None:
return {}
return config.lookup_table.table
@staticmethod
@lru_cache
def _get_device_key(device: torch.device) -> Optional[str]:
"""
Generate a device key for lookup table indexing.
For CPU devices, returns None.
For CUDA devices, returns the props.gcnArchName string.
"""
if device.type != "cuda":
# only cuda devices are supported, this indicates that the system is not in use
# for this device
return None
# Get CUDA device properties
props = torch.cuda.get_device_properties(device.index)
return props.gcnArchName
@staticmethod
def _generate_kernel_inputs_key(kernel_inputs: KernelInputs) -> str:
"""
Generate a key based on input node properties and scalars.
The key includes dtype, size, and stride information for each input node,
plus scalar values as key=value pairs separated by & signs.
"""
# Get node information using existing methods
dtypes = kernel_inputs.dtypes()
shapes = kernel_inputs.shapes_hinted()
strides = kernel_inputs.strides_hinted()
# Create tuple of (dtype, shape_list, stride_list) for each node
node_info = tuple(
(dtype, list(shape), list(stride))
for dtype, shape, stride in zip(dtypes, shapes, strides)
)
# Create base key from node information
fmt_key = str(node_info)
# Add scalar information if present
if kernel_inputs._scalars:
# Sort scalars for consistent key generation and join with &
scalar_parts = [
f"{key}={value}"
for key, value in sorted(kernel_inputs._scalars.items())
]
scalars_key = "&".join(scalar_parts)
fmt_key = f"{fmt_key}+{scalars_key}"
return f"{fmt_key}"
def make_lookup_key(
self, kernel_inputs: KernelInputs, op_name: str, include_device: bool = False
) -> Optional[str]:
"""
Create a flattened lookup key from kernel inputs and operation name.
Override this method to customize key generation.
Args:
kernel_inputs: KernelInputs object containing input nodes and scalars
op_name: Operation name (e.g., "mm", "addmm")
include_device: Whether to include device key in the generated key
Returns:
A string key combining device (optional), operation, and input information
"""
device = kernel_inputs.device()
dev_key = self._get_device_key(device)
if dev_key is None:
# The system does not run when dev_key is None, regardless of
# whether include_device is True or False
return None
if not include_device:
dev_key = None
# Generate input key using our staticmethod
input_key = self._generate_kernel_inputs_key(kernel_inputs)
# Create the flattened lookup key
if dev_key is not None:
key_parts = [dev_key, input_key, op_name]
else:
key_parts = [input_key, op_name]
return "+".join(key_parts)
def make_lookup_key_variants(
self, kernel_inputs: KernelInputs, op_name: str
) -> tuple[Optional[str], Optional[str]]:
"""
Generate both device-specific and device-agnostic lookup keys.
Override this method to customize key variant generation.
Args:
kernel_inputs: KernelInputs object containing input nodes and scalars
op_name: Operation name (e.g., "mm", "addmm")
Returns:
Tuple of (device_key, device_agnostic_key). Either may be None if generation fails.
"""
device_key = self.make_lookup_key(kernel_inputs, op_name, include_device=True)
device_agnostic_key = self.make_lookup_key(
kernel_inputs, op_name, include_device=False
)
return device_key, device_agnostic_key
@staticmethod
def _entry_is_valid(
cfg: dict[str, Any],
template_id: str,
template_hash_map: Optional[dict[str, Optional[str]]],
) -> bool:
"""
Check if a config entry is valid based on template hash validation.
Args:
cfg: Configuration dictionary that may contain a template_hash field
template_id: The template identifier
template_hash_map: Optional mapping from template_uid to src_hash for validation
Returns:
True if the config is valid and should be kept, False if it should be filtered out
"""
# If hash checking is disabled or no hash map provided, keep the config
if not config.lookup_table.check_src_hash or not template_hash_map:
return True
template_hash = template_hash_map.get(template_id)
config_hash = cfg.get("template_hash")
# Both hashes present - validate they match
if template_hash is not None and config_hash is not None:
if config_hash != template_hash:
log.warning(
"Hash validation failed for template '%s': config_hash='%s' != template_hash='%s'. "
"Template code may have changed. Filtering out config: %s",
template_id,
config_hash,
template_hash,
{k: v for k, v in cfg.items() if k != "template_hash"},
)
return False
else:
log.debug(
"Hash validation passed for template '%s': hash='%s'",
template_id,
template_hash,
)
return True
# Config has no hash - keep it
elif config_hash is None:
log.debug(
"Config for template '%s' has no hash - keeping it (template_hash='%s')",
template_id,
template_hash,
)
return True
# Template has no hash - keep config
else:
log.debug(
"Template '%s' has no src_hash - keeping config with hash '%s'",
template_id,
config_hash,
)
return True
def lookup_template_configs(
self,
kernel_inputs: KernelInputs,
op_name: str,
template_uids: list[str],
template_hash_map: Optional[dict[str, Optional[str]]] = None,
) -> dict[str, list[dict[str, Any]]]:
"""
Unified function to look up template configurations for multiple templates.
Override this method to customize lookup logic.
Args:
kernel_inputs: KernelInputs object containing input nodes and scalars
op_name: Operation name (e.g., "mm", "addmm")
template_uids: List of template identifiers (e.g., ["mm", "tma", "decompose_k"])
template_hash_map: Optional mapping from template_uid to src_hash for validation
Returns:
{}: No lookup table in use, or no matches found for any template
{"template_uid1": [config1, config2], ...}: Matches found, filtered configurations
"""
lookup_table = self._get_lookup_table()
if not lookup_table:
log.debug("Lookup table: no table configured or CUDA unavailable")
return {}
# Try both key variants: device-specific first, then device-agnostic
# If both exist, device-specific takes priority
device_key, device_agnostic_key = self.make_lookup_key_variants(
kernel_inputs, op_name
)
config_list = []
for key_type, key in [
("device-specific", device_key),
("device-agnostic", device_agnostic_key),
]:
if key is not None:
config_list = lookup_table.get(key, [])
if config_list:
log.debug(
"Lookup table: found %d configs using %s key '%s' for %s",
len(config_list),
key_type,
key,
op_name,
)
break
else:
log.debug(
"Lookup table: no match for %s (tried keys: %s, %s) (table has %d keys)",
op_name,
device_key,
device_agnostic_key,
len(lookup_table),
)
return {}
log.debug(
"Lookup table: found %d configs for %s templates %s",
len(config_list),
op_name,
template_uids,
)
# Group configs by template_id
configs_by_template: dict[str, list[dict[str, Any]]] = {}
for cfg in config_list:
if not isinstance(cfg, dict):
raise ValueError(
f"Config for {op_name} operation is not a dictionary: {cfg}"
)
if "template_id" not in cfg:
raise ValueError(
f"Config for {op_name} operation missing required 'template_id' field: {cfg}"
)
template_id = cfg["template_id"]
if template_id in template_uids:
if template_id not in configs_by_template:
configs_by_template[template_id] = []
configs_by_template[template_id].append(cfg)
# Check template hashes and clean up template_id field
result = {}
for template_id, matching_configs in configs_by_template.items():
filtered_configs = []
for cfg in matching_configs:
# Check template hash using helper function
if not self._entry_is_valid(cfg, template_id, template_hash_map):
continue
# Return a copy of the config, as we don't want to modify the original
cconfig = copy.deepcopy(cfg)
# Lastly, we have to throw out the template_id, as it's not a valid kwarg
# and just used to identify which template the entry belongs to
del cconfig["template_id"]
# Similarly, the template_hash is not a valid kwarg
cconfig.pop("template_hash", None)
filtered_configs.append(cconfig)
if filtered_configs:
result[template_id] = filtered_configs
return result
def _finalize_template_configs(
self,
template_choices: dict[str, Generator[KernelTemplateChoice, None, None]],
kernel_inputs: KernelInputs,
templates: list[Union[KernelTemplate, ExternKernelChoice]],
op_name: str,
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
) -> list[KernelTemplateChoice]:
"""Check lookup table for hits, use those if found, otherwise fall back to parent."""
# Early exit if lookup table system is disabled - call super directly
if not config.lookup_table.active:
return super()._finalize_template_configs(
template_choices, kernel_inputs, templates, op_name, kwarg_overrides
)
# 1. Collect template src_hashes for validation
template_uids = [template.uid for template in templates]
template_hash_map = {}
for template in templates:
src_hash = getattr(template, "src_hash", None)
template_hash_map[template.uid] = src_hash
log.debug(
"Choices: attempting lookup for %s with %d templates",
op_name,
len(template_uids),
)
# 2. Single batch lookup for all templates
lookup_results = self.lookup_template_configs(
kernel_inputs, op_name, template_uids, template_hash_map
)
# 3. Early exit if no lookup table or no matches
if not lookup_results: # Empty dict
log.info("LookupChoices: lookup miss for %s, using fallback", op_name)
return self._fallback(
template_choices,
kernel_inputs,
templates,
op_name,
kwarg_overrides,
)
log.info(
"LookupChoices: lookup hit for %s - found %d/%d templates: %s",
op_name,
len(lookup_results),
len(template_uids),
list(lookup_results.keys()),
)
# 4. Create KTCs only for templates with lookup entries
return self._create_lookup_choices(
lookup_results, templates, kernel_inputs, op_name
)
def _fallback(
self,
template_choices: dict[str, Generator[KernelTemplateChoice, None, None]],
kernel_inputs: KernelInputs,
templates: list[Union[KernelTemplate, ExternKernelChoice]],
op_name: str,
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
) -> list[KernelTemplateChoice]:
"""Fallback to parent if no lookup table or no matches."""
# NOTE: this is broken out, so that subclasses are able to override this
# to handle explicitly the situations where the lookup take had a miss vs
# overriding the entire logic
return super()._finalize_template_configs(
template_choices,
kernel_inputs,
templates,
op_name,
kwarg_overrides,
)
def _create_lookup_choices(
self,
lookup_results: dict[str, list[dict[str, Any]]],
templates: list[Union[KernelTemplate, ExternKernelChoice]],
kernel_inputs: KernelInputs,
op_name: str,
) -> list[KernelTemplateChoice]:
"""Create KernelTemplateChoice objects from lookup results using parent's get_ktc method."""
templates_by_uid = {template.uid: template for template in templates}
lookup_choices: list[KernelTemplateChoice] = []
for template_uid, configs in lookup_results.items():
template = templates_by_uid[template_uid]
# Use parent's get_ktc method to get a generator, then get the first base KTC
ktc_generator = self.get_ktc(kernel_inputs, template, op_name)
try:
base_ktc = next(ktc_generator)
except StopIteration:
# No configs from heuristic, skip this template
continue
# For each lookup config, create a KTC with the override kwargs
for c in configs:
lookup_ktc = KernelTemplateChoice(
template=base_ktc.template,
# use the ones from the lookup table
params=DictKernelTemplateParams(c),
extra_kwargs=base_ktc.extra_kwargs,
layout=base_ktc.layout,
inputs=base_ktc.inputs,
)
lookup_choices.append(lookup_ktc)
return lookup_choices

View File

@ -0,0 +1,415 @@
"""
Lookup table recorder system for capturing autotuning results.
This module provides a system to record and emit autotuning results from kernel selection.
It supports both immediate emission (logging) and table recording (building lookup tables).
"""
import json
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Callable, Optional
from torch._inductor import config
from torch._inductor.ir import ChoiceCaller
from torch._inductor.kernel_inputs import KernelInputs
from torch._inductor.kernel_template_choice import KernelTemplateChoice
from torch.utils._ordered_set import OrderedSet
from .choices import LookupTableChoices
log = logging.getLogger(__name__)
@dataclass
class LookupTableEntry:
"""Single entry representing one autotuning result"""
key: str # device_key+op_name+input_key
value: dict[str, Any] # Contains template_id and all kwargs
metadata: dict[str, Any] # Contains timing, rank, and other recording metadata
runtime: (
float # Contains the unique timing information. This is used to record topk
)
@classmethod
def from_ktc_and_timing(
cls,
ktc: KernelTemplateChoice,
timing: float,
rank: int,
op_name: str,
) -> Optional["LookupTableEntry"]:
"""Create entry from a KTC and its timing"""
# KTC must have a template - this is a requirement
assert ktc.template is not None, "KernelTemplateChoice must have a template"
# Use V.choices_handler to make lookup key if it's a LookupTableChoices instance
key = _make_lookup_key(
ktc.inputs, op_name, config.lookup_table.record_with_device_key
)
if key is None:
return None
# Build value dict from KTC kwargs
value = dict(template_id=ktc.template.uid, **ktc.params.to_serializeable_dict())
# Add template hash if available and configured
if config.lookup_table.record_template_hash:
# Use src_hash directly from the template
template_hash = getattr(ktc.template, "src_hash", None)
if template_hash is not None:
value["template_hash"] = template_hash
# Create metadata dict with timing and rank info
metadata = {
"timing": timing,
"rank": rank,
}
return cls(key=key, value=value, metadata=metadata, runtime=timing)
def _make_lookup_key(
kernel_inputs: KernelInputs, op_name: str, include_device: bool = False
) -> Optional[str]:
"""Make lookup key using V.choices_handler if available, otherwise use LookupTableChoices static methods"""
from torch._inductor.virtualized import V
if hasattr(V, "choices_handler") and isinstance(
V.choices_handler, LookupTableChoices
):
return V.choices_handler.make_lookup_key(kernel_inputs, op_name, include_device)
else:
# Fallback: create a temporary LookupTableChoices instance to use its methods
choices_handler = LookupTableChoices()
return choices_handler.make_lookup_key(kernel_inputs, op_name, include_device)
class Backend:
"""Base class for backends"""
def clear(self) -> None:
pass
class EmitBackend(ABC, Backend):
"""Backend for immediate emission of single entries"""
@abstractmethod
def emit(self, entry: LookupTableEntry) -> None:
pass
class RecordBackend(ABC, Backend):
"""Backend for dumping recorded table"""
@abstractmethod
def dump(self, data: dict[str, list[LookupTableEntry]]) -> None:
pass
# Track registered backends to avoid double registration
_registered_backends: OrderedSet[Any] = OrderedSet()
def _backend_key(
backend_class: type, kwargs: dict[str, Any]
) -> tuple[type, frozenset[Any]]:
"""Create a unique key for backend class + kwargs"""
return (backend_class, frozenset(kwargs.items()) if kwargs else frozenset())
class LogEmitBackend(EmitBackend):
"""Default emit backend that logs entries"""
def emit(self, entry: LookupTableEntry) -> None:
log.debug("LookupTable: %r -> %r", entry.key, entry.value)
class DirectoryRecordBackend(RecordBackend):
"""Default record backend that saves to timestamped files in a directory"""
def __init__(self, directory: str):
self.directory = directory
# Generate timestamped filename with 3-digit millisecond precision
self.setup()
def setup(self) -> None:
now = datetime.now(tz=timezone.utc)
timestamp = now.strftime("%Y%m%d_%H%M%S") + f"_{now.microsecond // 1000:03d}"
filename = f"inductor_lut_{timestamp}.json"
self.filepath = os.path.join(self.directory, filename)
def dump(self, data: dict[str, list[LookupTableEntry]]) -> None:
# Create directory if it doesn't exist
os.makedirs(self.directory, exist_ok=True)
# extract only the value from the entries and dump those
data_values = {}
for k, entries in data.items():
data_values[k] = [e.value for e in entries]
# just override it again
with open(self.filepath, "w") as f:
json.dump(data_values, f, indent=2)
def clear(self) -> None:
# generate a new path
self.setup()
class LookupTableRecorder:
"""Main recorder that manages both emit and record backends"""
def __init__(self, topk: Optional[int] = None) -> None:
self.data: dict[str, list[LookupTableEntry]] = {}
self.emit_backends: list[EmitBackend] = []
self.record_backends: list[RecordBackend] = []
# Use provided topk or fall back to config
self.topk = topk if topk is not None else config.lookup_table.recorder_topk
@property
def input_entries(self) -> int:
"""how many unique input entries have been recorded"""
return len(self.data)
def add_backend(self, backend: Backend) -> None:
"""Add a backend to the appropriate list based on its type"""
if isinstance(backend, EmitBackend):
self.emit_backends.append(backend)
elif isinstance(backend, RecordBackend):
self.record_backends.append(backend)
else:
raise ValueError(
f"Backend must be an instance of EmitBackend or RecordBackend, "
f"got {type(backend).__name__}"
)
def emit(self, entry: LookupTableEntry) -> None:
"""Emit a single entry immediately"""
for backend in self.emit_backends:
backend.emit(entry)
def record(self, entry: LookupTableEntry) -> None:
"""Record entry to table and emit it"""
# Always emit when recording
self.emit(entry)
# Initialize key if not exists
if entry.key not in self.data:
self.data[entry.key] = []
# just insert and sort, it's a small topk usually
# not worth doing bisection
entries_for_key = self.data[entry.key]
entries_for_key.append(entry)
entries_for_key.sort(
key=lambda x: x.runtime if x.runtime is not None else float("inf")
)
# Trim to topk if necessary (only if topk is positive)
if self.topk is not None and self.topk > 0 and len(entries_for_key) > self.topk:
topk: int = self.topk
# Log which entries we're replacing
replaced_entries = entries_for_key[topk:]
log.info(
"Replacing %d entries with new entry (value: %r, runtime: %f) due to topk=%d. "
"Replaced entries: %s",
len(replaced_entries),
entry.value,
entry.runtime,
self.topk,
[{"value": e.value, "runtime": e.runtime} for e in replaced_entries],
)
del entries_for_key[topk:]
def dump(self) -> None:
"""Dump via all record backends"""
for backend in self.record_backends:
backend.dump(self.data)
def clear(self) -> None:
self.data.clear()
for backend in self.emit_backends + self.record_backends:
backend.clear()
# Module-wide instance
_lookup_table_recorder: Optional[LookupTableRecorder] = None
def get_lookup_table_recorder() -> LookupTableRecorder:
"""Get the global lookup table recorder"""
global _lookup_table_recorder
if _lookup_table_recorder is None:
_lookup_table_recorder = LookupTableRecorder()
# Always register any pending backends
_register_pending_backends(_lookup_table_recorder)
assert _lookup_table_recorder is not None
return _lookup_table_recorder
def add_backend(backend: Backend) -> None:
"""Add a backend to the global lookup table recorder"""
recorder = get_lookup_table_recorder()
if recorder is not None:
recorder.add_backend(backend)
def _register_pending_backends(recorder: LookupTableRecorder) -> None:
"""Register built-in backends based on current config"""
global _registered_backends
# Add built-in LogEmitBackend if enabled and not already registered
if config.lookup_table.recorder_emit:
emit_key = _backend_key(LogEmitBackend, {})
if emit_key not in _registered_backends:
try:
recorder.add_backend(LogEmitBackend())
_registered_backends.add(emit_key)
log.debug("Registered LogEmitBackend")
except Exception as e:
log.warning("Failed to register LogEmitBackend: %r", e)
# Add built-in DirectoryRecordBackend if enabled and not already registered
record_dir = config.lookup_table.recorder_record_dir
if record_dir:
record_key = _backend_key(DirectoryRecordBackend, {"directory": record_dir})
if record_key not in _registered_backends:
try:
recorder.add_backend(DirectoryRecordBackend(record_dir))
_registered_backends.add(record_key)
log.debug("Registered DirectoryRecordBackend: %s", record_dir)
except Exception as e:
log.warning(
"Failed to register DirectoryRecordBackend %s: %r", record_dir, e
)
def record_topk_choices(
timings: dict[ChoiceCaller, float],
op_name: str,
input_nodes: list[Any],
choices: list[ChoiceCaller],
profiled_time_fn: Callable[[], dict[Any, Any]],
) -> None:
"""
Feedback function to record topk choices based on timing results.
Args:
timings: Mapping from choices to benchmark times
op_name: Name of the operation (e.g. "mm", "addmm")
input_nodes: List of input ir.Nodes
choices: List of ChoiceCaller objects
profiled_time_fn: Function to get profiled times (unused in this implementation)
"""
# Fast bail if recording not active
if not config.lookup_table.recording_active:
log.debug(
"Recording disabled (recording_active=False) for operation %s", op_name
)
return
# If topk is 0, don't record anything
if config.lookup_table.recorder_topk == 0:
log.debug("Recording disabled (topk=0) for operation %s", op_name)
return
# Get recorder
recorder = get_lookup_table_recorder()
if recorder is None:
log.warning("Failed to get lookup table recorder for operation %s", op_name)
return
# adjust the recorder topk if necessary
recorder.topk = config.lookup_table.recorder_topk
# Filter choices that have valid timings and KTC references
valid_choices = []
filtered_count = 0
for choice in choices:
if (
choice not in timings
or not hasattr(choice, "annotations")
or "ktc" not in choice.annotations
or timings[choice] == float("inf")
):
filtered_count += 1
else:
valid_choices.append(choice)
if filtered_count > 0:
log.debug(
"Recording %s: filtered %d/%d invalid choices",
op_name,
filtered_count,
len(choices),
)
if not valid_choices:
log.debug("Recording %s: no valid choices", op_name)
return
# Sort and trim to topk
sorted_choices = sorted(valid_choices, key=lambda c: timings[c])
if recorder.topk and recorder.topk > 0 and len(sorted_choices) > recorder.topk:
sorted_choices = sorted_choices[: recorder.topk]
# Record each choice
recorded_count = 0
for rank, choice in enumerate(sorted_choices):
ktc = choice.annotations["ktc"]
if ktc is None:
log.warning(
"Recording %s: KTC is None for choice %s",
op_name,
getattr(choice, "name", choice),
)
continue
entry = LookupTableEntry.from_ktc_and_timing(
ktc=ktc, timing=timings[choice], rank=rank, op_name=op_name
)
if entry is not None:
recorder.record(entry)
recorded_count += 1
log.info(
"Recording %s: saved %d/%d entries",
op_name,
recorded_count,
len(sorted_choices),
)
# Any time we record the table, we continue to dump
# The backends need to be implemented so that they can
# accommodate progressive dumping
recorder.dump()
def dump() -> None:
"""Dump the global lookup table recorder"""
recorder = get_lookup_table_recorder()
if recorder is not None:
recorder.dump()
def clear() -> None:
"""Clear the global lookup table recorder"""
recorder = get_lookup_table_recorder()
if recorder is not None:
recorder.clear()
# Auto-register the feedback function when the module is imported
try:
from torch._inductor.select_algorithm import add_feedback_saver
add_feedback_saver(record_topk_choices)
log.debug("Registered lookup table recorder feedback function")
except ImportError:
log.warning(
"Failed to register lookup table recorder feedback function - select_algorithm not available"
)

View File

@ -207,9 +207,9 @@ def _choices_default():
We virtualize InductorChoices to allow changing inductor heuristics from out of tree.
"""
from torch._inductor.choices import InductorChoices
from torch._inductor.lookup_table.choices import LookupTableChoices
rv = InductorChoices()
rv = LookupTableChoices()
setattr(threadlocal, _choices._key, rv)
return rv