[inductor] use writable temp file on windows (#159738)

Use `WritableTempFile` on Windows, reference to: https://github.com/pytorch/pytorch/pull/159342

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159738
Approved by: https://github.com/angelayi, https://github.com/Skylion007
This commit is contained in:
Xu Han
2025-08-04 21:50:58 +00:00
committed by PyTorch MergeBot
parent 83ba3f1101
commit 510e8b4ae0
5 changed files with 39 additions and 36 deletions

View File

@ -3,12 +3,12 @@ import io
import logging
import subprocess
import sys
import tempfile
import unittest
import torch
import torch._logging.structured
import torch.distributed as dist
from torch._inductor.codecache import WritableTempFile
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
@ -79,7 +79,7 @@ class FxGraphRunnableTest(TestCase):
self.assertTrue(payload, "Expected fx_graph_runnable payload but got nothing")
self.assertIn("def forward", payload) # sanity-check for actual FX code
with tempfile.NamedTemporaryFile("w", suffix=".py") as tmp:
with WritableTempFile("w", suffix=".py") as tmp:
tmp.write(payload)
tmp.flush()
res = subprocess.run(

View File

@ -54,6 +54,7 @@ from torch._dynamo.testing import (
)
from torch._dynamo.utils import call_size, counters, ifdynstaticdefault
from torch._dynamo.variables import builder
from torch._inductor.codecache import WritableTempFile
from torch._inductor.utils import fresh_cache, run_and_get_code
from torch.ao.quantization import MinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
@ -11243,7 +11244,7 @@ class AAA:
def fn():
return 3
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
with WritableTempFile(mode="w") as f:
f.write(src)
f.flush()
from torch._dynamo.funcname_cache import get_funcname

View File

@ -20,6 +20,7 @@ from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.testing import rand_strided, same
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.codecache import WritableTempFile
from torch._inductor.package import package_aoti
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import TestCase
@ -5602,7 +5603,7 @@ class AOTInductorTestsTemplate:
example_inputs=example_inputs,
)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
with WritableTempFile(suffix=".pt2") as f:
package_path = package_aoti(
f.name,
{"model": aoti_files},

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: functorch"]
import json
import tempfile
import zipfile
from pathlib import Path
@ -11,6 +10,7 @@ import torch._inductor
import torch._inductor.decomposition
from torch._higher_order_ops.torchbind import CallTorchBind, enable_torchbind_tracing
from torch._inductor import aot_compile, ir
from torch._inductor.codecache import WritableTempFile
from torch._inductor.package import package_aoti
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
@ -280,7 +280,7 @@ class TestTorchbind(TestCase):
)
# Test that the files are packaged
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
with WritableTempFile(suffix=".pt2") as f:
package_path = package_aoti(f.name, aoti_files)
with zipfile.ZipFile(package_path, "r") as zip_ref:

View File

@ -30,6 +30,7 @@ from ctypes import c_void_p, CDLL, cdll
from datetime import timedelta
from functools import lru_cache, partial
from pathlib import Path
from tempfile import _TemporaryFileWrapper
from time import time, time_ns
from types import ModuleType
from typing import (
@ -359,6 +360,36 @@ def get_hash(
raise AssertionError(f"Unknown hash type {hash_type}")
class WritableTempFile:
"""
Avoid "Permission denied error" on Windows:
with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file:
# Not writable on Windows:
# https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile
Example:
with WritableTempFile("w", suffix=".gv") as temp_file:
tree.to_dotfile(temp_file.name)
"""
def __init__(
self, mode: str = "w", *, encoding: Any = None, suffix: Any = None
) -> None:
self.mode = mode
self.encoding = encoding
self.suffix = suffix
def __enter__(self) -> _TemporaryFileWrapper[Any]:
self.temp_file = tempfile.NamedTemporaryFile(
self.mode, encoding=self.encoding, suffix=self.suffix, delete=False
)
return self.temp_file
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.temp_file.close()
os.unlink(self.temp_file.name)
def write(
content: Union[str, bytes],
extension: str,
@ -1622,36 +1653,6 @@ class CudaKernelParamCache:
return cls.cache.keys()
class WritableTempFile:
"""
Avoid "Permission denied error" on Windows:
with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file:
# Not writable on Windows:
# https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile
Example:
with WritableTempFile("w", suffix=".gv") as temp_file:
tree.to_dotfile(temp_file.name)
"""
def __init__(
self, mode: str = "w", *, encoding: Any = None, suffix: Any = None
) -> None:
self.mode = mode
self.encoding = encoding
self.suffix = suffix
def __enter__(self) -> Any:
self.temp_file = tempfile.NamedTemporaryFile(
self.mode, encoding=self.encoding, suffix=self.suffix, delete=False
)
return self.temp_file
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.temp_file.close()
os.unlink(self.temp_file.name)
class AotCodeCompiler:
"""
Compile AOT Inductor generated code.