mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 07:24:58 +08:00
Compare commits
19 Commits
ciflow/tru
...
trunk-tagg
| Author | SHA1 | Date | |
|---|---|---|---|
| 6c4b2b681f | |||
| aeb036b90e | |||
| acf2915502 | |||
| 4f7f43253d | |||
| 779296a3fc | |||
| 8f06a1308f | |||
| 240c13394e | |||
| 150682ba7f | |||
| ca7360e996 | |||
| 0bf604320f | |||
| 9875e70da8 | |||
| 69a4bfe8bb | |||
| 62a263b8d4 | |||
| 0da1f911dc | |||
| 8700d68fef | |||
| ab82456c16 | |||
| b23f4687fd | |||
| 2705937080 | |||
| c1eda348be |
147
.github/workflows/trunk-tagging.yml
vendored
147
.github/workflows/trunk-tagging.yml
vendored
@ -58,8 +58,10 @@ jobs:
|
||||
else
|
||||
COMMIT_SHA="${{ github.sha }}"
|
||||
fi
|
||||
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
{
|
||||
echo "sha=${COMMIT_SHA}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Validate commit SHA
|
||||
run: |
|
||||
@ -87,7 +89,7 @@ jobs:
|
||||
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
|
||||
fi
|
||||
|
||||
- name: Create and push tag with retry
|
||||
- name: Create and push tag(s) with retry
|
||||
id: check_tag
|
||||
env:
|
||||
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
|
||||
@ -112,14 +114,23 @@ jobs:
|
||||
return 1
|
||||
}
|
||||
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
echo "exists=true" >> "${GITHUB_OUTPUT}"
|
||||
exit 0
|
||||
fi
|
||||
# Counters for summary reporting
|
||||
created_count=0
|
||||
skipped_count=0
|
||||
failed_count=0
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
# Always write outputs once on exit
|
||||
finish() {
|
||||
set +e
|
||||
if [ -n "${GITHUB_OUTPUT:-}" ]; then
|
||||
{
|
||||
echo "created_count=${created_count}"
|
||||
echo "skipped_count=${skipped_count}"
|
||||
echo "failed_count=${failed_count}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
fi
|
||||
}
|
||||
trap finish EXIT
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES=5
|
||||
@ -194,31 +205,111 @@ jobs:
|
||||
}
|
||||
}
|
||||
|
||||
# Execute with retry
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
echo "exists=false" >> "${GITHUB_OUTPUT}"
|
||||
# New behavior for push events: enumerate commits in the push and tag each one.
|
||||
# For workflow_dispatch, retain existing single-SHA behavior.
|
||||
|
||||
# Always fetch tags once up front to improve idempotency in loops
|
||||
git fetch origin --tags --quiet || true
|
||||
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
BEFORE_SHA="${{ github.event.before }}"
|
||||
AFTER_SHA="${{ github.sha }}" # same as event.after
|
||||
|
||||
# List commits introduced by this push (old..new), oldest first for stable ordering
|
||||
commits_file="$(mktemp)"
|
||||
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
|
||||
|
||||
if [ ! -s "${commits_file}" ]; then
|
||||
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
|
||||
rm -f "${commits_file}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
|
||||
echo "Found ${commit_count} commit(s) to tag for push:"
|
||||
while IFS= read -r sha; do
|
||||
printf ' %s\n' "${sha}"
|
||||
done < "${commits_file}"
|
||||
|
||||
while IFS= read -r sha; do
|
||||
TAG_NAME="trunk/${sha}"
|
||||
COMMIT_SHA="${sha}"
|
||||
|
||||
# If tag already exists locally or remotely, skip (idempotent)
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag ${TAG_NAME} already exists - skipping"
|
||||
skipped_count=$((skipped_count + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=$((created_count + 1))
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
|
||||
failed_count=$((failed_count + 1))
|
||||
fi
|
||||
done < "${commits_file}"
|
||||
|
||||
rm -f "${commits_file}"
|
||||
|
||||
if [ "${failed_count}" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
exit 0
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
exit 1
|
||||
# workflow_dispatch path (single SHA tagging preserved)
|
||||
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
skipped_count=1
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=1
|
||||
exit 0
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
failed_count=1
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Tag creation summary
|
||||
if: always()
|
||||
run: |
|
||||
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
elif [ "${{ job.status }}" = "success" ]; then
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
echo "Trigger: push on main"
|
||||
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
|
||||
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
|
||||
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
else
|
||||
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
fi
|
||||
else
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
else
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
else
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
fi
|
||||
fi
|
||||
|
||||
34
.github/workflows/trunk.yml
vendored
34
.github/workflows/trunk.yml
vendored
@ -190,6 +190,40 @@ jobs:
|
||||
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
|
||||
secrets: inherit
|
||||
|
||||
inductor-build:
|
||||
name: inductor-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
@ -289,15 +289,14 @@ IF(USE_FBGEMM_GENAI)
|
||||
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
set(fbgemm_genai_cuh
|
||||
set(fbgemm_genai_mx8mx8bf16_grouped
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
"${FBGEMM_GENAI_SRCS}/"
|
||||
)
|
||||
|
||||
target_include_directories(fbgemm_genai PRIVATE
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${fbgemm_genai_cuh}
|
||||
${fbgemm_genai_mx8mx8bf16_grouped}
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
|
||||
@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
|
||||
} else if (dtype == ScalarType::Half) {
|
||||
[&]() {
|
||||
using scalar_t =
|
||||
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
|
||||
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
|
||||
const auto exp = exp_scalar.to<scalar_t>();
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
cpu_kernel_vec(iter,
|
||||
|
||||
@ -856,13 +856,9 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t output_offset_calculator,
|
||||
loader_t loader,
|
||||
storer_t storer) {
|
||||
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
|
||||
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
|
||||
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
|
||||
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
|
||||
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
|
||||
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
|
||||
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
|
||||
if (ret_t == rt_binary_specializations[arg_index][0] &&
|
||||
arg0_t == rt_binary_specializations[arg_index][1] &&
|
||||
arg1_t == rt_binary_specializations[arg_index][2])
|
||||
launch_vectorized_templated_kernel<
|
||||
func_t,
|
||||
array_t,
|
||||
@ -870,9 +866,12 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t,
|
||||
loader_t,
|
||||
storer_t,
|
||||
cret_t,
|
||||
carg0_t,
|
||||
carg1_t>(
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][0]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][1]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][2]>::t)>(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
@ -880,7 +879,6 @@ struct type_specialized_kernel_launcher {
|
||||
output_offset_calculator,
|
||||
loader,
|
||||
storer);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ __global__ void triu_tril_kernel(
|
||||
const int64_t k,
|
||||
const int64_t N_padded,
|
||||
const IndexType last_dim_padded) {
|
||||
int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread;
|
||||
int64_t linear_idx = (((int64_t)blockIdx.x) * blockDim.x + threadIdx.x) * elements_per_thread;
|
||||
if (linear_idx >= N_padded) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -441,7 +441,7 @@ kernel void applySYRK(
|
||||
uint3 tid [[thread_position_in_threadgroup]],
|
||||
uint3 tgid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threads_per_threadgroup]],
|
||||
uint sgitg [[simdgroup_index_in_threadgroup]]) {
|
||||
uint warp_id [[simdgroup_index_in_threadgroup]]) {
|
||||
const uint tx = tid.x;
|
||||
const uint ty = tid.y;
|
||||
const uint simdGroupsPerThreadgroup = (tpg.x * tpg.y + 31) / 32;
|
||||
@ -474,11 +474,8 @@ kernel void applySYRK(
|
||||
(actSize_j % 8 == 0) && (actSize_h % 8 == 0) && (actSize_k % 8 == 0);
|
||||
|
||||
if (use_simdgroup) {
|
||||
uint warp_id = sgitg;
|
||||
|
||||
simdgroup_matrix<float, 8, 8> negative_identity =
|
||||
simdgroup_matrix<float, 8, 8>(-1.0);
|
||||
simdgroup_matrix<float, 8, 8> identity = simdgroup_matrix<float, 8, 8>(1.0);
|
||||
simdgroup_matrix<float, 8, 8> Prod;
|
||||
simdgroup_matrix<float, 8, 8> Afrag;
|
||||
simdgroup_matrix<float, 8, 8> Bfrag;
|
||||
@ -521,8 +518,7 @@ kernel void applySYRK(
|
||||
/* transpose = */ upper);
|
||||
|
||||
simdgroup_multiply(Prod, Afrag, Bfrag);
|
||||
simdgroup_multiply(Prod, Prod, negative_identity);
|
||||
simdgroup_multiply_accumulate(Cfrag, Cfrag, identity, Prod);
|
||||
simdgroup_multiply_accumulate(Cfrag, Prod, negative_identity, Cfrag);
|
||||
}
|
||||
|
||||
simdgroup_store(
|
||||
|
||||
@ -102,7 +102,7 @@ uint64_t getNonDeterministicRandom(bool is_cuda) {
|
||||
} else {
|
||||
std::random_device rd;
|
||||
// limit to 53 bits to ensure unique representation in double
|
||||
s = (((static_cast<uint64_t>(rd())) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
|
||||
s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
@ -20,8 +20,7 @@ void maybeApplyRefcountedDeleter(const c10::Storage& storage) {
|
||||
std::lock_guard<std::mutex> guard(replace_data_ptr_mutex);
|
||||
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
|
||||
|
||||
if (reinterpret_cast<const void*>(data_ptr.get_deleter()) ==
|
||||
reinterpret_cast<const void*>(&c10::refcounted_deleter)) {
|
||||
if ((void*)data_ptr.get_deleter() == (void*)&c10::refcounted_deleter) {
|
||||
// Data pointer is already shared
|
||||
return;
|
||||
}
|
||||
|
||||
@ -52,6 +52,19 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
|
||||
#undef DEFINE_CONSTANT
|
||||
|
||||
inline const char* toString(ScalarType t) {
|
||||
#define DEFINE_CASE(_, name) \
|
||||
case ScalarType::name: \
|
||||
return #name;
|
||||
|
||||
switch (t) {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
||||
default:
|
||||
return "UNKNOWN_SCALAR";
|
||||
}
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
|
||||
inline size_t elementSize(ScalarType t) {
|
||||
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
|
||||
case ScalarType::name: \
|
||||
@ -295,6 +308,12 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
|
||||
|
||||
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
at::ScalarType scalar_type) {
|
||||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
// Returns a pair of strings representing the names for each dtype.
|
||||
// The returned pair is (name, legacy_name_if_applicable)
|
||||
C10_API std::pair<std::string, std::string> getDtypeNames(
|
||||
|
||||
@ -83,7 +83,7 @@ DEFINE_BINARY(max_slow_path, sym_max, SymInt)
|
||||
|
||||
SymInt::operator SymFloat() const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
return SymFloat(static_cast<double>(*ma));
|
||||
return SymFloat(double(*ma));
|
||||
} else {
|
||||
return SymFloat(toSymNodeImplUnowned()->sym_float());
|
||||
}
|
||||
|
||||
@ -44,8 +44,7 @@ bool has_simple_data_ptr(const c10::StorageImpl& storage) {
|
||||
}
|
||||
|
||||
bool is_cow_data_ptr(const c10::DataPtr& data_ptr) {
|
||||
return reinterpret_cast<const void*>(data_ptr.get_deleter()) ==
|
||||
reinterpret_cast<const void*>(&cow::cow_deleter);
|
||||
return (void*)data_ptr.get_deleter() == (void*)&cow::cow_deleter;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
|
||||
|
||||
@ -512,7 +512,7 @@ struct ExpandableSegment {
|
||||
header.segment_size = segment_size_;
|
||||
header.num_handles = end - begin;
|
||||
|
||||
buf.write(reinterpret_cast<const char*>(&header), sizeof(ShareHeader));
|
||||
buf.write((const char*)&header, sizeof(ShareHeader));
|
||||
for (auto i : c10::irange(begin, end)) {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
auto& handle = handles_.at(i).value();
|
||||
@ -528,9 +528,7 @@ struct ExpandableSegment {
|
||||
TORCH_CHECK(
|
||||
handle.shareable_handle != std::nullopt,
|
||||
"shareable_handle is null");
|
||||
buf.write(
|
||||
reinterpret_cast<const char*>(&*handle.shareable_handle),
|
||||
sizeof(int));
|
||||
buf.write((const char*)&*handle.shareable_handle, sizeof(int));
|
||||
} else {
|
||||
if (!handle.shareable_handle) {
|
||||
CUmemFabricHandle fabric_handle;
|
||||
@ -543,8 +541,7 @@ struct ExpandableSegment {
|
||||
handle.shareable_handle != std::nullopt,
|
||||
"shareable_handle is null");
|
||||
buf.write(
|
||||
reinterpret_cast<const char*>(&*handle.shareable_handle),
|
||||
sizeof(CUmemFabricHandle));
|
||||
(const char*)&*handle.shareable_handle, sizeof(CUmemFabricHandle));
|
||||
}
|
||||
}
|
||||
return rangeFromHandles(begin, end);
|
||||
@ -555,7 +552,7 @@ struct ExpandableSegment {
|
||||
std::vector<c10::DeviceIndex> peers,
|
||||
std::istream& buf) {
|
||||
ShareHeader header{};
|
||||
buf.read(reinterpret_cast<char*>(&header), sizeof(ShareHeader));
|
||||
buf.read((char*)&header, sizeof(ShareHeader));
|
||||
auto segment = std::make_unique<ExpandableSegment>(
|
||||
device, std::nullopt, header.segment_size, std::move(peers));
|
||||
// older build setups (e.g. multiwheels) do not have this syscall, added 2020
|
||||
@ -577,11 +574,11 @@ struct ExpandableSegment {
|
||||
for (auto i : c10::irange(header.num_handles)) {
|
||||
(void)i;
|
||||
int fd = 0;
|
||||
buf.read(reinterpret_cast<char*>(&fd), sizeof(int));
|
||||
buf.read((char*)&fd, sizeof(int));
|
||||
auto myfd = syscall(SYS_pidfd_getfd, pidfd, fd, 0);
|
||||
if (myfd == -1) {
|
||||
auto err = errno;
|
||||
close(static_cast<int>(pidfd));
|
||||
close((int)pidfd);
|
||||
for (auto& h : segment->handles_) {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
@ -601,16 +598,15 @@ struct ExpandableSegment {
|
||||
(void*)(uintptr_t)myfd,
|
||||
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
|
||||
LOG(INFO) << "use posix fd to import expandable segments.";
|
||||
close(static_cast<int>(myfd));
|
||||
close((int)myfd);
|
||||
segment->handles_.emplace_back(Handle{handle, std::nullopt});
|
||||
}
|
||||
close(static_cast<int>(pidfd));
|
||||
close((int)pidfd);
|
||||
} else {
|
||||
for (auto i : c10::irange(header.num_handles)) {
|
||||
(void)i;
|
||||
CUmemFabricHandle fabric_handle;
|
||||
buf.read(
|
||||
reinterpret_cast<char*>(&fabric_handle), sizeof(CUmemFabricHandle));
|
||||
buf.read((char*)&fabric_handle, sizeof(CUmemFabricHandle));
|
||||
CUmemGenericAllocationHandle handle = 0;
|
||||
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_(
|
||||
&handle,
|
||||
@ -1063,7 +1059,7 @@ class RingBuffer {
|
||||
|
||||
void setMaxEntries(size_t size) {
|
||||
std::lock_guard<std::mutex> lk(alloc_trace_lock);
|
||||
alloc_trace_max_entries_ = std::max(static_cast<size_t>(1), size);
|
||||
alloc_trace_max_entries_ = std::max(size_t(1), size);
|
||||
}
|
||||
|
||||
void insertEntries(const T& entry) {
|
||||
@ -1995,16 +1991,15 @@ class DeviceCachingAllocator {
|
||||
while (base_block->prev) {
|
||||
base_block = base_block->prev;
|
||||
}
|
||||
offset = static_cast<const char*>(block->ptr) -
|
||||
static_cast<const char*>(base_block->ptr);
|
||||
offset = (char*)block->ptr - (char*)base_block->ptr;
|
||||
cudaIpcMemHandle_t handle;
|
||||
C10_CUDA_CHECK(cudaIpcGetMemHandle(&handle, base_block->ptr));
|
||||
ss.write(reinterpret_cast<const char*>(&handle), CUDA_IPC_HANDLE_SIZE);
|
||||
ss.write((char*)&handle, CUDA_IPC_HANDLE_SIZE);
|
||||
} else {
|
||||
ss.put(SHAREABLE_CUDA_EXPANDABLE_SEGMENT);
|
||||
auto full_range = block->expandable_segment_->share(
|
||||
SegmentRange(block->ptr, block->size), ss);
|
||||
offset = static_cast<const char*>(block->ptr) - full_range.ptr;
|
||||
offset = (char*)block->ptr - full_range.ptr;
|
||||
}
|
||||
return ShareableHandle{offset, ss.str()};
|
||||
}
|
||||
@ -3234,8 +3229,7 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
|
||||
total_allocated_memory += size;
|
||||
p.block = new Block(
|
||||
p.device(), p.stream(), size, p.pool, static_cast<char*>(ptr));
|
||||
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
|
||||
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
|
||||
stats.segment[stat_type].increase(1);
|
||||
stats.reserved_bytes[stat_type].increase(size);
|
||||
@ -3783,7 +3777,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
allocated_blocks;
|
||||
|
||||
static size_t get_mutex_shard_id(void* ptr) {
|
||||
return twang_mix64(reinterpret_cast<uintptr_t>(ptr)) % kNumMutexShard;
|
||||
return twang_mix64((size_t)ptr) % kNumMutexShard;
|
||||
}
|
||||
|
||||
void add_allocated_block(Block* block) {
|
||||
@ -3820,8 +3814,8 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
if (size < device_count) {
|
||||
device_allocator.resize(device_count);
|
||||
for (const auto i : c10::irange(size, device_count)) {
|
||||
device_allocator[i] = std::make_unique<DeviceCachingAllocator>(
|
||||
static_cast<c10::DeviceIndex>(i));
|
||||
device_allocator[i] =
|
||||
std::make_unique<DeviceCachingAllocator>(c10::DeviceIndex(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4350,7 +4344,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
// SHARABLE_CUDA_MALLOC
|
||||
if (type == SHAREABLE_CUDA_MALLOC) {
|
||||
cudaIpcMemHandle_t cuda_handle;
|
||||
ss.read(reinterpret_cast<char*>(&cuda_handle), CUDA_IPC_HANDLE_SIZE);
|
||||
ss.read((char*)&cuda_handle, CUDA_IPC_HANDLE_SIZE);
|
||||
C10_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
&cuda_ipc_ptr_, cuda_handle, cudaIpcMemLazyEnablePeerAccess));
|
||||
} else if (type == SHAREABLE_CUDA_EXPANDABLE_SEGMENT) {
|
||||
|
||||
@ -46,7 +46,7 @@ bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
|
||||
|
||||
struct UsageStreamHash {
|
||||
size_t operator()(const UsageStream& us) const noexcept {
|
||||
return std::hash<void*>{}(us.stream) + static_cast<size_t>(us.device);
|
||||
return std::hash<void*>{}(us.stream) + size_t(us.device);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -128,7 +128,7 @@ std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
|
||||
} else if (s.isExt()) {
|
||||
stream << "EXT";
|
||||
} else {
|
||||
stream << "PRIORITY " << static_cast<int>(s.getStreamType());
|
||||
stream << "PRIORITY " << int(s.getStreamType());
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
||||
@ -46,8 +46,7 @@ std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
|
||||
for (const auto i : c10::irange(replicates)) {
|
||||
auto delta_ns = end_times[i].t_ - start_times_[i].t_;
|
||||
auto delta_approx = end_times[i].approx_t_ - start_times_[i].approx_t_;
|
||||
scale_factors[i] =
|
||||
static_cast<double>(delta_ns) / static_cast<double>(delta_approx);
|
||||
scale_factors[i] = (double)delta_ns / (double)delta_approx;
|
||||
}
|
||||
std::sort(scale_factors.begin(), scale_factors.end());
|
||||
long double scale_factor = scale_factors[replicates / 2 + 1];
|
||||
@ -65,8 +64,7 @@ std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
|
||||
for (const auto i : c10::irange(replicates)) {
|
||||
auto dt = start_times_[i].t_ - t0;
|
||||
auto dt_approx =
|
||||
static_cast<double>(start_times_[i].approx_t_ - t0_approx) *
|
||||
scale_factor;
|
||||
(double)(start_times_[i].approx_t_ - t0_approx) * scale_factor;
|
||||
t0_correction[i] = dt - (time_t)dt_approx; // NOLINT
|
||||
}
|
||||
t0 += t0_correction[t0_correction.size() / 2 + 1]; // NOLINT
|
||||
@ -74,9 +72,7 @@ std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
|
||||
return [=](approx_time_t t_approx) {
|
||||
// See above for why this is more stable than `A * t_approx + B`.
|
||||
return t_approx > t0_approx
|
||||
? static_cast<time_t>(
|
||||
static_cast<double>(t_approx - t0_approx) * scale_factor) +
|
||||
t0
|
||||
? (time_t)((double)(t_approx - t0_approx) * scale_factor) + t0
|
||||
: 0;
|
||||
};
|
||||
}
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
@ -41,106 +40,200 @@ namespace c10 {
|
||||
///
|
||||
/// This is intended to be trivially copyable, so it should be passed by
|
||||
/// value.
|
||||
///
|
||||
/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct
|
||||
/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of
|
||||
/// the underlying constexpr calls, we rely on apparent-type dispatch for
|
||||
/// inheritance. This should be fine because their memory format is the same,
|
||||
/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods.
|
||||
/// However, you should prefer to use ArrayRef when possible, because its use
|
||||
/// of TORCH_CHECK will lead to better user-facing error messages.
|
||||
template <typename T>
|
||||
class ArrayRef final : public HeaderOnlyArrayRef<T> {
|
||||
class ArrayRef final {
|
||||
public:
|
||||
/// @name Constructors, all inherited from HeaderOnlyArrayRef except for
|
||||
/// SmallVector.
|
||||
using iterator = const T*;
|
||||
using const_iterator = const T*;
|
||||
using size_type = size_t;
|
||||
using value_type = T;
|
||||
|
||||
using reverse_iterator = std::reverse_iterator<iterator>;
|
||||
|
||||
private:
|
||||
/// The start of the array, in an external buffer.
|
||||
const T* Data;
|
||||
|
||||
/// The number of elements.
|
||||
size_type Length;
|
||||
|
||||
void debugCheckNullptrInvariant() {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
Data != nullptr || Length == 0,
|
||||
"created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal");
|
||||
}
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
using HeaderOnlyArrayRef<T>::HeaderOnlyArrayRef;
|
||||
/// Construct an empty ArrayRef.
|
||||
/* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
|
||||
|
||||
/// Construct an ArrayRef from a std::vector.
|
||||
/// This constructor is identical to the one in HeaderOnlyArrayRef, but we
|
||||
/// include it to help with Class Template Argument Deduction (CTAD).
|
||||
/// Without it, CTAD can fail sometimes due to the indirect constructor
|
||||
/// inheritance. So we explicitly include this constructor.
|
||||
template <typename A>
|
||||
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
|
||||
: HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
|
||||
/// Construct an ArrayRef from a single element.
|
||||
// TODO Make this explicit
|
||||
constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
|
||||
|
||||
/// Construct an ArrayRef from a pointer and length.
|
||||
constexpr ArrayRef(const T* data, size_t length)
|
||||
: Data(data), Length(length) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a range.
|
||||
constexpr ArrayRef(const T* begin, const T* end)
|
||||
: Data(begin), Length(end - begin) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a SmallVector. This is templated in order to
|
||||
/// avoid instantiating SmallVectorTemplateCommon<T> whenever we
|
||||
/// copy-construct an ArrayRef.
|
||||
/// NOTE: this is the only constructor that is not inherited from
|
||||
/// HeaderOnlyArrayRef.
|
||||
template <typename U>
|
||||
/* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
|
||||
: HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
|
||||
: Data(Vec.data()), Length(Vec.size()) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
|
||||
template <
|
||||
typename Container,
|
||||
typename U = decltype(std::declval<Container>().data()),
|
||||
typename = std::enable_if_t<
|
||||
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
|
||||
/* implicit */ ArrayRef(const Container& container)
|
||||
: Data(container.data()), Length(container.size()) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a std::vector.
|
||||
// The enable_if stuff here makes sure that this isn't used for
|
||||
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
|
||||
// bitfield.
|
||||
template <typename A>
|
||||
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
|
||||
: Data(Vec.data()), Length(Vec.size()) {
|
||||
static_assert(
|
||||
!std::is_same_v<T, bool>,
|
||||
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a std::array
|
||||
template <size_t N>
|
||||
/* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
|
||||
: Data(Arr.data()), Length(N) {}
|
||||
|
||||
/// Construct an ArrayRef from a C array.
|
||||
template <size_t N>
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
/* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
|
||||
|
||||
/// Construct an ArrayRef from a std::initializer_list.
|
||||
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
|
||||
: Data(
|
||||
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
|
||||
: std::begin(Vec)),
|
||||
Length(Vec.size()) {}
|
||||
|
||||
/// @}
|
||||
/// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef
|
||||
/// @name Simple Operations
|
||||
/// @{
|
||||
|
||||
constexpr iterator begin() const {
|
||||
return Data;
|
||||
}
|
||||
constexpr iterator end() const {
|
||||
return Data + Length;
|
||||
}
|
||||
|
||||
// These are actually the same as iterator, since ArrayRef only
|
||||
// gives you const iterators.
|
||||
constexpr const_iterator cbegin() const {
|
||||
return Data;
|
||||
}
|
||||
constexpr const_iterator cend() const {
|
||||
return Data + Length;
|
||||
}
|
||||
|
||||
constexpr reverse_iterator rbegin() const {
|
||||
return reverse_iterator(end());
|
||||
}
|
||||
constexpr reverse_iterator rend() const {
|
||||
return reverse_iterator(begin());
|
||||
}
|
||||
|
||||
/// Check if all elements in the array satisfy the given expression
|
||||
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
|
||||
return std::all_of(cbegin(), cend(), pred);
|
||||
}
|
||||
|
||||
/// empty - Check if the array is empty.
|
||||
constexpr bool empty() const {
|
||||
return Length == 0;
|
||||
}
|
||||
|
||||
constexpr const T* data() const {
|
||||
return Data;
|
||||
}
|
||||
|
||||
/// size - Get the array size.
|
||||
constexpr size_t size() const {
|
||||
return Length;
|
||||
}
|
||||
|
||||
/// front - Get the first element.
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr const T& front() const {
|
||||
TORCH_CHECK(
|
||||
!this->empty(), "ArrayRef: attempted to access front() of empty list");
|
||||
return this->Data[0];
|
||||
!empty(), "ArrayRef: attempted to access front() of empty list");
|
||||
return Data[0];
|
||||
}
|
||||
|
||||
/// back - Get the last element.
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr const T& back() const {
|
||||
TORCH_CHECK(
|
||||
!this->empty(), "ArrayRef: attempted to access back() of empty list");
|
||||
return this->Data[this->Length - 1];
|
||||
TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
|
||||
return Data[Length - 1];
|
||||
}
|
||||
|
||||
/// equals - Check for element-wise equality.
|
||||
constexpr bool equals(ArrayRef RHS) const {
|
||||
return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
|
||||
}
|
||||
|
||||
/// slice(n, m) - Take M elements of the array starting at element N
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr ArrayRef<T> slice(size_t N, size_t M) const {
|
||||
TORCH_CHECK(
|
||||
N + M <= this->size(),
|
||||
N + M <= size(),
|
||||
"ArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
"; M = ",
|
||||
M,
|
||||
"; size = ",
|
||||
this->size());
|
||||
return ArrayRef<T>(this->data() + N, M);
|
||||
size());
|
||||
return ArrayRef<T>(data() + N, M);
|
||||
}
|
||||
|
||||
/// slice(n) - Chop off the first N elements of the array.
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr ArrayRef<T> slice(size_t N) const {
|
||||
TORCH_CHECK(
|
||||
N <= this->size(),
|
||||
"ArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
"; size = ",
|
||||
this->size());
|
||||
return slice(N, this->size() - N); // should this slice be this->slice?
|
||||
N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size());
|
||||
return slice(N, size() - N);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Operator Overloads
|
||||
/// @{
|
||||
constexpr const T& operator[](size_t Index) const {
|
||||
return Data[Index];
|
||||
}
|
||||
|
||||
/// Vector compatibility
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr const T& at(size_t Index) const {
|
||||
TORCH_CHECK(
|
||||
Index < this->Length,
|
||||
Index < Length,
|
||||
"ArrayRef: invalid index Index = ",
|
||||
Index,
|
||||
"; Length = ",
|
||||
this->Length);
|
||||
return this->Data[Index];
|
||||
Length);
|
||||
return Data[Index];
|
||||
}
|
||||
|
||||
/// Disallow accidental assignment from a temporary.
|
||||
@ -160,6 +253,13 @@ class ArrayRef final : public HeaderOnlyArrayRef<T> {
|
||||
std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
|
||||
std::initializer_list<U>) = delete;
|
||||
|
||||
/// @}
|
||||
/// @name Expensive Operations
|
||||
/// @{
|
||||
std::vector<T> vec() const {
|
||||
return std::vector<T>(Data, Data + Length);
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
|
||||
@ -132,15 +132,15 @@ std::ostream& operator<<(std::ostream& o, const uint128& b) {
|
||||
int div_base_log = 0;
|
||||
switch (flags & std::ios::basefield) {
|
||||
case std::ios::hex:
|
||||
div = static_cast<uint64_t>(0x1000000000000000u); // 16^15
|
||||
div = (uint64_t)0x1000000000000000u; // 16^15
|
||||
div_base_log = 15;
|
||||
break;
|
||||
case std::ios::oct:
|
||||
div = static_cast<uint64_t>(01000000000000000000000u); // 8^21
|
||||
div = (uint64_t)01000000000000000000000u; // 8^21
|
||||
div_base_log = 21;
|
||||
break;
|
||||
default: // std::ios::dec
|
||||
div = static_cast<uint64_t>(10000000000000000000u); // 10^19
|
||||
div = (uint64_t)10000000000000000000u; // 10^19
|
||||
div_base_log = 19;
|
||||
break;
|
||||
}
|
||||
|
||||
@ -7,7 +7,6 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
|
||||
|
||||
@ -1,52 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
using torch::headeronly::HeaderOnlyArrayRef;
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestEmpty) {
|
||||
HeaderOnlyArrayRef<float> arr;
|
||||
ASSERT_TRUE(arr.empty());
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestSingleton) {
|
||||
float val = 5.0f;
|
||||
HeaderOnlyArrayRef<float> arr(val);
|
||||
ASSERT_FALSE(arr.empty());
|
||||
EXPECT_EQ(arr.size(), 1);
|
||||
EXPECT_EQ(arr[0], val);
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestAPIs) {
|
||||
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
|
||||
HeaderOnlyArrayRef<int> arr(vec);
|
||||
ASSERT_FALSE(arr.empty());
|
||||
EXPECT_EQ(arr.size(), 7);
|
||||
for (size_t i = 0; i < arr.size(); i++) {
|
||||
EXPECT_EQ(arr[i], i + 1);
|
||||
EXPECT_EQ(arr.at(i), i + 1);
|
||||
}
|
||||
EXPECT_EQ(arr.front(), 1);
|
||||
EXPECT_EQ(arr.back(), 7);
|
||||
ASSERT_TRUE(arr.slice(3, 4).equals(arr.slice(3)));
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestFromInitializerList) {
|
||||
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
|
||||
HeaderOnlyArrayRef<int> arr({1, 2, 3, 4, 5, 6, 7});
|
||||
auto res_vec = arr.vec();
|
||||
for (size_t i = 0; i < vec.size(); i++) {
|
||||
EXPECT_EQ(vec[i], res_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestFromRange) {
|
||||
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
|
||||
HeaderOnlyArrayRef<int> arr(vec.data() + 3, vec.data() + 7);
|
||||
auto res_vec = arr.vec();
|
||||
for (size_t i = 0; i < res_vec.size(); i++) {
|
||||
EXPECT_EQ(vec[i + 3], res_vec[i]);
|
||||
}
|
||||
}
|
||||
@ -53,24 +53,3 @@ TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
|
||||
|
||||
#undef DEFINE_CHECK
|
||||
#undef TEST_FORALL
|
||||
|
||||
TEST(TestScalarType, toString) {
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name);
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, operator_left_shift) {
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(_, name) \
|
||||
{ \
|
||||
std::stringstream ss; \
|
||||
ss << ScalarType::name; \
|
||||
EXPECT_EQ(ss.str(), #name); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
@ -311,9 +311,10 @@ void boxed_fill_infinity(
|
||||
}
|
||||
|
||||
Tensor my_pad(Tensor t) {
|
||||
std::vector<int64_t> padding = {1, 2, 2, 1};
|
||||
std::string mode = "constant";
|
||||
double value = 0.0;
|
||||
return pad(t, {1, 2, 2, 1}, mode, value);
|
||||
return pad(t, padding, mode, value);
|
||||
}
|
||||
|
||||
void boxed_my_pad(
|
||||
@ -341,9 +342,6 @@ void boxed_my_narrow(
|
||||
}
|
||||
|
||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
// Still using a std::vector below even though people can just pass in an
|
||||
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
|
||||
// directly.
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
|
||||
return new_empty(t, sizes, dtype);
|
||||
@ -355,8 +353,9 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
|
||||
}
|
||||
|
||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(at::ScalarType::Float);
|
||||
return new_zeros(t, {2, 5}, dtype);
|
||||
return new_zeros(t, sizes, dtype);
|
||||
}
|
||||
|
||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
@ -430,7 +429,8 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
|
||||
}
|
||||
|
||||
Tensor my_amax_vec(Tensor t) {
|
||||
return amax(t, {0,1}, false);
|
||||
std::vector<int64_t> v = {0,1};
|
||||
return amax(t, v, false);
|
||||
}
|
||||
|
||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
|
||||
@ -13454,47 +13454,6 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
||||
y = torch.tensor(5)
|
||||
f(x, y)
|
||||
|
||||
def test_cond_ra_pollution(self):
|
||||
def compute(x, w):
|
||||
return torch.nn.functional.linear(x, w)
|
||||
|
||||
def nop(x, w):
|
||||
torch._check(x.shape[0] == 0)
|
||||
return torch.empty_like(x)
|
||||
|
||||
def chunked_compute(x, w):
|
||||
return torch.cond(x.shape[0] > 0, compute, nop, (x, w))
|
||||
|
||||
x, w = (
|
||||
torch.randn(4, 16, requires_grad=True),
|
||||
torch.randn(16, 16, requires_grad=True),
|
||||
)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(16, 16)
|
||||
|
||||
def forward(self, x):
|
||||
return chunked_compute(x, self.linear.weight)
|
||||
|
||||
torch._dynamo.decorators.mark_unbacked(x, 0)
|
||||
orig_mod = Model()
|
||||
mod = torch._dynamo.functional_export._dynamo_graph_capture_for_export(
|
||||
orig_mod
|
||||
)(x)
|
||||
torch.export._trace._restore_state_dict(orig_mod, mod)
|
||||
|
||||
# Previously, this would cause an error because torch._check(x.shape[0] == 0)
|
||||
# would propagate a runtime assertion from the subgraph (nop branch) into
|
||||
# the main graph, leading to an incorrect assertion error. The sequence:
|
||||
# 1) Trace through the nop subgraph.
|
||||
# 2) Add a runtime assert that u0 == 0 which erroneously would update the
|
||||
# global shape environment.
|
||||
# 3) When generating runtime asserts for the main graph, the shape
|
||||
# environment incorrectly asserts u0 == 0, causing a false assertion.
|
||||
mod(x)
|
||||
|
||||
def test_full_graph_capture_scalar_outputs(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def foo(a):
|
||||
|
||||
@ -9931,6 +9931,28 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
C = torch.matmul(A, B)
|
||||
self.assertEqual(C, B.sum().expand(B.shape))
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest("40GB")
|
||||
def test_triu_tril_large_matrix_64bit(self, device):
|
||||
"""
|
||||
Test triu/tril with large matrices requiring 64-bit indexing.
|
||||
Regression test for https://github.com/pytorch/pytorch/issues/136611
|
||||
"""
|
||||
# 100k x 100k matrix with 10B elements requires 64-bit indexing
|
||||
q_len = 100000
|
||||
causal_mask = torch.full((q_len, q_len), float('-inf'), device=device, dtype=torch.float32)
|
||||
causal_mask.triu_(1)
|
||||
|
||||
# Verify row 42950 is correct (previously failed due to int32 overflow at row*col)
|
||||
row_42950 = causal_mask[42950]
|
||||
num_zeros = (row_42950 == 0.0).sum().item()
|
||||
expected_zeros = 42951
|
||||
self.assertEqual(num_zeros, expected_zeros)
|
||||
|
||||
# Verify last row is correct
|
||||
last_row = causal_mask[-1]
|
||||
self.assertTrue((last_row == 0.0).all())
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
|
||||
def test_triu_tril_extreme_k_values(self, device, dtype):
|
||||
"""
|
||||
|
||||
2
third_party/fbgemm
vendored
2
third_party/fbgemm
vendored
Submodule third_party/fbgemm updated: c0b988d39a...3cefe0564a
@ -201,19 +201,6 @@ for hip_platform_file in hip_platform_files:
|
||||
sources.write(line)
|
||||
print(f"{hip_platform_file} updated")
|
||||
|
||||
# NOTE: fbgemm sources needing hipify
|
||||
# fbgemm is its own project with its own build system. pytorch uses fbgemm as
|
||||
# a submodule to acquire some gpu source files but compiles only those sources
|
||||
# instead of using fbgemm's own build system. One of the source files refers
|
||||
# to a header file that is the result of running hipify, but fbgemm uses
|
||||
# slightly different hipify settings than pytorch. fbgemm normally hipifies
|
||||
# and renames tuning_cache.cuh to tuning_cache_hip.cuh, but pytorch's settings
|
||||
# for hipify puts it into its own 'hip' directory. After hipify runs below with
|
||||
# the added fbgemm file, we move it to its expected location.
|
||||
fbgemm_dir = "third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize"
|
||||
fbgemm_original = f"{fbgemm_dir}/tuning_cache.cuh"
|
||||
fbgemm_move_src = f"{fbgemm_dir}/hip/tuning_cache.cuh"
|
||||
fbgemm_move_dst = f"{fbgemm_dir}/tuning_cache_hip.cuh"
|
||||
|
||||
hipify_python.hipify(
|
||||
project_directory=proj_dir,
|
||||
@ -225,26 +212,7 @@ hipify_python.hipify(
|
||||
"torch/_inductor/codegen/cpp_wrapper_cpu.py",
|
||||
"torch/_inductor/codegen/cpp_wrapper_gpu.py",
|
||||
"torch/_inductor/codegen/wrapper.py",
|
||||
fbgemm_original,
|
||||
],
|
||||
out_of_place_only=args.out_of_place_only,
|
||||
hip_clang_launch=is_hip_clang(),
|
||||
)
|
||||
|
||||
# only update the file if it changes or doesn't exist
|
||||
do_write = True
|
||||
src_lines = None
|
||||
with open(fbgemm_move_src) as src:
|
||||
src_lines = src.readlines()
|
||||
if os.path.exists(fbgemm_move_dst):
|
||||
dst_lines = None
|
||||
with open(fbgemm_move_dst) as dst:
|
||||
dst_lines = dst.readlines()
|
||||
if src_lines == dst_lines:
|
||||
print(f"{fbgemm_move_dst} skipped")
|
||||
do_write = False
|
||||
if do_write:
|
||||
with open(fbgemm_move_dst, "w") as dst:
|
||||
for line in src_lines:
|
||||
dst.write(line)
|
||||
print(f"{fbgemm_move_dst} updated")
|
||||
|
||||
@ -1353,30 +1353,23 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
# NB: 0 is predicate
|
||||
ix = 1 if branch else 2
|
||||
# TODO: Support kwargs
|
||||
ra_context = contextlib.nullcontext()
|
||||
if hasattr(args[0], "sym_num"):
|
||||
pred = args[0].sym_num.node.expr
|
||||
prelude = pred if branch else ~pred
|
||||
ra_context = tx.output.shape_env.patch_ra_prelude(prelude)
|
||||
|
||||
with ra_context:
|
||||
(
|
||||
(ret_val, ret_spec),
|
||||
ret_graph,
|
||||
ret_lifted_freevars,
|
||||
) = speculate_subgraph(
|
||||
tx,
|
||||
args[ix],
|
||||
operands_seq,
|
||||
{},
|
||||
"cond",
|
||||
source_target=self.value,
|
||||
should_flatten_outputs=True,
|
||||
# TODO - removing consts from control flow ops need more work
|
||||
remove_consts_from_outputs=False,
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
(
|
||||
(ret_val, ret_spec),
|
||||
ret_graph,
|
||||
ret_lifted_freevars,
|
||||
) = speculate_subgraph(
|
||||
tx,
|
||||
args[ix],
|
||||
operands_seq,
|
||||
{},
|
||||
"cond",
|
||||
source_target=self.value,
|
||||
should_flatten_outputs=True,
|
||||
# TODO - removing consts from control flow ops need more work
|
||||
remove_consts_from_outputs=False,
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
|
||||
if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)):
|
||||
unimplemented(
|
||||
|
||||
@ -200,10 +200,9 @@ class SuperVariable(VariableTracker):
|
||||
and not (args or kwargs)
|
||||
):
|
||||
with do_not_convert_to_tracable_parameter():
|
||||
fn_vt = VariableTracker.build(
|
||||
tx, unpatched_nn_module_init, source=source
|
||||
)
|
||||
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
|
||||
return variables.UserFunctionVariable(
|
||||
unpatched_nn_module_init, source=source
|
||||
).call_function(tx, [self.objvar] + args, kwargs)
|
||||
else:
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported super().__init__() call",
|
||||
@ -231,8 +230,9 @@ class SuperVariable(VariableTracker):
|
||||
elif isinstance(inner_fn, staticmethod) and isinstance(
|
||||
inner_fn.__func__, types.FunctionType
|
||||
):
|
||||
fn_vt = VariableTracker.build(tx, inner_fn.__func__, source=source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn.__func__, source=source
|
||||
).call_function(tx, args, kwargs)
|
||||
elif isinstance(inner_fn, classmethod) and isinstance(
|
||||
inner_fn.__func__, types.FunctionType
|
||||
):
|
||||
@ -255,13 +255,13 @@ class SuperVariable(VariableTracker):
|
||||
tx, self.objvar.value_type, cls_source
|
||||
)
|
||||
|
||||
fn_vt = VariableTracker.build(
|
||||
tx, inner_fn.__func__, source=AttrSource(source, "__func__")
|
||||
)
|
||||
return fn_vt.call_function(tx, [cls_variable, *args], kwargs)
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn.__func__, source=AttrSource(source, "__func__")
|
||||
).call_function(tx, [cls_variable, *args], kwargs)
|
||||
elif isinstance(inner_fn, types.FunctionType):
|
||||
fn_vt = VariableTracker.build(tx, inner_fn, source=source)
|
||||
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn, source=source
|
||||
).call_function(tx, [self.objvar] + args, kwargs)
|
||||
elif isinstance(inner_fn, types.MethodType):
|
||||
return variables.UserMethodVariable(
|
||||
inner_fn.__func__, self.objvar, source=source
|
||||
@ -574,8 +574,10 @@ class ComptimeVariable(VariableTracker):
|
||||
from ..comptime import comptime
|
||||
|
||||
# To support the comptime.print_graph convenience accessors
|
||||
return VariableTracker.build(
|
||||
tx, getattr(comptime, name), source=AttrSource(self.source, name)
|
||||
from .functions import UserFunctionVariable
|
||||
|
||||
return UserFunctionVariable(
|
||||
getattr(comptime, name), source=AttrSource(self.source, name)
|
||||
)
|
||||
|
||||
def call_function(
|
||||
@ -769,8 +771,9 @@ class AutogradFunctionVariable(VariableTracker):
|
||||
sig = inspect.signature(fn)
|
||||
if len(args) - 1 == len(sig._parameters):
|
||||
args = args[1:] # Don't use context
|
||||
fn_vt = VariableTracker.build(tx, fn, source=source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
return variables.UserFunctionVariable(fn, source=source).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
elif isinstance(fn, types.MethodType):
|
||||
return variables.UserMethodVariable(
|
||||
fn.__func__,
|
||||
@ -796,8 +799,9 @@ class AutogradFunctionVariable(VariableTracker):
|
||||
assert isinstance(fn, types.FunctionType)
|
||||
|
||||
fn_source = AttrSource(self.source, "backward")
|
||||
fn_vt = VariableTracker.build(tx, fn, source=fn_source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
|
||||
def call_function(self, tx: "InstructionTranslator", args, kwargs):
|
||||
return AutogradFunctionVariable(self.fn_cls)
|
||||
@ -1022,12 +1026,10 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
|
||||
assert tx.one_graph or tx.error_on_graph_break, (
|
||||
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
|
||||
)
|
||||
fn_vt = VariableTracker.build(
|
||||
tx,
|
||||
return variables.UserFunctionVariable(
|
||||
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
|
||||
source=self.source,
|
||||
)
|
||||
return fn_vt.call_function(
|
||||
).call_function(
|
||||
tx,
|
||||
(tx.output.side_effects.get_ca_final_callbacks_var(), *args),
|
||||
kwargs,
|
||||
|
||||
@ -293,8 +293,9 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
return VariableTracker.build(tx, obj.__get__(self.value), source)
|
||||
elif isinstance(obj, classmethod):
|
||||
if isinstance(obj.__func__, property):
|
||||
fget_vt = VariableTracker.build(tx, obj.__func__.fget)
|
||||
return fget_vt.call_function(tx, [self], {})
|
||||
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
|
||||
tx, [self], {}
|
||||
)
|
||||
return variables.UserMethodVariable(obj.__func__, self, source=source)
|
||||
elif isinstance(obj, types.ClassMethodDescriptorType):
|
||||
# e.g.: inspect.getattr_static(dict, "fromkeys")
|
||||
@ -1788,7 +1789,7 @@ class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
fn_variable = VariableTracker.build(tx, self.value.forward.__func__)
|
||||
fn_variable = variables.UserFunctionVariable(self.value.forward.__func__)
|
||||
args = [self] + args
|
||||
return tx.inline_user_function_return(
|
||||
fn_variable,
|
||||
|
||||
@ -951,7 +951,8 @@ class TritonCSEVariable(CSEVariable):
|
||||
# We'll use this to track which masks the variable needs when used for indirect indexing
|
||||
self.mask_vars: OrderedSet[str] = OrderedSet()
|
||||
assert dtype is not None, "TritonCSEVariable must have dtype"
|
||||
assert shape is not None, "TritonCSEVariable must have shape"
|
||||
# TODO: uncomment this and fix the few failures left
|
||||
# assert shape is not None, "TritonCSEVariable must have shape"
|
||||
|
||||
def update_on_args(self, name, args, kwargs):
|
||||
for arg in args:
|
||||
|
||||
@ -628,7 +628,7 @@ class ComboKernel(Kernel):
|
||||
if heuristics == "foreach":
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.foreach(
|
||||
num_warps={self.num_warps},
|
||||
filename=__file__,
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -36,6 +37,7 @@ from ...lowering import (
|
||||
to_dtype,
|
||||
)
|
||||
from ...select_algorithm import realize_inputs
|
||||
from ...utils import load_template
|
||||
|
||||
|
||||
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
|
||||
@ -337,13 +339,8 @@ def next_power_of_two(n):
|
||||
return 2 ** math.ceil(math.log2(n))
|
||||
|
||||
|
||||
_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
|
||||
|
||||
def load_template(name: str) -> str:
|
||||
"""Load a template file and return its content."""
|
||||
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
|
||||
return f.read()
|
||||
_FLEX_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_flex_template = partial(load_template, template_dir=_FLEX_TEMPLATE_DIR)
|
||||
|
||||
|
||||
# Template strings have been moved to templates/common.py.jinja
|
||||
|
||||
@ -29,7 +29,7 @@ from .common import (
|
||||
freeze_irnodes,
|
||||
get_fwd_subgraph_outputs,
|
||||
infer_dense_strides,
|
||||
load_template,
|
||||
load_flex_template,
|
||||
maybe_realize,
|
||||
set_head_dim_values,
|
||||
SubgraphResults,
|
||||
@ -79,9 +79,9 @@ def get_float32_precision():
|
||||
flex_attention_template = TritonTemplate(
|
||||
name="flex_attention",
|
||||
grid=flex_attention_grid,
|
||||
source=load_template("flex_attention")
|
||||
+ load_template("utilities")
|
||||
+ load_template("common"),
|
||||
source=load_flex_template("flex_attention")
|
||||
+ load_flex_template("utilities")
|
||||
+ load_flex_template("common"),
|
||||
)
|
||||
|
||||
|
||||
@ -469,7 +469,7 @@ def flex_attention_backward_grid(
|
||||
flex_attention_backward_template = TritonTemplate(
|
||||
name="flex_attention_backward",
|
||||
grid=flex_attention_backward_grid,
|
||||
source=load_template("flex_backwards") + load_template("utilities"),
|
||||
source=load_flex_template("flex_backwards") + load_flex_template("utilities"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ from .common import (
|
||||
create_num_blocks_fake_generator,
|
||||
freeze_irnodes,
|
||||
get_fwd_subgraph_outputs,
|
||||
load_template,
|
||||
load_flex_template,
|
||||
maybe_realize,
|
||||
set_head_dim_values,
|
||||
)
|
||||
@ -97,9 +97,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
|
||||
flex_decoding_template = TritonTemplate(
|
||||
name="flex_decoding",
|
||||
grid=flex_decoding_grid,
|
||||
source=load_template("flex_decode")
|
||||
+ load_template("utilities")
|
||||
+ load_template("common"),
|
||||
source=load_flex_template("flex_decode")
|
||||
+ load_flex_template("utilities")
|
||||
+ load_flex_template("common"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from torch.fx import GraphModule
|
||||
|
||||
from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox
|
||||
from ...lowering import empty_strided
|
||||
from .common import infer_dense_strides, load_template, SubgraphResults
|
||||
from .common import infer_dense_strides, load_flex_template, SubgraphResults
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -36,7 +36,7 @@ from ...codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
|
||||
|
||||
flash_attention_cutedsl_template = CuteDSLTemplate(
|
||||
name="flash_attention_cutedsl", source=load_template("flash_attention")
|
||||
name="flash_attention_cutedsl", source=load_flex_template("flash_attention")
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -3496,13 +3496,24 @@ def user_autotune(
|
||||
)
|
||||
|
||||
|
||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
||||
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton foreach kernel
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Naive autotuning path for num_warps
|
||||
if disable_pointwise_autotuning(inductor_meta) and not (
|
||||
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||
else:
|
||||
for warps in [1, 2, 4, 8]:
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||
|
||||
return cached_autotune(
|
||||
None,
|
||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
||||
configs,
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
||||
@ -67,6 +67,9 @@ from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_flatten, tree_map_only
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
OPTIMUS_EXCLUDE_POST_GRAD = [
|
||||
"activation_quantization_aten_pass",
|
||||
"inductor_autotune_lookup_table",
|
||||
@ -3885,3 +3888,10 @@ def is_nonfreeable_buffers(dep: Dep) -> bool:
|
||||
return dep_name.startswith(
|
||||
("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
|
||||
)
|
||||
|
||||
|
||||
# Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them
|
||||
def load_template(name: str, template_dir: Path) -> str:
|
||||
"""Load a template file and return its content."""
|
||||
with open(template_dir / f"{name}.py.jinja") as f:
|
||||
return f.read()
|
||||
|
||||
@ -334,37 +334,14 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
||||
from torch._higher_order_ops.utils import _has_gen_schema
|
||||
|
||||
if _has_gen_schema(self):
|
||||
try:
|
||||
schema = self.gen_schema(*args, **kwargs)
|
||||
if any(arg.is_write for arg in schema.arguments):
|
||||
raise RuntimeError(
|
||||
f"The {self.name()} HigherOrderOperator does not currently support training "
|
||||
"with in-place input or buffer mutations "
|
||||
"If you require this feature, please submit an issue to PyTorch. "
|
||||
"Alternatively, consider creating your own custom autograd.Function. "
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "Expected cond to be True, but got False" in str(e):
|
||||
# Although we attempt to detect in-place input or buffer mutations,
|
||||
# the current approach in CondOp::gen_schema is not fully reliable.
|
||||
# Specifically, we invoke materialize_as_graph on both the true and false
|
||||
# subgraphs with the provided inputs at runtime (not compile time).
|
||||
# This can lead to unintended side effects: for example, consider the following code:
|
||||
#
|
||||
# def nop(x, w):
|
||||
# torch._check(x.shape[0] == 0)
|
||||
#
|
||||
# torch.cond(x.shape[0] > 0, compute, nop, (x, w))
|
||||
#
|
||||
# If, at runtime, x.shape[0] > 0, the assertion in nop will be triggered,
|
||||
# even though that branch is not actually taken. As a result, strictly enforcing
|
||||
# a hard failure based on this check would incorrectly penalize valid programs
|
||||
# due to the unsoundness of our detection mechanism. Therefore, rather than
|
||||
# failing outright, we conservatively proceed under the assumption that there
|
||||
# are no in-place input or buffer mutations.
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
schema = self.gen_schema(*args, **kwargs)
|
||||
if any(arg.is_write for arg in schema.arguments):
|
||||
raise RuntimeError(
|
||||
f"The {self.name()} HigherOrderOperator does not currently support training "
|
||||
"with in-place input or buffer mutations "
|
||||
"If you require this feature, please submit an issue to PyTorch. "
|
||||
"Alternatively, consider creating your own custom autograd.Function. "
|
||||
)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
@ -151,7 +151,7 @@ static PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) {
|
||||
|
||||
static PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPDevice*>(_self);
|
||||
auto self = (THPDevice*)_self;
|
||||
auto ret = THPObjectPtr{PyTuple_New(2)};
|
||||
if (!ret)
|
||||
throw python_error();
|
||||
@ -221,16 +221,8 @@ typedef PyObject* (*getter)(PyObject*, void*);
|
||||
// NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in
|
||||
|
||||
static const std::initializer_list<PyGetSetDef> THPDevice_properties = {
|
||||
{"type",
|
||||
reinterpret_cast<getter>(THPDevice_type),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"index",
|
||||
reinterpret_cast<getter>(THPDevice_index),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
|
||||
{"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
static const std::initializer_list<PyMethodDef> THPDevice_methods = {
|
||||
@ -250,18 +242,18 @@ PyTypeObject THPDeviceType = {
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
reinterpret_cast<reprfunc>(THPDevice_repr), /* tp_repr */
|
||||
(reprfunc)THPDevice_repr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
reinterpret_cast<hashfunc>(THPDevice_hash), /* tp_hash */
|
||||
(hashfunc)THPDevice_hash, /* tp_hash */
|
||||
// TODO: We're not sure if this is a good idea or not, because making
|
||||
// torch.device callable means that it will start returning true
|
||||
// for callable() queries, and that is unexpected. We can always add
|
||||
// this later, so for now, don't actually implement this
|
||||
// THPDevice_call, /* tp_call */
|
||||
nullptr, /* tp_call */
|
||||
reinterpret_cast<reprfunc>(THPDevice_str), /* tp_str */
|
||||
(reprfunc)THPDevice_str, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
@ -269,7 +261,7 @@ PyTypeObject THPDeviceType = {
|
||||
nullptr, /* tp_doc */
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
static_cast<richcmpfunc>(THPDevice_rc), /* tp_richcompare */
|
||||
(richcmpfunc)THPDevice_rc, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
@ -294,8 +286,7 @@ void THPDevice_init(PyObject* module) {
|
||||
}
|
||||
Py_INCREF(&THPDeviceType);
|
||||
THPUpperModuleOfDevice = module;
|
||||
if (PyModule_AddObject(
|
||||
module, "device", reinterpret_cast<PyObject*>(&THPDeviceType)) != 0) {
|
||||
if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,14 +69,14 @@ static PyObject* THPDtype_reduce(PyObject* _self, PyObject* noargs) {
|
||||
* For singletons, a string is returned. The string should be interpreted
|
||||
* as the name of a global variable.
|
||||
*/
|
||||
auto self = reinterpret_cast<THPDtype*>(_self);
|
||||
auto self = (THPDtype*)_self;
|
||||
return THPUtils_packString(self->name);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPDtype_to_real(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto* self = reinterpret_cast<THPDtype*>(_self);
|
||||
auto* self = (THPDtype*)_self;
|
||||
auto scalar_type = self->scalar_type;
|
||||
if (!at::isFloatingType(self->scalar_type)) {
|
||||
scalar_type = at::toRealValueType(self->scalar_type);
|
||||
@ -87,7 +87,7 @@ static PyObject* THPDtype_to_real(PyObject* _self, PyObject* noargs) {
|
||||
|
||||
static PyObject* THPDtype_to_complex(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto* self = reinterpret_cast<THPDtype*>(_self);
|
||||
auto* self = (THPDtype*)_self;
|
||||
auto scalar_type = self->scalar_type;
|
||||
if (!at::isComplexType(self->scalar_type)) {
|
||||
scalar_type = at::toComplexType(self->scalar_type);
|
||||
@ -100,25 +100,13 @@ typedef PyObject* (*getter)(PyObject*, void*);
|
||||
|
||||
static const std::initializer_list<PyGetSetDef> THPDtype_properties = {
|
||||
{"is_floating_point",
|
||||
reinterpret_cast<getter>(THPDtype_is_floating_point),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"is_complex",
|
||||
reinterpret_cast<getter>(THPDtype_is_complex),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"is_signed",
|
||||
reinterpret_cast<getter>(THPDtype_is_signed),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"itemsize",
|
||||
reinterpret_cast<getter>(THPDtype_itemsize),
|
||||
(getter)THPDtype_is_floating_point,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"is_complex", (getter)THPDtype_is_complex, nullptr, nullptr, nullptr},
|
||||
{"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr},
|
||||
{"itemsize", (getter)THPDtype_itemsize, nullptr, nullptr, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
static const std::initializer_list<PyMethodDef> THPDtype_methods = {
|
||||
@ -142,7 +130,7 @@ PyTypeObject THPDtypeType = {
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
reinterpret_cast<reprfunc>(THPDtype_repr), /* tp_repr */
|
||||
(reprfunc)THPDtype_repr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
@ -202,8 +190,7 @@ void THPDtype_init(PyObject* module) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THPDtypeType);
|
||||
if (PyModule_AddObject(
|
||||
module, "dtype", reinterpret_cast<PyObject*>(&THPDtypeType)) != 0) {
|
||||
if (PyModule_AddObject(module, "dtype", (PyObject*)&THPDtypeType) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,7 +48,7 @@ static PyObject* THPEvent_pynew(
|
||||
TORCH_CHECK(ptr, "Failed to allocate memory for Event");
|
||||
}
|
||||
|
||||
THPEvent* self = reinterpret_cast<THPEvent*>(ptr.get());
|
||||
THPEvent* self = (THPEvent*)ptr.get();
|
||||
|
||||
// TODO: blocking and interprocess are not supported yet. To support them, the
|
||||
// flag system of c10::Event needs to be refactored. C10::Event should also
|
||||
@ -64,7 +64,7 @@ static PyObject* THPEvent_pynew(
|
||||
(enable_timing ? c10::EventFlag::BACKEND_DEFAULT
|
||||
: c10::EventFlag::PYTORCH_DEFAULT));
|
||||
|
||||
return static_cast<PyObject*>(ptr.release());
|
||||
return (PyObject*)ptr.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -82,7 +82,7 @@ static void THPEvent_dealloc(THPEvent* self) {
|
||||
pybind11::gil_scoped_release no_gil{};
|
||||
self->event.~Event();
|
||||
}
|
||||
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
}
|
||||
|
||||
static PyObject* THPEvent_get_device(THPEvent* self, void* unused) {
|
||||
@ -96,7 +96,7 @@ static PyObject* THPEvent_record(
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPEvent*>(_self);
|
||||
auto self = (THPEvent*)_self;
|
||||
PyObject* _stream = Py_None;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
constexpr const char* accepted_args[] = {"stream", nullptr};
|
||||
@ -111,7 +111,7 @@ static PyObject* THPEvent_record(
|
||||
return nullptr;
|
||||
}
|
||||
if (_stream != Py_None) {
|
||||
auto stream = reinterpret_cast<THPStream*>(_stream);
|
||||
auto stream = (THPStream*)_stream;
|
||||
self->event.record(c10::Stream::unpack3(
|
||||
stream->stream_id,
|
||||
static_cast<c10::DeviceIndex>(stream->device_index),
|
||||
@ -130,7 +130,7 @@ static PyObject* THPEvent_from_ipc_handle(
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto type = reinterpret_cast<PyTypeObject*>(_type);
|
||||
auto type = (PyTypeObject*)_type;
|
||||
|
||||
static torch::PythonArgParser parser({
|
||||
"from_ipc_handle(Device device, std::string ipc_handle)",
|
||||
@ -146,13 +146,13 @@ static PyObject* THPEvent_from_ipc_handle(
|
||||
if (!ptr) {
|
||||
return nullptr;
|
||||
}
|
||||
THPEvent* self = reinterpret_cast<THPEvent*>(ptr.get());
|
||||
THPEvent* self = (THPEvent*)ptr.get();
|
||||
|
||||
// TODO: for constructing event from ipc handle, the c10::Event needs to have
|
||||
// more general constructor to achieve that.
|
||||
new (&self->event) c10::Event(device.type(), c10::EventFlag::PYTORCH_DEFAULT);
|
||||
|
||||
return static_cast<PyObject*>(ptr.release());
|
||||
return (PyObject*)ptr.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -174,7 +174,7 @@ static PyObject* THPEvent_wait(
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS {
|
||||
auto self = reinterpret_cast<THPEvent*>(_self);
|
||||
auto self = (THPEvent*)_self;
|
||||
PyObject* _stream = Py_None;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
constexpr const char* accepted_args[] = {"stream", nullptr};
|
||||
@ -189,7 +189,7 @@ static PyObject* THPEvent_wait(
|
||||
return nullptr;
|
||||
}
|
||||
if (_stream != Py_None) {
|
||||
auto stream = reinterpret_cast<THPStream*>(_stream);
|
||||
auto stream = (THPStream*)_stream;
|
||||
self->event.block(c10::Stream::unpack3(
|
||||
stream->stream_id,
|
||||
static_cast<c10::DeviceIndex>(stream->device_index),
|
||||
@ -206,15 +206,15 @@ static PyObject* THPEvent_wait(
|
||||
|
||||
static PyObject* THPEvent_query(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPEvent*>(_self);
|
||||
auto self = (THPEvent*)_self;
|
||||
return PyBool_FromLong(self->event.query());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPEvent_elapsed_time(PyObject* _self, PyObject* _other) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPEvent*>(_self);
|
||||
auto other = reinterpret_cast<THPEvent*>(_other);
|
||||
auto self = (THPEvent*)_self;
|
||||
auto other = (THPEvent*)_other;
|
||||
return PyFloat_FromDouble(self->event.elapsedTime(other->event));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -222,7 +222,7 @@ static PyObject* THPEvent_elapsed_time(PyObject* _self, PyObject* _other) {
|
||||
static PyObject* THPEvent_synchronize(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS {
|
||||
pybind11::gil_scoped_release no_gil{};
|
||||
auto self = reinterpret_cast<THPEvent*>(_self);
|
||||
auto self = (THPEvent*)_self;
|
||||
self->event.synchronize();
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
@ -231,7 +231,7 @@ static PyObject* THPEvent_synchronize(PyObject* _self, PyObject* noargs) {
|
||||
|
||||
static PyObject* THPEvent_evend_id(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPEvent*>(_self);
|
||||
auto self = (THPEvent*)_self;
|
||||
return PyLong_FromVoidPtr(self->event.eventId());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -251,16 +251,8 @@ static PyObject* THPEvent_repr(THPEvent* self) {
|
||||
|
||||
// NOLINTNEXTLINE(*c-arrays*, *global-variables)
|
||||
static struct PyGetSetDef THPEvent_properties[] = {
|
||||
{"device",
|
||||
reinterpret_cast<getter>(THPEvent_get_device),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"event_id",
|
||||
reinterpret_cast<getter>(THPEvent_evend_id),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"device", (getter)THPEvent_get_device, nullptr, nullptr, nullptr},
|
||||
{"event_id", (getter)THPEvent_evend_id, nullptr, nullptr, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
// NOLINTNEXTLINE(*c-arrays*, *global-variables)
|
||||
@ -288,12 +280,12 @@ PyTypeObject THPEventType = {
|
||||
"torch.Event", /* tp_name */
|
||||
sizeof(THPEvent), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
reinterpret_cast<destructor>(THPEvent_dealloc), /* tp_dealloc */
|
||||
(destructor)THPEvent_dealloc, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
reinterpret_cast<reprfunc>(THPEvent_repr), /* tp_repr */
|
||||
(reprfunc)THPEvent_repr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
@ -330,8 +322,7 @@ void THPEvent_init(PyObject* module) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THPEventType);
|
||||
if (PyModule_AddObject(
|
||||
module, "Event", reinterpret_cast<PyObject*>(&THPEventType)) < 0) {
|
||||
if (PyModule_AddObject(module, "Event", (PyObject*)&THPEventType) < 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,8 +65,7 @@ could not be completed because the input matrix is singular.",
|
||||
"Exception raised when device is out of memory",
|
||||
PyExc_RuntimeError,
|
||||
nullptr));
|
||||
PyTypeObject* type =
|
||||
reinterpret_cast<PyTypeObject*>(THPException_OutOfMemoryError);
|
||||
PyTypeObject* type = (PyTypeObject*)THPException_OutOfMemoryError;
|
||||
type->tp_name = "torch.OutOfMemoryError";
|
||||
ASSERT_TRUE(
|
||||
PyModule_AddObject(
|
||||
@ -134,7 +133,7 @@ could not be completed because the input matrix is singular.",
|
||||
"Exception raised while executing on device",
|
||||
PyExc_RuntimeError,
|
||||
nullptr));
|
||||
type = reinterpret_cast<PyTypeObject*>(THPException_AcceleratorError);
|
||||
type = (PyTypeObject*)THPException_AcceleratorError;
|
||||
ASSERT_TRUE(
|
||||
PyModule_AddObject(
|
||||
module, "AcceleratorError", THPException_AcceleratorError) == 0);
|
||||
|
||||
@ -21,7 +21,7 @@ using namespace torch;
|
||||
PyObject* THPGeneratorClass = nullptr;
|
||||
|
||||
PyObject* THPGenerator_initDefaultGenerator(const at::Generator& cdata) {
|
||||
auto type = reinterpret_cast<PyTypeObject*>(THPGeneratorClass);
|
||||
auto type = (PyTypeObject*)THPGeneratorClass;
|
||||
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
|
||||
if (!self)
|
||||
throw python_error();
|
||||
@ -49,8 +49,7 @@ static PyObject* THPGenerator_pynew(
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto device = r.deviceWithDefault(0, at::Device(at::kCPU));
|
||||
|
||||
THPGeneratorPtr self(
|
||||
reinterpret_cast<THPGenerator*>(type->tp_alloc(type, 0)));
|
||||
THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0));
|
||||
|
||||
c10::DeviceType device_type = device.type();
|
||||
if (device_type == at::kCPU) {
|
||||
@ -61,14 +60,14 @@ static PyObject* THPGenerator_pynew(
|
||||
.getNewGenerator(device.index());
|
||||
}
|
||||
|
||||
return reinterpret_cast<PyObject*>(self.release());
|
||||
return (PyObject*)self.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPGenerator_getState(PyObject* _self, PyObject* noargs) {
|
||||
using namespace torch::autograd;
|
||||
HANDLE_TH_ERRORS
|
||||
auto& gen = (reinterpret_cast<THPGenerator*>(_self))->cdata;
|
||||
auto& gen = ((THPGenerator*)_self)->cdata;
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::scoped_lock<std::mutex> lock(gen.mutex());
|
||||
@ -89,7 +88,7 @@ static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) {
|
||||
"expected a torch.ByteTensor, but got {}",
|
||||
Py_TYPE(_new_state)->tp_name));
|
||||
}
|
||||
auto self = reinterpret_cast<THPGenerator*>(_self);
|
||||
auto self = (THPGenerator*)_self;
|
||||
auto& gen = self->cdata;
|
||||
const auto& new_state_tensor = THPVariable_Unpack(_new_state);
|
||||
|
||||
@ -98,7 +97,7 @@ static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) {
|
||||
gen.set_state(new_state_tensor);
|
||||
|
||||
Py_INCREF(self);
|
||||
return reinterpret_cast<PyObject*>(self);
|
||||
return (PyObject*)self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -126,7 +125,7 @@ static PyObject* THPGenerator_graphSafeGetState(
|
||||
PyObject* _self,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto& gen = (reinterpret_cast<THPGenerator*>(_self))->cdata;
|
||||
auto& gen = ((THPGenerator*)_self)->cdata;
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::scoped_lock<std::mutex> lock(gen.mutex());
|
||||
@ -139,7 +138,7 @@ static PyObject* THPGenerator_graphSafeSetState(
|
||||
PyObject* _self,
|
||||
PyObject* _state) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPGenerator*>(_self);
|
||||
auto self = (THPGenerator*)_self;
|
||||
auto& gen = self->cdata;
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
@ -147,13 +146,13 @@ static PyObject* THPGenerator_graphSafeSetState(
|
||||
gen.graphsafe_set_state(THPGenerator_Unwrap(_state));
|
||||
|
||||
Py_INCREF(self);
|
||||
return reinterpret_cast<PyObject*>(self);
|
||||
return (PyObject*)self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPGenerator_cloneState(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto& gen = (reinterpret_cast<THPGenerator*>(_self))->cdata;
|
||||
auto& gen = ((THPGenerator*)_self)->cdata;
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::scoped_lock<std::mutex> lock(gen.mutex());
|
||||
@ -164,7 +163,7 @@ static PyObject* THPGenerator_cloneState(PyObject* _self, PyObject* noargs) {
|
||||
|
||||
static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPGenerator*>(_self);
|
||||
auto self = (THPGenerator*)_self;
|
||||
auto generator = self->cdata;
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(seed),
|
||||
@ -176,13 +175,13 @@ static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) {
|
||||
std::scoped_lock<std::mutex> lock(generator.mutex());
|
||||
generator.set_current_seed(unsigned_seed);
|
||||
Py_INCREF(self);
|
||||
return reinterpret_cast<PyObject*>(self);
|
||||
return (PyObject*)self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPGenerator_setOffset(PyObject* _self, PyObject* offset) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPGenerator*>(_self);
|
||||
auto self = (THPGenerator*)_self;
|
||||
auto generator = self->cdata;
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(offset),
|
||||
@ -194,14 +193,14 @@ static PyObject* THPGenerator_setOffset(PyObject* _self, PyObject* offset) {
|
||||
std::scoped_lock<std::mutex> lock(generator.mutex());
|
||||
generator.set_offset(unsigned_offset);
|
||||
Py_INCREF(self);
|
||||
return reinterpret_cast<PyObject*>(self);
|
||||
return (PyObject*)self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPGenerator_seed(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
// See Note [Acquire lock when using random generators]
|
||||
auto self = reinterpret_cast<THPGenerator*>(_self);
|
||||
auto self = (THPGenerator*)_self;
|
||||
std::scoped_lock<std::mutex> lock(self->cdata.mutex());
|
||||
uint64_t seed_val = self->cdata.seed();
|
||||
return THPUtils_packUInt64(seed_val);
|
||||
@ -210,14 +209,14 @@ static PyObject* THPGenerator_seed(PyObject* _self, PyObject* noargs) {
|
||||
|
||||
static PyObject* THPGenerator_initialSeed(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPGenerator*>(_self);
|
||||
auto self = (THPGenerator*)_self;
|
||||
return THPUtils_packUInt64(self->cdata.current_seed());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPGenerator_getOffset(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPGenerator*>(_self);
|
||||
auto self = (THPGenerator*)_self;
|
||||
return THPUtils_packUInt64(self->cdata.get_offset());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -230,7 +229,7 @@ static PyObject* THPGenerator_get_device(THPGenerator* self, void* unused) {
|
||||
|
||||
static PyObject* THPGenerator_reduce(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPGenerator*>(_self);
|
||||
auto self = (THPGenerator*)_self;
|
||||
auto& gen = self->cdata;
|
||||
|
||||
auto ret = THPObjectPtr{PyTuple_New(3)};
|
||||
@ -280,11 +279,7 @@ static PyObject* THPGenerator_pickleSetState(PyObject* _self, PyObject* state) {
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static struct PyGetSetDef THPGenerator_properties[] = {
|
||||
{"device",
|
||||
reinterpret_cast<getter>(THPGenerator_get_device),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"device", (getter)THPGenerator_get_device, nullptr, nullptr, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||
@ -354,12 +349,11 @@ static PyTypeObject THPGeneratorType = {
|
||||
};
|
||||
|
||||
bool THPGenerator_init(PyObject* module) {
|
||||
THPGeneratorClass = reinterpret_cast<PyObject*>(&THPGeneratorType);
|
||||
THPGeneratorClass = (PyObject*)&THPGeneratorType;
|
||||
if (PyType_Ready(&THPGeneratorType) < 0)
|
||||
return false;
|
||||
Py_INCREF(&THPGeneratorType);
|
||||
PyModule_AddObject(
|
||||
module, "Generator", reinterpret_cast<PyObject*>(&THPGeneratorType));
|
||||
PyModule_AddObject(module, "Generator", (PyObject*)&THPGeneratorType);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -383,8 +377,7 @@ PyObject* THPGenerator_Wrap(const Generator& gen) {
|
||||
return obj;
|
||||
}
|
||||
|
||||
return THPGenerator_NewWithVar(
|
||||
reinterpret_cast<PyTypeObject*>(THPGeneratorClass), gen);
|
||||
return THPGenerator_NewWithVar((PyTypeObject*)THPGeneratorClass, gen);
|
||||
}
|
||||
|
||||
at::Generator THPGenerator_Unwrap(PyObject* state) {
|
||||
@ -402,7 +395,7 @@ at::Generator THPGenerator_Unwrap(PyObject* state) {
|
||||
PyObject* THPGenerator_NewWithVar(PyTypeObject* type, Generator gen) {
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
if (obj) {
|
||||
auto g = reinterpret_cast<THPGenerator*>(obj);
|
||||
auto g = (THPGenerator*)obj;
|
||||
new (&g->cdata) Generator(std::move(gen));
|
||||
set_pyobj(g->cdata, obj);
|
||||
}
|
||||
|
||||
@ -36,7 +36,7 @@ PyTypeObject THPLayoutType = {
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
reinterpret_cast<reprfunc>(THPLayout_repr), /* tp_repr */
|
||||
(reprfunc)THPLayout_repr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
@ -72,8 +72,7 @@ void THPLayout_init(PyObject* module) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THPLayoutType);
|
||||
if (PyModule_AddObject(
|
||||
module, "layout", reinterpret_cast<PyObject*>(&THPLayoutType)) != 0) {
|
||||
if (PyModule_AddObject(module, "layout", (PyObject*)&THPLayoutType) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@ static PyObject* THPMemoryFormat_repr(THPMemoryFormat* self) {
|
||||
}
|
||||
|
||||
static PyObject* THPMemoryFormat_reduce(PyObject* _self, PyObject* noargs) {
|
||||
auto* self = reinterpret_cast<THPMemoryFormat*>(_self);
|
||||
auto* self = (THPMemoryFormat*)_self;
|
||||
return THPUtils_packString(self->name);
|
||||
}
|
||||
|
||||
@ -49,7 +49,7 @@ PyTypeObject THPMemoryFormatType = {
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
reinterpret_cast<reprfunc>(THPMemoryFormat_repr), /* tp_repr */
|
||||
(reprfunc)THPMemoryFormat_repr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
@ -86,9 +86,7 @@ void THPMemoryFormat_init(PyObject* module) {
|
||||
}
|
||||
Py_INCREF(&THPMemoryFormatType);
|
||||
if (PyModule_AddObject(
|
||||
module,
|
||||
"memory_format",
|
||||
reinterpret_cast<PyObject*>(&THPMemoryFormatType)) != 0) {
|
||||
module, "memory_format", (PyObject*)&THPMemoryFormatType) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
@ -166,7 +166,7 @@ static PyObject* THPModule_initNames(PyObject* self, PyObject* arg) {
|
||||
for (Py_ssize_t i = 0; i < num_classes; i++) {
|
||||
PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i);
|
||||
TORCH_CHECK(PyType_Check(obj), "expected a PyTypeObject");
|
||||
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(obj);
|
||||
PyTypeObject* type = (PyTypeObject*)obj;
|
||||
|
||||
THPObjectPtr module_name(PyObject_GetAttrString(obj, "__module__"));
|
||||
if (!module_name)
|
||||
@ -268,7 +268,7 @@ static PyObject* THPModule_crashIfCsrcUBSAN(PyObject* module, PyObject* arg) {
|
||||
THPUtils_typename(arg));
|
||||
int32_t x = THPUtils_unpackInt(arg);
|
||||
double y = 1.0 / x;
|
||||
return THPUtils_packInt32(static_cast<int>(y));
|
||||
return THPUtils_packInt32((int)y);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -334,7 +334,7 @@ static PyObject* THPModule_setNumThreads(PyObject* module, PyObject* arg) {
|
||||
THPUtils_checkLong(arg),
|
||||
"set_num_threads expects an int, but got ",
|
||||
THPUtils_typename(arg));
|
||||
int nthreads = THPUtils_unpackInt(arg);
|
||||
int nthreads = (int)THPUtils_unpackLong(arg);
|
||||
TORCH_CHECK(nthreads > 0, "set_num_threads expects a positive integer");
|
||||
at::set_num_threads(nthreads);
|
||||
Py_RETURN_NONE;
|
||||
@ -356,7 +356,7 @@ static PyObject* THPModule_setNumInteropThreads(
|
||||
"set_num_interop_threads expects an int, "
|
||||
"but got ",
|
||||
THPUtils_typename(arg));
|
||||
int nthreads = THPUtils_unpackInt(arg);
|
||||
int nthreads = (int)THPUtils_unpackLong(arg);
|
||||
TORCH_CHECK(
|
||||
nthreads > 0, "set_num_interop_threads expects a positive integer");
|
||||
at::set_num_interop_threads(nthreads);
|
||||
@ -448,7 +448,7 @@ static PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
|
||||
}
|
||||
|
||||
if (Py_TYPE(obj) == &PyCFunction_Type) {
|
||||
PyCFunctionObject* f = reinterpret_cast<PyCFunctionObject*>(obj);
|
||||
PyCFunctionObject* f = (PyCFunctionObject*)obj;
|
||||
if (f->m_ml->ml_doc) {
|
||||
return PyErr_Format(
|
||||
PyExc_RuntimeError,
|
||||
@ -457,7 +457,7 @@ static PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
|
||||
}
|
||||
f->m_ml->ml_doc = doc_str;
|
||||
} else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) {
|
||||
PyMethodDescrObject* m = reinterpret_cast<PyMethodDescrObject*>(obj);
|
||||
PyMethodDescrObject* m = (PyMethodDescrObject*)obj;
|
||||
if (m->d_method->ml_doc) {
|
||||
return PyErr_Format(
|
||||
PyExc_RuntimeError,
|
||||
@ -466,7 +466,8 @@ static PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
|
||||
}
|
||||
m->d_method->ml_doc = doc_str;
|
||||
} else if (strcmp(Py_TYPE(obj)->tp_name, "getset_descriptor") == 0) {
|
||||
PyGetSetDescrObject* m = reinterpret_cast<PyGetSetDescrObject*>(obj);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
|
||||
PyGetSetDescrObject* m = (PyGetSetDescrObject*)obj;
|
||||
if (m->d_getset->doc) {
|
||||
return PyErr_Format(
|
||||
PyExc_RuntimeError,
|
||||
@ -475,7 +476,7 @@ static PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
|
||||
}
|
||||
m->d_getset->doc = doc_str;
|
||||
} else if (Py_TYPE(obj) == &PyType_Type) {
|
||||
PyTypeObject* t = reinterpret_cast<PyTypeObject*>(obj);
|
||||
PyTypeObject* t = (PyTypeObject*)obj;
|
||||
if (t->tp_doc) {
|
||||
return PyErr_Format(
|
||||
PyExc_RuntimeError, "Type '%s' already has a docstring", t->tp_name);
|
||||
@ -1471,11 +1472,10 @@ static PyObject* THPModule_willEngineExecuteNode(
|
||||
torch::autograd::Node* node = nullptr;
|
||||
std::shared_ptr<torch::autograd::Node> node_sp;
|
||||
if (isTHPFunction) {
|
||||
node_sp = (reinterpret_cast<THPFunction*>(arg))->cdata.lock();
|
||||
node_sp = ((THPFunction*)arg)->cdata.lock();
|
||||
node = node_sp.get();
|
||||
} else {
|
||||
node =
|
||||
(reinterpret_cast<torch::autograd::THPCppFunction*>(arg))->cdata.get();
|
||||
node = ((torch::autograd::THPCppFunction*)arg)->cdata.get();
|
||||
}
|
||||
const auto nodes_in_graph =
|
||||
torch::autograd::get_current_graph_task_nodes_in_graph();
|
||||
@ -1905,8 +1905,7 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_has_torch_function_variadic",
|
||||
reinterpret_cast<PyCFunction>(
|
||||
reinterpret_cast<void (*)()>(THPModule_has_torch_function_variadic)),
|
||||
(PyCFunction)(void (*)())THPModule_has_torch_function_variadic,
|
||||
METH_FASTCALL,
|
||||
nullptr},
|
||||
{"_ensureCUDADeviceGuardSet",
|
||||
@ -2613,7 +2612,7 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
.getAcceleratorHooksInterface(device_type)
|
||||
.deviceCount();
|
||||
}
|
||||
return static_cast<c10::DeviceIndex>(-1);
|
||||
return c10::DeviceIndex(-1);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
@ -2634,7 +2633,7 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
.getAcceleratorHooksInterface(device_type)
|
||||
.getCurrentDevice();
|
||||
}
|
||||
return static_cast<c10::DeviceIndex>(-1);
|
||||
return c10::DeviceIndex(-1);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
@ -2645,7 +2644,7 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
.getAcceleratorHooksInterface(device_type)
|
||||
.exchangeDevice(device_index);
|
||||
}
|
||||
return static_cast<c10::DeviceIndex>(-1);
|
||||
return c10::DeviceIndex(-1);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
@ -2657,7 +2656,7 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
.getAcceleratorHooksInterface(device_type)
|
||||
.maybeExchangeDevice(device_index);
|
||||
}
|
||||
return static_cast<c10::DeviceIndex>(-1);
|
||||
return c10::DeviceIndex(-1);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
@ -2821,8 +2820,8 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
py::arg("eps"));
|
||||
|
||||
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
|
||||
THPDefaultCPUGenerator = reinterpret_cast<THPGenerator*>(
|
||||
THPGenerator_initDefaultGenerator(defaultGenerator));
|
||||
THPDefaultCPUGenerator =
|
||||
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
|
||||
// This reference is meant to be given away, so no need to incref here.
|
||||
ASSERT_TRUE(set_module_attr(
|
||||
"default_generator",
|
||||
|
||||
@ -270,7 +270,7 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
|
||||
"This probably happened because you took out a weak reference to "
|
||||
"Tensor and didn't call _fix_weakref() after dereferencing it. "
|
||||
"Subsequent accesses to this tensor via the PyObject will now fail.");
|
||||
(reinterpret_cast<THPVariable*>(pyobj))->cdata =
|
||||
((THPVariable*)pyobj)->cdata =
|
||||
c10::MaybeOwned<torch::autograd::Variable>();
|
||||
} else if (THPStorage_Check(pyobj)) {
|
||||
TORCH_WARN(
|
||||
@ -278,8 +278,7 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
|
||||
"This probably happened because you took out a weak reference to "
|
||||
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
|
||||
"Subsequent accesses to this storage via the PyObject will now fail.");
|
||||
(reinterpret_cast<THPStorage*>(pyobj))->cdata =
|
||||
c10::MaybeOwned<c10::Storage>();
|
||||
((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
|
||||
}
|
||||
}
|
||||
Py_DECREF(pyobj);
|
||||
|
||||
@ -23,7 +23,7 @@ PyObject* THPQScheme_New(at::QScheme qscheme, const std::string& name) {
|
||||
}
|
||||
|
||||
static PyObject* THPQScheme_reduce(PyObject* _self, PyObject* noargs) {
|
||||
auto self = reinterpret_cast<THPQScheme*>(_self);
|
||||
auto self = (THPQScheme*)_self;
|
||||
return THPUtils_packString(self->name);
|
||||
}
|
||||
|
||||
@ -48,7 +48,7 @@ PyTypeObject THPQSchemeType = {
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
reinterpret_cast<reprfunc>(THPQScheme_repr), /* tp_repr */
|
||||
(reprfunc)THPQScheme_repr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
@ -84,9 +84,7 @@ void THPQScheme_init(PyObject* module) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THPQSchemeType);
|
||||
if (PyModule_AddObject(
|
||||
module, "qscheme", reinterpret_cast<PyObject*>(&THPQSchemeType)) !=
|
||||
0) {
|
||||
if (PyModule_AddObject(module, "qscheme", (PyObject*)&THPQSchemeType) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,8 +133,7 @@ static PyObject* THPSize_pynew(
|
||||
static PyObject* THPSize_repr(THPSize* self) {
|
||||
HANDLE_TH_ERRORS
|
||||
std::string repr("torch.Size([");
|
||||
for (Py_ssize_t i = 0; i < PyTuple_Size(reinterpret_cast<PyObject*>(self));
|
||||
++i) {
|
||||
for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
|
||||
if (i != 0) {
|
||||
repr += ", ";
|
||||
}
|
||||
@ -157,7 +156,7 @@ static PyObject* wrap_tuple_fn(Args... args) {
|
||||
return nullptr;
|
||||
if (PyTuple_Check(result.get())) {
|
||||
return PyObject_CallFunctionObjArgs(
|
||||
reinterpret_cast<PyObject*>(&THPSizeType), result.get(), nullptr);
|
||||
(PyObject*)&THPSizeType, result.get(), nullptr);
|
||||
}
|
||||
return result.release();
|
||||
}
|
||||
@ -226,9 +225,9 @@ static PyMappingMethods THPSize_as_mapping = {
|
||||
|
||||
static PyObject* THPSize_numel(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPSize*>(_self);
|
||||
auto self = (THPSize*)_self;
|
||||
int64_t numel = 1;
|
||||
for (Py_ssize_t i = 0; i < PyTuple_Size(_self); ++i) {
|
||||
for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
|
||||
numel *= THPUtils_unpackLong(PyTuple_GET_ITEM(self, i));
|
||||
}
|
||||
return THPUtils_packInt64(numel);
|
||||
@ -237,19 +236,19 @@ static PyObject* THPSize_numel(PyObject* _self, PyObject* noargs) {
|
||||
|
||||
static PyObject* THPSize_reduce(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPSize*>(_self);
|
||||
auto self = (THPSize*)_self;
|
||||
auto ret = THPObjectPtr{PyTuple_New(2)};
|
||||
if (!ret)
|
||||
throw python_error();
|
||||
|
||||
auto obj = reinterpret_cast<PyObject*>(&THPSizeType);
|
||||
auto obj = (PyObject*)(&THPSizeType);
|
||||
Py_INCREF(&THPSizeType);
|
||||
PyTuple_SET_ITEM(ret.get(), 0, obj);
|
||||
|
||||
THPObjectPtr t(PyTuple_New(PyTuple_Size(_self)));
|
||||
THPObjectPtr t(PyTuple_New(PyTuple_Size((PyObject*)self)));
|
||||
if (!t)
|
||||
throw python_error();
|
||||
for (Py_ssize_t i = 0; i < PyTuple_Size(_self); ++i) {
|
||||
for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
|
||||
auto d = PyTuple_GET_ITEM(self, i);
|
||||
Py_INCREF(d);
|
||||
PyTuple_SET_ITEM(t.get(), i, d);
|
||||
@ -280,7 +279,7 @@ PyTypeObject THPSizeType = {
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
reinterpret_cast<reprfunc>(THPSize_repr), /* tp_repr */
|
||||
(reprfunc)THPSize_repr, /* tp_repr */
|
||||
&THPSize_as_number, /* tp_as_number */
|
||||
&THPSize_as_sequence, /* tp_as_sequence */
|
||||
&THPSize_as_mapping, /* tp_as_mapping */
|
||||
@ -316,8 +315,7 @@ void THPSize_init(PyObject* module) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THPSizeType);
|
||||
if (PyModule_AddObject(
|
||||
module, "Size", reinterpret_cast<PyObject*>(&THPSizeType)) < 0) {
|
||||
if (PyModule_AddObject(module, "Size", (PyObject*)&THPSizeType) < 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
@ -68,7 +68,7 @@ PyObject* THPStorage_NewWithStorage(
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
|
||||
|
||||
auto s = reinterpret_cast<THPStorage*>(obj);
|
||||
auto s = (THPStorage*)obj;
|
||||
|
||||
new (&s->cdata) c10::MaybeOwned<c10::Storage>();
|
||||
|
||||
@ -128,7 +128,7 @@ static bool THPStorage_isPreservable(THPStorage* self) {
|
||||
}
|
||||
|
||||
if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/true) != reinterpret_cast<PyObject*>(self)) {
|
||||
/*ignore_hermetic_tls=*/true) != (PyObject*)self) {
|
||||
return false;
|
||||
}
|
||||
if (storage.use_count() <= 1) {
|
||||
@ -170,14 +170,14 @@ static bool THPStorage_tryPreserve(THPStorage* self) {
|
||||
storage_impl->pyobj_slot()->set_owns_pyobj(true);
|
||||
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
|
||||
// ensure the PyObject is in a valid state
|
||||
_Py_NewReference(reinterpret_cast<PyObject*>(self));
|
||||
_Py_NewReference((PyObject*)self);
|
||||
|
||||
self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void THPStorage_subclass_dealloc(PyObject* self) {
|
||||
THPStorage* _self = reinterpret_cast<THPStorage*>(self);
|
||||
THPStorage* _self = (THPStorage*)self;
|
||||
|
||||
if (THPStorage_tryPreserve(_self)) {
|
||||
return;
|
||||
@ -226,8 +226,8 @@ static void THPStorage_subclass_dealloc(PyObject* self) {
|
||||
being finalized that has already been destroyed. */
|
||||
if (type->tp_weaklistoffset) {
|
||||
/* Modeled after GET_WEAKREFS_LISTPTR() */
|
||||
PyWeakReference** list = reinterpret_cast<PyWeakReference**>(
|
||||
PyObject_GET_WEAKREFS_LISTPTR(self));
|
||||
PyWeakReference** list =
|
||||
(PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
|
||||
while (*list)
|
||||
_PyWeakref_ClearRef(*list);
|
||||
}
|
||||
@ -549,9 +549,9 @@ static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
|
||||
}
|
||||
|
||||
static PyMappingMethods THPStorage_mappingmethods = {
|
||||
reinterpret_cast<lenfunc>(THPStorage_length),
|
||||
reinterpret_cast<binaryfunc>(THPStorage_get),
|
||||
reinterpret_cast<objobjargproc>(THPStorage_set)};
|
||||
(lenfunc)THPStorage_length,
|
||||
(binaryfunc)THPStorage_get,
|
||||
(objobjargproc)THPStorage_set};
|
||||
|
||||
struct THPStorageMeta {
|
||||
PyHeapTypeObject base;
|
||||
@ -653,8 +653,7 @@ int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
|
||||
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
|
||||
return -1;
|
||||
}
|
||||
(reinterpret_cast<PyTypeObject*>(cls))->tp_dealloc =
|
||||
static_cast<destructor>(THPStorage_subclass_dealloc);
|
||||
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPStorage_subclass_dealloc;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -675,16 +674,8 @@ typedef PyObject* (*getter)(PyObject*, void*);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static struct PyGetSetDef THPStorage_properties[] = {
|
||||
{"device",
|
||||
reinterpret_cast<getter>(THPStorage_device),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"_cdata",
|
||||
reinterpret_cast<getter>(THPStorage_get_cdata),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"device", (getter)THPStorage_device, nullptr, nullptr, nullptr},
|
||||
{"_cdata", (getter)THPStorage_get_cdata, nullptr, nullptr, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
bool THPStorage_init(PyObject* module) {
|
||||
@ -696,22 +687,20 @@ bool THPStorage_init(PyObject* module) {
|
||||
if (PyType_Ready(&THPStorageMetaType) < 0)
|
||||
return false;
|
||||
Py_INCREF(&THPStorageMetaType);
|
||||
PyModule_AddObject(
|
||||
module, "_StorageMeta", reinterpret_cast<PyObject*>(&THPStorageMetaType));
|
||||
PyModule_AddObject(module, "_StorageMeta", (PyObject*)&THPStorageMetaType);
|
||||
|
||||
THPStorageType.tp_methods = methods.data();
|
||||
THPStorageType.tp_getset = THPStorage_properties;
|
||||
if (PyType_Ready(&THPStorageType) < 0)
|
||||
return false;
|
||||
Py_INCREF(&THPStorageType);
|
||||
PyModule_AddObject(
|
||||
module, "StorageBase", reinterpret_cast<PyObject*>(&THPStorageType));
|
||||
PyModule_AddObject(module, "StorageBase", (PyObject*)&THPStorageType);
|
||||
return true;
|
||||
}
|
||||
|
||||
void THPStorage_postInit(PyObject* module) {
|
||||
THPStorageClass = reinterpret_cast<PyTypeObject*>(
|
||||
PyObject_GetAttrString(module, "UntypedStorage"));
|
||||
THPStorageClass =
|
||||
(PyTypeObject*)PyObject_GetAttrString(module, "UntypedStorage");
|
||||
if (!THPStorageClass)
|
||||
throw python_error();
|
||||
}
|
||||
@ -722,5 +711,5 @@ void THPStorage_assertNotNull(THPStorage* storage) {
|
||||
}
|
||||
|
||||
void THPStorage_assertNotNull(PyObject* obj) {
|
||||
THPStorage_assertNotNull(reinterpret_cast<THPStorage*>(obj));
|
||||
THPStorage_assertNotNull((THPStorage*)obj);
|
||||
}
|
||||
|
||||
@ -297,7 +297,7 @@ static PyObject* THPStorage_fromBuffer(
|
||||
size_bytes = count * element_size;
|
||||
}
|
||||
|
||||
if (offset + (count * static_cast<Py_ssize_t>(element_size)) > buffer.len) {
|
||||
if (offset + (count * (Py_ssize_t)element_size) > buffer.len) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
fmt::format(
|
||||
@ -309,7 +309,7 @@ static PyObject* THPStorage_fromBuffer(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
uint8_t* src = static_cast<uint8_t*>(buffer.buf);
|
||||
uint8_t* src = (uint8_t*)buffer.buf;
|
||||
auto fake_mode_active =
|
||||
c10::impl::TorchDispatchModeTLS::get_mode(
|
||||
c10::impl::TorchDispatchModeKey::FAKE) != std::nullopt;
|
||||
@ -508,8 +508,8 @@ static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) {
|
||||
// advanced position
|
||||
const auto fd_current_pos = LSEEK(fd, 0, SEEK_CUR);
|
||||
LSEEK(fd, fd_original_pos, SEEK_SET);
|
||||
const auto seek_return = PyObject_CallMethod(
|
||||
file, "seek", "Li", static_cast<long long>(fd_current_pos), 0);
|
||||
const auto seek_return =
|
||||
PyObject_CallMethod(file, "seek", "Li", (long long)fd_current_pos, 0);
|
||||
if (seek_return == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -521,19 +521,18 @@ static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) {
|
||||
|
||||
static PyObject* THPStorage__setCdata(PyObject* _self, PyObject* new_cdata) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPStorage*>(_self);
|
||||
auto self = (THPStorage*)_self;
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(new_cdata),
|
||||
"given an invalid argument to "
|
||||
"_set_cdata - expected an int or long, but got ",
|
||||
THPUtils_typename(new_cdata));
|
||||
c10::StorageImpl* ptr =
|
||||
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(new_cdata));
|
||||
c10::StorageImpl* ptr = (c10::StorageImpl*)PyLong_AsVoidPtr(new_cdata);
|
||||
self->cdata.~MaybeOwned<c10::Storage>();
|
||||
self->cdata = c10::MaybeOwned<c10::Storage>::owned(
|
||||
c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(ptr)));
|
||||
Py_INCREF(self);
|
||||
return reinterpret_cast<PyObject*>(self);
|
||||
return (PyObject*)self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
@ -256,7 +256,7 @@ static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) {
|
||||
"a file descriptor (int) and storage size (int)");
|
||||
return nullptr;
|
||||
}
|
||||
int tmp_fd = THPUtils_unpackInt(_tmp_fd);
|
||||
int tmp_fd = (int)THPUtils_unpackLong(_tmp_fd);
|
||||
int64_t size = THPUtils_unpackLong(_size);
|
||||
int fd = dup(tmp_fd);
|
||||
if (fd == -1) {
|
||||
@ -312,8 +312,8 @@ static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
|
||||
auto shandle =
|
||||
c10::cuda::CUDACachingAllocator::shareIpcHandle(storage.mutable_data());
|
||||
_handle = PyBytes_FromStringAndSize(
|
||||
shandle.handle.c_str(), static_cast<Py_ssize_t>(shandle.handle.size()));
|
||||
_offset_bytes = PyLong_FromSsize_t(static_cast<Py_ssize_t>(shandle.offset));
|
||||
shandle.handle.c_str(), (Py_ssize_t)shandle.handle.size());
|
||||
_offset_bytes = PyLong_FromSsize_t((Py_ssize_t)shandle.offset);
|
||||
|
||||
// Put Storage Data behind new ref counting context
|
||||
// See Note [CUDA IPC Refcounting implementation explained]
|
||||
@ -334,7 +334,7 @@ static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
|
||||
}
|
||||
|
||||
_event_handle = PyBytes_FromStringAndSize(
|
||||
reinterpret_cast<const char*>(&ipc_event_handle), CUDA_IPC_HANDLE_SIZE);
|
||||
(char*)&ipc_event_handle, CUDA_IPC_HANDLE_SIZE);
|
||||
_event_sync_required = PyBool_FromLong(sent_data->event_sync_required_);
|
||||
}
|
||||
|
||||
@ -385,7 +385,7 @@ static PyObject* THPStorage_releaseIPCCounter(
|
||||
}
|
||||
std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter);
|
||||
ptrdiff_t ref_counter_offset =
|
||||
static_cast<ptrdiff_t>(THPUtils_unpackLong(_ref_counter_offset));
|
||||
(ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset);
|
||||
// We don't want to break existing code, so resource deletion is best
|
||||
// effort basis. Exception expected if producer process terminated
|
||||
// before consumer released data.
|
||||
@ -446,9 +446,10 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t storage_size = THPUtils_unpackUInt64(_size_bytes) / sizeof(uint8_t);
|
||||
size_t storage_size =
|
||||
(size_t)THPUtils_unpackLong(_size_bytes) / sizeof(uint8_t);
|
||||
ptrdiff_t storage_offset_bytes =
|
||||
static_cast<ptrdiff_t>(THPUtils_unpackLong(_offset_bytes));
|
||||
(ptrdiff_t)THPUtils_unpackLong(_offset_bytes);
|
||||
|
||||
const auto device = c10::checked_convert<c10::DeviceIndex>(
|
||||
THPUtils_unpackLong(_device), "c10::DeviceIndex");
|
||||
@ -479,11 +480,11 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
|
||||
// Offset the basePtr to reconstruct the real storage
|
||||
// devPtr = basePtr + storage_offset
|
||||
void* devPtr = basePtr.get();
|
||||
devPtr = static_cast<char*>(devPtr) + storage_offset_bytes;
|
||||
devPtr = (char*)devPtr + storage_offset_bytes;
|
||||
|
||||
std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter);
|
||||
ptrdiff_t ref_counter_offset =
|
||||
static_cast<ptrdiff_t>(THPUtils_unpackLong(_ref_counter_offset));
|
||||
(ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset);
|
||||
|
||||
struct IpcDeleterContext {
|
||||
std::string ref_counter_handle;
|
||||
@ -577,8 +578,7 @@ static PyObject* THPStorage_newWithWeakPtr(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(arg), "_new_with_weak_ptr(): arg must be an 'int'");
|
||||
c10::StorageImpl* weak_storage =
|
||||
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(arg));
|
||||
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
|
||||
if (auto* storage = c10::raw::weak_intrusive_ptr::lock(weak_storage)) {
|
||||
return THPStorage_Wrap(
|
||||
c10::intrusive_ptr<c10::StorageImpl>::reclaim(storage));
|
||||
@ -594,8 +594,7 @@ static PyObject* THPStorage_freeWeakRef(PyObject* _unused, PyObject* arg) {
|
||||
}
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(arg), "_free_weak_ref(): arg must be an 'int'");
|
||||
c10::StorageImpl* weak_storage =
|
||||
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(arg));
|
||||
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
|
||||
c10::raw::weak_intrusive_ptr::decref(weak_storage);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
@ -605,8 +604,7 @@ static PyObject* THPStorage_freeWeakRef(PyObject* _unused, PyObject* arg) {
|
||||
static PyObject* THPStorage_expired(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "_expired(): arg must be an 'int'");
|
||||
c10::StorageImpl* weak_storage =
|
||||
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(arg));
|
||||
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
|
||||
return PyBool_FromLong(
|
||||
c10::raw::weak_intrusive_ptr::use_count(weak_storage) == 0);
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
||||
@ -74,7 +74,7 @@ static PyObject* THPStream_pynew(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
THPStream* self = reinterpret_cast<THPStream*>(ptr.get());
|
||||
THPStream* self = (THPStream*)ptr.get();
|
||||
|
||||
// If torch.Stream is not created from existing Stream, then create a new one.
|
||||
// It requires other device backends override getNewStream method. How the new
|
||||
@ -96,7 +96,7 @@ static PyObject* THPStream_pynew(
|
||||
self->device_type = static_cast<int64_t>(stream_opt->device_type());
|
||||
self->context = nullptr;
|
||||
|
||||
return static_cast<PyObject*>(ptr.release());
|
||||
return (PyObject*)ptr.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -108,7 +108,7 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
THPStream* self = reinterpret_cast<THPStream*>(ptr.get());
|
||||
THPStream* self = (THPStream*)ptr.get();
|
||||
self->stream_id = stream.id();
|
||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||
self->device_index = static_cast<int64_t>(stream.device_index());
|
||||
@ -119,7 +119,7 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
|
||||
}
|
||||
|
||||
static void THPStream_dealloc(THPStream* self) {
|
||||
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
}
|
||||
|
||||
static PyObject* THPStream_get_device(THPStream* self, void* unused) {
|
||||
@ -132,7 +132,7 @@ static PyObject* THPStream_get_device(THPStream* self, void* unused) {
|
||||
|
||||
static PyObject* THPStream_query(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPStream*>(_self);
|
||||
auto self = (THPStream*)_self;
|
||||
|
||||
return PyBool_FromLong(c10::Stream::unpack3(
|
||||
self->stream_id,
|
||||
@ -146,7 +146,7 @@ static PyObject* THPStream_query(PyObject* _self, PyObject* noargs) {
|
||||
static PyObject* THPStream_synchronize(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
auto self = reinterpret_cast<THPStream*>(_self);
|
||||
auto self = (THPStream*)_self;
|
||||
|
||||
c10::Stream::unpack3(
|
||||
self->stream_id,
|
||||
@ -160,8 +160,8 @@ static PyObject* THPStream_synchronize(PyObject* _self, PyObject* noargs) {
|
||||
|
||||
static PyObject* THPStream_wait_event(PyObject* _self, PyObject* _event) {
|
||||
HANDLE_TH_ERRORS {
|
||||
auto self = reinterpret_cast<THPStream*>(_self);
|
||||
auto event = reinterpret_cast<THPEvent*>(_event);
|
||||
auto self = (THPStream*)_self;
|
||||
auto event = (THPEvent*)_event;
|
||||
c10::Stream::unpack3(
|
||||
self->stream_id,
|
||||
static_cast<c10::DeviceIndex>(self->device_index),
|
||||
@ -174,8 +174,8 @@ static PyObject* THPStream_wait_event(PyObject* _self, PyObject* _event) {
|
||||
|
||||
static PyObject* THPStream_wait_stream(PyObject* _self, PyObject* _other) {
|
||||
HANDLE_TH_ERRORS {
|
||||
auto self = reinterpret_cast<THPStream*>(_self);
|
||||
auto other_stream = reinterpret_cast<THPStream*>(_other);
|
||||
auto self = (THPStream*)_self;
|
||||
auto other_stream = (THPStream*)_other;
|
||||
c10::Event new_event(
|
||||
static_cast<c10::DeviceType>(other_stream->device_type),
|
||||
c10::EventFlag::PYTORCH_DEFAULT);
|
||||
@ -198,7 +198,7 @@ static PyObject* THPStream_record_event(
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPStream*>(_self);
|
||||
auto self = (THPStream*)_self;
|
||||
PyObject* _new_event = nullptr;
|
||||
PyObject* _event = Py_None;
|
||||
|
||||
@ -222,13 +222,13 @@ static PyObject* THPStream_record_event(
|
||||
static_cast<c10::DeviceType>(self->device_type),
|
||||
c10::EventFlag::PYTORCH_DEFAULT);
|
||||
}
|
||||
auto new_event = reinterpret_cast<THPEvent*>(_new_event);
|
||||
auto new_event = (THPEvent*)_new_event;
|
||||
TORCH_CHECK(new_event, "event must not be null");
|
||||
new_event->event.record(c10::Stream::unpack3(
|
||||
self->stream_id,
|
||||
static_cast<c10::DeviceIndex>(self->device_index),
|
||||
static_cast<c10::DeviceType>(self->device_type)));
|
||||
return reinterpret_cast<PyObject*>(new_event);
|
||||
return (PyObject*)new_event;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -260,7 +260,7 @@ static PyObject* THPStream_eq(THPStream* self, THPStream* other) {
|
||||
|
||||
static PyObject* THPStream_enter(PyObject* _self, PyObject* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPStream*>(_self);
|
||||
auto self = (THPStream*)_self;
|
||||
c10::DeviceType stream_device_type =
|
||||
static_cast<c10::DeviceType>(self->device_type);
|
||||
// No operation is performed if the stream does not belong to an accelerator.
|
||||
@ -304,7 +304,7 @@ static PyObject* THPStream_enter(PyObject* _self, PyObject* unused) {
|
||||
|
||||
static PyObject* THPStream_exit(PyObject* _self, PyObject* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = reinterpret_cast<THPStream*>(_self);
|
||||
auto self = (THPStream*)_self;
|
||||
// No operation is performed if the stream does not belong to an accelerator.
|
||||
if (C10_UNLIKELY(!at::accelerator::isAccelerator(
|
||||
static_cast<c10::DeviceType>(self->device_type)))) {
|
||||
@ -323,7 +323,7 @@ static PyObject* THPStream_exit(PyObject* _self, PyObject* unused) {
|
||||
auto ctx_device_index = THPObjectPtr(py_device_index);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
ctx_stream.get(), "ctx_stream should be present on the context dict.");
|
||||
auto prev_stream = reinterpret_cast<THPStream*>(ctx_stream.get());
|
||||
auto prev_stream = (THPStream*)(ctx_stream.get());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
ctx_device_index.get(),
|
||||
"ctx_device_index should be present on the context dict.");
|
||||
@ -360,14 +360,10 @@ static PyObject* THPStream_richcompare(
|
||||
} else {
|
||||
switch (op) {
|
||||
case Py_EQ:
|
||||
result = THPStream_eq(
|
||||
reinterpret_cast<THPStream*>(self),
|
||||
reinterpret_cast<THPStream*>(other));
|
||||
result = THPStream_eq((THPStream*)self, (THPStream*)other);
|
||||
break;
|
||||
case Py_NE:
|
||||
result = THPStream_ne(
|
||||
reinterpret_cast<THPStream*>(self),
|
||||
reinterpret_cast<THPStream*>(other));
|
||||
result = THPStream_ne((THPStream*)self, (THPStream*)other);
|
||||
break;
|
||||
default:
|
||||
result = Py_False;
|
||||
@ -397,11 +393,7 @@ static const std::initializer_list<PyMemberDef> THPStream_members = {
|
||||
{nullptr}};
|
||||
|
||||
static const std::initializer_list<PyGetSetDef> THPStream_properties = {
|
||||
{"device",
|
||||
reinterpret_cast<getter>(THPStream_get_device),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"device", (getter)THPStream_get_device, nullptr, nullptr, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
static const std::initializer_list<PyMethodDef> THPStream_methods = {
|
||||
@ -413,7 +405,7 @@ static const std::initializer_list<PyMethodDef> THPStream_methods = {
|
||||
castPyCFunctionWithKeywords(THPStream_record_event),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"__eq__", reinterpret_cast<PyCFunction>(THPStream_eq), METH_O, nullptr},
|
||||
{"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr},
|
||||
{"__enter__", THPStream_enter, METH_NOARGS, nullptr},
|
||||
{"__exit__", THPStream_exit, METH_VARARGS, nullptr},
|
||||
{nullptr}};
|
||||
@ -423,16 +415,16 @@ static PyTypeObject THPStreamType = {
|
||||
"torch.Stream", /* tp_name */
|
||||
sizeof(THPStream), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
reinterpret_cast<destructor>(THPStream_dealloc), /* tp_dealloc */
|
||||
(destructor)THPStream_dealloc, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
reinterpret_cast<reprfunc>(THPStream_repr), /* tp_repr */
|
||||
(reprfunc)THPStream_repr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
reinterpret_cast<hashfunc>(THPStream_hash), /* tp_hash */
|
||||
(hashfunc)THPStream_hash, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
@ -470,8 +462,7 @@ void THPStream_init(PyObject* module) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THPStreamType);
|
||||
if (PyModule_AddObject(
|
||||
module, "Stream", reinterpret_cast<PyObject*>(&THPStreamType)) < 0) {
|
||||
if (PyModule_AddObject(module, "Stream", (PyObject*)&THPStreamType) < 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
@ -273,34 +273,18 @@ static PyObject* THPIInfo_str(THPIInfo* self) {
|
||||
}
|
||||
|
||||
static const std::initializer_list<PyGetSetDef> THPFInfo_properties = {
|
||||
{"bits",
|
||||
reinterpret_cast<getter>(THPDTypeInfo_bits),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"eps", reinterpret_cast<getter>(THPFInfo_eps), nullptr, nullptr, nullptr},
|
||||
{"max", reinterpret_cast<getter>(THPFInfo_max), nullptr, nullptr, nullptr},
|
||||
{"min", reinterpret_cast<getter>(THPFInfo_min), nullptr, nullptr, nullptr},
|
||||
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
|
||||
{"eps", (getter)THPFInfo_eps, nullptr, nullptr, nullptr},
|
||||
{"max", (getter)THPFInfo_max, nullptr, nullptr, nullptr},
|
||||
{"min", (getter)THPFInfo_min, nullptr, nullptr, nullptr},
|
||||
{"smallest_normal",
|
||||
reinterpret_cast<getter>(THPFInfo_smallest_normal),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"tiny",
|
||||
reinterpret_cast<getter>(THPFInfo_tiny),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"resolution",
|
||||
reinterpret_cast<getter>(THPFInfo_resolution),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"dtype",
|
||||
reinterpret_cast<getter>(THPFInfo_dtype),
|
||||
(getter)THPFInfo_smallest_normal,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"tiny", (getter)THPFInfo_tiny, nullptr, nullptr, nullptr},
|
||||
{"resolution", (getter)THPFInfo_resolution, nullptr, nullptr, nullptr},
|
||||
{"dtype", (getter)THPFInfo_dtype, nullptr, nullptr, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
PyTypeObject THPFInfoType = {
|
||||
@ -313,13 +297,13 @@ PyTypeObject THPFInfoType = {
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
reinterpret_cast<reprfunc>(THPFInfo_str), /* tp_repr */
|
||||
(reprfunc)THPFInfo_str, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
reinterpret_cast<reprfunc>(THPFInfo_str), /* tp_str */
|
||||
(reprfunc)THPFInfo_str, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
@ -327,7 +311,7 @@ PyTypeObject THPFInfoType = {
|
||||
nullptr, /* tp_doc */
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
reinterpret_cast<richcmpfunc>(THPDTypeInfo_compare), /* tp_richcompare */
|
||||
(richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
@ -346,18 +330,10 @@ PyTypeObject THPFInfoType = {
|
||||
};
|
||||
|
||||
static const std::initializer_list<PyGetSetDef> THPIInfo_properties = {
|
||||
{"bits",
|
||||
reinterpret_cast<getter>(THPDTypeInfo_bits),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"max", reinterpret_cast<getter>(THPIInfo_max), nullptr, nullptr, nullptr},
|
||||
{"min", reinterpret_cast<getter>(THPIInfo_min), nullptr, nullptr, nullptr},
|
||||
{"dtype",
|
||||
reinterpret_cast<getter>(THPIInfo_dtype),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
|
||||
{"max", (getter)THPIInfo_max, nullptr, nullptr, nullptr},
|
||||
{"min", (getter)THPIInfo_min, nullptr, nullptr, nullptr},
|
||||
{"dtype", (getter)THPIInfo_dtype, nullptr, nullptr, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
PyTypeObject THPIInfoType = {
|
||||
@ -370,13 +346,13 @@ PyTypeObject THPIInfoType = {
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
reinterpret_cast<reprfunc>(THPIInfo_str), /* tp_repr */
|
||||
(reprfunc)THPIInfo_str, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
reinterpret_cast<reprfunc>(THPIInfo_str), /* tp_str */
|
||||
(reprfunc)THPIInfo_str, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
@ -384,7 +360,7 @@ PyTypeObject THPIInfoType = {
|
||||
nullptr, /* tp_doc */
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
reinterpret_cast<richcmpfunc>(THPDTypeInfo_compare), /* tp_richcompare */
|
||||
(richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
@ -407,16 +383,14 @@ void THPDTypeInfo_init(PyObject* module) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THPFInfoType);
|
||||
if (PyModule_AddObject(
|
||||
module, "finfo", reinterpret_cast<PyObject*>(&THPFInfoType)) != 0) {
|
||||
if (PyModule_AddObject(module, "finfo", (PyObject*)&THPFInfoType) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
if (PyType_Ready(&THPIInfoType) < 0) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THPIInfoType);
|
||||
if (PyModule_AddObject(
|
||||
module, "iinfo", reinterpret_cast<PyObject*>(&THPIInfoType)) != 0) {
|
||||
if (PyModule_AddObject(module, "iinfo", (PyObject*)&THPIInfoType) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,7 +25,7 @@ c10::intrusive_ptr<rpc::Message> CleanupAutogradContextReq::toMessageImpl() && {
|
||||
std::unique_ptr<CleanupAutogradContextReq> CleanupAutogradContextReq::
|
||||
fromMessage(const rpc::Message& message) {
|
||||
// unpickle and get the context_id we need to clean up
|
||||
auto payload = message.payload().data();
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
IValue ivalue_context_id = jit::unpickle(
|
||||
payload,
|
||||
|
||||
@ -47,7 +47,7 @@ c10::intrusive_ptr<Message> PropagateGradientsReq::toMessageImpl() && {
|
||||
std::unique_ptr<PropagateGradientsReq> PropagateGradientsReq::fromMessage(
|
||||
const Message& message) {
|
||||
// Unpickle the message and retrieve tupleElements.
|
||||
auto payload = message.payload().data();
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
IValue tuple = jit::unpickle(
|
||||
payload,
|
||||
|
||||
@ -37,7 +37,7 @@ c10::intrusive_ptr<Message> RRefBackwardReq::toMessageImpl() && {
|
||||
std::unique_ptr<RRefBackwardReq> RRefBackwardReq::fromMessage(
|
||||
const Message& message) {
|
||||
// Unpickle the message and retrieve tupleElements.
|
||||
auto payload = message.payload().data();
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
IValue tuple = jit::unpickle(
|
||||
payload,
|
||||
|
||||
@ -225,7 +225,7 @@ class File {
|
||||
while (count > 0) {
|
||||
auto rv = syscall([this, buf, count] { return ::read(fd_, buf, count); });
|
||||
SYSASSERT(rv, "read");
|
||||
buf = static_cast<uint8_t*>(buf) + rv;
|
||||
buf = (uint8_t*)buf + rv;
|
||||
count -= rv;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2476,7 +2476,7 @@ static at::Tensor& checkSingleTensor(std::vector<at::Tensor>& tensors) {
|
||||
|
||||
static uint32_t checkTag(int32_t tag) {
|
||||
TORCH_CHECK(tag >= 0, "Tag must be nonnegative");
|
||||
return static_cast<uint32_t>(tag);
|
||||
return (uint32_t)tag;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Work> ProcessGroupGloo::send(
|
||||
|
||||
@ -207,7 +207,7 @@ class SendBuffer {
|
||||
SendBuffer(detail::TCPClient& client, detail::QueryType cmd)
|
||||
: client(client) {
|
||||
buffer.reserve(32); // enough for most commands
|
||||
buffer.push_back(static_cast<uint8_t>(cmd));
|
||||
buffer.push_back((uint8_t)cmd);
|
||||
}
|
||||
|
||||
void appendString(const std::string& str) {
|
||||
@ -224,7 +224,7 @@ class SendBuffer {
|
||||
|
||||
template <typename T>
|
||||
void appendValue(T value) {
|
||||
uint8_t* begin = reinterpret_cast<uint8_t*>(&value);
|
||||
uint8_t* begin = (uint8_t*)&value;
|
||||
buffer.insert(buffer.end(), begin, begin + sizeof(T));
|
||||
maybeFlush();
|
||||
}
|
||||
|
||||
@ -36,14 +36,14 @@ Other callbacks don't provide exception safety so avoid there.
|
||||
// backlog. This should be at least world size to avoid issues on init. We set
|
||||
// it to -1 to use the host max value which is controlled by `soconnmax`.
|
||||
auto constexpr DEFAULT_BACKLOG = -1;
|
||||
auto constexpr MAX_KEY_COUNT = static_cast<size_t>(128 * 1024);
|
||||
auto constexpr MAX_KEY_COUNT = size_t(128 * 1024);
|
||||
auto constexpr MAX_STRING_LEN = 8 * 1024;
|
||||
auto constexpr MAX_PAYLOAD_LEN = 8 * 1024 * 1024;
|
||||
|
||||
// This controls the preferred size for buffers.
|
||||
// Too small and we'll need multiple buffers for one request
|
||||
// Too big and we might taxing malloc
|
||||
auto constexpr ALLOC_BUFFER_SIZE = static_cast<size_t>(4096);
|
||||
auto constexpr ALLOC_BUFFER_SIZE = size_t(4096);
|
||||
class UvHandle : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
~UvHandle() override = default;
|
||||
@ -78,7 +78,7 @@ class UvHandle : public c10::intrusive_ptr_target {
|
||||
|
||||
private:
|
||||
static c10::intrusive_ptr<UvHandle> reclaim(uv_handle_t* handle) {
|
||||
auto h = static_cast<UvHandle*>(uv_handle_get_data(handle));
|
||||
auto h = (UvHandle*)uv_handle_get_data(handle);
|
||||
return c10::intrusive_ptr<UvHandle>::reclaim(h);
|
||||
}
|
||||
|
||||
@ -97,8 +97,7 @@ class UvTcpSocket : public UvHandle {
|
||||
}
|
||||
|
||||
static c10::intrusive_ptr<UvTcpSocket> borrow(uv_stream_t* handle) {
|
||||
auto h = static_cast<UvTcpSocket*>(
|
||||
uv_handle_get_data(reinterpret_cast<uv_handle_t*>(handle)));
|
||||
auto h = (UvTcpSocket*)uv_handle_get_data((uv_handle_t*)handle);
|
||||
return h->iptr();
|
||||
}
|
||||
|
||||
@ -108,7 +107,7 @@ class UvTcpSocket : public UvHandle {
|
||||
uv_buf_t* buf) {
|
||||
suggested_size = std::min(suggested_size, ALLOC_BUFFER_SIZE);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
|
||||
buf->base = static_cast<char*>(malloc(suggested_size));
|
||||
buf->base = (char*)malloc(suggested_size);
|
||||
buf->len = suggested_size;
|
||||
}
|
||||
|
||||
@ -169,8 +168,7 @@ class UvTcpSocket : public UvHandle {
|
||||
formatSockAddr(reinterpret_cast<struct ::sockaddr*>(&addr), addrLen);
|
||||
}
|
||||
|
||||
int res = uv_read_start(
|
||||
reinterpret_cast<uv_stream_t*>(&client), alloc_buffer, read_callback);
|
||||
int res = uv_read_start((uv_stream_t*)&client, alloc_buffer, read_callback);
|
||||
if (res) {
|
||||
C10D_WARNING(
|
||||
"Failed to setup read callback. client:{} code:{} name:{} desc:{}.",
|
||||
@ -183,12 +181,12 @@ class UvTcpSocket : public UvHandle {
|
||||
}
|
||||
|
||||
uv_handle_t* unsafeGetHandle() override {
|
||||
return reinterpret_cast<uv_handle_t*>(&client);
|
||||
return (uv_handle_t*)&client;
|
||||
}
|
||||
|
||||
protected:
|
||||
uv_stream_t* unsafeGetStream() {
|
||||
return reinterpret_cast<uv_stream_t*>(&client);
|
||||
return (uv_stream_t*)&client;
|
||||
}
|
||||
|
||||
uv_tcp_t* unsafeGetSocket() {
|
||||
@ -219,7 +217,7 @@ class UvTcpServer : public UvTcpSocket {
|
||||
auto res = c10::make_intrusive<UvTcpServer>(loop);
|
||||
res->handleReady();
|
||||
try {
|
||||
int uv_res = uv_tcp_open(res->unsafeGetSocket(), socket);
|
||||
int uv_res = uv_tcp_open((uv_tcp_t*)res->unsafeGetStream(), socket);
|
||||
C10D_CHECK_WITH(
|
||||
SocketError,
|
||||
uv_res == 0,
|
||||
@ -268,11 +266,9 @@ class UvTcpServer : public UvTcpSocket {
|
||||
struct sockaddr_storage addr{};
|
||||
int uv_res = 0;
|
||||
if (useIpv6) {
|
||||
uv_res = uv_ip6_addr(
|
||||
"::", port, reinterpret_cast<struct sockaddr_in6*>(&addr));
|
||||
uv_res = uv_ip6_addr("::", port, (struct sockaddr_in6*)&addr);
|
||||
} else {
|
||||
uv_res = uv_ip4_addr(
|
||||
"0.0.0.0", port, reinterpret_cast<struct sockaddr_in*>(&addr));
|
||||
uv_res = uv_ip4_addr("0.0.0.0", port, (struct sockaddr_in*)&addr);
|
||||
}
|
||||
TORCH_CHECK_WITH(
|
||||
DistStoreError,
|
||||
@ -290,9 +286,7 @@ class UvTcpServer : public UvTcpSocket {
|
||||
uv_strerror(uv_res));
|
||||
|
||||
uv_res = uv_tcp_bind(
|
||||
res->unsafeGetSocket(),
|
||||
reinterpret_cast<const struct ::sockaddr*>(&addr),
|
||||
0);
|
||||
res->unsafeGetSocket(), (const struct ::sockaddr*)&addr, 0);
|
||||
C10D_CHECK_WITH(
|
||||
SocketError,
|
||||
uv_res == 0,
|
||||
@ -335,9 +329,8 @@ class UvTcpServer : public UvTcpSocket {
|
||||
}
|
||||
|
||||
void accept(const c10::intrusive_ptr<UvTcpSocket>& socket) {
|
||||
int res = uv_accept(
|
||||
unsafeGetStream(),
|
||||
reinterpret_cast<uv_stream_t*>(socket->unsafeGetHandle()));
|
||||
int res =
|
||||
uv_accept(unsafeGetStream(), (uv_stream_t*)socket->unsafeGetHandle());
|
||||
C10D_CHECK_WITH(
|
||||
SocketError,
|
||||
res == 0,
|
||||
@ -359,8 +352,7 @@ class UvTcpServer : public UvTcpSocket {
|
||||
}
|
||||
|
||||
static c10::intrusive_ptr<UvTcpServer> borrow(uv_stream_t* handle) {
|
||||
auto h = static_cast<UvTcpServer*>(
|
||||
uv_handle_get_data(reinterpret_cast<uv_handle_t*>(handle)));
|
||||
auto h = (UvTcpServer*)uv_handle_get_data((uv_handle_t*)handle);
|
||||
return h->iptr();
|
||||
}
|
||||
|
||||
@ -397,8 +389,7 @@ class WriterPayload : public c10::intrusive_ptr_target {
|
||||
static c10::intrusive_ptr<WriterPayload> reclaim(uv_write_t* request) {
|
||||
/* This method returns a intrusive_ptr that does not increase the refcount.
|
||||
*/
|
||||
auto h = static_cast<WriterPayload*>(
|
||||
uv_req_get_data(reinterpret_cast<uv_req_t*>(request)));
|
||||
auto h = (WriterPayload*)uv_req_get_data((uv_req_t*)request);
|
||||
return c10::intrusive_ptr<WriterPayload>::reclaim(h);
|
||||
}
|
||||
|
||||
@ -436,19 +427,15 @@ class WriterPayload : public c10::intrusive_ptr_target {
|
||||
std::vector<uint8_t>&& in_data,
|
||||
c10::intrusive_ptr<UvHandle> handle)
|
||||
: data(std::move(in_data)), handle(std::move(handle)) {
|
||||
uv_req_set_data(reinterpret_cast<uv_req_t*>(&req), this);
|
||||
uv_req_set_data((uv_req_t*)&req, this);
|
||||
}
|
||||
|
||||
~WriterPayload() override = default;
|
||||
|
||||
void send() {
|
||||
buf = uv_buf_init(reinterpret_cast<char*>(data.data()), data.size());
|
||||
buf = uv_buf_init((char*)data.data(), data.size());
|
||||
int res = uv_write(
|
||||
&req,
|
||||
reinterpret_cast<uv_stream_t*>(handle->unsafeGetHandle()),
|
||||
&buf,
|
||||
1,
|
||||
write_done);
|
||||
&req, (uv_stream_t*)handle->unsafeGetHandle(), &buf, 1, write_done);
|
||||
|
||||
if (res) {
|
||||
C10D_WARNING(
|
||||
@ -597,7 +584,7 @@ class ChunkedStream {
|
||||
if (available() < size)
|
||||
return false;
|
||||
str.resize(size);
|
||||
return read_many(str.data(), size);
|
||||
return read_many((char*)str.data(), size);
|
||||
}
|
||||
|
||||
bool read_payload(std::vector<uint8_t>& data) {
|
||||
@ -617,7 +604,7 @@ class ChunkedStream {
|
||||
if (available() < size_in_bytes)
|
||||
return false;
|
||||
data.resize(size);
|
||||
return read_many(reinterpret_cast<char*>(data.data()), size_in_bytes);
|
||||
return read_many((char*)data.data(), size_in_bytes);
|
||||
}
|
||||
|
||||
size_t available() {
|
||||
@ -716,15 +703,15 @@ class LibUVStoreDaemon : public BackgroundThread {
|
||||
int port_;
|
||||
|
||||
static LibUVStoreDaemon& from_uv(uv_handle_t* stream) {
|
||||
return *static_cast<LibUVStoreDaemon*>(uv_handle_get_data(stream));
|
||||
return *(LibUVStoreDaemon*)uv_handle_get_data(stream);
|
||||
}
|
||||
|
||||
static void on_new_connection(uv_stream_t* server, int status) {
|
||||
from_uv(reinterpret_cast<uv_handle_t*>(server)).onConnect(status);
|
||||
from_uv((uv_handle_t*)server).onConnect(status);
|
||||
}
|
||||
|
||||
static void on_exit_request(uv_async_t* handle) {
|
||||
from_uv(reinterpret_cast<uv_handle_t*>(handle)).onExitRequest();
|
||||
from_uv((uv_handle_t*)handle).onExitRequest();
|
||||
}
|
||||
|
||||
void onConnect(int status);
|
||||
@ -752,12 +739,12 @@ class UvClient : public UvTcpSocket {
|
||||
if (!stream.read1(command))
|
||||
break;
|
||||
if (store->isMiscellaneousClient(iptr())) {
|
||||
if (static_cast<QueryType>(command) != QueryType::VALIDATE)
|
||||
if ((QueryType)command != QueryType::VALIDATE)
|
||||
return;
|
||||
if (!parse_validate_command())
|
||||
return;
|
||||
} else {
|
||||
switch (static_cast<QueryType>(command)) {
|
||||
switch ((QueryType)command) {
|
||||
case QueryType::PING:
|
||||
if (!parse_ping_command())
|
||||
return;
|
||||
@ -996,7 +983,7 @@ class UvClient : public UvTcpSocket {
|
||||
|
||||
if (store->waitKeys(keys, iptr())) {
|
||||
StreamWriter sw(iptr());
|
||||
sw.write1(static_cast<uint8_t>(WaitResponseType::STOP_WAITING));
|
||||
sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
|
||||
sw.send();
|
||||
}
|
||||
|
||||
@ -1115,7 +1102,7 @@ class UvClient : public UvTcpSocket {
|
||||
C10D_TRACE("cancel_wait address:{}", this->address());
|
||||
|
||||
StreamWriter sw(iptr());
|
||||
sw.write1(static_cast<uint8_t>(WaitResponseType::WAIT_CANCELED));
|
||||
sw.write1((uint8_t)WaitResponseType::WAIT_CANCELED);
|
||||
sw.send();
|
||||
|
||||
return true;
|
||||
@ -1200,7 +1187,7 @@ void LibUVStoreDaemon::onConnect(int status) {
|
||||
|
||||
void LibUVStoreDaemon::onExitRequest() {
|
||||
C10D_DEBUG("Store exit requested\n");
|
||||
uv_close(reinterpret_cast<uv_handle_t*>(&exit_handle_), nullptr);
|
||||
uv_close((uv_handle_t*)&exit_handle_, nullptr);
|
||||
uv_stop(&loop_);
|
||||
}
|
||||
|
||||
@ -1241,12 +1228,12 @@ LibUVStoreDaemon::LibUVStoreDaemon(int port) : port_(port) {
|
||||
uv_async_init(&loop_, &exit_handle_, LibUVStoreDaemon::on_exit_request) ==
|
||||
0,
|
||||
"Failed to init uv async event");
|
||||
uv_handle_set_data(reinterpret_cast<uv_handle_t*>(&exit_handle_), this);
|
||||
uv_handle_set_data((uv_handle_t*)&exit_handle_, this);
|
||||
}
|
||||
|
||||
LibUVStoreDaemon::~LibUVStoreDaemon() {
|
||||
if (!is_running()) {
|
||||
uv_close(reinterpret_cast<uv_handle_t*>(&exit_handle_), nullptr);
|
||||
uv_close((uv_handle_t*)&exit_handle_, nullptr);
|
||||
uv_run(&loop_, UV_RUN_NOWAIT);
|
||||
if (uv_loop_close(&loop_) != 0) {
|
||||
C10D_ERROR("loop cleanup didn't work");
|
||||
@ -1490,7 +1477,7 @@ void LibUVStoreDaemon::wakeupWaitingClients(const std::string& key) {
|
||||
for (const auto& client : socketsToWait->second) {
|
||||
if (--keysAwaited_[client] == 0) {
|
||||
StreamWriter sw(client->iptr());
|
||||
sw.write1(static_cast<uint8_t>(WaitResponseType::STOP_WAITING));
|
||||
sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
|
||||
sw.send();
|
||||
}
|
||||
}
|
||||
@ -1504,7 +1491,7 @@ void LibUVStoreDaemon::wakeupOneWaitingClient(const std::string& key) {
|
||||
for (const auto& client : socketsToWait->second) {
|
||||
if (--keysAwaited_[client] == 0) {
|
||||
StreamWriter sw(client->iptr());
|
||||
sw.write1(static_cast<uint8_t>(WaitResponseType::STOP_WAITING));
|
||||
sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
|
||||
sw.send();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -443,8 +443,7 @@ PyTypeObject* GetReduceOpMetaclass() {
|
||||
spec.basicsize = base_metaclass->tp_basicsize;
|
||||
spec.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
|
||||
spec.slots = slots;
|
||||
PyTypeObject* metaclass =
|
||||
reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
|
||||
PyTypeObject* metaclass = (PyTypeObject*)PyType_FromSpec(&spec);
|
||||
if (!metaclass)
|
||||
throw py::error_already_set();
|
||||
return metaclass;
|
||||
@ -813,10 +812,7 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||
// `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol.
|
||||
// https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
|
||||
py::class_<::c10d::ReduceOp> reduce_op(
|
||||
module,
|
||||
"ReduceOp",
|
||||
py::metaclass(reinterpret_cast<PyObject*>(GetReduceOpMetaclass())),
|
||||
R"(
|
||||
module, "ReduceOp", py::metaclass((PyObject*)GetReduceOpMetaclass()), R"(
|
||||
An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``,
|
||||
``MIN``, ``MAX``, ``BAND``, ``BOR``, ``BXOR``, and ``PREMUL_SUM``.
|
||||
|
||||
|
||||
@ -136,9 +136,9 @@ Reducer::Reducer(
|
||||
{
|
||||
std::set<int> unique_devices;
|
||||
for (const auto& v : params_) {
|
||||
auto device_idx = static_cast<int>(v.device().index());
|
||||
auto [_, inserted] = unique_devices.emplace(device_idx);
|
||||
if (inserted) {
|
||||
auto device_idx = int(v.device().index());
|
||||
if (unique_devices.find(device_idx) == unique_devices.end()) {
|
||||
unique_devices.insert(device_idx);
|
||||
if (unique_devices.size() > 1) {
|
||||
is_multi_device_module_ = true;
|
||||
break;
|
||||
@ -168,7 +168,7 @@ Reducer::Reducer(
|
||||
}
|
||||
|
||||
// All variables are expected to have their `grad_fn` set to the gradient
|
||||
// accumulation function (since they are leaves in the autograd graph).
|
||||
// accumulation function (since they are leafs in the autograd graph).
|
||||
// We store pointers to these functions such that we can check if they are
|
||||
// used in an autograd pass. If they are not, we know their grad tensors
|
||||
// can be marked as ready for reduction.
|
||||
|
||||
@ -76,7 +76,7 @@ class CudaTimer : public Timer {
|
||||
if (milliseconds < 0) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return static_cast<int64_t>(milliseconds * kMilliSecondToNanosSecond);
|
||||
return int64_t(milliseconds * kMilliSecondToNanosSecond);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -220,7 +220,7 @@ std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) {
|
||||
}
|
||||
// if we can't resolve the hostname, display the IP address
|
||||
if (addr->sa_family == AF_INET) {
|
||||
struct sockaddr_in* psai = reinterpret_cast<struct sockaddr_in*>(&addr);
|
||||
struct sockaddr_in* psai = (struct sockaddr_in*)&addr;
|
||||
// NOLINTNEXTLINE(*array*)
|
||||
char ip[INET_ADDRSTRLEN];
|
||||
if (inet_ntop(addr->sa_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) !=
|
||||
@ -228,7 +228,7 @@ std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) {
|
||||
return fmt::format("{}:{}", ip, psai->sin_port);
|
||||
}
|
||||
} else if (addr->sa_family == AF_INET6) {
|
||||
struct sockaddr_in6* psai = reinterpret_cast<struct sockaddr_in6*>(&addr);
|
||||
struct sockaddr_in6* psai = (struct sockaddr_in6*)&addr;
|
||||
// NOLINTNEXTLINE(*array*)
|
||||
char ip[INET6_ADDRSTRLEN];
|
||||
if (inet_ntop(addr->sa_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) !=
|
||||
|
||||
@ -178,7 +178,7 @@ std::vector<int> IpcChannel::all_gather_fds(
|
||||
int rank,
|
||||
const std::vector<int>& pids,
|
||||
int fd) {
|
||||
int world_size = static_cast<int>(pids.size());
|
||||
int world_size = (int)pids.size();
|
||||
std::vector<int> fds(pids.size());
|
||||
fds[rank] = fd;
|
||||
|
||||
@ -197,7 +197,7 @@ int IpcChannel::broadcast_fds(
|
||||
int src_rank,
|
||||
const std::vector<int>& pids,
|
||||
int fd) {
|
||||
int world_size = static_cast<int>(pids.size());
|
||||
int world_size = (int)pids.size();
|
||||
|
||||
if (rank == src_rank) {
|
||||
for (int dst_rank = 0; dst_rank < world_size; ++dst_rank) {
|
||||
|
||||
@ -125,7 +125,7 @@ static at::Tensor empty_strided_p2p_persistent(
|
||||
const size_t numel = std::accumulate(
|
||||
size.begin(),
|
||||
size.end(),
|
||||
static_cast<size_t>(1),
|
||||
size_t(1),
|
||||
// NOLINTNEXTLINE(modernize-use-transparent-functors)
|
||||
std::multiplies<size_t>());
|
||||
const size_t element_size = c10::elementSize(dtype);
|
||||
@ -230,7 +230,7 @@ at::Tensor empty_strided_p2p(
|
||||
const size_t numel = std::accumulate(
|
||||
size.begin(),
|
||||
size.end(),
|
||||
static_cast<size_t>(1),
|
||||
size_t(1),
|
||||
// NOLINTNEXTLINE(modernize-use-transparent-functors)
|
||||
std::multiplies<size_t>());
|
||||
const size_t element_size = c10::elementSize(dtype);
|
||||
|
||||
@ -23,8 +23,7 @@ std::unordered_map<std::string, worker_id_t> collectNames(
|
||||
}
|
||||
std::vector<uint8_t> workerNameVector = store.get(std::to_string(workerId));
|
||||
std::string workerName(
|
||||
reinterpret_cast<char*>(workerNameVector.data()),
|
||||
workerNameVector.size());
|
||||
(char*)workerNameVector.data(), workerNameVector.size());
|
||||
|
||||
TORCH_CHECK(
|
||||
nameToId.find(workerName) == nameToId.end(),
|
||||
@ -92,8 +91,7 @@ std::unordered_map<std::string, worker_id_t> collectCurrentNames(
|
||||
// Get the current list of workers
|
||||
std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey);
|
||||
allWorkerInfos = std::string(
|
||||
reinterpret_cast<const char*>(allWorkerInfosKeyVector.data()),
|
||||
allWorkerInfosKeyVector.size());
|
||||
(char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size());
|
||||
// workerInfos are comma separated with a comma at the end (e.g.
|
||||
// "Name1-Rank1,Name2-Rank2,Name3-Rank2,") parse list of workers.
|
||||
if (!allWorkerInfos.empty()) {
|
||||
@ -134,8 +132,7 @@ void removeCurrentName(
|
||||
// Get current list of names/ranks
|
||||
std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey);
|
||||
std::string allWorkerInfos = std::string(
|
||||
reinterpret_cast<const char*>(allWorkerInfosKeyVector.data()),
|
||||
allWorkerInfosKeyVector.size());
|
||||
(char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size());
|
||||
|
||||
// Remove the current name and rank
|
||||
std::string str_to_erase = fmt::format("{}-{},", selfName, selfId);
|
||||
|
||||
@ -149,13 +149,13 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_info",
|
||||
static_cast<const WorkerInfo& (RpcAgent::*)(void) const>(
|
||||
&RpcAgent::getWorkerInfo),
|
||||
(const WorkerInfo& (RpcAgent::*)(void) const) &
|
||||
RpcAgent::getWorkerInfo,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_info",
|
||||
static_cast<const WorkerInfo& (RpcAgent::*)(const std::string&)
|
||||
const>(&RpcAgent::getWorkerInfo),
|
||||
(const WorkerInfo& (RpcAgent::*)(const std::string&) const) &
|
||||
RpcAgent::getWorkerInfo,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_infos",
|
||||
@ -611,28 +611,28 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_info",
|
||||
static_cast<const WorkerInfo& (TensorPipeAgent::*)(void) const>(
|
||||
&RpcAgent::getWorkerInfo),
|
||||
(const WorkerInfo& (TensorPipeAgent::*)(void) const) &
|
||||
RpcAgent::getWorkerInfo,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_info",
|
||||
static_cast<const WorkerInfo& (TensorPipeAgent::*)(const std::string&)
|
||||
const>(&TensorPipeAgent::getWorkerInfo),
|
||||
(const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
|
||||
TensorPipeAgent::getWorkerInfo,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_info",
|
||||
static_cast<const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id)
|
||||
const>(&TensorPipeAgent::getWorkerInfo),
|
||||
(const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
|
||||
TensorPipeAgent::getWorkerInfo,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_infos",
|
||||
static_cast<std::vector<WorkerInfo> (TensorPipeAgent::*)() const>(
|
||||
&TensorPipeAgent::getWorkerInfos),
|
||||
(std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
|
||||
TensorPipeAgent::getWorkerInfos,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_get_device_map",
|
||||
static_cast<DeviceMap (TensorPipeAgent::*)(const WorkerInfo& dst)
|
||||
const>(&TensorPipeAgent::getDeviceMap),
|
||||
(DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst)
|
||||
const)&TensorPipeAgent::getDeviceMap,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_get_backend_options",
|
||||
|
||||
@ -32,7 +32,7 @@ c10::intrusive_ptr<Message> PythonRemoteCall::toMessageImpl() && {
|
||||
|
||||
std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
|
||||
const Message& message) {
|
||||
auto payload = message.payload().data();
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
|
||||
auto value = jit::unpickle(
|
||||
|
||||
@ -74,7 +74,7 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage(
|
||||
[this,
|
||||
// std::function must be copyable, hence hae to cast the unique_ptr to
|
||||
// a shared_ptr here.
|
||||
rpc = std::shared_ptr<RpcCommandBase>(std::move(rpc)),
|
||||
rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
|
||||
messageType = request.type(),
|
||||
streams = std::move(streams)](JitFuture& /* unused */) mutable {
|
||||
// The cost of pre-request check is minimal thanks to
|
||||
|
||||
@ -13,7 +13,7 @@ RegisterWorkerInfoOnce::RegisterWorkerInfoOnce() {
|
||||
}
|
||||
|
||||
WorkerInfo::WorkerInfo(std::string name, int64_t id)
|
||||
: WorkerInfo(std::move(name), static_cast<worker_id_t>(id)) {
|
||||
: WorkerInfo(std::move(name), (worker_id_t)id) {
|
||||
TORCH_CHECK(
|
||||
id <= std::numeric_limits<worker_id_t>::max(),
|
||||
"RPC worker id ",
|
||||
|
||||
@ -15,7 +15,7 @@ c10::ivalue::TupleElements toIValues(const Message& message, MessageType type) {
|
||||
type,
|
||||
", but got ",
|
||||
message.type());
|
||||
auto payload = message.payload().data();
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
|
||||
auto value = jit::unpickle(
|
||||
@ -87,7 +87,7 @@ std::unique_ptr<ScriptRRefFetchCall> ScriptRRefFetchCall::fromMessage(
|
||||
id <= std::numeric_limits<worker_id_t>::max(),
|
||||
"ScriptRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
|
||||
return std::make_unique<ScriptRRefFetchCall>(
|
||||
static_cast<worker_id_t>(id), RRefId::fromIValue(values[0]));
|
||||
worker_id_t(id), RRefId::fromIValue(values[0]));
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Message> PythonRRefFetchCall::toMessageImpl() && {
|
||||
@ -109,7 +109,7 @@ std::unique_ptr<PythonRRefFetchCall> PythonRRefFetchCall::fromMessage(
|
||||
id <= std::numeric_limits<worker_id_t>::max(),
|
||||
"PythonRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
|
||||
return std::make_unique<PythonRRefFetchCall>(
|
||||
static_cast<worker_id_t>(id), RRefId::fromIValue(values[0]));
|
||||
worker_id_t(id), RRefId::fromIValue(values[0]));
|
||||
}
|
||||
|
||||
const std::vector<at::IValue>& RRefFetchRet::values() {
|
||||
|
||||
@ -127,7 +127,7 @@ c10::intrusive_ptr<Message> ScriptCall::toMessageImpl() && {
|
||||
}
|
||||
|
||||
std::unique_ptr<ScriptCall> ScriptCall::fromMessage(const Message& message) {
|
||||
auto payload = message.payload().data();
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
auto value = jit::unpickle(
|
||||
payload,
|
||||
|
||||
@ -65,7 +65,7 @@ c10::intrusive_ptr<Message> ScriptRemoteCall::toMessageImpl() && {
|
||||
|
||||
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
|
||||
const Message& message) {
|
||||
auto payload = message.payload().data();
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
|
||||
auto value = jit::unpickle(
|
||||
|
||||
@ -20,7 +20,7 @@ c10::intrusive_ptr<Message> ScriptResp::toMessageImpl() && {
|
||||
}
|
||||
|
||||
std::unique_ptr<ScriptResp> ScriptResp::fromMessage(const Message& message) {
|
||||
auto payload = message.payload().data();
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
auto value = jit::unpickle(
|
||||
payload,
|
||||
|
||||
@ -304,10 +304,9 @@ void TensorPipeAgent::TimeSeriesMetricsTracker::addData(uint64_t dataPoint) {
|
||||
}
|
||||
|
||||
float TensorPipeAgent::TimeSeriesMetricsTracker::computeAverage() const {
|
||||
return currentCount_ == 0 ? 0
|
||||
: static_cast<float>(
|
||||
static_cast<double>(currentSum_) /
|
||||
static_cast<double>(currentCount_));
|
||||
return currentCount_ == 0
|
||||
? 0
|
||||
: static_cast<float>((double)currentSum_ / (double)currentCount_);
|
||||
}
|
||||
|
||||
//////////////////////// TensorpipeRpcAgent /////////////////////////////////
|
||||
@ -504,9 +503,8 @@ void TensorPipeAgent::startImpl() {
|
||||
for (const auto& p : workerNameToInfo_) {
|
||||
const auto& name = p.first;
|
||||
auto nodeAddrData = nameToAddressStore_.get(name);
|
||||
auto nodeAddrStr = std::string(
|
||||
reinterpret_cast<const char*>(nodeAddrData.data()),
|
||||
nodeAddrData.size());
|
||||
auto nodeAddrStr =
|
||||
std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
|
||||
workerNameToURL_.insert({name, nodeAddrStr});
|
||||
}
|
||||
|
||||
@ -1242,9 +1240,8 @@ void TensorPipeAgent::updateGroupMembership(
|
||||
// TODO: we should get nodeAddrStr in the joining process, then pass in as
|
||||
// an argument rather than getting from store each time
|
||||
auto nodeAddrData = nameToAddressStore_.get(name);
|
||||
auto nodeAddrStr = std::string(
|
||||
reinterpret_cast<const char*>(nodeAddrData.data()),
|
||||
nodeAddrData.size());
|
||||
auto nodeAddrStr =
|
||||
std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
|
||||
workerNameToURL_.insert({name, nodeAddrStr});
|
||||
|
||||
for (const auto& it : reverseDeviceMaps) {
|
||||
|
||||
@ -106,23 +106,23 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_info",
|
||||
static_cast<const WorkerInfo& (TensorPipeAgent::*)(void) const>(
|
||||
&RpcAgent::getWorkerInfo),
|
||||
(const WorkerInfo& (TensorPipeAgent::*)(void) const) &
|
||||
RpcAgent::getWorkerInfo,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_info",
|
||||
static_cast<const WorkerInfo& (TensorPipeAgent::*)(const std::string&)
|
||||
const>(&TensorPipeAgent::getWorkerInfo),
|
||||
(const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
|
||||
TensorPipeAgent::getWorkerInfo,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_info",
|
||||
static_cast<const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id)
|
||||
const>(&TensorPipeAgent::getWorkerInfo),
|
||||
(const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
|
||||
TensorPipeAgent::getWorkerInfo,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_infos",
|
||||
static_cast<std::vector<WorkerInfo> (TensorPipeAgent::*)() const>(
|
||||
&TensorPipeAgent::getWorkerInfos),
|
||||
(std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
|
||||
TensorPipeAgent::getWorkerInfos,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
#endif // USE_TENSORPIPE
|
||||
|
||||
|
||||
@ -507,7 +507,8 @@ std::vector<at::IValue> readWrappedPayload(
|
||||
" but additional payload size is ",
|
||||
additionalPayloadSize);
|
||||
auto wrappedPayloadBegin =
|
||||
message.payload().data() + payload.size() - additionalPayloadSize;
|
||||
static_cast<const char*>(message.payload().data()) + payload.size() -
|
||||
additionalPayloadSize;
|
||||
std::vector<torch::Tensor> tensorTable;
|
||||
IValue tuple = jit::unpickle(
|
||||
wrappedPayloadBegin,
|
||||
|
||||
@ -257,7 +257,7 @@ void THPStorage_writeFileRaw(
|
||||
at::device(self->device()).dtype(c10::kByte),
|
||||
{self->device()});
|
||||
cpu_tensor = device_tensor.to(at::kCPU);
|
||||
data = static_cast<uint8_t*>(cpu_tensor.data_ptr());
|
||||
data = (uint8_t*)cpu_tensor.data_ptr();
|
||||
}
|
||||
if (save_size) {
|
||||
if (torch::utils::THP_nativeByteOrder() ==
|
||||
@ -266,8 +266,8 @@ void THPStorage_writeFileRaw(
|
||||
else {
|
||||
int64_t nsize{}; // convert big endian cpu to little endian storage
|
||||
torch::utils::THP_encodeBuffer(
|
||||
reinterpret_cast<uint8_t*>(&nsize),
|
||||
reinterpret_cast<const int64_t*>(&numel),
|
||||
(uint8_t*)&nsize,
|
||||
(const int64_t*)&numel,
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
|
||||
1);
|
||||
doWrite(fd, &nsize, sizeof(int64_t));
|
||||
@ -279,7 +279,7 @@ void THPStorage_writeFileRaw(
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) {
|
||||
doWrite(fd, data, size_bytes);
|
||||
} else {
|
||||
size_t buffer_size = std::min(numel, static_cast<size_t>(5000));
|
||||
size_t buffer_size = std::min(numel, (size_t)5000);
|
||||
std::vector<uint8_t> le_buffer;
|
||||
le_buffer.resize(buffer_size * element_size);
|
||||
for (size_t i = 0; i < numel; i += buffer_size) {
|
||||
@ -287,19 +287,19 @@ void THPStorage_writeFileRaw(
|
||||
if (element_size == 2) {
|
||||
torch::utils::THP_encodeBuffer(
|
||||
le_buffer.data(),
|
||||
reinterpret_cast<const int16_t*>(data) + i,
|
||||
(const int16_t*)data + i,
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
|
||||
to_convert);
|
||||
} else if (element_size == 4) {
|
||||
torch::utils::THP_encodeBuffer(
|
||||
le_buffer.data(),
|
||||
reinterpret_cast<const int32_t*>(data) + i,
|
||||
(const int32_t*)data + i,
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
|
||||
to_convert);
|
||||
} else if (element_size == 8) {
|
||||
torch::utils::THP_encodeBuffer(
|
||||
le_buffer.data(),
|
||||
reinterpret_cast<const int64_t*>(data) + i,
|
||||
(const int64_t*)data + i,
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
|
||||
to_convert);
|
||||
}
|
||||
@ -333,8 +333,7 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
|
||||
if (torch::utils::THP_nativeByteOrder() ==
|
||||
torch::utils::THPByteOrder::THP_BIG_ENDIAN) {
|
||||
int64_t tsize = size; // convert little endian storage to big endian cpu
|
||||
torch::utils::THP_decodeBuffer(
|
||||
&size, reinterpret_cast<const uint8_t*>(&tsize), true, 1);
|
||||
torch::utils::THP_decodeBuffer(&size, (const uint8_t*)&tsize, true, 1);
|
||||
}
|
||||
size_t nbytes = element_size * size;
|
||||
if (!storage.defined()) {
|
||||
@ -359,7 +358,7 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
|
||||
data = static_cast<uint8_t*>(storage->mutable_data());
|
||||
} else {
|
||||
cpu_data.resize(nbytes);
|
||||
data = reinterpret_cast<uint8_t*>(cpu_data.data());
|
||||
data = (uint8_t*)cpu_data.data();
|
||||
}
|
||||
|
||||
// fast track for bytes and little endian
|
||||
@ -368,7 +367,7 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) {
|
||||
doRead(file, data, storage->nbytes());
|
||||
} else {
|
||||
int64_t buffer_size = std::min(size, static_cast<int64_t>(5000));
|
||||
int64_t buffer_size = std::min(size, (int64_t)5000);
|
||||
std::vector<uint8_t> le_buffer;
|
||||
le_buffer.resize(buffer_size * element_size);
|
||||
|
||||
@ -379,22 +378,13 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
if (element_size == 2) {
|
||||
torch::utils::THP_decodeBuffer(
|
||||
reinterpret_cast<int16_t*>(data) + i,
|
||||
le_buffer.data(),
|
||||
true,
|
||||
to_convert);
|
||||
(int16_t*)data + i, le_buffer.data(), true, to_convert);
|
||||
} else if (element_size == 4) {
|
||||
torch::utils::THP_decodeBuffer(
|
||||
reinterpret_cast<int32_t*>(data) + i,
|
||||
le_buffer.data(),
|
||||
true,
|
||||
to_convert);
|
||||
(int32_t*)data + i, le_buffer.data(), true, to_convert);
|
||||
} else if (element_size == 8) {
|
||||
torch::utils::THP_decodeBuffer(
|
||||
reinterpret_cast<int64_t*>(data) + i,
|
||||
le_buffer.data(),
|
||||
true,
|
||||
to_convert);
|
||||
(int64_t*)data + i, le_buffer.data(), true, to_convert);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,10 +5,10 @@
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
|
||||
namespace torch::stable {
|
||||
|
||||
@ -60,7 +60,7 @@ inline torch::stable::Tensor narrow(
|
||||
// only dtype information.
|
||||
inline torch::stable::Tensor new_empty(
|
||||
const torch::stable::Tensor& self,
|
||||
torch::headeronly::IntHeaderOnlyArrayRef size,
|
||||
std::vector<int64_t> size,
|
||||
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
||||
int32_t device_type;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
||||
@ -98,7 +98,7 @@ inline torch::stable::Tensor new_empty(
|
||||
// only dtype information.
|
||||
inline torch::stable::Tensor new_zeros(
|
||||
const torch::stable::Tensor& self,
|
||||
torch::headeronly::IntHeaderOnlyArrayRef size,
|
||||
std::vector<int64_t> size,
|
||||
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
||||
int32_t device_type;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
||||
@ -134,10 +134,12 @@ inline torch::stable::Tensor new_zeros(
|
||||
|
||||
// We expect this to be the stable version of the pad.default op.
|
||||
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
|
||||
// torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only.
|
||||
// use std::vector<int64_t> because
|
||||
// (1) IntArrayRef is not yet header-only
|
||||
// (2) SymInt is not yet header-only
|
||||
inline torch::stable::Tensor pad(
|
||||
const torch::stable::Tensor& self,
|
||||
torch::headeronly::IntHeaderOnlyArrayRef pad,
|
||||
std::vector<int64_t> pad,
|
||||
const std::string& mode = "constant",
|
||||
double value = 0.0) {
|
||||
AtenTensorHandle ret0 = nullptr;
|
||||
@ -169,10 +171,11 @@ inline torch::stable::Tensor amax(
|
||||
// This function is an overload to compute the maximum value along each slice of
|
||||
// `self` reducing over all the dimensions in the vector `dims`. The
|
||||
// amax.default op takes in a SymInt[] as the dims argument, however dims is
|
||||
// typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only
|
||||
// typed as use std::vector<int64_t> here because (1) IntArrayRef is not yet
|
||||
// header-only (2) SymInt is not yet header-only
|
||||
inline torch::stable::Tensor amax(
|
||||
const torch::stable::Tensor& self,
|
||||
torch::headeronly::IntHeaderOnlyArrayRef dims,
|
||||
std::vector<int64_t> dims,
|
||||
bool keepdim = false) {
|
||||
AtenTensorHandle ret = nullptr;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(
|
||||
|
||||
@ -84,7 +84,7 @@ std::vector<int> THPUtils_unpackIntTuple(PyObject* arg) {
|
||||
TORCH_CHECK(THPUtils_checkIntTuple(arg), "Couldn't unpack int tuple");
|
||||
std::vector<int> values(PyTuple_GET_SIZE(arg));
|
||||
for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
|
||||
values[i] = THPUtils_unpackInt(PyTuple_GET_ITEM(arg, i));
|
||||
values[i] = (int)THPUtils_unpackLong(PyTuple_GET_ITEM(arg, i));
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
@ -84,7 +84,6 @@ from torch.utils._sympy.functions import (
|
||||
IsNonOverlappingAndDenseIndicator,
|
||||
Max,
|
||||
Mod,
|
||||
OrderedAnd,
|
||||
PythonMod,
|
||||
TruncToInt,
|
||||
)
|
||||
@ -3766,8 +3765,6 @@ class ShapeEnv:
|
||||
self.guards: list[ShapeGuard] = []
|
||||
self.axioms: dict[sympy.Expr, sympy.Expr] = {}
|
||||
|
||||
self.ra_prelude: Optional[sympy.Expr] = None
|
||||
|
||||
# A set of ids that have already been allocated. This is used
|
||||
# for when we allocate symbol ids using the hash of the source
|
||||
# names to ensure we don't have collisions via linear probing
|
||||
@ -6285,15 +6282,6 @@ class ShapeEnv:
|
||||
self, e: SympyBoolean
|
||||
) -> tuple[tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]:
|
||||
"""Given a expression, it returns a list of predicates that follow from it"""
|
||||
|
||||
if isinstance(e, OrderedAnd):
|
||||
# Because SymPy's default And does not preserve operand order,
|
||||
# we introduced OrderedAnd to maintain order. As a result, we
|
||||
# cannot make additional global logical assumptions about the
|
||||
# conjunction as a whole, since the semantics of OrderedAnd are
|
||||
# intentionally more restrictive.
|
||||
return tuple()
|
||||
|
||||
equiv: dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {}
|
||||
|
||||
def add_expr(expr: SympyBoolean) -> None:
|
||||
@ -7799,9 +7787,6 @@ class ShapeEnv:
|
||||
"""
|
||||
expr = orig_expr
|
||||
|
||||
if self.ra_prelude is not None:
|
||||
expr = OrderedAnd(self.ra_prelude, expr)
|
||||
|
||||
# TODO: split conjunctions and evaluate them separately
|
||||
|
||||
static_expr = self._maybe_evaluate_static(expr)
|
||||
@ -7954,25 +7939,6 @@ class ShapeEnv:
|
||||
"constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def patch_ra_prelude(self, prelude: sympy.Expr) -> Iterator[None]:
|
||||
"""
|
||||
Context manager that ensures all runtime asserts generated while this context manager
|
||||
is active include a prelude expression. This is mainly used in torch.cond to guarantee
|
||||
that runtime asserts in subgraphs are guarded by the original cond predicate, preventing
|
||||
them from leaking into the main graph.
|
||||
"""
|
||||
prev = self.ra_prelude
|
||||
|
||||
if prev is not None:
|
||||
prelude = OrderedAnd(prev, prelude)
|
||||
|
||||
self.ra_prelude = prelude
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.ra_prelude = prev
|
||||
|
||||
|
||||
def _is_int(expr: object) -> bool:
|
||||
return isinstance(expr, SymInt) and expr.node.expr.is_number
|
||||
|
||||
@ -178,24 +178,7 @@ def insert_deferred_runtime_asserts(
|
||||
assert isinstance(node.target, str)
|
||||
target = getattr(fake_args[0], node.target)
|
||||
fake_args = fake_args[1:]
|
||||
|
||||
# The OrderedAnd function in torch.utils._sympy.functions combines
|
||||
# `not` and `any` operations to generate runtime assertions correctly
|
||||
# in the code. For these specific operations, we avoid evaluating the
|
||||
# function directly, as doing so could unnecessarily trigger
|
||||
# data-dependent errors. For example, if there's a runtime
|
||||
# assertion `u0 <= 0`, evaluating the meta of not(u0 <= 0) would
|
||||
# cause us to guard on the inner expression and potentially raise a
|
||||
# data-dependent error. Therefore, we choose not to compute the meta
|
||||
# in these cases, since it's not essential.
|
||||
calculate_meta = (
|
||||
node.target != operator.not_
|
||||
and node.target != any
|
||||
and not any(hasattr(a, "target") and a.target == any for a in node.args) # type: ignore[union-attr]
|
||||
)
|
||||
if calculate_meta:
|
||||
node.meta[val_key] = target(*fake_args) # type: ignore[operator]
|
||||
|
||||
node.meta[val_key] = target(*fake_args) # type: ignore[operator]
|
||||
except NotImplementedError:
|
||||
# This can happen when attempting to reify a symbol with an unsupported call_function node,
|
||||
# e.g. with NestedTensors + sym_size.int via match_symbol().
|
||||
@ -212,12 +195,11 @@ def insert_deferred_runtime_asserts(
|
||||
|
||||
Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis
|
||||
|
||||
def _sympy_interp(expr_to_proxy, expr, graph):
|
||||
def _sympy_interp(expr_to_proxy, expr):
|
||||
# sympy_interp() with hash consing
|
||||
from sympy import Integer, Number, Symbol
|
||||
from sympy.logic.boolalg import BooleanAtom
|
||||
|
||||
from torch.utils._sympy.functions import OrderedAnd
|
||||
from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
|
||||
|
||||
# hash cons
|
||||
@ -227,31 +209,10 @@ def insert_deferred_runtime_asserts(
|
||||
if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)):
|
||||
return sympy_interp(Analysis, expr_to_proxy, expr)
|
||||
|
||||
if isinstance(expr, OrderedAnd):
|
||||
predicate = (_sympy_interp(expr_to_proxy, expr.args[0], graph)).node
|
||||
runtime_assert = _sympy_interp(expr_to_proxy, expr.args[1], graph).node
|
||||
|
||||
not_predicate = fx.Proxy(
|
||||
graph.call_function(operator.not_, (predicate,)), tracer=tracer
|
||||
).node
|
||||
|
||||
return fx.Proxy(
|
||||
graph.call_function(
|
||||
any,
|
||||
(
|
||||
[
|
||||
not_predicate,
|
||||
runtime_assert,
|
||||
],
|
||||
),
|
||||
),
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# hash cons on arguments, run expr handler
|
||||
expr_to_proxy[expr] = _run_sympy_handler(
|
||||
Analysis,
|
||||
[_sympy_interp(expr_to_proxy, arg, graph) for arg in expr.args],
|
||||
[_sympy_interp(expr_to_proxy, arg) for arg in expr.args],
|
||||
expr,
|
||||
)
|
||||
return expr_to_proxy[expr]
|
||||
@ -297,7 +258,7 @@ def insert_deferred_runtime_asserts(
|
||||
# Convert the sympy expression into a sequence of FX
|
||||
# nodes
|
||||
with _set_node_metadata_hook(gm, _node_metadata_hook):
|
||||
res = _sympy_interp(expr_to_proxy, ra.expr, graph).node
|
||||
res = _sympy_interp(expr_to_proxy, ra.expr).node
|
||||
|
||||
graph.call_function(
|
||||
torch.ops.aten._assert_scalar.default,
|
||||
@ -449,7 +410,6 @@ def insert_deferred_runtime_asserts(
|
||||
expr_to_proxy,
|
||||
# pyrefly: ignore # unbound-name
|
||||
sym_expr,
|
||||
graph,
|
||||
) # type: ignore[arg-type]
|
||||
# won't try DCE-ing tensor compute here
|
||||
hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
|
||||
@ -667,9 +627,7 @@ def insert_deferred_runtime_asserts(
|
||||
),
|
||||
):
|
||||
if (min_val := convert(vr.lower)) is not None:
|
||||
ge = _sympy_interp(
|
||||
expr_to_proxy, i0 >= min_val, graph
|
||||
).node
|
||||
ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
|
||||
graph.call_function(
|
||||
torch.ops.aten._assert_scalar.default,
|
||||
(
|
||||
@ -679,9 +637,7 @@ def insert_deferred_runtime_asserts(
|
||||
)
|
||||
added_asserts.add(i0 >= min_val)
|
||||
if (max_val := convert(vr.upper)) is not None:
|
||||
le = _sympy_interp(
|
||||
expr_to_proxy, i0 <= max_val, graph
|
||||
).node
|
||||
le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
|
||||
graph.call_function(
|
||||
torch.ops.aten._assert_scalar.default,
|
||||
(
|
||||
|
||||
@ -42,9 +42,6 @@ fp16_ieee_to_fp32_value
|
||||
# fp32_from_bits called from fp16_ieee_to_fp32_value
|
||||
# fp32_to_bits called from fp16_ieee_from_fp32_value
|
||||
|
||||
# torch/headeronly/util/HeaderOnlyArrayRef.h
|
||||
HeaderOnlyArrayRef
|
||||
|
||||
# c10/util/complex.h, torch/headeronly/util/complex.h
|
||||
complex
|
||||
|
||||
@ -136,5 +133,3 @@ AT_FORALL_SCALAR_TYPES_AND7
|
||||
AT_FORALL_QINT_TYPES
|
||||
AT_FORALL_FLOAT8_TYPES
|
||||
AT_FORALL_COMPLEX_TYPES
|
||||
toString
|
||||
<<
|
||||
|
||||
@ -63,15 +63,15 @@ struct dummy_int1_7_t {};
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(c10::Half, Half) \
|
||||
_(at::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(c10::BFloat16, BFloat16) \
|
||||
_(c10::Float8_e5m2, Float8_e5m2) \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn)
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
||||
|
||||
// This macro controls many of our C++ APIs, including constructors
|
||||
// for Scalar as well as the data() and item() accessors on Tensor
|
||||
@ -81,19 +81,19 @@ struct dummy_int1_7_t {};
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(c10::Half, Half) \
|
||||
_(at::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(c10::BFloat16, BFloat16) \
|
||||
_(c10::Float8_e5m2, Float8_e5m2) \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
// NB: Order matters for this macro; it is relied upon in
|
||||
// _promoteTypesLookup and the serialization format.
|
||||
@ -103,7 +103,7 @@ struct dummy_int1_7_t {};
|
||||
_(int16_t, Short) /* 2 */ \
|
||||
_(int, Int) /* 3 */ \
|
||||
_(int64_t, Long) /* 4 */ \
|
||||
_(c10::Half, Half) /* 5 */ \
|
||||
_(at::Half, Half) /* 5 */ \
|
||||
_(float, Float) /* 6 */ \
|
||||
_(double, Double) /* 7 */ \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
|
||||
@ -113,7 +113,7 @@ struct dummy_int1_7_t {};
|
||||
_(c10::qint8, QInt8) /* 12 */ \
|
||||
_(c10::quint8, QUInt8) /* 13 */ \
|
||||
_(c10::qint32, QInt32) /* 14 */ \
|
||||
_(c10::BFloat16, BFloat16) /* 15 */ \
|
||||
_(at::BFloat16, BFloat16) /* 15 */ \
|
||||
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
||||
_(c10::quint2x4, QUInt2x4) /* 17 */ \
|
||||
_(c10::bits1x8, Bits1x8) /* 18 */ \
|
||||
@ -176,19 +176,24 @@ struct dummy_int1_7_t {};
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE>, SCALARTYPE)
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE>::t), \
|
||||
SCALARTYPE)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
|
||||
SCALARTYPE1) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, SCALARTYPE2)
|
||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
|
||||
_(uint8_t, Byte) \
|
||||
@ -198,41 +203,53 @@ struct dummy_int1_7_t {};
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE3>, SCALARTYPE3)
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND7( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
SCALARTYPE7, \
|
||||
_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
|
||||
SCALARTYPE1) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, \
|
||||
SCALARTYPE2) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE3>, \
|
||||
SCALARTYPE3) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE4>, \
|
||||
SCALARTYPE4) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE5>, \
|
||||
SCALARTYPE5) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE6>, \
|
||||
SCALARTYPE6) \
|
||||
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE7>, SCALARTYPE7)
|
||||
#define AT_FORALL_SCALAR_TYPES_AND7( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
SCALARTYPE7, \
|
||||
_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE4>::t), \
|
||||
SCALARTYPE4) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE5>::t), \
|
||||
SCALARTYPE5) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE6>::t), \
|
||||
SCALARTYPE6) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE7>::t), \
|
||||
SCALARTYPE7)
|
||||
|
||||
#define AT_FORALL_QINT_TYPES(_) \
|
||||
_(c10::qint8, QInt8) \
|
||||
@ -241,12 +258,12 @@ struct dummy_int1_7_t {};
|
||||
_(c10::quint4x2, QUInt4x2) \
|
||||
_(c10::quint2x4, QUInt2x4)
|
||||
|
||||
#define AT_FORALL_FLOAT8_TYPES(_) \
|
||||
_(c10::Float8_e5m2, Float8_e5m2) \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
#define AT_FORALL_FLOAT8_TYPES(_) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
#define AT_FORALL_COMPLEX_TYPES(_) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
@ -270,10 +287,19 @@ namespace impl {
|
||||
template <c10::ScalarType N>
|
||||
struct ScalarTypeToCPPType;
|
||||
|
||||
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
|
||||
using type = cpp_type; \
|
||||
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
|
||||
using type = cpp_type; \
|
||||
\
|
||||
/* This is a workaround for the CUDA bug which prevents */ \
|
||||
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
|
||||
/* ambiguous reference which can't to be resolved. For some reason it */ \
|
||||
/* can't pick between at::detail and at::cuda::detail. */ \
|
||||
/* For repro example, please see: */ \
|
||||
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
|
||||
/* TODO: remove once the bug is fixed. */ \
|
||||
static type t; \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
|
||||
@ -285,25 +311,6 @@ using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
|
||||
|
||||
} // namespace impl
|
||||
|
||||
inline const char* toString(ScalarType t) {
|
||||
#define DEFINE_CASE(_, name) \
|
||||
case ScalarType::name: \
|
||||
return #name;
|
||||
|
||||
switch (t) {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
||||
default:
|
||||
return "UNKNOWN_SCALAR";
|
||||
}
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
at::ScalarType scalar_type) {
|
||||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::headeronly {
|
||||
@ -314,6 +321,4 @@ using c10::ScalarType;
|
||||
namespace impl {
|
||||
using c10::impl::ScalarTypeToCPPTypeT;
|
||||
} // namespace impl
|
||||
using c10::toString;
|
||||
using c10::operator<<;
|
||||
} // namespace torch::headeronly
|
||||
|
||||
@ -1,247 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <iterator>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// HeaderOnlyArrayRef - A subset of ArrayRef that is implemented only
|
||||
/// in headers. This will be a base class from which ArrayRef inherits, so that
|
||||
/// we can keep much of the implementation shared.
|
||||
///
|
||||
/// [HeaderOnlyArrayRef vs ArrayRef note]
|
||||
/// As HeaderOnlyArrayRef is a subset of ArrayRef, it has slightly less
|
||||
/// functionality than ArrayRef. We document the minor differences below:
|
||||
/// 1. ArrayRef has an extra convenience constructor for SmallVector.
|
||||
/// 2. ArrayRef uses TORCH_CHECK. HeaderOnlyArrayRef uses header-only
|
||||
/// STD_TORCH_CHECK, which will output a std::runtime_error vs a
|
||||
/// c10::Error. Consequently, you should use ArrayRef when possible
|
||||
/// and HeaderOnlyArrayRef only when necessary to support headeronly code.
|
||||
/// In all other aspects, HeaderOnlyArrayRef is identical to ArrayRef, with the
|
||||
/// positive benefit of being header-only and thus independent of libtorch.so.
|
||||
template <typename T>
|
||||
class HeaderOnlyArrayRef {
|
||||
public:
|
||||
using iterator = const T*;
|
||||
using const_iterator = const T*;
|
||||
using size_type = size_t;
|
||||
using value_type = T;
|
||||
|
||||
using reverse_iterator = std::reverse_iterator<iterator>;
|
||||
|
||||
protected:
|
||||
/// The start of the array, in an external buffer.
|
||||
const T* Data;
|
||||
|
||||
/// The number of elements.
|
||||
size_type Length;
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Construct an empty HeaderOnlyArrayRef.
|
||||
/* implicit */ constexpr HeaderOnlyArrayRef() : Data(nullptr), Length(0) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a single element.
|
||||
// TODO Make this explicit
|
||||
constexpr HeaderOnlyArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a pointer and length.
|
||||
constexpr HeaderOnlyArrayRef(const T* data, size_t length)
|
||||
: Data(data), Length(length) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a range.
|
||||
constexpr HeaderOnlyArrayRef(const T* begin, const T* end)
|
||||
: Data(begin), Length(end - begin) {}
|
||||
|
||||
template <
|
||||
typename Container,
|
||||
typename U = decltype(std::declval<Container>().data()),
|
||||
typename = std::enable_if_t<
|
||||
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
|
||||
/* implicit */ HeaderOnlyArrayRef(const Container& container)
|
||||
: Data(container.data()), Length(container.size()) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a std::vector.
|
||||
// The enable_if stuff here makes sure that this isn't used for
|
||||
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
|
||||
// bitfield.
|
||||
template <typename A>
|
||||
/* implicit */ HeaderOnlyArrayRef(const std::vector<T, A>& Vec)
|
||||
: Data(Vec.data()), Length(Vec.size()) {
|
||||
static_assert(
|
||||
!std::is_same_v<T, bool>,
|
||||
"HeaderOnlyArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
|
||||
}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a std::array
|
||||
template <size_t N>
|
||||
/* implicit */ constexpr HeaderOnlyArrayRef(const std::array<T, N>& Arr)
|
||||
: Data(Arr.data()), Length(N) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a C array.
|
||||
template <size_t N>
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
/* implicit */ constexpr HeaderOnlyArrayRef(const T (&Arr)[N])
|
||||
: Data(Arr), Length(N) {}
|
||||
|
||||
/// Construct a HeaderOnlyArrayRef from a std::initializer_list.
|
||||
/* implicit */ constexpr HeaderOnlyArrayRef(
|
||||
const std::initializer_list<T>& Vec)
|
||||
: Data(
|
||||
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
|
||||
: std::begin(Vec)),
|
||||
Length(Vec.size()) {}
|
||||
|
||||
/// @}
|
||||
/// @name Simple Operations
|
||||
/// @{
|
||||
|
||||
constexpr iterator begin() const {
|
||||
return this->Data;
|
||||
}
|
||||
constexpr iterator end() const {
|
||||
return this->Data + this->Length;
|
||||
}
|
||||
|
||||
// These are actually the same as iterator, since ArrayRef only
|
||||
// gives you const iterators.
|
||||
constexpr const_iterator cbegin() const {
|
||||
return this->Data;
|
||||
}
|
||||
constexpr const_iterator cend() const {
|
||||
return this->Data + this->Length;
|
||||
}
|
||||
|
||||
constexpr reverse_iterator rbegin() const {
|
||||
return reverse_iterator(end());
|
||||
}
|
||||
constexpr reverse_iterator rend() const {
|
||||
return reverse_iterator(begin());
|
||||
}
|
||||
|
||||
/// Check if all elements in the array satisfy the given expression
|
||||
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
|
||||
return std::all_of(cbegin(), cend(), pred);
|
||||
}
|
||||
|
||||
/// empty - Check if the array is empty.
|
||||
constexpr bool empty() const {
|
||||
return this->Length == 0;
|
||||
}
|
||||
|
||||
constexpr const T* data() const {
|
||||
return this->Data;
|
||||
}
|
||||
|
||||
/// size - Get the array size.
|
||||
constexpr size_t size() const {
|
||||
return this->Length;
|
||||
}
|
||||
|
||||
/// front - Get the first element.
|
||||
constexpr const T& front() const {
|
||||
STD_TORCH_CHECK(
|
||||
!this->empty(),
|
||||
"HeaderOnlyArrayRef: attempted to access front() of empty list");
|
||||
return this->Data[0];
|
||||
}
|
||||
|
||||
/// back - Get the last element.
|
||||
constexpr const T& back() const {
|
||||
STD_TORCH_CHECK(
|
||||
!this->empty(),
|
||||
"HeaderOnlyArrayRef: attempted to access back() of empty list");
|
||||
return this->Data[this->Length - 1];
|
||||
}
|
||||
|
||||
/// equals - Check for element-wise equality.
|
||||
constexpr bool equals(HeaderOnlyArrayRef RHS) const {
|
||||
return this->Length == RHS.Length &&
|
||||
std::equal(begin(), end(), RHS.begin());
|
||||
}
|
||||
|
||||
/// slice(n, m) - Take M elements of the array starting at element N
|
||||
constexpr HeaderOnlyArrayRef<T> slice(size_t N, size_t M) const {
|
||||
STD_TORCH_CHECK(
|
||||
N + M <= this->size(),
|
||||
"HeaderOnlyArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
"; M = ",
|
||||
M,
|
||||
"; size = ",
|
||||
this->size());
|
||||
return HeaderOnlyArrayRef<T>(this->data() + N, M);
|
||||
}
|
||||
|
||||
/// slice(n) - Chop off the first N elements of the array.
|
||||
constexpr HeaderOnlyArrayRef<T> slice(size_t N) const {
|
||||
STD_TORCH_CHECK(
|
||||
N <= this->size(),
|
||||
"HeaderOnlyArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
"; size = ",
|
||||
this->size());
|
||||
return slice(N, this->size() - N);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Operator Overloads
|
||||
/// @{
|
||||
constexpr const T& operator[](size_t Index) const {
|
||||
return this->Data[Index];
|
||||
}
|
||||
|
||||
/// Vector compatibility
|
||||
constexpr const T& at(size_t Index) const {
|
||||
STD_TORCH_CHECK(
|
||||
Index < this->Length,
|
||||
"HeaderOnlyArrayRef: invalid index Index = ",
|
||||
Index,
|
||||
"; Length = ",
|
||||
this->Length);
|
||||
return this->Data[Index];
|
||||
}
|
||||
|
||||
/// Disallow accidental assignment from a temporary.
|
||||
///
|
||||
/// The declaration here is extra complicated so that "arrayRef = {}"
|
||||
/// continues to select the move assignment operator.
|
||||
template <typename U>
|
||||
std::enable_if_t<std::is_same_v<U, T>, HeaderOnlyArrayRef<T>>& operator=(
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
|
||||
U&& Temporary) = delete;
|
||||
|
||||
/// Disallow accidental assignment from a temporary.
|
||||
///
|
||||
/// The declaration here is extra complicated so that "arrayRef = {}"
|
||||
/// continues to select the move assignment operator.
|
||||
template <typename U>
|
||||
std::enable_if_t<std::is_same_v<U, T>, HeaderOnlyArrayRef<T>>& operator=(
|
||||
std::initializer_list<U>) = delete;
|
||||
|
||||
/// @}
|
||||
/// @name Expensive Operations
|
||||
/// @{
|
||||
std::vector<T> vec() const {
|
||||
return std::vector<T>(this->Data, this->Data + this->Length);
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::headeronly {
|
||||
using c10::HeaderOnlyArrayRef;
|
||||
using IntHeaderOnlyArrayRef = HeaderOnlyArrayRef<int64_t>;
|
||||
} // namespace torch::headeronly
|
||||
@ -1461,16 +1461,3 @@ def make_opaque_bitwise_fn(name, real_op_name):
|
||||
|
||||
BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_")
|
||||
BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_")
|
||||
|
||||
|
||||
from sympy.logic.boolalg import BooleanFunction
|
||||
|
||||
|
||||
class OrderedAnd(BooleanFunction):
|
||||
@classmethod
|
||||
def eval(cls, *args):
|
||||
# Returning None tells SymPy not to simplify further
|
||||
return None
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return " and ".join(printer._print(a) for a in self.args)
|
||||
|
||||
@ -34,7 +34,6 @@ from .functions import (
|
||||
Mod,
|
||||
ModularIndexing,
|
||||
OpaqueUnaryFn_log2,
|
||||
OrderedAnd,
|
||||
PowByNatural,
|
||||
PythonMod,
|
||||
RoundDecimal,
|
||||
@ -109,7 +108,6 @@ def handlers():
|
||||
OpaqueUnaryFn_log2: "log2",
|
||||
BitwiseFn_bitwise_and: "bitwise_and",
|
||||
BitwiseFn_bitwise_or: "bitwise_or",
|
||||
OrderedAnd: "ordered_and",
|
||||
}
|
||||
# TODO: This is kind of pointless, we shouldn't be generating sympy.sin
|
||||
# for these functions, they should be Opaque instead
|
||||
|
||||
@ -513,10 +513,6 @@ class SymPyValueRangeAnalysis:
|
||||
def and_(a, b):
|
||||
return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And)
|
||||
|
||||
@staticmethod
|
||||
def ordered_and(a, b):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
@staticmethod
|
||||
def _bool_to_int(x):
|
||||
if x.is_singleton():
|
||||
|
||||
@ -639,8 +639,6 @@ def is_pytorch_file(rel_filepath):
|
||||
return True
|
||||
if rel_filepath.startswith("third_party/nvfuser/"):
|
||||
return True
|
||||
if rel_filepath.startswith("third_party/fbgemm/"):
|
||||
return True
|
||||
if rel_filepath.startswith("tools/autograd/templates/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user