Compare commits

...

2 Commits

Author SHA1 Message Date
998e02cc96 fall back deterministic index_copy to index_put 2025-11-14 19:30:09 +08:00
226850cc66 [ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (#167734)
This PR add a sm_121a flag for row-wise scaled matmuls on DGX Spark.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167734
Approved by: https://github.com/eqy, https://github.com/cyyever
2025-11-14 08:44:04 +00:00
2 changed files with 8 additions and 2 deletions

View File

@ -1087,7 +1087,8 @@ TORCH_IMPL_FUNC(index_copy_out)
result.copy_(self);
// See Note [Enabling Deterministic Operations]
if (result.is_cuda() && globalContext().deterministicAlgorithms()) {
if ((result.is_cuda() || result.is_xpu()) &&
globalContext().deterministicAlgorithms()) {
torch::List<std::optional<Tensor>> indices;
indices.resize(dim + 1);
indices.set(dim, index);

View File

@ -118,6 +118,11 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
endif()
endif()
if("${_arch}" STREQUAL "121a")
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
endif()
endif()
endforeach()
list(JOIN _file_compile_flags " " _file_compile_flags)
@ -126,7 +131,7 @@ if(INTERN_BUILD_ATEN_OPS)
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a;103a;120a")
"89;90a;100a;103a;120a;121a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"90a")