From 510e8b4ae0a61d02f773ec81964113a82beedfb7 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 4 Aug 2025 21:50:58 +0000 Subject: [PATCH] [inductor] use writable temp file on windows (#159738) Use `WritableTempFile` on Windows, reference to: https://github.com/pytorch/pytorch/pull/159342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159738 Approved by: https://github.com/angelayi, https://github.com/Skylion007 --- test/dynamo/test_fx_graph_runnable.py | 4 +- test/dynamo/test_misc.py | 3 +- test/inductor/test_aot_inductor.py | 3 +- test/inductor/test_torchbind.py | 4 +- torch/_inductor/codecache.py | 61 ++++++++++++++------------- 5 files changed, 39 insertions(+), 36 deletions(-) diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py index 0164b6f9c680..d5ad0c160c4b 100644 --- a/test/dynamo/test_fx_graph_runnable.py +++ b/test/dynamo/test_fx_graph_runnable.py @@ -3,12 +3,12 @@ import io import logging import subprocess import sys -import tempfile import unittest import torch import torch._logging.structured import torch.distributed as dist +from torch._inductor.codecache import WritableTempFile from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE @@ -79,7 +79,7 @@ class FxGraphRunnableTest(TestCase): self.assertTrue(payload, "Expected fx_graph_runnable payload but got nothing") self.assertIn("def forward", payload) # sanity-check for actual FX code - with tempfile.NamedTemporaryFile("w", suffix=".py") as tmp: + with WritableTempFile("w", suffix=".py") as tmp: tmp.write(payload) tmp.flush() res = subprocess.run( diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 9c23c3e4a495..82c0368c5b15 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -54,6 +54,7 @@ from torch._dynamo.testing import ( ) from torch._dynamo.utils import call_size, counters, ifdynstaticdefault from torch._dynamo.variables import builder +from torch._inductor.codecache import WritableTempFile from torch._inductor.utils import fresh_cache, run_and_get_code from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize @@ -11243,7 +11244,7 @@ class AAA: def fn(): return 3 """ - with tempfile.NamedTemporaryFile(mode="w") as f: + with WritableTempFile(mode="w") as f: f.write(src) f.flush() from torch._dynamo.funcname_cache import get_funcname diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 1e95673b32f9..63a30103d378 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -20,6 +20,7 @@ from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.testing import rand_strided, same from torch._dynamo.utils import counters from torch._inductor import config +from torch._inductor.codecache import WritableTempFile from torch._inductor.package import package_aoti from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.test_case import TestCase @@ -5602,7 +5603,7 @@ class AOTInductorTestsTemplate: example_inputs=example_inputs, ) - with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + with WritableTempFile(suffix=".pt2") as f: package_path = package_aoti( f.name, {"model": aoti_files}, diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index 40695e6affb1..631a4fce31fd 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -1,6 +1,5 @@ # Owner(s): ["module: functorch"] import json -import tempfile import zipfile from pathlib import Path @@ -11,6 +10,7 @@ import torch._inductor import torch._inductor.decomposition from torch._higher_order_ops.torchbind import CallTorchBind, enable_torchbind_tracing from torch._inductor import aot_compile, ir +from torch._inductor.codecache import WritableTempFile from torch._inductor.package import package_aoti from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu @@ -280,7 +280,7 @@ class TestTorchbind(TestCase): ) # Test that the files are packaged - with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + with WritableTempFile(suffix=".pt2") as f: package_path = package_aoti(f.name, aoti_files) with zipfile.ZipFile(package_path, "r") as zip_ref: diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 53f166f183d5..451f72f62169 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -30,6 +30,7 @@ from ctypes import c_void_p, CDLL, cdll from datetime import timedelta from functools import lru_cache, partial from pathlib import Path +from tempfile import _TemporaryFileWrapper from time import time, time_ns from types import ModuleType from typing import ( @@ -359,6 +360,36 @@ def get_hash( raise AssertionError(f"Unknown hash type {hash_type}") +class WritableTempFile: + """ + Avoid "Permission denied error" on Windows: + with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file: + # Not writable on Windows: + # https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile + + Example: + with WritableTempFile("w", suffix=".gv") as temp_file: + tree.to_dotfile(temp_file.name) + """ + + def __init__( + self, mode: str = "w", *, encoding: Any = None, suffix: Any = None + ) -> None: + self.mode = mode + self.encoding = encoding + self.suffix = suffix + + def __enter__(self) -> _TemporaryFileWrapper[Any]: + self.temp_file = tempfile.NamedTemporaryFile( + self.mode, encoding=self.encoding, suffix=self.suffix, delete=False + ) + return self.temp_file + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.temp_file.close() + os.unlink(self.temp_file.name) + + def write( content: Union[str, bytes], extension: str, @@ -1622,36 +1653,6 @@ class CudaKernelParamCache: return cls.cache.keys() -class WritableTempFile: - """ - Avoid "Permission denied error" on Windows: - with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file: - # Not writable on Windows: - # https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile - - Example: - with WritableTempFile("w", suffix=".gv") as temp_file: - tree.to_dotfile(temp_file.name) - """ - - def __init__( - self, mode: str = "w", *, encoding: Any = None, suffix: Any = None - ) -> None: - self.mode = mode - self.encoding = encoding - self.suffix = suffix - - def __enter__(self) -> Any: - self.temp_file = tempfile.NamedTemporaryFile( - self.mode, encoding=self.encoding, suffix=self.suffix, delete=False - ) - return self.temp_file - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self.temp_file.close() - os.unlink(self.temp_file.name) - - class AotCodeCompiler: """ Compile AOT Inductor generated code.