[AoTI Minifier] UX Improvement (#143330)

Summary:
- When a user specify `TORCHINDUCTOR_MAX_AUTOTUNE=1` env variable, we add `config.max_autotune=True` to the generated minifier_launcher
- We should do this to other inductor configs as well in a followup Diff

Currently in dynamo and aoti minifier, if a config is overwritten by an env variable, the config will not show up in the config list in the minifier_launcher.py file. As a result, when running the minifier_launcher, they need to re-apply the same env variable.
 This is:
1) not convenient for the users
2) if they copy-paste the minifier_launcher.py to us without including the env variable, we could be confused and not able to reproduce the error.

Underlying implementation change:

- Add `env_default` parameter to `codegen_config()`. If set, configs overriden by the env are not considered default.

Test Plan:
```
 buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:utils -- -r test_codegen_config
```

Differential Revision: D67299312

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143330
Approved by: https://github.com/jansel, https://github.com/eellison
This commit is contained in:
Shangdi Yu
2025-01-07 20:04:19 +00:00
committed by PyTorch MergeBot
parent 096cb874d3
commit 72e8f34715
8 changed files with 78 additions and 9 deletions

View File

@ -1,11 +1,13 @@
# Owner(s): ["module: dynamo"]
import os
import unittest
from unittest.mock import patch
import torch
from functorch import make_fx
from torch._dynamo import debug_utils
from torch._dynamo.debug_utils import aot_graph_input_parser
from torch._dynamo.debug_utils import aot_graph_input_parser, generate_env_vars_string
from torch._dynamo.test_case import TestCase
from torch.testing._internal.inductor_utils import HAS_CUDA
@ -172,6 +174,25 @@ def forward(self, x_1):
self.assertEqual(list(kwargs["primals_4"].shape), [5])
self.assertEqual(kwargs["primals_5"], 5)
@patch.dict(os.environ, {"TORCHINDUCTOR_MAX_AUTOTUNE": "1", "TEST_ENV": "1"})
def test_generate_env_vars_string(self):
env_strings = generate_env_vars_string()
self.assertIn(
"""os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1'
""",
env_strings,
)
self.assertIn(
"""import os
""",
env_strings,
)
self.assertNotIn(
"""TEST_ENV
""",
env_strings,
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -173,6 +173,9 @@ class TestConfigModule(TestCase):
self.assertEqual(
code,
"""torch.testing._internal.fake_config_module.e_bool = False
torch.testing._internal.fake_config_module.e_env_default = True
torch.testing._internal.fake_config_module.e_env_default_FALSE = False
torch.testing._internal.fake_config_module.e_env_force = True
torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""",
)

View File

@ -1,10 +1,8 @@
# mypy: allow-untyped-defs
import getpass
import inspect
import os
import re
import sys
import tempfile
from os.path import abspath, dirname
from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union
@ -437,10 +435,6 @@ def default_debug_dir_root():
DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR"
if DEBUG_DIR_VAR_NAME in os.environ:
return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug")
elif is_fbcode():
return os.path.join(
tempfile.gettempdir(), getpass.getuser(), "torch_compile_debug"
)
else:
return os.path.join(os.getcwd(), "torch_compile_debug")

View File

@ -250,6 +250,31 @@ def _cuda_system_info_comment():
return model_str
def generate_env_vars_string(*, stable_output=False):
"""
Generate a string configuration for environment variables related to Dynamo, Inductor, and Triton.
"""
if stable_output:
return "# env var omitted due to stable_output=True"
allow_list = ["TORCH", "DYNAMO", "INDUCTOR", "TRITON"]
skip_list = ["TRITON_LIBDEVICE_PATH", "TRITON_PTXAS_PATH", "TRITON_LIBCUDA_PATH"]
def filter(key):
return any(string in key for string in allow_list) and key not in skip_list
config_lines = [
f"os.environ['{key}'] = '{value}'"
for key, value in os.environ.items()
if filter(key)
]
config_string = "\n".join(config_lines)
return f"""\
import os
{config_string}
"""
def generate_config_string(*, stable_output=False):
import torch._functorch.config
import torch._inductor.config

View File

@ -28,6 +28,7 @@ from torch._dynamo.debug_utils import (
extra_deps,
extra_imports,
generate_config_string,
generate_env_vars_string,
helper_for_dump_minify,
InputReader,
InputWriter,
@ -264,6 +265,7 @@ def generate_compiler_repro_string(
):
model_str = textwrap.dedent(
f"""
{generate_env_vars_string(stable_output=stable_output)}
import torch
from torch import tensor, device
import torch.fx as fx

View File

@ -20,6 +20,7 @@ from torch._dynamo.debug_utils import (
BuckTargetWriter,
extra_imports,
generate_config_string,
generate_env_vars_string,
helper_for_dump_minify,
InputReader,
InputWriter,
@ -179,6 +180,7 @@ def generate_dynamo_fx_repro_string(
return textwrap.dedent(
f"""
{generate_env_vars_string(stable_output=stable_output)}
from math import inf
import torch
from torch import tensor, device

View File

@ -17,6 +17,7 @@ from torch._dynamo.debug_utils import (
BuckTargetWriter,
extra_imports,
generate_config_string,
generate_env_vars_string,
helper_for_dump_minify,
InputReader,
minifier_dir,
@ -193,6 +194,7 @@ def generate_compiler_repro_exported_program(
):
model_str = textwrap.dedent(
f"""
{generate_env_vars_string(stable_output=stable_output)}
import torch
import torch._inductor.inductor_prims
@ -455,7 +457,7 @@ default settings on this script:
)
subparsers = parser.add_subparsers(
dest="command", metavar="{run,minify,analyze}", required=True
dest="command", metavar="{run,minify}", required=True
)
parser_run = subparsers.add_parser(

View File

@ -408,7 +408,27 @@ class ConfigModule(ModuleType):
setattr(module, constant_name, val)
def _is_default(self, name: str) -> bool:
return self._config[name].user_override is _UNSET_SENTINEL
"""
Returns true if the config is at its default value.
configs overriden by the env are not considered default.
"""
config_val = self._config[name]
# The config is not overridden by the user, and the env_value_default
# is different from the default value (meaning user has set the env to
# change the default value).
not_set_env_default = (
config_val.env_value_default is _UNSET_SENTINEL
or config_val.env_value_default == config_val.default
)
not_set_env_force = (
config_val.env_value_force is _UNSET_SENTINEL
or config_val.env_value_force == config_val.default
)
return (
config_val.user_override is _UNSET_SENTINEL
and not_set_env_default
and not_set_env_force
)
def _get_dict(
self,