Compare commits

..

9 Commits

Author SHA1 Message Date
d6d3367233 Update on "WIP: fix attempt #2"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-24 21:41:42 -07:00
a514a050fa Update base for Update on "WIP: fix attempt #2"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-24 21:41:42 -07:00
1a71669a22 Update on "WIP: fix attempt #2"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-24 09:50:31 -07:00
a91c4ceb08 Update base for Update on "WIP: fix attempt #2"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-24 09:50:31 -07:00
8b3dc0d1b0 Better error handling in torch/csrc/jit/runtime/* (#165118)
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/runtime/*

Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165118
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 15:22:49 +00:00
485c73a947 Update on "WIP: fix attempt #2"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-21 08:12:51 -07:00
06773663b5 Implement an AOT precompile mode for standalone_compile (#165843)
This PR introduces an `aot` flag to standalone_compile that uses BundledAOTAutogradCacheEntry, and then allows regional_inductor to use this so that we can start aot compiling regional compiler graphs. The diff above this will attempt to allow GraphPickler to fully serialize graphs that have regionally compiled subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165843
Approved by: https://github.com/oulgen
2025-10-21 15:02:45 +00:00
0bff65503c Move hardware_destructive_interference_size to c10/core/alignment.h (#160067)
# Motivation
Move `hardware_destructive_interference_size` to `c10/core/alignment.h`, which gives a chance to reuse it across different accelerators.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160067
Approved by: https://github.com/Skylion007, https://github.com/EikanWang
2025-10-21 14:39:46 +00:00
5c075d8bcf WIP: fix attempt #2
[ghstack-poisoned]
2025-10-20 22:03:34 -07:00
22 changed files with 452 additions and 276 deletions

View File

@ -39,7 +39,7 @@ struct HostBlock {
};
template <typename B>
struct alignas(64) FreeBlockList {
struct alignas(hardware_destructive_interference_size) FreeBlockList {
std::mutex mutex_;
std::deque<B*> list_;
};
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
// Struct containing memory allocator summary statistics for host, as they
// are staged for reporting. This is a temporary struct that is used to
// avoid locking the allocator while collecting stats.
struct alignas(64) HostStatsStaged {
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
std::mutex timing_mutex_;
// COUNT: total allocations (active + free)
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
}
alignas(64) std::mutex blocks_mutex_;
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
ska::flat_hash_set<B*> blocks_; // block list
ska::flat_hash_map<void*, B*> ptr_to_block_;
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
// size. This allows us to quickly find a free block of the right size.
// We use deque to store per size free list and guard the list with its own
// mutex.
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
alignas(64) std::mutex events_mutex_;
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{true};
protected:
alignas(64) HostStatsStaged stats_;
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};
struct TORCH_API HostAllocator : public at::Allocator {

View File

@ -141,8 +141,6 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
return;
}
checkTrilTriuMemoryOverlap(result, self);
bool inplace_op = self.is_same(result);
bool inplace_update = false;

View File

@ -1,4 +1,3 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/LinearAlgebraUtils.h>
@ -55,13 +54,4 @@ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor
return std::make_tuple(true, tensor);
}
static inline void checkTrilTriuMemoryOverlap(const Tensor& result, const Tensor& self) {
if (result.is_same(self)) {
at::assert_no_internal_overlap(result);
} else {
at::assert_no_internal_overlap(result);
at::assert_no_overlap(result, self);
}
}
} // namespace at::native

View File

@ -5,7 +5,6 @@
#include <ATen/Dispatch.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TriangularOpsUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -111,8 +110,6 @@ __global__ void triu_tril_kernel(
template <bool upper>
void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
checkTrilTriuMemoryOverlap(result, self);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
at::ScalarType::ComplexHalf,
at::ScalarType::Half,

View File

@ -9,6 +9,7 @@
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/alignment.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

View File

@ -1,6 +1,7 @@
#pragma once
#include <cstddef>
#include <new>
namespace c10 {
@ -18,4 +19,12 @@ constexpr size_t gPagesize = 4096;
// since the default thp pagesize is 2MB, enable thp only
// for buffers of size 2MB or larger to avoid memory bloating
constexpr size_t gAlloc_threshold_thp = static_cast<size_t>(2) * 1024 * 1024;
// Cache line size used to avoid false sharing between threads. Falls back to 64
// bytes if C++17 feature is unavailable.
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
} // namespace c10

View File

@ -941,7 +941,7 @@ class EventPool {
private:
struct PerDevicePool {
alignas(64) std::mutex mutex_;
alignas(hardware_destructive_interference_size) std::mutex mutex_;
std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
};
std::vector<PerDevicePool> pools_;
@ -3758,11 +3758,6 @@ static void uncached_delete(void* ptr) {
static void local_raw_delete(void* ptr);
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
thread_local std::string DeviceCachingAllocator::user_metadata;
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
static constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
class NativeCachingAllocator : public CUDAAllocator {
private:

View File

@ -554,7 +554,7 @@ static void local_raw_delete(void* ptr);
class XPUAllocator : public DeviceAllocator {
private:
std::mutex mutex;
alignas(hardware_destructive_interference_size) std::mutex mutex;
ska::flat_hash_map<void*, Block*> allocated_blocks;
void add_allocated_block(Block* block) {

View File

@ -3064,11 +3064,12 @@ class GraphModule(torch.nn.Module):
primals_6, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
primals_7, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
primals_8, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_10, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
primals_3, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
primals_10, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
primals_1, # SavedForBackwardsAOTOutput(idx=0)
primals_8, # SavedForBackwardsAOTOutput(idx=1)
primals_10, # SavedForBackwardsAOTOutput(idx=2)
primals_3, # SavedForBackwardsAOTOutput(idx=1)
primals_8, # SavedForBackwardsAOTOutput(idx=2)
primals_10, # SavedForBackwardsAOTOutput(idx=3)
)
""", # noqa: B950
)
@ -3080,6 +3081,7 @@ class GraphModule(torch.nn.Module):
def forward(
self,
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
primals_3: "Sym(s55)", # PlainAOTInput(idx=2)
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
tangents_1: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values')
@ -3097,7 +3099,7 @@ class GraphModule(torch.nn.Module):
tangents_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor')
tangents_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor')
primals_8, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0)
primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2)
primals_3, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2)
primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1)
)
""", # noqa: B950
@ -3134,7 +3136,7 @@ class GraphModule(torch.nn.Module):
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
add_2: "Sym(2*s55)" = primals_10 + primals_10
add_2: "Sym(2*s55)" = primals_10 + primals_3; primals_3 = None
return (
cat, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')

View File

@ -1839,12 +1839,22 @@ class TestStandaloneCompile(TestCase):
@parametrize("format", ("binary", "unpacked"))
@parametrize("dynamic", (False, True))
@parametrize("graph_partition", (False, True))
@parametrize("is_aot", (False, True))
def test_basic(
self, device: str, format: str, dynamic: bool, graph_partition: bool
self,
device: str,
format: str,
dynamic: bool,
graph_partition: bool,
is_aot: bool,
) -> None:
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
# AOT mode does not support unpacked format
if is_aot and format == "unpacked":
raise unittest.SkipTest("AOT mode does not support unpacked format")
mod = torch.nn.Linear(1, 3, device=device)
x = torch.randn(4, 1, device=device)
if dynamic:
@ -1869,7 +1879,9 @@ class TestStandaloneCompile(TestCase):
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact = torch._inductor.standalone_compile(
gm, args, aot=is_aot
)
compiled_artifact.save(path=path, format=format)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
@ -1885,13 +1897,15 @@ class TestStandaloneCompile(TestCase):
compiled_out = loaded(*concrete_args)
self.assertEqual(eager_out, compiled_out)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
if not is_aot:
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("dynamic", (False, True))
def test_call_in_backend(self, dynamic: bool) -> None:
@parametrize("is_aot", (False, True))
def test_call_in_backend(self, dynamic: bool, is_aot: bool) -> None:
mod = torch.nn.Linear(1, 3)
x = torch.randn(4, 1)
if dynamic:
@ -1904,7 +1918,7 @@ class TestStandaloneCompile(TestCase):
eager_out = f(x)
def backend(gm, args, **kwargs):
return torch._inductor.standalone_compile(gm, args)
return torch._inductor.standalone_compile(gm, args, aot=is_aot)
with fresh_cache():
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
@ -2055,7 +2069,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_dynamic_shapes_from_graph(self):
@parametrize("is_aot", (False, True))
def test_dynamic_shapes_from_graph(self, is_aot: bool):
def f(x):
return x.shape[0] * x
@ -2067,7 +2082,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes="from_graph"
gm, args, dynamic_shapes="from_graph", aot=is_aot
)
x = torch.ones(4)
(result,) = compiled_artifact(4, x)
@ -2077,7 +2092,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"autograd_cache_normalize_inputs": True})
def test_split_module(self):
@parametrize("is_aot", (False, True))
def test_split_module(self, is_aot):
class Mod(torch.nn.Module):
def forward(self, x, a0, a1, b0, b1, c0, c1):
x = x + (a0**2) + (a1 / 2)
@ -2116,16 +2132,24 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
split = torch.fx.passes.split_module.split_module(gm, gm, split)
# Each of the split graphs only has one output.
ca0 = torch._inductor.standalone_compile(split.submod_0, (a0, x, a1))
ca1 = torch._inductor.standalone_compile(split.submod_1, (b0, x, b1))
ca2 = torch._inductor.standalone_compile(split.submod_2, (c0, x, c1))
ca0 = torch._inductor.standalone_compile(
split.submod_0, (a0, x, a1), aot=is_aot
)
ca1 = torch._inductor.standalone_compile(
split.submod_1, (b0, x, b1), aot=is_aot
)
ca2 = torch._inductor.standalone_compile(
split.submod_2, (c0, x, c1), aot=is_aot
)
y = ca0(a0, x, a1)
y = ca1(b0, y, b1)
y = ca2(c0, y, c1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
if not is_aot:
# fx graph cache doesn't run in AOT mode
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
# TODO: split_module causes ca1 and ca2 to have different type annotations
# for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
@ -2138,8 +2162,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (False, True))
@parametrize("config_patches", [True, False])
def test_dynamic_shapes_from_example_inputs(self, config_patches):
def test_dynamic_shapes_from_example_inputs(self, config_patches, is_aot):
def f(x):
return x.shape[0] * x
@ -2161,6 +2186,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
(5, torch.ones(4)),
dynamic_shapes="from_example_inputs",
options={"config_patches": config_patches},
aot=is_aot,
)
x = torch.ones(4)
(result,) = compiled_artifact(3, x)
@ -2175,8 +2201,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (True, False))
@parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"])
def test_static_shapes(self, dynamic_shapes):
def test_static_shapes(self, dynamic_shapes, is_aot):
def f(x):
return x.shape[0] * x
@ -2186,7 +2213,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
static_gm, [static_x], dynamic_shapes=dynamic_shapes
static_gm, [static_x], dynamic_shapes=dynamic_shapes, aot=is_aot
)
x = torch.randn(3)
(result,) = compiled_artifact(x)
@ -2198,8 +2225,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (True, False))
@parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"])
def test_backend(self, dynamic_shapes):
def test_backend(self, dynamic_shapes, is_aot):
def f(x):
return x.shape[0] * x
@ -2208,7 +2236,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes=dynamic_shapes
gm, args, dynamic_shapes=dynamic_shapes, aot=is_aot
)
y = torch.randn(4)
(result,) = compiled_artifact(4, y)
@ -2221,7 +2249,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_backend_dynamic_shapes_from_example_inputs(self):
@parametrize("is_aot", (True, False))
def test_backend_dynamic_shapes_from_example_inputs(self, is_aot):
def f(x):
return x.shape[0] * x
@ -2230,7 +2259,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs"
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs", aot=is_aot
)
y = torch.ones(4)
(result,) = compiled_artifact(4, y)

View File

@ -9986,20 +9986,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.assertEqual(result_triu_min, expected_triu_min)
self.assertEqual(result_tril_min, expected_tril_min)
@dtypes(torch.float)
def test_triu_tril_inplace_memory_overlap(self, device, dtype):
base = torch.rand((), dtype=dtype, device=device)
expanded = base.expand(3, 3)
msg = (
"unsupported operation: more than one element of the written-to tensor "
"refers to a single memory location. Please clone() the tensor before "
"performing the operation."
)
with self.assertRaisesRegex(RuntimeError, msg):
expanded.triu_(1)
with self.assertRaisesRegex(RuntimeError, msg):
expanded.tril_(-1)
@dtypes(torch.float, torch.double)
@precisionOverride({torch.float32: 1e-4})
def test_1_sized_with_0_strided(self, device, dtype):

View File

@ -1,4 +1,3 @@
import abc
import dataclasses
import importlib
import inspect
@ -15,6 +14,10 @@ from torch._dynamo.graph_utils import _graph_device_type
from torch._dynamo.package import SystemInfo
from . import convert_frame
from .aot_compile_types import (
BundledAOTAutogradSerializableCallable,
SerializableCallable,
)
from .hooks import Hooks
@ -26,18 +29,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
class SerializableCallable(abc.ABC):
@classmethod
@abc.abstractmethod
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
pass
@classmethod
@abc.abstractmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
pass
def bind_locals(
signature: inspect.Signature, *args: Any, **kwargs: Any
) -> dict[str, Any]:
@ -149,53 +140,6 @@ class AOTCompiledFunction:
self._guard_check_enabled = False
class BundledAOTAutogradSerializableCallable(SerializableCallable):
"""
Represents a serializable callable generated by compile_fx.
This class wraps around the compiled function generated by AOTAutograd.
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
this object should be what's *returned* by aot_module_simplified.
We'll do that refactor in a later PR.
"""
def __init__(self, compiled_fn: Any) -> None:
"""
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
of a compiled function generated by AOTAutograd.
"""
assert hasattr(compiled_fn, "serialize")
self.compiled_fn = compiled_fn
def __getattr__(self, attr: Any) -> Any:
if hasattr(self, attr):
return getattr(super(), attr)
else:
return getattr(self.compiled_fn, attr)
@classmethod
def serialize_compile_artifacts(
cls, fn: "BundledAOTAutogradSerializableCallable"
) -> bytes:
with torch._functorch.config.patch("bundled_autograd_cache", True):
result = pickle.dumps(fn.compiled_fn.serialize())
return result
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
from torch._functorch._aot_autograd.autograd_cache import (
deserialize_bundled_cache_entry,
)
entry = pickle.loads(data)
compiled_fn = deserialize_bundled_cache_entry(entry)
return cls(compiled_fn)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.compiled_fn(*args, **kwargs)
def aot_compile_fullgraph(
model: Any,
example_inputs: tuple[tuple[Any, ...], dict[str, Any]],

View File

@ -0,0 +1,61 @@
import abc
import pickle
from typing import Any
import torch
class SerializableCallable(abc.ABC):
@classmethod
@abc.abstractmethod
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
pass
@classmethod
@abc.abstractmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
pass
class BundledAOTAutogradSerializableCallable(SerializableCallable):
"""
Represents a serializable callable generated by compile_fx.
This class wraps around the compiled function generated by AOTAutograd.
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
this object should be what's *returned* by aot_module_simplified.
We'll do that refactor in a later PR.
"""
def __init__(self, compiled_fn: Any) -> None:
"""
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
of a compiled function generated by AOTAutograd.
"""
assert hasattr(compiled_fn, "serialize")
self.compiled_fn = compiled_fn
def __getattr__(self, attr: Any) -> Any:
return getattr(self.compiled_fn, attr)
@classmethod
def serialize_compile_artifacts(
cls, fn: "BundledAOTAutogradSerializableCallable"
) -> bytes:
with torch._functorch.config.patch("bundled_autograd_cache", True):
result = pickle.dumps(fn.compiled_fn.serialize())
return result
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
from torch._functorch._aot_autograd.autograd_cache import (
deserialize_bundled_cache_entry,
)
entry = pickle.loads(data)
compiled_fn = deserialize_bundled_cache_entry(entry)
return cls(compiled_fn)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.compiled_fn(*args, **kwargs)

View File

@ -2188,20 +2188,6 @@ class OutputGraph(OutputGraphCommon):
),
)
self.call_cleanup_hooks()
old_fake_mode = self.tracing_context.fake_mode
assert old_fake_mode is not None
if not self.export:
import torch._functorch.config as _config
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
backend_fake_mode = torch._subclasses.FakeTensorMode(
shape_env=old_fake_mode.shape_env,
)
# TODO(voz): Ostensibily, this should be scoped and
# restore back to old_fake_mode, but doing so currently violates
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
self.tracing_context.fake_mode = backend_fake_mode
with self.restore_global_state():
compiled_fn = self.call_user_compiler(gm, self.example_inputs())
@ -2237,8 +2223,10 @@ class OutputGraph(OutputGraphCommon):
)
counters["stats"]["unique_graphs"] += 1
assert old_fake_mode.shape_env is not None
if specializations := old_fake_mode.shape_env.specializations:
if (
specializations
:= self.tracing_context.fake_mode.shape_env.specializations
):
specialization_guards = []
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
sources = [a.source for a in self.graphargs]

View File

@ -182,7 +182,7 @@ def create_subclass_meta(
def enumerate_filter_symints(lst: Iterable[IntLikeType]) -> list[tuple[int, SymInt]]:
# Capture all SymInts from the iterable.
def symint_check(s: IntLikeType) -> TypeGuard[SymInt]:
return isinstance(s, SymInt) and not s.node.is_nested_int()
return isinstance(s, SymInt) and not s.node.is_nested_int() and not s.node.expr.is_number
return [(i, s) for i, s in enumerate(lst) if symint_check(s)]

View File

@ -391,6 +391,7 @@ def standalone_compile(
"from_example_inputs", "from_tracing_context", "from_graph"
] = "from_graph",
options: Optional[dict[str, Any]] = None,
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
) -> CompiledArtifact:
"""
Precompilation API for inductor.
@ -422,5 +423,5 @@ def standalone_compile(
options = options if options else {}
return standalone_compile(
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot
)

View File

@ -5,10 +5,12 @@ import logging
import os
import pickle
import shutil
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
import torch.fx
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
from torch._dynamo.utils import dynamo_timed
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
@ -30,9 +32,9 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
class CompiledArtifact:
class CompiledArtifact(ABC):
"""
CompiledArtifact class represents the precompiled inductor artifact that
CompiledArtifact class represents the inductor cache artifacts that
can be invoked in order to avoid repeated compilation.
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
@ -45,11 +47,68 @@ class CompiledArtifact:
binary or unpacked data.
Finally, the CompiledArtifact can be invoked via the __call__ method
to execute the precompiled artifact.
to execute the cached artifact.
"""
_compiled_fn: Callable[..., Any]
_artifacts: Optional[tuple[bytes, CacheInfo]]
def __init__(
self,
compiled_fn: Callable[..., Any],
artifacts: Optional[tuple[bytes, CacheInfo]],
):
self._compiled_fn = compiled_fn
self._artifacts = artifacts
@abstractmethod
def __call__(self, *args: Any) -> Any: ...
@abstractmethod
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None: ...
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
if format == "unpacked":
# If format is unpacked, it must be a CacheCompiledArtifact
return CacheCompiledArtifact.load(path=path, format=format)
assert format == "binary"
with open(path, "rb") as file:
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
result_bytes = file.read()
reader = BytesReader(result_bytes)
header = reader.read_bytes()
if header == AOTCompiledArtifact.AOT_HEADER:
assert reader.read_bytes() == torch_key()
artifact = reader.read_bytes()
assert reader.is_finished()
return AOTCompiledArtifact.deserialize(artifact)
# Otherwise, it's in the CacheCompiledArtifact format
elif header == CacheCompiledArtifact.CACHE_HEADER:
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
return CacheCompiledArtifact._load_impl(nullcontext(), key)
else:
raise RuntimeError(
"Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: "
+ header.decode("utf-8")
)
class CacheCompiledArtifact(CompiledArtifact):
"""
CompiledArtifact that depends on torch.compiler.save_cache_artifacts
"""
CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8")
def __init__(
self,
@ -83,6 +142,7 @@ class CompiledArtifact:
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER)
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
@ -116,9 +176,51 @@ class CompiledArtifact:
log.info("Output code written to: %s", output_file)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
def _load_impl(
cache_dir_ctx: AbstractContextManager[Any], key: str
) -> CompiledArtifact:
with (
cache_dir_ctx,
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
):
with torch._functorch.config.patch(strict_autograd_cache=True):
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
)
result = AOTAutogradCache._lookup(
key,
local=True,
remote=False,
args=[],
cache_info={},
aot_config=None,
)
assert result is not None
(entry, _) = result
from .compile_fx import _CompileFxKwargs
fx_config = _CompileFxKwargs(
cudagraphs=BoxedBool(False),
boxed_forward_device_index=BoxedDeviceIndex(0),
)
context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv()))
with torch._guards.tracing(context):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None)
@staticmethod
def _prepare_load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> tuple[str, AbstractContextManager[Any]]:
"""
Do format specific prep and loads, return a context manager and key
"""
path = normalize_path_separator(path)
with dynamo_timed("CompiledArtifact.load"):
if format == "binary":
@ -137,8 +239,7 @@ class CompiledArtifact:
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
return key, nullcontext()
else:
assert format == "unpacked"
assert os.path.isdir(path)
@ -148,43 +249,105 @@ class CompiledArtifact:
assert len(files) == 1
key = files[0]
cache_dir_ctx = temporary_cache_dir(path)
return key, cache_dir_ctx
with (
cache_dir_ctx,
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
):
with torch._functorch.config.patch(strict_autograd_cache=True):
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
key, cache_dir_ctx = CacheCompiledArtifact._prepare_load(
path=path, format=format
)
return CacheCompiledArtifact._load_impl(cache_dir_ctx, key)
result = AOTAutogradCache._lookup(
key,
local=True,
remote=False,
args=[],
cache_info={},
aot_config=None,
)
assert result is not None
(entry, _) = result
class AOTCompiledArtifact(CompiledArtifact):
"""
Similar to CompiledArtifact, but the object is a single, bundled precompiled function.
This object is always a serializable callable function.
from .compile_fx import _CompileFxKwargs
This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which
is used by torch._dynamo.aot_compile for AOT Precompilation.
"""
fx_config = _CompileFxKwargs(
cudagraphs=BoxedBool(False),
boxed_forward_device_index=BoxedDeviceIndex(0),
)
AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8")
context = torch._guards.TracingContext(
FakeTensorMode(shape_env=ShapeEnv())
)
with torch._guards.tracing(context):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
def __init__(
self,
compiled_fn: Callable[..., Any],
):
self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
self._artifacts = (
None # We don't need artifacts, the inner object handles everything
)
@staticmethod
def from_bundled_callable(
bundled_fn: BundledAOTAutogradSerializableCallable,
) -> AOTCompiledArtifact:
return AOTCompiledArtifact(bundled_fn.compiled_fn)
def __call__(self, *args: Any) -> Any:
return self.inner_fn(*args)
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
if format == "unpacked":
raise RuntimeError(
"AOTCompiledArtifact does not support unpacked format yet"
)
result_bytes = self.serialize()
from torch.utils._appending_byte_serializer import BytesWriter
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(AOTCompiledArtifact.AOT_HEADER)
writer.write_bytes(torch_key())
writer.write_bytes(result_bytes)
from torch._inductor.codecache import write_atomic
# Save a sentinel file to indicate that this is AOT
write_atomic(path, writer.to_bytes())
def serialize(self) -> bytes:
return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts(
self.inner_fn
)
@staticmethod
def deserialize(result_bytes: bytes) -> AOTCompiledArtifact:
deserialized = (
BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts(
result_bytes
)
)
assert isinstance(deserialized, BundledAOTAutogradSerializableCallable)
return AOTCompiledArtifact.from_bundled_callable(deserialized)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
if format == "unpacked":
raise RuntimeError(
"AOTCompiledArtifact does not support unpacked format yet"
)
with open(path, "rb") as file:
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
result_bytes = file.read()
reader = BytesReader(result_bytes)
header = reader.read_bytes()
assert header == AOTCompiledArtifact.AOT_HEADER
assert reader.read_bytes() == torch_key()
artifact = reader.read_bytes()
assert reader.is_finished()
return AOTCompiledArtifact.deserialize(artifact)
def standalone_compile(
@ -193,7 +356,11 @@ def standalone_compile(
*,
dynamic_shapes: Any,
options: Any,
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
) -> CompiledArtifact:
"""
Implementation of torch.inductor.standalone_compile
"""
from torch.compiler._cache import CacheArtifactManager
from .compile_fx import compile_fx
@ -249,6 +416,7 @@ def standalone_compile(
torch._guards.tracing(context),
CacheArtifactManager.with_fresh_cache(),
config.patch("triton.autotune_at_compile_time", True),
torch._functorch.config.patch("bundled_autograd_cache", aot),
):
# compile_fx can mutate gm
gm = copy.deepcopy(gm)
@ -256,7 +424,12 @@ def standalone_compile(
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
)
assert callable(compiled_fn)
if aot:
if not hasattr(compiled_fn, "serialize"):
raise RuntimeError(
"Compiled function should have serialize method when aot=True"
)
return AOTCompiledArtifact(compiled_fn)
artifacts = torch.compiler.save_cache_artifacts()
if artifacts is None:
log.warning(
@ -264,4 +437,4 @@ def standalone_compile(
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
)
return CompiledArtifact(compiled_fn, artifacts)
return CacheCompiledArtifact(compiled_fn, artifacts)

View File

@ -35,7 +35,7 @@ from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedte
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.fake_profile import MissingOpProfile
from torch._logging import dtrace_structured
from torch._prims_common import suggest_memory_format
from torch._prims_common import check_contiguous_sizes_strides, suggest_memory_format
from torch._subclasses.meta_utils import (
assert_eq,
assert_metadata_eq,
@ -1072,21 +1072,46 @@ def extract_tensor_metadata(t: Tensor) -> TensorMetadata:
Extract the TensorMetadata of a tensor.
"""
memory_format = suggest_memory_format(t)
# Don't call is_contiguous() on a Tensor which has symbolic sizes or things
# will go badly (guards will be messed up?)
if (
t._has_symbolic_sizes_strides
or is_sparse_any(t)
or not t.is_contiguous(memory_format=memory_format)
):
shape = tuple(t.shape)
stride = tuple(t.stride()) if t.layout == torch.strided else ()
if is_sparse_any(t):
is_contiguous = False
else:
if t._has_symbolic_sizes_strides:
still_has_symbolic_sizes_strides = False
def simplify(x: IntLikeType) -> IntLikeType:
if not isinstance(x, SymInt):
return x
value = x.node.expr
if value.is_number:
return int(value)
nonlocal still_has_symbolic_sizes_strides
still_has_symbolic_sizes_strides = True
return x
shape = tuple(simplify(x) for x in shape)
stride = tuple(simplify(x) for x in stride)
if still_has_symbolic_sizes_strides:
# Don't call is_contiguous() on a Tensor which has symbolic sizes or things
# will go badly (guards will be messed up?)
is_contiguous = False
else:
is_contiguous = check_contiguous_sizes_strides(shape, stride, True)
else:
is_contiguous = t.is_contiguous(memory_format=memory_format)
if not is_contiguous:
memory_format = None # type: ignore[assignment]
storage_offset = t.storage_offset()
return TensorMetadata(
t.dtype,
t.shape,
t.stride() if t.layout == torch.strided else (),
shape,
stride,
t.device,
t.layout,
memory_format,

View File

@ -1,5 +1,6 @@
#include <torch/csrc/jit/runtime/logging.h>
#include <c10/util/Exception.h>
#include <atomic>
#include <chrono>
#include <mutex>
@ -33,7 +34,7 @@ int64_t LockingLogger::getCounterValue(const std::string& name) const {
return raw_counter.sum / raw_counter.count;
} break;
}
throw std::runtime_error("Unknown aggregation type!");
TORCH_CHECK(false, "Unknown aggregation type!");
}
void LockingLogger::setAggregationType(

View File

@ -11,6 +11,7 @@
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
#include <limits>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
namespace torch::jit {
@ -112,20 +113,17 @@ void listRemove<at::Tensor>(Stack& stack) {
}
void checkImplicitTensorToNum(const at::Tensor& t, bool toInt) {
if (t.requires_grad()) {
throw std::runtime_error(
"Cannot input a tensor that requires grad as a scalar argument");
}
if (!t.sizes().empty()) {
throw std::runtime_error(
"Cannot input a tensor of dimension other than 0 as a scalar argument");
}
if (toInt && !isIntegralType(t.scalar_type(), /*includeBool=*/false)) {
std::stringstream ss;
ss << "Cannot input a tensor of type " << t.scalar_type()
<< " as an integral argument";
throw std::runtime_error(ss.str());
}
TORCH_CHECK(
!t.requires_grad(),
"Cannot input a tensor that requires grad as a scalar argument");
TORCH_CHECK(
t.sizes().empty(),
"Cannot input a tensor of dimension other than 0 as a scalar argument");
TORCH_CHECK(
!toInt || isIntegralType(t.scalar_type(), /*includeBool=*/false),
"Cannot input a tensor of type ",
t.scalar_type(),
" as an integral argument");
}
void checkDoubleInRange(double a) {

View File

@ -1,5 +1,6 @@
#include <ATen/autocast_mode.h>
#include <ATen/core/Generator.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
@ -159,9 +160,8 @@ void sort_op(Stack& stack) {
if (!g_list.empty()) {
std::stringstream error_str;
if (!isSortableListOfObjectsOrTuples(g_list, error_str)) {
throw std::runtime_error(error_str.str());
}
TORCH_CHECK(
isSortableListOfObjectsOrTuples(g_list, error_str), error_str.str());
c10::IValueComparator comparator;
if (reverse) {
@ -254,9 +254,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
int64_t lo = 0, hi = 0, step = 0;
pop(stack, lo, hi, step);
// error handling when step_val = 0 during runtime
if (step == 0) {
throw std::runtime_error("range() arg 3 must not be zero");
}
TORCH_CHECK(step != 0, "range() arg 3 must not be zero");
if (step > 0 && lo < hi) {
push(stack, 1 + (hi - 1 - lo) / step);
} else if (step < 0 && lo > hi) {
@ -382,14 +380,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
auto s = pop(stack).toString();
std::string::size_type sz = 0;
int64_t val = static_cast<int64_t>(std::stoll(s->string(), &sz));
if (sz == s->string().size()) {
push(stack, val);
} else {
std::stringstream error_str;
error_str << "invalid literal for int() "
<< "with base 10: '" << s->string() << "'";
throw std::runtime_error(error_str.str());
}
TORCH_CHECK(
sz == s->string().size(),
"invalid literal for int() ",
"with base 10: '",
s->string(),
"'");
push(stack, val);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
@ -436,14 +433,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
auto s = pop(stack).toString();
std::string::size_type sz = 0;
double b = std::stod(s->string(), &sz);
if (sz == s->string().size()) {
push(stack, b);
} else {
std::stringstream error_str;
error_str << "could not convert string "
<< "to float: '" << s->string() << "'";
throw std::runtime_error(error_str.str());
}
TORCH_CHECK(
sz == s->string().size(),
"could not convert string ",
"to float: '",
s->string(),
"'");
push(stack, b);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
@ -1793,10 +1789,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
}
const std::string& separator = ivalue.toStringRef();
if (separator.empty()) {
throw std::runtime_error("ValueError: empty separator");
}
TORCH_CHECK(!separator.empty(), "ValueError: empty separator");
auto count = 0;
@ -1919,11 +1912,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
std::string fillchar = pop(stack).toStringRef();
int64_t width = pop(stack).toInt();
std::string string = pop(stack).toStringRef();
if (fillchar.size() != 1) {
// TODO: this should be a TypeError
throw std::runtime_error(
"TypeError: The fill character must be exactly one character long");
}
TORCH_CHECK(
fillchar.size() == 1,
"TypeError: The fill character must be exactly one character long");
if (string.size() > static_cast<std::string::size_type>(width)) {
push(stack, string);
return;
@ -2092,9 +2083,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
std::string substr = pop(stack).toStringRef();
std::string string = pop(stack).toStringRef();
auto result = stringFindImpl(string, substr, start, end);
if (result < 0) {
throw std::runtime_error("ValueError: substring not found");
}
TORCH_CHECK(result >= 0, "ValueError: substring not found");
push(stack, result);
},
aliasAnalysisFromSchema()),
@ -2107,9 +2096,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
std::string substr = pop(stack).toStringRef();
std::string string = pop(stack).toStringRef();
auto result = stringFindImpl(string, substr, start, end, true);
if (result < 0) {
throw std::runtime_error("ValueError: substring not found");
}
TORCH_CHECK(result >= 0, "ValueError: substring not found");
push(stack, result);
},
aliasAnalysisFromSchema()),
@ -2183,11 +2170,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
std::string fillchar = pop(stack).toStringRef();
int64_t width = pop(stack).toInt();
std::string string = pop(stack).toStringRef();
if (fillchar.size() != 1) {
// TODO: this should be a TypeError
throw std::runtime_error(
"TypeError: The fill character must be exactly one character long");
}
TORCH_CHECK(
fillchar.size() == 1,
"TypeError: The fill character must be exactly one character long");
auto to_append =
std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
@ -2207,11 +2192,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
std::string fillchar = pop(stack).toStringRef();
int64_t width = pop(stack).toInt();
std::string string = pop(stack).toStringRef();
if (fillchar.size() != 1) {
// TODO: this should be a TypeError
throw std::runtime_error(
"TypeError: The fill character must be exactly one character long");
}
TORCH_CHECK(
fillchar.size() == 1,
"TypeError: The fill character must be exactly one character long");
auto to_append =
std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
@ -3358,10 +3341,8 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
int64_t a = 0, b = 0;
lldiv_t divresult = {};
pop(stack, a, b);
if (b == 0) {
throw std::runtime_error(
"ZeroDivisionError: integer division or modulo by zero");
}
TORCH_CHECK(
b != 0, "ZeroDivisionError: integer division or modulo by zero");
divresult = lldiv(a, b);
if (divresult.rem && (a < 0) != (b < 0)) {
divresult.quot -= 1;
@ -3379,9 +3360,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
[](Stack& stack) {
double a = 0, b = 0;
pop(stack, a, b);
if (b == 0) {
throw std::runtime_error("ZeroDivisionError: float divmod()");
}
TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()");
double rem = fmod(a, b);
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
if (rem && (a < 0) != (b < 0)) {
@ -3426,9 +3405,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
type_a a; \
type_b b; \
pop(stack, a, b); \
if (b == 0) { \
throw std::runtime_error("ZeroDivisionError: float divmod()"); \
} \
TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()"); \
double quot = floor(a / b); \
double rem = a - (quot * b); \
push(stack, quot, rem); \

View File

@ -43,6 +43,8 @@ def _partition_by_supported_nodes(gm, supported_ops, prefix):
def _compile_submod(gm, prefix):
from torch._inductor.standalone_compile import AOTCompiledArtifact
for node in gm.graph.nodes:
if node.op == "call_module" and node.target.startswith(prefix):
fake_inputs = []
@ -56,13 +58,12 @@ def _compile_submod(gm, prefix):
submod = getattr(gm, node.target)
# _dummy_wrapper is to make call_function happy
compiled_submod = _dummy_wrapper(
torch._inductor.standalone_compile(
submod, fake_inputs, dynamic_shapes="from_tracing_context"
)
compiled_fn = torch._inductor.standalone_compile(
submod, fake_inputs, dynamic_shapes="from_tracing_context", aot=True
)
assert isinstance(compiled_fn, AOTCompiledArtifact)
# _dummy_wrapper is to make call_function happy
compiled_submod = _dummy_wrapper(compiled_fn)
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
compiled_submod, args=node.args, kwargs=node.kwargs