mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 08:00:58 +08:00
Update
[ghstack-poisoned]
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -2080,17 +2080,6 @@ write_are_deterministic_algorithms_enabled = (
|
||||
)
|
||||
|
||||
|
||||
class lookup_table:
|
||||
# Lookup table for template config overrides
|
||||
table: Optional[dict[str, list[dict[str, Any]]]] = None
|
||||
|
||||
# Enable template src_hash checking in lookup table to prevent using stale configs.
|
||||
# If True, configs with 'template_hash' field will be compared against the template's
|
||||
# src_hash at runtime and filtered out if they don't match. If False, no
|
||||
# hash checking is performed.
|
||||
check_src_hash: bool = True
|
||||
|
||||
|
||||
class test_configs:
|
||||
force_extern_kernel_in_multi_template: bool = False
|
||||
|
||||
|
||||
@ -1,253 +0,0 @@
|
||||
# Template Lookup Table System
|
||||
|
||||
The template lookup table system provides a way to pre-configure kernel template parameters for specific operations and
|
||||
input configurations, bypassing the default choice generation and autotuning process.
|
||||
|
||||
## Overview
|
||||
|
||||
The lookup table system replaces default choice generation with pre-configured template parameters for specific
|
||||
operations and input configurations. It sits orthogonal to `max-autotune(-gemm)` in the following way
|
||||
|
||||
If a lookup table is provided and there is a match
|
||||
|
||||
- We check whether the template(s) in the match are currently in use
|
||||
- If so, we use the pre-configured template(s) and config and bypass choice generation
|
||||
- If more than one choice is provided, we run autotune among the pre-configured choices
|
||||
- If not, we fall back to the default choice generation process, including max-autotune(-gemm) logic
|
||||
|
||||
If there is no match, we fall back to the default choice generation process, including max-autotune(-gemm) logic
|
||||
|
||||
## Configuration
|
||||
|
||||
Enable the system by setting both:
|
||||
|
||||
```python
|
||||
from torch._inductor import config
|
||||
config.lookup_table.table = your_table_dict
|
||||
# You also need to set it as the default choice handler
|
||||
from torch._inductor.lookup_table import LookupTableChoices
|
||||
torch._inductor.V.set_choices_handler(LookupTableChoices())
|
||||
```
|
||||
|
||||
### Device Key Handling
|
||||
|
||||
The key schema format is described in detail in the [Key Schemas](#key-schemas) section below.
|
||||
|
||||
Configure device key behavior:
|
||||
|
||||
```python
|
||||
# Control whether entries include device-specific keys for lookups
|
||||
# Device-agnostic entries work across different GPU models
|
||||
```
|
||||
|
||||
**Lookup Behavior**: During lookup, the system automatically tries both key formats:
|
||||
|
||||
1. **Device-specific key** (e.g., `"NVIDIA H100+input_data+mm"`) - tried first
|
||||
1. **Device-agnostic key** (e.g., `"input_data+mm"`) - tried if device-specific fails
|
||||
|
||||
**Priority**: If both device-specific and device-agnostic entries exist for the same inputs, the device-specific entry
|
||||
takes priority.
|
||||
|
||||
**NOTE**: Device-based keys simplify hardware-specific optimization without complex build rules. Currently limited to
|
||||
device name only. If you need additional conditional key attributes (e.g., CUDA version filtering), please file an issue
|
||||
or submit a patch.
|
||||
|
||||
## Behavior
|
||||
|
||||
When the table is active, the following behavior occurs for all supported operations:
|
||||
|
||||
### Match Found
|
||||
|
||||
- Uses pre-configured choices from the table instead of generating default choices
|
||||
- Bypasses autotuning if only a single choice is provided
|
||||
- If multiple choices are provided, autotuning occurs among those choices only
|
||||
|
||||
### No Match Found
|
||||
|
||||
- Standard default behavior - generates choices using heuristics and max-autotune settings
|
||||
|
||||
### Table Not Set or Inactive
|
||||
|
||||
- Standard default behavior - generates choices using heuristics and max-autotune settings
|
||||
|
||||
## Supported Operations
|
||||
|
||||
Currently supports: `mm`, `addmm`, `bmm`, `mm_plus_mm`, `scaled_mm` operations with
|
||||
|
||||
- Triton
|
||||
- ATEN
|
||||
- DecomposeK
|
||||
|
||||
## Table Format
|
||||
|
||||
The table is a dictionary with keys in the format:
|
||||
|
||||
```
|
||||
"input_key+op_name"
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `input_key`: Generated from `KernelInputs.key` property, represents tensor shapes/dtypes/strides
|
||||
- `op_name`: Operation name (`"mm"`, `"addmm"`, etc.)
|
||||
|
||||
Each value is a list of configuration dictionaries containing:
|
||||
|
||||
- `template_id`: Template identifier (`"triton:mm"`, `"triton::mm_persistent_tma"`, `"decompose_k"`, etc.)
|
||||
- Template-specific parameters (`BLOCK_M`, `BLOCK_N`, `BLOCK_K`, `num_warps`, etc.)
|
||||
|
||||
## Key Schemas
|
||||
|
||||
**NOTE**: The key schema format is subject to change as the system evolves.
|
||||
|
||||
The lookup table uses composite keys to match kernel configurations. See
|
||||
[Implementation Details](#implementation-details) below for more technical information about key generation. This
|
||||
section describes the structure of these keys.
|
||||
|
||||
### Key Format Structure
|
||||
|
||||
Keys follow the pattern:
|
||||
|
||||
```
|
||||
[device_name+]input_key+[additional_params+]op_name
|
||||
```
|
||||
|
||||
Components:
|
||||
|
||||
- **device_name** (optional): GPU device identifier (e.g., `"NVIDIA H100"`)
|
||||
|
||||
- Obtained from `torch.cuda.get_device_properties().gcnArchName`
|
||||
- Enables device-specific optimizations
|
||||
- When omitted, creates device-agnostic entries that work across hardware
|
||||
|
||||
- **input_key**: Tensor configuration representation from `KernelInputs.key`
|
||||
|
||||
- Format: `((dtype, shape, stride), (dtype, shape, stride), ...)`
|
||||
- Each tuple represents one input tensor's properties
|
||||
- Example: `((torch.float16, [128, 256], [0, 1]), (torch.float16, [64, 256], [256, 1]))`
|
||||
- Order matches the operation's input argument order
|
||||
|
||||
- **additional_params** (optional): Operation-specific parameters
|
||||
|
||||
- Format: `key1=value1&key2=value2`
|
||||
- Example: `alpha=1&beta=1` for addmm operations
|
||||
|
||||
- **op_name**: Operation identifier
|
||||
|
||||
- Examples: `"mm"`, `"addmm"`, `"bmm"`, `"mm_plus_mm"`, `"scaled_mm"`
|
||||
|
||||
### Key Examples
|
||||
|
||||
**Device-specific key for addmm:**
|
||||
|
||||
```
|
||||
"NVIDIA H100+((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm"
|
||||
```
|
||||
|
||||
**Device-agnostic key for mm:**
|
||||
|
||||
```
|
||||
"((torch.float16, [64, 128], [128, 1]), (torch.float16, [128, 256], [256, 1]))+mm"
|
||||
```
|
||||
|
||||
**Key with no additional parameters:**
|
||||
|
||||
```
|
||||
"((torch.float32, [512, 512], [512, 1]), (torch.float32, [512, 512], [512, 1]))+bmm"
|
||||
```
|
||||
|
||||
### Lookup Strategy
|
||||
|
||||
During lookup, the system tries keys in priority order:
|
||||
|
||||
1. **Device-specific key** - checked first if device information is available
|
||||
1. **Device-agnostic key** - fallback if device-specific lookup fails
|
||||
|
||||
This allows tables to contain:
|
||||
|
||||
- Device-optimized configurations (higher priority)
|
||||
- Portable configurations that work across devices
|
||||
- Mix of both for flexible deployment
|
||||
|
||||
## Example Table
|
||||
|
||||
This is an example table for a single input showing two configurations
|
||||
|
||||
```python
|
||||
table = {
|
||||
"((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm": [
|
||||
{
|
||||
"template_id": "triton::mm",
|
||||
"EVEN_K": true,
|
||||
"USE_FAST_ACCUM": false,
|
||||
"ACC_TYPE": "tl.float32",
|
||||
"num_stages": 2,
|
||||
"num_warps": 4,
|
||||
"BLOCK_M": 32,
|
||||
"BLOCK_N": 32,
|
||||
"BLOCK_K": 64,
|
||||
"hint_override": null,
|
||||
"GROUP_M": 8,
|
||||
"template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896"
|
||||
},
|
||||
{
|
||||
"template_id": "aten::bias_addmm"
|
||||
},
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Source Hashing Safety
|
||||
|
||||
The lookup table system includes source hashing to prevent using stale configurations when template code changes.
|
||||
|
||||
### Configuration
|
||||
|
||||
- **Enabled by default**: `torch._inductor.config.lookup_table.check_src_hash = True`
|
||||
- **Optional field**: Add `"template_hash"` to table entries for enhanced safety
|
||||
|
||||
### Behavior
|
||||
|
||||
When source hash checking is enabled:
|
||||
|
||||
- Template configurations with `"template_hash"` fields are validated against current template source hashes
|
||||
- Mismatched hashes indicate the template code has changed since the configuration was created
|
||||
- Stale configurations are automatically filtered out with a warning message
|
||||
- Configurations without hash fields are preserved for backward compatibility or if the user wants to fly looser
|
||||
|
||||
### Example with Template Hash
|
||||
|
||||
```python
|
||||
{
|
||||
"template_id": "triton::mm",
|
||||
"BLOCK_M": 32,
|
||||
"BLOCK_N": 32,
|
||||
"BLOCK_K": 16,
|
||||
"template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896"
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Impact
|
||||
|
||||
- **Lookup Hit**: Eliminates heuristic choice generation and autotuning overhead (if a single choice)
|
||||
- **Lookup Miss**: Default behavior, including heuristic choice generation and autotuning
|
||||
- **Memory**: Table stored in memory, minimal overhead for key generation and lookup
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Key Generation
|
||||
|
||||
- Device key: Uses `torch.cuda.get_device_properties().gcnArchName` (e.g., "NVIDIA H100")
|
||||
- Input key: Generated from `KernelInputs.key` containing tensor properties
|
||||
|
||||
### Entry Points
|
||||
|
||||
The system is accessed through:
|
||||
|
||||
- `lookup_template_configs(kernel_inputs, op_name, template_uids)` - Main lookup function
|
||||
- `LookupTableChoices._finalize_template_configs()` - Integration point with existing choice system
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Validates config dictionaries contain required `template_id` field
|
||||
- Gracefully handles non-CUDA devices by returning empty results
|
||||
@ -1,32 +0,0 @@
|
||||
"""
|
||||
Template lookup table system for PyTorch Inductor.
|
||||
|
||||
This package provides functionality for:
|
||||
- Loading pre-configured template choices from lookup tables
|
||||
- Managing template configurations and choices
|
||||
|
||||
All functionality is contained within the LookupTableChoices class.
|
||||
You can customize any aspect by subclassing LookupTableChoices and overriding methods.
|
||||
|
||||
Usage:
|
||||
# Basic usage
|
||||
choices = LookupTableChoices()
|
||||
V.set_choices_handler(choices)
|
||||
|
||||
# Custom usage
|
||||
class MyCustomChoices(LookupTableChoices):
|
||||
def _get_lookup_table(self):
|
||||
return my_custom_table
|
||||
|
||||
def make_lookup_key(self, kernel_inputs, op_name, include_device=False):
|
||||
return f"custom_{op_name}_{hash(str(kernel_inputs))}"
|
||||
|
||||
V.set_choices_handler(MyCustomChoices())
|
||||
"""
|
||||
|
||||
from .choices import LookupTableChoices
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LookupTableChoices",
|
||||
]
|
||||
@ -1,418 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.choices import InductorChoices
|
||||
from torch._inductor.kernel_template_choice import KernelTemplateChoice
|
||||
from torch._inductor.template_heuristics.params import DictKernelTemplateParams
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from torch._inductor.codegen.common import KernelTemplate
|
||||
from torch._inductor.kernel_inputs import KernelInputs
|
||||
from torch._inductor.select_algorithm import ExternKernelChoice
|
||||
|
||||
|
||||
class LookupTableChoices(InductorChoices):
|
||||
"""
|
||||
InductorChoices subclass that uses lookup table when available, otherwise falls back to parent.
|
||||
All lookup functionality is contained within this class and can be customized by overriding methods.
|
||||
"""
|
||||
|
||||
def _get_lookup_table(self) -> dict[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
Get the template lookup table from config.
|
||||
Override this method to use custom lookup table sources (database, API, etc.).
|
||||
"""
|
||||
if not torch.cuda.is_available() or config.lookup_table.table is None:
|
||||
return {}
|
||||
return config.lookup_table.table
|
||||
|
||||
@staticmethod
|
||||
@lru_cache
|
||||
def _get_device_key(device: torch.device) -> Optional[str]:
|
||||
"""
|
||||
Generate a device key for lookup table indexing.
|
||||
For CPU devices, returns None.
|
||||
For CUDA devices, returns the props.gcnArchName string.
|
||||
"""
|
||||
if device.type != "cuda":
|
||||
# only cuda devices are supported, this indicates that the system is not in use
|
||||
# for this device
|
||||
return None
|
||||
|
||||
# Get CUDA device properties
|
||||
props = torch.cuda.get_device_properties(device.index)
|
||||
return props.gcnArchName
|
||||
|
||||
@staticmethod
|
||||
def _generate_kernel_inputs_key(kernel_inputs: KernelInputs) -> str:
|
||||
"""
|
||||
Generate a key based on input node properties and scalars.
|
||||
The key includes dtype, size, and stride information for each input node,
|
||||
plus scalar values as key=value pairs separated by & signs.
|
||||
"""
|
||||
# Get node information using existing methods
|
||||
dtypes = kernel_inputs.dtypes()
|
||||
shapes = kernel_inputs.shapes_hinted()
|
||||
strides = kernel_inputs.strides_hinted()
|
||||
|
||||
# Create tuple of (dtype, shape_list, stride_list) for each node
|
||||
node_info = tuple(
|
||||
(dtype, list(shape), list(stride))
|
||||
for dtype, shape, stride in zip(dtypes, shapes, strides)
|
||||
)
|
||||
|
||||
# Create base key from node information
|
||||
fmt_key = str(node_info)
|
||||
# Add scalar information if present
|
||||
if kernel_inputs._scalars:
|
||||
# Sort scalars for consistent key generation and join with &
|
||||
scalar_parts = [
|
||||
f"{key}={value}"
|
||||
for key, value in sorted(kernel_inputs._scalars.items())
|
||||
]
|
||||
scalars_key = "&".join(scalar_parts)
|
||||
fmt_key = f"{fmt_key}+{scalars_key}"
|
||||
|
||||
return f"{fmt_key}"
|
||||
|
||||
def make_lookup_key(
|
||||
self, kernel_inputs: KernelInputs, op_name: str, include_device: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Create a flattened lookup key from kernel inputs and operation name.
|
||||
Override this method to customize key generation.
|
||||
|
||||
Args:
|
||||
kernel_inputs: KernelInputs object containing input nodes and scalars
|
||||
op_name: Operation name (e.g., "mm", "addmm")
|
||||
include_device: Whether to include device key in the generated key
|
||||
|
||||
Returns:
|
||||
A string key combining device (optional), operation, and input information
|
||||
"""
|
||||
device = kernel_inputs.device()
|
||||
dev_key = self._get_device_key(device)
|
||||
if dev_key is None:
|
||||
# The system does not run when dev_key is None, regardless of
|
||||
# whether include_device is True or False
|
||||
return None
|
||||
if not include_device:
|
||||
dev_key = None
|
||||
|
||||
# Generate input key using our staticmethod
|
||||
input_key = self._generate_kernel_inputs_key(kernel_inputs)
|
||||
|
||||
# Create the flattened lookup key
|
||||
if dev_key is not None:
|
||||
key_parts = [dev_key, input_key, op_name]
|
||||
else:
|
||||
key_parts = [input_key, op_name]
|
||||
|
||||
return "+".join(key_parts)
|
||||
|
||||
def make_lookup_key_variants(
|
||||
self, kernel_inputs: KernelInputs, op_name: str
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Generate both device-specific and device-agnostic lookup keys.
|
||||
Override this method to customize key variant generation.
|
||||
|
||||
Args:
|
||||
kernel_inputs: KernelInputs object containing input nodes and scalars
|
||||
op_name: Operation name (e.g., "mm", "addmm")
|
||||
|
||||
Returns:
|
||||
Tuple of (device_key, device_agnostic_key). Either may be None if generation fails.
|
||||
"""
|
||||
device_key = self.make_lookup_key(kernel_inputs, op_name, include_device=True)
|
||||
device_agnostic_key = self.make_lookup_key(
|
||||
kernel_inputs, op_name, include_device=False
|
||||
)
|
||||
|
||||
return device_key, device_agnostic_key
|
||||
|
||||
@staticmethod
|
||||
def _entry_is_valid(
|
||||
cfg: dict[str, Any],
|
||||
template_id: str,
|
||||
template_hash_map: Optional[dict[str, Optional[str]]],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a config entry is valid based on template hash validation.
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary that may contain a template_hash field
|
||||
template_id: The template identifier
|
||||
template_hash_map: Optional mapping from template_uid to src_hash for validation
|
||||
|
||||
Returns:
|
||||
True if the config is valid and should be kept, False if it should be filtered out
|
||||
"""
|
||||
# If hash checking is disabled or no hash map provided, keep the config
|
||||
if not config.lookup_table.check_src_hash or not template_hash_map:
|
||||
return True
|
||||
|
||||
template_hash = template_hash_map.get(template_id)
|
||||
config_hash = cfg.get("template_hash")
|
||||
|
||||
# Both hashes present - validate they match
|
||||
if template_hash is not None and config_hash is not None:
|
||||
if config_hash != template_hash:
|
||||
log.warning(
|
||||
"Hash validation failed for template '%s': config_hash='%s' != template_hash='%s'. "
|
||||
"Template code may have changed. Filtering out config: %s",
|
||||
template_id,
|
||||
config_hash,
|
||||
template_hash,
|
||||
{k: v for k, v in cfg.items() if k != "template_hash"},
|
||||
)
|
||||
return False
|
||||
else:
|
||||
log.debug(
|
||||
"Hash validation passed for template '%s': hash='%s'",
|
||||
template_id,
|
||||
template_hash,
|
||||
)
|
||||
return True
|
||||
# Config has no hash - keep it
|
||||
elif config_hash is None:
|
||||
log.debug(
|
||||
"Config for template '%s' has no hash - keeping it (template_hash='%s')",
|
||||
template_id,
|
||||
template_hash,
|
||||
)
|
||||
return True
|
||||
# Template has no hash - keep config
|
||||
else:
|
||||
log.debug(
|
||||
"Template '%s' has no src_hash - keeping config with hash '%s'",
|
||||
template_id,
|
||||
config_hash,
|
||||
)
|
||||
return True
|
||||
|
||||
def lookup_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
op_name: str,
|
||||
template_uids: list[str],
|
||||
template_hash_map: Optional[dict[str, Optional[str]]] = None,
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
Unified function to look up template configurations for multiple templates.
|
||||
Override this method to customize lookup logic.
|
||||
|
||||
Args:
|
||||
kernel_inputs: KernelInputs object containing input nodes and scalars
|
||||
op_name: Operation name (e.g., "mm", "addmm")
|
||||
template_uids: List of template identifiers (e.g., ["mm", "tma", "decompose_k"])
|
||||
template_hash_map: Optional mapping from template_uid to src_hash for validation
|
||||
|
||||
Returns:
|
||||
{}: No lookup table in use, or no matches found for any template
|
||||
{"template_uid1": [config1, config2], ...}: Matches found, filtered configurations
|
||||
"""
|
||||
lookup_table = self._get_lookup_table()
|
||||
if not lookup_table:
|
||||
log.debug("Lookup table: no table configured or CUDA unavailable")
|
||||
return {}
|
||||
|
||||
# Try both key variants: device-specific first, then device-agnostic
|
||||
# If both exist, device-specific takes priority
|
||||
device_key, device_agnostic_key = self.make_lookup_key_variants(
|
||||
kernel_inputs, op_name
|
||||
)
|
||||
|
||||
config_list = []
|
||||
|
||||
for key_type, key in [
|
||||
("device-specific", device_key),
|
||||
("device-agnostic", device_agnostic_key),
|
||||
]:
|
||||
if key is not None:
|
||||
config_list = lookup_table.get(key, [])
|
||||
if config_list:
|
||||
log.debug(
|
||||
"Lookup table: found %d configs using %s key '%s' for %s",
|
||||
len(config_list),
|
||||
key_type,
|
||||
key,
|
||||
op_name,
|
||||
)
|
||||
break
|
||||
else:
|
||||
log.debug(
|
||||
"Lookup table: no match for %s (tried keys: %s, %s) (table has %d keys)",
|
||||
op_name,
|
||||
device_key,
|
||||
device_agnostic_key,
|
||||
len(lookup_table),
|
||||
)
|
||||
return {}
|
||||
|
||||
log.debug(
|
||||
"Lookup table: found %d configs for %s templates %s",
|
||||
len(config_list),
|
||||
op_name,
|
||||
template_uids,
|
||||
)
|
||||
# Group configs by template_id
|
||||
configs_by_template: dict[str, list[dict[str, Any]]] = {}
|
||||
for cfg in config_list:
|
||||
if not isinstance(cfg, dict):
|
||||
raise ValueError(
|
||||
f"Config for {op_name} operation is not a dictionary: {cfg}"
|
||||
)
|
||||
if "template_id" not in cfg:
|
||||
raise ValueError(
|
||||
f"Config for {op_name} operation missing required 'template_id' field: {cfg}"
|
||||
)
|
||||
|
||||
template_id = cfg["template_id"]
|
||||
if template_id in template_uids:
|
||||
if template_id not in configs_by_template:
|
||||
configs_by_template[template_id] = []
|
||||
configs_by_template[template_id].append(cfg)
|
||||
|
||||
# Check template hashes and clean up template_id field
|
||||
result = {}
|
||||
for template_id, matching_configs in configs_by_template.items():
|
||||
filtered_configs = []
|
||||
for cfg in matching_configs:
|
||||
# Check template hash using helper function
|
||||
if not self._entry_is_valid(cfg, template_id, template_hash_map):
|
||||
continue
|
||||
|
||||
# Return a copy of the config, as we don't want to modify the original
|
||||
cconfig = copy.deepcopy(cfg)
|
||||
# Lastly, we have to throw out the template_id, as it's not a valid kwarg
|
||||
# and just used to identify which template the entry belongs to
|
||||
del cconfig["template_id"]
|
||||
# Similarly, the template_hash is not a valid kwarg
|
||||
cconfig.pop("template_hash", None)
|
||||
filtered_configs.append(cconfig)
|
||||
|
||||
if filtered_configs:
|
||||
result[template_id] = filtered_configs
|
||||
|
||||
return result
|
||||
|
||||
def _finalize_template_configs(
|
||||
self,
|
||||
template_choices: dict[str, Generator[KernelTemplateChoice, None, None]],
|
||||
kernel_inputs: KernelInputs,
|
||||
templates: list[Union[KernelTemplate, ExternKernelChoice]],
|
||||
op_name: str,
|
||||
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
|
||||
) -> list[KernelTemplateChoice]:
|
||||
"""Check lookup table for hits, use those if found, otherwise fall back to parent."""
|
||||
# 1. Collect template src_hashes for validation
|
||||
template_uids = [template.uid for template in templates]
|
||||
template_hash_map = {}
|
||||
for template in templates:
|
||||
src_hash = getattr(template, "src_hash", None)
|
||||
template_hash_map[template.uid] = src_hash
|
||||
|
||||
log.debug(
|
||||
"Choices: attempting lookup for %s with %d templates",
|
||||
op_name,
|
||||
len(template_uids),
|
||||
)
|
||||
|
||||
# 2. Single batch lookup for all templates
|
||||
lookup_results = self.lookup_template_configs(
|
||||
kernel_inputs, op_name, template_uids, template_hash_map
|
||||
)
|
||||
|
||||
# 3. Early exit if no lookup table or no matches
|
||||
if not lookup_results: # Empty dict
|
||||
log.info("LookupChoices: lookup miss for %s, using fallback", op_name)
|
||||
return self._fallback(
|
||||
template_choices,
|
||||
kernel_inputs,
|
||||
templates,
|
||||
op_name,
|
||||
kwarg_overrides,
|
||||
)
|
||||
|
||||
log.info(
|
||||
"LookupChoices: lookup hit for %s - found %d/%d templates: %s",
|
||||
op_name,
|
||||
len(lookup_results),
|
||||
len(template_uids),
|
||||
list(lookup_results.keys()),
|
||||
)
|
||||
|
||||
# 4. Create KTCs only for templates with lookup entries
|
||||
return self._create_lookup_choices(
|
||||
lookup_results, templates, kernel_inputs, op_name
|
||||
)
|
||||
|
||||
def _fallback(
|
||||
self,
|
||||
template_choices: dict[str, Generator[KernelTemplateChoice, None, None]],
|
||||
kernel_inputs: KernelInputs,
|
||||
templates: list[Union[KernelTemplate, ExternKernelChoice]],
|
||||
op_name: str,
|
||||
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
|
||||
) -> list[KernelTemplateChoice]:
|
||||
"""Fallback to parent if no lookup table or no matches."""
|
||||
# NOTE: this is broken out, so that subclasses are able to override this
|
||||
# to handle explicitly the situations where the lookup take had a miss vs
|
||||
# overriding the entire logic
|
||||
return super()._finalize_template_configs(
|
||||
template_choices,
|
||||
kernel_inputs,
|
||||
templates,
|
||||
op_name,
|
||||
kwarg_overrides,
|
||||
)
|
||||
|
||||
def _create_lookup_choices(
|
||||
self,
|
||||
lookup_results: dict[str, list[dict[str, Any]]],
|
||||
templates: list[Union[KernelTemplate, ExternKernelChoice]],
|
||||
kernel_inputs: KernelInputs,
|
||||
op_name: str,
|
||||
) -> list[KernelTemplateChoice]:
|
||||
"""Create KernelTemplateChoice objects from lookup results using parent's get_ktc method."""
|
||||
templates_by_uid = {template.uid: template for template in templates}
|
||||
lookup_choices: list[KernelTemplateChoice] = []
|
||||
|
||||
for template_uid, configs in lookup_results.items():
|
||||
template = templates_by_uid[template_uid]
|
||||
|
||||
# Use parent's get_ktc method to get a generator, then get the first base KTC
|
||||
ktc_generator = self.get_ktc(kernel_inputs, template, op_name)
|
||||
|
||||
try:
|
||||
base_ktc = next(ktc_generator)
|
||||
except StopIteration:
|
||||
# No configs from heuristic, skip this template
|
||||
continue
|
||||
|
||||
# For each lookup config, create a KTC with the override kwargs
|
||||
for c in configs:
|
||||
lookup_ktc = KernelTemplateChoice(
|
||||
template=base_ktc.template,
|
||||
# use the ones from the lookup table
|
||||
params=DictKernelTemplateParams(c),
|
||||
extra_kwargs=base_ktc.extra_kwargs,
|
||||
layout=base_ktc.layout,
|
||||
inputs=base_ktc.inputs,
|
||||
)
|
||||
lookup_choices.append(lookup_ktc)
|
||||
|
||||
return lookup_choices
|
||||
Reference in New Issue
Block a user