[micro_pipeline_tp] support all _scaled_mm args (#131984)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131984
Approved by: https://github.com/weifengpy
This commit is contained in:
Yifu Wang
2024-08-01 13:32:53 -07:00
committed by PyTorch MergeBot
parent 2b5e31d099
commit ea42027e0e
3 changed files with 278 additions and 212 deletions

View File

@ -223,25 +223,31 @@ class SymmetricMemoryTest(MultiProcessTestCase):
torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T for _ in range(3)
]
B_scales = [torch.tensor(0.1, device="cuda") for _ in range(3)]
output_dtypes = [None, torch.bfloat16, torch.float32]
out_dtypes = [None, torch.bfloat16, torch.float32]
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
A_shard,
Bs,
A_scale,
B_scales,
output_dtypes,
gather_dim=gather_dim,
group_name=group.group_name,
)
ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback(
A_shard,
Bs,
A_scale,
B_scales,
output_dtypes,
gather_dim=gather_dim,
group_name=group.group_name,
biases=[None] * len(Bs),
result_scales=[None] * len(Bs),
out_dtypes=out_dtypes,
use_fast_accum=[None] * len(Bs),
)
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
A_shard,
Bs,
A_scale,
B_scales,
gather_dim=gather_dim,
group_name=group.group_name,
biases=[None] * len(Bs),
result_scales=[None] * len(Bs),
out_dtypes=out_dtypes,
use_fast_accum=[None] * len(Bs),
)
self.assertTrue(
@ -317,20 +323,20 @@ class SymmetricMemoryTest(MultiProcessTestCase):
B,
A_scale,
B_scale,
torch.bfloat16,
"avg",
scatter_dim=scatter_dim,
group_name=group.group_name,
scatter_dim,
group.group_name,
out_dtype=torch.bfloat16,
)
output_1 = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
A,
B,
A_scale,
B_scale,
torch.bfloat16,
"avg",
scatter_dim=scatter_dim,
group_name=group.group_name,
scatter_dim,
group.group_name,
out_dtype=torch.bfloat16,
)
assert torch.allclose(output_0, output_1)

View File

@ -2,7 +2,7 @@
import operator
from collections import defaultdict
from dataclasses import dataclass, field
from typing import cast, Dict, List, Optional, Set
from typing import Any, cast, Dict, List, Optional, Set
import torch
@ -347,7 +347,6 @@ class _Matmul:
assert len(match) in (1, 3)
assert match[0].target in (
aten.mm.default,
aten._scaled_mm.default,
aten.reshape.default,
)
mm_node = match[0] if len(match) == 1 else match[1]
@ -362,7 +361,10 @@ class _Matmul:
class _ScaledMatmul(_Matmul):
A_scale_node: torch.fx.Node
B_scale_node: torch.fx.Node
bias_node: Optional[torch.fx.Node]
result_scale_node: Optional[torch.fx.Node]
out_dtype: Optional[torch.dtype]
use_fast_accum: bool
def __post_init__(self):
super().__post_init__()
@ -373,22 +375,26 @@ class _ScaledMatmul(_Matmul):
def from_match(cls, match: List[torch.fx.Node]) -> "_ScaledMatmul":
assert len(match) in (1, 3)
assert match[0].target in (
aten.mm.default,
aten._scaled_mm.default,
aten.reshape.default,
)
mm_node = match[0] if len(match) == 1 else match[1]
out_dtype = (
cast(torch.dtype, mm_node.args[6]) if len(mm_node.args) > 6 else None
)
assert isinstance(out_dtype, (torch.dtype, type(None)))
def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any:
if idx >= len(node.args):
return default
return node.args[idx]
return _ScaledMatmul(
nodes=match,
A_node=cast(torch.fx.Node, match[0].args[0]),
B_node=cast(torch.fx.Node, mm_node.args[1]),
A_scale_node=cast(torch.fx.Node, mm_node.args[2]),
B_scale_node=cast(torch.fx.Node, mm_node.args[3]),
out_dtype=out_dtype,
bias_node=get_arg(mm_node, 4, None),
result_scale_node=get_arg(mm_node, 5, None),
out_dtype=get_arg(mm_node, 6, None),
use_fast_accum=get_arg(mm_node, 7, False),
)
@ -477,20 +483,19 @@ def _insert_fused_all_gather_matmul(
)
elif mm_type == _ScaledMatmul:
scaled_matmuls = cast(List[_ScaledMatmul], matmuls)
B_nodes = [matmul.B_node for matmul in scaled_matmuls]
A_scale_node = scaled_matmuls[0].A_scale_node
B_scale_nodes = [matmul.B_scale_node for matmul in scaled_matmuls]
out_dtypes = [matmul.out_dtype for matmul in scaled_matmuls]
return graph.call_function(
torch.ops.symm_mem.fused_all_gather_scaled_matmul.default,
args=(
shard_node,
B_nodes,
A_scale_node,
B_scale_nodes,
out_dtypes,
[matmul.B_node for matmul in scaled_matmuls],
scaled_matmuls[0].A_scale_node,
[matmul.B_scale_node for matmul in scaled_matmuls],
gather_dim,
group_name,
[matmul.bias_node for matmul in scaled_matmuls],
[matmul.result_scale_node for matmul in scaled_matmuls],
[matmul.out_dtype for matmul in scaled_matmuls],
[matmul.use_fast_accum for matmul in scaled_matmuls],
),
)
else:
@ -647,10 +652,13 @@ def _insert_fused_matmul_reduce_scatter(
matmul.B_node,
matmul.A_scale_node,
matmul.B_scale_node,
matmul.out_dtype,
reduce_op,
scatter_dim,
group_name,
matmul.bias_node,
matmul.result_scale_node,
matmul.out_dtype,
matmul.use_fast_accum,
),
)
else:

View File

@ -260,7 +260,11 @@ lib.define(
lib.define(
"fused_all_gather_scaled_matmul("
"Tensor A, Tensor[] Bs, Tensor A_scale, Tensor[] B_scales, "
"ScalarType?[] out_dtypes, int gather_dim, str group_name) -> (Tensor, Tensor[])"
"int gather_dim, str group_name, "
"Tensor?[] biases, "
"Tensor?[] result_scales, "
"ScalarType?[] out_dtypes, "
"bool[] use_fast_accum) -> (Tensor, Tensor[])"
)
lib.define(
"fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor"
@ -268,7 +272,11 @@ lib.define(
lib.define(
"fused_scaled_matmul_reduce_scatter("
"Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, "
"ScalarType out_dtype, str reduce_op, int scatter_dim, str group_name) -> Tensor"
"str reduce_op, int scatter_dim, str group_name, "
"Tensor? bias = None, "
"Tensor? result_scale = None, "
"ScalarType? out_dtype = None, "
"bool use_fast_accum = False) -> Tensor"
)
lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor")
lib.define(
@ -276,6 +284,65 @@ lib.define(
)
def _fused_all_gather_matmul_impl(
mm_out_op: torch._ops.OpOverload,
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
kwargs_list: List[Dict[str, Any]],
out_dtypes: List[Optional[torch.dtype]],
gather_dim: int,
group_name: str,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
if A_shard.dim() < 2:
raise ValueError("A_shard must be a matrix")
for B in Bs:
if B.dim() != 2:
raise ValueError("B must be a matrix")
if len(out_dtypes) != len(Bs):
raise ValueError("len(out_types) must be the same as len(Bs)")
if len(kwargs_list) != len(Bs):
raise ValueError("len(kwargs_list) must be the same as len(Bs)")
if gather_dim < 0 or gather_dim >= A_shard.dim():
raise ValueError("Invalid gather_dim")
group = c10d._resolve_process_group(group_name)
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
# The flattened tensor doesn't need to be contiguous (for computation
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
# passed to shard_consumer are contiguous.
x = A_shard.movedim(gather_dim, 0)
leading_dims = [group.size()] + list(x.shape[:-1])
x = x.flatten(0, -2)
# Helper function for reverting the above transformation
def unflatten(t: torch.Tensor) -> torch.Tensor:
return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim)
ag_out = x.new_empty(
x.shape[0] * group.size(),
x.shape[1],
)
outputs = [
x.new_empty(x.shape[0] * group.size(), B.shape[1], dtype=out_dtype or B.dtype)
for B, out_dtype in zip(Bs, out_dtypes)
]
output_shards = [output.chunk(group.size()) for output in outputs]
# Computing block-wise matmul along the first dim of A
def shard_consumer(shard: torch.Tensor, rank: int) -> None:
for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)):
mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank])
_pipelined_all_gather_and_consume(
x,
shard_consumer,
ag_out,
group_name,
)
return unflatten(ag_out), [unflatten(output) for output in outputs]
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
def _fused_all_gather_matmul_fallback(
A_shard: torch.Tensor,
@ -313,54 +380,140 @@ def _fused_all_gather_matmul(
"""
if _is_test_mode:
return _fused_all_gather_matmul_fallback(A_shard, Bs, gather_dim, group_name)
if A_shard.dim() < 2:
raise ValueError("A_shard must be a matrix")
for B in Bs:
if B.dim() != 2:
raise ValueError("B must be a matrix")
if gather_dim < 0 or gather_dim >= A_shard.dim():
raise ValueError("Invalid gather_dim")
group = c10d._resolve_process_group(group_name)
with torch.profiler.record_function("fused_all_gather_matmul"):
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
# The flattened tensor doesn't need to be contiguous (for computation
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
# passed to shard_consumer are contiguous.
x = A_shard.movedim(gather_dim, 0)
leading_dims = [group.size()] + list(x.shape[:-1])
x = x.flatten(0, -2)
# Helper function for reverting the above transformation
def unflatten(t: torch.Tensor) -> torch.Tensor:
return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim)
ag_out = x.new_empty(
x.shape[0] * group.size(),
x.shape[1],
)
outputs = [
x.new_empty(
x.shape[0] * group.size(),
B.shape[1],
)
for B in Bs
]
output_shards = [output.chunk(group.size()) for output in outputs]
# Computing block-wise matmul along the first dim of A
def shard_consumer(shard: torch.Tensor, rank: int) -> None:
for idx, B in enumerate(Bs):
torch.mm(shard, B, out=output_shards[idx][rank])
_pipelined_all_gather_and_consume(
x,
shard_consumer,
ag_out,
return _fused_all_gather_matmul_impl(
torch.ops.aten.mm.out,
A_shard,
Bs,
[{} for B in Bs],
[B.dtype for B in Bs],
gather_dim,
group_name,
)
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta")
def _fused_all_gather_scaled_matmul_fallback(
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
A_scale: torch.Tensor,
B_scales: List[torch.Tensor],
gather_dim: int,
group_name: str,
biases: List[Optional[torch.Tensor]],
result_scales: List[Optional[torch.Tensor]],
out_dtypes: List[Optional[torch.dtype]],
use_fast_accum: List[bool],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
group_size = c10d._get_group_size_by_name(group_name)
A = torch.ops._c10d_functional.all_gather_into_tensor(
A_shard.contiguous(), group_size, group_name
)
A = torch.ops._c10d_functional.wait_tensor(A)
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
def scaled_matmul(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
bias: Optional[torch.Tensor],
result_scale: Optional[torch.Tensor],
out_dtype: Optional[torch.dtype],
use_fast_accum: bool,
) -> torch.Tensor:
leading_dims = A.shape[:-1]
res = torch.ops.aten._scaled_mm(
A.flatten(0, -2), B, A_scale, B_scale, out_dtype=out_dtype
)
return res.unflatten(0, leading_dims)
return A.movedim(0, gather_dim), [
scaled_matmul(
A, B, A_scale, B_scale, bias, result_scale, out_dtype, fast_accum
).movedim(0, gather_dim)
for B, B_scale, bias, result_scale, out_dtype, fast_accum in zip(
Bs, B_scales, biases, result_scales, out_dtypes, use_fast_accum
)
]
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA")
def _fused_all_gather_scaled_matmul(
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
A_scale: torch.Tensor,
B_scales: List[torch.Tensor],
gather_dim: int,
group_name: str,
biases: List[Optional[torch.Tensor]],
result_scales: List[Optional[torch.Tensor]],
out_dtypes: List[Optional[torch.dtype]],
use_fast_accum: List[bool],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Perform the following logic with micro-pipelined computation and
communication:
A = all_gather_tensor(A_shard, gather_dim, group_name)
leading_dims = A.shape[:-1]
res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale)
res = res.unflatten(0, leading_dims)
Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is
contiguous, no extra copy is required for input layout transformation.
Otherwise A_shard needs to be copied once.
"""
out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
if len(biases) != len(Bs):
raise ValueError("len(biases) must be the same as len(Bs)")
if len(result_scales) != len(Bs):
raise ValueError("len(result_scales) must be the same as len(Bs)")
if len(out_dtypes) != len(Bs):
raise ValueError("len(out_dtypes) must be the same as len(Bs)")
if len(use_fast_accum) != len(Bs):
raise ValueError("len(use_gast_accum_list) must be the same as len(Bs)")
if _is_test_mode:
return _fused_all_gather_scaled_matmul_fallback(
A_shard,
Bs,
A_scale,
B_scales,
gather_dim,
group_name,
biases,
result_scales,
out_dtypes,
use_fast_accum,
)
with torch.profiler.record_function("fused_all_gather_scaled_matmul"):
return _fused_all_gather_matmul_impl(
torch.ops.aten._scaled_mm.out,
A_shard,
Bs,
[
{
"scale_a": A_scale,
"scale_b": B_scale,
"bias": bias,
"scale_result": result_scale,
"out_dtype": out_dtype,
"use_fast_accum": fast_accum,
}
for B_scale, bias, result_scale, out_dtype, fast_accum in zip(
B_scales, biases, result_scales, out_dtypes, use_fast_accum
)
],
out_dtypes,
gather_dim,
group_name,
)
return unflatten(ag_out), [unflatten(output) for output in outputs]
def make_contiguous_for_perm(
@ -499,17 +652,23 @@ def _fused_scaled_matmul_reduce_scatter_fallback(
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
out_dtype: Optional[torch.dtype],
reduce_op: str,
scatter_dim: int,
group_name: str,
bias: Optional[torch.Tensor] = None,
result_scale: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> torch.Tensor:
C = torch._scaled_mm(
A.flatten(0, -2).contiguous(),
B,
A_scale,
B_scale,
out_dtype=out_dtype,
bias,
result_scale,
out_dtype,
use_fast_accum,
)
C = C.view(*A.shape[:-1], B.shape[1])
res = funcol.reduce_scatter_tensor(
@ -528,21 +687,41 @@ def _fused_scaled_matmul_reduce_scatter(
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
out_dtype: Optional[torch.dtype],
reduce_op: str,
scatter_dim: int,
group_name: str,
bias: Optional[torch.Tensor] = None,
result_scale: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> torch.Tensor:
if _is_test_mode:
return _fused_scaled_matmul_reduce_scatter_fallback(
A, B, A_scale, B_scale, out_dtype, reduce_op, scatter_dim, group_name
A,
B,
A_scale,
B_scale,
reduce_op,
scatter_dim,
group_name,
bias,
result_scale,
out_dtype,
use_fast_accum,
)
with torch.profiler.record_function("fused_matmul_reduce_scatter"):
return _fused_matmul_reduce_scatter_impl(
mm_out_op=torch.ops.aten._scaled_mm.out,
A=A,
B=B,
kwargs={"scale_a": A_scale, "scale_b": B_scale, "out_dtype": out_dtype},
kwargs={
"scale_a": A_scale,
"scale_b": B_scale,
"bias": bias,
"scale_result": result_scale,
"out_dtype": out_dtype,
"use_fast_accum": use_fast_accum,
},
out_dtype=out_dtype,
reduce_op=reduce_op,
scatter_dim=scatter_dim,
@ -608,133 +787,6 @@ def _maybe_convert_scalar_types_to_dtypes(
return dtypes
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta")
def _fused_all_gather_scaled_matmul_fallback(
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
A_scale: torch.Tensor,
B_scales: List[torch.Tensor],
out_dtypes: List[Optional[torch.dtype]],
gather_dim: int,
group_name: str,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
group_size = c10d._get_group_size_by_name(group_name)
A = torch.ops._c10d_functional.all_gather_into_tensor(
A_shard.contiguous(), group_size, group_name
)
A = torch.ops._c10d_functional.wait_tensor(A)
A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
def scaled_matmul(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
out_dtype: Optional[torch.dtype],
) -> torch.Tensor:
leading_dims = A.shape[:-1]
res = torch.ops.aten._scaled_mm(
A.flatten(0, -2), B, A_scale, B_scale, out_dtype=out_dtype
)
return res.unflatten(0, leading_dims)
return A.movedim(0, gather_dim), [
scaled_matmul(A, B, A_scale, B_scale, out_dtype).movedim(0, gather_dim)
for B, B_scale, out_dtype in zip(Bs, B_scales, out_dtypes)
]
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA")
def _fused_all_gather_scaled_matmul(
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
A_scale: torch.Tensor,
B_scales: List[torch.Tensor],
out_dtypes: List[Optional[torch.dtype]],
gather_dim: int,
group_name: str,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Perform the following logic with micro-pipelined computation and
communication:
A = all_gather_tensor(A_shard, gather_dim, group_name)
leading_dims = A.shape[:-1]
res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale)
res = res.unflatten(0, leading_dims)
Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is
contiguous, no extra copy is required for input layout transformation.
Otherwise A_shard needs to be copied once.
"""
out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
if _is_test_mode:
return _fused_all_gather_scaled_matmul_fallback(
A_shard, Bs, A_scale, B_scales, out_dtypes, gather_dim, group_name
)
if A_shard.dim() < 2:
raise ValueError("A_shard must be a matrix")
for B in Bs:
if B.dim() != 2:
raise ValueError("B must be a matrix")
if gather_dim < 0 or gather_dim >= A_shard.dim():
raise ValueError("Invalid gather_dim")
group = c10d._resolve_process_group(group_name)
with torch.profiler.record_function("fused_all_gather_scaled_matmul"):
# Move the gather_dim to the front and flatten the tensor into a 2D matrix.
# The flattened tensor doesn't need to be contiguous (for computation
# efficiency), as _pipelined_all_gather_and_consume guarantees that shards
# passed to shard_consumer are contiguous.
x = A_shard.movedim(gather_dim, 0)
leading_dims = [group.size()] + list(x.shape[:-1])
x = x.flatten(0, -2)
# Helper function for reverting the above transformation
def unflatten(t: torch.Tensor) -> torch.Tensor:
return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim)
ag_out = x.new_empty(
x.shape[0] * group.size(),
x.shape[1],
)
outputs = [
x.new_empty(
x.shape[0] * group.size(),
B.shape[1],
dtype=out_dtype if out_dtype is not None else A_shard.dtype,
)
for B, out_dtype in zip(Bs, out_dtypes)
]
output_shards = [output.chunk(group.size()) for output in outputs]
# Computing block-wise matmul along the first dim of A
def shard_consumer(shard: torch.Tensor, rank: int) -> None:
for idx, (B, B_scale, out_dtype) in enumerate(
zip(Bs, B_scales, out_dtypes)
):
torch.ops.aten._scaled_mm(
shard,
B,
A_scale,
B_scale,
out_dtype=out_dtype,
out=output_shards[idx][rank],
)
_pipelined_all_gather_and_consume(
x,
shard_consumer,
ag_out,
group_name,
)
return unflatten(ag_out), [unflatten(output) for output in outputs]
class Work(_Work):
def __init__(self) -> None:
super().__init__()