Revert "[PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)"

This reverts commit ef6296e7f20d744a0cfed81cab573d60204e7626.

Reverted https://github.com/pytorch/pytorch/pull/148590 on behalf of https://github.com/izaitsevfb due to reverted internally, see D71292427 ([comment](https://github.com/pytorch/pytorch/pull/148590#issuecomment-2731114626))
This commit is contained in:
PyTorch MergeBot
2025-03-17 22:43:15 +00:00
parent a16ada41b9
commit afa1eda901
11 changed files with 362 additions and 411 deletions

View File

@ -224,7 +224,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
const std::optional<at::Tensor>& sparse_indices,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
@ -232,7 +231,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -252,14 +250,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -281,7 +277,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t,
int64_t,
bool,
int64_t)>();
auto work = op.call(
tensors,
@ -289,7 +284,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -312,14 +306,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&,
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -371,19 +363,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool)>();
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
auto work = op.call(
outputTensorLists,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp);
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensorLists) {
@ -408,14 +399,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
.typed<c10::intrusive_ptr<Work>(
const at::TensorList,
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool)>();
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp);
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
@ -436,14 +425,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
bool,
int64_t)>();
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -500,14 +487,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -561,7 +546,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = op.call(
@ -569,7 +553,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -594,7 +577,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
std::vector<int64_t>,
std::vector<int64_t>,
bool,
int64_t)>();
auto work = op.call(
outputBuffer,
@ -602,7 +584,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes,
inputSplitSizes,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -623,13 +604,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -799,14 +778,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::Tensor,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const std::vector<int64_t>&,
bool,
int64_t)>();
auto work = op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(tensor, work);