Compare commits

...

7 Commits

Author SHA1 Message Date
b23f4687fd [Inductor][CuTeDSL] Move load_template up two directories (#165868)
Summary:
This is a reland of https://github.com/pytorch/pytorch/pull/165347

Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.

Test Plan: test/inductor/test_flex_flash

Differential Revision: D85013024

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165868
Approved by: https://github.com/jananisriram
2025-10-20 12:14:38 +00:00
2705937080 [CI] Add rocm CI back to trunk for pre-submit/PR jobs (#165674)
Only adding single-GPU shards for now, to observe how current capacity handles it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165674
Approved by: https://github.com/jeffdaily
2025-10-20 12:14:06 +00:00
c1eda348be [cuda] fix triu/tril int32 overflow for large matrices (#164705)
Fixes #136611

Cast blockIdx.x to int64_t before multiplication to prevent overflow when computing linear_idx for matrices larger than 2^31 elements.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164705
Approved by: https://github.com/eqy, https://github.com/ngimel
2025-10-20 07:17:41 +00:00
ba93d5636e [cuda] fix nll_loss2d backward bounds check with reduction=none (#165247)
Fixes #49882

Add missing bounds check in nll_loss2d backward kernel with reduction=none. Forward kernel already had CUDA_KERNEL_ASSERT for target bounds, now backward kernel matches.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165247
Approved by: https://github.com/ngimel
2025-10-20 06:25:11 +00:00
722b2b86c9 [dynamo] Remove duplicated guards (#165806)
This is by looking at a tlparse of an internal job. We will need deeper audit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165806
Approved by: https://github.com/jansel
2025-10-20 05:50:33 +00:00
e1e8491b31 [1/N] Change C-style casts to static_cast or reinterpret_cast (#165750)
This series of changes try to cover C style casts into C++ alternatives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750
Approved by: https://github.com/Skylion007
2025-10-20 04:36:19 +00:00
767199fd9b [flex_attention] replace sliced BlockMask noop with helpful error (#164702)
Fixes part of #163314

After slicing BlockMask with `[]`, mask_mod was silently replaced with noop_mask. This caused silent incorrect results when users applied transformations to `sliced_mask.mask_mod`.

Replace noop with `_sliced_mask_mod_error` that raises RuntimeError with guidance to use `base_mask.mask_mod` instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164702
Approved by: https://github.com/drisspg, https://github.com/BoyuanFeng
2025-10-20 03:46:16 +00:00
65 changed files with 653 additions and 348 deletions

View File

@ -190,6 +190,40 @@ jobs:
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
secrets: inherit
linux-jammy-rocm-py3_10-build:
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
secrets: inherit
inductor-build:
name: inductor-build
uses: ./.github/workflows/_linux-build.yml

View File

@ -146,6 +146,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
int64_t batch_size = target.size(0);
int64_t H = target.size(1);
int64_t W = target.size(2);
int64_t n_classes = grad_input.size(1);
CUDA_KERNEL_LOOP(index, n_threads) {
const int64_t b = index % batch_size;
@ -156,6 +157,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
if (cur_target == ignore_index) {
continue;
}
CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes);
scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1));
grad_input[b][cur_target][h][w] = value * grad_output[b][h][w];
}

View File

@ -44,7 +44,7 @@ __global__ void triu_tril_kernel(
const int64_t k,
const int64_t N_padded,
const IndexType last_dim_padded) {
int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread;
int64_t linear_idx = (((int64_t)blockIdx.x) * blockDim.x + threadIdx.x) * elements_per_thread;
if (linear_idx >= N_padded) {
return;
}

View File

@ -102,7 +102,7 @@ uint64_t getNonDeterministicRandom(bool is_cuda) {
} else {
std::random_device rd;
// limit to 53 bits to ensure unique representation in double
s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
s = (((static_cast<uint64_t>(rd())) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
}
return s;
}

View File

@ -20,7 +20,8 @@ void maybeApplyRefcountedDeleter(const c10::Storage& storage) {
std::lock_guard<std::mutex> guard(replace_data_ptr_mutex);
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
if ((void*)data_ptr.get_deleter() == (void*)&c10::refcounted_deleter) {
if (reinterpret_cast<const void*>(data_ptr.get_deleter()) ==
reinterpret_cast<const void*>(&c10::refcounted_deleter)) {
// Data pointer is already shared
return;
}

View File

@ -83,7 +83,7 @@ DEFINE_BINARY(max_slow_path, sym_max, SymInt)
SymInt::operator SymFloat() const {
if (auto ma = maybe_as_int()) {
return SymFloat(double(*ma));
return SymFloat(static_cast<double>(*ma));
} else {
return SymFloat(toSymNodeImplUnowned()->sym_float());
}

View File

@ -44,7 +44,8 @@ bool has_simple_data_ptr(const c10::StorageImpl& storage) {
}
bool is_cow_data_ptr(const c10::DataPtr& data_ptr) {
return (void*)data_ptr.get_deleter() == (void*)&cow::cow_deleter;
return reinterpret_cast<const void*>(data_ptr.get_deleter()) ==
reinterpret_cast<const void*>(&cow::cow_deleter);
}
c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {

View File

@ -512,7 +512,7 @@ struct ExpandableSegment {
header.segment_size = segment_size_;
header.num_handles = end - begin;
buf.write((const char*)&header, sizeof(ShareHeader));
buf.write(reinterpret_cast<const char*>(&header), sizeof(ShareHeader));
for (auto i : c10::irange(begin, end)) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
auto& handle = handles_.at(i).value();
@ -528,7 +528,9 @@ struct ExpandableSegment {
TORCH_CHECK(
handle.shareable_handle != std::nullopt,
"shareable_handle is null");
buf.write((const char*)&*handle.shareable_handle, sizeof(int));
buf.write(
reinterpret_cast<const char*>(&*handle.shareable_handle),
sizeof(int));
} else {
if (!handle.shareable_handle) {
CUmemFabricHandle fabric_handle;
@ -541,7 +543,8 @@ struct ExpandableSegment {
handle.shareable_handle != std::nullopt,
"shareable_handle is null");
buf.write(
(const char*)&*handle.shareable_handle, sizeof(CUmemFabricHandle));
reinterpret_cast<const char*>(&*handle.shareable_handle),
sizeof(CUmemFabricHandle));
}
}
return rangeFromHandles(begin, end);
@ -552,7 +555,7 @@ struct ExpandableSegment {
std::vector<c10::DeviceIndex> peers,
std::istream& buf) {
ShareHeader header{};
buf.read((char*)&header, sizeof(ShareHeader));
buf.read(reinterpret_cast<char*>(&header), sizeof(ShareHeader));
auto segment = std::make_unique<ExpandableSegment>(
device, std::nullopt, header.segment_size, std::move(peers));
// older build setups (e.g. multiwheels) do not have this syscall, added 2020
@ -574,11 +577,11 @@ struct ExpandableSegment {
for (auto i : c10::irange(header.num_handles)) {
(void)i;
int fd = 0;
buf.read((char*)&fd, sizeof(int));
buf.read(reinterpret_cast<char*>(&fd), sizeof(int));
auto myfd = syscall(SYS_pidfd_getfd, pidfd, fd, 0);
if (myfd == -1) {
auto err = errno;
close((int)pidfd);
close(static_cast<int>(pidfd));
for (auto& h : segment->handles_) {
C10_CUDA_DRIVER_CHECK(
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
@ -598,15 +601,16 @@ struct ExpandableSegment {
(void*)(uintptr_t)myfd,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
LOG(INFO) << "use posix fd to import expandable segments.";
close((int)myfd);
close(static_cast<int>(myfd));
segment->handles_.emplace_back(Handle{handle, std::nullopt});
}
close((int)pidfd);
close(static_cast<int>(pidfd));
} else {
for (auto i : c10::irange(header.num_handles)) {
(void)i;
CUmemFabricHandle fabric_handle;
buf.read((char*)&fabric_handle, sizeof(CUmemFabricHandle));
buf.read(
reinterpret_cast<char*>(&fabric_handle), sizeof(CUmemFabricHandle));
CUmemGenericAllocationHandle handle = 0;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_(
&handle,
@ -1059,7 +1063,7 @@ class RingBuffer {
void setMaxEntries(size_t size) {
std::lock_guard<std::mutex> lk(alloc_trace_lock);
alloc_trace_max_entries_ = std::max(size_t(1), size);
alloc_trace_max_entries_ = std::max(static_cast<size_t>(1), size);
}
void insertEntries(const T& entry) {
@ -1991,15 +1995,16 @@ class DeviceCachingAllocator {
while (base_block->prev) {
base_block = base_block->prev;
}
offset = (char*)block->ptr - (char*)base_block->ptr;
offset = static_cast<const char*>(block->ptr) -
static_cast<const char*>(base_block->ptr);
cudaIpcMemHandle_t handle;
C10_CUDA_CHECK(cudaIpcGetMemHandle(&handle, base_block->ptr));
ss.write((char*)&handle, CUDA_IPC_HANDLE_SIZE);
ss.write(reinterpret_cast<const char*>(&handle), CUDA_IPC_HANDLE_SIZE);
} else {
ss.put(SHAREABLE_CUDA_EXPANDABLE_SEGMENT);
auto full_range = block->expandable_segment_->share(
SegmentRange(block->ptr, block->size), ss);
offset = (char*)block->ptr - full_range.ptr;
offset = static_cast<const char*>(block->ptr) - full_range.ptr;
}
return ShareableHandle{offset, ss.str()};
}
@ -3229,7 +3234,8 @@ class DeviceCachingAllocator {
}
total_allocated_memory += size;
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
p.block = new Block(
p.device(), p.stream(), size, p.pool, static_cast<char*>(ptr));
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
stats.segment[stat_type].increase(1);
stats.reserved_bytes[stat_type].increase(size);
@ -3777,7 +3783,7 @@ class NativeCachingAllocator : public CUDAAllocator {
allocated_blocks;
static size_t get_mutex_shard_id(void* ptr) {
return twang_mix64((size_t)ptr) % kNumMutexShard;
return twang_mix64(reinterpret_cast<uintptr_t>(ptr)) % kNumMutexShard;
}
void add_allocated_block(Block* block) {
@ -3814,8 +3820,8 @@ class NativeCachingAllocator : public CUDAAllocator {
if (size < device_count) {
device_allocator.resize(device_count);
for (const auto i : c10::irange(size, device_count)) {
device_allocator[i] =
std::make_unique<DeviceCachingAllocator>(c10::DeviceIndex(i));
device_allocator[i] = std::make_unique<DeviceCachingAllocator>(
static_cast<c10::DeviceIndex>(i));
}
}
}
@ -4344,7 +4350,7 @@ class NativeCachingAllocator : public CUDAAllocator {
// SHARABLE_CUDA_MALLOC
if (type == SHAREABLE_CUDA_MALLOC) {
cudaIpcMemHandle_t cuda_handle;
ss.read((char*)&cuda_handle, CUDA_IPC_HANDLE_SIZE);
ss.read(reinterpret_cast<char*>(&cuda_handle), CUDA_IPC_HANDLE_SIZE);
C10_CUDA_CHECK(cudaIpcOpenMemHandle(
&cuda_ipc_ptr_, cuda_handle, cudaIpcMemLazyEnablePeerAccess));
} else if (type == SHAREABLE_CUDA_EXPANDABLE_SEGMENT) {

View File

@ -46,7 +46,7 @@ bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
struct UsageStreamHash {
size_t operator()(const UsageStream& us) const noexcept {
return std::hash<void*>{}(us.stream) + size_t(us.device);
return std::hash<void*>{}(us.stream) + static_cast<size_t>(us.device);
}
};

View File

@ -128,7 +128,7 @@ std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
} else if (s.isExt()) {
stream << "EXT";
} else {
stream << "PRIORITY " << int(s.getStreamType());
stream << "PRIORITY " << static_cast<int>(s.getStreamType());
}
return stream;
}

View File

@ -46,7 +46,8 @@ std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
for (const auto i : c10::irange(replicates)) {
auto delta_ns = end_times[i].t_ - start_times_[i].t_;
auto delta_approx = end_times[i].approx_t_ - start_times_[i].approx_t_;
scale_factors[i] = (double)delta_ns / (double)delta_approx;
scale_factors[i] =
static_cast<double>(delta_ns) / static_cast<double>(delta_approx);
}
std::sort(scale_factors.begin(), scale_factors.end());
long double scale_factor = scale_factors[replicates / 2 + 1];
@ -64,7 +65,8 @@ std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
for (const auto i : c10::irange(replicates)) {
auto dt = start_times_[i].t_ - t0;
auto dt_approx =
(double)(start_times_[i].approx_t_ - t0_approx) * scale_factor;
static_cast<double>(start_times_[i].approx_t_ - t0_approx) *
scale_factor;
t0_correction[i] = dt - (time_t)dt_approx; // NOLINT
}
t0 += t0_correction[t0_correction.size() / 2 + 1]; // NOLINT
@ -72,7 +74,9 @@ std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
return [=](approx_time_t t_approx) {
// See above for why this is more stable than `A * t_approx + B`.
return t_approx > t0_approx
? (time_t)((double)(t_approx - t0_approx) * scale_factor) + t0
? static_cast<time_t>(
static_cast<double>(t_approx - t0_approx) * scale_factor) +
t0
: 0;
};
}

View File

@ -132,15 +132,15 @@ std::ostream& operator<<(std::ostream& o, const uint128& b) {
int div_base_log = 0;
switch (flags & std::ios::basefield) {
case std::ios::hex:
div = (uint64_t)0x1000000000000000u; // 16^15
div = static_cast<uint64_t>(0x1000000000000000u); // 16^15
div_base_log = 15;
break;
case std::ios::oct:
div = (uint64_t)01000000000000000000000u; // 8^21
div = static_cast<uint64_t>(01000000000000000000000u); // 8^21
div_base_log = 21;
break;
default: // std::ios::dec
div = (uint64_t)10000000000000000000u; // 10^19
div = static_cast<uint64_t>(10000000000000000000u); // 10^19
div_base_log = 19;
break;
}

View File

@ -962,6 +962,42 @@ class TypePropagationTests(torch._dynamo.test_case.TestCase):
opt_fn(torch.randn(4, 4))
class DuplicateGuardTest(torch._dynamo.test_case.TestCase):
def test_duplicate_guard(self):
class Foo:
def __init__(self):
self.x = 4
self.bar = 4
foo = Foo()
def fn(x):
if hasattr(foo, "y"):
x = torch.sin(x)
if hasattr(foo, "y"):
x = torch.sin(x)
if hasattr(foo, "bar"):
x = torch.cos(x)
if hasattr(foo, "bar"):
x = torch.cos(x)
return x + foo.x
try:
from .utils import install_guard_manager_testing_hook
except ImportError:
from utils import install_guard_manager_testing_hook
def hook(guard_wrapper, f_locals, builder):
guard_str = str(guard_wrapper)
# One for tensor and one for y
self.assertEqual(guard_str.count("NO_HASATTR"), 2)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
with install_guard_manager_testing_hook(hook):
opt_fn(torch.randn(4, 4))
class RecursiveDictTagTests(torch._dynamo.test_case.TestCase):
def setUp(self):
self._prev = torch._dynamo.config.use_recursive_dict_tags_for_guards

View File

@ -4995,6 +4995,28 @@ class TestBlockMask(InductorTestCase):
block_mask.full_kv_indices[:, :, q_index, :],
)
@supported_platform
def test_sliced_blockmask_mask_mod_error(self, device):
"""Test that sliced BlockMask raises helpful error when used with flex_attention"""
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
base_mask = create_block_mask(
causal_mask, B=1, H=1, Q_LEN=256, KV_LEN=256, device=device
)
sliced_mask = base_mask[:, :, 0]
q = torch.randn(1, 1, 1, 64, device=device)
k = torch.randn(1, 1, 256, 64, device=device)
v = torch.randn(1, 1, 256, 64, device=device)
compiled_fa = torch.compile(flex_attention)
with self.assertRaisesRegex(
RuntimeError, "Cannot use mask_mod from a sliced BlockMask"
):
compiled_fa(q, k, v, block_mask=sliced_mask)
@supported_platform
def test_block_mask_device_change(self, device):
device = torch.device(device)

View File

@ -9931,6 +9931,28 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
C = torch.matmul(A, B)
self.assertEqual(C, B.sum().expand(B.shape))
@onlyCUDA
@largeTensorTest("40GB")
def test_triu_tril_large_matrix_64bit(self, device):
"""
Test triu/tril with large matrices requiring 64-bit indexing.
Regression test for https://github.com/pytorch/pytorch/issues/136611
"""
# 100k x 100k matrix with 10B elements requires 64-bit indexing
q_len = 100000
causal_mask = torch.full((q_len, q_len), float('-inf'), device=device, dtype=torch.float32)
causal_mask.triu_(1)
# Verify row 42950 is correct (previously failed due to int32 overflow at row*col)
row_42950 = causal_mask[42950]
num_zeros = (row_42950 == 0.0).sum().item()
expected_zeros = 42951
self.assertEqual(num_zeros, expected_zeros)
# Verify last row is correct
last_row = causal_mask[-1]
self.assertTrue((last_row == 0.0).all())
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
def test_triu_tril_extreme_k_values(self, device, dtype):
"""

View File

@ -1054,9 +1054,7 @@ class GuardBuilder(GuardBuilderBase):
self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
"pytorch/compiler:guard_nn_modules"
)
self.already_guarded_not_present_in_generic_dict: OrderedSet[
tuple[str, str]
] = OrderedSet()
self.already_added_code_parts: OrderedSet[str] = OrderedSet()
def guard_on_dict_keys_and_ignore_order(
self, example_value: dict[Any, Any], guard: Guard
@ -1849,6 +1847,10 @@ class GuardBuilder(GuardBuilderBase):
code = f"hasattr({ref}, {attr!r})"
else:
code = f"not hasattr({ref}, {attr!r})"
if code in self.already_added_code_parts:
return
self._set_guard_export_info(
guard, [code], provided_guarded_object=self.get(base)
)
@ -1882,6 +1884,7 @@ class GuardBuilder(GuardBuilderBase):
)
else:
base_manager.add_no_hasattr_guard(attr, get_verbose_code_parts(code, guard))
self.already_added_code_parts.add(code)
def NOT_PRESENT_IN_GENERIC_DICT(
self, guard: Guard, attr: Optional[Any] = None
@ -1892,7 +1895,8 @@ class GuardBuilder(GuardBuilderBase):
base_manager = self.get_guard_manager(guard)
if (ref, attr) in self.already_guarded_not_present_in_generic_dict:
code = f"not ___dict_contains({attr!r}, {ref}.__dict__)"
if code in self.already_added_code_parts:
return
mod_dict_source = f"{guard.name}.__dict__"
@ -1902,11 +1906,10 @@ class GuardBuilder(GuardBuilderBase):
guard_manager_enum=GuardManagerType.GUARD_MANAGER,
)
code = f"not ___dict_contains({attr!r}, {ref}.__dict__)"
mod_generic_dict_manager.add_dict_contains_guard(
False, attr, get_verbose_code_parts(code, guard)
)
self.already_guarded_not_present_in_generic_dict.add((ref, attr))
self.already_added_code_parts.add(code)
def TYPE_MATCH(self, guard: Guard) -> None:
# ___check_type_id is same as `id(type(x)) == y`
@ -1948,11 +1951,14 @@ class GuardBuilder(GuardBuilderBase):
maybe_not = "not " if invert else ""
code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})"
if code in self.already_added_code_parts:
return
self._set_guard_export_info(guard, [code])
self.get_guard_manager(guard).add_dict_contains_guard(
not invert, key, get_verbose_code_parts(code, guard)
)
self.already_added_code_parts.add(code)
def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool) -> None:
set_ref = self.arg_ref(guard)
@ -1960,12 +1966,15 @@ class GuardBuilder(GuardBuilderBase):
contains = not invert # install_dict_contains_guard inverts "contains"
code = f"set.__contains__({set_ref}, {item!r})"
if code in self.already_added_code_parts:
return
self._set_guard_export_info(guard, [code])
self.get_guard_manager(guard).add_set_contains_guard(
contains, item, get_verbose_code_parts(code, guard)
)
self.already_added_code_parts.add(code)
def BOOL_MATCH(self, guard: Guard) -> None:
# checks val == True or val == False

View File

@ -3,6 +3,7 @@
import math
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any, Optional, Union
@ -36,6 +37,7 @@ from ...lowering import (
to_dtype,
)
from ...select_algorithm import realize_inputs
from ...utils import load_template
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
@ -337,13 +339,8 @@ def next_power_of_two(n):
return 2 ** math.ceil(math.log2(n))
_TEMPLATE_DIR = Path(__file__).parent / "templates"
def load_template(name: str) -> str:
"""Load a template file and return its content."""
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
return f.read()
_FLEX_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_flex_template = partial(load_template, template_dir=_FLEX_TEMPLATE_DIR)
# Template strings have been moved to templates/common.py.jinja

View File

@ -29,7 +29,7 @@ from .common import (
freeze_irnodes,
get_fwd_subgraph_outputs,
infer_dense_strides,
load_template,
load_flex_template,
maybe_realize,
set_head_dim_values,
SubgraphResults,
@ -79,9 +79,9 @@ def get_float32_precision():
flex_attention_template = TritonTemplate(
name="flex_attention",
grid=flex_attention_grid,
source=load_template("flex_attention")
+ load_template("utilities")
+ load_template("common"),
source=load_flex_template("flex_attention")
+ load_flex_template("utilities")
+ load_flex_template("common"),
)
@ -469,7 +469,7 @@ def flex_attention_backward_grid(
flex_attention_backward_template = TritonTemplate(
name="flex_attention_backward",
grid=flex_attention_backward_grid,
source=load_template("flex_backwards") + load_template("utilities"),
source=load_flex_template("flex_backwards") + load_flex_template("utilities"),
)

View File

@ -22,7 +22,7 @@ from .common import (
create_num_blocks_fake_generator,
freeze_irnodes,
get_fwd_subgraph_outputs,
load_template,
load_flex_template,
maybe_realize,
set_head_dim_values,
)
@ -97,9 +97,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
flex_decoding_template = TritonTemplate(
name="flex_decoding",
grid=flex_decoding_grid,
source=load_template("flex_decode")
+ load_template("utilities")
+ load_template("common"),
source=load_flex_template("flex_decode")
+ load_flex_template("utilities")
+ load_flex_template("common"),
)

View File

@ -12,7 +12,7 @@ from torch.fx import GraphModule
from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox
from ...lowering import empty_strided
from .common import infer_dense_strides, load_template, SubgraphResults
from .common import infer_dense_strides, load_flex_template, SubgraphResults
aten = torch.ops.aten
@ -36,7 +36,7 @@ from ...codegen.cutedsl.cutedsl_template import CuteDSLTemplate
flash_attention_cutedsl_template = CuteDSLTemplate(
name="flash_attention_cutedsl", source=load_template("flash_attention")
name="flash_attention_cutedsl", source=load_flex_template("flash_attention")
)

View File

@ -67,6 +67,9 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_flatten, tree_map_only
if TYPE_CHECKING:
from pathlib import Path
OPTIMUS_EXCLUDE_POST_GRAD = [
"activation_quantization_aten_pass",
"inductor_autotune_lookup_table",
@ -3885,3 +3888,10 @@ def is_nonfreeable_buffers(dep: Dep) -> bool:
return dep_name.startswith(
("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
)
# Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them
def load_template(name: str, template_dir: Path) -> str:
"""Load a template file and return its content."""
with open(template_dir / f"{name}.py.jinja") as f:
return f.read()

View File

@ -151,7 +151,7 @@ static PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) {
static PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPDevice*)_self;
auto self = reinterpret_cast<THPDevice*>(_self);
auto ret = THPObjectPtr{PyTuple_New(2)};
if (!ret)
throw python_error();
@ -221,8 +221,16 @@ typedef PyObject* (*getter)(PyObject*, void*);
// NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in
static const std::initializer_list<PyGetSetDef> THPDevice_properties = {
{"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
{"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
{"type",
reinterpret_cast<getter>(THPDevice_type),
nullptr,
nullptr,
nullptr},
{"index",
reinterpret_cast<getter>(THPDevice_index),
nullptr,
nullptr,
nullptr},
{nullptr}};
static const std::initializer_list<PyMethodDef> THPDevice_methods = {
@ -242,18 +250,18 @@ PyTypeObject THPDeviceType = {
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPDevice_repr, /* tp_repr */
reinterpret_cast<reprfunc>(THPDevice_repr), /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
(hashfunc)THPDevice_hash, /* tp_hash */
reinterpret_cast<hashfunc>(THPDevice_hash), /* tp_hash */
// TODO: We're not sure if this is a good idea or not, because making
// torch.device callable means that it will start returning true
// for callable() queries, and that is unexpected. We can always add
// this later, so for now, don't actually implement this
// THPDevice_call, /* tp_call */
nullptr, /* tp_call */
(reprfunc)THPDevice_str, /* tp_str */
reinterpret_cast<reprfunc>(THPDevice_str), /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
@ -261,7 +269,7 @@ PyTypeObject THPDeviceType = {
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
(richcmpfunc)THPDevice_rc, /* tp_richcompare */
static_cast<richcmpfunc>(THPDevice_rc), /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
@ -286,7 +294,8 @@ void THPDevice_init(PyObject* module) {
}
Py_INCREF(&THPDeviceType);
THPUpperModuleOfDevice = module;
if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) {
if (PyModule_AddObject(
module, "device", reinterpret_cast<PyObject*>(&THPDeviceType)) != 0) {
throw python_error();
}
}

View File

@ -69,14 +69,14 @@ static PyObject* THPDtype_reduce(PyObject* _self, PyObject* noargs) {
* For singletons, a string is returned. The string should be interpreted
* as the name of a global variable.
*/
auto self = (THPDtype*)_self;
auto self = reinterpret_cast<THPDtype*>(_self);
return THPUtils_packString(self->name);
END_HANDLE_TH_ERRORS
}
static PyObject* THPDtype_to_real(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto* self = (THPDtype*)_self;
auto* self = reinterpret_cast<THPDtype*>(_self);
auto scalar_type = self->scalar_type;
if (!at::isFloatingType(self->scalar_type)) {
scalar_type = at::toRealValueType(self->scalar_type);
@ -87,7 +87,7 @@ static PyObject* THPDtype_to_real(PyObject* _self, PyObject* noargs) {
static PyObject* THPDtype_to_complex(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto* self = (THPDtype*)_self;
auto* self = reinterpret_cast<THPDtype*>(_self);
auto scalar_type = self->scalar_type;
if (!at::isComplexType(self->scalar_type)) {
scalar_type = at::toComplexType(self->scalar_type);
@ -100,13 +100,25 @@ typedef PyObject* (*getter)(PyObject*, void*);
static const std::initializer_list<PyGetSetDef> THPDtype_properties = {
{"is_floating_point",
(getter)THPDtype_is_floating_point,
reinterpret_cast<getter>(THPDtype_is_floating_point),
nullptr,
nullptr,
nullptr},
{"is_complex",
reinterpret_cast<getter>(THPDtype_is_complex),
nullptr,
nullptr,
nullptr},
{"is_signed",
reinterpret_cast<getter>(THPDtype_is_signed),
nullptr,
nullptr,
nullptr},
{"itemsize",
reinterpret_cast<getter>(THPDtype_itemsize),
nullptr,
nullptr,
nullptr},
{"is_complex", (getter)THPDtype_is_complex, nullptr, nullptr, nullptr},
{"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr},
{"itemsize", (getter)THPDtype_itemsize, nullptr, nullptr, nullptr},
{nullptr}};
static const std::initializer_list<PyMethodDef> THPDtype_methods = {
@ -130,7 +142,7 @@ PyTypeObject THPDtypeType = {
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPDtype_repr, /* tp_repr */
reinterpret_cast<reprfunc>(THPDtype_repr), /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
@ -190,7 +202,8 @@ void THPDtype_init(PyObject* module) {
throw python_error();
}
Py_INCREF(&THPDtypeType);
if (PyModule_AddObject(module, "dtype", (PyObject*)&THPDtypeType) != 0) {
if (PyModule_AddObject(
module, "dtype", reinterpret_cast<PyObject*>(&THPDtypeType)) != 0) {
throw python_error();
}
}

View File

@ -48,7 +48,7 @@ static PyObject* THPEvent_pynew(
TORCH_CHECK(ptr, "Failed to allocate memory for Event");
}
THPEvent* self = (THPEvent*)ptr.get();
THPEvent* self = reinterpret_cast<THPEvent*>(ptr.get());
// TODO: blocking and interprocess are not supported yet. To support them, the
// flag system of c10::Event needs to be refactored. C10::Event should also
@ -64,7 +64,7 @@ static PyObject* THPEvent_pynew(
(enable_timing ? c10::EventFlag::BACKEND_DEFAULT
: c10::EventFlag::PYTORCH_DEFAULT));
return (PyObject*)ptr.release();
return static_cast<PyObject*>(ptr.release());
END_HANDLE_TH_ERRORS
}
@ -82,7 +82,7 @@ static void THPEvent_dealloc(THPEvent* self) {
pybind11::gil_scoped_release no_gil{};
self->event.~Event();
}
Py_TYPE(self)->tp_free((PyObject*)self);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
static PyObject* THPEvent_get_device(THPEvent* self, void* unused) {
@ -96,7 +96,7 @@ static PyObject* THPEvent_record(
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
auto self = (THPEvent*)_self;
auto self = reinterpret_cast<THPEvent*>(_self);
PyObject* _stream = Py_None;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr const char* accepted_args[] = {"stream", nullptr};
@ -111,7 +111,7 @@ static PyObject* THPEvent_record(
return nullptr;
}
if (_stream != Py_None) {
auto stream = (THPStream*)_stream;
auto stream = reinterpret_cast<THPStream*>(_stream);
self->event.record(c10::Stream::unpack3(
stream->stream_id,
static_cast<c10::DeviceIndex>(stream->device_index),
@ -130,7 +130,7 @@ static PyObject* THPEvent_from_ipc_handle(
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
auto type = (PyTypeObject*)_type;
auto type = reinterpret_cast<PyTypeObject*>(_type);
static torch::PythonArgParser parser({
"from_ipc_handle(Device device, std::string ipc_handle)",
@ -146,13 +146,13 @@ static PyObject* THPEvent_from_ipc_handle(
if (!ptr) {
return nullptr;
}
THPEvent* self = (THPEvent*)ptr.get();
THPEvent* self = reinterpret_cast<THPEvent*>(ptr.get());
// TODO: for constructing event from ipc handle, the c10::Event needs to have
// more general constructor to achieve that.
new (&self->event) c10::Event(device.type(), c10::EventFlag::PYTORCH_DEFAULT);
return (PyObject*)ptr.release();
return static_cast<PyObject*>(ptr.release());
END_HANDLE_TH_ERRORS
}
@ -174,7 +174,7 @@ static PyObject* THPEvent_wait(
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS {
auto self = (THPEvent*)_self;
auto self = reinterpret_cast<THPEvent*>(_self);
PyObject* _stream = Py_None;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr const char* accepted_args[] = {"stream", nullptr};
@ -189,7 +189,7 @@ static PyObject* THPEvent_wait(
return nullptr;
}
if (_stream != Py_None) {
auto stream = (THPStream*)_stream;
auto stream = reinterpret_cast<THPStream*>(_stream);
self->event.block(c10::Stream::unpack3(
stream->stream_id,
static_cast<c10::DeviceIndex>(stream->device_index),
@ -206,15 +206,15 @@ static PyObject* THPEvent_wait(
static PyObject* THPEvent_query(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPEvent*)_self;
auto self = reinterpret_cast<THPEvent*>(_self);
return PyBool_FromLong(self->event.query());
END_HANDLE_TH_ERRORS
}
static PyObject* THPEvent_elapsed_time(PyObject* _self, PyObject* _other) {
HANDLE_TH_ERRORS
auto self = (THPEvent*)_self;
auto other = (THPEvent*)_other;
auto self = reinterpret_cast<THPEvent*>(_self);
auto other = reinterpret_cast<THPEvent*>(_other);
return PyFloat_FromDouble(self->event.elapsedTime(other->event));
END_HANDLE_TH_ERRORS
}
@ -222,7 +222,7 @@ static PyObject* THPEvent_elapsed_time(PyObject* _self, PyObject* _other) {
static PyObject* THPEvent_synchronize(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS {
pybind11::gil_scoped_release no_gil{};
auto self = (THPEvent*)_self;
auto self = reinterpret_cast<THPEvent*>(_self);
self->event.synchronize();
}
Py_RETURN_NONE;
@ -231,7 +231,7 @@ static PyObject* THPEvent_synchronize(PyObject* _self, PyObject* noargs) {
static PyObject* THPEvent_evend_id(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPEvent*)_self;
auto self = reinterpret_cast<THPEvent*>(_self);
return PyLong_FromVoidPtr(self->event.eventId());
END_HANDLE_TH_ERRORS
}
@ -251,8 +251,16 @@ static PyObject* THPEvent_repr(THPEvent* self) {
// NOLINTNEXTLINE(*c-arrays*, *global-variables)
static struct PyGetSetDef THPEvent_properties[] = {
{"device", (getter)THPEvent_get_device, nullptr, nullptr, nullptr},
{"event_id", (getter)THPEvent_evend_id, nullptr, nullptr, nullptr},
{"device",
reinterpret_cast<getter>(THPEvent_get_device),
nullptr,
nullptr,
nullptr},
{"event_id",
reinterpret_cast<getter>(THPEvent_evend_id),
nullptr,
nullptr,
nullptr},
{nullptr}};
// NOLINTNEXTLINE(*c-arrays*, *global-variables)
@ -280,12 +288,12 @@ PyTypeObject THPEventType = {
"torch.Event", /* tp_name */
sizeof(THPEvent), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)THPEvent_dealloc, /* tp_dealloc */
reinterpret_cast<destructor>(THPEvent_dealloc), /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPEvent_repr, /* tp_repr */
reinterpret_cast<reprfunc>(THPEvent_repr), /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
@ -322,7 +330,8 @@ void THPEvent_init(PyObject* module) {
throw python_error();
}
Py_INCREF(&THPEventType);
if (PyModule_AddObject(module, "Event", (PyObject*)&THPEventType) < 0) {
if (PyModule_AddObject(
module, "Event", reinterpret_cast<PyObject*>(&THPEventType)) < 0) {
throw python_error();
}
}

View File

@ -65,7 +65,8 @@ could not be completed because the input matrix is singular.",
"Exception raised when device is out of memory",
PyExc_RuntimeError,
nullptr));
PyTypeObject* type = (PyTypeObject*)THPException_OutOfMemoryError;
PyTypeObject* type =
reinterpret_cast<PyTypeObject*>(THPException_OutOfMemoryError);
type->tp_name = "torch.OutOfMemoryError";
ASSERT_TRUE(
PyModule_AddObject(
@ -133,7 +134,7 @@ could not be completed because the input matrix is singular.",
"Exception raised while executing on device",
PyExc_RuntimeError,
nullptr));
type = (PyTypeObject*)THPException_AcceleratorError;
type = reinterpret_cast<PyTypeObject*>(THPException_AcceleratorError);
ASSERT_TRUE(
PyModule_AddObject(
module, "AcceleratorError", THPException_AcceleratorError) == 0);

View File

@ -21,7 +21,7 @@ using namespace torch;
PyObject* THPGeneratorClass = nullptr;
PyObject* THPGenerator_initDefaultGenerator(const at::Generator& cdata) {
auto type = (PyTypeObject*)THPGeneratorClass;
auto type = reinterpret_cast<PyTypeObject*>(THPGeneratorClass);
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
if (!self)
throw python_error();
@ -49,7 +49,8 @@ static PyObject* THPGenerator_pynew(
auto r = parser.parse(args, kwargs, parsed_args);
auto device = r.deviceWithDefault(0, at::Device(at::kCPU));
THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0));
THPGeneratorPtr self(
reinterpret_cast<THPGenerator*>(type->tp_alloc(type, 0)));
c10::DeviceType device_type = device.type();
if (device_type == at::kCPU) {
@ -60,14 +61,14 @@ static PyObject* THPGenerator_pynew(
.getNewGenerator(device.index());
}
return (PyObject*)self.release();
return reinterpret_cast<PyObject*>(self.release());
END_HANDLE_TH_ERRORS
}
static PyObject* THPGenerator_getState(PyObject* _self, PyObject* noargs) {
using namespace torch::autograd;
HANDLE_TH_ERRORS
auto& gen = ((THPGenerator*)_self)->cdata;
auto& gen = (reinterpret_cast<THPGenerator*>(_self))->cdata;
// See Note [Acquire lock when using random generators]
std::scoped_lock<std::mutex> lock(gen.mutex());
@ -88,7 +89,7 @@ static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) {
"expected a torch.ByteTensor, but got {}",
Py_TYPE(_new_state)->tp_name));
}
auto self = (THPGenerator*)_self;
auto self = reinterpret_cast<THPGenerator*>(_self);
auto& gen = self->cdata;
const auto& new_state_tensor = THPVariable_Unpack(_new_state);
@ -97,7 +98,7 @@ static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) {
gen.set_state(new_state_tensor);
Py_INCREF(self);
return (PyObject*)self;
return reinterpret_cast<PyObject*>(self);
END_HANDLE_TH_ERRORS
}
@ -125,7 +126,7 @@ static PyObject* THPGenerator_graphSafeGetState(
PyObject* _self,
PyObject* noargs) {
HANDLE_TH_ERRORS
auto& gen = ((THPGenerator*)_self)->cdata;
auto& gen = (reinterpret_cast<THPGenerator*>(_self))->cdata;
// See Note [Acquire lock when using random generators]
std::scoped_lock<std::mutex> lock(gen.mutex());
@ -138,7 +139,7 @@ static PyObject* THPGenerator_graphSafeSetState(
PyObject* _self,
PyObject* _state) {
HANDLE_TH_ERRORS
auto self = (THPGenerator*)_self;
auto self = reinterpret_cast<THPGenerator*>(_self);
auto& gen = self->cdata;
// See Note [Acquire lock when using random generators]
@ -146,13 +147,13 @@ static PyObject* THPGenerator_graphSafeSetState(
gen.graphsafe_set_state(THPGenerator_Unwrap(_state));
Py_INCREF(self);
return (PyObject*)self;
return reinterpret_cast<PyObject*>(self);
END_HANDLE_TH_ERRORS
}
static PyObject* THPGenerator_cloneState(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto& gen = ((THPGenerator*)_self)->cdata;
auto& gen = (reinterpret_cast<THPGenerator*>(_self))->cdata;
// See Note [Acquire lock when using random generators]
std::scoped_lock<std::mutex> lock(gen.mutex());
@ -163,7 +164,7 @@ static PyObject* THPGenerator_cloneState(PyObject* _self, PyObject* noargs) {
static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) {
HANDLE_TH_ERRORS
auto self = (THPGenerator*)_self;
auto self = reinterpret_cast<THPGenerator*>(_self);
auto generator = self->cdata;
TORCH_CHECK(
THPUtils_checkLong(seed),
@ -175,13 +176,13 @@ static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) {
std::scoped_lock<std::mutex> lock(generator.mutex());
generator.set_current_seed(unsigned_seed);
Py_INCREF(self);
return (PyObject*)self;
return reinterpret_cast<PyObject*>(self);
END_HANDLE_TH_ERRORS
}
static PyObject* THPGenerator_setOffset(PyObject* _self, PyObject* offset) {
HANDLE_TH_ERRORS
auto self = (THPGenerator*)_self;
auto self = reinterpret_cast<THPGenerator*>(_self);
auto generator = self->cdata;
TORCH_CHECK(
THPUtils_checkLong(offset),
@ -193,14 +194,14 @@ static PyObject* THPGenerator_setOffset(PyObject* _self, PyObject* offset) {
std::scoped_lock<std::mutex> lock(generator.mutex());
generator.set_offset(unsigned_offset);
Py_INCREF(self);
return (PyObject*)self;
return reinterpret_cast<PyObject*>(self);
END_HANDLE_TH_ERRORS
}
static PyObject* THPGenerator_seed(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
// See Note [Acquire lock when using random generators]
auto self = (THPGenerator*)_self;
auto self = reinterpret_cast<THPGenerator*>(_self);
std::scoped_lock<std::mutex> lock(self->cdata.mutex());
uint64_t seed_val = self->cdata.seed();
return THPUtils_packUInt64(seed_val);
@ -209,14 +210,14 @@ static PyObject* THPGenerator_seed(PyObject* _self, PyObject* noargs) {
static PyObject* THPGenerator_initialSeed(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPGenerator*)_self;
auto self = reinterpret_cast<THPGenerator*>(_self);
return THPUtils_packUInt64(self->cdata.current_seed());
END_HANDLE_TH_ERRORS
}
static PyObject* THPGenerator_getOffset(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPGenerator*)_self;
auto self = reinterpret_cast<THPGenerator*>(_self);
return THPUtils_packUInt64(self->cdata.get_offset());
END_HANDLE_TH_ERRORS
}
@ -229,7 +230,7 @@ static PyObject* THPGenerator_get_device(THPGenerator* self, void* unused) {
static PyObject* THPGenerator_reduce(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPGenerator*)_self;
auto self = reinterpret_cast<THPGenerator*>(_self);
auto& gen = self->cdata;
auto ret = THPObjectPtr{PyTuple_New(3)};
@ -279,7 +280,11 @@ static PyObject* THPGenerator_pickleSetState(PyObject* _self, PyObject* state) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static struct PyGetSetDef THPGenerator_properties[] = {
{"device", (getter)THPGenerator_get_device, nullptr, nullptr, nullptr},
{"device",
reinterpret_cast<getter>(THPGenerator_get_device),
nullptr,
nullptr,
nullptr},
{nullptr}};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
@ -349,11 +354,12 @@ static PyTypeObject THPGeneratorType = {
};
bool THPGenerator_init(PyObject* module) {
THPGeneratorClass = (PyObject*)&THPGeneratorType;
THPGeneratorClass = reinterpret_cast<PyObject*>(&THPGeneratorType);
if (PyType_Ready(&THPGeneratorType) < 0)
return false;
Py_INCREF(&THPGeneratorType);
PyModule_AddObject(module, "Generator", (PyObject*)&THPGeneratorType);
PyModule_AddObject(
module, "Generator", reinterpret_cast<PyObject*>(&THPGeneratorType));
return true;
}
@ -377,7 +383,8 @@ PyObject* THPGenerator_Wrap(const Generator& gen) {
return obj;
}
return THPGenerator_NewWithVar((PyTypeObject*)THPGeneratorClass, gen);
return THPGenerator_NewWithVar(
reinterpret_cast<PyTypeObject*>(THPGeneratorClass), gen);
}
at::Generator THPGenerator_Unwrap(PyObject* state) {
@ -395,7 +402,7 @@ at::Generator THPGenerator_Unwrap(PyObject* state) {
PyObject* THPGenerator_NewWithVar(PyTypeObject* type, Generator gen) {
PyObject* obj = type->tp_alloc(type, 0);
if (obj) {
auto g = (THPGenerator*)obj;
auto g = reinterpret_cast<THPGenerator*>(obj);
new (&g->cdata) Generator(std::move(gen));
set_pyobj(g->cdata, obj);
}

View File

@ -36,7 +36,7 @@ PyTypeObject THPLayoutType = {
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPLayout_repr, /* tp_repr */
reinterpret_cast<reprfunc>(THPLayout_repr), /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
@ -72,7 +72,8 @@ void THPLayout_init(PyObject* module) {
throw python_error();
}
Py_INCREF(&THPLayoutType);
if (PyModule_AddObject(module, "layout", (PyObject*)&THPLayoutType) != 0) {
if (PyModule_AddObject(
module, "layout", reinterpret_cast<PyObject*>(&THPLayoutType)) != 0) {
throw python_error();
}
}

View File

@ -29,7 +29,7 @@ static PyObject* THPMemoryFormat_repr(THPMemoryFormat* self) {
}
static PyObject* THPMemoryFormat_reduce(PyObject* _self, PyObject* noargs) {
auto* self = (THPMemoryFormat*)_self;
auto* self = reinterpret_cast<THPMemoryFormat*>(_self);
return THPUtils_packString(self->name);
}
@ -49,7 +49,7 @@ PyTypeObject THPMemoryFormatType = {
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPMemoryFormat_repr, /* tp_repr */
reinterpret_cast<reprfunc>(THPMemoryFormat_repr), /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
@ -86,7 +86,9 @@ void THPMemoryFormat_init(PyObject* module) {
}
Py_INCREF(&THPMemoryFormatType);
if (PyModule_AddObject(
module, "memory_format", (PyObject*)&THPMemoryFormatType) != 0) {
module,
"memory_format",
reinterpret_cast<PyObject*>(&THPMemoryFormatType)) != 0) {
throw python_error();
}
}

View File

@ -166,7 +166,7 @@ static PyObject* THPModule_initNames(PyObject* self, PyObject* arg) {
for (Py_ssize_t i = 0; i < num_classes; i++) {
PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i);
TORCH_CHECK(PyType_Check(obj), "expected a PyTypeObject");
PyTypeObject* type = (PyTypeObject*)obj;
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(obj);
THPObjectPtr module_name(PyObject_GetAttrString(obj, "__module__"));
if (!module_name)
@ -268,7 +268,7 @@ static PyObject* THPModule_crashIfCsrcUBSAN(PyObject* module, PyObject* arg) {
THPUtils_typename(arg));
int32_t x = THPUtils_unpackInt(arg);
double y = 1.0 / x;
return THPUtils_packInt32((int)y);
return THPUtils_packInt32(static_cast<int>(y));
END_HANDLE_TH_ERRORS
}
@ -334,7 +334,7 @@ static PyObject* THPModule_setNumThreads(PyObject* module, PyObject* arg) {
THPUtils_checkLong(arg),
"set_num_threads expects an int, but got ",
THPUtils_typename(arg));
int nthreads = (int)THPUtils_unpackLong(arg);
int nthreads = THPUtils_unpackInt(arg);
TORCH_CHECK(nthreads > 0, "set_num_threads expects a positive integer");
at::set_num_threads(nthreads);
Py_RETURN_NONE;
@ -356,7 +356,7 @@ static PyObject* THPModule_setNumInteropThreads(
"set_num_interop_threads expects an int, "
"but got ",
THPUtils_typename(arg));
int nthreads = (int)THPUtils_unpackLong(arg);
int nthreads = THPUtils_unpackInt(arg);
TORCH_CHECK(
nthreads > 0, "set_num_interop_threads expects a positive integer");
at::set_num_interop_threads(nthreads);
@ -448,7 +448,7 @@ static PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
}
if (Py_TYPE(obj) == &PyCFunction_Type) {
PyCFunctionObject* f = (PyCFunctionObject*)obj;
PyCFunctionObject* f = reinterpret_cast<PyCFunctionObject*>(obj);
if (f->m_ml->ml_doc) {
return PyErr_Format(
PyExc_RuntimeError,
@ -457,7 +457,7 @@ static PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
}
f->m_ml->ml_doc = doc_str;
} else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) {
PyMethodDescrObject* m = (PyMethodDescrObject*)obj;
PyMethodDescrObject* m = reinterpret_cast<PyMethodDescrObject*>(obj);
if (m->d_method->ml_doc) {
return PyErr_Format(
PyExc_RuntimeError,
@ -466,8 +466,7 @@ static PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
}
m->d_method->ml_doc = doc_str;
} else if (strcmp(Py_TYPE(obj)->tp_name, "getset_descriptor") == 0) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
PyGetSetDescrObject* m = (PyGetSetDescrObject*)obj;
PyGetSetDescrObject* m = reinterpret_cast<PyGetSetDescrObject*>(obj);
if (m->d_getset->doc) {
return PyErr_Format(
PyExc_RuntimeError,
@ -476,7 +475,7 @@ static PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
}
m->d_getset->doc = doc_str;
} else if (Py_TYPE(obj) == &PyType_Type) {
PyTypeObject* t = (PyTypeObject*)obj;
PyTypeObject* t = reinterpret_cast<PyTypeObject*>(obj);
if (t->tp_doc) {
return PyErr_Format(
PyExc_RuntimeError, "Type '%s' already has a docstring", t->tp_name);
@ -1472,10 +1471,11 @@ static PyObject* THPModule_willEngineExecuteNode(
torch::autograd::Node* node = nullptr;
std::shared_ptr<torch::autograd::Node> node_sp;
if (isTHPFunction) {
node_sp = ((THPFunction*)arg)->cdata.lock();
node_sp = (reinterpret_cast<THPFunction*>(arg))->cdata.lock();
node = node_sp.get();
} else {
node = ((torch::autograd::THPCppFunction*)arg)->cdata.get();
node =
(reinterpret_cast<torch::autograd::THPCppFunction*>(arg))->cdata.get();
}
const auto nodes_in_graph =
torch::autograd::get_current_graph_task_nodes_in_graph();
@ -1905,7 +1905,8 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
METH_O,
nullptr},
{"_has_torch_function_variadic",
(PyCFunction)(void (*)())THPModule_has_torch_function_variadic,
reinterpret_cast<PyCFunction>(
reinterpret_cast<void (*)()>(THPModule_has_torch_function_variadic)),
METH_FASTCALL,
nullptr},
{"_ensureCUDADeviceGuardSet",
@ -2612,7 +2613,7 @@ Call this whenever a new thread is created in order to propagate values from
.getAcceleratorHooksInterface(device_type)
.deviceCount();
}
return c10::DeviceIndex(-1);
return static_cast<c10::DeviceIndex>(-1);
});
py_module.def(
@ -2633,7 +2634,7 @@ Call this whenever a new thread is created in order to propagate values from
.getAcceleratorHooksInterface(device_type)
.getCurrentDevice();
}
return c10::DeviceIndex(-1);
return static_cast<c10::DeviceIndex>(-1);
});
py_module.def(
@ -2644,7 +2645,7 @@ Call this whenever a new thread is created in order to propagate values from
.getAcceleratorHooksInterface(device_type)
.exchangeDevice(device_index);
}
return c10::DeviceIndex(-1);
return static_cast<c10::DeviceIndex>(-1);
});
py_module.def(
@ -2656,7 +2657,7 @@ Call this whenever a new thread is created in order to propagate values from
.getAcceleratorHooksInterface(device_type)
.maybeExchangeDevice(device_index);
}
return c10::DeviceIndex(-1);
return static_cast<c10::DeviceIndex>(-1);
});
py_module.def(
@ -2820,8 +2821,8 @@ Call this whenever a new thread is created in order to propagate values from
py::arg("eps"));
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
THPDefaultCPUGenerator =
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
THPDefaultCPUGenerator = reinterpret_cast<THPGenerator*>(
THPGenerator_initDefaultGenerator(defaultGenerator));
// This reference is meant to be given away, so no need to incref here.
ASSERT_TRUE(set_module_attr(
"default_generator",

View File

@ -270,7 +270,7 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
"This probably happened because you took out a weak reference to "
"Tensor and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this tensor via the PyObject will now fail.");
((THPVariable*)pyobj)->cdata =
(reinterpret_cast<THPVariable*>(pyobj))->cdata =
c10::MaybeOwned<torch::autograd::Variable>();
} else if (THPStorage_Check(pyobj)) {
TORCH_WARN(
@ -278,7 +278,8 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
"This probably happened because you took out a weak reference to "
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this storage via the PyObject will now fail.");
((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
(reinterpret_cast<THPStorage*>(pyobj))->cdata =
c10::MaybeOwned<c10::Storage>();
}
}
Py_DECREF(pyobj);

View File

@ -23,7 +23,7 @@ PyObject* THPQScheme_New(at::QScheme qscheme, const std::string& name) {
}
static PyObject* THPQScheme_reduce(PyObject* _self, PyObject* noargs) {
auto self = (THPQScheme*)_self;
auto self = reinterpret_cast<THPQScheme*>(_self);
return THPUtils_packString(self->name);
}
@ -48,7 +48,7 @@ PyTypeObject THPQSchemeType = {
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPQScheme_repr, /* tp_repr */
reinterpret_cast<reprfunc>(THPQScheme_repr), /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
@ -84,7 +84,9 @@ void THPQScheme_init(PyObject* module) {
throw python_error();
}
Py_INCREF(&THPQSchemeType);
if (PyModule_AddObject(module, "qscheme", (PyObject*)&THPQSchemeType) != 0) {
if (PyModule_AddObject(
module, "qscheme", reinterpret_cast<PyObject*>(&THPQSchemeType)) !=
0) {
throw python_error();
}
}

View File

@ -133,7 +133,8 @@ static PyObject* THPSize_pynew(
static PyObject* THPSize_repr(THPSize* self) {
HANDLE_TH_ERRORS
std::string repr("torch.Size([");
for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
for (Py_ssize_t i = 0; i < PyTuple_Size(reinterpret_cast<PyObject*>(self));
++i) {
if (i != 0) {
repr += ", ";
}
@ -156,7 +157,7 @@ static PyObject* wrap_tuple_fn(Args... args) {
return nullptr;
if (PyTuple_Check(result.get())) {
return PyObject_CallFunctionObjArgs(
(PyObject*)&THPSizeType, result.get(), nullptr);
reinterpret_cast<PyObject*>(&THPSizeType), result.get(), nullptr);
}
return result.release();
}
@ -225,9 +226,9 @@ static PyMappingMethods THPSize_as_mapping = {
static PyObject* THPSize_numel(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPSize*)_self;
auto self = reinterpret_cast<THPSize*>(_self);
int64_t numel = 1;
for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
for (Py_ssize_t i = 0; i < PyTuple_Size(_self); ++i) {
numel *= THPUtils_unpackLong(PyTuple_GET_ITEM(self, i));
}
return THPUtils_packInt64(numel);
@ -236,19 +237,19 @@ static PyObject* THPSize_numel(PyObject* _self, PyObject* noargs) {
static PyObject* THPSize_reduce(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPSize*)_self;
auto self = reinterpret_cast<THPSize*>(_self);
auto ret = THPObjectPtr{PyTuple_New(2)};
if (!ret)
throw python_error();
auto obj = (PyObject*)(&THPSizeType);
auto obj = reinterpret_cast<PyObject*>(&THPSizeType);
Py_INCREF(&THPSizeType);
PyTuple_SET_ITEM(ret.get(), 0, obj);
THPObjectPtr t(PyTuple_New(PyTuple_Size((PyObject*)self)));
THPObjectPtr t(PyTuple_New(PyTuple_Size(_self)));
if (!t)
throw python_error();
for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
for (Py_ssize_t i = 0; i < PyTuple_Size(_self); ++i) {
auto d = PyTuple_GET_ITEM(self, i);
Py_INCREF(d);
PyTuple_SET_ITEM(t.get(), i, d);
@ -279,7 +280,7 @@ PyTypeObject THPSizeType = {
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPSize_repr, /* tp_repr */
reinterpret_cast<reprfunc>(THPSize_repr), /* tp_repr */
&THPSize_as_number, /* tp_as_number */
&THPSize_as_sequence, /* tp_as_sequence */
&THPSize_as_mapping, /* tp_as_mapping */
@ -315,7 +316,8 @@ void THPSize_init(PyObject* module) {
throw python_error();
}
Py_INCREF(&THPSizeType);
if (PyModule_AddObject(module, "Size", (PyObject*)&THPSizeType) < 0) {
if (PyModule_AddObject(
module, "Size", reinterpret_cast<PyObject*>(&THPSizeType)) < 0) {
throw python_error();
}
}

View File

@ -68,7 +68,7 @@ PyObject* THPStorage_NewWithStorage(
PyObject* obj = type->tp_alloc(type, 0);
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
auto s = (THPStorage*)obj;
auto s = reinterpret_cast<THPStorage*>(obj);
new (&s->cdata) c10::MaybeOwned<c10::Storage>();
@ -128,7 +128,7 @@ static bool THPStorage_isPreservable(THPStorage* self) {
}
if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/true) != (PyObject*)self) {
/*ignore_hermetic_tls=*/true) != reinterpret_cast<PyObject*>(self)) {
return false;
}
if (storage.use_count() <= 1) {
@ -170,14 +170,14 @@ static bool THPStorage_tryPreserve(THPStorage* self) {
storage_impl->pyobj_slot()->set_owns_pyobj(true);
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
// ensure the PyObject is in a valid state
_Py_NewReference((PyObject*)self);
_Py_NewReference(reinterpret_cast<PyObject*>(self));
self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
return true;
}
static void THPStorage_subclass_dealloc(PyObject* self) {
THPStorage* _self = (THPStorage*)self;
THPStorage* _self = reinterpret_cast<THPStorage*>(self);
if (THPStorage_tryPreserve(_self)) {
return;
@ -226,8 +226,8 @@ static void THPStorage_subclass_dealloc(PyObject* self) {
being finalized that has already been destroyed. */
if (type->tp_weaklistoffset) {
/* Modeled after GET_WEAKREFS_LISTPTR() */
PyWeakReference** list =
(PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
PyWeakReference** list = reinterpret_cast<PyWeakReference**>(
PyObject_GET_WEAKREFS_LISTPTR(self));
while (*list)
_PyWeakref_ClearRef(*list);
}
@ -549,9 +549,9 @@ static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
}
static PyMappingMethods THPStorage_mappingmethods = {
(lenfunc)THPStorage_length,
(binaryfunc)THPStorage_get,
(objobjargproc)THPStorage_set};
reinterpret_cast<lenfunc>(THPStorage_length),
reinterpret_cast<binaryfunc>(THPStorage_get),
reinterpret_cast<objobjargproc>(THPStorage_set)};
struct THPStorageMeta {
PyHeapTypeObject base;
@ -653,7 +653,8 @@ int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
return -1;
}
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPStorage_subclass_dealloc;
(reinterpret_cast<PyTypeObject*>(cls))->tp_dealloc =
static_cast<destructor>(THPStorage_subclass_dealloc);
return 0;
}
@ -674,8 +675,16 @@ typedef PyObject* (*getter)(PyObject*, void*);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static struct PyGetSetDef THPStorage_properties[] = {
{"device", (getter)THPStorage_device, nullptr, nullptr, nullptr},
{"_cdata", (getter)THPStorage_get_cdata, nullptr, nullptr, nullptr},
{"device",
reinterpret_cast<getter>(THPStorage_device),
nullptr,
nullptr,
nullptr},
{"_cdata",
reinterpret_cast<getter>(THPStorage_get_cdata),
nullptr,
nullptr,
nullptr},
{nullptr}};
bool THPStorage_init(PyObject* module) {
@ -687,20 +696,22 @@ bool THPStorage_init(PyObject* module) {
if (PyType_Ready(&THPStorageMetaType) < 0)
return false;
Py_INCREF(&THPStorageMetaType);
PyModule_AddObject(module, "_StorageMeta", (PyObject*)&THPStorageMetaType);
PyModule_AddObject(
module, "_StorageMeta", reinterpret_cast<PyObject*>(&THPStorageMetaType));
THPStorageType.tp_methods = methods.data();
THPStorageType.tp_getset = THPStorage_properties;
if (PyType_Ready(&THPStorageType) < 0)
return false;
Py_INCREF(&THPStorageType);
PyModule_AddObject(module, "StorageBase", (PyObject*)&THPStorageType);
PyModule_AddObject(
module, "StorageBase", reinterpret_cast<PyObject*>(&THPStorageType));
return true;
}
void THPStorage_postInit(PyObject* module) {
THPStorageClass =
(PyTypeObject*)PyObject_GetAttrString(module, "UntypedStorage");
THPStorageClass = reinterpret_cast<PyTypeObject*>(
PyObject_GetAttrString(module, "UntypedStorage"));
if (!THPStorageClass)
throw python_error();
}
@ -711,5 +722,5 @@ void THPStorage_assertNotNull(THPStorage* storage) {
}
void THPStorage_assertNotNull(PyObject* obj) {
THPStorage_assertNotNull((THPStorage*)obj);
THPStorage_assertNotNull(reinterpret_cast<THPStorage*>(obj));
}

View File

@ -297,7 +297,7 @@ static PyObject* THPStorage_fromBuffer(
size_bytes = count * element_size;
}
if (offset + (count * (Py_ssize_t)element_size) > buffer.len) {
if (offset + (count * static_cast<Py_ssize_t>(element_size)) > buffer.len) {
PyErr_SetString(
PyExc_ValueError,
fmt::format(
@ -309,7 +309,7 @@ static PyObject* THPStorage_fromBuffer(
return nullptr;
}
uint8_t* src = (uint8_t*)buffer.buf;
uint8_t* src = static_cast<uint8_t*>(buffer.buf);
auto fake_mode_active =
c10::impl::TorchDispatchModeTLS::get_mode(
c10::impl::TorchDispatchModeKey::FAKE) != std::nullopt;
@ -508,8 +508,8 @@ static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) {
// advanced position
const auto fd_current_pos = LSEEK(fd, 0, SEEK_CUR);
LSEEK(fd, fd_original_pos, SEEK_SET);
const auto seek_return =
PyObject_CallMethod(file, "seek", "Li", (long long)fd_current_pos, 0);
const auto seek_return = PyObject_CallMethod(
file, "seek", "Li", static_cast<long long>(fd_current_pos), 0);
if (seek_return == nullptr) {
return nullptr;
}
@ -521,18 +521,19 @@ static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) {
static PyObject* THPStorage__setCdata(PyObject* _self, PyObject* new_cdata) {
HANDLE_TH_ERRORS
auto self = (THPStorage*)_self;
auto self = reinterpret_cast<THPStorage*>(_self);
TORCH_CHECK(
THPUtils_checkLong(new_cdata),
"given an invalid argument to "
"_set_cdata - expected an int or long, but got ",
THPUtils_typename(new_cdata));
c10::StorageImpl* ptr = (c10::StorageImpl*)PyLong_AsVoidPtr(new_cdata);
c10::StorageImpl* ptr =
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(new_cdata));
self->cdata.~MaybeOwned<c10::Storage>();
self->cdata = c10::MaybeOwned<c10::Storage>::owned(
c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(ptr)));
Py_INCREF(self);
return (PyObject*)self;
return reinterpret_cast<PyObject*>(self);
END_HANDLE_TH_ERRORS
}

View File

@ -256,7 +256,7 @@ static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) {
"a file descriptor (int) and storage size (int)");
return nullptr;
}
int tmp_fd = (int)THPUtils_unpackLong(_tmp_fd);
int tmp_fd = THPUtils_unpackInt(_tmp_fd);
int64_t size = THPUtils_unpackLong(_size);
int fd = dup(tmp_fd);
if (fd == -1) {
@ -312,8 +312,8 @@ static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
auto shandle =
c10::cuda::CUDACachingAllocator::shareIpcHandle(storage.mutable_data());
_handle = PyBytes_FromStringAndSize(
shandle.handle.c_str(), (Py_ssize_t)shandle.handle.size());
_offset_bytes = PyLong_FromSsize_t((Py_ssize_t)shandle.offset);
shandle.handle.c_str(), static_cast<Py_ssize_t>(shandle.handle.size()));
_offset_bytes = PyLong_FromSsize_t(static_cast<Py_ssize_t>(shandle.offset));
// Put Storage Data behind new ref counting context
// See Note [CUDA IPC Refcounting implementation explained]
@ -334,7 +334,7 @@ static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
}
_event_handle = PyBytes_FromStringAndSize(
(char*)&ipc_event_handle, CUDA_IPC_HANDLE_SIZE);
reinterpret_cast<const char*>(&ipc_event_handle), CUDA_IPC_HANDLE_SIZE);
_event_sync_required = PyBool_FromLong(sent_data->event_sync_required_);
}
@ -385,7 +385,7 @@ static PyObject* THPStorage_releaseIPCCounter(
}
std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter);
ptrdiff_t ref_counter_offset =
(ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset);
static_cast<ptrdiff_t>(THPUtils_unpackLong(_ref_counter_offset));
// We don't want to break existing code, so resource deletion is best
// effort basis. Exception expected if producer process terminated
// before consumer released data.
@ -446,10 +446,9 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
return nullptr;
}
size_t storage_size =
(size_t)THPUtils_unpackLong(_size_bytes) / sizeof(uint8_t);
size_t storage_size = THPUtils_unpackUInt64(_size_bytes) / sizeof(uint8_t);
ptrdiff_t storage_offset_bytes =
(ptrdiff_t)THPUtils_unpackLong(_offset_bytes);
static_cast<ptrdiff_t>(THPUtils_unpackLong(_offset_bytes));
const auto device = c10::checked_convert<c10::DeviceIndex>(
THPUtils_unpackLong(_device), "c10::DeviceIndex");
@ -480,11 +479,11 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
// Offset the basePtr to reconstruct the real storage
// devPtr = basePtr + storage_offset
void* devPtr = basePtr.get();
devPtr = (char*)devPtr + storage_offset_bytes;
devPtr = static_cast<char*>(devPtr) + storage_offset_bytes;
std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter);
ptrdiff_t ref_counter_offset =
(ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset);
static_cast<ptrdiff_t>(THPUtils_unpackLong(_ref_counter_offset));
struct IpcDeleterContext {
std::string ref_counter_handle;
@ -578,7 +577,8 @@ static PyObject* THPStorage_newWithWeakPtr(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
THPUtils_checkLong(arg), "_new_with_weak_ptr(): arg must be an 'int'");
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
c10::StorageImpl* weak_storage =
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(arg));
if (auto* storage = c10::raw::weak_intrusive_ptr::lock(weak_storage)) {
return THPStorage_Wrap(
c10::intrusive_ptr<c10::StorageImpl>::reclaim(storage));
@ -594,7 +594,8 @@ static PyObject* THPStorage_freeWeakRef(PyObject* _unused, PyObject* arg) {
}
TORCH_CHECK(
THPUtils_checkLong(arg), "_free_weak_ref(): arg must be an 'int'");
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
c10::StorageImpl* weak_storage =
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(arg));
c10::raw::weak_intrusive_ptr::decref(weak_storage);
Py_RETURN_NONE;
@ -604,7 +605,8 @@ static PyObject* THPStorage_freeWeakRef(PyObject* _unused, PyObject* arg) {
static PyObject* THPStorage_expired(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "_expired(): arg must be an 'int'");
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
c10::StorageImpl* weak_storage =
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(arg));
return PyBool_FromLong(
c10::raw::weak_intrusive_ptr::use_count(weak_storage) == 0);
END_HANDLE_TH_ERRORS

View File

@ -74,7 +74,7 @@ static PyObject* THPStream_pynew(
return nullptr;
}
THPStream* self = (THPStream*)ptr.get();
THPStream* self = reinterpret_cast<THPStream*>(ptr.get());
// If torch.Stream is not created from existing Stream, then create a new one.
// It requires other device backends override getNewStream method. How the new
@ -96,7 +96,7 @@ static PyObject* THPStream_pynew(
self->device_type = static_cast<int64_t>(stream_opt->device_type());
self->context = nullptr;
return (PyObject*)ptr.release();
return static_cast<PyObject*>(ptr.release());
END_HANDLE_TH_ERRORS
}
@ -108,7 +108,7 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
throw python_error();
}
THPStream* self = (THPStream*)ptr.get();
THPStream* self = reinterpret_cast<THPStream*>(ptr.get());
self->stream_id = stream.id();
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
self->device_index = static_cast<int64_t>(stream.device_index());
@ -119,7 +119,7 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
}
static void THPStream_dealloc(THPStream* self) {
Py_TYPE(self)->tp_free((PyObject*)self);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
static PyObject* THPStream_get_device(THPStream* self, void* unused) {
@ -132,7 +132,7 @@ static PyObject* THPStream_get_device(THPStream* self, void* unused) {
static PyObject* THPStream_query(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPStream*)_self;
auto self = reinterpret_cast<THPStream*>(_self);
return PyBool_FromLong(c10::Stream::unpack3(
self->stream_id,
@ -146,7 +146,7 @@ static PyObject* THPStream_query(PyObject* _self, PyObject* noargs) {
static PyObject* THPStream_synchronize(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS {
pybind11::gil_scoped_release no_gil;
auto self = (THPStream*)_self;
auto self = reinterpret_cast<THPStream*>(_self);
c10::Stream::unpack3(
self->stream_id,
@ -160,8 +160,8 @@ static PyObject* THPStream_synchronize(PyObject* _self, PyObject* noargs) {
static PyObject* THPStream_wait_event(PyObject* _self, PyObject* _event) {
HANDLE_TH_ERRORS {
auto self = (THPStream*)_self;
auto event = (THPEvent*)_event;
auto self = reinterpret_cast<THPStream*>(_self);
auto event = reinterpret_cast<THPEvent*>(_event);
c10::Stream::unpack3(
self->stream_id,
static_cast<c10::DeviceIndex>(self->device_index),
@ -174,8 +174,8 @@ static PyObject* THPStream_wait_event(PyObject* _self, PyObject* _event) {
static PyObject* THPStream_wait_stream(PyObject* _self, PyObject* _other) {
HANDLE_TH_ERRORS {
auto self = (THPStream*)_self;
auto other_stream = (THPStream*)_other;
auto self = reinterpret_cast<THPStream*>(_self);
auto other_stream = reinterpret_cast<THPStream*>(_other);
c10::Event new_event(
static_cast<c10::DeviceType>(other_stream->device_type),
c10::EventFlag::PYTORCH_DEFAULT);
@ -198,7 +198,7 @@ static PyObject* THPStream_record_event(
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
auto self = (THPStream*)_self;
auto self = reinterpret_cast<THPStream*>(_self);
PyObject* _new_event = nullptr;
PyObject* _event = Py_None;
@ -222,13 +222,13 @@ static PyObject* THPStream_record_event(
static_cast<c10::DeviceType>(self->device_type),
c10::EventFlag::PYTORCH_DEFAULT);
}
auto new_event = (THPEvent*)_new_event;
auto new_event = reinterpret_cast<THPEvent*>(_new_event);
TORCH_CHECK(new_event, "event must not be null");
new_event->event.record(c10::Stream::unpack3(
self->stream_id,
static_cast<c10::DeviceIndex>(self->device_index),
static_cast<c10::DeviceType>(self->device_type)));
return (PyObject*)new_event;
return reinterpret_cast<PyObject*>(new_event);
END_HANDLE_TH_ERRORS
}
@ -260,7 +260,7 @@ static PyObject* THPStream_eq(THPStream* self, THPStream* other) {
static PyObject* THPStream_enter(PyObject* _self, PyObject* unused) {
HANDLE_TH_ERRORS
auto self = (THPStream*)_self;
auto self = reinterpret_cast<THPStream*>(_self);
c10::DeviceType stream_device_type =
static_cast<c10::DeviceType>(self->device_type);
// No operation is performed if the stream does not belong to an accelerator.
@ -304,7 +304,7 @@ static PyObject* THPStream_enter(PyObject* _self, PyObject* unused) {
static PyObject* THPStream_exit(PyObject* _self, PyObject* unused) {
HANDLE_TH_ERRORS
auto self = (THPStream*)_self;
auto self = reinterpret_cast<THPStream*>(_self);
// No operation is performed if the stream does not belong to an accelerator.
if (C10_UNLIKELY(!at::accelerator::isAccelerator(
static_cast<c10::DeviceType>(self->device_type)))) {
@ -323,7 +323,7 @@ static PyObject* THPStream_exit(PyObject* _self, PyObject* unused) {
auto ctx_device_index = THPObjectPtr(py_device_index);
TORCH_INTERNAL_ASSERT(
ctx_stream.get(), "ctx_stream should be present on the context dict.");
auto prev_stream = (THPStream*)(ctx_stream.get());
auto prev_stream = reinterpret_cast<THPStream*>(ctx_stream.get());
TORCH_INTERNAL_ASSERT(
ctx_device_index.get(),
"ctx_device_index should be present on the context dict.");
@ -360,10 +360,14 @@ static PyObject* THPStream_richcompare(
} else {
switch (op) {
case Py_EQ:
result = THPStream_eq((THPStream*)self, (THPStream*)other);
result = THPStream_eq(
reinterpret_cast<THPStream*>(self),
reinterpret_cast<THPStream*>(other));
break;
case Py_NE:
result = THPStream_ne((THPStream*)self, (THPStream*)other);
result = THPStream_ne(
reinterpret_cast<THPStream*>(self),
reinterpret_cast<THPStream*>(other));
break;
default:
result = Py_False;
@ -393,7 +397,11 @@ static const std::initializer_list<PyMemberDef> THPStream_members = {
{nullptr}};
static const std::initializer_list<PyGetSetDef> THPStream_properties = {
{"device", (getter)THPStream_get_device, nullptr, nullptr, nullptr},
{"device",
reinterpret_cast<getter>(THPStream_get_device),
nullptr,
nullptr,
nullptr},
{nullptr}};
static const std::initializer_list<PyMethodDef> THPStream_methods = {
@ -405,7 +413,7 @@ static const std::initializer_list<PyMethodDef> THPStream_methods = {
castPyCFunctionWithKeywords(THPStream_record_event),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr},
{"__eq__", reinterpret_cast<PyCFunction>(THPStream_eq), METH_O, nullptr},
{"__enter__", THPStream_enter, METH_NOARGS, nullptr},
{"__exit__", THPStream_exit, METH_VARARGS, nullptr},
{nullptr}};
@ -415,16 +423,16 @@ static PyTypeObject THPStreamType = {
"torch.Stream", /* tp_name */
sizeof(THPStream), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)THPStream_dealloc, /* tp_dealloc */
reinterpret_cast<destructor>(THPStream_dealloc), /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPStream_repr, /* tp_repr */
reinterpret_cast<reprfunc>(THPStream_repr), /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
(hashfunc)THPStream_hash, /* tp_hash */
reinterpret_cast<hashfunc>(THPStream_hash), /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
@ -462,7 +470,8 @@ void THPStream_init(PyObject* module) {
throw python_error();
}
Py_INCREF(&THPStreamType);
if (PyModule_AddObject(module, "Stream", (PyObject*)&THPStreamType) < 0) {
if (PyModule_AddObject(
module, "Stream", reinterpret_cast<PyObject*>(&THPStreamType)) < 0) {
throw python_error();
}
}

View File

@ -273,18 +273,34 @@ static PyObject* THPIInfo_str(THPIInfo* self) {
}
static const std::initializer_list<PyGetSetDef> THPFInfo_properties = {
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
{"eps", (getter)THPFInfo_eps, nullptr, nullptr, nullptr},
{"max", (getter)THPFInfo_max, nullptr, nullptr, nullptr},
{"min", (getter)THPFInfo_min, nullptr, nullptr, nullptr},
{"smallest_normal",
(getter)THPFInfo_smallest_normal,
{"bits",
reinterpret_cast<getter>(THPDTypeInfo_bits),
nullptr,
nullptr,
nullptr},
{"eps", reinterpret_cast<getter>(THPFInfo_eps), nullptr, nullptr, nullptr},
{"max", reinterpret_cast<getter>(THPFInfo_max), nullptr, nullptr, nullptr},
{"min", reinterpret_cast<getter>(THPFInfo_min), nullptr, nullptr, nullptr},
{"smallest_normal",
reinterpret_cast<getter>(THPFInfo_smallest_normal),
nullptr,
nullptr,
nullptr},
{"tiny",
reinterpret_cast<getter>(THPFInfo_tiny),
nullptr,
nullptr,
nullptr},
{"resolution",
reinterpret_cast<getter>(THPFInfo_resolution),
nullptr,
nullptr,
nullptr},
{"dtype",
reinterpret_cast<getter>(THPFInfo_dtype),
nullptr,
nullptr,
nullptr},
{"tiny", (getter)THPFInfo_tiny, nullptr, nullptr, nullptr},
{"resolution", (getter)THPFInfo_resolution, nullptr, nullptr, nullptr},
{"dtype", (getter)THPFInfo_dtype, nullptr, nullptr, nullptr},
{nullptr}};
PyTypeObject THPFInfoType = {
@ -297,13 +313,13 @@ PyTypeObject THPFInfoType = {
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPFInfo_str, /* tp_repr */
reinterpret_cast<reprfunc>(THPFInfo_str), /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
(reprfunc)THPFInfo_str, /* tp_str */
reinterpret_cast<reprfunc>(THPFInfo_str), /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
@ -311,7 +327,7 @@ PyTypeObject THPFInfoType = {
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
(richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
reinterpret_cast<richcmpfunc>(THPDTypeInfo_compare), /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
@ -330,10 +346,18 @@ PyTypeObject THPFInfoType = {
};
static const std::initializer_list<PyGetSetDef> THPIInfo_properties = {
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
{"max", (getter)THPIInfo_max, nullptr, nullptr, nullptr},
{"min", (getter)THPIInfo_min, nullptr, nullptr, nullptr},
{"dtype", (getter)THPIInfo_dtype, nullptr, nullptr, nullptr},
{"bits",
reinterpret_cast<getter>(THPDTypeInfo_bits),
nullptr,
nullptr,
nullptr},
{"max", reinterpret_cast<getter>(THPIInfo_max), nullptr, nullptr, nullptr},
{"min", reinterpret_cast<getter>(THPIInfo_min), nullptr, nullptr, nullptr},
{"dtype",
reinterpret_cast<getter>(THPIInfo_dtype),
nullptr,
nullptr,
nullptr},
{nullptr}};
PyTypeObject THPIInfoType = {
@ -346,13 +370,13 @@ PyTypeObject THPIInfoType = {
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPIInfo_str, /* tp_repr */
reinterpret_cast<reprfunc>(THPIInfo_str), /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
(reprfunc)THPIInfo_str, /* tp_str */
reinterpret_cast<reprfunc>(THPIInfo_str), /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
@ -360,7 +384,7 @@ PyTypeObject THPIInfoType = {
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
(richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
reinterpret_cast<richcmpfunc>(THPDTypeInfo_compare), /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
@ -383,14 +407,16 @@ void THPDTypeInfo_init(PyObject* module) {
throw python_error();
}
Py_INCREF(&THPFInfoType);
if (PyModule_AddObject(module, "finfo", (PyObject*)&THPFInfoType) != 0) {
if (PyModule_AddObject(
module, "finfo", reinterpret_cast<PyObject*>(&THPFInfoType)) != 0) {
throw python_error();
}
if (PyType_Ready(&THPIInfoType) < 0) {
throw python_error();
}
Py_INCREF(&THPIInfoType);
if (PyModule_AddObject(module, "iinfo", (PyObject*)&THPIInfoType) != 0) {
if (PyModule_AddObject(
module, "iinfo", reinterpret_cast<PyObject*>(&THPIInfoType)) != 0) {
throw python_error();
}
}

View File

@ -25,7 +25,7 @@ c10::intrusive_ptr<rpc::Message> CleanupAutogradContextReq::toMessageImpl() && {
std::unique_ptr<CleanupAutogradContextReq> CleanupAutogradContextReq::
fromMessage(const rpc::Message& message) {
// unpickle and get the context_id we need to clean up
auto payload = static_cast<const char*>(message.payload().data());
auto payload = message.payload().data();
auto payload_size = message.payload().size();
IValue ivalue_context_id = jit::unpickle(
payload,

View File

@ -47,7 +47,7 @@ c10::intrusive_ptr<Message> PropagateGradientsReq::toMessageImpl() && {
std::unique_ptr<PropagateGradientsReq> PropagateGradientsReq::fromMessage(
const Message& message) {
// Unpickle the message and retrieve tupleElements.
auto payload = static_cast<const char*>(message.payload().data());
auto payload = message.payload().data();
auto payload_size = message.payload().size();
IValue tuple = jit::unpickle(
payload,

View File

@ -37,7 +37,7 @@ c10::intrusive_ptr<Message> RRefBackwardReq::toMessageImpl() && {
std::unique_ptr<RRefBackwardReq> RRefBackwardReq::fromMessage(
const Message& message) {
// Unpickle the message and retrieve tupleElements.
auto payload = static_cast<const char*>(message.payload().data());
auto payload = message.payload().data();
auto payload_size = message.payload().size();
IValue tuple = jit::unpickle(
payload,

View File

@ -225,7 +225,7 @@ class File {
while (count > 0) {
auto rv = syscall([this, buf, count] { return ::read(fd_, buf, count); });
SYSASSERT(rv, "read");
buf = (uint8_t*)buf + rv;
buf = static_cast<uint8_t*>(buf) + rv;
count -= rv;
}
}

View File

@ -2476,7 +2476,7 @@ static at::Tensor& checkSingleTensor(std::vector<at::Tensor>& tensors) {
static uint32_t checkTag(int32_t tag) {
TORCH_CHECK(tag >= 0, "Tag must be nonnegative");
return (uint32_t)tag;
return static_cast<uint32_t>(tag);
}
c10::intrusive_ptr<Work> ProcessGroupGloo::send(

View File

@ -207,7 +207,7 @@ class SendBuffer {
SendBuffer(detail::TCPClient& client, detail::QueryType cmd)
: client(client) {
buffer.reserve(32); // enough for most commands
buffer.push_back((uint8_t)cmd);
buffer.push_back(static_cast<uint8_t>(cmd));
}
void appendString(const std::string& str) {
@ -224,7 +224,7 @@ class SendBuffer {
template <typename T>
void appendValue(T value) {
uint8_t* begin = (uint8_t*)&value;
uint8_t* begin = reinterpret_cast<uint8_t*>(&value);
buffer.insert(buffer.end(), begin, begin + sizeof(T));
maybeFlush();
}

View File

@ -36,14 +36,14 @@ Other callbacks don't provide exception safety so avoid there.
// backlog. This should be at least world size to avoid issues on init. We set
// it to -1 to use the host max value which is controlled by `soconnmax`.
auto constexpr DEFAULT_BACKLOG = -1;
auto constexpr MAX_KEY_COUNT = size_t(128 * 1024);
auto constexpr MAX_KEY_COUNT = static_cast<size_t>(128 * 1024);
auto constexpr MAX_STRING_LEN = 8 * 1024;
auto constexpr MAX_PAYLOAD_LEN = 8 * 1024 * 1024;
// This controls the preferred size for buffers.
// Too small and we'll need multiple buffers for one request
// Too big and we might taxing malloc
auto constexpr ALLOC_BUFFER_SIZE = size_t(4096);
auto constexpr ALLOC_BUFFER_SIZE = static_cast<size_t>(4096);
class UvHandle : public c10::intrusive_ptr_target {
public:
~UvHandle() override = default;
@ -78,7 +78,7 @@ class UvHandle : public c10::intrusive_ptr_target {
private:
static c10::intrusive_ptr<UvHandle> reclaim(uv_handle_t* handle) {
auto h = (UvHandle*)uv_handle_get_data(handle);
auto h = static_cast<UvHandle*>(uv_handle_get_data(handle));
return c10::intrusive_ptr<UvHandle>::reclaim(h);
}
@ -97,7 +97,8 @@ class UvTcpSocket : public UvHandle {
}
static c10::intrusive_ptr<UvTcpSocket> borrow(uv_stream_t* handle) {
auto h = (UvTcpSocket*)uv_handle_get_data((uv_handle_t*)handle);
auto h = static_cast<UvTcpSocket*>(
uv_handle_get_data(reinterpret_cast<uv_handle_t*>(handle)));
return h->iptr();
}
@ -107,7 +108,7 @@ class UvTcpSocket : public UvHandle {
uv_buf_t* buf) {
suggested_size = std::min(suggested_size, ALLOC_BUFFER_SIZE);
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
buf->base = (char*)malloc(suggested_size);
buf->base = static_cast<char*>(malloc(suggested_size));
buf->len = suggested_size;
}
@ -168,7 +169,8 @@ class UvTcpSocket : public UvHandle {
formatSockAddr(reinterpret_cast<struct ::sockaddr*>(&addr), addrLen);
}
int res = uv_read_start((uv_stream_t*)&client, alloc_buffer, read_callback);
int res = uv_read_start(
reinterpret_cast<uv_stream_t*>(&client), alloc_buffer, read_callback);
if (res) {
C10D_WARNING(
"Failed to setup read callback. client:{} code:{} name:{} desc:{}.",
@ -181,12 +183,12 @@ class UvTcpSocket : public UvHandle {
}
uv_handle_t* unsafeGetHandle() override {
return (uv_handle_t*)&client;
return reinterpret_cast<uv_handle_t*>(&client);
}
protected:
uv_stream_t* unsafeGetStream() {
return (uv_stream_t*)&client;
return reinterpret_cast<uv_stream_t*>(&client);
}
uv_tcp_t* unsafeGetSocket() {
@ -217,7 +219,7 @@ class UvTcpServer : public UvTcpSocket {
auto res = c10::make_intrusive<UvTcpServer>(loop);
res->handleReady();
try {
int uv_res = uv_tcp_open((uv_tcp_t*)res->unsafeGetStream(), socket);
int uv_res = uv_tcp_open(res->unsafeGetSocket(), socket);
C10D_CHECK_WITH(
SocketError,
uv_res == 0,
@ -266,9 +268,11 @@ class UvTcpServer : public UvTcpSocket {
struct sockaddr_storage addr{};
int uv_res = 0;
if (useIpv6) {
uv_res = uv_ip6_addr("::", port, (struct sockaddr_in6*)&addr);
uv_res = uv_ip6_addr(
"::", port, reinterpret_cast<struct sockaddr_in6*>(&addr));
} else {
uv_res = uv_ip4_addr("0.0.0.0", port, (struct sockaddr_in*)&addr);
uv_res = uv_ip4_addr(
"0.0.0.0", port, reinterpret_cast<struct sockaddr_in*>(&addr));
}
TORCH_CHECK_WITH(
DistStoreError,
@ -286,7 +290,9 @@ class UvTcpServer : public UvTcpSocket {
uv_strerror(uv_res));
uv_res = uv_tcp_bind(
res->unsafeGetSocket(), (const struct ::sockaddr*)&addr, 0);
res->unsafeGetSocket(),
reinterpret_cast<const struct ::sockaddr*>(&addr),
0);
C10D_CHECK_WITH(
SocketError,
uv_res == 0,
@ -329,8 +335,9 @@ class UvTcpServer : public UvTcpSocket {
}
void accept(const c10::intrusive_ptr<UvTcpSocket>& socket) {
int res =
uv_accept(unsafeGetStream(), (uv_stream_t*)socket->unsafeGetHandle());
int res = uv_accept(
unsafeGetStream(),
reinterpret_cast<uv_stream_t*>(socket->unsafeGetHandle()));
C10D_CHECK_WITH(
SocketError,
res == 0,
@ -352,7 +359,8 @@ class UvTcpServer : public UvTcpSocket {
}
static c10::intrusive_ptr<UvTcpServer> borrow(uv_stream_t* handle) {
auto h = (UvTcpServer*)uv_handle_get_data((uv_handle_t*)handle);
auto h = static_cast<UvTcpServer*>(
uv_handle_get_data(reinterpret_cast<uv_handle_t*>(handle)));
return h->iptr();
}
@ -389,7 +397,8 @@ class WriterPayload : public c10::intrusive_ptr_target {
static c10::intrusive_ptr<WriterPayload> reclaim(uv_write_t* request) {
/* This method returns a intrusive_ptr that does not increase the refcount.
*/
auto h = (WriterPayload*)uv_req_get_data((uv_req_t*)request);
auto h = static_cast<WriterPayload*>(
uv_req_get_data(reinterpret_cast<uv_req_t*>(request)));
return c10::intrusive_ptr<WriterPayload>::reclaim(h);
}
@ -427,15 +436,19 @@ class WriterPayload : public c10::intrusive_ptr_target {
std::vector<uint8_t>&& in_data,
c10::intrusive_ptr<UvHandle> handle)
: data(std::move(in_data)), handle(std::move(handle)) {
uv_req_set_data((uv_req_t*)&req, this);
uv_req_set_data(reinterpret_cast<uv_req_t*>(&req), this);
}
~WriterPayload() override = default;
void send() {
buf = uv_buf_init((char*)data.data(), data.size());
buf = uv_buf_init(reinterpret_cast<char*>(data.data()), data.size());
int res = uv_write(
&req, (uv_stream_t*)handle->unsafeGetHandle(), &buf, 1, write_done);
&req,
reinterpret_cast<uv_stream_t*>(handle->unsafeGetHandle()),
&buf,
1,
write_done);
if (res) {
C10D_WARNING(
@ -584,7 +597,7 @@ class ChunkedStream {
if (available() < size)
return false;
str.resize(size);
return read_many((char*)str.data(), size);
return read_many(str.data(), size);
}
bool read_payload(std::vector<uint8_t>& data) {
@ -604,7 +617,7 @@ class ChunkedStream {
if (available() < size_in_bytes)
return false;
data.resize(size);
return read_many((char*)data.data(), size_in_bytes);
return read_many(reinterpret_cast<char*>(data.data()), size_in_bytes);
}
size_t available() {
@ -703,15 +716,15 @@ class LibUVStoreDaemon : public BackgroundThread {
int port_;
static LibUVStoreDaemon& from_uv(uv_handle_t* stream) {
return *(LibUVStoreDaemon*)uv_handle_get_data(stream);
return *static_cast<LibUVStoreDaemon*>(uv_handle_get_data(stream));
}
static void on_new_connection(uv_stream_t* server, int status) {
from_uv((uv_handle_t*)server).onConnect(status);
from_uv(reinterpret_cast<uv_handle_t*>(server)).onConnect(status);
}
static void on_exit_request(uv_async_t* handle) {
from_uv((uv_handle_t*)handle).onExitRequest();
from_uv(reinterpret_cast<uv_handle_t*>(handle)).onExitRequest();
}
void onConnect(int status);
@ -739,12 +752,12 @@ class UvClient : public UvTcpSocket {
if (!stream.read1(command))
break;
if (store->isMiscellaneousClient(iptr())) {
if ((QueryType)command != QueryType::VALIDATE)
if (static_cast<QueryType>(command) != QueryType::VALIDATE)
return;
if (!parse_validate_command())
return;
} else {
switch ((QueryType)command) {
switch (static_cast<QueryType>(command)) {
case QueryType::PING:
if (!parse_ping_command())
return;
@ -983,7 +996,7 @@ class UvClient : public UvTcpSocket {
if (store->waitKeys(keys, iptr())) {
StreamWriter sw(iptr());
sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
sw.write1(static_cast<uint8_t>(WaitResponseType::STOP_WAITING));
sw.send();
}
@ -1102,7 +1115,7 @@ class UvClient : public UvTcpSocket {
C10D_TRACE("cancel_wait address:{}", this->address());
StreamWriter sw(iptr());
sw.write1((uint8_t)WaitResponseType::WAIT_CANCELED);
sw.write1(static_cast<uint8_t>(WaitResponseType::WAIT_CANCELED));
sw.send();
return true;
@ -1187,7 +1200,7 @@ void LibUVStoreDaemon::onConnect(int status) {
void LibUVStoreDaemon::onExitRequest() {
C10D_DEBUG("Store exit requested\n");
uv_close((uv_handle_t*)&exit_handle_, nullptr);
uv_close(reinterpret_cast<uv_handle_t*>(&exit_handle_), nullptr);
uv_stop(&loop_);
}
@ -1228,12 +1241,12 @@ LibUVStoreDaemon::LibUVStoreDaemon(int port) : port_(port) {
uv_async_init(&loop_, &exit_handle_, LibUVStoreDaemon::on_exit_request) ==
0,
"Failed to init uv async event");
uv_handle_set_data((uv_handle_t*)&exit_handle_, this);
uv_handle_set_data(reinterpret_cast<uv_handle_t*>(&exit_handle_), this);
}
LibUVStoreDaemon::~LibUVStoreDaemon() {
if (!is_running()) {
uv_close((uv_handle_t*)&exit_handle_, nullptr);
uv_close(reinterpret_cast<uv_handle_t*>(&exit_handle_), nullptr);
uv_run(&loop_, UV_RUN_NOWAIT);
if (uv_loop_close(&loop_) != 0) {
C10D_ERROR("loop cleanup didn't work");
@ -1477,7 +1490,7 @@ void LibUVStoreDaemon::wakeupWaitingClients(const std::string& key) {
for (const auto& client : socketsToWait->second) {
if (--keysAwaited_[client] == 0) {
StreamWriter sw(client->iptr());
sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
sw.write1(static_cast<uint8_t>(WaitResponseType::STOP_WAITING));
sw.send();
}
}
@ -1491,7 +1504,7 @@ void LibUVStoreDaemon::wakeupOneWaitingClient(const std::string& key) {
for (const auto& client : socketsToWait->second) {
if (--keysAwaited_[client] == 0) {
StreamWriter sw(client->iptr());
sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
sw.write1(static_cast<uint8_t>(WaitResponseType::STOP_WAITING));
sw.send();
return;
}

View File

@ -443,7 +443,8 @@ PyTypeObject* GetReduceOpMetaclass() {
spec.basicsize = base_metaclass->tp_basicsize;
spec.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
spec.slots = slots;
PyTypeObject* metaclass = (PyTypeObject*)PyType_FromSpec(&spec);
PyTypeObject* metaclass =
reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
if (!metaclass)
throw py::error_already_set();
return metaclass;
@ -812,7 +813,10 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
// `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol.
// https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
py::class_<::c10d::ReduceOp> reduce_op(
module, "ReduceOp", py::metaclass((PyObject*)GetReduceOpMetaclass()), R"(
module,
"ReduceOp",
py::metaclass(reinterpret_cast<PyObject*>(GetReduceOpMetaclass())),
R"(
An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``,
``MIN``, ``MAX``, ``BAND``, ``BOR``, ``BXOR``, and ``PREMUL_SUM``.

View File

@ -136,9 +136,9 @@ Reducer::Reducer(
{
std::set<int> unique_devices;
for (const auto& v : params_) {
auto device_idx = int(v.device().index());
if (unique_devices.find(device_idx) == unique_devices.end()) {
unique_devices.insert(device_idx);
auto device_idx = static_cast<int>(v.device().index());
auto [_, inserted] = unique_devices.emplace(device_idx);
if (inserted) {
if (unique_devices.size() > 1) {
is_multi_device_module_ = true;
break;
@ -168,7 +168,7 @@ Reducer::Reducer(
}
// All variables are expected to have their `grad_fn` set to the gradient
// accumulation function (since they are leafs in the autograd graph).
// accumulation function (since they are leaves in the autograd graph).
// We store pointers to these functions such that we can check if they are
// used in an autograd pass. If they are not, we know their grad tensors
// can be marked as ready for reduction.

View File

@ -76,7 +76,7 @@ class CudaTimer : public Timer {
if (milliseconds < 0) {
return std::nullopt;
}
return int64_t(milliseconds * kMilliSecondToNanosSecond);
return static_cast<int64_t>(milliseconds * kMilliSecondToNanosSecond);
}
};

View File

@ -220,7 +220,7 @@ std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) {
}
// if we can't resolve the hostname, display the IP address
if (addr->sa_family == AF_INET) {
struct sockaddr_in* psai = (struct sockaddr_in*)&addr;
struct sockaddr_in* psai = reinterpret_cast<struct sockaddr_in*>(&addr);
// NOLINTNEXTLINE(*array*)
char ip[INET_ADDRSTRLEN];
if (inet_ntop(addr->sa_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) !=
@ -228,7 +228,7 @@ std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) {
return fmt::format("{}:{}", ip, psai->sin_port);
}
} else if (addr->sa_family == AF_INET6) {
struct sockaddr_in6* psai = (struct sockaddr_in6*)&addr;
struct sockaddr_in6* psai = reinterpret_cast<struct sockaddr_in6*>(&addr);
// NOLINTNEXTLINE(*array*)
char ip[INET6_ADDRSTRLEN];
if (inet_ntop(addr->sa_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) !=

View File

@ -178,7 +178,7 @@ std::vector<int> IpcChannel::all_gather_fds(
int rank,
const std::vector<int>& pids,
int fd) {
int world_size = (int)pids.size();
int world_size = static_cast<int>(pids.size());
std::vector<int> fds(pids.size());
fds[rank] = fd;
@ -197,7 +197,7 @@ int IpcChannel::broadcast_fds(
int src_rank,
const std::vector<int>& pids,
int fd) {
int world_size = (int)pids.size();
int world_size = static_cast<int>(pids.size());
if (rank == src_rank) {
for (int dst_rank = 0; dst_rank < world_size; ++dst_rank) {

View File

@ -125,7 +125,7 @@ static at::Tensor empty_strided_p2p_persistent(
const size_t numel = std::accumulate(
size.begin(),
size.end(),
size_t(1),
static_cast<size_t>(1),
// NOLINTNEXTLINE(modernize-use-transparent-functors)
std::multiplies<size_t>());
const size_t element_size = c10::elementSize(dtype);
@ -230,7 +230,7 @@ at::Tensor empty_strided_p2p(
const size_t numel = std::accumulate(
size.begin(),
size.end(),
size_t(1),
static_cast<size_t>(1),
// NOLINTNEXTLINE(modernize-use-transparent-functors)
std::multiplies<size_t>());
const size_t element_size = c10::elementSize(dtype);

View File

@ -23,7 +23,8 @@ std::unordered_map<std::string, worker_id_t> collectNames(
}
std::vector<uint8_t> workerNameVector = store.get(std::to_string(workerId));
std::string workerName(
(char*)workerNameVector.data(), workerNameVector.size());
reinterpret_cast<char*>(workerNameVector.data()),
workerNameVector.size());
TORCH_CHECK(
nameToId.find(workerName) == nameToId.end(),
@ -91,7 +92,8 @@ std::unordered_map<std::string, worker_id_t> collectCurrentNames(
// Get the current list of workers
std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey);
allWorkerInfos = std::string(
(char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size());
reinterpret_cast<const char*>(allWorkerInfosKeyVector.data()),
allWorkerInfosKeyVector.size());
// workerInfos are comma separated with a comma at the end (e.g.
// "Name1-Rank1,Name2-Rank2,Name3-Rank2,") parse list of workers.
if (!allWorkerInfos.empty()) {
@ -132,7 +134,8 @@ void removeCurrentName(
// Get current list of names/ranks
std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey);
std::string allWorkerInfos = std::string(
(char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size());
reinterpret_cast<const char*>(allWorkerInfosKeyVector.data()),
allWorkerInfosKeyVector.size());
// Remove the current name and rank
std::string str_to_erase = fmt::format("{}-{},", selfName, selfId);

View File

@ -149,13 +149,13 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (RpcAgent::*)(void) const) &
RpcAgent::getWorkerInfo,
static_cast<const WorkerInfo& (RpcAgent::*)(void) const>(
&RpcAgent::getWorkerInfo),
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (RpcAgent::*)(const std::string&) const) &
RpcAgent::getWorkerInfo,
static_cast<const WorkerInfo& (RpcAgent::*)(const std::string&)
const>(&RpcAgent::getWorkerInfo),
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_infos",
@ -611,28 +611,28 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (TensorPipeAgent::*)(void) const) &
RpcAgent::getWorkerInfo,
static_cast<const WorkerInfo& (TensorPipeAgent::*)(void) const>(
&RpcAgent::getWorkerInfo),
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
TensorPipeAgent::getWorkerInfo,
static_cast<const WorkerInfo& (TensorPipeAgent::*)(const std::string&)
const>(&TensorPipeAgent::getWorkerInfo),
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
TensorPipeAgent::getWorkerInfo,
static_cast<const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id)
const>(&TensorPipeAgent::getWorkerInfo),
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_infos",
(std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
TensorPipeAgent::getWorkerInfos,
static_cast<std::vector<WorkerInfo> (TensorPipeAgent::*)() const>(
&TensorPipeAgent::getWorkerInfos),
py::call_guard<py::gil_scoped_release>())
.def(
"_get_device_map",
(DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst)
const)&TensorPipeAgent::getDeviceMap,
static_cast<DeviceMap (TensorPipeAgent::*)(const WorkerInfo& dst)
const>(&TensorPipeAgent::getDeviceMap),
py::call_guard<py::gil_scoped_release>())
.def(
"_get_backend_options",

View File

@ -32,7 +32,7 @@ c10::intrusive_ptr<Message> PythonRemoteCall::toMessageImpl() && {
std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload = message.payload().data();
auto payload_size = message.payload().size();
auto value = jit::unpickle(

View File

@ -74,7 +74,7 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage(
[this,
// std::function must be copyable, hence hae to cast the unique_ptr to
// a shared_ptr here.
rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
rpc = std::shared_ptr<RpcCommandBase>(std::move(rpc)),
messageType = request.type(),
streams = std::move(streams)](JitFuture& /* unused */) mutable {
// The cost of pre-request check is minimal thanks to

View File

@ -13,7 +13,7 @@ RegisterWorkerInfoOnce::RegisterWorkerInfoOnce() {
}
WorkerInfo::WorkerInfo(std::string name, int64_t id)
: WorkerInfo(std::move(name), (worker_id_t)id) {
: WorkerInfo(std::move(name), static_cast<worker_id_t>(id)) {
TORCH_CHECK(
id <= std::numeric_limits<worker_id_t>::max(),
"RPC worker id ",

View File

@ -15,7 +15,7 @@ c10::ivalue::TupleElements toIValues(const Message& message, MessageType type) {
type,
", but got ",
message.type());
auto payload = static_cast<const char*>(message.payload().data());
auto payload = message.payload().data();
auto payload_size = message.payload().size();
auto value = jit::unpickle(
@ -87,7 +87,7 @@ std::unique_ptr<ScriptRRefFetchCall> ScriptRRefFetchCall::fromMessage(
id <= std::numeric_limits<worker_id_t>::max(),
"ScriptRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
return std::make_unique<ScriptRRefFetchCall>(
worker_id_t(id), RRefId::fromIValue(values[0]));
static_cast<worker_id_t>(id), RRefId::fromIValue(values[0]));
}
c10::intrusive_ptr<Message> PythonRRefFetchCall::toMessageImpl() && {
@ -109,7 +109,7 @@ std::unique_ptr<PythonRRefFetchCall> PythonRRefFetchCall::fromMessage(
id <= std::numeric_limits<worker_id_t>::max(),
"PythonRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
return std::make_unique<PythonRRefFetchCall>(
worker_id_t(id), RRefId::fromIValue(values[0]));
static_cast<worker_id_t>(id), RRefId::fromIValue(values[0]));
}
const std::vector<at::IValue>& RRefFetchRet::values() {

View File

@ -127,7 +127,7 @@ c10::intrusive_ptr<Message> ScriptCall::toMessageImpl() && {
}
std::unique_ptr<ScriptCall> ScriptCall::fromMessage(const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload = message.payload().data();
auto payload_size = message.payload().size();
auto value = jit::unpickle(
payload,

View File

@ -65,7 +65,7 @@ c10::intrusive_ptr<Message> ScriptRemoteCall::toMessageImpl() && {
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload = message.payload().data();
auto payload_size = message.payload().size();
auto value = jit::unpickle(

View File

@ -20,7 +20,7 @@ c10::intrusive_ptr<Message> ScriptResp::toMessageImpl() && {
}
std::unique_ptr<ScriptResp> ScriptResp::fromMessage(const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload = message.payload().data();
auto payload_size = message.payload().size();
auto value = jit::unpickle(
payload,

View File

@ -304,9 +304,10 @@ void TensorPipeAgent::TimeSeriesMetricsTracker::addData(uint64_t dataPoint) {
}
float TensorPipeAgent::TimeSeriesMetricsTracker::computeAverage() const {
return currentCount_ == 0
? 0
: static_cast<float>((double)currentSum_ / (double)currentCount_);
return currentCount_ == 0 ? 0
: static_cast<float>(
static_cast<double>(currentSum_) /
static_cast<double>(currentCount_));
}
//////////////////////// TensorpipeRpcAgent /////////////////////////////////
@ -503,8 +504,9 @@ void TensorPipeAgent::startImpl() {
for (const auto& p : workerNameToInfo_) {
const auto& name = p.first;
auto nodeAddrData = nameToAddressStore_.get(name);
auto nodeAddrStr =
std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
auto nodeAddrStr = std::string(
reinterpret_cast<const char*>(nodeAddrData.data()),
nodeAddrData.size());
workerNameToURL_.insert({name, nodeAddrStr});
}
@ -1240,8 +1242,9 @@ void TensorPipeAgent::updateGroupMembership(
// TODO: we should get nodeAddrStr in the joining process, then pass in as
// an argument rather than getting from store each time
auto nodeAddrData = nameToAddressStore_.get(name);
auto nodeAddrStr =
std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
auto nodeAddrStr = std::string(
reinterpret_cast<const char*>(nodeAddrData.data()),
nodeAddrData.size());
workerNameToURL_.insert({name, nodeAddrStr});
for (const auto& it : reverseDeviceMaps) {

View File

@ -106,23 +106,23 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (TensorPipeAgent::*)(void) const) &
RpcAgent::getWorkerInfo,
static_cast<const WorkerInfo& (TensorPipeAgent::*)(void) const>(
&RpcAgent::getWorkerInfo),
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
TensorPipeAgent::getWorkerInfo,
static_cast<const WorkerInfo& (TensorPipeAgent::*)(const std::string&)
const>(&TensorPipeAgent::getWorkerInfo),
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
TensorPipeAgent::getWorkerInfo,
static_cast<const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id)
const>(&TensorPipeAgent::getWorkerInfo),
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_infos",
(std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
TensorPipeAgent::getWorkerInfos,
static_cast<std::vector<WorkerInfo> (TensorPipeAgent::*)() const>(
&TensorPipeAgent::getWorkerInfos),
py::call_guard<py::gil_scoped_release>());
#endif // USE_TENSORPIPE

View File

@ -507,8 +507,7 @@ std::vector<at::IValue> readWrappedPayload(
" but additional payload size is ",
additionalPayloadSize);
auto wrappedPayloadBegin =
static_cast<const char*>(message.payload().data()) + payload.size() -
additionalPayloadSize;
message.payload().data() + payload.size() - additionalPayloadSize;
std::vector<torch::Tensor> tensorTable;
IValue tuple = jit::unpickle(
wrappedPayloadBegin,

View File

@ -257,7 +257,7 @@ void THPStorage_writeFileRaw(
at::device(self->device()).dtype(c10::kByte),
{self->device()});
cpu_tensor = device_tensor.to(at::kCPU);
data = (uint8_t*)cpu_tensor.data_ptr();
data = static_cast<uint8_t*>(cpu_tensor.data_ptr());
}
if (save_size) {
if (torch::utils::THP_nativeByteOrder() ==
@ -266,8 +266,8 @@ void THPStorage_writeFileRaw(
else {
int64_t nsize{}; // convert big endian cpu to little endian storage
torch::utils::THP_encodeBuffer(
(uint8_t*)&nsize,
(const int64_t*)&numel,
reinterpret_cast<uint8_t*>(&nsize),
reinterpret_cast<const int64_t*>(&numel),
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
1);
doWrite(fd, &nsize, sizeof(int64_t));
@ -279,7 +279,7 @@ void THPStorage_writeFileRaw(
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) {
doWrite(fd, data, size_bytes);
} else {
size_t buffer_size = std::min(numel, (size_t)5000);
size_t buffer_size = std::min(numel, static_cast<size_t>(5000));
std::vector<uint8_t> le_buffer;
le_buffer.resize(buffer_size * element_size);
for (size_t i = 0; i < numel; i += buffer_size) {
@ -287,19 +287,19 @@ void THPStorage_writeFileRaw(
if (element_size == 2) {
torch::utils::THP_encodeBuffer(
le_buffer.data(),
(const int16_t*)data + i,
reinterpret_cast<const int16_t*>(data) + i,
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
} else if (element_size == 4) {
torch::utils::THP_encodeBuffer(
le_buffer.data(),
(const int32_t*)data + i,
reinterpret_cast<const int32_t*>(data) + i,
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
} else if (element_size == 8) {
torch::utils::THP_encodeBuffer(
le_buffer.data(),
(const int64_t*)data + i,
reinterpret_cast<const int64_t*>(data) + i,
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
}
@ -333,7 +333,8 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
if (torch::utils::THP_nativeByteOrder() ==
torch::utils::THPByteOrder::THP_BIG_ENDIAN) {
int64_t tsize = size; // convert little endian storage to big endian cpu
torch::utils::THP_decodeBuffer(&size, (const uint8_t*)&tsize, true, 1);
torch::utils::THP_decodeBuffer(
&size, reinterpret_cast<const uint8_t*>(&tsize), true, 1);
}
size_t nbytes = element_size * size;
if (!storage.defined()) {
@ -358,7 +359,7 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
data = static_cast<uint8_t*>(storage->mutable_data());
} else {
cpu_data.resize(nbytes);
data = (uint8_t*)cpu_data.data();
data = reinterpret_cast<uint8_t*>(cpu_data.data());
}
// fast track for bytes and little endian
@ -367,7 +368,7 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) {
doRead(file, data, storage->nbytes());
} else {
int64_t buffer_size = std::min(size, (int64_t)5000);
int64_t buffer_size = std::min(size, static_cast<int64_t>(5000));
std::vector<uint8_t> le_buffer;
le_buffer.resize(buffer_size * element_size);
@ -378,13 +379,22 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
// NOLINTNEXTLINE(bugprone-branch-clone)
if (element_size == 2) {
torch::utils::THP_decodeBuffer(
(int16_t*)data + i, le_buffer.data(), true, to_convert);
reinterpret_cast<int16_t*>(data) + i,
le_buffer.data(),
true,
to_convert);
} else if (element_size == 4) {
torch::utils::THP_decodeBuffer(
(int32_t*)data + i, le_buffer.data(), true, to_convert);
reinterpret_cast<int32_t*>(data) + i,
le_buffer.data(),
true,
to_convert);
} else if (element_size == 8) {
torch::utils::THP_decodeBuffer(
(int64_t*)data + i, le_buffer.data(), true, to_convert);
reinterpret_cast<int64_t*>(data) + i,
le_buffer.data(),
true,
to_convert);
}
}
}

View File

@ -84,7 +84,7 @@ std::vector<int> THPUtils_unpackIntTuple(PyObject* arg) {
TORCH_CHECK(THPUtils_checkIntTuple(arg), "Couldn't unpack int tuple");
std::vector<int> values(PyTuple_GET_SIZE(arg));
for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
values[i] = (int)THPUtils_unpackLong(PyTuple_GET_ITEM(arg, i));
values[i] = THPUtils_unpackInt(PyTuple_GET_ITEM(arg, i));
}
return values;
}

View File

@ -357,6 +357,33 @@ def noop_mask(
return batch.new_ones(size=(), dtype=torch.bool, device=batch.device)
def _sliced_mask_mod_error(
batch: Tensor,
head: Tensor,
token_q: Tensor,
token_kv: Tensor,
) -> Tensor:
"""
Raises helpful error when using mask_mod from a sliced BlockMask.
After slicing a BlockMask, the mask_mod is reset and cannot be used directly.
Users must reassign mask_mod from the original (unsliced) BlockMask.
"""
raise RuntimeError(
"Cannot use mask_mod from a sliced BlockMask. "
"When you slice a BlockMask using [], the mask_mod attribute is reset. "
"You must set it from the original BlockMask's mask_mod."
"\n\nIncorrect usage:"
"\n base_mask = create_block_mask(my_mask_fn, ...)"
"\n sliced_mask = base_mask[:, :, block_idx]"
"\n sliced_mask.mask_mod = apply_offset(sliced_mask.mask_mod, offset) # WRONG!"
"\n\nCorrect usage:"
"\n base_mask = create_block_mask(my_mask_fn, ...)"
"\n sliced_mask = base_mask[:, :, block_idx]"
"\n sliced_mask.mask_mod = apply_offset(base_mask.mask_mod, offset) # Use base_mask!"
)
_DEFAULT_SPARSE_BLOCK_SIZE = 128
_LARGE_SPARSE_BLOCK_SIZE = 1 << 30
@ -710,7 +737,7 @@ class BlockMask:
new_full_kv_num_blocks,
new_full_kv_indices,
BLOCK_SIZE=self.BLOCK_SIZE,
mask_mod=None,
mask_mod=_sliced_mask_mod_error,
seq_lengths=self.seq_lengths,
compute_q_blocks=self.q_indices is not None,
)
@ -1414,6 +1441,11 @@ def flex_attention(
if block_mask is None:
block_mask = _create_empty_block_mask(query, key)
# If BlockMask was sliced, its mask_mod is intentionally replaced with an error-raising stub.
# This guard ensures we surface the intended error message before any shape-based checks.
if getattr(block_mask, "mask_mod", None) is _sliced_mask_mod_error:
raise RuntimeError("Cannot use mask_mod from a sliced BlockMask")
if (
block_mask.BLOCK_SIZE[0] == _LARGE_SPARSE_BLOCK_SIZE
and block_mask.BLOCK_SIZE[1] == _LARGE_SPARSE_BLOCK_SIZE