mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
### Implementation of #151705 This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates. To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705: 1. **Basic support** (this PR) 2. **Lazy broadcasting** for optimal performance (future PR) ### Summary of This PR This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead. ### Notable Changes 1. Adds a new config flag: `config.triton.enable_native_matmul` 2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled 3. Enforces tililng suitable for matmul when the native matmul flag is enabled 4. Implements code generation for `ops.dot` 5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this. @eellison @jansel @PaulZhang12 @shunting314 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743 Approved by: https://github.com/jansel
488 lines
21 KiB
Python
488 lines
21 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.logging
|
|
import torch._dynamo.test_case
|
|
|
|
# for some reason importing functional collectives after dynamo breaks collectives handling!
|
|
import torch.distributed._functional_collectives as _functional_collectives
|
|
from torch._C import FileCheck
|
|
from torch._dynamo.utils import same
|
|
from torch._inductor import ir, scheduler
|
|
from torch._inductor.comm_analysis import (
|
|
baseLat,
|
|
hwLat,
|
|
llMaxBws,
|
|
NCCL_ALGO,
|
|
NCCL_HW,
|
|
NCCL_PROTO,
|
|
NVIDIA_GPU_TYPE,
|
|
)
|
|
from torch._inductor.utils import run_and_get_triton_code
|
|
from torch.testing._internal.common_distributed import (
|
|
_dynamo_dist_per_rank_init,
|
|
at_least_x_gpu,
|
|
DynamoDistributedMultiProcTestCase,
|
|
requires_accelerator_dist_backend,
|
|
)
|
|
from torch.testing._internal.common_fsdp import get_devtype
|
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
|
|
|
|
|
device_type = str(get_devtype())
|
|
|
|
|
|
def get_snode_runtime_for_reorder_compute_test(snode):
|
|
# NOTE: custom cost model to show that the compute reordering algorithm is working
|
|
# Collective kernels
|
|
if isinstance(snode.node, ir._CollectiveKernel):
|
|
return 100
|
|
elif isinstance(snode.node, ir._WaitKernel):
|
|
return 0
|
|
# High-arithmetic-intensity compute kernels
|
|
elif isinstance(snode.node, ir.ExternKernel):
|
|
return 5
|
|
# All other kernels
|
|
return 1
|
|
|
|
|
|
def create_grouped_node_for_allreduce_and_its_deps(snodes):
|
|
name_to_snode = {snode.node.name: snode for snode in snodes}
|
|
all_reduce_snodes = [
|
|
snode
|
|
for snode in snodes
|
|
if isinstance(snode.node, ir._CollectiveKernel)
|
|
and snode.node.op_overload == torch.ops._c10d_functional.all_reduce_.default
|
|
]
|
|
assert len(all_reduce_snodes) == 1
|
|
all_reduce_snode = all_reduce_snodes[0]
|
|
all_reduce_dep_snodes = [
|
|
name_to_snode[node.name] for node in all_reduce_snode.node.inputs
|
|
]
|
|
assert len(all_reduce_dep_snodes) == 1
|
|
all_reduce_dep_snode = all_reduce_dep_snodes[0]
|
|
|
|
grouped_snode = scheduler.GroupedSchedulerNode.create(
|
|
[all_reduce_dep_snode, all_reduce_snode]
|
|
)
|
|
new_snode_order = []
|
|
new_snode_order.append(grouped_snode)
|
|
for snode in snodes:
|
|
if snode in grouped_snode.snodes:
|
|
continue
|
|
new_snode_order.append(snode)
|
|
return new_snode_order
|
|
|
|
|
|
@requires_accelerator_dist_backend()
|
|
@unittest.skipIf(
|
|
torch._inductor.config.triton.native_matmul,
|
|
"native matmul is fused with surrounding ops",
|
|
)
|
|
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|
"""
|
|
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
|
"""
|
|
|
|
def get_world_trs(self):
|
|
return {
|
|
"tag": "",
|
|
"ranks": list(range(self.world_size)),
|
|
"group_size": self.world_size,
|
|
}
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
|
|
# works around issue with skipif<2 and workers with unpredictable #s gpu
|
|
return 2
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
|
@patch.object(torch._inductor.config, "compile_threads", 1)
|
|
@patch.object(torch._inductor.config, "reorder_for_locality", False)
|
|
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
|
|
@patch.object(
|
|
torch._inductor.config,
|
|
"reorder_for_compute_comm_overlap_passes",
|
|
[
|
|
"sink_waits",
|
|
],
|
|
)
|
|
def test_sink_waits(self):
|
|
def func(a):
|
|
ar = _functional_collectives.all_reduce(a, "sum", "0")
|
|
b = torch.matmul(a, a)
|
|
return torch.matmul(ar, b)
|
|
|
|
with _dynamo_dist_per_rank_init(
|
|
self.rank,
|
|
self.world_size,
|
|
self.backend(device_type),
|
|
fake_pg=not at_least_x_gpu(2),
|
|
):
|
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
|
compiled = torch.compile(func)
|
|
code = run_and_get_triton_code(compiled, inputs)
|
|
# Verify that the wait_tensor is sinked below the 1st matmul but
|
|
# above the 2nd matmul.
|
|
(
|
|
FileCheck()
|
|
.check("torch.ops._c10d_functional.all_reduce_.default")
|
|
.check("extern_kernels.mm")
|
|
.check("torch.ops._c10d_functional.wait_tensor.default")
|
|
.check("extern_kernels.mm")
|
|
.run(code)
|
|
)
|
|
out = compiled(inputs)
|
|
correct = func(inputs)
|
|
self.assertTrue(same(out, correct))
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
|
@patch.object(torch._inductor.config, "compile_threads", 1)
|
|
@patch.object(torch._inductor.config, "reorder_for_locality", False)
|
|
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
|
|
@patch.object(
|
|
torch._inductor.config,
|
|
"reorder_for_compute_comm_overlap_passes",
|
|
[
|
|
"raise_comms",
|
|
],
|
|
)
|
|
def test_raise_comms(self):
|
|
def func(a):
|
|
b = torch.matmul(a, a)
|
|
c = torch.relu(b)
|
|
d = torch.matmul(c, c)
|
|
e = _functional_collectives.all_reduce(b, "sum", "0")
|
|
return torch.matmul(d, e)
|
|
|
|
with _dynamo_dist_per_rank_init(
|
|
self.rank,
|
|
self.world_size,
|
|
self.backend(device_type),
|
|
fake_pg=not at_least_x_gpu(2),
|
|
):
|
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
|
compiled = torch.compile(func)
|
|
code = run_and_get_triton_code(compiled, inputs)
|
|
# Verify that the all_reduce_ has been raised above the 2nd matmul
|
|
# but below the 1st matmul. Note that the all_reduce_ directly
|
|
# writes to the output buffer of the 1st matmul, which is an input
|
|
# to the first relu. Therefore, the all_reduce_ should be scheduled
|
|
# after the first relu.
|
|
(
|
|
FileCheck()
|
|
.check("extern_kernels.mm")
|
|
.check("triton_poi_fused_relu")
|
|
.check("torch.ops._c10d_functional.all_reduce_.default")
|
|
.check_same("buf0")
|
|
# mm not use buf prior to wait_tensor
|
|
.check("extern_kernels.mm")
|
|
.check_not("buf0")
|
|
.check("torch.ops._c10d_functional.wait_tensor.default")
|
|
.check("extern_kernels.mm")
|
|
.run(code)
|
|
)
|
|
out = compiled(inputs)
|
|
correct = func(inputs)
|
|
self.assertTrue(same(out, correct))
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
|
@patch.object(torch._inductor.config, "compile_threads", 1)
|
|
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
|
|
@patch.object(
|
|
torch._inductor.config,
|
|
"reorder_for_compute_comm_overlap_passes",
|
|
[
|
|
"sink_waits",
|
|
"raise_comms",
|
|
],
|
|
)
|
|
def test_sink_waits_raise_comms(self):
|
|
def func(a, *, tag, ranks, group_size):
|
|
b = torch.matmul(a, a)
|
|
c = torch.relu(b)
|
|
d = torch.matmul(c, c)
|
|
e = _functional_collectives.all_reduce(b, "sum", "0")
|
|
f = torch.relu(d)
|
|
g = torch.matmul(f, f)
|
|
return torch.mm(e, g)
|
|
|
|
with _dynamo_dist_per_rank_init(
|
|
self.rank,
|
|
self.world_size,
|
|
self.backend(device_type),
|
|
fake_pg=not at_least_x_gpu(2),
|
|
):
|
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
|
compiled = torch.compile(func)
|
|
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
|
# Things to verify:
|
|
# - The clone prologue of the all_reduce_ should not be fused with
|
|
# any relus.
|
|
# - The all_reduce_ and its prologue should be raised above the 2nd
|
|
# matmul but below the 1st matmul.
|
|
# - The wait_tensor should be sinked below the 3rd matmul but above
|
|
# the 4th matmul.
|
|
(
|
|
FileCheck()
|
|
.check("extern_kernels.mm")
|
|
.check("triton_poi_fused_all_reduce_0")
|
|
.check("torch.ops._c10d_functional.all_reduce_.default")
|
|
.check("triton_poi_fused_relu")
|
|
.check("extern_kernels.mm")
|
|
.check("triton_poi_fused_relu")
|
|
.check("extern_kernels.mm")
|
|
.check("torch.ops._c10d_functional.wait_tensor.default")
|
|
.check("extern_kernels.mm")
|
|
.run(code)
|
|
)
|
|
out = compiled(inputs, **self.get_world_trs())
|
|
correct = func(inputs, **self.get_world_trs())
|
|
self.assertTrue(same(out, correct))
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
|
@patch.object(torch._inductor.config, "compile_threads", 1)
|
|
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
|
|
@patch.object(
|
|
torch._inductor.config,
|
|
"reorder_for_compute_comm_overlap_passes",
|
|
[
|
|
"reorder_compute_for_overlap",
|
|
],
|
|
)
|
|
@patch.object(
|
|
torch._inductor.config,
|
|
"runtime_estimations_mms_benchmark",
|
|
False,
|
|
)
|
|
def test_reorder_compute_for_overlap(self):
|
|
def func(a, *, tag, ranks, group_size):
|
|
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
|
|
g = torch.matmul(a, a)
|
|
c = torch.relu(a)
|
|
d = torch.matmul(c, c)
|
|
f = d * c * ar
|
|
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
|
|
e = torch.matmul(d + ar + fr, g)
|
|
return (e,)
|
|
|
|
with _dynamo_dist_per_rank_init(
|
|
self.rank,
|
|
self.world_size,
|
|
self.backend(device_type),
|
|
fake_pg=not at_least_x_gpu(2),
|
|
):
|
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
|
compiled = torch.compile(func)
|
|
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
|
# NOTE: after scheduling the first all_reduce:
|
|
# 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
|
|
# 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
|
|
# 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
|
|
# and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
|
|
(
|
|
FileCheck()
|
|
.check("torch.ops._c10d_functional.all_reduce_.default")
|
|
.check("triton_poi_fused_relu")
|
|
.check("extern_kernels.mm")
|
|
.check("extern_kernels.mm")
|
|
.check("torch.ops._c10d_functional.wait_tensor.default")
|
|
.check("triton_poi_fused_all_reduce_mul")
|
|
.check("torch.ops._c10d_functional.all_reduce_.default")
|
|
.check("torch.ops._c10d_functional.wait_tensor.default")
|
|
.check("triton_poi_fused_add")
|
|
.check("extern_kernels.mm")
|
|
.run(code)
|
|
)
|
|
out = compiled(inputs, **self.get_world_trs())
|
|
correct = func(inputs, **self.get_world_trs())
|
|
self.assertTrue(same(out, correct))
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
|
@patch.object(torch._inductor.config, "compile_threads", 1)
|
|
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
|
|
@patch.object(
|
|
torch._inductor.config,
|
|
"reorder_for_compute_comm_overlap_passes",
|
|
[
|
|
"reorder_compute_for_overlap",
|
|
],
|
|
)
|
|
@patch.object(
|
|
torch._inductor.config,
|
|
"estimate_op_runtime",
|
|
get_snode_runtime_for_reorder_compute_test,
|
|
)
|
|
def test_reorder_compute_for_overlap_custom_runtime_estimation(self):
|
|
def func(a, *, tag, ranks, group_size):
|
|
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
|
|
g = torch.matmul(a, a)
|
|
c = torch.relu(a)
|
|
d = torch.matmul(c, c)
|
|
f = d * c * ar
|
|
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
|
|
e = torch.matmul(d + ar + fr, g)
|
|
return (e,)
|
|
|
|
with _dynamo_dist_per_rank_init(
|
|
self.rank,
|
|
self.world_size,
|
|
self.backend(device_type),
|
|
fake_pg=not at_least_x_gpu(2),
|
|
):
|
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
|
compiled = torch.compile(func)
|
|
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
|
# NOTE: after scheduling the first all_reduce:
|
|
# 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
|
|
# 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
|
|
# 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
|
|
# and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
|
|
(
|
|
FileCheck()
|
|
.check("torch.ops._c10d_functional.all_reduce_.default")
|
|
.check("triton_poi_fused_relu")
|
|
.check("extern_kernels.mm")
|
|
.check("extern_kernels.mm")
|
|
.check("torch.ops._c10d_functional.wait_tensor.default")
|
|
.check("triton_poi_fused_all_reduce_mul")
|
|
.check("torch.ops._c10d_functional.all_reduce_.default")
|
|
.check("torch.ops._c10d_functional.wait_tensor.default")
|
|
.check("triton_poi_fused_add")
|
|
.check("extern_kernels.mm")
|
|
.run(code)
|
|
)
|
|
out = compiled(inputs, **self.get_world_trs())
|
|
correct = func(inputs, **self.get_world_trs())
|
|
self.assertTrue(same(out, correct))
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@unittest.skipIf(
|
|
torch._inductor.config.triton.native_matmul,
|
|
"native matmul is fused with surrounding ops",
|
|
)
|
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
|
@patch.object(torch._inductor.config, "compile_threads", 1)
|
|
@patch.object(
|
|
torch._inductor.config,
|
|
"_pre_fusion_custom_pass",
|
|
create_grouped_node_for_allreduce_and_its_deps,
|
|
)
|
|
def test_grouped_scheduler_node(self):
|
|
def func(a, *, tag, ranks, group_size):
|
|
add = a + a
|
|
div = add / a
|
|
ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
|
|
# Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
|
|
# but here in this unit test, we intentionally put `add`, `div` and `ar` computation
|
|
# into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
|
|
mul = a * a
|
|
mm = torch.matmul(mul, ar)
|
|
return (mm,)
|
|
|
|
with _dynamo_dist_per_rank_init(
|
|
self.rank,
|
|
self.world_size,
|
|
self.backend(device_type),
|
|
fake_pg=not at_least_x_gpu(2),
|
|
):
|
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
|
compiled = torch.compile(func)
|
|
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
|
# Expectations:
|
|
# 1. `add = a + a` and `div = add / a` are still fused, which means fusion
|
|
# still happens among nodes within a GroupedSchedulerNode.
|
|
# 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
|
|
# GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
|
|
FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check(
|
|
"_c10d_functional.all_reduce_."
|
|
).check("triton_poi_fused_mul_1.").run(code)
|
|
out = compiled(inputs, **self.get_world_trs())
|
|
correct = func(inputs, **self.get_world_trs())
|
|
self.assertTrue(same(out, correct))
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@torch._inductor.config.patch(force_disable_caches=True)
|
|
def test_inductor_default_comms_ordering(self):
|
|
pg_info = self.get_world_trs()
|
|
tag = pg_info["tag"]
|
|
ranks = pg_info["ranks"]
|
|
group_size = pg_info["group_size"]
|
|
|
|
g1 = torch.ones(10, 10, device=device_type)
|
|
g2 = torch.ones(11, 11, device=device_type)
|
|
g3 = torch.ones(12, 12, device=device_type)
|
|
|
|
def assert_pass(graph):
|
|
# all_reduces need to remain in order!
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
graph():
|
|
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
|
|
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
|
|
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
|
|
%all_reduce : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg0_1, avg, 0), kwargs = {})
|
|
%all_reduce_1 : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg1_1, avg, 0), kwargs = {})
|
|
%all_reduce_2 : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg2_1, avg, 0), kwargs = {})
|
|
%wait_tensor : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce_2,), kwargs = {})
|
|
%wait_tensor_1 : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce_1,), kwargs = {})
|
|
%wait_tensor_2 : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce,), kwargs = {})
|
|
return (wait_tensor, wait_tensor_1, wait_tensor_2)""", # noqa: B950
|
|
)
|
|
|
|
torch._inductor.config.post_grad_custom_post_pass = assert_pass
|
|
|
|
@torch.compile
|
|
def fn(g1, g2, g3):
|
|
handle1 = torch.ops.c10d_functional.all_reduce(
|
|
g1, "avg", tag, ranks, group_size
|
|
)
|
|
handle2 = torch.ops.c10d_functional.all_reduce(
|
|
g2, "avg", tag, ranks, group_size
|
|
)
|
|
handle3 = torch.ops.c10d_functional.all_reduce(
|
|
g3, "avg", tag, ranks, group_size
|
|
)
|
|
|
|
# wait on them in a different order
|
|
grad3 = torch.ops._c10d_functional.wait_tensor.default(handle3)
|
|
grad2 = torch.ops._c10d_functional.wait_tensor.default(handle2)
|
|
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
|
|
return grad3, grad2, grad1
|
|
|
|
with _dynamo_dist_per_rank_init(
|
|
self.rank, self.world_size, self.backend(device_type), fake_pg=True
|
|
):
|
|
fn(g1, g2, g3)
|
|
|
|
def test_nccl_heuristics(self):
|
|
assert len(baseLat) == len(NCCL_ALGO)
|
|
assert all(len(x) == len(NCCL_PROTO) for x in baseLat)
|
|
|
|
assert len(hwLat) == len(NCCL_HW)
|
|
assert all(len(x) == len(NCCL_ALGO) for x in hwLat)
|
|
assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x)
|
|
|
|
assert len(llMaxBws) == len(NVIDIA_GPU_TYPE)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|