Compare commits

..

1 Commits

Author SHA1 Message Date
e3d00beddd Fix triu_/tril_ overlap handling 2025-10-21 07:54:24 -07:00
16 changed files with 172 additions and 361 deletions

View File

@ -39,7 +39,7 @@ struct HostBlock {
};
template <typename B>
struct alignas(hardware_destructive_interference_size) FreeBlockList {
struct alignas(64) FreeBlockList {
std::mutex mutex_;
std::deque<B*> list_;
};
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
// Struct containing memory allocator summary statistics for host, as they
// are staged for reporting. This is a temporary struct that is used to
// avoid locking the allocator while collecting stats.
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
struct alignas(64) HostStatsStaged {
std::mutex timing_mutex_;
// COUNT: total allocations (active + free)
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
}
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
alignas(64) std::mutex blocks_mutex_;
ska::flat_hash_set<B*> blocks_; // block list
ska::flat_hash_map<void*, B*> ptr_to_block_;
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
// size. This allows us to quickly find a free block of the right size.
// We use deque to store per size free list and guard the list with its own
// mutex.
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
alignas(64) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{true};
protected:
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
alignas(64) HostStatsStaged stats_;
};
struct TORCH_API HostAllocator : public at::Allocator {

View File

@ -141,6 +141,8 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
return;
}
checkTrilTriuMemoryOverlap(result, self);
bool inplace_op = self.is_same(result);
bool inplace_update = false;

View File

@ -1,3 +1,4 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/LinearAlgebraUtils.h>
@ -54,4 +55,13 @@ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor
return std::make_tuple(true, tensor);
}
static inline void checkTrilTriuMemoryOverlap(const Tensor& result, const Tensor& self) {
if (result.is_same(self)) {
at::assert_no_internal_overlap(result);
} else {
at::assert_no_internal_overlap(result);
at::assert_no_overlap(result, self);
}
}
} // namespace at::native

View File

@ -5,6 +5,7 @@
#include <ATen/Dispatch.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TriangularOpsUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -110,6 +111,8 @@ __global__ void triu_tril_kernel(
template <bool upper>
void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
checkTrilTriuMemoryOverlap(result, self);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
at::ScalarType::ComplexHalf,
at::ScalarType::Half,

View File

@ -9,7 +9,6 @@
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/alignment.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

View File

@ -1,7 +1,6 @@
#pragma once
#include <cstddef>
#include <new>
namespace c10 {
@ -19,12 +18,4 @@ constexpr size_t gPagesize = 4096;
// since the default thp pagesize is 2MB, enable thp only
// for buffers of size 2MB or larger to avoid memory bloating
constexpr size_t gAlloc_threshold_thp = static_cast<size_t>(2) * 1024 * 1024;
// Cache line size used to avoid false sharing between threads. Falls back to 64
// bytes if C++17 feature is unavailable.
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
} // namespace c10

View File

@ -941,7 +941,7 @@ class EventPool {
private:
struct PerDevicePool {
alignas(hardware_destructive_interference_size) std::mutex mutex_;
alignas(64) std::mutex mutex_;
std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
};
std::vector<PerDevicePool> pools_;
@ -3758,6 +3758,11 @@ static void uncached_delete(void* ptr) {
static void local_raw_delete(void* ptr);
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
thread_local std::string DeviceCachingAllocator::user_metadata;
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
static constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
class NativeCachingAllocator : public CUDAAllocator {
private:

View File

@ -554,7 +554,7 @@ static void local_raw_delete(void* ptr);
class XPUAllocator : public DeviceAllocator {
private:
alignas(hardware_destructive_interference_size) std::mutex mutex;
std::mutex mutex;
ska::flat_hash_map<void*, Block*> allocated_blocks;
void add_allocated_block(Block* block) {

View File

@ -1839,22 +1839,12 @@ class TestStandaloneCompile(TestCase):
@parametrize("format", ("binary", "unpacked"))
@parametrize("dynamic", (False, True))
@parametrize("graph_partition", (False, True))
@parametrize("is_aot", (False, True))
def test_basic(
self,
device: str,
format: str,
dynamic: bool,
graph_partition: bool,
is_aot: bool,
self, device: str, format: str, dynamic: bool, graph_partition: bool
) -> None:
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
# AOT mode does not support unpacked format
if is_aot and format == "unpacked":
raise unittest.SkipTest("AOT mode does not support unpacked format")
mod = torch.nn.Linear(1, 3, device=device)
x = torch.randn(4, 1, device=device)
if dynamic:
@ -1879,9 +1869,7 @@ class TestStandaloneCompile(TestCase):
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
gm, args, aot=is_aot
)
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact.save(path=path, format=format)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
@ -1897,15 +1885,13 @@ class TestStandaloneCompile(TestCase):
compiled_out = loaded(*concrete_args)
self.assertEqual(eager_out, compiled_out)
if not is_aot:
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("dynamic", (False, True))
@parametrize("is_aot", (False, True))
def test_call_in_backend(self, dynamic: bool, is_aot: bool) -> None:
def test_call_in_backend(self, dynamic: bool) -> None:
mod = torch.nn.Linear(1, 3)
x = torch.randn(4, 1)
if dynamic:
@ -1918,7 +1904,7 @@ class TestStandaloneCompile(TestCase):
eager_out = f(x)
def backend(gm, args, **kwargs):
return torch._inductor.standalone_compile(gm, args, aot=is_aot)
return torch._inductor.standalone_compile(gm, args)
with fresh_cache():
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
@ -2069,8 +2055,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (False, True))
def test_dynamic_shapes_from_graph(self, is_aot: bool):
def test_dynamic_shapes_from_graph(self):
def f(x):
return x.shape[0] * x
@ -2082,7 +2067,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes="from_graph", aot=is_aot
gm, args, dynamic_shapes="from_graph"
)
x = torch.ones(4)
(result,) = compiled_artifact(4, x)
@ -2092,8 +2077,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"autograd_cache_normalize_inputs": True})
@parametrize("is_aot", (False, True))
def test_split_module(self, is_aot):
def test_split_module(self):
class Mod(torch.nn.Module):
def forward(self, x, a0, a1, b0, b1, c0, c1):
x = x + (a0**2) + (a1 / 2)
@ -2132,24 +2116,16 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
split = torch.fx.passes.split_module.split_module(gm, gm, split)
# Each of the split graphs only has one output.
ca0 = torch._inductor.standalone_compile(
split.submod_0, (a0, x, a1), aot=is_aot
)
ca1 = torch._inductor.standalone_compile(
split.submod_1, (b0, x, b1), aot=is_aot
)
ca2 = torch._inductor.standalone_compile(
split.submod_2, (c0, x, c1), aot=is_aot
)
ca0 = torch._inductor.standalone_compile(split.submod_0, (a0, x, a1))
ca1 = torch._inductor.standalone_compile(split.submod_1, (b0, x, b1))
ca2 = torch._inductor.standalone_compile(split.submod_2, (c0, x, c1))
y = ca0(a0, x, a1)
y = ca1(b0, y, b1)
y = ca2(c0, y, c1)
if not is_aot:
# fx graph cache doesn't run in AOT mode
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
# TODO: split_module causes ca1 and ca2 to have different type annotations
# for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
@ -2162,9 +2138,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (False, True))
@parametrize("config_patches", [True, False])
def test_dynamic_shapes_from_example_inputs(self, config_patches, is_aot):
def test_dynamic_shapes_from_example_inputs(self, config_patches):
def f(x):
return x.shape[0] * x
@ -2186,7 +2161,6 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
(5, torch.ones(4)),
dynamic_shapes="from_example_inputs",
options={"config_patches": config_patches},
aot=is_aot,
)
x = torch.ones(4)
(result,) = compiled_artifact(3, x)
@ -2201,9 +2175,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (True, False))
@parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"])
def test_static_shapes(self, dynamic_shapes, is_aot):
def test_static_shapes(self, dynamic_shapes):
def f(x):
return x.shape[0] * x
@ -2213,7 +2186,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
static_gm, [static_x], dynamic_shapes=dynamic_shapes, aot=is_aot
static_gm, [static_x], dynamic_shapes=dynamic_shapes
)
x = torch.randn(3)
(result,) = compiled_artifact(x)
@ -2225,9 +2198,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (True, False))
@parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"])
def test_backend(self, dynamic_shapes, is_aot):
def test_backend(self, dynamic_shapes):
def f(x):
return x.shape[0] * x
@ -2236,7 +2208,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes=dynamic_shapes, aot=is_aot
gm, args, dynamic_shapes=dynamic_shapes
)
y = torch.randn(4)
(result,) = compiled_artifact(4, y)
@ -2249,8 +2221,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (True, False))
def test_backend_dynamic_shapes_from_example_inputs(self, is_aot):
def test_backend_dynamic_shapes_from_example_inputs(self):
def f(x):
return x.shape[0] * x
@ -2259,7 +2230,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs", aot=is_aot
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs"
)
y = torch.ones(4)
(result,) = compiled_artifact(4, y)

View File

@ -9986,6 +9986,20 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.assertEqual(result_triu_min, expected_triu_min)
self.assertEqual(result_tril_min, expected_tril_min)
@dtypes(torch.float)
def test_triu_tril_inplace_memory_overlap(self, device, dtype):
base = torch.rand((), dtype=dtype, device=device)
expanded = base.expand(3, 3)
msg = (
"unsupported operation: more than one element of the written-to tensor "
"refers to a single memory location. Please clone() the tensor before "
"performing the operation."
)
with self.assertRaisesRegex(RuntimeError, msg):
expanded.triu_(1)
with self.assertRaisesRegex(RuntimeError, msg):
expanded.tril_(-1)
@dtypes(torch.float, torch.double)
@precisionOverride({torch.float32: 1e-4})
def test_1_sized_with_0_strided(self, device, dtype):

View File

@ -754,10 +754,6 @@ def align_trace_from_beginning(
# Rank 3: [0, 1, 2, 3, 4, 5, None]
# Then we should start from collective 2 not 0 because any collective before,
# we don't have complete records from all ranks so we need to ignore them.
# If we don't have any trace from some ranks, ignore them
# as well.
if len(entries[rank]) == 0:
continue
first_record_id = entries[rank][0]["record_id"]
maximum_starting_record_id = max(maximum_starting_record_id, first_record_id)

View File

@ -1,3 +1,4 @@
import abc
import dataclasses
import importlib
import inspect
@ -14,10 +15,6 @@ from torch._dynamo.graph_utils import _graph_device_type
from torch._dynamo.package import SystemInfo
from . import convert_frame
from .aot_compile_types import (
BundledAOTAutogradSerializableCallable,
SerializableCallable,
)
from .hooks import Hooks
@ -29,6 +26,18 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
class SerializableCallable(abc.ABC):
@classmethod
@abc.abstractmethod
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
pass
@classmethod
@abc.abstractmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
pass
def bind_locals(
signature: inspect.Signature, *args: Any, **kwargs: Any
) -> dict[str, Any]:
@ -140,6 +149,53 @@ class AOTCompiledFunction:
self._guard_check_enabled = False
class BundledAOTAutogradSerializableCallable(SerializableCallable):
"""
Represents a serializable callable generated by compile_fx.
This class wraps around the compiled function generated by AOTAutograd.
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
this object should be what's *returned* by aot_module_simplified.
We'll do that refactor in a later PR.
"""
def __init__(self, compiled_fn: Any) -> None:
"""
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
of a compiled function generated by AOTAutograd.
"""
assert hasattr(compiled_fn, "serialize")
self.compiled_fn = compiled_fn
def __getattr__(self, attr: Any) -> Any:
if hasattr(self, attr):
return getattr(super(), attr)
else:
return getattr(self.compiled_fn, attr)
@classmethod
def serialize_compile_artifacts(
cls, fn: "BundledAOTAutogradSerializableCallable"
) -> bytes:
with torch._functorch.config.patch("bundled_autograd_cache", True):
result = pickle.dumps(fn.compiled_fn.serialize())
return result
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
from torch._functorch._aot_autograd.autograd_cache import (
deserialize_bundled_cache_entry,
)
entry = pickle.loads(data)
compiled_fn = deserialize_bundled_cache_entry(entry)
return cls(compiled_fn)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.compiled_fn(*args, **kwargs)
def aot_compile_fullgraph(
model: Any,
example_inputs: tuple[tuple[Any, ...], dict[str, Any]],

View File

@ -1,61 +0,0 @@
import abc
import pickle
from typing import Any
import torch
class SerializableCallable(abc.ABC):
@classmethod
@abc.abstractmethod
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
pass
@classmethod
@abc.abstractmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
pass
class BundledAOTAutogradSerializableCallable(SerializableCallable):
"""
Represents a serializable callable generated by compile_fx.
This class wraps around the compiled function generated by AOTAutograd.
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
this object should be what's *returned* by aot_module_simplified.
We'll do that refactor in a later PR.
"""
def __init__(self, compiled_fn: Any) -> None:
"""
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
of a compiled function generated by AOTAutograd.
"""
assert hasattr(compiled_fn, "serialize")
self.compiled_fn = compiled_fn
def __getattr__(self, attr: Any) -> Any:
return getattr(self.compiled_fn, attr)
@classmethod
def serialize_compile_artifacts(
cls, fn: "BundledAOTAutogradSerializableCallable"
) -> bytes:
with torch._functorch.config.patch("bundled_autograd_cache", True):
result = pickle.dumps(fn.compiled_fn.serialize())
return result
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
from torch._functorch._aot_autograd.autograd_cache import (
deserialize_bundled_cache_entry,
)
entry = pickle.loads(data)
compiled_fn = deserialize_bundled_cache_entry(entry)
return cls(compiled_fn)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.compiled_fn(*args, **kwargs)

View File

@ -391,7 +391,6 @@ def standalone_compile(
"from_example_inputs", "from_tracing_context", "from_graph"
] = "from_graph",
options: Optional[dict[str, Any]] = None,
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
) -> CompiledArtifact:
"""
Precompilation API for inductor.
@ -423,5 +422,5 @@ def standalone_compile(
options = options if options else {}
return standalone_compile(
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options
)

View File

@ -5,12 +5,10 @@ import logging
import os
import pickle
import shutil
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
import torch.fx
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
from torch._dynamo.utils import dynamo_timed
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
@ -32,9 +30,9 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
class CompiledArtifact(ABC):
class CompiledArtifact:
"""
CompiledArtifact class represents the inductor cache artifacts that
CompiledArtifact class represents the precompiled inductor artifact that
can be invoked in order to avoid repeated compilation.
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
@ -47,68 +45,11 @@ class CompiledArtifact(ABC):
binary or unpacked data.
Finally, the CompiledArtifact can be invoked via the __call__ method
to execute the cached artifact.
to execute the precompiled artifact.
"""
def __init__(
self,
compiled_fn: Callable[..., Any],
artifacts: Optional[tuple[bytes, CacheInfo]],
):
self._compiled_fn = compiled_fn
self._artifacts = artifacts
@abstractmethod
def __call__(self, *args: Any) -> Any: ...
@abstractmethod
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None: ...
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
if format == "unpacked":
# If format is unpacked, it must be a CacheCompiledArtifact
return CacheCompiledArtifact.load(path=path, format=format)
assert format == "binary"
with open(path, "rb") as file:
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
result_bytes = file.read()
reader = BytesReader(result_bytes)
header = reader.read_bytes()
if header == AOTCompiledArtifact.AOT_HEADER:
assert reader.read_bytes() == torch_key()
artifact = reader.read_bytes()
assert reader.is_finished()
return AOTCompiledArtifact.deserialize(artifact)
# Otherwise, it's in the CacheCompiledArtifact format
elif header == CacheCompiledArtifact.CACHE_HEADER:
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
return CacheCompiledArtifact._load_impl(nullcontext(), key)
else:
raise RuntimeError(
"Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: "
+ header.decode("utf-8")
)
class CacheCompiledArtifact(CompiledArtifact):
"""
CompiledArtifact that depends on torch.compiler.save_cache_artifacts
"""
CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8")
_compiled_fn: Callable[..., Any]
_artifacts: Optional[tuple[bytes, CacheInfo]]
def __init__(
self,
@ -142,7 +83,6 @@ class CacheCompiledArtifact(CompiledArtifact):
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER)
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
@ -176,51 +116,9 @@ class CacheCompiledArtifact(CompiledArtifact):
log.info("Output code written to: %s", output_file)
@staticmethod
def _load_impl(
cache_dir_ctx: AbstractContextManager[Any], key: str
) -> CompiledArtifact:
with (
cache_dir_ctx,
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
):
with torch._functorch.config.patch(strict_autograd_cache=True):
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
)
result = AOTAutogradCache._lookup(
key,
local=True,
remote=False,
args=[],
cache_info={},
aot_config=None,
)
assert result is not None
(entry, _) = result
from .compile_fx import _CompileFxKwargs
fx_config = _CompileFxKwargs(
cudagraphs=BoxedBool(False),
boxed_forward_device_index=BoxedDeviceIndex(0),
)
context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv()))
with torch._guards.tracing(context):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None)
@staticmethod
def _prepare_load(
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> tuple[str, AbstractContextManager[Any]]:
"""
Do format specific prep and loads, return a context manager and key
"""
) -> CompiledArtifact:
path = normalize_path_separator(path)
with dynamo_timed("CompiledArtifact.load"):
if format == "binary":
@ -239,7 +137,8 @@ class CacheCompiledArtifact(CompiledArtifact):
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
return key, nullcontext()
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
else:
assert format == "unpacked"
assert os.path.isdir(path)
@ -249,105 +148,43 @@ class CacheCompiledArtifact(CompiledArtifact):
assert len(files) == 1
key = files[0]
cache_dir_ctx = temporary_cache_dir(path)
return key, cache_dir_ctx
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
key, cache_dir_ctx = CacheCompiledArtifact._prepare_load(
path=path, format=format
)
return CacheCompiledArtifact._load_impl(cache_dir_ctx, key)
with (
cache_dir_ctx,
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
):
with torch._functorch.config.patch(strict_autograd_cache=True):
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
)
result = AOTAutogradCache._lookup(
key,
local=True,
remote=False,
args=[],
cache_info={},
aot_config=None,
)
class AOTCompiledArtifact(CompiledArtifact):
"""
Similar to CompiledArtifact, but the object is a single, bundled precompiled function.
This object is always a serializable callable function.
assert result is not None
(entry, _) = result
This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which
is used by torch._dynamo.aot_compile for AOT Precompilation.
"""
from .compile_fx import _CompileFxKwargs
AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8")
fx_config = _CompileFxKwargs(
cudagraphs=BoxedBool(False),
boxed_forward_device_index=BoxedDeviceIndex(0),
)
def __init__(
self,
compiled_fn: Callable[..., Any],
):
self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
self._artifacts = (
None # We don't need artifacts, the inner object handles everything
)
@staticmethod
def from_bundled_callable(
bundled_fn: BundledAOTAutogradSerializableCallable,
) -> AOTCompiledArtifact:
return AOTCompiledArtifact(bundled_fn.compiled_fn)
def __call__(self, *args: Any) -> Any:
return self.inner_fn(*args)
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
if format == "unpacked":
raise RuntimeError(
"AOTCompiledArtifact does not support unpacked format yet"
)
result_bytes = self.serialize()
from torch.utils._appending_byte_serializer import BytesWriter
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(AOTCompiledArtifact.AOT_HEADER)
writer.write_bytes(torch_key())
writer.write_bytes(result_bytes)
from torch._inductor.codecache import write_atomic
# Save a sentinel file to indicate that this is AOT
write_atomic(path, writer.to_bytes())
def serialize(self) -> bytes:
return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts(
self.inner_fn
)
@staticmethod
def deserialize(result_bytes: bytes) -> AOTCompiledArtifact:
deserialized = (
BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts(
result_bytes
)
)
assert isinstance(deserialized, BundledAOTAutogradSerializableCallable)
return AOTCompiledArtifact.from_bundled_callable(deserialized)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
if format == "unpacked":
raise RuntimeError(
"AOTCompiledArtifact does not support unpacked format yet"
)
with open(path, "rb") as file:
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
result_bytes = file.read()
reader = BytesReader(result_bytes)
header = reader.read_bytes()
assert header == AOTCompiledArtifact.AOT_HEADER
assert reader.read_bytes() == torch_key()
artifact = reader.read_bytes()
assert reader.is_finished()
return AOTCompiledArtifact.deserialize(artifact)
context = torch._guards.TracingContext(
FakeTensorMode(shape_env=ShapeEnv())
)
with torch._guards.tracing(context):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
def standalone_compile(
@ -356,11 +193,7 @@ def standalone_compile(
*,
dynamic_shapes: Any,
options: Any,
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
) -> CompiledArtifact:
"""
Implementation of torch.inductor.standalone_compile
"""
from torch.compiler._cache import CacheArtifactManager
from .compile_fx import compile_fx
@ -416,7 +249,6 @@ def standalone_compile(
torch._guards.tracing(context),
CacheArtifactManager.with_fresh_cache(),
config.patch("triton.autotune_at_compile_time", True),
torch._functorch.config.patch("bundled_autograd_cache", aot),
):
# compile_fx can mutate gm
gm = copy.deepcopy(gm)
@ -424,12 +256,7 @@ def standalone_compile(
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
)
assert callable(compiled_fn)
if aot:
if not hasattr(compiled_fn, "serialize"):
raise RuntimeError(
"Compiled function should have serialize method when aot=True"
)
return AOTCompiledArtifact(compiled_fn)
artifacts = torch.compiler.save_cache_artifacts()
if artifacts is None:
log.warning(
@ -437,4 +264,4 @@ def standalone_compile(
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
)
return CacheCompiledArtifact(compiled_fn, artifacts)
return CompiledArtifact(compiled_fn, artifacts)

View File

@ -43,8 +43,6 @@ def _partition_by_supported_nodes(gm, supported_ops, prefix):
def _compile_submod(gm, prefix):
from torch._inductor.standalone_compile import AOTCompiledArtifact
for node in gm.graph.nodes:
if node.op == "call_module" and node.target.startswith(prefix):
fake_inputs = []
@ -58,12 +56,13 @@ def _compile_submod(gm, prefix):
submod = getattr(gm, node.target)
compiled_fn = torch._inductor.standalone_compile(
submod, fake_inputs, dynamic_shapes="from_tracing_context", aot=True
)
assert isinstance(compiled_fn, AOTCompiledArtifact)
# _dummy_wrapper is to make call_function happy
compiled_submod = _dummy_wrapper(compiled_fn)
compiled_submod = _dummy_wrapper(
torch._inductor.standalone_compile(
submod, fake_inputs, dynamic_shapes="from_tracing_context"
)
)
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
compiled_submod, args=node.args, kwargs=node.kwargs