mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 06:34:55 +08:00
Compare commits
11 Commits
cpp-docs-d
...
viable/str
Author | SHA1 | Date | |
---|---|---|---|
e7592f4005 | |||
d334c3649d | |||
9f82535c5a | |||
5b35fc8777 | |||
2f38eece7c | |||
830e789a55 | |||
ad4dc52bf6 | |||
dac9ed9790 | |||
1c7fe8f861 | |||
4e643422f6 | |||
3c3b278872 |
@ -6,7 +6,7 @@ dependencies = [
|
||||
"GitPython==3.1.45",
|
||||
"docker==7.1.0",
|
||||
"pytest==7.3.2",
|
||||
"uv==0.8.6"
|
||||
"uv==0.9.5"
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
|
15
.github/workflows/periodic.yml
vendored
15
.github/workflows/periodic.yml
vendored
@ -147,15 +147,16 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
|
||||
cuda-arch-list: 8.9
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
147
.github/workflows/trunk-tagging.yml
vendored
147
.github/workflows/trunk-tagging.yml
vendored
@ -58,8 +58,10 @@ jobs:
|
||||
else
|
||||
COMMIT_SHA="${{ github.sha }}"
|
||||
fi
|
||||
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||
{
|
||||
echo "sha=${COMMIT_SHA}"
|
||||
echo "tag_name=trunk/${COMMIT_SHA}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Validate commit SHA
|
||||
run: |
|
||||
@ -87,7 +89,7 @@ jobs:
|
||||
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
|
||||
fi
|
||||
|
||||
- name: Create and push tag with retry
|
||||
- name: Create and push tag(s) with retry
|
||||
id: check_tag
|
||||
env:
|
||||
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
|
||||
@ -112,14 +114,23 @@ jobs:
|
||||
return 1
|
||||
}
|
||||
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
echo "exists=true" >> "${GITHUB_OUTPUT}"
|
||||
exit 0
|
||||
fi
|
||||
# Counters for summary reporting
|
||||
created_count=0
|
||||
skipped_count=0
|
||||
failed_count=0
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
# Always write outputs once on exit
|
||||
finish() {
|
||||
set +e
|
||||
if [ -n "${GITHUB_OUTPUT:-}" ]; then
|
||||
{
|
||||
echo "created_count=${created_count}"
|
||||
echo "skipped_count=${skipped_count}"
|
||||
echo "failed_count=${failed_count}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
fi
|
||||
}
|
||||
trap finish EXIT
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES=5
|
||||
@ -194,31 +205,111 @@ jobs:
|
||||
}
|
||||
}
|
||||
|
||||
# Execute with retry
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
echo "exists=false" >> "${GITHUB_OUTPUT}"
|
||||
# New behavior for push events: enumerate commits in the push and tag each one.
|
||||
# For workflow_dispatch, retain existing single-SHA behavior.
|
||||
|
||||
# Always fetch tags once up front to improve idempotency in loops
|
||||
git fetch origin --tags --quiet || true
|
||||
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
BEFORE_SHA="${{ github.event.before }}"
|
||||
AFTER_SHA="${{ github.sha }}" # same as event.after
|
||||
|
||||
# List commits introduced by this push (old..new), oldest first for stable ordering
|
||||
commits_file="$(mktemp)"
|
||||
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
|
||||
|
||||
if [ ! -s "${commits_file}" ]; then
|
||||
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
|
||||
rm -f "${commits_file}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
|
||||
echo "Found ${commit_count} commit(s) to tag for push:"
|
||||
while IFS= read -r sha; do
|
||||
printf ' %s\n' "${sha}"
|
||||
done < "${commits_file}"
|
||||
|
||||
while IFS= read -r sha; do
|
||||
TAG_NAME="trunk/${sha}"
|
||||
COMMIT_SHA="${sha}"
|
||||
|
||||
# If tag already exists locally or remotely, skip (idempotent)
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag ${TAG_NAME} already exists - skipping"
|
||||
skipped_count=$((skipped_count + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=$((created_count + 1))
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
|
||||
failed_count=$((failed_count + 1))
|
||||
fi
|
||||
done < "${commits_file}"
|
||||
|
||||
rm -f "${commits_file}"
|
||||
|
||||
if [ "${failed_count}" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
exit 0
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
exit 1
|
||||
# workflow_dispatch path (single SHA tagging preserved)
|
||||
|
||||
# Exit early if tag already exists
|
||||
if check_tag_exists; then
|
||||
echo "✅ Tag already exists - no action needed"
|
||||
skipped_count=1
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||
|
||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||
created_count=1
|
||||
exit 0
|
||||
else
|
||||
echo "Tag creation failed after all retry attempts"
|
||||
failed_count=1
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Tag creation summary
|
||||
if: always()
|
||||
run: |
|
||||
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
elif [ "${{ job.status }}" = "success" ]; then
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
echo "Trigger: push on main"
|
||||
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
|
||||
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
|
||||
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
else
|
||||
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
|
||||
fi
|
||||
else
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
||||
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
|
||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||
else
|
||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
else
|
||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
echo ""
|
||||
echo "Tag details:"
|
||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||
echo " Trigger: ${{ github.event_name }}"
|
||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||
fi
|
||||
fi
|
||||
|
@ -272,28 +272,110 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
||||
}
|
||||
}
|
||||
|
||||
static bool getDisableAddmmCudaLt() {
|
||||
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
|
||||
if (env_value == "1") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
/*
|
||||
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
|
||||
* Additionally, for ROCM we test whether the architecture supports the Lt.
|
||||
*/
|
||||
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
|
||||
// When hipBLASLt is not supported on the architecture, return true
|
||||
#ifdef USE_ROCM
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx90a", "gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx950", "gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
|
||||
if (!is_hipblas_lt_arch_supported) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Check whether it is disabled in the env
|
||||
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
|
||||
if (is_addmm_cuda_lt_disabled == "1") {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool isSupportedHipLtROCmArch(int index) {
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx90a", "gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx950", "gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
return at::detail::getCUDAHooks().isGPUArch(archs, index);
|
||||
/*
|
||||
* Check whether for the given input we want to enable the Lt interface
|
||||
*/
|
||||
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
|
||||
// Implies 2D bias which we currently not send through Lt.
|
||||
// TODO: this check is done pre col-major input preparation,
|
||||
// so, this condition can be ralexed in cases when a col-major
|
||||
// copy of result is needed.
|
||||
if (result.is_same(self)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
const auto args = cublasCommonArgs(mat1, mat2, result);
|
||||
if (args.transa == 't' && args.transb == 't') {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
const auto mat1_sizes = mat1.sizes();
|
||||
const auto mat2_sizes = mat2.sizes();
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
const auto scalar_type = mat1.scalar_type();
|
||||
return (beta.toComplexDouble() == 1.0
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
|
||||
&& result.dim() == 2 && result.is_contiguous()
|
||||
&& ( // some dtype restrictions
|
||||
#ifndef USE_ROCM
|
||||
scalar_type == at::ScalarType::Double ||
|
||||
#endif
|
||||
scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16
|
||||
)
|
||||
&& ( // some shape/stride restrictions
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
|
||||
// their row-/col-majorness.
|
||||
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
|
||||
// The last conditions is to skip 16b transA and non-trans-B having
|
||||
// leading dim >> rows when they are sliced from a large tensor
|
||||
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
|
||||
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
|
||||
// Related to avoiding the leading stride >> leading dim problematic case
|
||||
// with 16b dtypes described above. For such dtypes we only allow inputs
|
||||
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
|
||||
// In that case the leading stride will be equal to the outer dim len.
|
||||
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
|
||||
// does not modify inputs as long as there is a stride of length 1
|
||||
// and the leading stride is at least max(1, other dim length), so we might
|
||||
// end up with contiguous cols but not rows (i.e. holes between different rows)
|
||||
// and vice versa.
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
&& (
|
||||
// filter by dtype
|
||||
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
|
||||
// check mat1/mat2 is row-/col-major
|
||||
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
|
||||
)
|
||||
#endif
|
||||
)
|
||||
);
|
||||
#endif
|
||||
|
||||
// no compliance by default
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
|
||||
@ -335,7 +417,70 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmAndBiasCublasLt(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const Tensor& self,
|
||||
const Scalar& alpha,
|
||||
Activation activation = Activation::None
|
||||
) {
|
||||
const auto* self_ptr = self.const_data_ptr<scalar_t>();
|
||||
|
||||
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
// TODO: maybe also return some success state?
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self_ptr,
|
||||
args.result->data_ptr<res_scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmCublas(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta
|
||||
) {
|
||||
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
beta.to<at::opmath_type<scalar_t>>(),
|
||||
args.result->data_ptr<res_scalar_t>(),
|
||||
args.result_ld
|
||||
);
|
||||
return true; // success!
|
||||
}
|
||||
|
||||
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
|
||||
// Shape checks {
|
||||
// Make sure to keep addmm_cuda below in sync with this code; it
|
||||
// preflights a check to try to avoid actually needing to call
|
||||
// expand().
|
||||
@ -345,105 +490,62 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
|
||||
)
|
||||
|
||||
if (result.is_same(self)) {
|
||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
|
||||
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
|
||||
}
|
||||
// } Shape checks
|
||||
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
IntArrayRef self__sizes;
|
||||
bool useLtInterface = false;
|
||||
#if defined(USE_ROCM)
|
||||
// When hipBLASLt is not supported on the architecture,
|
||||
// disable_addmm_cuda_lt will always be to set to true
|
||||
static bool disable_addmm_cuda_lt =
|
||||
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
|
||||
#else
|
||||
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
|
||||
#endif
|
||||
// Handle whether to use the Lt interface {
|
||||
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
|
||||
// if lt path fails, we recurse back into this function here and force the lt path to off
|
||||
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
|
||||
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
cublasCommonArgs _args(mat1, mat2, result);
|
||||
if (_args.transa == 't' && _args.transb == 't') {
|
||||
disable_addmm_cuda_lt_final = true;
|
||||
}
|
||||
#endif
|
||||
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#ifdef USE_ROCM
|
||||
// Conditioned on the device index, which is not persistent
|
||||
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
||||
#endif
|
||||
// Condition on the input
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
|
||||
// }
|
||||
|
||||
at::ScalarType scalar_type = mat1.scalar_type();
|
||||
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
|
||||
c10::MaybeOwned<Tensor> self_;
|
||||
if (&result != &self) {
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
// for cuda 11.4, cublasLtMatmul is activated
|
||||
// the last two conditions is to skip 16b transA and non-trans-B having
|
||||
// leading dim >> rows when they are sliced from a large tensor
|
||||
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
|
||||
if (!disable_addmm_cuda_lt_final) {
|
||||
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
|
||||
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
|
||||
self.is_contiguous() && result.is_contiguous() &&
|
||||
#ifdef USE_ROCM
|
||||
(scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16) &&
|
||||
#else
|
||||
(scalar_type == at::ScalarType::Double ||
|
||||
scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16) &&
|
||||
#endif
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
|
||||
#else
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
// avoid leading dim >> rows bugs
|
||||
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
|
||||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
|
||||
(scalar_type != at::ScalarType::Half &&
|
||||
scalar_type != at::ScalarType::BFloat16)) &&
|
||||
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
|
||||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
|
||||
(scalar_type != at::ScalarType::Half &&
|
||||
scalar_type != at::ScalarType::BFloat16));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
if (!useLtInterface) {
|
||||
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
|
||||
}
|
||||
self__sizes = self_->sizes();
|
||||
} else {
|
||||
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
self__sizes = self_->sizes();
|
||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
|
||||
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
|
||||
}
|
||||
|
||||
if (&result != &self) {
|
||||
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
|
||||
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
|
||||
at::native::copy_(result, *self_);
|
||||
// Handle result/self shapes
|
||||
if (!result.is_same(self)) {
|
||||
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
|
||||
|
||||
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
|
||||
if (disable_addmm_cuda_lt) {
|
||||
// When in non-Lt path we do expand self even before
|
||||
// check for beta != 0.0 to make sure that
|
||||
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
|
||||
// runs green.
|
||||
return expand_size(self, result.sizes(), "addmm");
|
||||
}
|
||||
// copy next, should broadcast
|
||||
return c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
}();
|
||||
// We copy bias when in the non-Lt path
|
||||
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
|
||||
// NOTE: self should broadcast over result
|
||||
at::native::copy_(result, *self_maybe_expanded);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
IntArrayRef result_sizes = result.sizes();
|
||||
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
|
||||
// Short circuit on empty result
|
||||
if (result.numel() == 0) {
|
||||
return result;
|
||||
}
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, result);
|
||||
|
||||
if (mat1.numel() == 0) {
|
||||
// Short circuit if the reduction dim is empty
|
||||
if (mat1.sizes()[1] == 0) {
|
||||
// By definition, when beta==0, values in self should be ignored. nans and infs
|
||||
// should not propagate
|
||||
if (beta.toComplexDouble() == 0.) {
|
||||
@ -455,158 +557,64 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
result,
|
||||
self.expand(result.sizes()),
|
||||
at::native::scalar_tensor(
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt /* layout */,
|
||||
at::kCPU,
|
||||
std::nullopt /* pin_memory */));
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt /* layout */,
|
||||
at::kCPU,
|
||||
std::nullopt /* pin_memory */
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, result);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||
|
||||
if (useLtInterface) {
|
||||
#if defined(USE_ROCM)
|
||||
bool okay = true;
|
||||
// The Lt path
|
||||
if (!disable_addmm_cuda_lt) {
|
||||
bool lt_success = false;
|
||||
if (is_float_output_with_half_input) {
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
activation_to_gemm_and_blas_arg(activation));
|
||||
} else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
// This condition is needed for mm case on ROCm for hipblasLt path.
|
||||
// Passing the bias ptr as null to avoid accuracy issues for mm case.
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
if (!okay) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
|
||||
}
|
||||
#else
|
||||
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
|
||||
bool okay = true;
|
||||
if (is_float_output_with_half_input) {
|
||||
#else
|
||||
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
|
||||
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
|
||||
}
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
|
||||
}
|
||||
else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<float>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
}});
|
||||
);
|
||||
#endif
|
||||
} else {
|
||||
// !is_float_output_with_half_input
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
activation_epilogue);
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
|
||||
}
|
||||
else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
}});
|
||||
}
|
||||
if (!okay) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
);
|
||||
} // end is_float_output_with_half_input
|
||||
|
||||
if (!lt_success) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
|
||||
}
|
||||
#endif
|
||||
} else
|
||||
{
|
||||
// end Lt path
|
||||
} else {
|
||||
// No Lt, we use a GEMM instead
|
||||
if (is_float_output_with_half_input) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda",
|
||||
[&] {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
|
||||
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
|
||||
|
||||
float* result_ptr = args.result->mutable_data_ptr<float>();
|
||||
at::cuda::blas::gemm<scalar_t, float>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha_val,
|
||||
mat1_ptr,
|
||||
args.lda,
|
||||
mat2_ptr,
|
||||
args.ldb,
|
||||
beta_val,
|
||||
result_ptr,
|
||||
args.result_ld);
|
||||
});
|
||||
launchGemmCublas<scalar_t, float>(args, alpha, beta);
|
||||
}
|
||||
);
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
@ -614,28 +622,12 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type,
|
||||
"addmm_cuda",
|
||||
[&] {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
|
||||
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
|
||||
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
|
||||
at::cuda::blas::gemm<scalar_t>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha_val,
|
||||
mat1_ptr,
|
||||
args.lda,
|
||||
mat2_ptr,
|
||||
args.ldb,
|
||||
beta_val,
|
||||
result_ptr,
|
||||
args.result_ld);
|
||||
});
|
||||
launchGemmCublas<scalar_t>(args, alpha, beta);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
switch (activation) {
|
||||
case Activation::RELU:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
@ -647,14 +639,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
} // end GEMM path
|
||||
|
||||
// Preprocessor gate here needs to match the inverse of the check
|
||||
// gating activation_to_gemm_and_blas_arg above; here we are manually
|
||||
// performing a post-GELU because we weren't able to use the GELU
|
||||
// epilogue above.
|
||||
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
|
||||
if (useLtInterface && activation == Activation::GELU) {
|
||||
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
|
||||
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
|
||||
}
|
||||
#endif
|
||||
|
@ -23,7 +23,7 @@ namespace at::native {
|
||||
|
||||
// The maximum number of threads in a block
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int MAX_BLOCK_SIZE = 256;
|
||||
constexpr int MAX_BLOCK_SIZE = 1024;
|
||||
#else
|
||||
constexpr int MAX_BLOCK_SIZE = 512;
|
||||
#endif
|
||||
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
|
||||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
||||
static int getNumThreads(int nElem) {
|
||||
#if defined(USE_ROCM)
|
||||
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
|
||||
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
|
||||
#else
|
||||
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
|
||||
#endif
|
||||
|
@ -92,6 +92,16 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
|
||||
output_offset + output_y * output_dim_x + output_x);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
|
||||
const int64_t two = (len - 1) * 2;
|
||||
if (two <= 0) {
|
||||
return 0;
|
||||
}
|
||||
int64_t m = x % two;
|
||||
if (m < 0) m += two;
|
||||
return (m < len) ? m : (two - m);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reflection_pad1d_out_kernel(
|
||||
const scalar_t * input, scalar_t * output,
|
||||
@ -106,6 +116,28 @@ __global__ void reflection_pad1d_out_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void reflection_pad1d_flat(
|
||||
const scalar_t* __restrict__ input,
|
||||
scalar_t* __restrict__ output,
|
||||
int64_t input_w, int64_t pad_l, int64_t pad_r,
|
||||
int64_t out_w, int64_t plane_count) {
|
||||
|
||||
const int64_t bx = blockDim.x;
|
||||
const int64_t tx = threadIdx.x;
|
||||
|
||||
const int64_t total = plane_count * out_w;
|
||||
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
|
||||
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
|
||||
|
||||
for (; linear < total; linear += grid_stride) {
|
||||
const int64_t plane = linear / out_w;
|
||||
const int64_t x = linear - plane * out_w;
|
||||
const int64_t j = reflect_index(x - pad_l, input_w);
|
||||
output[plane * out_w + x] = input[plane * input_w + j];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void reflection_pad1d_backward_out_kernel(
|
||||
scalar_t * grad_input, const scalar_t * grad_output,
|
||||
@ -710,25 +742,44 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
|
||||
int64_t input_w = input_.size(dim_w);
|
||||
int64_t output_w = input_w + pad_l + pad_r;
|
||||
|
||||
dim3 block_size(output_w > 256 ? 256 : output_w);
|
||||
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
|
||||
|
||||
Tensor input = input_.contiguous();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
||||
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
|
||||
reflection_pad1d_out_kernel<<<
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w,
|
||||
pad_l,
|
||||
pad_r);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
|
||||
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
const int max_x = prop->maxGridSize[0];
|
||||
const int max_y = prop->maxGridSize[1];
|
||||
const int max_z = prop->maxGridSize[2];
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
|
||||
|
||||
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
|
||||
|
||||
if (fits3d) {
|
||||
dim3 block(block_x, 1, 1);
|
||||
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
|
||||
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w, pad_l, pad_r);
|
||||
} else {
|
||||
dim3 block(block_x, 1, 1);
|
||||
const int64_t plane_count = nplane * nbatch;
|
||||
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
|
||||
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
|
||||
dim3 grid(grid_x, 1, 1);
|
||||
|
||||
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
output.mutable_data_ptr<scalar_t>(),
|
||||
input_w, pad_l, pad_r, output_w, plane_count);
|
||||
}
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
|
||||
|
@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
FusedOptimizerTensorListMetadata<3>& tl,
|
||||
const float* lr_ptr,
|
||||
const double& lr,
|
||||
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace at::native
|
||||
} // namespace at::native
|
||||
|
@ -1,8 +1,8 @@
|
||||
add_loop_eager,compile_time_instruction_count,3070000000,0.1
|
||||
add_loop_eager,compile_time_instruction_count,3184000000,0.1
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1719000000,0.1
|
||||
update_hint_regression,compile_time_instruction_count,1645000000,0.1
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
|
||||
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
|
||||
|
||||
|
||||
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
|
||||
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
|
||||
|
|
@ -67,7 +67,21 @@ class TestFullyShardMemory(FSDPTest):
|
||||
# allocate the cuBLAS workspaces before measuring the memory usage
|
||||
# since the workspace size can differ between hardwares
|
||||
lin = torch.nn.Linear(768, 768, device=device_type)
|
||||
inp = torch.randn(1, 768, device=device_type)
|
||||
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
|
||||
# the input shape was (1, 768), so that the forward gemm used
|
||||
# cublaslt, and the backward used cublas.
|
||||
# With the aforementioned PR, and with shape (1, 768),
|
||||
# the cublas path is used both in forward and in backward,
|
||||
# altering peak memory usage not accounting for cublaslt.
|
||||
# Here we change the input shape to (2, 768), and that swaps
|
||||
# the cublas/cublaslt selection in the forward/backward,
|
||||
# but that does not affect the peak memory usage stored in `base_mem_mb`.
|
||||
# Reasons for the flip:
|
||||
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
|
||||
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
|
||||
# since the input preparation can swap matrices based on output
|
||||
# row-/col-majorness.
|
||||
inp = torch.randn(2, 768, device=device_type)
|
||||
lin(inp).sum().backward()
|
||||
torch.get_device_module(device_type).empty_cache()
|
||||
base_mem_mb = self._get_peak_active_memory_mb()
|
||||
|
@ -288,6 +288,18 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_graph_break(self):
|
||||
def fn(x):
|
||||
with torch.fx.traceback.annotate({"pp_stage": 0}):
|
||||
x = torch.sin(x)
|
||||
torch._dynamo.graph_break()
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -6,6 +6,7 @@ import builtins
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import functools
|
||||
import inspect
|
||||
import io
|
||||
@ -19,6 +20,7 @@ import traceback
|
||||
import types
|
||||
import typing
|
||||
import unittest
|
||||
import weakref
|
||||
import warnings
|
||||
from math import sqrt
|
||||
from torch.multiprocessing import Process
|
||||
@ -1624,6 +1626,25 @@ class TestFX(JitTestCase):
|
||||
|
||||
self.assertTrue(neg not in relu.users)
|
||||
|
||||
@skipIfTorchDynamo("Dynamo does not free right away")
|
||||
def test_prepend_does_not_leak(self):
|
||||
g = Graph()
|
||||
x = g.placeholder("x")
|
||||
relu = g.call_function(torch.relu, (x,))
|
||||
neg = g.call_function(torch.neg, (x,))
|
||||
|
||||
relu.prepend(neg)
|
||||
|
||||
ref = weakref.ref(neg)
|
||||
g.erase_node(neg)
|
||||
del g
|
||||
del x
|
||||
del relu
|
||||
del neg
|
||||
gc.collect()
|
||||
|
||||
self.assertIsNone(ref())
|
||||
|
||||
def test_remove_uses_with_custom_filter(self):
|
||||
g: torch.fx.Graph = Graph()
|
||||
x: torch.fx.Node = g.placeholder("x")
|
||||
|
@ -7381,6 +7381,10 @@ torch.cuda.synchronize()
|
||||
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
||||
@parametrize("use_legacy_api", [True, False])
|
||||
@skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_dummy_mha_with_nt(self, device, use_legacy_api):
|
||||
bs = 3
|
||||
d1 = 2
|
||||
|
@ -8490,6 +8490,14 @@ class TestNNDeviceType(NNTestCase):
|
||||
y_cuda_contig = pool(x_cuda.contiguous())
|
||||
self.assertEqual(y_cuda_ch_last, y_cuda_contig)
|
||||
|
||||
@onlyCUDA
|
||||
def test_large_reflect_pad(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/165861
|
||||
x = torch.rand(2**16, 2, device="cuda")
|
||||
c = F.pad(x, (1, 1), mode="reflect")
|
||||
c_cpu = F.pad(x.cpu(), (1, 1), mode="reflect")
|
||||
self.assertEqual(c, c_cpu)
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest("48GB", "cpu")
|
||||
@largeTensorTest("48GB", "cuda")
|
||||
|
@ -247,6 +247,10 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
||||
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_compile(self) -> None:
|
||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||
|
||||
@ -576,6 +580,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_prune_dense_static_sort(self, dtype) -> None:
|
||||
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
|
||||
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
|
||||
@ -621,6 +629,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
|
||||
inp = torch.tensor(
|
||||
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
|
||||
@ -658,6 +670,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
|
||||
M, N = 128, 256
|
||||
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
|
||||
@ -692,6 +708,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pack_both_ways_id(self, dtype) -> None:
|
||||
N = 512
|
||||
torch.manual_seed(0)
|
||||
@ -729,6 +749,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_pack_both_ways_edge_case1(self, dtype) -> None:
|
||||
# In this case, the heuristic will keep 7 values out of 16
|
||||
# instead of 8. let's see how the kernel handles this
|
||||
@ -754,6 +778,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_apply(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
@ -770,6 +798,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_apply_dense(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
@ -808,6 +840,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls(self, dtype) -> None:
|
||||
M, N, K = 64, 256, 1024
|
||||
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
||||
@ -843,6 +879,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls_mat_vec(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([128], device="cuda", dtype=torch.float16)
|
||||
@ -853,6 +893,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls_bmm(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
|
||||
|
@ -2758,6 +2758,12 @@ class _NodeBase:
|
||||
return_type: Any,
|
||||
) -> None: ...
|
||||
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||
def _prepend(self, n: FxNode) -> None: ...
|
||||
def _remove_from_list(self) -> None: ...
|
||||
def __lt__(self, n: Self) -> _bool: ...
|
||||
def __gt__(self, n: Self) -> _bool: ...
|
||||
def __le__(self, n: Self) -> _bool: ...
|
||||
def __ge__(self, n: Self) -> _bool: ...
|
||||
|
||||
class _NodeIter(Iterator[FxNode]):
|
||||
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
||||
|
@ -2810,5 +2810,15 @@
|
||||
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0279": [
|
||||
{
|
||||
"Gb_type": "torch.fx.traceback.annotate escaped from compiled region",
|
||||
"Context": "str(self)",
|
||||
"Explanation": "Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -1295,6 +1295,16 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
|
||||
def fn_name(self):
|
||||
return "annotate"
|
||||
|
||||
def reconstruct_type(self, codegen: "PyCodegen"):
|
||||
unimplemented_v2(
|
||||
gb_type="torch.fx.traceback.annotate escaped from compiled region",
|
||||
context=str(self),
|
||||
explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class DynamoConfigPatchVariable(ContextWrappingVariable):
|
||||
"""represents torch._dynamo.patch_dynamo_config"""
|
||||
|
@ -1,11 +1,15 @@
|
||||
#include <torch/csrc/fx/node.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <structmember.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace {
|
||||
|
||||
using NodeSortKey = c10::SmallVector<int64_t, 4>;
|
||||
struct NodeBase;
|
||||
|
||||
// Thrown to exit out of a C++ function and return an error to Python.
|
||||
@ -163,7 +167,41 @@ struct NodeBase {
|
||||
PyObject* users;
|
||||
PyObject* _repr_fn;
|
||||
PyObject* meta;
|
||||
PyObject* _sort_key;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
|
||||
|
||||
inline NodeSortKey& sort_key() {
|
||||
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
|
||||
}
|
||||
|
||||
inline void set_prev(NodeBase* value) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
|
||||
Py_INCREF(reinterpret_cast<PyObject*>(value));
|
||||
NodeBase* old = _prev;
|
||||
_prev = value;
|
||||
Py_DECREF(reinterpret_cast<PyObject*>(old));
|
||||
}
|
||||
|
||||
inline void set_next(NodeBase* value) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
|
||||
Py_INCREF(reinterpret_cast<PyObject*>(value));
|
||||
NodeBase* old = _next;
|
||||
_next = value;
|
||||
Py_DECREF(reinterpret_cast<PyObject*>(old));
|
||||
}
|
||||
|
||||
// Equivalent to:
|
||||
// p, n = self._prev, self._next
|
||||
// p._next, n._prev = n, p
|
||||
inline void remove_from_list() {
|
||||
if (this->_prev == this && this->_next == this) {
|
||||
return;
|
||||
}
|
||||
NodeBase* p = this->_prev;
|
||||
NodeBase* n = this->_next;
|
||||
p->set_next(n);
|
||||
n->set_prev(p);
|
||||
}
|
||||
};
|
||||
|
||||
static PyObject* NodeBase_new(
|
||||
@ -173,6 +211,8 @@ static PyObject* NodeBase_new(
|
||||
PyObject* self = type->tp_alloc(type, 0);
|
||||
if (!self)
|
||||
return nullptr;
|
||||
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
|
||||
NodeSortKey(); // placement new does not allocate
|
||||
return self;
|
||||
}
|
||||
|
||||
@ -201,7 +241,6 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||
self->users = PyDict_New();
|
||||
self->_repr_fn = Py_NewRef(Py_None);
|
||||
self->meta = PyDict_New();
|
||||
self->_sort_key = PyTuple_New(0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -221,7 +260,6 @@ static struct PyMemberDef NodeBase_members[] = {
|
||||
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
|
||||
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
|
||||
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
|
||||
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
|
||||
{nullptr} /* Sentinel */
|
||||
};
|
||||
|
||||
@ -239,7 +277,6 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||
Py_VISIT(self->users);
|
||||
Py_VISIT(self->_repr_fn);
|
||||
Py_VISIT(self->meta);
|
||||
Py_VISIT(self->_sort_key);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -257,12 +294,12 @@ static int NodeBase_clear(NodeBase* self) {
|
||||
Py_CLEAR(self->users);
|
||||
Py_CLEAR(self->_repr_fn);
|
||||
Py_CLEAR(self->meta);
|
||||
Py_CLEAR(self->_sort_key);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void NodeBase_dealloc(PyObject* self) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
|
||||
(void)NodeBase_clear((NodeBase*)self);
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
@ -321,15 +358,195 @@ static PyObject* NodeBase__update_args_kwargs(
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__remove_from_list(
|
||||
PyObject* self,
|
||||
PyObject* _ignored) {
|
||||
reinterpret_cast<NodeBase*>(self)->remove_from_list();
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
|
||||
if (self_ == arg) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
if (!is_node(arg)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
|
||||
return nullptr;
|
||||
}
|
||||
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
|
||||
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
|
||||
if (self->graph != x->graph) {
|
||||
PyErr_SetString(
|
||||
PyExc_AssertionError,
|
||||
"Attempting to move a Node into a different Graph");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
x->remove_from_list();
|
||||
NodeBase* p = self->_prev;
|
||||
p->set_next(x);
|
||||
x->set_prev(p);
|
||||
x->set_next(self);
|
||||
self->set_prev(x);
|
||||
|
||||
// Now compute x.sort_key()
|
||||
const NodeSortKey& psk = x->_prev->sort_key();
|
||||
const NodeSortKey& nsk = x->_next->sort_key();
|
||||
if (psk.size() > nsk.size()) {
|
||||
// prefix = psk[: len(nsk)+1]
|
||||
size_t slice_len = nsk.size() + 1;
|
||||
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
|
||||
// last element is idx => increment by 1
|
||||
prefix.back()++;
|
||||
x->sort_key() = std::move(prefix);
|
||||
} else if (psk.size() < nsk.size()) {
|
||||
// prefix = nsk[: len(psk)+1]
|
||||
size_t slice_len = psk.size() + 1;
|
||||
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
|
||||
// last element is idx => decrement by 1
|
||||
prefix.back()--;
|
||||
x->sort_key() = std::move(prefix);
|
||||
} else {
|
||||
// same length => add a 0
|
||||
x->sort_key() = psk;
|
||||
x->sort_key().emplace_back(0);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// __lt__(self, other): Return self.sort_key < other.sort_key
|
||||
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
|
||||
// METH_O => one argument: 'other'
|
||||
if (!is_node(other)) {
|
||||
Py_RETURN_NOTIMPLEMENTED;
|
||||
}
|
||||
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||
bool less = std::lexicographical_compare(
|
||||
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
|
||||
if (less)
|
||||
Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
// __gt__(self, other): Return self.sort_key() > other.sort_key
|
||||
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
|
||||
if (!is_node(other)) {
|
||||
Py_RETURN_NOTIMPLEMENTED;
|
||||
}
|
||||
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||
// "a > b" is equivalent to "b < a"
|
||||
bool greater = std::lexicographical_compare(
|
||||
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
|
||||
if (greater)
|
||||
Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
|
||||
if (self == other) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
return NodeBase___gt__(self, other);
|
||||
}
|
||||
|
||||
// __le__(self, other): Return not (self > other)
|
||||
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
|
||||
if (self == other) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
return NodeBase___lt__(self, other);
|
||||
}
|
||||
|
||||
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
|
||||
// Only used by pickle/__getstate__
|
||||
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
|
||||
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||
const NodeSortKey& vec = node->sort_key();
|
||||
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
|
||||
THPObjectPtr tuple(PyTuple_New(n));
|
||||
if (!tuple) {
|
||||
return nullptr; // Out of memory
|
||||
}
|
||||
for (Py_ssize_t i = 0; i < n; i++) {
|
||||
PyObject* value = PyLong_FromSsize_t(vec[i]);
|
||||
if (!value) {
|
||||
return nullptr;
|
||||
}
|
||||
PyTuple_SET_ITEM(tuple.get(), i, value);
|
||||
}
|
||||
return tuple.release();
|
||||
}
|
||||
|
||||
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
|
||||
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
|
||||
static int NodeBase_set_sort_key(
|
||||
PyObject* self,
|
||||
PyObject* value,
|
||||
void* /*closure*/) {
|
||||
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||
if (!PyTuple_Check(value)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
|
||||
return -1;
|
||||
}
|
||||
Py_ssize_t size = PyTuple_GET_SIZE(value);
|
||||
NodeSortKey new_vec;
|
||||
new_vec.reserve(size);
|
||||
for (Py_ssize_t i = 0; i < size; i++) {
|
||||
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
|
||||
if (val == -1 && PyErr_Occurred()) {
|
||||
return -1;
|
||||
}
|
||||
new_vec.emplace_back(val);
|
||||
}
|
||||
node->sort_key() = std::move(new_vec);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyMethodDef NodeBase_methods[] = {
|
||||
{"_update_args_kwargs",
|
||||
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
|
||||
METH_FASTCALL,
|
||||
"Internal method: do not call directly."},
|
||||
{"_remove_from_list",
|
||||
(PyCFunction)(void*)(NodeBase__remove_from_list),
|
||||
METH_NOARGS,
|
||||
"Internal method: do not call directly."},
|
||||
{"_prepend",
|
||||
(PyCFunction)(void*)(NodeBase__prepend),
|
||||
METH_O,
|
||||
"Internal method: do not call directly."},
|
||||
{"__lt__",
|
||||
(PyCFunction)(void*)NodeBase___lt__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key < other.sort_key"},
|
||||
{"__gt__",
|
||||
(PyCFunction)(void*)NodeBase___gt__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key > other.sort_key"},
|
||||
{"__ge__",
|
||||
(PyCFunction)(void*)NodeBase___ge__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key >= other.sort_key"},
|
||||
{"__le__",
|
||||
(PyCFunction)(void*)NodeBase___le__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key <= other.sort_key"},
|
||||
{nullptr, nullptr, 0, nullptr} // Sentinel
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyGetSetDef NodeBase_getset[] = {
|
||||
{"_sort_key", // attribute name in Python
|
||||
(getter)NodeBase_get_sort_key, // C getter function
|
||||
(setter)NodeBase_set_sort_key, // C setter function
|
||||
(char*)"The sort key as a tuple of ints", // docstring
|
||||
nullptr},
|
||||
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
|
||||
};
|
||||
|
||||
PyTypeObject NodeBaseType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch._C._NodeBase", /* tp_name */
|
||||
@ -361,7 +578,7 @@ PyTypeObject NodeBaseType = {
|
||||
nullptr, /* tp_iternext */
|
||||
NodeBase_methods, /* tp_methods */
|
||||
NodeBase_members, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
NodeBase_getset, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
|
@ -385,41 +385,7 @@ class Node(_NodeBase):
|
||||
Args:
|
||||
x (Node): The node to put before this node. Must be a member of the same graph.
|
||||
"""
|
||||
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
|
||||
if self == x:
|
||||
log.debug(
|
||||
"Trying to prepend a node to itself. This behavior has no effect on the graph."
|
||||
)
|
||||
return
|
||||
x._remove_from_list()
|
||||
p = self._prev
|
||||
p._next, x._prev = x, p
|
||||
x._next, self._prev = self, x
|
||||
|
||||
# compute x._sort_key
|
||||
psk = x._prev._sort_key
|
||||
nsk = x._next._sort_key
|
||||
if len(psk) > len(nsk):
|
||||
idx: int
|
||||
*prefix, idx = psk[: len(nsk) + 1]
|
||||
x._sort_key = (*prefix, idx + 1)
|
||||
elif len(psk) < len(nsk):
|
||||
*prefix, idx = nsk[: len(psk) + 1]
|
||||
x._sort_key = (*prefix, idx - 1)
|
||||
else: # same length, increase length by 1
|
||||
x._sort_key = (*psk, 0)
|
||||
|
||||
def __gt__(self, other: "Node") -> bool:
|
||||
return self._sort_key > other._sort_key
|
||||
|
||||
def __lt__(self, other: "Node") -> bool:
|
||||
return self._sort_key < other._sort_key
|
||||
|
||||
def __ge__(self, other: "Node") -> bool:
|
||||
return self > other or self == other
|
||||
|
||||
def __le__(self, other: "Node") -> bool:
|
||||
return self < other or self == other
|
||||
self._prepend(x)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def append(self, x: "Node") -> None:
|
||||
@ -430,11 +396,7 @@ class Node(_NodeBase):
|
||||
Args:
|
||||
x (Node): The node to put after this node. Must be a member of the same graph.
|
||||
"""
|
||||
self._next.prepend(x)
|
||||
|
||||
def _remove_from_list(self) -> None:
|
||||
p, n = self._prev, self._next
|
||||
p._next, n._prev = n, p
|
||||
self._next._prepend(x)
|
||||
|
||||
@property
|
||||
def args(self) -> tuple[Argument, ...]:
|
||||
|
Reference in New Issue
Block a user