mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[micro_pipeline_tp] support all _scaled_mm args (#131984)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131984 Approved by: https://github.com/weifengpy
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b5e31d099
commit
ea42027e0e
@ -223,25 +223,31 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T for _ in range(3)
|
||||
]
|
||||
B_scales = [torch.tensor(0.1, device="cuda") for _ in range(3)]
|
||||
output_dtypes = [None, torch.bfloat16, torch.float32]
|
||||
out_dtypes = [None, torch.bfloat16, torch.float32]
|
||||
|
||||
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
|
||||
A_shard,
|
||||
Bs,
|
||||
A_scale,
|
||||
B_scales,
|
||||
output_dtypes,
|
||||
gather_dim=gather_dim,
|
||||
group_name=group.group_name,
|
||||
)
|
||||
ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback(
|
||||
A_shard,
|
||||
Bs,
|
||||
A_scale,
|
||||
B_scales,
|
||||
output_dtypes,
|
||||
gather_dim=gather_dim,
|
||||
group_name=group.group_name,
|
||||
biases=[None] * len(Bs),
|
||||
result_scales=[None] * len(Bs),
|
||||
out_dtypes=out_dtypes,
|
||||
use_fast_accum=[None] * len(Bs),
|
||||
)
|
||||
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
|
||||
A_shard,
|
||||
Bs,
|
||||
A_scale,
|
||||
B_scales,
|
||||
gather_dim=gather_dim,
|
||||
group_name=group.group_name,
|
||||
biases=[None] * len(Bs),
|
||||
result_scales=[None] * len(Bs),
|
||||
out_dtypes=out_dtypes,
|
||||
use_fast_accum=[None] * len(Bs),
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
@ -317,20 +323,20 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
torch.bfloat16,
|
||||
"avg",
|
||||
scatter_dim=scatter_dim,
|
||||
group_name=group.group_name,
|
||||
scatter_dim,
|
||||
group.group_name,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
output_1 = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||
A,
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
torch.bfloat16,
|
||||
"avg",
|
||||
scatter_dim=scatter_dim,
|
||||
group_name=group.group_name,
|
||||
scatter_dim,
|
||||
group.group_name,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
assert torch.allclose(output_0, output_1)
|
||||
|
@ -2,7 +2,7 @@
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import cast, Dict, List, Optional, Set
|
||||
from typing import Any, cast, Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
@ -347,7 +347,6 @@ class _Matmul:
|
||||
assert len(match) in (1, 3)
|
||||
assert match[0].target in (
|
||||
aten.mm.default,
|
||||
aten._scaled_mm.default,
|
||||
aten.reshape.default,
|
||||
)
|
||||
mm_node = match[0] if len(match) == 1 else match[1]
|
||||
@ -362,7 +361,10 @@ class _Matmul:
|
||||
class _ScaledMatmul(_Matmul):
|
||||
A_scale_node: torch.fx.Node
|
||||
B_scale_node: torch.fx.Node
|
||||
bias_node: Optional[torch.fx.Node]
|
||||
result_scale_node: Optional[torch.fx.Node]
|
||||
out_dtype: Optional[torch.dtype]
|
||||
use_fast_accum: bool
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@ -373,22 +375,26 @@ class _ScaledMatmul(_Matmul):
|
||||
def from_match(cls, match: List[torch.fx.Node]) -> "_ScaledMatmul":
|
||||
assert len(match) in (1, 3)
|
||||
assert match[0].target in (
|
||||
aten.mm.default,
|
||||
aten._scaled_mm.default,
|
||||
aten.reshape.default,
|
||||
)
|
||||
mm_node = match[0] if len(match) == 1 else match[1]
|
||||
out_dtype = (
|
||||
cast(torch.dtype, mm_node.args[6]) if len(mm_node.args) > 6 else None
|
||||
)
|
||||
assert isinstance(out_dtype, (torch.dtype, type(None)))
|
||||
|
||||
def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any:
|
||||
if idx >= len(node.args):
|
||||
return default
|
||||
return node.args[idx]
|
||||
|
||||
return _ScaledMatmul(
|
||||
nodes=match,
|
||||
A_node=cast(torch.fx.Node, match[0].args[0]),
|
||||
B_node=cast(torch.fx.Node, mm_node.args[1]),
|
||||
A_scale_node=cast(torch.fx.Node, mm_node.args[2]),
|
||||
B_scale_node=cast(torch.fx.Node, mm_node.args[3]),
|
||||
out_dtype=out_dtype,
|
||||
bias_node=get_arg(mm_node, 4, None),
|
||||
result_scale_node=get_arg(mm_node, 5, None),
|
||||
out_dtype=get_arg(mm_node, 6, None),
|
||||
use_fast_accum=get_arg(mm_node, 7, False),
|
||||
)
|
||||
|
||||
|
||||
@ -477,20 +483,19 @@ def _insert_fused_all_gather_matmul(
|
||||
)
|
||||
elif mm_type == _ScaledMatmul:
|
||||
scaled_matmuls = cast(List[_ScaledMatmul], matmuls)
|
||||
B_nodes = [matmul.B_node for matmul in scaled_matmuls]
|
||||
A_scale_node = scaled_matmuls[0].A_scale_node
|
||||
B_scale_nodes = [matmul.B_scale_node for matmul in scaled_matmuls]
|
||||
out_dtypes = [matmul.out_dtype for matmul in scaled_matmuls]
|
||||
return graph.call_function(
|
||||
torch.ops.symm_mem.fused_all_gather_scaled_matmul.default,
|
||||
args=(
|
||||
shard_node,
|
||||
B_nodes,
|
||||
A_scale_node,
|
||||
B_scale_nodes,
|
||||
out_dtypes,
|
||||
[matmul.B_node for matmul in scaled_matmuls],
|
||||
scaled_matmuls[0].A_scale_node,
|
||||
[matmul.B_scale_node for matmul in scaled_matmuls],
|
||||
gather_dim,
|
||||
group_name,
|
||||
[matmul.bias_node for matmul in scaled_matmuls],
|
||||
[matmul.result_scale_node for matmul in scaled_matmuls],
|
||||
[matmul.out_dtype for matmul in scaled_matmuls],
|
||||
[matmul.use_fast_accum for matmul in scaled_matmuls],
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -647,10 +652,13 @@ def _insert_fused_matmul_reduce_scatter(
|
||||
matmul.B_node,
|
||||
matmul.A_scale_node,
|
||||
matmul.B_scale_node,
|
||||
matmul.out_dtype,
|
||||
reduce_op,
|
||||
scatter_dim,
|
||||
group_name,
|
||||
matmul.bias_node,
|
||||
matmul.result_scale_node,
|
||||
matmul.out_dtype,
|
||||
matmul.use_fast_accum,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
@ -260,7 +260,11 @@ lib.define(
|
||||
lib.define(
|
||||
"fused_all_gather_scaled_matmul("
|
||||
"Tensor A, Tensor[] Bs, Tensor A_scale, Tensor[] B_scales, "
|
||||
"ScalarType?[] out_dtypes, int gather_dim, str group_name) -> (Tensor, Tensor[])"
|
||||
"int gather_dim, str group_name, "
|
||||
"Tensor?[] biases, "
|
||||
"Tensor?[] result_scales, "
|
||||
"ScalarType?[] out_dtypes, "
|
||||
"bool[] use_fast_accum) -> (Tensor, Tensor[])"
|
||||
)
|
||||
lib.define(
|
||||
"fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor"
|
||||
@ -268,7 +272,11 @@ lib.define(
|
||||
lib.define(
|
||||
"fused_scaled_matmul_reduce_scatter("
|
||||
"Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, "
|
||||
"ScalarType out_dtype, str reduce_op, int scatter_dim, str group_name) -> Tensor"
|
||||
"str reduce_op, int scatter_dim, str group_name, "
|
||||
"Tensor? bias = None, "
|
||||
"Tensor? result_scale = None, "
|
||||
"ScalarType? out_dtype = None, "
|
||||
"bool use_fast_accum = False) -> Tensor"
|
||||
)
|
||||
lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor")
|
||||
lib.define(
|
||||
@ -276,6 +284,65 @@ lib.define(
|
||||
)
|
||||
|
||||
|
||||
def _fused_all_gather_matmul_impl(
|
||||
mm_out_op: torch._ops.OpOverload,
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
kwargs_list: List[Dict[str, Any]],
|
||||
out_dtypes: List[Optional[torch.dtype]],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
if A_shard.dim() < 2:
|
||||
raise ValueError("A_shard must be a matrix")
|
||||
for B in Bs:
|
||||
if B.dim() != 2:
|
||||
raise ValueError("B must be a matrix")
|
||||
if len(out_dtypes) != len(Bs):
|
||||
raise ValueError("len(out_types) must be the same as len(Bs)")
|
||||
if len(kwargs_list) != len(Bs):
|
||||
raise ValueError("len(kwargs_list) must be the same as len(Bs)")
|
||||
if gather_dim < 0 or gather_dim >= A_shard.dim():
|
||||
raise ValueError("Invalid gather_dim")
|
||||
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
|
||||
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
|
||||
# The flattened tensor doesn't need to be contiguous (for computation
|
||||
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
|
||||
# passed to shard_consumer are contiguous.
|
||||
x = A_shard.movedim(gather_dim, 0)
|
||||
leading_dims = [group.size()] + list(x.shape[:-1])
|
||||
x = x.flatten(0, -2)
|
||||
|
||||
# Helper function for reverting the above transformation
|
||||
def unflatten(t: torch.Tensor) -> torch.Tensor:
|
||||
return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim)
|
||||
|
||||
ag_out = x.new_empty(
|
||||
x.shape[0] * group.size(),
|
||||
x.shape[1],
|
||||
)
|
||||
outputs = [
|
||||
x.new_empty(x.shape[0] * group.size(), B.shape[1], dtype=out_dtype or B.dtype)
|
||||
for B, out_dtype in zip(Bs, out_dtypes)
|
||||
]
|
||||
output_shards = [output.chunk(group.size()) for output in outputs]
|
||||
|
||||
# Computing block-wise matmul along the first dim of A
|
||||
def shard_consumer(shard: torch.Tensor, rank: int) -> None:
|
||||
for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)):
|
||||
mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank])
|
||||
|
||||
_pipelined_all_gather_and_consume(
|
||||
x,
|
||||
shard_consumer,
|
||||
ag_out,
|
||||
group_name,
|
||||
)
|
||||
return unflatten(ag_out), [unflatten(output) for output in outputs]
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
|
||||
def _fused_all_gather_matmul_fallback(
|
||||
A_shard: torch.Tensor,
|
||||
@ -313,54 +380,140 @@ def _fused_all_gather_matmul(
|
||||
"""
|
||||
if _is_test_mode:
|
||||
return _fused_all_gather_matmul_fallback(A_shard, Bs, gather_dim, group_name)
|
||||
if A_shard.dim() < 2:
|
||||
raise ValueError("A_shard must be a matrix")
|
||||
for B in Bs:
|
||||
if B.dim() != 2:
|
||||
raise ValueError("B must be a matrix")
|
||||
if gather_dim < 0 or gather_dim >= A_shard.dim():
|
||||
raise ValueError("Invalid gather_dim")
|
||||
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
|
||||
with torch.profiler.record_function("fused_all_gather_matmul"):
|
||||
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
|
||||
# The flattened tensor doesn't need to be contiguous (for computation
|
||||
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
|
||||
# passed to shard_consumer are contiguous.
|
||||
x = A_shard.movedim(gather_dim, 0)
|
||||
leading_dims = [group.size()] + list(x.shape[:-1])
|
||||
x = x.flatten(0, -2)
|
||||
|
||||
# Helper function for reverting the above transformation
|
||||
def unflatten(t: torch.Tensor) -> torch.Tensor:
|
||||
return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim)
|
||||
|
||||
ag_out = x.new_empty(
|
||||
x.shape[0] * group.size(),
|
||||
x.shape[1],
|
||||
)
|
||||
outputs = [
|
||||
x.new_empty(
|
||||
x.shape[0] * group.size(),
|
||||
B.shape[1],
|
||||
)
|
||||
for B in Bs
|
||||
]
|
||||
output_shards = [output.chunk(group.size()) for output in outputs]
|
||||
|
||||
# Computing block-wise matmul along the first dim of A
|
||||
def shard_consumer(shard: torch.Tensor, rank: int) -> None:
|
||||
for idx, B in enumerate(Bs):
|
||||
torch.mm(shard, B, out=output_shards[idx][rank])
|
||||
|
||||
_pipelined_all_gather_and_consume(
|
||||
x,
|
||||
shard_consumer,
|
||||
ag_out,
|
||||
return _fused_all_gather_matmul_impl(
|
||||
torch.ops.aten.mm.out,
|
||||
A_shard,
|
||||
Bs,
|
||||
[{} for B in Bs],
|
||||
[B.dtype for B in Bs],
|
||||
gather_dim,
|
||||
group_name,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta")
|
||||
def _fused_all_gather_scaled_matmul_fallback(
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
A_scale: torch.Tensor,
|
||||
B_scales: List[torch.Tensor],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
biases: List[Optional[torch.Tensor]],
|
||||
result_scales: List[Optional[torch.Tensor]],
|
||||
out_dtypes: List[Optional[torch.dtype]],
|
||||
use_fast_accum: List[bool],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
|
||||
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
A = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
A_shard.contiguous(), group_size, group_name
|
||||
)
|
||||
A = torch.ops._c10d_functional.wait_tensor(A)
|
||||
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
|
||||
|
||||
def scaled_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
result_scale: Optional[torch.Tensor],
|
||||
out_dtype: Optional[torch.dtype],
|
||||
use_fast_accum: bool,
|
||||
) -> torch.Tensor:
|
||||
leading_dims = A.shape[:-1]
|
||||
res = torch.ops.aten._scaled_mm(
|
||||
A.flatten(0, -2), B, A_scale, B_scale, out_dtype=out_dtype
|
||||
)
|
||||
return res.unflatten(0, leading_dims)
|
||||
|
||||
return A.movedim(0, gather_dim), [
|
||||
scaled_matmul(
|
||||
A, B, A_scale, B_scale, bias, result_scale, out_dtype, fast_accum
|
||||
).movedim(0, gather_dim)
|
||||
for B, B_scale, bias, result_scale, out_dtype, fast_accum in zip(
|
||||
Bs, B_scales, biases, result_scales, out_dtypes, use_fast_accum
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA")
|
||||
def _fused_all_gather_scaled_matmul(
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
A_scale: torch.Tensor,
|
||||
B_scales: List[torch.Tensor],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
biases: List[Optional[torch.Tensor]],
|
||||
result_scales: List[Optional[torch.Tensor]],
|
||||
out_dtypes: List[Optional[torch.dtype]],
|
||||
use_fast_accum: List[bool],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Perform the following logic with micro-pipelined computation and
|
||||
communication:
|
||||
|
||||
A = all_gather_tensor(A_shard, gather_dim, group_name)
|
||||
leading_dims = A.shape[:-1]
|
||||
res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale)
|
||||
res = res.unflatten(0, leading_dims)
|
||||
|
||||
Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is
|
||||
contiguous, no extra copy is required for input layout transformation.
|
||||
Otherwise A_shard needs to be copied once.
|
||||
"""
|
||||
out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
|
||||
|
||||
if len(biases) != len(Bs):
|
||||
raise ValueError("len(biases) must be the same as len(Bs)")
|
||||
if len(result_scales) != len(Bs):
|
||||
raise ValueError("len(result_scales) must be the same as len(Bs)")
|
||||
if len(out_dtypes) != len(Bs):
|
||||
raise ValueError("len(out_dtypes) must be the same as len(Bs)")
|
||||
if len(use_fast_accum) != len(Bs):
|
||||
raise ValueError("len(use_gast_accum_list) must be the same as len(Bs)")
|
||||
|
||||
if _is_test_mode:
|
||||
return _fused_all_gather_scaled_matmul_fallback(
|
||||
A_shard,
|
||||
Bs,
|
||||
A_scale,
|
||||
B_scales,
|
||||
gather_dim,
|
||||
group_name,
|
||||
biases,
|
||||
result_scales,
|
||||
out_dtypes,
|
||||
use_fast_accum,
|
||||
)
|
||||
|
||||
with torch.profiler.record_function("fused_all_gather_scaled_matmul"):
|
||||
return _fused_all_gather_matmul_impl(
|
||||
torch.ops.aten._scaled_mm.out,
|
||||
A_shard,
|
||||
Bs,
|
||||
[
|
||||
{
|
||||
"scale_a": A_scale,
|
||||
"scale_b": B_scale,
|
||||
"bias": bias,
|
||||
"scale_result": result_scale,
|
||||
"out_dtype": out_dtype,
|
||||
"use_fast_accum": fast_accum,
|
||||
}
|
||||
for B_scale, bias, result_scale, out_dtype, fast_accum in zip(
|
||||
B_scales, biases, result_scales, out_dtypes, use_fast_accum
|
||||
)
|
||||
],
|
||||
out_dtypes,
|
||||
gather_dim,
|
||||
group_name,
|
||||
)
|
||||
return unflatten(ag_out), [unflatten(output) for output in outputs]
|
||||
|
||||
|
||||
def make_contiguous_for_perm(
|
||||
@ -499,17 +652,23 @@ def _fused_scaled_matmul_reduce_scatter_fallback(
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
group_name: str,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
result_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
use_fast_accum: bool = False,
|
||||
) -> torch.Tensor:
|
||||
C = torch._scaled_mm(
|
||||
A.flatten(0, -2).contiguous(),
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
out_dtype=out_dtype,
|
||||
bias,
|
||||
result_scale,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
)
|
||||
C = C.view(*A.shape[:-1], B.shape[1])
|
||||
res = funcol.reduce_scatter_tensor(
|
||||
@ -528,21 +687,41 @@ def _fused_scaled_matmul_reduce_scatter(
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
group_name: str,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
result_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
use_fast_accum: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if _is_test_mode:
|
||||
return _fused_scaled_matmul_reduce_scatter_fallback(
|
||||
A, B, A_scale, B_scale, out_dtype, reduce_op, scatter_dim, group_name
|
||||
A,
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
reduce_op,
|
||||
scatter_dim,
|
||||
group_name,
|
||||
bias,
|
||||
result_scale,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
)
|
||||
with torch.profiler.record_function("fused_matmul_reduce_scatter"):
|
||||
return _fused_matmul_reduce_scatter_impl(
|
||||
mm_out_op=torch.ops.aten._scaled_mm.out,
|
||||
A=A,
|
||||
B=B,
|
||||
kwargs={"scale_a": A_scale, "scale_b": B_scale, "out_dtype": out_dtype},
|
||||
kwargs={
|
||||
"scale_a": A_scale,
|
||||
"scale_b": B_scale,
|
||||
"bias": bias,
|
||||
"scale_result": result_scale,
|
||||
"out_dtype": out_dtype,
|
||||
"use_fast_accum": use_fast_accum,
|
||||
},
|
||||
out_dtype=out_dtype,
|
||||
reduce_op=reduce_op,
|
||||
scatter_dim=scatter_dim,
|
||||
@ -608,133 +787,6 @@ def _maybe_convert_scalar_types_to_dtypes(
|
||||
return dtypes
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta")
|
||||
def _fused_all_gather_scaled_matmul_fallback(
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
A_scale: torch.Tensor,
|
||||
B_scales: List[torch.Tensor],
|
||||
out_dtypes: List[Optional[torch.dtype]],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
|
||||
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
A = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
A_shard.contiguous(), group_size, group_name
|
||||
)
|
||||
A = torch.ops._c10d_functional.wait_tensor(A)
|
||||
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
|
||||
|
||||
def scaled_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
) -> torch.Tensor:
|
||||
leading_dims = A.shape[:-1]
|
||||
res = torch.ops.aten._scaled_mm(
|
||||
A.flatten(0, -2), B, A_scale, B_scale, out_dtype=out_dtype
|
||||
)
|
||||
return res.unflatten(0, leading_dims)
|
||||
|
||||
return A.movedim(0, gather_dim), [
|
||||
scaled_matmul(A, B, A_scale, B_scale, out_dtype).movedim(0, gather_dim)
|
||||
for B, B_scale, out_dtype in zip(Bs, B_scales, out_dtypes)
|
||||
]
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA")
|
||||
def _fused_all_gather_scaled_matmul(
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
A_scale: torch.Tensor,
|
||||
B_scales: List[torch.Tensor],
|
||||
out_dtypes: List[Optional[torch.dtype]],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Perform the following logic with micro-pipelined computation and
|
||||
communication:
|
||||
|
||||
A = all_gather_tensor(A_shard, gather_dim, group_name)
|
||||
leading_dims = A.shape[:-1]
|
||||
res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale)
|
||||
res = res.unflatten(0, leading_dims)
|
||||
|
||||
Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is
|
||||
contiguous, no extra copy is required for input layout transformation.
|
||||
Otherwise A_shard needs to be copied once.
|
||||
"""
|
||||
out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
|
||||
|
||||
if _is_test_mode:
|
||||
return _fused_all_gather_scaled_matmul_fallback(
|
||||
A_shard, Bs, A_scale, B_scales, out_dtypes, gather_dim, group_name
|
||||
)
|
||||
if A_shard.dim() < 2:
|
||||
raise ValueError("A_shard must be a matrix")
|
||||
for B in Bs:
|
||||
if B.dim() != 2:
|
||||
raise ValueError("B must be a matrix")
|
||||
if gather_dim < 0 or gather_dim >= A_shard.dim():
|
||||
raise ValueError("Invalid gather_dim")
|
||||
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
|
||||
with torch.profiler.record_function("fused_all_gather_scaled_matmul"):
|
||||
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
|
||||
# The flattened tensor doesn't need to be contiguous (for computation
|
||||
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
|
||||
# passed to shard_consumer are contiguous.
|
||||
x = A_shard.movedim(gather_dim, 0)
|
||||
leading_dims = [group.size()] + list(x.shape[:-1])
|
||||
x = x.flatten(0, -2)
|
||||
|
||||
# Helper function for reverting the above transformation
|
||||
def unflatten(t: torch.Tensor) -> torch.Tensor:
|
||||
return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim)
|
||||
|
||||
ag_out = x.new_empty(
|
||||
x.shape[0] * group.size(),
|
||||
x.shape[1],
|
||||
)
|
||||
outputs = [
|
||||
x.new_empty(
|
||||
x.shape[0] * group.size(),
|
||||
B.shape[1],
|
||||
dtype=out_dtype if out_dtype is not None else A_shard.dtype,
|
||||
)
|
||||
for B, out_dtype in zip(Bs, out_dtypes)
|
||||
]
|
||||
output_shards = [output.chunk(group.size()) for output in outputs]
|
||||
|
||||
# Computing block-wise matmul along the first dim of A
|
||||
def shard_consumer(shard: torch.Tensor, rank: int) -> None:
|
||||
for idx, (B, B_scale, out_dtype) in enumerate(
|
||||
zip(Bs, B_scales, out_dtypes)
|
||||
):
|
||||
torch.ops.aten._scaled_mm(
|
||||
shard,
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
out_dtype=out_dtype,
|
||||
out=output_shards[idx][rank],
|
||||
)
|
||||
|
||||
_pipelined_all_gather_and_consume(
|
||||
x,
|
||||
shard_consumer,
|
||||
ag_out,
|
||||
group_name,
|
||||
)
|
||||
return unflatten(ag_out), [unflatten(output) for output in outputs]
|
||||
|
||||
|
||||
class Work(_Work):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
Reference in New Issue
Block a user