mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Config fuzzer (#139736)
This tool makes it easy to search through config state-space with a minimal reproduction or test. It presents a similar interface to the config bisector by taking a test_function that should either raise on Exception or return False upon failure. It has two entry points: `fuzz_n_tuple`, which tries every combination of n configs, and `bisect`, which randomly flips configs and tries to find the minimal reproduction upon failure. `bisect` is a much more efficient way to search the space, but `fuzz_n_tuple` can give you peace of mind that a new config will compose with every other config. It's been used to find three bugs so far in the inductor config: https://github.com/pytorch/pytorch/issues/140220 https://github.com/pytorch/pytorch/issues/140219 https://github.com/pytorch/pytorch/issues/143524 This PR also adds a bunch of missing types to the inductor config to get them to play nice with the fuzzer, so it can be a good forcing function for adding types to config. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139736 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
334ee8ba40
commit
1376116ab1
157
test/inductor/test_fuzzer.py
Normal file
157
test/inductor/test_fuzzer.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from typing import List, Literal
|
||||
|
||||
import torch
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.fuzzer import ConfigFuzzer, SamplingMethod, Status
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal import fake_config_module as fake_config
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
def create_simple_test_model_cpu():
|
||||
def test_fn() -> bool:
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1)
|
||||
)
|
||||
|
||||
x = torch.randn(32, 10)
|
||||
model(x)
|
||||
return True
|
||||
|
||||
return test_fn
|
||||
|
||||
|
||||
def create_simple_test_model_gpu():
|
||||
batch_size = 32
|
||||
seq_length = 50
|
||||
hidden_size = 768
|
||||
|
||||
inp = torch.randn(batch_size, seq_length, hidden_size, device="cuda")
|
||||
weight = torch.randn(hidden_size, hidden_size, device="cuda")
|
||||
|
||||
def test_fn() -> bool:
|
||||
matmul_output = inp @ weight
|
||||
torch.nn.LayerNorm(hidden_size, device="cuda")(matmul_output)
|
||||
return True
|
||||
|
||||
return test_fn
|
||||
|
||||
|
||||
class TestConfigFuzzer(TestCase):
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
|
||||
def test_sampling_method_toggle(self):
|
||||
toggle = SamplingMethod.dispatch(SamplingMethod.TOGGLE)
|
||||
self.assertEqual(toggle("", bool, False), True)
|
||||
self.assertEqual(toggle("", bool, True), False)
|
||||
self.assertEqual(toggle("", Literal["foo", "bar"], "foo"), "bar")
|
||||
self.assertEqual(toggle("", Literal["foo", "bar"], "bar"), "foo")
|
||||
self.assertTrue("bar" in toggle("", List[Literal["foo", "bar"]], ["foo"]))
|
||||
self.assertTrue("foo" in toggle("", List[Literal["foo", "bar"]], ["bar"]))
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
|
||||
def test_sampling_method_random(self):
|
||||
random = SamplingMethod.dispatch(SamplingMethod.RANDOM)
|
||||
samp = [random("", bool, False) for i in range(1000)]
|
||||
self.assertTrue(not all(samp))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "requires gpu")
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
|
||||
def test_config_fuzzer_inductor_gpu(self):
|
||||
fuzzer = ConfigFuzzer(inductor_config, create_simple_test_model_gpu, seed=30)
|
||||
self.assertIsNotNone(fuzzer.default)
|
||||
fuzzer.reproduce([{"max_fusion_size": 1}])
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
|
||||
def test_config_fuzzer_inductor_cpu(self):
|
||||
fuzzer = ConfigFuzzer(inductor_config, create_simple_test_model_cpu, seed=100)
|
||||
self.assertIsNotNone(fuzzer.default)
|
||||
fuzzer.reproduce([{"max_fusion_size": 1}])
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
|
||||
def test_config_fuzzer_bisector_exception(self):
|
||||
key_1 = {"e_bool": False, "e_optional": None}
|
||||
|
||||
class MyException(Exception):
|
||||
pass
|
||||
|
||||
def create_key_1():
|
||||
def myfn():
|
||||
if not fake_config.e_bool and fake_config.e_optional is None:
|
||||
raise MyException("hi")
|
||||
return True
|
||||
|
||||
return myfn
|
||||
|
||||
fuzzer = ConfigFuzzer(fake_config, create_key_1, seed=100, default={})
|
||||
results = fuzzer.bisect(num_attempts=2, p=1.0)
|
||||
self.assertEqual(len(results), 2)
|
||||
for res in results:
|
||||
self.assertEqual(res, key_1)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
|
||||
def test_config_fuzzer_bisector_boolean(self):
|
||||
key_1 = {"e_bool": False, "e_optional": None}
|
||||
|
||||
def create_key_1():
|
||||
def myfn():
|
||||
if not fake_config.e_bool and fake_config.e_optional is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
return myfn
|
||||
|
||||
fuzzer = ConfigFuzzer(fake_config, create_key_1, seed=100, default={})
|
||||
num_attempts = 2
|
||||
results = fuzzer.bisect(num_attempts=num_attempts, p=1.0)
|
||||
self.assertEqual(len(results), num_attempts)
|
||||
for res in results:
|
||||
self.assertEqual(res, key_1)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
|
||||
def test_config_fuzzer_n_tuple(self):
|
||||
key_1 = {"e_bool": False, "e_optional": None}
|
||||
|
||||
def create_key_1():
|
||||
def myfn():
|
||||
if not fake_config.e_bool and fake_config.e_optional is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
return myfn
|
||||
|
||||
fuzzer = ConfigFuzzer(fake_config, create_key_1, seed=100, default={})
|
||||
max_combo = 100
|
||||
results = fuzzer.fuzz_n_tuple(2, max_combinations=max_combo)
|
||||
self.assertEqual(results.num_ran(), max_combo)
|
||||
self.assertEqual(results.lookup(tuple(key_1.keys())), Status.FAILED_RUN_RETURN)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
|
||||
def test_config_fuzzer_inductor_bisect(self):
|
||||
# these values just chosen randomly, change to different ones if necessary
|
||||
key_1 = {"split_reductions": False, "compute_all_bounds": True}
|
||||
|
||||
def create_key_1():
|
||||
def myfn():
|
||||
if (
|
||||
not inductor_config.split_reductions
|
||||
and inductor_config.compute_all_bounds
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
return myfn
|
||||
|
||||
fuzzer = ConfigFuzzer(inductor_config, create_key_1, seed=100, default={})
|
||||
num_attempts = 2
|
||||
results = fuzzer.bisect(num_attempts=num_attempts, p=1.0)
|
||||
self.assertEqual(len(results), num_attempts)
|
||||
for res in results:
|
||||
self.assertEqual(res, key_1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1,6 +1,16 @@
|
||||
import os # noqa: C101
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch._inductor.custom_graph_pass
|
||||
@ -88,7 +98,7 @@ autotune_remote_cache: Optional[bool] = autotune_remote_cache_default()
|
||||
bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default()
|
||||
|
||||
# Force disabled all inductor level caching -- This will override any other caching flag
|
||||
force_disable_caches = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1"
|
||||
force_disable_caches: bool = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1"
|
||||
|
||||
# sleep in inductor for testing
|
||||
sleep_sec_TESTING_ONLY: Optional[int] = None
|
||||
@ -98,16 +108,19 @@ sleep_sec_TESTING_ONLY: Optional[int] = None
|
||||
# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
|
||||
# If the custom op does not have a layout constraint tag already
|
||||
# then we assume the following applies.
|
||||
custom_op_default_layout_constraint = "needs_fixed_stride_order"
|
||||
custom_op_default_layout_constraint: Literal[
|
||||
"needs_fixed_stride_order", "flexible_layout"
|
||||
] = "needs_fixed_stride_order"
|
||||
|
||||
# The default layout constraint for user-defined triton kernels.
|
||||
# See "The default layout constraint for custom operators" for options.
|
||||
triton_kernel_default_layout_constraint = "needs_fixed_stride_order"
|
||||
triton_kernel_default_layout_constraint: Literal[
|
||||
"needs_fixed_stride_order", "flexible_layout"
|
||||
] = "needs_fixed_stride_order"
|
||||
|
||||
# use cpp wrapper instead of python wrapper
|
||||
cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
|
||||
|
||||
c_shim_version = os.environ.get("TORCHINDUCTOR_C_SHIM_VERSION", "2")
|
||||
# incompatible with disable_cpp_codegen
|
||||
cpp_wrapper: bool = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
|
||||
|
||||
# dead code elimination
|
||||
dce = False
|
||||
@ -136,7 +149,9 @@ memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1"
|
||||
# - "intermediates": all non-outputs share storage, outputs each get unique storage
|
||||
# - "outputs": two pools, one for intermediates (freed on return) and one for outputs
|
||||
# - "combined": a single pool for both intermediates and outputs
|
||||
memory_pool = os.environ.get("TORCHINDUCTOR_MEMORY_POOL", "intermediates")
|
||||
memory_pool: Literal["none", "intermediates", "outputs", "combined"] = os.environ.get(
|
||||
"TORCHINDUCTOR_MEMORY_POOL", "intermediates"
|
||||
) # type: ignore[assignment]
|
||||
|
||||
# codegen benchmark harness
|
||||
benchmark_harness = True
|
||||
@ -262,16 +277,25 @@ fx_passes_numeric_check: Dict[str, Any] = {
|
||||
# - If autotune is disabled, this config will always be chosen.
|
||||
# - If autotune is enabled, it will also compare with fallback aten implementation and fused kernel.
|
||||
# The use_mixed_mm flag will be ignored if mixed_mm_choice != "default".
|
||||
mixed_mm_choice = "heuristic"
|
||||
mixed_mm_choice: Literal["default", "triton", "aten", "heuristic"] = "heuristic"
|
||||
|
||||
# enable reordering pass for increasing overlap between compute and communication
|
||||
# only use with fsdp
|
||||
reorder_for_compute_comm_overlap = False
|
||||
|
||||
# passes (in execution order) for increasing overlap between compute and communication
|
||||
# for built-in passes, use string name; for user-defined passes, pass in the function handle
|
||||
# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
|
||||
# hence custom IR passes built on top of it might break in the future.
|
||||
reorder_for_compute_comm_overlap_passes = [
|
||||
reorder_for_compute_comm_overlap_passes: List[
|
||||
Union[
|
||||
str,
|
||||
Callable[
|
||||
[List["torch._inductor.scheduler.BaseSchedulerNode"]],
|
||||
List["torch._inductor.scheduler.BaseSchedulerNode"],
|
||||
],
|
||||
]
|
||||
] = [
|
||||
"reorder_compute_for_overlap",
|
||||
"sink_waits",
|
||||
"raise_comms",
|
||||
@ -334,9 +358,9 @@ max_autotune_conv_backends = os.environ.get(
|
||||
# Specify the size of the search space for GEMM autotuning.
|
||||
# DEFAULT - balance between compile time overhead and performance
|
||||
# EXHAUSTIVE - maximize performance
|
||||
max_autotune_gemm_search_space = os.environ.get(
|
||||
max_autotune_gemm_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get(
|
||||
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT"
|
||||
).upper()
|
||||
).upper() # type: ignore[assignment]
|
||||
|
||||
# Whether we fall back to ATen or hard error when no matches are found during autotuning
|
||||
autotune_fallback_to_aten = (
|
||||
@ -441,10 +465,10 @@ aggressive_fusion = False
|
||||
|
||||
# For each fused kernel in the wrapper, comment with the nodes that get fused.
|
||||
# Useful for debugging fusion.
|
||||
debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
|
||||
benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
|
||||
debug_fusion: bool = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
|
||||
benchmark_fusion: bool = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
|
||||
enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
|
||||
loop_ordering_after_fusion = (
|
||||
loop_ordering_after_fusion: bool = (
|
||||
os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1"
|
||||
)
|
||||
|
||||
@ -637,6 +661,7 @@ def decide_compile_threads() -> int:
|
||||
compile_threads: Optional[int] = None if is_fbcode() else decide_compile_threads()
|
||||
|
||||
# gemm autotuning global cache dir
|
||||
global_cache_dir: Optional[str]
|
||||
if is_fbcode():
|
||||
try:
|
||||
from libfb.py import parutil
|
||||
@ -727,7 +752,9 @@ profile_bandwidth = _profile_var != ""
|
||||
profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var
|
||||
# Specify a file where we print out the profiling results.
|
||||
# None means we do not dump results to a file.
|
||||
profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None)
|
||||
profile_bandwidth_output: Optional[str] = os.environ.get(
|
||||
"TORCHINDUCTOR_PROFILE_OUTPUT", None
|
||||
)
|
||||
# Switch to do_bench_using_profiling to exclude the CPU overheads
|
||||
profile_bandwidth_with_do_bench_using_profiling = (
|
||||
os.environ.get("TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING") == "1"
|
||||
@ -735,6 +762,7 @@ profile_bandwidth_with_do_bench_using_profiling = (
|
||||
|
||||
|
||||
# TODO: remove later
|
||||
# incompatible with cpp_wrapper
|
||||
disable_cpp_codegen = False
|
||||
|
||||
|
||||
@ -764,7 +792,7 @@ unsafe_ignore_unsupported_triton_autotune_args: bool = False
|
||||
# When True, we will check in scheduler.py _codegen that there are no "loops"
|
||||
# in the call stack; that is to say, the same frame multiple times. This
|
||||
# ensures that a cProfile trace to this frame will be a straight line without
|
||||
# any cycles.
|
||||
# any cycles. Incompatible with cpp_wrapper.
|
||||
check_stack_no_cycles_TESTING_ONLY: bool = False
|
||||
|
||||
# When True, complex_memory_overlap always reports True
|
||||
@ -799,15 +827,12 @@ class cpp:
|
||||
|
||||
simdlen: Optional[int] = None
|
||||
min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096"))
|
||||
cxx = (
|
||||
|
||||
cxx: Tuple[None, str] = (
|
||||
None, # download gcc12 from conda-forge if conda is installed
|
||||
# "g++-12",
|
||||
# "g++-11",
|
||||
# "g++-10",
|
||||
# "clang++",
|
||||
os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
|
||||
# "g++.par",
|
||||
)
|
||||
) # type: ignore[assignment]
|
||||
|
||||
# Allow kernel performance profiling via PyTorch profiler
|
||||
enable_kernel_profile = (
|
||||
os.environ.get("TORCHINDUCTOR_CPP_ENABLE_KERNEL_PROFILE", "0") == "1"
|
||||
@ -827,7 +852,9 @@ class cpp:
|
||||
vec_isa_ok: Optional[bool] = None
|
||||
|
||||
# similar to config.triton.descriptive_names
|
||||
descriptive_names = "original_aten"
|
||||
descriptive_names: Union[
|
||||
bool, Literal["torch", "original_aten", "inductor_node"]
|
||||
] = "original_aten"
|
||||
|
||||
# how many nodes to allow into a single horizontal fusion
|
||||
max_horizontal_fusion_size = int(
|
||||
@ -984,7 +1011,9 @@ class triton:
|
||||
# "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
|
||||
# "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
|
||||
# "inductor_node": Maps to the node name in the FX graph passed to Inductor
|
||||
descriptive_names = "original_aten"
|
||||
descriptive_names: Union[
|
||||
bool, Literal["torch", "original_aten", "inductor_node"]
|
||||
] = "original_aten"
|
||||
|
||||
# use alternate codegen for smaller reductions
|
||||
persistent_reductions = (
|
||||
@ -999,11 +1028,13 @@ class triton:
|
||||
# used for debugging cooperative reduction codegen, always generate cooperative_reductions
|
||||
force_cooperative_reductions = False
|
||||
|
||||
# 0/False: disable
|
||||
# 0: disable
|
||||
# 1/True: enable, use tuning to pick between different subkernels
|
||||
# 2: enable, force using persistent reduction (for debugging)
|
||||
# 3: enable, force using non-persistent reduction (for debugging)
|
||||
multi_kernel = int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0"))
|
||||
multi_kernel: Literal[0, 1, 2, 3] = int(
|
||||
os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0")
|
||||
) # type: ignore[assignment]
|
||||
|
||||
# hint to Triton when arguments are divisible by 16
|
||||
divisible_by_16 = True
|
||||
@ -1064,9 +1095,9 @@ class aot_inductor:
|
||||
# 1: enable saving intermediate tensor values
|
||||
# 2: enable printing intermediate tensor values
|
||||
# 3: enable printing kernel names only (useful for pinpointing troublesome kernels)
|
||||
debug_intermediate_value_printer = os.environ.get(
|
||||
debug_intermediate_value_printer: Literal["0", "1", "2", "3"] = os.environ.get(
|
||||
"AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0"
|
||||
)
|
||||
) # type: ignore[assignment]
|
||||
|
||||
# filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2
|
||||
filtered_kernel_names = os.environ.get(
|
||||
@ -1137,7 +1168,7 @@ class cuda:
|
||||
version: Optional[str] = None
|
||||
|
||||
# Optimization level for the host compiler.
|
||||
compile_opt_level = "-O1"
|
||||
compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1"
|
||||
|
||||
# Whether to enable device LTO (link-time-optimization).
|
||||
enable_cuda_lto = False
|
||||
@ -1207,8 +1238,11 @@ class rocm:
|
||||
# Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
|
||||
ck_supported_arch: List[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
|
||||
|
||||
# Optimization level, use to balance compilation speed and runtime performance
|
||||
compile_opt_level = "-O2"
|
||||
# Optimization level, use to balance compilation speed and runtime performance.
|
||||
# The type will not necessarily be comprehensive and won't be enforced at runtime.
|
||||
compile_opt_level: Literal[
|
||||
"-O0", "-O1", "-O2", "-O3", "-Os", "-Oz", "-Omin", "-Ofast", "-Omax"
|
||||
] = "-O2"
|
||||
|
||||
# Flag to keep debug information in compiled objects
|
||||
is_debug = False
|
||||
@ -1247,10 +1281,10 @@ class rocm:
|
||||
|
||||
|
||||
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
|
||||
cpu_backend = "cpp"
|
||||
cpu_backend: Literal["cpp", "triton", "halide"] = "cpp"
|
||||
|
||||
# Backend to use for CUDA codegen either "triton" or "halide" (experimental)
|
||||
cuda_backend = "triton"
|
||||
cuda_backend: Literal["triton", "halide"] = "triton"
|
||||
|
||||
|
||||
class halide:
|
||||
@ -1262,8 +1296,12 @@ class halide:
|
||||
|
||||
# Halide autoscheduler to use, choices are:
|
||||
# "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
|
||||
scheduler_cuda = "Anderson2021"
|
||||
scheduler_cpu = "Adams2019"
|
||||
scheduler_cuda: Literal[
|
||||
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"
|
||||
] = "Anderson2021"
|
||||
scheduler_cpu: Literal[
|
||||
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"
|
||||
] = "Adams2019"
|
||||
|
||||
# Controls `no_asserts` flag passed to Halide target (warning: can false positive)
|
||||
asserts = False
|
||||
@ -1344,7 +1382,7 @@ class trace:
|
||||
log_inductor_triton_kernel_to_post_grad_node_info: bool = True
|
||||
|
||||
|
||||
_save_config_ignore = [
|
||||
_save_config_ignore: List[str] = [
|
||||
# workaround: "Can't pickle <function ...>"
|
||||
"trace.upload_tar",
|
||||
"joint_custom_pre_pass",
|
||||
@ -1352,7 +1390,7 @@ _save_config_ignore = [
|
||||
"pre_grad_custom_pass",
|
||||
]
|
||||
|
||||
_cache_config_ignore_prefix = [
|
||||
_cache_config_ignore_prefix: List[str] = [
|
||||
# trace functions are not relevant to config caching
|
||||
"trace",
|
||||
# uses absolute path
|
||||
|
1029
torch/_inductor/fuzzer.py
Normal file
1029
torch/_inductor/fuzzer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -187,14 +187,17 @@ def install_config_module(module: ModuleType) -> None:
|
||||
continue
|
||||
|
||||
name = f"{prefix}{key}"
|
||||
annotated_type = type_hints.get(key, None)
|
||||
if isinstance(value, CONFIG_TYPES):
|
||||
annotated_type = type_hints.get(key, None)
|
||||
config[name] = _ConfigEntry(
|
||||
_Config(default=value, value_type=annotated_type)
|
||||
)
|
||||
if dest is module:
|
||||
delattr(module, key)
|
||||
elif isinstance(value, _Config):
|
||||
if annotated_type is not None and value.value_type is None:
|
||||
value.value_type = annotated_type
|
||||
|
||||
config[name] = _ConfigEntry(value)
|
||||
|
||||
if dest is module:
|
||||
|
Reference in New Issue
Block a user