mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] use writable temp file on windows (#159738)
Use `WritableTempFile` on Windows, reference to: https://github.com/pytorch/pytorch/pull/159342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159738 Approved by: https://github.com/angelayi, https://github.com/Skylion007
This commit is contained in:
@ -3,12 +3,12 @@ import io
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._logging.structured
|
||||
import torch.distributed as dist
|
||||
from torch._inductor.codecache import WritableTempFile
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
|
||||
|
||||
@ -79,7 +79,7 @@ class FxGraphRunnableTest(TestCase):
|
||||
self.assertTrue(payload, "Expected fx_graph_runnable payload but got nothing")
|
||||
self.assertIn("def forward", payload) # sanity-check for actual FX code
|
||||
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py") as tmp:
|
||||
with WritableTempFile("w", suffix=".py") as tmp:
|
||||
tmp.write(payload)
|
||||
tmp.flush()
|
||||
res = subprocess.run(
|
||||
|
@ -54,6 +54,7 @@ from torch._dynamo.testing import (
|
||||
)
|
||||
from torch._dynamo.utils import call_size, counters, ifdynstaticdefault
|
||||
from torch._dynamo.variables import builder
|
||||
from torch._inductor.codecache import WritableTempFile
|
||||
from torch._inductor.utils import fresh_cache, run_and_get_code
|
||||
from torch.ao.quantization import MinMaxObserver
|
||||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
@ -11243,7 +11244,7 @@ class AAA:
|
||||
def fn():
|
||||
return 3
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
with WritableTempFile(mode="w") as f:
|
||||
f.write(src)
|
||||
f.flush()
|
||||
from torch._dynamo.funcname_cache import get_funcname
|
||||
|
@ -20,6 +20,7 @@ from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.testing import rand_strided, same
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codecache import WritableTempFile
|
||||
from torch._inductor.package import package_aoti
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.test_case import TestCase
|
||||
@ -5602,7 +5603,7 @@ class AOTInductorTestsTemplate:
|
||||
example_inputs=example_inputs,
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
with WritableTempFile(suffix=".pt2") as f:
|
||||
package_path = package_aoti(
|
||||
f.name,
|
||||
{"model": aoti_files},
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: functorch"]
|
||||
import json
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
@ -11,6 +10,7 @@ import torch._inductor
|
||||
import torch._inductor.decomposition
|
||||
from torch._higher_order_ops.torchbind import CallTorchBind, enable_torchbind_tracing
|
||||
from torch._inductor import aot_compile, ir
|
||||
from torch._inductor.codecache import WritableTempFile
|
||||
from torch._inductor.package import package_aoti
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
|
||||
@ -280,7 +280,7 @@ class TestTorchbind(TestCase):
|
||||
)
|
||||
|
||||
# Test that the files are packaged
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
with WritableTempFile(suffix=".pt2") as f:
|
||||
package_path = package_aoti(f.name, aoti_files)
|
||||
|
||||
with zipfile.ZipFile(package_path, "r") as zip_ref:
|
||||
|
@ -30,6 +30,7 @@ from ctypes import c_void_p, CDLL, cdll
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache, partial
|
||||
from pathlib import Path
|
||||
from tempfile import _TemporaryFileWrapper
|
||||
from time import time, time_ns
|
||||
from types import ModuleType
|
||||
from typing import (
|
||||
@ -359,6 +360,36 @@ def get_hash(
|
||||
raise AssertionError(f"Unknown hash type {hash_type}")
|
||||
|
||||
|
||||
class WritableTempFile:
|
||||
"""
|
||||
Avoid "Permission denied error" on Windows:
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file:
|
||||
# Not writable on Windows:
|
||||
# https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile
|
||||
|
||||
Example:
|
||||
with WritableTempFile("w", suffix=".gv") as temp_file:
|
||||
tree.to_dotfile(temp_file.name)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, mode: str = "w", *, encoding: Any = None, suffix: Any = None
|
||||
) -> None:
|
||||
self.mode = mode
|
||||
self.encoding = encoding
|
||||
self.suffix = suffix
|
||||
|
||||
def __enter__(self) -> _TemporaryFileWrapper[Any]:
|
||||
self.temp_file = tempfile.NamedTemporaryFile(
|
||||
self.mode, encoding=self.encoding, suffix=self.suffix, delete=False
|
||||
)
|
||||
return self.temp_file
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.temp_file.close()
|
||||
os.unlink(self.temp_file.name)
|
||||
|
||||
|
||||
def write(
|
||||
content: Union[str, bytes],
|
||||
extension: str,
|
||||
@ -1622,36 +1653,6 @@ class CudaKernelParamCache:
|
||||
return cls.cache.keys()
|
||||
|
||||
|
||||
class WritableTempFile:
|
||||
"""
|
||||
Avoid "Permission denied error" on Windows:
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file:
|
||||
# Not writable on Windows:
|
||||
# https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile
|
||||
|
||||
Example:
|
||||
with WritableTempFile("w", suffix=".gv") as temp_file:
|
||||
tree.to_dotfile(temp_file.name)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, mode: str = "w", *, encoding: Any = None, suffix: Any = None
|
||||
) -> None:
|
||||
self.mode = mode
|
||||
self.encoding = encoding
|
||||
self.suffix = suffix
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
self.temp_file = tempfile.NamedTemporaryFile(
|
||||
self.mode, encoding=self.encoding, suffix=self.suffix, delete=False
|
||||
)
|
||||
return self.temp_file
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.temp_file.close()
|
||||
os.unlink(self.temp_file.name)
|
||||
|
||||
|
||||
class AotCodeCompiler:
|
||||
"""
|
||||
Compile AOT Inductor generated code.
|
||||
|
Reference in New Issue
Block a user