mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 00:24:53 +08:00
Compare commits
9 Commits
codex/add-
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| d6d3367233 | |||
| a514a050fa | |||
| 1a71669a22 | |||
| a91c4ceb08 | |||
| 8b3dc0d1b0 | |||
| 485c73a947 | |||
| 06773663b5 | |||
| 0bff65503c | |||
| 5c075d8bcf |
@ -39,7 +39,7 @@ struct HostBlock {
|
||||
};
|
||||
|
||||
template <typename B>
|
||||
struct alignas(64) FreeBlockList {
|
||||
struct alignas(hardware_destructive_interference_size) FreeBlockList {
|
||||
std::mutex mutex_;
|
||||
std::deque<B*> list_;
|
||||
};
|
||||
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
|
||||
// Struct containing memory allocator summary statistics for host, as they
|
||||
// are staged for reporting. This is a temporary struct that is used to
|
||||
// avoid locking the allocator while collecting stats.
|
||||
struct alignas(64) HostStatsStaged {
|
||||
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
|
||||
std::mutex timing_mutex_;
|
||||
// COUNT: total allocations (active + free)
|
||||
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
|
||||
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
|
||||
}
|
||||
|
||||
alignas(64) std::mutex blocks_mutex_;
|
||||
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
|
||||
ska::flat_hash_set<B*> blocks_; // block list
|
||||
ska::flat_hash_map<void*, B*> ptr_to_block_;
|
||||
|
||||
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
|
||||
// size. This allows us to quickly find a free block of the right size.
|
||||
// We use deque to store per size free list and guard the list with its own
|
||||
// mutex.
|
||||
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
|
||||
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
|
||||
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
|
||||
|
||||
alignas(64) std::mutex events_mutex_;
|
||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||
|
||||
// Indicates whether the object is active.
|
||||
// Set to false in the destructor to signal background threads to stop.
|
||||
std::atomic<bool> active_{true};
|
||||
protected:
|
||||
alignas(64) HostStatsStaged stats_;
|
||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
||||
};
|
||||
|
||||
struct TORCH_API HostAllocator : public at::Allocator {
|
||||
|
||||
@ -141,8 +141,6 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
|
||||
return;
|
||||
}
|
||||
|
||||
checkTrilTriuMemoryOverlap(result, self);
|
||||
|
||||
bool inplace_op = self.is_same(result);
|
||||
|
||||
bool inplace_update = false;
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
|
||||
@ -55,13 +54,4 @@ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor
|
||||
return std::make_tuple(true, tensor);
|
||||
}
|
||||
|
||||
static inline void checkTrilTriuMemoryOverlap(const Tensor& result, const Tensor& self) {
|
||||
if (result.is_same(self)) {
|
||||
at::assert_no_internal_overlap(result);
|
||||
} else {
|
||||
at::assert_no_internal_overlap(result);
|
||||
at::assert_no_overlap(result, self);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TriangularOpsUtils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -111,8 +110,6 @@ __global__ void triu_tril_kernel(
|
||||
|
||||
template <bool upper>
|
||||
void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
|
||||
checkTrilTriuMemoryOverlap(result, self);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
at::ScalarType::ComplexHalf,
|
||||
at::ScalarType::Half,
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/alignment.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <new>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -18,4 +19,12 @@ constexpr size_t gPagesize = 4096;
|
||||
// since the default thp pagesize is 2MB, enable thp only
|
||||
// for buffers of size 2MB or larger to avoid memory bloating
|
||||
constexpr size_t gAlloc_threshold_thp = static_cast<size_t>(2) * 1024 * 1024;
|
||||
|
||||
// Cache line size used to avoid false sharing between threads. Falls back to 64
|
||||
// bytes if C++17 feature is unavailable.
|
||||
#ifdef __cpp_lib_hardware_interference_size
|
||||
using std::hardware_destructive_interference_size;
|
||||
#else
|
||||
constexpr std::size_t hardware_destructive_interference_size = 64;
|
||||
#endif
|
||||
} // namespace c10
|
||||
|
||||
@ -941,7 +941,7 @@ class EventPool {
|
||||
|
||||
private:
|
||||
struct PerDevicePool {
|
||||
alignas(64) std::mutex mutex_;
|
||||
alignas(hardware_destructive_interference_size) std::mutex mutex_;
|
||||
std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
|
||||
};
|
||||
std::vector<PerDevicePool> pools_;
|
||||
@ -3758,11 +3758,6 @@ static void uncached_delete(void* ptr) {
|
||||
static void local_raw_delete(void* ptr);
|
||||
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
|
||||
thread_local std::string DeviceCachingAllocator::user_metadata;
|
||||
#ifdef __cpp_lib_hardware_interference_size
|
||||
using std::hardware_destructive_interference_size;
|
||||
#else
|
||||
static constexpr std::size_t hardware_destructive_interference_size = 64;
|
||||
#endif
|
||||
|
||||
class NativeCachingAllocator : public CUDAAllocator {
|
||||
private:
|
||||
|
||||
@ -554,7 +554,7 @@ static void local_raw_delete(void* ptr);
|
||||
|
||||
class XPUAllocator : public DeviceAllocator {
|
||||
private:
|
||||
std::mutex mutex;
|
||||
alignas(hardware_destructive_interference_size) std::mutex mutex;
|
||||
ska::flat_hash_map<void*, Block*> allocated_blocks;
|
||||
|
||||
void add_allocated_block(Block* block) {
|
||||
|
||||
@ -3064,11 +3064,12 @@ class GraphModule(torch.nn.Module):
|
||||
primals_6, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
|
||||
primals_7, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
|
||||
primals_8, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_10, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
|
||||
primals_3, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
|
||||
primals_10, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
||||
primals_1, # SavedForBackwardsAOTOutput(idx=0)
|
||||
primals_8, # SavedForBackwardsAOTOutput(idx=1)
|
||||
primals_10, # SavedForBackwardsAOTOutput(idx=2)
|
||||
primals_3, # SavedForBackwardsAOTOutput(idx=1)
|
||||
primals_8, # SavedForBackwardsAOTOutput(idx=2)
|
||||
primals_10, # SavedForBackwardsAOTOutput(idx=3)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
@ -3080,6 +3081,7 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
|
||||
primals_3: "Sym(s55)", # PlainAOTInput(idx=2)
|
||||
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
||||
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
||||
tangents_1: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values')
|
||||
@ -3097,7 +3099,7 @@ class GraphModule(torch.nn.Module):
|
||||
tangents_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor')
|
||||
tangents_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor')
|
||||
primals_8, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0)
|
||||
primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2)
|
||||
primals_3, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2)
|
||||
primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1)
|
||||
)
|
||||
""", # noqa: B950
|
||||
@ -3134,7 +3136,7 @@ class GraphModule(torch.nn.Module):
|
||||
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
|
||||
add_2: "Sym(2*s55)" = primals_10 + primals_10
|
||||
add_2: "Sym(2*s55)" = primals_10 + primals_3; primals_3 = None
|
||||
return (
|
||||
cat, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
|
||||
primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')
|
||||
|
||||
@ -1839,12 +1839,22 @@ class TestStandaloneCompile(TestCase):
|
||||
@parametrize("format", ("binary", "unpacked"))
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("graph_partition", (False, True))
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_basic(
|
||||
self, device: str, format: str, dynamic: bool, graph_partition: bool
|
||||
self,
|
||||
device: str,
|
||||
format: str,
|
||||
dynamic: bool,
|
||||
graph_partition: bool,
|
||||
is_aot: bool,
|
||||
) -> None:
|
||||
if device == GPU_TYPE and not HAS_GPU:
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
|
||||
# AOT mode does not support unpacked format
|
||||
if is_aot and format == "unpacked":
|
||||
raise unittest.SkipTest("AOT mode does not support unpacked format")
|
||||
|
||||
mod = torch.nn.Linear(1, 3, device=device)
|
||||
x = torch.randn(4, 1, device=device)
|
||||
if dynamic:
|
||||
@ -1869,7 +1879,9 @@ class TestStandaloneCompile(TestCase):
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
||||
compiled_artifact = torch._inductor.standalone_compile(gm, args)
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, aot=is_aot
|
||||
)
|
||||
compiled_artifact.save(path=path, format=format)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
@ -1885,13 +1897,15 @@ class TestStandaloneCompile(TestCase):
|
||||
compiled_out = loaded(*concrete_args)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
if not is_aot:
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("dynamic", (False, True))
|
||||
def test_call_in_backend(self, dynamic: bool) -> None:
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_call_in_backend(self, dynamic: bool, is_aot: bool) -> None:
|
||||
mod = torch.nn.Linear(1, 3)
|
||||
x = torch.randn(4, 1)
|
||||
if dynamic:
|
||||
@ -1904,7 +1918,7 @@ class TestStandaloneCompile(TestCase):
|
||||
eager_out = f(x)
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
return torch._inductor.standalone_compile(gm, args)
|
||||
return torch._inductor.standalone_compile(gm, args, aot=is_aot)
|
||||
|
||||
with fresh_cache():
|
||||
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
|
||||
@ -2055,7 +2069,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_dynamic_shapes_from_graph(self):
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_dynamic_shapes_from_graph(self, is_aot: bool):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2067,7 +2082,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
assert not kwargs
|
||||
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, dynamic_shapes="from_graph"
|
||||
gm, args, dynamic_shapes="from_graph", aot=is_aot
|
||||
)
|
||||
x = torch.ones(4)
|
||||
(result,) = compiled_artifact(4, x)
|
||||
@ -2077,7 +2092,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch({"autograd_cache_normalize_inputs": True})
|
||||
def test_split_module(self):
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_split_module(self, is_aot):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x, a0, a1, b0, b1, c0, c1):
|
||||
x = x + (a0**2) + (a1 / 2)
|
||||
@ -2116,16 +2132,24 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
split = torch.fx.passes.split_module.split_module(gm, gm, split)
|
||||
|
||||
# Each of the split graphs only has one output.
|
||||
ca0 = torch._inductor.standalone_compile(split.submod_0, (a0, x, a1))
|
||||
ca1 = torch._inductor.standalone_compile(split.submod_1, (b0, x, b1))
|
||||
ca2 = torch._inductor.standalone_compile(split.submod_2, (c0, x, c1))
|
||||
ca0 = torch._inductor.standalone_compile(
|
||||
split.submod_0, (a0, x, a1), aot=is_aot
|
||||
)
|
||||
ca1 = torch._inductor.standalone_compile(
|
||||
split.submod_1, (b0, x, b1), aot=is_aot
|
||||
)
|
||||
ca2 = torch._inductor.standalone_compile(
|
||||
split.submod_2, (c0, x, c1), aot=is_aot
|
||||
)
|
||||
|
||||
y = ca0(a0, x, a1)
|
||||
y = ca1(b0, y, b1)
|
||||
y = ca2(c0, y, c1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
|
||||
if not is_aot:
|
||||
# fx graph cache doesn't run in AOT mode
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
|
||||
# TODO: split_module causes ca1 and ca2 to have different type annotations
|
||||
# for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
@ -2138,8 +2162,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("is_aot", (False, True))
|
||||
@parametrize("config_patches", [True, False])
|
||||
def test_dynamic_shapes_from_example_inputs(self, config_patches):
|
||||
def test_dynamic_shapes_from_example_inputs(self, config_patches, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2161,6 +2186,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
(5, torch.ones(4)),
|
||||
dynamic_shapes="from_example_inputs",
|
||||
options={"config_patches": config_patches},
|
||||
aot=is_aot,
|
||||
)
|
||||
x = torch.ones(4)
|
||||
(result,) = compiled_artifact(3, x)
|
||||
@ -2175,8 +2201,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("is_aot", (True, False))
|
||||
@parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"])
|
||||
def test_static_shapes(self, dynamic_shapes):
|
||||
def test_static_shapes(self, dynamic_shapes, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2186,7 +2213,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
|
||||
assert not kwargs
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
static_gm, [static_x], dynamic_shapes=dynamic_shapes
|
||||
static_gm, [static_x], dynamic_shapes=dynamic_shapes, aot=is_aot
|
||||
)
|
||||
x = torch.randn(3)
|
||||
(result,) = compiled_artifact(x)
|
||||
@ -2198,8 +2225,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("is_aot", (True, False))
|
||||
@parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"])
|
||||
def test_backend(self, dynamic_shapes):
|
||||
def test_backend(self, dynamic_shapes, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2208,7 +2236,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, dynamic_shapes=dynamic_shapes
|
||||
gm, args, dynamic_shapes=dynamic_shapes, aot=is_aot
|
||||
)
|
||||
y = torch.randn(4)
|
||||
(result,) = compiled_artifact(4, y)
|
||||
@ -2221,7 +2249,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_backend_dynamic_shapes_from_example_inputs(self):
|
||||
@parametrize("is_aot", (True, False))
|
||||
def test_backend_dynamic_shapes_from_example_inputs(self, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2230,7 +2259,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs"
|
||||
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs", aot=is_aot
|
||||
)
|
||||
y = torch.ones(4)
|
||||
(result,) = compiled_artifact(4, y)
|
||||
|
||||
@ -9986,20 +9986,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
self.assertEqual(result_triu_min, expected_triu_min)
|
||||
self.assertEqual(result_tril_min, expected_tril_min)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_triu_tril_inplace_memory_overlap(self, device, dtype):
|
||||
base = torch.rand((), dtype=dtype, device=device)
|
||||
expanded = base.expand(3, 3)
|
||||
msg = (
|
||||
"unsupported operation: more than one element of the written-to tensor "
|
||||
"refers to a single memory location. Please clone() the tensor before "
|
||||
"performing the operation."
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
expanded.triu_(1)
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
expanded.tril_(-1)
|
||||
|
||||
@dtypes(torch.float, torch.double)
|
||||
@precisionOverride({torch.float32: 1e-4})
|
||||
def test_1_sized_with_0_strided(self, device, dtype):
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
import importlib
|
||||
import inspect
|
||||
@ -15,6 +14,10 @@ from torch._dynamo.graph_utils import _graph_device_type
|
||||
from torch._dynamo.package import SystemInfo
|
||||
|
||||
from . import convert_frame
|
||||
from .aot_compile_types import (
|
||||
BundledAOTAutogradSerializableCallable,
|
||||
SerializableCallable,
|
||||
)
|
||||
from .hooks import Hooks
|
||||
|
||||
|
||||
@ -26,18 +29,6 @@ if TYPE_CHECKING:
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SerializableCallable(abc.ABC):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
def bind_locals(
|
||||
signature: inspect.Signature, *args: Any, **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
@ -149,53 +140,6 @@ class AOTCompiledFunction:
|
||||
self._guard_check_enabled = False
|
||||
|
||||
|
||||
class BundledAOTAutogradSerializableCallable(SerializableCallable):
|
||||
"""
|
||||
Represents a serializable callable generated by compile_fx.
|
||||
This class wraps around the compiled function generated by AOTAutograd.
|
||||
|
||||
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
|
||||
this object should be what's *returned* by aot_module_simplified.
|
||||
We'll do that refactor in a later PR.
|
||||
"""
|
||||
|
||||
def __init__(self, compiled_fn: Any) -> None:
|
||||
"""
|
||||
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
|
||||
of a compiled function generated by AOTAutograd.
|
||||
"""
|
||||
assert hasattr(compiled_fn, "serialize")
|
||||
self.compiled_fn = compiled_fn
|
||||
|
||||
def __getattr__(self, attr: Any) -> Any:
|
||||
if hasattr(self, attr):
|
||||
return getattr(super(), attr)
|
||||
else:
|
||||
return getattr(self.compiled_fn, attr)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, fn: "BundledAOTAutogradSerializableCallable"
|
||||
) -> bytes:
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
result = pickle.dumps(fn.compiled_fn.serialize())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
deserialize_bundled_cache_entry,
|
||||
)
|
||||
|
||||
entry = pickle.loads(data)
|
||||
|
||||
compiled_fn = deserialize_bundled_cache_entry(entry)
|
||||
return cls(compiled_fn)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.compiled_fn(*args, **kwargs)
|
||||
|
||||
|
||||
def aot_compile_fullgraph(
|
||||
model: Any,
|
||||
example_inputs: tuple[tuple[Any, ...], dict[str, Any]],
|
||||
|
||||
61
torch/_dynamo/aot_compile_types.py
Normal file
61
torch/_dynamo/aot_compile_types.py
Normal file
@ -0,0 +1,61 @@
|
||||
import abc
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SerializableCallable(abc.ABC):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class BundledAOTAutogradSerializableCallable(SerializableCallable):
|
||||
"""
|
||||
Represents a serializable callable generated by compile_fx.
|
||||
This class wraps around the compiled function generated by AOTAutograd.
|
||||
|
||||
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
|
||||
this object should be what's *returned* by aot_module_simplified.
|
||||
We'll do that refactor in a later PR.
|
||||
"""
|
||||
|
||||
def __init__(self, compiled_fn: Any) -> None:
|
||||
"""
|
||||
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
|
||||
of a compiled function generated by AOTAutograd.
|
||||
"""
|
||||
assert hasattr(compiled_fn, "serialize")
|
||||
self.compiled_fn = compiled_fn
|
||||
|
||||
def __getattr__(self, attr: Any) -> Any:
|
||||
return getattr(self.compiled_fn, attr)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, fn: "BundledAOTAutogradSerializableCallable"
|
||||
) -> bytes:
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
result = pickle.dumps(fn.compiled_fn.serialize())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
deserialize_bundled_cache_entry,
|
||||
)
|
||||
|
||||
entry = pickle.loads(data)
|
||||
|
||||
compiled_fn = deserialize_bundled_cache_entry(entry)
|
||||
return cls(compiled_fn)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.compiled_fn(*args, **kwargs)
|
||||
@ -2188,20 +2188,6 @@ class OutputGraph(OutputGraphCommon):
|
||||
),
|
||||
)
|
||||
self.call_cleanup_hooks()
|
||||
old_fake_mode = self.tracing_context.fake_mode
|
||||
assert old_fake_mode is not None
|
||||
if not self.export:
|
||||
import torch._functorch.config as _config
|
||||
|
||||
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
||||
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
|
||||
backend_fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=old_fake_mode.shape_env,
|
||||
)
|
||||
# TODO(voz): Ostensibily, this should be scoped and
|
||||
# restore back to old_fake_mode, but doing so currently violates
|
||||
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
|
||||
self.tracing_context.fake_mode = backend_fake_mode
|
||||
|
||||
with self.restore_global_state():
|
||||
compiled_fn = self.call_user_compiler(gm, self.example_inputs())
|
||||
@ -2237,8 +2223,10 @@ class OutputGraph(OutputGraphCommon):
|
||||
)
|
||||
|
||||
counters["stats"]["unique_graphs"] += 1
|
||||
assert old_fake_mode.shape_env is not None
|
||||
if specializations := old_fake_mode.shape_env.specializations:
|
||||
if (
|
||||
specializations
|
||||
:= self.tracing_context.fake_mode.shape_env.specializations
|
||||
):
|
||||
specialization_guards = []
|
||||
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
|
||||
sources = [a.source for a in self.graphargs]
|
||||
|
||||
@ -182,7 +182,7 @@ def create_subclass_meta(
|
||||
def enumerate_filter_symints(lst: Iterable[IntLikeType]) -> list[tuple[int, SymInt]]:
|
||||
# Capture all SymInts from the iterable.
|
||||
def symint_check(s: IntLikeType) -> TypeGuard[SymInt]:
|
||||
return isinstance(s, SymInt) and not s.node.is_nested_int()
|
||||
return isinstance(s, SymInt) and not s.node.is_nested_int() and not s.node.expr.is_number
|
||||
|
||||
return [(i, s) for i, s in enumerate(lst) if symint_check(s)]
|
||||
|
||||
|
||||
@ -391,6 +391,7 @@ def standalone_compile(
|
||||
"from_example_inputs", "from_tracing_context", "from_graph"
|
||||
] = "from_graph",
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
|
||||
) -> CompiledArtifact:
|
||||
"""
|
||||
Precompilation API for inductor.
|
||||
@ -422,5 +423,5 @@ def standalone_compile(
|
||||
|
||||
options = options if options else {}
|
||||
return standalone_compile(
|
||||
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options
|
||||
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot
|
||||
)
|
||||
|
||||
@ -5,10 +5,12 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
|
||||
|
||||
import torch.fx
|
||||
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._inductor.cpp_builder import normalize_path_separator
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
||||
@ -30,9 +32,9 @@ if TYPE_CHECKING:
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompiledArtifact:
|
||||
class CompiledArtifact(ABC):
|
||||
"""
|
||||
CompiledArtifact class represents the precompiled inductor artifact that
|
||||
CompiledArtifact class represents the inductor cache artifacts that
|
||||
can be invoked in order to avoid repeated compilation.
|
||||
|
||||
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
|
||||
@ -45,11 +47,68 @@ class CompiledArtifact:
|
||||
binary or unpacked data.
|
||||
|
||||
Finally, the CompiledArtifact can be invoked via the __call__ method
|
||||
to execute the precompiled artifact.
|
||||
to execute the cached artifact.
|
||||
"""
|
||||
|
||||
_compiled_fn: Callable[..., Any]
|
||||
_artifacts: Optional[tuple[bytes, CacheInfo]]
|
||||
def __init__(
|
||||
self,
|
||||
compiled_fn: Callable[..., Any],
|
||||
artifacts: Optional[tuple[bytes, CacheInfo]],
|
||||
):
|
||||
self._compiled_fn = compiled_fn
|
||||
self._artifacts = artifacts
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, *args: Any) -> Any: ...
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> CompiledArtifact:
|
||||
if format == "unpacked":
|
||||
# If format is unpacked, it must be a CacheCompiledArtifact
|
||||
return CacheCompiledArtifact.load(path=path, format=format)
|
||||
|
||||
assert format == "binary"
|
||||
with open(path, "rb") as file:
|
||||
from torch.utils._appending_byte_serializer import BytesReader
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
result_bytes = file.read()
|
||||
reader = BytesReader(result_bytes)
|
||||
header = reader.read_bytes()
|
||||
if header == AOTCompiledArtifact.AOT_HEADER:
|
||||
assert reader.read_bytes() == torch_key()
|
||||
artifact = reader.read_bytes()
|
||||
assert reader.is_finished()
|
||||
return AOTCompiledArtifact.deserialize(artifact)
|
||||
# Otherwise, it's in the CacheCompiledArtifact format
|
||||
elif header == CacheCompiledArtifact.CACHE_HEADER:
|
||||
assert reader.read_bytes() == torch_key()
|
||||
key = reader.read_str()
|
||||
artifact_bytes = reader.read_bytes()
|
||||
assert reader.is_finished()
|
||||
torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
return CacheCompiledArtifact._load_impl(nullcontext(), key)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: "
|
||||
+ header.decode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
class CacheCompiledArtifact(CompiledArtifact):
|
||||
"""
|
||||
CompiledArtifact that depends on torch.compiler.save_cache_artifacts
|
||||
"""
|
||||
|
||||
CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -83,6 +142,7 @@ class CompiledArtifact:
|
||||
from .codecache import torch_key
|
||||
|
||||
writer = BytesWriter()
|
||||
writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER)
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_str(key)
|
||||
writer.write_bytes(artifact_bytes)
|
||||
@ -116,9 +176,51 @@ class CompiledArtifact:
|
||||
log.info("Output code written to: %s", output_file)
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
def _load_impl(
|
||||
cache_dir_ctx: AbstractContextManager[Any], key: str
|
||||
) -> CompiledArtifact:
|
||||
with (
|
||||
cache_dir_ctx,
|
||||
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
|
||||
):
|
||||
with torch._functorch.config.patch(strict_autograd_cache=True):
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache,
|
||||
)
|
||||
|
||||
result = AOTAutogradCache._lookup(
|
||||
key,
|
||||
local=True,
|
||||
remote=False,
|
||||
args=[],
|
||||
cache_info={},
|
||||
aot_config=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
(entry, _) = result
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
|
||||
fx_config = _CompileFxKwargs(
|
||||
cudagraphs=BoxedBool(False),
|
||||
boxed_forward_device_index=BoxedDeviceIndex(0),
|
||||
)
|
||||
|
||||
context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv()))
|
||||
with torch._guards.tracing(context):
|
||||
compiled_fn = entry.wrap_post_compile(
|
||||
[], entry.sanitized_aot_config, fx_config
|
||||
)
|
||||
return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> tuple[str, AbstractContextManager[Any]]:
|
||||
"""
|
||||
Do format specific prep and loads, return a context manager and key
|
||||
"""
|
||||
path = normalize_path_separator(path)
|
||||
with dynamo_timed("CompiledArtifact.load"):
|
||||
if format == "binary":
|
||||
@ -137,8 +239,7 @@ class CompiledArtifact:
|
||||
assert reader.is_finished()
|
||||
|
||||
torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
|
||||
return key, nullcontext()
|
||||
else:
|
||||
assert format == "unpacked"
|
||||
assert os.path.isdir(path)
|
||||
@ -148,43 +249,105 @@ class CompiledArtifact:
|
||||
assert len(files) == 1
|
||||
key = files[0]
|
||||
cache_dir_ctx = temporary_cache_dir(path)
|
||||
return key, cache_dir_ctx
|
||||
|
||||
with (
|
||||
cache_dir_ctx,
|
||||
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
|
||||
):
|
||||
with torch._functorch.config.patch(strict_autograd_cache=True):
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache,
|
||||
)
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> CompiledArtifact:
|
||||
key, cache_dir_ctx = CacheCompiledArtifact._prepare_load(
|
||||
path=path, format=format
|
||||
)
|
||||
return CacheCompiledArtifact._load_impl(cache_dir_ctx, key)
|
||||
|
||||
result = AOTAutogradCache._lookup(
|
||||
key,
|
||||
local=True,
|
||||
remote=False,
|
||||
args=[],
|
||||
cache_info={},
|
||||
aot_config=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
(entry, _) = result
|
||||
class AOTCompiledArtifact(CompiledArtifact):
|
||||
"""
|
||||
Similar to CompiledArtifact, but the object is a single, bundled precompiled function.
|
||||
This object is always a serializable callable function.
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which
|
||||
is used by torch._dynamo.aot_compile for AOT Precompilation.
|
||||
"""
|
||||
|
||||
fx_config = _CompileFxKwargs(
|
||||
cudagraphs=BoxedBool(False),
|
||||
boxed_forward_device_index=BoxedDeviceIndex(0),
|
||||
)
|
||||
AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8")
|
||||
|
||||
context = torch._guards.TracingContext(
|
||||
FakeTensorMode(shape_env=ShapeEnv())
|
||||
)
|
||||
with torch._guards.tracing(context):
|
||||
compiled_fn = entry.wrap_post_compile(
|
||||
[], entry.sanitized_aot_config, fx_config
|
||||
)
|
||||
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
|
||||
def __init__(
|
||||
self,
|
||||
compiled_fn: Callable[..., Any],
|
||||
):
|
||||
self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
|
||||
self._artifacts = (
|
||||
None # We don't need artifacts, the inner object handles everything
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_bundled_callable(
|
||||
bundled_fn: BundledAOTAutogradSerializableCallable,
|
||||
) -> AOTCompiledArtifact:
|
||||
return AOTCompiledArtifact(bundled_fn.compiled_fn)
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
return self.inner_fn(*args)
|
||||
|
||||
def save(
|
||||
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> None:
|
||||
if format == "unpacked":
|
||||
raise RuntimeError(
|
||||
"AOTCompiledArtifact does not support unpacked format yet"
|
||||
)
|
||||
result_bytes = self.serialize()
|
||||
from torch.utils._appending_byte_serializer import BytesWriter
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
writer = BytesWriter()
|
||||
writer.write_bytes(AOTCompiledArtifact.AOT_HEADER)
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_bytes(result_bytes)
|
||||
|
||||
from torch._inductor.codecache import write_atomic
|
||||
|
||||
# Save a sentinel file to indicate that this is AOT
|
||||
write_atomic(path, writer.to_bytes())
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts(
|
||||
self.inner_fn
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(result_bytes: bytes) -> AOTCompiledArtifact:
|
||||
deserialized = (
|
||||
BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts(
|
||||
result_bytes
|
||||
)
|
||||
)
|
||||
assert isinstance(deserialized, BundledAOTAutogradSerializableCallable)
|
||||
return AOTCompiledArtifact.from_bundled_callable(deserialized)
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> CompiledArtifact:
|
||||
if format == "unpacked":
|
||||
raise RuntimeError(
|
||||
"AOTCompiledArtifact does not support unpacked format yet"
|
||||
)
|
||||
with open(path, "rb") as file:
|
||||
from torch.utils._appending_byte_serializer import BytesReader
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
result_bytes = file.read()
|
||||
reader = BytesReader(result_bytes)
|
||||
header = reader.read_bytes()
|
||||
assert header == AOTCompiledArtifact.AOT_HEADER
|
||||
assert reader.read_bytes() == torch_key()
|
||||
artifact = reader.read_bytes()
|
||||
assert reader.is_finished()
|
||||
return AOTCompiledArtifact.deserialize(artifact)
|
||||
|
||||
|
||||
def standalone_compile(
|
||||
@ -193,7 +356,11 @@ def standalone_compile(
|
||||
*,
|
||||
dynamic_shapes: Any,
|
||||
options: Any,
|
||||
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
|
||||
) -> CompiledArtifact:
|
||||
"""
|
||||
Implementation of torch.inductor.standalone_compile
|
||||
"""
|
||||
from torch.compiler._cache import CacheArtifactManager
|
||||
|
||||
from .compile_fx import compile_fx
|
||||
@ -249,6 +416,7 @@ def standalone_compile(
|
||||
torch._guards.tracing(context),
|
||||
CacheArtifactManager.with_fresh_cache(),
|
||||
config.patch("triton.autotune_at_compile_time", True),
|
||||
torch._functorch.config.patch("bundled_autograd_cache", aot),
|
||||
):
|
||||
# compile_fx can mutate gm
|
||||
gm = copy.deepcopy(gm)
|
||||
@ -256,7 +424,12 @@ def standalone_compile(
|
||||
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
|
||||
)
|
||||
assert callable(compiled_fn)
|
||||
|
||||
if aot:
|
||||
if not hasattr(compiled_fn, "serialize"):
|
||||
raise RuntimeError(
|
||||
"Compiled function should have serialize method when aot=True"
|
||||
)
|
||||
return AOTCompiledArtifact(compiled_fn)
|
||||
artifacts = torch.compiler.save_cache_artifacts()
|
||||
if artifacts is None:
|
||||
log.warning(
|
||||
@ -264,4 +437,4 @@ def standalone_compile(
|
||||
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
|
||||
)
|
||||
|
||||
return CompiledArtifact(compiled_fn, artifacts)
|
||||
return CacheCompiledArtifact(compiled_fn, artifacts)
|
||||
|
||||
@ -35,7 +35,7 @@ from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedte
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.fake_profile import MissingOpProfile
|
||||
from torch._logging import dtrace_structured
|
||||
from torch._prims_common import suggest_memory_format
|
||||
from torch._prims_common import check_contiguous_sizes_strides, suggest_memory_format
|
||||
from torch._subclasses.meta_utils import (
|
||||
assert_eq,
|
||||
assert_metadata_eq,
|
||||
@ -1072,21 +1072,46 @@ def extract_tensor_metadata(t: Tensor) -> TensorMetadata:
|
||||
Extract the TensorMetadata of a tensor.
|
||||
"""
|
||||
memory_format = suggest_memory_format(t)
|
||||
# Don't call is_contiguous() on a Tensor which has symbolic sizes or things
|
||||
# will go badly (guards will be messed up?)
|
||||
if (
|
||||
t._has_symbolic_sizes_strides
|
||||
or is_sparse_any(t)
|
||||
or not t.is_contiguous(memory_format=memory_format)
|
||||
):
|
||||
shape = tuple(t.shape)
|
||||
stride = tuple(t.stride()) if t.layout == torch.strided else ()
|
||||
|
||||
if is_sparse_any(t):
|
||||
is_contiguous = False
|
||||
else:
|
||||
if t._has_symbolic_sizes_strides:
|
||||
still_has_symbolic_sizes_strides = False
|
||||
|
||||
def simplify(x: IntLikeType) -> IntLikeType:
|
||||
if not isinstance(x, SymInt):
|
||||
return x
|
||||
value = x.node.expr
|
||||
if value.is_number:
|
||||
return int(value)
|
||||
nonlocal still_has_symbolic_sizes_strides
|
||||
still_has_symbolic_sizes_strides = True
|
||||
return x
|
||||
|
||||
shape = tuple(simplify(x) for x in shape)
|
||||
stride = tuple(simplify(x) for x in stride)
|
||||
|
||||
if still_has_symbolic_sizes_strides:
|
||||
# Don't call is_contiguous() on a Tensor which has symbolic sizes or things
|
||||
# will go badly (guards will be messed up?)
|
||||
is_contiguous = False
|
||||
else:
|
||||
is_contiguous = check_contiguous_sizes_strides(shape, stride, True)
|
||||
else:
|
||||
is_contiguous = t.is_contiguous(memory_format=memory_format)
|
||||
|
||||
if not is_contiguous:
|
||||
memory_format = None # type: ignore[assignment]
|
||||
|
||||
storage_offset = t.storage_offset()
|
||||
|
||||
return TensorMetadata(
|
||||
t.dtype,
|
||||
t.shape,
|
||||
t.stride() if t.layout == torch.strided else (),
|
||||
shape,
|
||||
stride,
|
||||
t.device,
|
||||
t.layout,
|
||||
memory_format,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/jit/runtime/logging.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
@ -33,7 +34,7 @@ int64_t LockingLogger::getCounterValue(const std::string& name) const {
|
||||
return raw_counter.sum / raw_counter.count;
|
||||
} break;
|
||||
}
|
||||
throw std::runtime_error("Unknown aggregation type!");
|
||||
TORCH_CHECK(false, "Unknown aggregation type!");
|
||||
}
|
||||
|
||||
void LockingLogger::setAggregationType(
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
|
||||
#include <limits>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace torch::jit {
|
||||
@ -112,20 +113,17 @@ void listRemove<at::Tensor>(Stack& stack) {
|
||||
}
|
||||
|
||||
void checkImplicitTensorToNum(const at::Tensor& t, bool toInt) {
|
||||
if (t.requires_grad()) {
|
||||
throw std::runtime_error(
|
||||
"Cannot input a tensor that requires grad as a scalar argument");
|
||||
}
|
||||
if (!t.sizes().empty()) {
|
||||
throw std::runtime_error(
|
||||
"Cannot input a tensor of dimension other than 0 as a scalar argument");
|
||||
}
|
||||
if (toInt && !isIntegralType(t.scalar_type(), /*includeBool=*/false)) {
|
||||
std::stringstream ss;
|
||||
ss << "Cannot input a tensor of type " << t.scalar_type()
|
||||
<< " as an integral argument";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!t.requires_grad(),
|
||||
"Cannot input a tensor that requires grad as a scalar argument");
|
||||
TORCH_CHECK(
|
||||
t.sizes().empty(),
|
||||
"Cannot input a tensor of dimension other than 0 as a scalar argument");
|
||||
TORCH_CHECK(
|
||||
!toInt || isIntegralType(t.scalar_type(), /*includeBool=*/false),
|
||||
"Cannot input a tensor of type ",
|
||||
t.scalar_type(),
|
||||
" as an integral argument");
|
||||
}
|
||||
|
||||
void checkDoubleInRange(double a) {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
@ -159,9 +160,8 @@ void sort_op(Stack& stack) {
|
||||
|
||||
if (!g_list.empty()) {
|
||||
std::stringstream error_str;
|
||||
if (!isSortableListOfObjectsOrTuples(g_list, error_str)) {
|
||||
throw std::runtime_error(error_str.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
isSortableListOfObjectsOrTuples(g_list, error_str), error_str.str());
|
||||
|
||||
c10::IValueComparator comparator;
|
||||
if (reverse) {
|
||||
@ -254,9 +254,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
int64_t lo = 0, hi = 0, step = 0;
|
||||
pop(stack, lo, hi, step);
|
||||
// error handling when step_val = 0 during runtime
|
||||
if (step == 0) {
|
||||
throw std::runtime_error("range() arg 3 must not be zero");
|
||||
}
|
||||
TORCH_CHECK(step != 0, "range() arg 3 must not be zero");
|
||||
if (step > 0 && lo < hi) {
|
||||
push(stack, 1 + (hi - 1 - lo) / step);
|
||||
} else if (step < 0 && lo > hi) {
|
||||
@ -382,14 +380,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
auto s = pop(stack).toString();
|
||||
std::string::size_type sz = 0;
|
||||
int64_t val = static_cast<int64_t>(std::stoll(s->string(), &sz));
|
||||
if (sz == s->string().size()) {
|
||||
push(stack, val);
|
||||
} else {
|
||||
std::stringstream error_str;
|
||||
error_str << "invalid literal for int() "
|
||||
<< "with base 10: '" << s->string() << "'";
|
||||
throw std::runtime_error(error_str.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
sz == s->string().size(),
|
||||
"invalid literal for int() ",
|
||||
"with base 10: '",
|
||||
s->string(),
|
||||
"'");
|
||||
push(stack, val);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
@ -436,14 +433,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
auto s = pop(stack).toString();
|
||||
std::string::size_type sz = 0;
|
||||
double b = std::stod(s->string(), &sz);
|
||||
if (sz == s->string().size()) {
|
||||
push(stack, b);
|
||||
} else {
|
||||
std::stringstream error_str;
|
||||
error_str << "could not convert string "
|
||||
<< "to float: '" << s->string() << "'";
|
||||
throw std::runtime_error(error_str.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
sz == s->string().size(),
|
||||
"could not convert string ",
|
||||
"to float: '",
|
||||
s->string(),
|
||||
"'");
|
||||
push(stack, b);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
@ -1793,10 +1789,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
}
|
||||
|
||||
const std::string& separator = ivalue.toStringRef();
|
||||
|
||||
if (separator.empty()) {
|
||||
throw std::runtime_error("ValueError: empty separator");
|
||||
}
|
||||
TORCH_CHECK(!separator.empty(), "ValueError: empty separator");
|
||||
|
||||
auto count = 0;
|
||||
|
||||
@ -1919,11 +1912,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string fillchar = pop(stack).toStringRef();
|
||||
int64_t width = pop(stack).toInt();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
fillchar.size() == 1,
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
if (string.size() > static_cast<std::string::size_type>(width)) {
|
||||
push(stack, string);
|
||||
return;
|
||||
@ -2092,9 +2083,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string substr = pop(stack).toStringRef();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
auto result = stringFindImpl(string, substr, start, end);
|
||||
if (result < 0) {
|
||||
throw std::runtime_error("ValueError: substring not found");
|
||||
}
|
||||
TORCH_CHECK(result >= 0, "ValueError: substring not found");
|
||||
push(stack, result);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
@ -2107,9 +2096,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string substr = pop(stack).toStringRef();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
auto result = stringFindImpl(string, substr, start, end, true);
|
||||
if (result < 0) {
|
||||
throw std::runtime_error("ValueError: substring not found");
|
||||
}
|
||||
TORCH_CHECK(result >= 0, "ValueError: substring not found");
|
||||
push(stack, result);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
@ -2183,11 +2170,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string fillchar = pop(stack).toStringRef();
|
||||
int64_t width = pop(stack).toInt();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
fillchar.size() == 1,
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
auto to_append =
|
||||
std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
|
||||
|
||||
@ -2207,11 +2192,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string fillchar = pop(stack).toStringRef();
|
||||
int64_t width = pop(stack).toInt();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
fillchar.size() == 1,
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
auto to_append =
|
||||
std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
|
||||
|
||||
@ -3358,10 +3341,8 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
|
||||
int64_t a = 0, b = 0;
|
||||
lldiv_t divresult = {};
|
||||
pop(stack, a, b);
|
||||
if (b == 0) {
|
||||
throw std::runtime_error(
|
||||
"ZeroDivisionError: integer division or modulo by zero");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
b != 0, "ZeroDivisionError: integer division or modulo by zero");
|
||||
divresult = lldiv(a, b);
|
||||
if (divresult.rem && (a < 0) != (b < 0)) {
|
||||
divresult.quot -= 1;
|
||||
@ -3379,9 +3360,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
|
||||
[](Stack& stack) {
|
||||
double a = 0, b = 0;
|
||||
pop(stack, a, b);
|
||||
if (b == 0) {
|
||||
throw std::runtime_error("ZeroDivisionError: float divmod()");
|
||||
}
|
||||
TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()");
|
||||
double rem = fmod(a, b);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
||||
if (rem && (a < 0) != (b < 0)) {
|
||||
@ -3426,9 +3405,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
|
||||
type_a a; \
|
||||
type_b b; \
|
||||
pop(stack, a, b); \
|
||||
if (b == 0) { \
|
||||
throw std::runtime_error("ZeroDivisionError: float divmod()"); \
|
||||
} \
|
||||
TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()"); \
|
||||
double quot = floor(a / b); \
|
||||
double rem = a - (quot * b); \
|
||||
push(stack, quot, rem); \
|
||||
|
||||
@ -43,6 +43,8 @@ def _partition_by_supported_nodes(gm, supported_ops, prefix):
|
||||
|
||||
|
||||
def _compile_submod(gm, prefix):
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module" and node.target.startswith(prefix):
|
||||
fake_inputs = []
|
||||
@ -56,13 +58,12 @@ def _compile_submod(gm, prefix):
|
||||
|
||||
submod = getattr(gm, node.target)
|
||||
|
||||
# _dummy_wrapper is to make call_function happy
|
||||
compiled_submod = _dummy_wrapper(
|
||||
torch._inductor.standalone_compile(
|
||||
submod, fake_inputs, dynamic_shapes="from_tracing_context"
|
||||
)
|
||||
compiled_fn = torch._inductor.standalone_compile(
|
||||
submod, fake_inputs, dynamic_shapes="from_tracing_context", aot=True
|
||||
)
|
||||
|
||||
assert isinstance(compiled_fn, AOTCompiledArtifact)
|
||||
# _dummy_wrapper is to make call_function happy
|
||||
compiled_submod = _dummy_wrapper(compiled_fn)
|
||||
with gm.graph.inserting_after(node):
|
||||
new_node = gm.graph.call_function(
|
||||
compiled_submod, args=node.args, kwargs=node.kwargs
|
||||
|
||||
Reference in New Issue
Block a user