mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Compare commits
10 Commits
codex/add-
...
viable/str
| Author | SHA1 | Date | |
|---|---|---|---|
| 6c4412f72b | |||
| 78bf6186f2 | |||
| c40048472c | |||
| 3dfd0c7584 | |||
| e6ba4d0725 | |||
| bdf7cb9d9c | |||
| 6aed378958 | |||
| 8b3dc0d1b0 | |||
| 06773663b5 | |||
| 0bff65503c |
@ -39,7 +39,7 @@ struct HostBlock {
|
||||
};
|
||||
|
||||
template <typename B>
|
||||
struct alignas(64) FreeBlockList {
|
||||
struct alignas(hardware_destructive_interference_size) FreeBlockList {
|
||||
std::mutex mutex_;
|
||||
std::deque<B*> list_;
|
||||
};
|
||||
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
|
||||
// Struct containing memory allocator summary statistics for host, as they
|
||||
// are staged for reporting. This is a temporary struct that is used to
|
||||
// avoid locking the allocator while collecting stats.
|
||||
struct alignas(64) HostStatsStaged {
|
||||
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
|
||||
std::mutex timing_mutex_;
|
||||
// COUNT: total allocations (active + free)
|
||||
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
|
||||
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
|
||||
}
|
||||
|
||||
alignas(64) std::mutex blocks_mutex_;
|
||||
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
|
||||
ska::flat_hash_set<B*> blocks_; // block list
|
||||
ska::flat_hash_map<void*, B*> ptr_to_block_;
|
||||
|
||||
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
|
||||
// size. This allows us to quickly find a free block of the right size.
|
||||
// We use deque to store per size free list and guard the list with its own
|
||||
// mutex.
|
||||
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
|
||||
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
|
||||
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
|
||||
|
||||
alignas(64) std::mutex events_mutex_;
|
||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||
|
||||
// Indicates whether the object is active.
|
||||
// Set to false in the destructor to signal background threads to stop.
|
||||
std::atomic<bool> active_{true};
|
||||
protected:
|
||||
alignas(64) HostStatsStaged stats_;
|
||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
||||
};
|
||||
|
||||
struct TORCH_API HostAllocator : public at::Allocator {
|
||||
|
||||
@ -141,8 +141,6 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
|
||||
return;
|
||||
}
|
||||
|
||||
checkTrilTriuMemoryOverlap(result, self);
|
||||
|
||||
bool inplace_op = self.is_same(result);
|
||||
|
||||
bool inplace_update = false;
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
|
||||
@ -55,13 +54,4 @@ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor
|
||||
return std::make_tuple(true, tensor);
|
||||
}
|
||||
|
||||
static inline void checkTrilTriuMemoryOverlap(const Tensor& result, const Tensor& self) {
|
||||
if (result.is_same(self)) {
|
||||
at::assert_no_internal_overlap(result);
|
||||
} else {
|
||||
at::assert_no_internal_overlap(result);
|
||||
at::assert_no_overlap(result, self);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TriangularOpsUtils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -111,8 +110,6 @@ __global__ void triu_tril_kernel(
|
||||
|
||||
template <bool upper>
|
||||
void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
|
||||
checkTrilTriuMemoryOverlap(result, self);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
at::ScalarType::ComplexHalf,
|
||||
at::ScalarType::Half,
|
||||
|
||||
@ -202,7 +202,6 @@ supported:
|
||||
- select_backward
|
||||
- _trilinear
|
||||
- linalg_pinv.atol_rtol_tensor
|
||||
- svd
|
||||
- logsumexp.out
|
||||
symint:
|
||||
- empty.memory_format
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/alignment.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
|
||||
// where we would like to support composite implicit kernels but not
|
||||
// explicit kernels therefore we manually add the key to the
|
||||
// math_dispatch_keyset
|
||||
DispatchKeySet{DispatchKey::NestedTensor};
|
||||
DispatchKeySet{DispatchKey::NestedTensor} |
|
||||
// Functionalize should always reuse CompositeImplicit decomps.
|
||||
DispatchKeySet{DispatchKey::Functionalize};
|
||||
|
||||
constexpr DispatchKeySet nested_dispatch_keyset =
|
||||
DispatchKeySet(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <new>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -18,4 +19,12 @@ constexpr size_t gPagesize = 4096;
|
||||
// since the default thp pagesize is 2MB, enable thp only
|
||||
// for buffers of size 2MB or larger to avoid memory bloating
|
||||
constexpr size_t gAlloc_threshold_thp = static_cast<size_t>(2) * 1024 * 1024;
|
||||
|
||||
// Cache line size used to avoid false sharing between threads. Falls back to 64
|
||||
// bytes if C++17 feature is unavailable.
|
||||
#ifdef __cpp_lib_hardware_interference_size
|
||||
using std::hardware_destructive_interference_size;
|
||||
#else
|
||||
constexpr std::size_t hardware_destructive_interference_size = 64;
|
||||
#endif
|
||||
} // namespace c10
|
||||
|
||||
@ -941,7 +941,7 @@ class EventPool {
|
||||
|
||||
private:
|
||||
struct PerDevicePool {
|
||||
alignas(64) std::mutex mutex_;
|
||||
alignas(hardware_destructive_interference_size) std::mutex mutex_;
|
||||
std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
|
||||
};
|
||||
std::vector<PerDevicePool> pools_;
|
||||
@ -3758,11 +3758,6 @@ static void uncached_delete(void* ptr) {
|
||||
static void local_raw_delete(void* ptr);
|
||||
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
|
||||
thread_local std::string DeviceCachingAllocator::user_metadata;
|
||||
#ifdef __cpp_lib_hardware_interference_size
|
||||
using std::hardware_destructive_interference_size;
|
||||
#else
|
||||
static constexpr std::size_t hardware_destructive_interference_size = 64;
|
||||
#endif
|
||||
|
||||
class NativeCachingAllocator : public CUDAAllocator {
|
||||
private:
|
||||
|
||||
@ -554,7 +554,7 @@ static void local_raw_delete(void* ptr);
|
||||
|
||||
class XPUAllocator : public DeviceAllocator {
|
||||
private:
|
||||
std::mutex mutex;
|
||||
alignas(hardware_destructive_interference_size) std::mutex mutex;
|
||||
ska::flat_hash_map<void*, Block*> allocated_blocks;
|
||||
|
||||
void add_allocated_block(Block* block) {
|
||||
|
||||
@ -16,7 +16,7 @@ find_path(vecLib_INCLUDE_DIR vecLib.h
|
||||
DOC "vecLib include directory"
|
||||
PATHS /System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix}
|
||||
/System/Library/${__veclib_include_suffix}
|
||||
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.9.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/
|
||||
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/
|
||||
${CMAKE_OSX_SYSROOT}/System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
|
||||
@ -318,17 +318,19 @@ class inner_f(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 2)
|
||||
|
||||
def forward(self, x, scale=1.0):
|
||||
def forward(self, x, *, scale):
|
||||
return self.linear(x) * scale
|
||||
|
||||
model = ModuleWithKwargs()
|
||||
inputs = (torch.randn(4, 3),)
|
||||
kwargs = {"scale": 2.0}
|
||||
kwargs = {"scale": torch.tensor(2.0)}
|
||||
|
||||
gm = _dynamo_graph_capture_for_export(model)(*inputs, **kwargs)
|
||||
|
||||
with ExitStack() as stack:
|
||||
# Export joint with descriptors
|
||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||
stack, model, inputs, kwargs, decompositions=decomposition_table
|
||||
stack, gm, inputs, kwargs, decompositions=decomposition_table
|
||||
)
|
||||
|
||||
# Test the exported graph structure
|
||||
@ -336,9 +338,17 @@ class inner_f(torch.nn.Module):
|
||||
print_output=False, expanded_def=True
|
||||
)
|
||||
|
||||
# For some reason PYTORCH_TEST_WITH_CROSSREF will add extra spaces.
|
||||
# I tried to fix this in normalize_gm but there are too many files
|
||||
# depending on that behavior..
|
||||
graph_code_str = normalize_gm(graph_code)
|
||||
graph_code_str = "\n".join(
|
||||
[line for line in graph_code_str.split("\n") if len(line.rstrip()) > 0]
|
||||
)
|
||||
|
||||
# Expect test on the printed graph
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(graph_code),
|
||||
graph_code_str,
|
||||
"""\
|
||||
class inner_f(torch.nn.Module):
|
||||
def forward(
|
||||
@ -346,19 +356,20 @@ class inner_f(torch.nn.Module):
|
||||
primals,
|
||||
tangents,
|
||||
):
|
||||
primals_1: "f32[2, 3]" # ParamAOTInput(target='linear.weight')
|
||||
primals_2: "f32[2]" # ParamAOTInput(target='linear.bias')
|
||||
primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear_weight')
|
||||
primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear_bias')
|
||||
primals_3: "f32[4, 3]" # PlainAOTInput(idx=0)
|
||||
primals_4: "f32[]" # PlainAOTInput(idx=1)
|
||||
tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0))
|
||||
primals_1, primals_2, primals_3, primals_4 , tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
|
||||
primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
|
||||
transpose: "f32[3, 2]" = torch.ops.prims.transpose.default(primals_1, [1, 0]); primals_1 = None
|
||||
mm: "f32[4, 2]" = torch.ops.aten.mm.default(primals_3, transpose); transpose = None
|
||||
mul: "f32[4, 2]" = torch.ops.prims.mul.default(mm, 1.0); mm = None
|
||||
mul_1: "f32[2]" = torch.ops.prims.mul.default(primals_2, 1.0); primals_2 = None
|
||||
broadcast_in_dim: "f32[4, 2]" = torch.ops.prims.broadcast_in_dim.default(mul_1, [4, 2], [1]); mul_1 = None
|
||||
add: "f32[4, 2]" = torch.ops.prims.add.default(mul, broadcast_in_dim); mul = broadcast_in_dim = None
|
||||
mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, 2.0); add = None
|
||||
mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, 2.0); tangents_1 = None
|
||||
mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, primals_4); add = None
|
||||
mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, primals_4); tangents_1 = primals_4 = None
|
||||
transpose_1: "f32[2, 4]" = torch.ops.prims.transpose.default(mul_3, [1, 0])
|
||||
mm_1: "f32[2, 3]" = torch.ops.aten.mm.default(transpose_1, primals_3); transpose_1 = primals_3 = None
|
||||
transpose_2: "f32[3, 2]" = torch.ops.prims.transpose.default(mm_1, [1, 0]); mm_1 = None
|
||||
@ -368,12 +379,11 @@ class inner_f(torch.nn.Module):
|
||||
transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None
|
||||
return pytree.tree_unflatten([
|
||||
mul_2, # PlainAOTOutput(idx=0)
|
||||
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.weight'))
|
||||
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.bias'))
|
||||
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_weight'))
|
||||
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_bias'))
|
||||
None, # None
|
||||
None, # None
|
||||
], self._out_spec)
|
||||
""",
|
||||
], self._out_spec)""",
|
||||
)
|
||||
|
||||
# Compile the result
|
||||
|
||||
@ -7356,6 +7356,7 @@ metadata incorrectly.
|
||||
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
|
||||
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
def test_rms_norm(self):
|
||||
# Only CUDA rms norm fails to be decomposed
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import tempfile
|
||||
import unittest
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._environment import is_fbcode
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import IS_CI
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU, requires_gpu
|
||||
|
||||
|
||||
class Simple(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(16, 1)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class TestAOTInductorWindowsCrossCompilation(TestCase):
|
||||
@requires_gpu()
|
||||
def test_simple_so(self):
|
||||
if is_fbcode() or IS_CI:
|
||||
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
|
||||
|
||||
# TODO: enable in CI
|
||||
with torch.no_grad():
|
||||
device = "cuda"
|
||||
model = Simple().to(device=device)
|
||||
example_inputs = (torch.randn(8, 10, device=device),)
|
||||
batch_dim = torch.export.Dim("batch", min=1, max=1024)
|
||||
exported = torch.export.export(
|
||||
model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}
|
||||
)
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
exported,
|
||||
inductor_configs={
|
||||
"aot_inductor.model_name_for_generated_files": "model",
|
||||
"aot_inductor.cross_target_platform": "windows",
|
||||
"aot_inductor.link_libtorch": False,
|
||||
# TODO: need to add aoti_shim_library_path for CI
|
||||
"aot_inductor.aoti_shim_library": "executorch",
|
||||
# no fallback ops
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON,CPP",
|
||||
"max_autotune_conv_backends": "TRITON,CPP",
|
||||
"aot_inductor.embed_kernel_binary": True,
|
||||
# simplify things for now
|
||||
"aot_inductor.precompile_headers": False,
|
||||
"aot_inductor.package_constants_on_disk_format": "binary_blob",
|
||||
"aot_inductor.package_constants_in_so": False,
|
||||
},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with zipfile.ZipFile(package_path, "r") as zf:
|
||||
zf.extractall(tmpdir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_GPU:
|
||||
run_tests(needs="filelock")
|
||||
@ -9,6 +9,7 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._environment import is_fbcode
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
|
||||
|
||||
@ -77,6 +78,9 @@ class WindowsCrossCompilationTestFramework:
|
||||
"This test should run on Linux for cross-compilation"
|
||||
)
|
||||
|
||||
if is_fbcode():
|
||||
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
|
||||
|
||||
self.assertTrue("WINDOWS_CUDA_HOME" in os.environ)
|
||||
|
||||
with torch.no_grad():
|
||||
@ -128,6 +132,9 @@ class WindowsCrossCompilationTestFramework:
|
||||
if platform.system() != "Windows":
|
||||
raise unittest.SkipTest("This test should run on Windows")
|
||||
|
||||
if is_fbcode():
|
||||
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
|
||||
|
||||
if not HAS_GPU:
|
||||
raise unittest.SkipTest("Test requires GPU")
|
||||
|
||||
|
||||
@ -1839,12 +1839,22 @@ class TestStandaloneCompile(TestCase):
|
||||
@parametrize("format", ("binary", "unpacked"))
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("graph_partition", (False, True))
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_basic(
|
||||
self, device: str, format: str, dynamic: bool, graph_partition: bool
|
||||
self,
|
||||
device: str,
|
||||
format: str,
|
||||
dynamic: bool,
|
||||
graph_partition: bool,
|
||||
is_aot: bool,
|
||||
) -> None:
|
||||
if device == GPU_TYPE and not HAS_GPU:
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
|
||||
# AOT mode does not support unpacked format
|
||||
if is_aot and format == "unpacked":
|
||||
raise unittest.SkipTest("AOT mode does not support unpacked format")
|
||||
|
||||
mod = torch.nn.Linear(1, 3, device=device)
|
||||
x = torch.randn(4, 1, device=device)
|
||||
if dynamic:
|
||||
@ -1869,7 +1879,9 @@ class TestStandaloneCompile(TestCase):
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
||||
compiled_artifact = torch._inductor.standalone_compile(gm, args)
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, aot=is_aot
|
||||
)
|
||||
compiled_artifact.save(path=path, format=format)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
@ -1885,13 +1897,15 @@ class TestStandaloneCompile(TestCase):
|
||||
compiled_out = loaded(*concrete_args)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
if not is_aot:
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("dynamic", (False, True))
|
||||
def test_call_in_backend(self, dynamic: bool) -> None:
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_call_in_backend(self, dynamic: bool, is_aot: bool) -> None:
|
||||
mod = torch.nn.Linear(1, 3)
|
||||
x = torch.randn(4, 1)
|
||||
if dynamic:
|
||||
@ -1904,7 +1918,7 @@ class TestStandaloneCompile(TestCase):
|
||||
eager_out = f(x)
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
return torch._inductor.standalone_compile(gm, args)
|
||||
return torch._inductor.standalone_compile(gm, args, aot=is_aot)
|
||||
|
||||
with fresh_cache():
|
||||
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
|
||||
@ -2055,7 +2069,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_dynamic_shapes_from_graph(self):
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_dynamic_shapes_from_graph(self, is_aot: bool):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2067,7 +2082,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
assert not kwargs
|
||||
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, dynamic_shapes="from_graph"
|
||||
gm, args, dynamic_shapes="from_graph", aot=is_aot
|
||||
)
|
||||
x = torch.ones(4)
|
||||
(result,) = compiled_artifact(4, x)
|
||||
@ -2077,7 +2092,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch({"autograd_cache_normalize_inputs": True})
|
||||
def test_split_module(self):
|
||||
@parametrize("is_aot", (False, True))
|
||||
def test_split_module(self, is_aot):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x, a0, a1, b0, b1, c0, c1):
|
||||
x = x + (a0**2) + (a1 / 2)
|
||||
@ -2116,16 +2132,24 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
split = torch.fx.passes.split_module.split_module(gm, gm, split)
|
||||
|
||||
# Each of the split graphs only has one output.
|
||||
ca0 = torch._inductor.standalone_compile(split.submod_0, (a0, x, a1))
|
||||
ca1 = torch._inductor.standalone_compile(split.submod_1, (b0, x, b1))
|
||||
ca2 = torch._inductor.standalone_compile(split.submod_2, (c0, x, c1))
|
||||
ca0 = torch._inductor.standalone_compile(
|
||||
split.submod_0, (a0, x, a1), aot=is_aot
|
||||
)
|
||||
ca1 = torch._inductor.standalone_compile(
|
||||
split.submod_1, (b0, x, b1), aot=is_aot
|
||||
)
|
||||
ca2 = torch._inductor.standalone_compile(
|
||||
split.submod_2, (c0, x, c1), aot=is_aot
|
||||
)
|
||||
|
||||
y = ca0(a0, x, a1)
|
||||
y = ca1(b0, y, b1)
|
||||
y = ca2(c0, y, c1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
|
||||
if not is_aot:
|
||||
# fx graph cache doesn't run in AOT mode
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
|
||||
# TODO: split_module causes ca1 and ca2 to have different type annotations
|
||||
# for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
@ -2138,8 +2162,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("is_aot", (False, True))
|
||||
@parametrize("config_patches", [True, False])
|
||||
def test_dynamic_shapes_from_example_inputs(self, config_patches):
|
||||
def test_dynamic_shapes_from_example_inputs(self, config_patches, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2161,6 +2186,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
(5, torch.ones(4)),
|
||||
dynamic_shapes="from_example_inputs",
|
||||
options={"config_patches": config_patches},
|
||||
aot=is_aot,
|
||||
)
|
||||
x = torch.ones(4)
|
||||
(result,) = compiled_artifact(3, x)
|
||||
@ -2175,8 +2201,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("is_aot", (True, False))
|
||||
@parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"])
|
||||
def test_static_shapes(self, dynamic_shapes):
|
||||
def test_static_shapes(self, dynamic_shapes, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2186,7 +2213,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
|
||||
assert not kwargs
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
static_gm, [static_x], dynamic_shapes=dynamic_shapes
|
||||
static_gm, [static_x], dynamic_shapes=dynamic_shapes, aot=is_aot
|
||||
)
|
||||
x = torch.randn(3)
|
||||
(result,) = compiled_artifact(x)
|
||||
@ -2198,8 +2225,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("is_aot", (True, False))
|
||||
@parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"])
|
||||
def test_backend(self, dynamic_shapes):
|
||||
def test_backend(self, dynamic_shapes, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2208,7 +2236,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, dynamic_shapes=dynamic_shapes
|
||||
gm, args, dynamic_shapes=dynamic_shapes, aot=is_aot
|
||||
)
|
||||
y = torch.randn(4)
|
||||
(result,) = compiled_artifact(4, y)
|
||||
@ -2221,7 +2249,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_backend_dynamic_shapes_from_example_inputs(self):
|
||||
@parametrize("is_aot", (True, False))
|
||||
def test_backend_dynamic_shapes_from_example_inputs(self, is_aot):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
@ -2230,7 +2259,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs"
|
||||
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs", aot=is_aot
|
||||
)
|
||||
y = torch.ones(4)
|
||||
(result,) = compiled_artifact(4, y)
|
||||
|
||||
@ -1543,26 +1543,22 @@ class CPUReproTests(TestCase):
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
inputs = (
|
||||
x,
|
||||
scale,
|
||||
zero_point,
|
||||
use_dequant,
|
||||
use_quant,
|
||||
quant_min,
|
||||
quant_max,
|
||||
dtype,
|
||||
dequant_out_dtype,
|
||||
self.common(
|
||||
fn,
|
||||
(
|
||||
x,
|
||||
scale,
|
||||
zero_point,
|
||||
use_dequant,
|
||||
use_quant,
|
||||
quant_min,
|
||||
quant_max,
|
||||
dtype,
|
||||
dequant_out_dtype,
|
||||
),
|
||||
)
|
||||
self.common(fn, inputs)
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Check that both main and tail loops are vectorized
|
||||
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
compiled_fn = torch.compile(fn)
|
||||
_, code = run_and_get_cpp_code(compiled_fn, *inputs)
|
||||
FileCheck().check_count("loadu", 2, exactly=True).run(code)
|
||||
|
||||
@requires_vectorization
|
||||
def test_dequant_quant_lowering_uint8(self):
|
||||
self._test_dequant_quant_lowering_helper(torch.uint8)
|
||||
@ -4814,22 +4810,6 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randn((22, 22), dtype=torch.double)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_double_reduction_vec(self):
|
||||
def fn(x):
|
||||
return x.sum(dim=1)
|
||||
@ -4839,22 +4819,6 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randn((22, 22), dtype=torch.double)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_convert_fp32_to_double_vec(self):
|
||||
def fn(x):
|
||||
return x.to(torch.double)
|
||||
@ -4864,22 +4828,6 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randn(22, 22)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::convert<double,2,float,1>", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_convert_double_to_fp32_vec(self):
|
||||
def fn(x):
|
||||
return x.to(torch.float32)
|
||||
@ -4889,22 +4837,6 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randn((22, 22), dtype=torch.double)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::convert<float,1,double,2>", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_no_redundant_to_dtypes_between_fused_scheduler_node(self):
|
||||
# https://github.com/pytorch/pytorch/issues/115260
|
||||
p0 = torch.tensor([1.0879], dtype=torch.float16)
|
||||
|
||||
@ -85,7 +85,6 @@ def init_lists():
|
||||
"linalg_inv_ex",
|
||||
"linalg_pinv.atol_rtol_tensor",
|
||||
"logsumexp",
|
||||
"svd",
|
||||
}
|
||||
# For some ops, we don't support all variants. Here we use formatted_name
|
||||
# to uniquely identify the variant.
|
||||
@ -221,15 +220,20 @@ class TestLazyOpInfo(TestCase):
|
||||
torch._lazy.wait_device_ops()
|
||||
prefix = "aten" if op.name in FALLBACK_LIST else "lazy"
|
||||
symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else ""
|
||||
metrics = remove_suffixes(torch._lazy.metrics.counter_names())
|
||||
cands = [f"{prefix}::{op.name}{symint_suffix}"]
|
||||
# check aliases
|
||||
for alias in op.aliases:
|
||||
cands.append(f"{prefix}::{alias.name}{symint_suffix}")
|
||||
|
||||
self.assertTrue(
|
||||
any(c in metrics for c in cands), f"none of {cands} not found in {metrics}"
|
||||
found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes(
|
||||
torch._lazy.metrics.counter_names()
|
||||
)
|
||||
# check aliases
|
||||
if not found:
|
||||
for alias in op.aliases:
|
||||
alias_found = (
|
||||
f"{prefix}::{alias.name}{symint_suffix}"
|
||||
in remove_suffixes(torch._lazy.metrics.counter_names())
|
||||
)
|
||||
found = found or alias_found
|
||||
if found:
|
||||
break
|
||||
self.assertTrue(found)
|
||||
|
||||
@ops(
|
||||
[
|
||||
|
||||
@ -1258,10 +1258,11 @@ class DecompOneOffTests(TestCase):
|
||||
)
|
||||
|
||||
# check RMSNorm was fused with sinh
|
||||
self.assertTrue("triton_per_fused__fused_rms_norm_sinh" in generated_codes[0])
|
||||
self.assertTrue(
|
||||
"triton_per_fused__fused_rms_norm__fused_rms_norm_backward_cosh_mul"
|
||||
in generated_codes[1]
|
||||
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
|
||||
)
|
||||
self.assertTrue(
|
||||
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -9986,20 +9986,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
self.assertEqual(result_triu_min, expected_triu_min)
|
||||
self.assertEqual(result_tril_min, expected_tril_min)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_triu_tril_inplace_memory_overlap(self, device, dtype):
|
||||
base = torch.rand((), dtype=dtype, device=device)
|
||||
expanded = base.expand(3, 3)
|
||||
msg = (
|
||||
"unsupported operation: more than one element of the written-to tensor "
|
||||
"refers to a single memory location. Please clone() the tensor before "
|
||||
"performing the operation."
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
expanded.triu_(1)
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
expanded.tril_(-1)
|
||||
|
||||
@dtypes(torch.float, torch.double)
|
||||
@precisionOverride({torch.float32: 1e-4})
|
||||
def test_1_sized_with_0_strided(self, device, dtype):
|
||||
|
||||
@ -404,7 +404,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.max_unpool3d,
|
||||
aten.mish,
|
||||
aten.mish_,
|
||||
aten.mish_backward,
|
||||
aten.mse_loss,
|
||||
aten.mse_loss_backward,
|
||||
aten.multi_margin_loss,
|
||||
@ -420,7 +419,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.native_dropout_backward,
|
||||
aten.native_group_norm_backward,
|
||||
aten.native_layer_norm_backward,
|
||||
aten._fused_rms_norm,
|
||||
aten._fused_rms_norm_backward,
|
||||
aten.new_empty,
|
||||
aten.new_full,
|
||||
@ -477,7 +475,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.silu,
|
||||
aten.silu_,
|
||||
aten.silu_backward.grad_input,
|
||||
aten.silu_backward,
|
||||
aten.sinc,
|
||||
aten.sinc_,
|
||||
aten.slice_backward,
|
||||
|
||||
@ -1757,61 +1757,6 @@ def native_layer_norm_backward_out(
|
||||
return grad_input
|
||||
|
||||
|
||||
@register_decomposition(aten._fused_rms_norm.default)
|
||||
def _fused_rms_norm(
|
||||
input: Tensor,
|
||||
normalized_shape: list[int],
|
||||
weight: Optional[Tensor],
|
||||
eps: Optional[float],
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
dims_to_reduce: list[int] = []
|
||||
for i in range(len(normalized_shape)):
|
||||
dims_to_reduce.append(input.dim() - i - 1)
|
||||
|
||||
# upcast is needed for fp16 and bf16
|
||||
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||
upcasted_input = input.to(computation_dtype)
|
||||
|
||||
# computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble]
|
||||
if eps is None:
|
||||
if computation_dtype in (torch.float32, torch.complex64):
|
||||
eps_val = torch.finfo(torch.float32).eps
|
||||
else:
|
||||
eps_val = torch.finfo(torch.float64).eps
|
||||
else:
|
||||
eps_val = eps
|
||||
|
||||
rqrst_input = torch.rsqrt(
|
||||
# NB: don't inplace here, will violate functional IR invariant
|
||||
# NB: carefully use the Scalar overload of add to ensure compatibility with the C++ decomp
|
||||
torch.ops.aten.add.Scalar(
|
||||
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True), eps_val
|
||||
)
|
||||
)
|
||||
|
||||
upcasted_result = upcasted_input.mul(rqrst_input)
|
||||
|
||||
if weight is not None:
|
||||
upcasted_result = upcasted_result.mul(weight)
|
||||
|
||||
# NB: nested should be dead here, just here for fidelity
|
||||
is_nested = input.is_nested or (weight is not None and weight.is_nested)
|
||||
memory_format = utils.suggest_memory_format(input)
|
||||
is_channels_last = memory_format in (
|
||||
torch.channels_last,
|
||||
torch.channels_last_3d,
|
||||
)
|
||||
|
||||
if not is_nested and not is_channels_last:
|
||||
upcasted_result = upcasted_result.contiguous()
|
||||
rqrst_input = rqrst_input.contiguous()
|
||||
|
||||
# Cast normalized result back to original input type
|
||||
result = upcasted_result.type_as(input)
|
||||
|
||||
return result, rqrst_input
|
||||
|
||||
|
||||
@register_decomposition(aten._fused_rms_norm_backward.default)
|
||||
def _fused_rms_norm_backward(
|
||||
grad_out: Tensor,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
import importlib
|
||||
import inspect
|
||||
@ -15,6 +14,10 @@ from torch._dynamo.graph_utils import _graph_device_type
|
||||
from torch._dynamo.package import SystemInfo
|
||||
|
||||
from . import convert_frame
|
||||
from .aot_compile_types import (
|
||||
BundledAOTAutogradSerializableCallable,
|
||||
SerializableCallable,
|
||||
)
|
||||
from .hooks import Hooks
|
||||
|
||||
|
||||
@ -26,18 +29,6 @@ if TYPE_CHECKING:
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SerializableCallable(abc.ABC):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
def bind_locals(
|
||||
signature: inspect.Signature, *args: Any, **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
@ -149,53 +140,6 @@ class AOTCompiledFunction:
|
||||
self._guard_check_enabled = False
|
||||
|
||||
|
||||
class BundledAOTAutogradSerializableCallable(SerializableCallable):
|
||||
"""
|
||||
Represents a serializable callable generated by compile_fx.
|
||||
This class wraps around the compiled function generated by AOTAutograd.
|
||||
|
||||
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
|
||||
this object should be what's *returned* by aot_module_simplified.
|
||||
We'll do that refactor in a later PR.
|
||||
"""
|
||||
|
||||
def __init__(self, compiled_fn: Any) -> None:
|
||||
"""
|
||||
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
|
||||
of a compiled function generated by AOTAutograd.
|
||||
"""
|
||||
assert hasattr(compiled_fn, "serialize")
|
||||
self.compiled_fn = compiled_fn
|
||||
|
||||
def __getattr__(self, attr: Any) -> Any:
|
||||
if hasattr(self, attr):
|
||||
return getattr(super(), attr)
|
||||
else:
|
||||
return getattr(self.compiled_fn, attr)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, fn: "BundledAOTAutogradSerializableCallable"
|
||||
) -> bytes:
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
result = pickle.dumps(fn.compiled_fn.serialize())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
deserialize_bundled_cache_entry,
|
||||
)
|
||||
|
||||
entry = pickle.loads(data)
|
||||
|
||||
compiled_fn = deserialize_bundled_cache_entry(entry)
|
||||
return cls(compiled_fn)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.compiled_fn(*args, **kwargs)
|
||||
|
||||
|
||||
def aot_compile_fullgraph(
|
||||
model: Any,
|
||||
example_inputs: tuple[tuple[Any, ...], dict[str, Any]],
|
||||
|
||||
61
torch/_dynamo/aot_compile_types.py
Normal file
61
torch/_dynamo/aot_compile_types.py
Normal file
@ -0,0 +1,61 @@
|
||||
import abc
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SerializableCallable(abc.ABC):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class BundledAOTAutogradSerializableCallable(SerializableCallable):
|
||||
"""
|
||||
Represents a serializable callable generated by compile_fx.
|
||||
This class wraps around the compiled function generated by AOTAutograd.
|
||||
|
||||
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
|
||||
this object should be what's *returned* by aot_module_simplified.
|
||||
We'll do that refactor in a later PR.
|
||||
"""
|
||||
|
||||
def __init__(self, compiled_fn: Any) -> None:
|
||||
"""
|
||||
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
|
||||
of a compiled function generated by AOTAutograd.
|
||||
"""
|
||||
assert hasattr(compiled_fn, "serialize")
|
||||
self.compiled_fn = compiled_fn
|
||||
|
||||
def __getattr__(self, attr: Any) -> Any:
|
||||
return getattr(self.compiled_fn, attr)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, fn: "BundledAOTAutogradSerializableCallable"
|
||||
) -> bytes:
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
result = pickle.dumps(fn.compiled_fn.serialize())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
deserialize_bundled_cache_entry,
|
||||
)
|
||||
|
||||
entry = pickle.loads(data)
|
||||
|
||||
compiled_fn = deserialize_bundled_cache_entry(entry)
|
||||
return cls(compiled_fn)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.compiled_fn(*args, **kwargs)
|
||||
@ -1347,6 +1347,15 @@ def create_functional_call(
|
||||
maybe_disable_thunkify(),
|
||||
):
|
||||
if isinstance(mod, torch.fx.GraphModule):
|
||||
if kwargs:
|
||||
# Handle **kwargs. FX only natively supports positional
|
||||
# arguments (through placeholders).
|
||||
arg_list = list(args[params_len:])
|
||||
arg_list.extend(list(kwargs.values()))
|
||||
args = tuple(arg_list)
|
||||
else:
|
||||
args = args[params_len:]
|
||||
|
||||
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", "Anomaly Detection has been enabled."
|
||||
@ -1355,9 +1364,7 @@ def create_functional_call(
|
||||
fake_mode = detect_fake_mode()
|
||||
assert fake_mode is not None
|
||||
fake_mode.epoch += 1
|
||||
out = PropagateUnbackedSymInts(mod).run(
|
||||
*args[params_len:], **kwargs
|
||||
)
|
||||
out = PropagateUnbackedSymInts(mod).run(*args)
|
||||
else:
|
||||
out = mod(*args[params_len:], **kwargs)
|
||||
|
||||
|
||||
@ -391,6 +391,7 @@ def standalone_compile(
|
||||
"from_example_inputs", "from_tracing_context", "from_graph"
|
||||
] = "from_graph",
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
|
||||
) -> CompiledArtifact:
|
||||
"""
|
||||
Precompilation API for inductor.
|
||||
@ -422,5 +423,5 @@ def standalone_compile(
|
||||
|
||||
options = options if options else {}
|
||||
return standalone_compile(
|
||||
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options
|
||||
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot
|
||||
)
|
||||
|
||||
@ -159,14 +159,11 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
]
|
||||
|
||||
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
torch.float64,
|
||||
torch.float,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -5,10 +5,12 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
|
||||
|
||||
import torch.fx
|
||||
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._inductor.cpp_builder import normalize_path_separator
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
||||
@ -30,9 +32,9 @@ if TYPE_CHECKING:
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompiledArtifact:
|
||||
class CompiledArtifact(ABC):
|
||||
"""
|
||||
CompiledArtifact class represents the precompiled inductor artifact that
|
||||
CompiledArtifact class represents the inductor cache artifacts that
|
||||
can be invoked in order to avoid repeated compilation.
|
||||
|
||||
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
|
||||
@ -45,11 +47,68 @@ class CompiledArtifact:
|
||||
binary or unpacked data.
|
||||
|
||||
Finally, the CompiledArtifact can be invoked via the __call__ method
|
||||
to execute the precompiled artifact.
|
||||
to execute the cached artifact.
|
||||
"""
|
||||
|
||||
_compiled_fn: Callable[..., Any]
|
||||
_artifacts: Optional[tuple[bytes, CacheInfo]]
|
||||
def __init__(
|
||||
self,
|
||||
compiled_fn: Callable[..., Any],
|
||||
artifacts: Optional[tuple[bytes, CacheInfo]],
|
||||
):
|
||||
self._compiled_fn = compiled_fn
|
||||
self._artifacts = artifacts
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, *args: Any) -> Any: ...
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> CompiledArtifact:
|
||||
if format == "unpacked":
|
||||
# If format is unpacked, it must be a CacheCompiledArtifact
|
||||
return CacheCompiledArtifact.load(path=path, format=format)
|
||||
|
||||
assert format == "binary"
|
||||
with open(path, "rb") as file:
|
||||
from torch.utils._appending_byte_serializer import BytesReader
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
result_bytes = file.read()
|
||||
reader = BytesReader(result_bytes)
|
||||
header = reader.read_bytes()
|
||||
if header == AOTCompiledArtifact.AOT_HEADER:
|
||||
assert reader.read_bytes() == torch_key()
|
||||
artifact = reader.read_bytes()
|
||||
assert reader.is_finished()
|
||||
return AOTCompiledArtifact.deserialize(artifact)
|
||||
# Otherwise, it's in the CacheCompiledArtifact format
|
||||
elif header == CacheCompiledArtifact.CACHE_HEADER:
|
||||
assert reader.read_bytes() == torch_key()
|
||||
key = reader.read_str()
|
||||
artifact_bytes = reader.read_bytes()
|
||||
assert reader.is_finished()
|
||||
torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
return CacheCompiledArtifact._load_impl(nullcontext(), key)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: "
|
||||
+ header.decode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
class CacheCompiledArtifact(CompiledArtifact):
|
||||
"""
|
||||
CompiledArtifact that depends on torch.compiler.save_cache_artifacts
|
||||
"""
|
||||
|
||||
CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -83,6 +142,7 @@ class CompiledArtifact:
|
||||
from .codecache import torch_key
|
||||
|
||||
writer = BytesWriter()
|
||||
writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER)
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_str(key)
|
||||
writer.write_bytes(artifact_bytes)
|
||||
@ -116,9 +176,51 @@ class CompiledArtifact:
|
||||
log.info("Output code written to: %s", output_file)
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
def _load_impl(
|
||||
cache_dir_ctx: AbstractContextManager[Any], key: str
|
||||
) -> CompiledArtifact:
|
||||
with (
|
||||
cache_dir_ctx,
|
||||
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
|
||||
):
|
||||
with torch._functorch.config.patch(strict_autograd_cache=True):
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache,
|
||||
)
|
||||
|
||||
result = AOTAutogradCache._lookup(
|
||||
key,
|
||||
local=True,
|
||||
remote=False,
|
||||
args=[],
|
||||
cache_info={},
|
||||
aot_config=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
(entry, _) = result
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
|
||||
fx_config = _CompileFxKwargs(
|
||||
cudagraphs=BoxedBool(False),
|
||||
boxed_forward_device_index=BoxedDeviceIndex(0),
|
||||
)
|
||||
|
||||
context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv()))
|
||||
with torch._guards.tracing(context):
|
||||
compiled_fn = entry.wrap_post_compile(
|
||||
[], entry.sanitized_aot_config, fx_config
|
||||
)
|
||||
return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> tuple[str, AbstractContextManager[Any]]:
|
||||
"""
|
||||
Do format specific prep and loads, return a context manager and key
|
||||
"""
|
||||
path = normalize_path_separator(path)
|
||||
with dynamo_timed("CompiledArtifact.load"):
|
||||
if format == "binary":
|
||||
@ -137,8 +239,7 @@ class CompiledArtifact:
|
||||
assert reader.is_finished()
|
||||
|
||||
torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
|
||||
return key, nullcontext()
|
||||
else:
|
||||
assert format == "unpacked"
|
||||
assert os.path.isdir(path)
|
||||
@ -148,43 +249,105 @@ class CompiledArtifact:
|
||||
assert len(files) == 1
|
||||
key = files[0]
|
||||
cache_dir_ctx = temporary_cache_dir(path)
|
||||
return key, cache_dir_ctx
|
||||
|
||||
with (
|
||||
cache_dir_ctx,
|
||||
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
|
||||
):
|
||||
with torch._functorch.config.patch(strict_autograd_cache=True):
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache,
|
||||
)
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> CompiledArtifact:
|
||||
key, cache_dir_ctx = CacheCompiledArtifact._prepare_load(
|
||||
path=path, format=format
|
||||
)
|
||||
return CacheCompiledArtifact._load_impl(cache_dir_ctx, key)
|
||||
|
||||
result = AOTAutogradCache._lookup(
|
||||
key,
|
||||
local=True,
|
||||
remote=False,
|
||||
args=[],
|
||||
cache_info={},
|
||||
aot_config=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
(entry, _) = result
|
||||
class AOTCompiledArtifact(CompiledArtifact):
|
||||
"""
|
||||
Similar to CompiledArtifact, but the object is a single, bundled precompiled function.
|
||||
This object is always a serializable callable function.
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which
|
||||
is used by torch._dynamo.aot_compile for AOT Precompilation.
|
||||
"""
|
||||
|
||||
fx_config = _CompileFxKwargs(
|
||||
cudagraphs=BoxedBool(False),
|
||||
boxed_forward_device_index=BoxedDeviceIndex(0),
|
||||
)
|
||||
AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8")
|
||||
|
||||
context = torch._guards.TracingContext(
|
||||
FakeTensorMode(shape_env=ShapeEnv())
|
||||
)
|
||||
with torch._guards.tracing(context):
|
||||
compiled_fn = entry.wrap_post_compile(
|
||||
[], entry.sanitized_aot_config, fx_config
|
||||
)
|
||||
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
|
||||
def __init__(
|
||||
self,
|
||||
compiled_fn: Callable[..., Any],
|
||||
):
|
||||
self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
|
||||
self._artifacts = (
|
||||
None # We don't need artifacts, the inner object handles everything
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_bundled_callable(
|
||||
bundled_fn: BundledAOTAutogradSerializableCallable,
|
||||
) -> AOTCompiledArtifact:
|
||||
return AOTCompiledArtifact(bundled_fn.compiled_fn)
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
return self.inner_fn(*args)
|
||||
|
||||
def save(
|
||||
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> None:
|
||||
if format == "unpacked":
|
||||
raise RuntimeError(
|
||||
"AOTCompiledArtifact does not support unpacked format yet"
|
||||
)
|
||||
result_bytes = self.serialize()
|
||||
from torch.utils._appending_byte_serializer import BytesWriter
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
writer = BytesWriter()
|
||||
writer.write_bytes(AOTCompiledArtifact.AOT_HEADER)
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_bytes(result_bytes)
|
||||
|
||||
from torch._inductor.codecache import write_atomic
|
||||
|
||||
# Save a sentinel file to indicate that this is AOT
|
||||
write_atomic(path, writer.to_bytes())
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts(
|
||||
self.inner_fn
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(result_bytes: bytes) -> AOTCompiledArtifact:
|
||||
deserialized = (
|
||||
BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts(
|
||||
result_bytes
|
||||
)
|
||||
)
|
||||
assert isinstance(deserialized, BundledAOTAutogradSerializableCallable)
|
||||
return AOTCompiledArtifact.from_bundled_callable(deserialized)
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> CompiledArtifact:
|
||||
if format == "unpacked":
|
||||
raise RuntimeError(
|
||||
"AOTCompiledArtifact does not support unpacked format yet"
|
||||
)
|
||||
with open(path, "rb") as file:
|
||||
from torch.utils._appending_byte_serializer import BytesReader
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
result_bytes = file.read()
|
||||
reader = BytesReader(result_bytes)
|
||||
header = reader.read_bytes()
|
||||
assert header == AOTCompiledArtifact.AOT_HEADER
|
||||
assert reader.read_bytes() == torch_key()
|
||||
artifact = reader.read_bytes()
|
||||
assert reader.is_finished()
|
||||
return AOTCompiledArtifact.deserialize(artifact)
|
||||
|
||||
|
||||
def standalone_compile(
|
||||
@ -193,7 +356,11 @@ def standalone_compile(
|
||||
*,
|
||||
dynamic_shapes: Any,
|
||||
options: Any,
|
||||
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
|
||||
) -> CompiledArtifact:
|
||||
"""
|
||||
Implementation of torch.inductor.standalone_compile
|
||||
"""
|
||||
from torch.compiler._cache import CacheArtifactManager
|
||||
|
||||
from .compile_fx import compile_fx
|
||||
@ -249,6 +416,7 @@ def standalone_compile(
|
||||
torch._guards.tracing(context),
|
||||
CacheArtifactManager.with_fresh_cache(),
|
||||
config.patch("triton.autotune_at_compile_time", True),
|
||||
torch._functorch.config.patch("bundled_autograd_cache", aot),
|
||||
):
|
||||
# compile_fx can mutate gm
|
||||
gm = copy.deepcopy(gm)
|
||||
@ -256,7 +424,12 @@ def standalone_compile(
|
||||
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
|
||||
)
|
||||
assert callable(compiled_fn)
|
||||
|
||||
if aot:
|
||||
if not hasattr(compiled_fn, "serialize"):
|
||||
raise RuntimeError(
|
||||
"Compiled function should have serialize method when aot=True"
|
||||
)
|
||||
return AOTCompiledArtifact(compiled_fn)
|
||||
artifacts = torch.compiler.save_cache_artifacts()
|
||||
if artifacts is None:
|
||||
log.warning(
|
||||
@ -264,4 +437,4 @@ def standalone_compile(
|
||||
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
|
||||
)
|
||||
|
||||
return CompiledArtifact(compiled_fn, artifacts)
|
||||
return CacheCompiledArtifact(compiled_fn, artifacts)
|
||||
|
||||
@ -15,7 +15,6 @@ from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.utils._python_dispatch import (
|
||||
_detect_infra_mode,
|
||||
_disable_infra_mode,
|
||||
autograd_would_have_decomposed,
|
||||
return_and_correct_aliasing,
|
||||
TorchDispatchMode,
|
||||
)
|
||||
@ -410,13 +409,8 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
return False
|
||||
return True
|
||||
|
||||
# in normal torch.compile IR, we only decompose an op if autograd
|
||||
# would have decomposed it (NB: autograd may have been skipped if
|
||||
# we are in inference mode)
|
||||
# TODO: the flatten here can potentially be deduped with the
|
||||
# unwrapping pytree_map later
|
||||
flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs))
|
||||
return autograd_would_have_decomposed(func, flat_args_kwargs)
|
||||
# in normal torch.compile IR, we decompose functional composite ops
|
||||
return True
|
||||
|
||||
if (
|
||||
func not in FunctionalTensor.metadata_fns
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/jit/runtime/logging.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
@ -33,7 +34,7 @@ int64_t LockingLogger::getCounterValue(const std::string& name) const {
|
||||
return raw_counter.sum / raw_counter.count;
|
||||
} break;
|
||||
}
|
||||
throw std::runtime_error("Unknown aggregation type!");
|
||||
TORCH_CHECK(false, "Unknown aggregation type!");
|
||||
}
|
||||
|
||||
void LockingLogger::setAggregationType(
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
|
||||
#include <limits>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace torch::jit {
|
||||
@ -112,20 +113,17 @@ void listRemove<at::Tensor>(Stack& stack) {
|
||||
}
|
||||
|
||||
void checkImplicitTensorToNum(const at::Tensor& t, bool toInt) {
|
||||
if (t.requires_grad()) {
|
||||
throw std::runtime_error(
|
||||
"Cannot input a tensor that requires grad as a scalar argument");
|
||||
}
|
||||
if (!t.sizes().empty()) {
|
||||
throw std::runtime_error(
|
||||
"Cannot input a tensor of dimension other than 0 as a scalar argument");
|
||||
}
|
||||
if (toInt && !isIntegralType(t.scalar_type(), /*includeBool=*/false)) {
|
||||
std::stringstream ss;
|
||||
ss << "Cannot input a tensor of type " << t.scalar_type()
|
||||
<< " as an integral argument";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!t.requires_grad(),
|
||||
"Cannot input a tensor that requires grad as a scalar argument");
|
||||
TORCH_CHECK(
|
||||
t.sizes().empty(),
|
||||
"Cannot input a tensor of dimension other than 0 as a scalar argument");
|
||||
TORCH_CHECK(
|
||||
!toInt || isIntegralType(t.scalar_type(), /*includeBool=*/false),
|
||||
"Cannot input a tensor of type ",
|
||||
t.scalar_type(),
|
||||
" as an integral argument");
|
||||
}
|
||||
|
||||
void checkDoubleInRange(double a) {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
@ -159,9 +160,8 @@ void sort_op(Stack& stack) {
|
||||
|
||||
if (!g_list.empty()) {
|
||||
std::stringstream error_str;
|
||||
if (!isSortableListOfObjectsOrTuples(g_list, error_str)) {
|
||||
throw std::runtime_error(error_str.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
isSortableListOfObjectsOrTuples(g_list, error_str), error_str.str());
|
||||
|
||||
c10::IValueComparator comparator;
|
||||
if (reverse) {
|
||||
@ -254,9 +254,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
int64_t lo = 0, hi = 0, step = 0;
|
||||
pop(stack, lo, hi, step);
|
||||
// error handling when step_val = 0 during runtime
|
||||
if (step == 0) {
|
||||
throw std::runtime_error("range() arg 3 must not be zero");
|
||||
}
|
||||
TORCH_CHECK(step != 0, "range() arg 3 must not be zero");
|
||||
if (step > 0 && lo < hi) {
|
||||
push(stack, 1 + (hi - 1 - lo) / step);
|
||||
} else if (step < 0 && lo > hi) {
|
||||
@ -382,14 +380,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
auto s = pop(stack).toString();
|
||||
std::string::size_type sz = 0;
|
||||
int64_t val = static_cast<int64_t>(std::stoll(s->string(), &sz));
|
||||
if (sz == s->string().size()) {
|
||||
push(stack, val);
|
||||
} else {
|
||||
std::stringstream error_str;
|
||||
error_str << "invalid literal for int() "
|
||||
<< "with base 10: '" << s->string() << "'";
|
||||
throw std::runtime_error(error_str.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
sz == s->string().size(),
|
||||
"invalid literal for int() ",
|
||||
"with base 10: '",
|
||||
s->string(),
|
||||
"'");
|
||||
push(stack, val);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
@ -436,14 +433,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
auto s = pop(stack).toString();
|
||||
std::string::size_type sz = 0;
|
||||
double b = std::stod(s->string(), &sz);
|
||||
if (sz == s->string().size()) {
|
||||
push(stack, b);
|
||||
} else {
|
||||
std::stringstream error_str;
|
||||
error_str << "could not convert string "
|
||||
<< "to float: '" << s->string() << "'";
|
||||
throw std::runtime_error(error_str.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
sz == s->string().size(),
|
||||
"could not convert string ",
|
||||
"to float: '",
|
||||
s->string(),
|
||||
"'");
|
||||
push(stack, b);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
@ -1793,10 +1789,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
}
|
||||
|
||||
const std::string& separator = ivalue.toStringRef();
|
||||
|
||||
if (separator.empty()) {
|
||||
throw std::runtime_error("ValueError: empty separator");
|
||||
}
|
||||
TORCH_CHECK(!separator.empty(), "ValueError: empty separator");
|
||||
|
||||
auto count = 0;
|
||||
|
||||
@ -1919,11 +1912,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string fillchar = pop(stack).toStringRef();
|
||||
int64_t width = pop(stack).toInt();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
fillchar.size() == 1,
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
if (string.size() > static_cast<std::string::size_type>(width)) {
|
||||
push(stack, string);
|
||||
return;
|
||||
@ -2092,9 +2083,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string substr = pop(stack).toStringRef();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
auto result = stringFindImpl(string, substr, start, end);
|
||||
if (result < 0) {
|
||||
throw std::runtime_error("ValueError: substring not found");
|
||||
}
|
||||
TORCH_CHECK(result >= 0, "ValueError: substring not found");
|
||||
push(stack, result);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
@ -2107,9 +2096,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string substr = pop(stack).toStringRef();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
auto result = stringFindImpl(string, substr, start, end, true);
|
||||
if (result < 0) {
|
||||
throw std::runtime_error("ValueError: substring not found");
|
||||
}
|
||||
TORCH_CHECK(result >= 0, "ValueError: substring not found");
|
||||
push(stack, result);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
@ -2183,11 +2170,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string fillchar = pop(stack).toStringRef();
|
||||
int64_t width = pop(stack).toInt();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
fillchar.size() == 1,
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
auto to_append =
|
||||
std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
|
||||
|
||||
@ -2207,11 +2192,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{
|
||||
std::string fillchar = pop(stack).toStringRef();
|
||||
int64_t width = pop(stack).toInt();
|
||||
std::string string = pop(stack).toStringRef();
|
||||
if (fillchar.size() != 1) {
|
||||
// TODO: this should be a TypeError
|
||||
throw std::runtime_error(
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
fillchar.size() == 1,
|
||||
"TypeError: The fill character must be exactly one character long");
|
||||
auto to_append =
|
||||
std::max(int64_t(0), width - static_cast<int64_t>(string.size()));
|
||||
|
||||
@ -3358,10 +3341,8 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
|
||||
int64_t a = 0, b = 0;
|
||||
lldiv_t divresult = {};
|
||||
pop(stack, a, b);
|
||||
if (b == 0) {
|
||||
throw std::runtime_error(
|
||||
"ZeroDivisionError: integer division or modulo by zero");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
b != 0, "ZeroDivisionError: integer division or modulo by zero");
|
||||
divresult = lldiv(a, b);
|
||||
if (divresult.rem && (a < 0) != (b < 0)) {
|
||||
divresult.quot -= 1;
|
||||
@ -3379,9 +3360,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
|
||||
[](Stack& stack) {
|
||||
double a = 0, b = 0;
|
||||
pop(stack, a, b);
|
||||
if (b == 0) {
|
||||
throw std::runtime_error("ZeroDivisionError: float divmod()");
|
||||
}
|
||||
TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()");
|
||||
double rem = fmod(a, b);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
||||
if (rem && (a < 0) != (b < 0)) {
|
||||
@ -3426,9 +3405,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{
|
||||
type_a a; \
|
||||
type_b b; \
|
||||
pop(stack, a, b); \
|
||||
if (b == 0) { \
|
||||
throw std::runtime_error("ZeroDivisionError: float divmod()"); \
|
||||
} \
|
||||
TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()"); \
|
||||
double quot = floor(a / b); \
|
||||
double rem = a - (quot * b); \
|
||||
push(stack, quot, rem); \
|
||||
|
||||
@ -466,14 +466,6 @@ at::Tensor LazyNativeFunctions::linalg_pinv(
|
||||
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> LazyNativeFunctions::svd(
|
||||
const at::Tensor& self,
|
||||
bool some,
|
||||
bool compute_uv) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(svd)>::call(
|
||||
self, some, compute_uv);
|
||||
}
|
||||
|
||||
// functionalize_aten_op can't handle out= ops directly.
|
||||
// Instead, we can call the composite kernel from core, and copy and mutations
|
||||
// back to the inputs.
|
||||
|
||||
@ -21,10 +21,6 @@ backends are ready, this list allows opt-in one at a time.
|
||||
PRESERVED_ATEN_CIA_OPS = {
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.upsample_nearest2d.vec,
|
||||
# NB: don't use the C++ decomp, because it is not functional!
|
||||
torch.ops.aten.silu_backward.default,
|
||||
torch.ops.aten.mish_backward.default,
|
||||
torch.ops.aten._fused_rms_norm.default,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -63,7 +63,6 @@ from torch.utils._python_dispatch import (
|
||||
_disable_infra_mode,
|
||||
_push_mode,
|
||||
_unset_infra_mode,
|
||||
autograd_would_have_decomposed,
|
||||
TorchDispatchMode,
|
||||
)
|
||||
from torch.utils._stats import count
|
||||
@ -1033,16 +1032,11 @@ def proxy_call(
|
||||
return r
|
||||
|
||||
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
|
||||
if (
|
||||
not pre_dispatch
|
||||
and func
|
||||
not in [
|
||||
torch.ops.aten.size.default,
|
||||
torch.ops.aten.stride.default,
|
||||
torch.ops.aten.storage_offset.default,
|
||||
]
|
||||
and autograd_would_have_decomposed(func, flat_args_kwargs)
|
||||
):
|
||||
if not pre_dispatch and func not in [
|
||||
torch.ops.aten.size.default,
|
||||
torch.ops.aten.stride.default,
|
||||
torch.ops.aten.storage_offset.default,
|
||||
]:
|
||||
with proxy_mode:
|
||||
r = func.decompose(*args, **kwargs)
|
||||
if r is not NotImplemented:
|
||||
|
||||
@ -43,6 +43,8 @@ def _partition_by_supported_nodes(gm, supported_ops, prefix):
|
||||
|
||||
|
||||
def _compile_submod(gm, prefix):
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module" and node.target.startswith(prefix):
|
||||
fake_inputs = []
|
||||
@ -56,13 +58,12 @@ def _compile_submod(gm, prefix):
|
||||
|
||||
submod = getattr(gm, node.target)
|
||||
|
||||
# _dummy_wrapper is to make call_function happy
|
||||
compiled_submod = _dummy_wrapper(
|
||||
torch._inductor.standalone_compile(
|
||||
submod, fake_inputs, dynamic_shapes="from_tracing_context"
|
||||
)
|
||||
compiled_fn = torch._inductor.standalone_compile(
|
||||
submod, fake_inputs, dynamic_shapes="from_tracing_context", aot=True
|
||||
)
|
||||
|
||||
assert isinstance(compiled_fn, AOTCompiledArtifact)
|
||||
# _dummy_wrapper is to make call_function happy
|
||||
compiled_submod = _dummy_wrapper(compiled_fn)
|
||||
with gm.graph.inserting_after(node):
|
||||
new_node = gm.graph.call_function(
|
||||
compiled_submod, args=node.args, kwargs=node.kwargs
|
||||
|
||||
@ -38,8 +38,7 @@ class BytesWriter:
|
||||
digest = zlib.crc32(self._data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
|
||||
4, byteorder="big", signed=False
|
||||
)
|
||||
if len(digest) != CHECKSUM_DIGEST_SIZE:
|
||||
raise AssertionError("Computed checksum digest has unexpected size")
|
||||
assert len(digest) == CHECKSUM_DIGEST_SIZE
|
||||
self._data[0:CHECKSUM_DIGEST_SIZE] = digest
|
||||
return bytes(self._data)
|
||||
|
||||
@ -47,13 +46,11 @@ class BytesWriter:
|
||||
class BytesReader:
|
||||
def __init__(self, data: bytes) -> None:
|
||||
# Check for data corruption
|
||||
if len(data) < CHECKSUM_DIGEST_SIZE:
|
||||
raise AssertionError("Input data is too short to contain checksum")
|
||||
assert len(data) >= CHECKSUM_DIGEST_SIZE
|
||||
digest = zlib.crc32(data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
|
||||
4, byteorder="big", signed=False
|
||||
)
|
||||
if len(digest) != CHECKSUM_DIGEST_SIZE:
|
||||
raise AssertionError("Computed checksum digest has unexpected size")
|
||||
assert len(digest) == CHECKSUM_DIGEST_SIZE
|
||||
if data[0:CHECKSUM_DIGEST_SIZE] != digest:
|
||||
raise RuntimeError(
|
||||
"Bytes object is corrupted, checksum does not match. "
|
||||
@ -123,11 +120,7 @@ class AppendingByteSerializer(Generic[T]):
|
||||
@staticmethod
|
||||
def to_list(data: bytes, *, deserialize_fn: Callable[[BytesReader], T]) -> list[T]:
|
||||
reader = BytesReader(data)
|
||||
if reader.read_uint64() != _ENCODING_VERSION:
|
||||
raise AssertionError(
|
||||
f"Encoding version mismatch in AppendingByteSerializer.to_list, \
|
||||
got {reader.read_uint64()}"
|
||||
)
|
||||
assert reader.read_uint64() == _ENCODING_VERSION
|
||||
|
||||
result: list[T] = []
|
||||
while not reader.is_finished():
|
||||
|
||||
@ -85,16 +85,12 @@ class _Config(Generic[T]):
|
||||
)
|
||||
|
||||
if self.alias is not None:
|
||||
if (
|
||||
self.default is not _UNSET_SENTINEL
|
||||
or self.justknob is not None
|
||||
or self.env_name_default is not None
|
||||
or self.env_name_force is not None
|
||||
):
|
||||
raise AssertionError(
|
||||
"if alias is set, none of {default, justknob, \
|
||||
env_name_default and env_name_force} can be set"
|
||||
)
|
||||
assert (
|
||||
self.default is _UNSET_SENTINEL
|
||||
and self.justknob is None
|
||||
and self.env_name_default is None
|
||||
and self.env_name_force is None
|
||||
), "if alias is set, none of {default, justknob and env var} can be set"
|
||||
|
||||
@staticmethod
|
||||
def string_or_list_of_string_to_list(
|
||||
@ -104,8 +100,7 @@ class _Config(Generic[T]):
|
||||
return None
|
||||
if isinstance(val, str):
|
||||
return [val]
|
||||
if not isinstance(val, list):
|
||||
raise AssertionError(f"val is not a list, got {type(val)}")
|
||||
assert isinstance(val, list)
|
||||
return val
|
||||
|
||||
|
||||
@ -198,10 +193,7 @@ def install_config_module(module: ModuleType) -> None:
|
||||
if dest is module:
|
||||
delattr(module, key)
|
||||
elif isinstance(value, type):
|
||||
if value.__module__ != module.__name__:
|
||||
raise AssertionError(
|
||||
f"subconfig class {value} must be defined in module {module.__name__}"
|
||||
)
|
||||
assert value.__module__ == module.__name__
|
||||
# a subconfig with `class Blah:` syntax
|
||||
proxy = SubConfigProxy(module, f"{name}.")
|
||||
visit(value, proxy, f"{name}.")
|
||||
@ -242,8 +234,10 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> set[str
|
||||
prev_name = ""
|
||||
maybe_current = token.string.strip()
|
||||
if COMPILE_IGNORED_MARKER in maybe_current:
|
||||
if current_comment != ("", -1):
|
||||
raise AssertionError(f"unconsumed {COMPILE_IGNORED_MARKER}")
|
||||
assert current_comment == (
|
||||
"",
|
||||
-1,
|
||||
), f"unconsumed {COMPILE_IGNORED_MARKER}"
|
||||
current_comment = maybe_current, token.start[0]
|
||||
elif token.type == tokenize.NAME:
|
||||
# Only accept the first name token, to handle if you have
|
||||
@ -260,8 +254,7 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> set[str
|
||||
assignments.add(prev_name)
|
||||
current_comment = "", -1 # reset
|
||||
prev_name = ""
|
||||
if current_comment != ("", -1):
|
||||
raise AssertionError(f"unconsumed {COMPILE_IGNORED_MARKER}")
|
||||
assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}"
|
||||
return assignments
|
||||
|
||||
|
||||
@ -313,22 +306,20 @@ class _ConfigEntry:
|
||||
|
||||
# Ensure justknobs and envvars are allowlisted types
|
||||
if self.justknob is not None and self.default is not None:
|
||||
if not isinstance(self.default, bool):
|
||||
raise AssertionError(
|
||||
f"justknobs only support booleans, {self.default} is not a boolean"
|
||||
)
|
||||
assert isinstance(self.default, bool), (
|
||||
f"justknobs only support booleans, {self.default} is not a boolean"
|
||||
)
|
||||
if self.value_type is not None and (
|
||||
config.env_name_default is not None or config.env_name_force is not None
|
||||
):
|
||||
if self.value_type not in (
|
||||
assert self.value_type in (
|
||||
bool,
|
||||
str,
|
||||
Optional[bool],
|
||||
Optional[str],
|
||||
):
|
||||
raise AssertionError(
|
||||
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
|
||||
)
|
||||
), (
|
||||
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
|
||||
)
|
||||
|
||||
|
||||
class ConfigModule(ModuleType):
|
||||
@ -426,10 +417,7 @@ class ConfigModule(ModuleType):
|
||||
|
||||
def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None:
|
||||
data = self._get_alias_module_and_name(entry)
|
||||
if data is None:
|
||||
raise AssertionError(
|
||||
"alias data should not be None when setting alias value"
|
||||
)
|
||||
assert data is not None
|
||||
module, constant_name = data
|
||||
setattr(module, constant_name, val)
|
||||
|
||||
@ -654,32 +642,19 @@ class ConfigModule(ModuleType):
|
||||
changes: dict[str, Any]
|
||||
if arg1 is not None:
|
||||
if arg2 is not None:
|
||||
if not isinstance(arg1, str):
|
||||
raise AssertionError(
|
||||
"first argument must be a string when passing 2 positional args to patch"
|
||||
)
|
||||
assert isinstance(arg1, str)
|
||||
# patch("key", True) syntax
|
||||
changes = {arg1: arg2}
|
||||
else:
|
||||
if not isinstance(arg1, dict):
|
||||
raise AssertionError(
|
||||
"first argument must be a dict when passing a single positional arg to patch"
|
||||
)
|
||||
assert isinstance(arg1, dict)
|
||||
# patch({"key": True}) syntax
|
||||
changes = arg1
|
||||
if kwargs:
|
||||
raise AssertionError(
|
||||
"cannot pass both positional and keyword arguments to patch"
|
||||
)
|
||||
assert not kwargs
|
||||
else:
|
||||
# patch(key=True) syntax
|
||||
changes = kwargs
|
||||
if arg2 is not None:
|
||||
raise AssertionError(
|
||||
"second positional argument is only valid when first argument is a key string"
|
||||
)
|
||||
if not isinstance(changes, dict):
|
||||
raise AssertionError(f"expected `dict` got {type(changes)}")
|
||||
assert arg2 is None
|
||||
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
|
||||
prior: dict[str, Any] = {}
|
||||
config = self
|
||||
|
||||
@ -688,10 +663,7 @@ class ConfigModule(ModuleType):
|
||||
self.changes = changes
|
||||
|
||||
def __enter__(self) -> None:
|
||||
if prior:
|
||||
raise AssertionError(
|
||||
"prior should be empty when entering ConfigPatch"
|
||||
)
|
||||
assert not prior
|
||||
for key in self.changes.keys():
|
||||
# KeyError on invalid entry
|
||||
prior[key] = config.__getattr__(key)
|
||||
|
||||
@ -21,8 +21,7 @@ This file should be imported into any file that uses install_config_module like
|
||||
Note that the import should happen before the call to install_config_module(), otherwise runtime errors may occur.
|
||||
"""
|
||||
|
||||
if not TYPE_CHECKING: # noqa: PYI002
|
||||
raise AssertionError("Do not use at runtime") # noqa: W291
|
||||
assert TYPE_CHECKING, "Do not use at runtime"
|
||||
|
||||
def save_config() -> bytes: ...
|
||||
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...
|
||||
|
||||
@ -217,10 +217,7 @@ class ContentStoreReader:
|
||||
weights_only=True,
|
||||
map_location=device,
|
||||
)._untyped_storage
|
||||
if s is None:
|
||||
raise AssertionError(
|
||||
f"expected storage for hash {h} in {os.path.join(self.loc, 'storages')}, got None"
|
||||
)
|
||||
assert s is not None
|
||||
if self.storage_cache is not None:
|
||||
self.storage_cache[device][h] = StorageWeakRef(s)
|
||||
return s
|
||||
|
||||
@ -86,14 +86,13 @@ def context_decorator(ctx, func):
|
||||
be a multi-shot context manager that can be directly invoked multiple times)
|
||||
or a callable that produces a context manager.
|
||||
"""
|
||||
if callable(ctx) and hasattr(ctx, "__enter__"):
|
||||
raise AssertionError(
|
||||
f"Passed in {ctx} is both callable and also a valid context manager "
|
||||
"(has __enter__), making it ambiguous which interface to use. If you "
|
||||
"intended to pass a context manager factory, rewrite your call as "
|
||||
"context_decorator(lambda: ctx()); if you intended to pass a context "
|
||||
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
|
||||
)
|
||||
assert not (callable(ctx) and hasattr(ctx, "__enter__")), (
|
||||
f"Passed in {ctx} is both callable and also a valid context manager "
|
||||
"(has __enter__), making it ambiguous which interface to use. If you "
|
||||
"intended to pass a context manager factory, rewrite your call as "
|
||||
"context_decorator(lambda: ctx()); if you intended to pass a context "
|
||||
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
|
||||
)
|
||||
|
||||
if not callable(ctx):
|
||||
|
||||
|
||||
@ -933,10 +933,7 @@ def _broadcast_to_and_flatten(
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[list[Any]]:
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise AssertionError(
|
||||
f"_broadcast_to_and_flatten: Expected `treespec` to be instance of PyTreeSpec but got {type(treespec)}"
|
||||
)
|
||||
assert _is_pytreespec_instance(treespec)
|
||||
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
try:
|
||||
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
|
||||
|
||||
@ -87,18 +87,12 @@ class DeviceContext(TorchFunctionMode):
|
||||
# or else someone else has popped it!
|
||||
for _ in range(_len_torch_function_stack() - 1):
|
||||
mode = _pop_mode()
|
||||
if isinstance(mode, DeviceContext):
|
||||
raise AssertionError(
|
||||
"Found nested DeviceContext on the mode stack where none expected"
|
||||
)
|
||||
assert not isinstance(mode, DeviceContext)
|
||||
cur_stack.append(mode)
|
||||
|
||||
if _len_torch_function_stack() > 0:
|
||||
mode = _pop_mode()
|
||||
if not isinstance(mode, DeviceContext):
|
||||
raise AssertionError(
|
||||
"Expected a DeviceContext at the bottom of the mode stack"
|
||||
)
|
||||
assert isinstance(mode, DeviceContext)
|
||||
|
||||
for mode in reversed(cur_stack):
|
||||
_push_mode(mode)
|
||||
|
||||
@ -31,8 +31,7 @@ def cache_method(
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrap(self: _C, *args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
if kwargs:
|
||||
raise AssertionError("cache_method does not accept keyword arguments")
|
||||
assert not kwargs
|
||||
if not (cache := getattr(self, cache_name, None)):
|
||||
cache = {}
|
||||
setattr(self, cache_name, cache)
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import warnings
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, overload, Protocol, TYPE_CHECKING, Union
|
||||
from typing import Optional, overload, Protocol, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
@ -21,10 +20,6 @@ from torch._C import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
|
||||
# - We need a better user-facing api for _DisableTorchDispatch that
|
||||
# is able to selectively disable __torch_dispatch__ of a particular class.
|
||||
@ -88,8 +83,7 @@ class TorchDispatchMode:
|
||||
|
||||
def __init__(self, _dispatch_key=None):
|
||||
if _dispatch_key is not None:
|
||||
if not isinstance(_dispatch_key, torch._C.DispatchKey):
|
||||
raise AssertionError("_dispatch_key must be a torch._C.DispatchKey")
|
||||
assert isinstance(_dispatch_key, torch._C.DispatchKey)
|
||||
self.__dict__["_dispatch_key"] = _dispatch_key
|
||||
|
||||
self.old_dispatch_mode_flags: deque[bool] = deque()
|
||||
@ -218,24 +212,16 @@ def _get_current_dispatch_mode() -> Optional[TorchDispatchMode]:
|
||||
|
||||
|
||||
def _detect_infra_mode(key):
|
||||
if key not in (
|
||||
assert key in [
|
||||
torch._C._TorchDispatchModeKey.FUNCTIONAL,
|
||||
torch._C._TorchDispatchModeKey.PROXY,
|
||||
):
|
||||
raise AssertionError(
|
||||
f"key must be either FUNCTIONAL ({torch._C._TorchDispatchModeKey.FUNCTIONAL}) \
|
||||
or PROXY ({torch._C._TorchDispatchModeKey.PROXY}) _TorchDispatchModeKey, \
|
||||
got {key}"
|
||||
)
|
||||
]
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
|
||||
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key)
|
||||
post_dispatch_mode = torch._C._get_dispatch_mode(key)
|
||||
|
||||
if pre_dispatch_mode is not None and post_dispatch_mode is not None:
|
||||
raise AssertionError(
|
||||
"At most one of pre_dispatch_mode and post_dispatch_mode may be active"
|
||||
)
|
||||
assert (pre_dispatch_mode is None) or (post_dispatch_mode is None)
|
||||
|
||||
if pre_dispatch_mode is None:
|
||||
return post_dispatch_mode
|
||||
@ -261,13 +247,10 @@ def _unset_infra_mode(key):
|
||||
|
||||
|
||||
def _disable_infra_mode(key):
|
||||
if key not in (
|
||||
assert key in (
|
||||
torch._C._TorchDispatchModeKey.FUNCTIONAL,
|
||||
torch._C._TorchDispatchModeKey.PROXY,
|
||||
):
|
||||
raise AssertionError(
|
||||
"key must be either FUNCTIONAL or PROXY _TorchDispatchModeKey"
|
||||
)
|
||||
)
|
||||
mode_unset = _unset_infra_mode(key)
|
||||
try:
|
||||
yield mode_unset
|
||||
@ -288,10 +271,7 @@ def _get_current_dispatch_mode_stack() -> list[TorchDispatchMode]:
|
||||
|
||||
def _push_mode(mode: TorchDispatchMode):
|
||||
k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None
|
||||
if k is not None and k != torch._C.DispatchKey.PreDispatch:
|
||||
raise AssertionError(
|
||||
"mode._dispatch_key must be None or DispatchKey.PreDispatch"
|
||||
)
|
||||
assert k is None or k == torch._C.DispatchKey.PreDispatch
|
||||
if k is None:
|
||||
_push_on_torch_dispatch_stack(mode)
|
||||
return
|
||||
@ -434,7 +414,7 @@ class TensorWithFlatten(Protocol):
|
||||
@overload
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch._prims_common.DeviceLikeType] = None,
|
||||
device: Optional["torch._prims_common.DeviceLikeType"] = None,
|
||||
dtype: Optional[torch.types._dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
@ -529,16 +509,14 @@ def transform_subclass(t, callback, outer_size=None, outer_stride=None):
|
||||
# NB: Purposefully guard here to simplify the inner / outer symbols.
|
||||
# Using sym_eq() for symbolic comparison can result in an expression that's too
|
||||
# difficult to guard on, so we use == here.
|
||||
if sub.shape != outer_size:
|
||||
raise AssertionError(
|
||||
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
|
||||
f"shape equal to {outer_size}, but got: {sub.shape}"
|
||||
)
|
||||
if sub.stride() != outer_stride:
|
||||
raise AssertionError(
|
||||
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
|
||||
f"stride equal to {outer_stride}, but got: {sub.stride()}"
|
||||
)
|
||||
assert sub.shape == outer_size, (
|
||||
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
|
||||
f"shape equal to {outer_size}, but got: {sub.shape}"
|
||||
)
|
||||
assert sub.stride() == outer_stride, (
|
||||
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
|
||||
f"stride equal to {outer_stride}, but got: {sub.stride()}"
|
||||
)
|
||||
|
||||
return sub
|
||||
|
||||
@ -555,12 +533,9 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
|
||||
It does this by unsafely overwriting the storage field of the output tensor
|
||||
to be the same storage as the input.
|
||||
"""
|
||||
if not isinstance(func, torch._ops.OpOverload):
|
||||
raise AssertionError(f"func must be an OpOverload, got {type(args)}")
|
||||
if not isinstance(args, tuple):
|
||||
raise AssertionError(f"args must be a tuple, got {type(args)}")
|
||||
if not isinstance(outs, (list, tuple)):
|
||||
raise AssertionError(f"outs must be a list or tuple, got {type(args)}")
|
||||
assert isinstance(func, torch._ops.OpOverload)
|
||||
assert isinstance(args, tuple)
|
||||
assert isinstance(outs, (list, tuple))
|
||||
|
||||
def alias_non_inplace_storage(arg, ret):
|
||||
# This is hopefully a reasonable assert:
|
||||
@ -581,11 +556,10 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
|
||||
):
|
||||
ret_list = ret if isinstance(ret, list) else [ret]
|
||||
for r in ret_list:
|
||||
if type(arg) is not type(r):
|
||||
raise AssertionError(
|
||||
f"Called {str(func)} with input of type {type(arg)}\n"
|
||||
f"and output of type {type(ret)}. But expected types to match."
|
||||
)
|
||||
assert type(arg) is type(
|
||||
r
|
||||
), f"""Called {str(func)} with input of type {type(arg)}
|
||||
and output of type {type(ret)}. But expected types to match."""
|
||||
# Need to call a non-dispatcher helper, because we explicitly do **not**
|
||||
# want our subclass to intercept the set_() call.
|
||||
# instead, our subclass should directly have its storage swapped out.
|
||||
@ -601,8 +575,7 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
|
||||
for r in ret:
|
||||
torch._functionalize_unsafe_set(r, arg)
|
||||
else:
|
||||
if not isinstance(ret, torch.Tensor):
|
||||
raise AssertionError(f"expected torch.Tensor, got {type(ret)}")
|
||||
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
|
||||
torch._functionalize_unsafe_set(ret, arg)
|
||||
|
||||
for arg_idx, schema_arg in enumerate(schema_info.args):
|
||||
@ -646,10 +619,7 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
# properly for some ops that output tensorlists)
|
||||
if func.namespace == "aten":
|
||||
torchgen_schema_str = str(func._schema)
|
||||
if not torchgen_schema_str.startswith("aten::"):
|
||||
raise AssertionError(
|
||||
"Expected torchgen schema string to start with 'aten::'"
|
||||
)
|
||||
assert torchgen_schema_str.startswith("aten::")
|
||||
# remove the aten:: namespace, which is added by the torchscript parser,
|
||||
# and torchgen doesn't know how to handle
|
||||
torchgen_schema_str = torchgen_schema_str[6:]
|
||||
@ -712,64 +682,6 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
return schema_info
|
||||
|
||||
|
||||
def autograd_would_have_decomposed(
|
||||
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
|
||||
) -> bool:
|
||||
"""
|
||||
Suppose that an operator has CompositeImplicitAutograd decomp registered.
|
||||
Would autograd have used this decomposition? It will only use it if there
|
||||
isn't an explicit backend registration for the device as well. This function
|
||||
will tell if this would have occurred.
|
||||
|
||||
Why do we need to apply these decompositions later? When inference mode is
|
||||
on, the autograd key is bypassed entirely, so a lower level mode cannot rely
|
||||
on the decomposition have been applied. It's easy to accidentally never apply
|
||||
the decomposition, resulting in an operator showing up in a graph that
|
||||
is unexpected.
|
||||
|
||||
Why do we need to AVOID applying the decomposition when autograd wouldn't
|
||||
have decomposed? If autograd doesn't decompose, this means in eager mode
|
||||
we would have run the fused kernel. It must be possible to trace this
|
||||
fused kernel directly into the graph for fidelity with eager (NB: a user
|
||||
has the option of then further decomposing at proxy tensor mode via
|
||||
decomposition table, but we must preserve it to proxy mode to have the
|
||||
choice.)
|
||||
|
||||
Why does functionalization need to also perform the test here? This is
|
||||
because some CompositeImplicitAutograd decompositions are not functional.
|
||||
If we are eventually going to decompose, we need to do this while we can
|
||||
still turn functionalization back on, so those decompositions get functionalized.
|
||||
So an early decomposition in functionalization may still be necessary. Note that
|
||||
if proxy tensor decomposition process could turn functionalization back on, this
|
||||
wouldn't be necessary, and maybe that is a useful thing to do anyway because
|
||||
the decomposition table is user specified and a user could violate the functional
|
||||
decomp requirement with a bad decomp. If this happened, then you could always
|
||||
pass through functionalization.
|
||||
"""
|
||||
has_backend_registration = False
|
||||
for a in flat_args:
|
||||
if isinstance(a, torch.Tensor):
|
||||
backend_key = torch._C._parse_dispatch_key(
|
||||
torch._C._dispatch_key_for_device(a.device.type)
|
||||
)
|
||||
if backend_key is None:
|
||||
raise AssertionError(
|
||||
"Failed to infer backend dispatch key from tensor device"
|
||||
)
|
||||
# TODO: use func.has_kernel_for_dispatch_key(backend_key)
|
||||
# but this one checks py_impl and CompositeImplicitAutograd
|
||||
# incorrectly shows up as has backend reg here
|
||||
has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
func.name(), backend_key
|
||||
)
|
||||
|
||||
# in theory we should take all backend keys and take the highest priority one
|
||||
# to properly mimic the dispatcher,
|
||||
# this just grabs the first tensor and takes its device key
|
||||
break
|
||||
return not has_backend_registration
|
||||
|
||||
|
||||
# See NOTE[SchemaInfo int_tags] above.
|
||||
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]
|
||||
|
||||
@ -799,8 +711,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
||||
if not alias_set or not x.is_write:
|
||||
return None
|
||||
# torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
|
||||
if len(alias_set) != 1:
|
||||
raise AssertionError("Expected alias_set to contain exactly one element")
|
||||
assert len(alias_set) == 1
|
||||
# timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for
|
||||
# set of size 1 on Python 3.13.
|
||||
return next(iter(alias_set))
|
||||
@ -814,10 +725,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
||||
i for i, a in enumerate(schema_info.args) if output_alias in a.alias_set
|
||||
]
|
||||
# For any dispatcher op with an output alias, we expect it to map to exactly one alias in the schema's input arguments.
|
||||
if len(arg_indices) != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one argument index for the given output alias"
|
||||
)
|
||||
assert len(arg_indices) == 1
|
||||
idx = arg_indices[0]
|
||||
arg_info = schema_info.args[idx]
|
||||
if arg_info.name is not None and arg_info.name in new_kwargs:
|
||||
@ -843,10 +751,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
||||
]
|
||||
# Assumption: we have a very small number of inplace_view ops that follow a strict schema:
|
||||
# there is only a single argument that gets its metadata mutated.
|
||||
if len(mutated_args) != 1:
|
||||
raise AssertionError(
|
||||
"expected exactly one mutated arg for inplace_view ops"
|
||||
)
|
||||
assert len(mutated_args) == 1
|
||||
# This check exists because we generally *do* want to update the metadata of any wrapper subclasses,
|
||||
# but FunctionalTensor is special: it overrides all size/stride calls to plumb to the inner tensor.
|
||||
# so we don't actually need to update the metadata (and attempting to do so causes errors)
|
||||
|
||||
@ -476,8 +476,7 @@ def _is_constant_holder(spec: "TreeSpec") -> bool:
|
||||
|
||||
def _retrieve_constant(spec: "TreeSpec") -> Any:
|
||||
"""Given a spec from a pytree registered with register_constant, retrieves the constant"""
|
||||
if not _is_constant_holder(spec):
|
||||
raise AssertionError("spec does not correspond to a registered constant pytree")
|
||||
assert _is_constant_holder(spec)
|
||||
return tree_unflatten([], spec)
|
||||
|
||||
|
||||
@ -900,25 +899,17 @@ def _defaultdict_serialize(context: Context) -> DumpableContext:
|
||||
|
||||
|
||||
def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
|
||||
if not isinstance(dumpable_context, dict):
|
||||
raise AssertionError("dumpable_context must be a dict")
|
||||
|
||||
expected_keys = {
|
||||
assert isinstance(dumpable_context, dict)
|
||||
assert set(dumpable_context) == {
|
||||
"default_factory_module",
|
||||
"default_factory_name",
|
||||
"dict_context",
|
||||
}
|
||||
if set(dumpable_context) != expected_keys:
|
||||
raise AssertionError(
|
||||
f"dumpable_context keys must be {expected_keys}, got {set(dumpable_context)}"
|
||||
)
|
||||
|
||||
default_factory_module = dumpable_context["default_factory_module"]
|
||||
default_factory_name = dumpable_context["default_factory_name"]
|
||||
if not isinstance(default_factory_module, str):
|
||||
raise AssertionError("default_factory_module must be a string")
|
||||
if not isinstance(default_factory_name, str):
|
||||
raise AssertionError("default_factory_name must be a string")
|
||||
assert isinstance(default_factory_module, str)
|
||||
assert isinstance(default_factory_name, str)
|
||||
module = importlib.import_module(default_factory_module)
|
||||
default_factory = getattr(module, default_factory_name)
|
||||
|
||||
@ -1742,8 +1733,7 @@ def _broadcast_to_and_flatten(
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[list[Any]]:
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise AssertionError("treespec must be a TreeSpec")
|
||||
assert isinstance(treespec, TreeSpec)
|
||||
|
||||
if tree_is_leaf(tree, is_leaf=is_leaf):
|
||||
return [tree] * treespec.num_leaves
|
||||
|
||||
@ -206,8 +206,7 @@ class CapturedTraceback:
|
||||
import torch._C._profiler
|
||||
|
||||
if script or cpp:
|
||||
if skip != 0:
|
||||
raise AssertionError("skip with script/cpp NYI")
|
||||
assert skip == 0, "skip with script/cpp NYI"
|
||||
|
||||
return CapturedTraceback(
|
||||
torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
|
||||
|
||||
@ -430,8 +430,9 @@ def _get_custom_mod_func(func_name: str):
|
||||
it is marked as private. It is a convenience function for backend implementers to
|
||||
more easily call the hooks into their backend extensions.
|
||||
"""
|
||||
if not isinstance(func_name, str):
|
||||
raise AssertionError(f"func_name must be `str`, but got `{type(func_name)}`.")
|
||||
assert isinstance(func_name, str), (
|
||||
f"func_name must be `str`, but got `{type(func_name)}`."
|
||||
)
|
||||
backend_name = _get_privateuse1_backend_name()
|
||||
custom_device_mod = getattr(torch, backend_name, None)
|
||||
function = getattr(custom_device_mod, func_name, None)
|
||||
|
||||
@ -119,12 +119,10 @@ def bundle_inputs(
|
||||
# Fortunately there is a function in _recursive that does exactly that conversion.
|
||||
cloned_module = wrap_cpp_module(clone)
|
||||
if isinstance(inputs, dict):
|
||||
if not isinstance(info, dict) and info is not None:
|
||||
raise AssertionError("If inputs is a dict, info must be a dict or None")
|
||||
assert isinstance(info, dict) or info is None
|
||||
augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
|
||||
else:
|
||||
if not isinstance(info, list) and info is not None:
|
||||
raise AssertionError("If inputs is a list, info must be a list or None")
|
||||
assert isinstance(info, list) or info is None
|
||||
augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
|
||||
return cloned_module
|
||||
|
||||
|
||||
@ -1034,10 +1034,8 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
|
||||
out += f"{line['filename']}:{line['line']}:{line['name']}\n"
|
||||
out += "\n\n"
|
||||
return out
|
||||
if capture_logs_fwd.logs is None:
|
||||
raise AssertionError("capture_logs_fwd.logs is None")
|
||||
if capture_logs_recompute.logs is None:
|
||||
raise AssertionError("capture_logs_recompute.logs is None")
|
||||
assert capture_logs_fwd.logs is not None
|
||||
assert capture_logs_recompute.logs is not None
|
||||
raise CheckpointError(
|
||||
_checkpoint_error_template.format(
|
||||
forward_traces=get_str_tb("original", capture_logs_fwd),
|
||||
@ -1075,14 +1073,12 @@ class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
def pack_hook(x):
|
||||
x = x.detach() if x.requires_grad else x
|
||||
target_frame = target_frame_ref()
|
||||
if target_frame is None:
|
||||
raise AssertionError("Internal error: target_frame reference is None")
|
||||
assert target_frame is not None # appease mypy
|
||||
recomp_idx = target_frame.recomp_counter[gid]
|
||||
target_frame.recomp_counter[gid] += 1
|
||||
|
||||
if recomp_idx >= len(target_frame.weak_holders):
|
||||
if target_frame.early_stop:
|
||||
raise AssertionError("Unexpected state: target_frame.early_stop is set")
|
||||
assert not target_frame.early_stop
|
||||
if not target_frame.forward_completed:
|
||||
# We run into this case when early stop is not enabled and do
|
||||
# grad within checkpoint.
|
||||
@ -1519,14 +1515,12 @@ def _checkpoint_without_reentrant_generator(
|
||||
device_module = _get_device_module(device_type)
|
||||
forward_context, recompute_context = context_fn()
|
||||
if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn:
|
||||
if (
|
||||
not isinstance(forward_context, TorchDispatchMode)
|
||||
or not isinstance(recompute_context, TorchDispatchMode)
|
||||
):
|
||||
raise AssertionError(
|
||||
"In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` "
|
||||
"must generate a tuple of two `TorchDispatchMode`s."
|
||||
)
|
||||
assert (
|
||||
isinstance(forward_context, TorchDispatchMode) and
|
||||
isinstance(recompute_context, TorchDispatchMode)
|
||||
), \
|
||||
"In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \
|
||||
"must generate a tuple of two `TorchDispatchMode`s."
|
||||
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||
device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type)
|
||||
|
||||
|
||||
@ -290,8 +290,7 @@ def _get_icpx_version() -> str:
|
||||
match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode().strip())
|
||||
version = ['0', '0', '0'] if match is None else list(match.groups())
|
||||
version = list(map(int, version))
|
||||
if len(version) != 3:
|
||||
raise AssertionError("Failed to parse DPC++ compiler version")
|
||||
assert len(version) == 3, "Failed to parse DPC++ compiler version"
|
||||
# Aligning version format with what torch.version.xpu() returns
|
||||
return f"{version[0]}{version[1]:02}{version[2]:02}"
|
||||
|
||||
@ -325,8 +324,7 @@ def _get_sycl_device_flags(cflags):
|
||||
# We need last occurrence of -fsycl-targets as it will be the one taking effect.
|
||||
# So searching in reversed list.
|
||||
flags = [f for f in reversed(cflags) if f.startswith('-fsycl-targets=')]
|
||||
if not flags:
|
||||
raise AssertionError("bug: -fsycl-targets should have been amended to cflags")
|
||||
assert flags, "bug: -fsycl-targets should have been amended to cflags"
|
||||
|
||||
arch_list = _get_sycl_arch_list()
|
||||
if arch_list != '':
|
||||
@ -664,8 +662,7 @@ class BuildExtension(build_ext):
|
||||
extension = next(extension_iter, None)
|
||||
|
||||
if sycl_ext:
|
||||
if not self.use_ninja:
|
||||
raise AssertionError("ninja is required to build sycl extensions.")
|
||||
assert self.use_ninja, "ninja is required to build sycl extensions."
|
||||
|
||||
if cuda_ext and not IS_HIP_EXTENSION:
|
||||
_check_cuda_version(compiler_name, compiler_version)
|
||||
@ -697,10 +694,7 @@ class BuildExtension(build_ext):
|
||||
self._define_torch_extension_name(extension)
|
||||
|
||||
if 'nvcc_dlink' in extension.extra_compile_args:
|
||||
if not self.use_ninja:
|
||||
raise AssertionError(
|
||||
f"With dlink=True, ninja is required to build cuda extension {extension.name}."
|
||||
)
|
||||
assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}."
|
||||
|
||||
# Register .cu, .cuh, .hip, .mm and .sycl as valid source extensions.
|
||||
# NOTE: At the moment .sycl is not a standard extension for SYCL supported
|
||||
@ -2659,11 +2653,9 @@ def _import_module_from_library(module_name, path, is_python_module):
|
||||
if is_python_module:
|
||||
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
|
||||
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
||||
if spec is None:
|
||||
raise AssertionError(f"Failed to create spec for module {module_name} at {filepath}")
|
||||
assert spec is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
if not isinstance(spec.loader, importlib.abc.Loader):
|
||||
raise AssertionError("spec.loader is not a valid importlib Loader")
|
||||
assert isinstance(spec.loader, importlib.abc.Loader)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
else:
|
||||
@ -2865,10 +2857,8 @@ e.
|
||||
ldflags = sanitize_flags(ldflags)
|
||||
|
||||
# Sanity checks...
|
||||
if len(sources) != len(objects):
|
||||
raise AssertionError("sources and objects lists must be the same length")
|
||||
if len(sources) == 0:
|
||||
raise AssertionError("At least one source is required to build a library")
|
||||
assert len(sources) == len(objects)
|
||||
assert len(sources) > 0
|
||||
|
||||
compiler = get_cxx_compiler()
|
||||
|
||||
|
||||
@ -133,8 +133,9 @@ def from_dlpack(
|
||||
if device is not None:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
if not isinstance(device, torch.device):
|
||||
raise AssertionError(f"from_dlpack: unsupported device type: {type(device)}")
|
||||
assert isinstance(device, torch.device), (
|
||||
f"from_dlpack: unsupported device type: {type(device)}"
|
||||
)
|
||||
kwargs["dl_device"] = torch._C._torchDeviceToDLDevice(device)
|
||||
|
||||
ext_device = ext_tensor.__dlpack_device__()
|
||||
@ -162,10 +163,10 @@ def from_dlpack(
|
||||
dlpack = ext_tensor.__dlpack__(**kwargs)
|
||||
|
||||
else:
|
||||
if device is not None or copy is not None:
|
||||
raise AssertionError(
|
||||
"device and copy kwargs not supported when ext_tensor is already a DLPack capsule."
|
||||
)
|
||||
assert device is None and copy is None, (
|
||||
"device and copy kwargs not supported when ext_tensor is "
|
||||
"already a DLPack capsule."
|
||||
)
|
||||
# Old versions just call the converter
|
||||
dlpack = ext_tensor
|
||||
return torch._C._from_dlpack(dlpack)
|
||||
|
||||
@ -62,8 +62,7 @@ def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
|
||||
# Inputs contains the shapes of two matrices.
|
||||
m, k = a_shape
|
||||
k2, n = b_shape
|
||||
if k != k2:
|
||||
raise AssertionError(f"matmul: inner dimensions must match (k == k2), got {k} and {k2}")
|
||||
assert k == k2
|
||||
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
|
||||
return m * n * 2 * k
|
||||
|
||||
@ -79,10 +78,8 @@ def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
|
||||
# Inputs contains the shapes of two tensor.
|
||||
b, m, k = a_shape
|
||||
b2, k2, n = b_shape
|
||||
if b != b2:
|
||||
raise AssertionError(f"bmm: batch dimensions must match (b == b2), got {b} and {b2}")
|
||||
if k != k2:
|
||||
raise AssertionError(f"bmm: inner dimensions must match (k == k2), got {k} and {k2}")
|
||||
assert b == b2
|
||||
assert k == k2
|
||||
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
|
||||
flop = b * m * n * 2 * k
|
||||
return flop
|
||||
@ -269,8 +266,7 @@ def sdpa_flop_count(query_shape, key_shape, value_shape):
|
||||
b, h, s_q, d_q = query_shape
|
||||
_b2, _h2, s_k, _d2 = key_shape
|
||||
_b3, _h3, _s3, d_v = value_shape
|
||||
if not b == _b2 == _b3 or not h == _h2 == _h3 or not d_q == _d2 or not s_k == _s3 or not d_q == _d2:
|
||||
raise AssertionError("sdpa_flop_count: query/key/value shapes are incompatible")
|
||||
assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
|
||||
total_flops = 0
|
||||
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
|
||||
total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
|
||||
@ -324,21 +320,15 @@ def _unpack_flash_attention_nested_shapes(
|
||||
# In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
|
||||
# To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
|
||||
# So the flops calculation in this case is an overestimate of the actual flops.
|
||||
if len(key.shape) != 3:
|
||||
raise AssertionError("sdpa_flop_count: expected key.shape to be 3-dimensional")
|
||||
if len(value.shape) != 3:
|
||||
raise AssertionError("sdpa_flop_count: expected value.shape to be 3-dimensional")
|
||||
if grad_out is not None and grad_out.shape != query.shape:
|
||||
raise AssertionError("sdpa_flop_count: grad_out.shape must match query.shape when provided")
|
||||
assert len(key.shape) == 3
|
||||
assert len(value.shape) == 3
|
||||
assert grad_out is None or grad_out.shape == query.shape
|
||||
_, h_q, d_q = query.shape
|
||||
_, h_k, d_k = key.shape
|
||||
_, h_v, d_v = value.shape
|
||||
if cum_seq_q is None:
|
||||
raise AssertionError("sdpa_flop_count: cum_seq_q must not be None")
|
||||
if cum_seq_k is None:
|
||||
raise AssertionError("sdpa_flop_count: cum_seq_k must not be None")
|
||||
if cum_seq_q.shape != cum_seq_k.shape:
|
||||
raise AssertionError("sdpa_flop_count: cum_seq_q and cum_seq_k must have the same shape")
|
||||
assert cum_seq_q is not None
|
||||
assert cum_seq_k is not None
|
||||
assert cum_seq_q.shape == cum_seq_k.shape
|
||||
seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
|
||||
seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
|
||||
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
|
||||
@ -378,22 +368,15 @@ def _unpack_efficient_attention_nested_shapes(
|
||||
# In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
|
||||
# To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
|
||||
# So the flops calculation in this case is an overestimate of the actual flops.
|
||||
if len(key.shape) != 4:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: expected key.shape to be 4-dimensional")
|
||||
if len(value.shape) != 4:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: expected value.shape to be 4-dimensional")
|
||||
if grad_out is not None and grad_out.shape != query.shape:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: grad_out.shape must match query.shape when provided")
|
||||
assert len(key.shape) == 4
|
||||
assert len(value.shape) == 4
|
||||
assert grad_out is None or grad_out.shape == query.shape
|
||||
_, _, h_q, d_q = query.shape
|
||||
_, _, h_k, d_k = key.shape
|
||||
_, _, h_v, d_v = value.shape
|
||||
if cu_seqlens_q is None:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_q must not be None")
|
||||
if cu_seqlens_k is None:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_k must not be None")
|
||||
if cu_seqlens_q.shape != cu_seqlens_k.shape:
|
||||
raise AssertionError("_unpack_efficient_attention_nested_shapes: "
|
||||
"cu_seqlens_q and cu_seqlens_k must have the same shape")
|
||||
assert cu_seqlens_q is not None
|
||||
assert cu_seqlens_k is not None
|
||||
assert cu_seqlens_q.shape == cu_seqlens_k.shape
|
||||
seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
|
||||
seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
|
||||
for len_q, len_k in zip(seqlens_q, seqlens_k):
|
||||
@ -477,10 +460,8 @@ def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape
|
||||
_b2, _h2, s_k, _d2 = key_shape
|
||||
_b3, _h3, _s3, d_v = value_shape
|
||||
_b4, _h4, _s4, _d4 = grad_out_shape
|
||||
if not b == _b2 == _b3 == _b4 or not h == _h2 == _h3 == _h4 or not d_q == _d2:
|
||||
raise AssertionError("sdpa_backward_flop_count: batch/heads/dimension mismatch among tensors")
|
||||
if not d_v == _d4 or not s_k == _s3 or not s_q == _s4:
|
||||
raise AssertionError("sdpa_backward_flop_count: grad_out/value/key/query shapes are incompatible")
|
||||
assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
|
||||
assert d_v == _d4 and s_k == _s3 and s_q == _s4
|
||||
total_flops = 0
|
||||
# Step 1: We recompute the scores matrix.
|
||||
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
|
||||
@ -761,8 +742,7 @@ class FlopCounterMode:
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
if self.mode is None:
|
||||
raise AssertionError("Internal error: FlopCounter.__exit__ called but mode is None")
|
||||
assert self.mode is not None
|
||||
b = self.mode.__exit__(*args)
|
||||
self.mode = None # break cycles
|
||||
self.mod_tracker.__exit__()
|
||||
|
||||
@ -238,8 +238,7 @@ class BackwardHook:
|
||||
self.grad_outputs = None
|
||||
|
||||
if local_grad_outputs is not None:
|
||||
if self.output_tensors_index is None:
|
||||
raise AssertionError("output_tensors_index should not be None when grad_outputs is not None")
|
||||
assert self.output_tensors_index is not None # mypy
|
||||
return tuple(local_grad_outputs[i] for i in self.output_tensors_index)
|
||||
|
||||
grad_fn.register_hook(hook)
|
||||
|
||||
@ -137,12 +137,9 @@ class MkldnnBatchNorm(torch.jit.ScriptModule):
|
||||
def __init__(self, dense_module):
|
||||
super().__init__()
|
||||
|
||||
if dense_module.training:
|
||||
raise AssertionError("Only support eval mode batchnorm for mkldnn path now")
|
||||
if not dense_module.track_running_stats:
|
||||
raise AssertionError("Only support track_running_stats=True for mkldnn path now")
|
||||
if not dense_module.affine:
|
||||
raise AssertionError("Only support affine=True for mkldnn path now")
|
||||
assert not dense_module.training
|
||||
assert dense_module.track_running_stats
|
||||
assert dense_module.affine
|
||||
|
||||
if dense_module.momentum is None:
|
||||
self.exponential_average_factor = 0.0
|
||||
@ -207,9 +204,8 @@ class MkldnnPrelu(torch.jit.ScriptModule):
|
||||
return y
|
||||
|
||||
def to_mkldnn(module, dtype=torch.float):
|
||||
if dtype not in (torch.float, torch.bfloat16, torch.half):
|
||||
raise AssertionError("MKLDNN only support float, bfloat16, and half path now")
|
||||
|
||||
assert dtype in [torch.float, torch.bfloat16, torch.half], \
|
||||
"MKLDNN only support float, bfloat16, and half path now"
|
||||
|
||||
def m_fn(m, d):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
|
||||
@ -5,8 +5,7 @@ import torch._C
|
||||
|
||||
def format_time(time_us=None, time_ms=None, time_s=None):
|
||||
"""Define time formatting."""
|
||||
if time_us is not None or time_ms is not None or time_s is not None:
|
||||
raise AssertionError("Expected at least one of time_us, time_ms, time_s is not None.")
|
||||
assert sum([time_us is not None, time_ms is not None, time_s is not None]) == 1
|
||||
|
||||
US_IN_SECOND = 1e6
|
||||
US_IN_MS = 1e3
|
||||
|
||||
@ -351,16 +351,14 @@ class TensorWeakRef:
|
||||
ref: WeakRef[Tensor]
|
||||
|
||||
def __init__(self, tensor: Tensor):
|
||||
if not isinstance(tensor, Tensor):
|
||||
raise AssertionError(f"expected torch.Tensor, got {type(tensor)}.")
|
||||
assert isinstance(tensor, Tensor)
|
||||
self.ref = weakref.ref(tensor)
|
||||
|
||||
def __call__(self):
|
||||
out = self.ref()
|
||||
if out is None:
|
||||
return out
|
||||
if not isinstance(out, Tensor):
|
||||
raise AssertionError(f"expected torch.Tensor, got {type(out)}.")
|
||||
assert isinstance(out, Tensor)
|
||||
# TODO, add _fix_weakref type binding
|
||||
out._fix_weakref() # type: ignore[attr-defined]
|
||||
return out
|
||||
|
||||
@ -1024,22 +1024,8 @@ def gen_functionalization_registration(
|
||||
) -> list[str]:
|
||||
@with_native_function
|
||||
def emit_registration_helper(f: NativeFunction) -> str:
|
||||
if f.has_composite_implicit_autograd_kernel:
|
||||
metadata = composite_implicit_autograd_index.get_kernel(f)
|
||||
assert metadata is not None
|
||||
native_api_name = metadata.kernel
|
||||
sig = NativeSignature(f.func, symint=metadata.supports_symint())
|
||||
# Note [Composite view ops in the functionalization pass]
|
||||
# We don't need to worry about implemententing functionalization kernels for views with
|
||||
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
|
||||
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
|
||||
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
|
||||
registration_str = (
|
||||
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
|
||||
)
|
||||
else:
|
||||
# non-composite view ops (and inplace ops) get a normal registration.
|
||||
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
|
||||
assert not f.has_composite_implicit_autograd_kernel
|
||||
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
|
||||
return f'm.impl("{f.func.name}", {registration_str});'
|
||||
|
||||
# Don't generate kernels in mobile build
|
||||
@ -1052,8 +1038,12 @@ def gen_functionalization_registration(
|
||||
if str(g.view.func.name) == "lift_fresh":
|
||||
return []
|
||||
view_str = []
|
||||
view_str.append(emit_registration_helper(g.view))
|
||||
if g.view_inplace is not None:
|
||||
if not g.view.has_composite_implicit_autograd_kernel:
|
||||
view_str.append(emit_registration_helper(g.view))
|
||||
if (
|
||||
g.view_inplace is not None
|
||||
and not g.view_inplace.has_composite_implicit_autograd_kernel
|
||||
):
|
||||
assert g.view_inplace.is_view_op
|
||||
view_str.append(emit_registration_helper(g.view_inplace))
|
||||
return view_str
|
||||
|
||||
Reference in New Issue
Block a user