Use non-blocking custom_op and fuse multiple offsets into one kernel, then utilize lookup_seed to lookup offset

1. Replace Triton helper with a non-blocking custom_op that writes offsets directly to device tensors, avoiding any dependence on BLOCK/XBLOCK and eliminating syncs on H2D.
2. Fuse multiple offset reservations into a single kernel/op to cut launch overhead and AVOID op only generate once issue; expose a vectorized rand_eager_offsets and register_fake for meta shape inference.
3. Plumb offsets through inductor_lookup_seed (GPU-resident) so downstream RNG kernels index by offset, preserving determinism across eager and torch.compile.
This commit is contained in:
Eric.Chin.AMD
2025-10-07 21:12:18 +08:00
parent d8ebb27471
commit 7549a60771

View File

@ -14,7 +14,6 @@ from ..pattern_matcher import (
register_graph_pattern,
)
from ..virtualized import V
from . import custom_philox_rand
log = logging.getLogger(__name__)
patterns = PatternMatcherPass()
@ -27,78 +26,7 @@ from typing import Sequence, Optional
import torch
import triton
import triton.language as tl
from torch.library import triton_op, wrap_triton, register_fake, register_kernel
BLOCK = 256
@triton.jit
def _pick_lane(u0, u1, u2, u3, lane):
v = tl.where(lane == 0, u0, u1)
v = tl.where(lane == 1, u1, v)
v = tl.where(lane == 2, u2, v)
v = tl.where(lane == 3, u3, v)
return v
@triton.jit
def _philox_fill_uniform_gridstride(out_ptr, n_elements, seed, offset_blocks, lane_shift,threads_per_round,BLOCK: tl.constexpr = BLOCK):
UNROLL = 4
pid = tl.program_id(0) # [0, grid_x)
tid = pid * BLOCK + tl.arange(0, BLOCK) # [0, BLOCK*grid_x)
inv = 1.0 / 4294967296.0
half = inv * 0.5
# rounds_total = ceil(n / (threads_per_round * UNROLL))
rounds_total = (n_elements + threads_per_round * UNROLL - 1) // (threads_per_round * UNROLL)
# tl.device_print("rand_philox offset_blocks %d rounds_total %d\n", offset_blocks, rounds_total)
for r in range(rounds_total):
subseq = (tid).to(tl.uint64)
lane = ((tid + lane_shift) % 4).to(tl.uint64)
offblk = tl.full(subseq.shape, (offset_blocks + r), tl.uint64)
u0, u1, u2, u3 = tl.philox(
seed,
(offblk & 0xFFFFFFFF).to(tl.uint32),
((offblk >> 32) & 0xFFFFFFFF).to(tl.uint32),
(subseq & 0xFFFFFFFF).to(tl.uint32),
((subseq >> 32) & 0xFFFFFFFF).to(tl.uint32),
)
inv = 1.0 / 4294967296.0 # 2^-32
half = inv * 0.5
base = tid * 4
stride = threads_per_round
# k=0
i0 = base + (r * UNROLL) * stride
m0 = i0 < n_elements
lane0 = tl.full(tid.shape, (lane_shift + 0) % 4, tl.uint32)
f0 = _pick_lane(u0, u1, u2, u3, lane0).to(tl.float32) * inv + half
tl.store(out_ptr + i0, 1.0 - f0, mask=m0)
# k=1
i1 = base + 1 + (r * UNROLL) * stride
m1 = i1 < n_elements
lane1 = tl.full(tid.shape, (lane_shift + 1) % 4, tl.uint32)
f1 = _pick_lane(u0, u1, u2, u3, lane1).to(tl.float32) * inv + half
tl.store(out_ptr + i1, 1.0 - f1, mask=m1)
# k=2
i2 = base + 2 + (r * UNROLL) * stride
m2 = i2 < n_elements
lane2 = tl.full(tid.shape, (lane_shift + 2) % 4, tl.uint32)
f2 = _pick_lane(u0, u1, u2, u3, lane2).to(tl.float32) * inv + half
tl.store(out_ptr + i2, 1.0 - f2, mask=m2)
# k=3
i3 = base + 3 + (r * UNROLL) * stride
m3 = i3 < n_elements
lane3 = tl.full(tid.shape, (lane_shift + 3) % 4, tl.uint32)
f3 = _pick_lane(u0, u1, u2, u3, lane3).to(tl.float32) * inv + half
tl.store(out_ptr + i3, 1.0 - f3, mask=m3)
from torch.library import triton_op, wrap_triton, register_fake, register_kernel, custom_op
# ---- host helpers ----
def _compute_grid_x(nelem: int, block: int, device_index: int) -> int:
@ -108,48 +36,49 @@ def _compute_grid_x(nelem: int, block: int, device_index: int) -> int:
need_blocks = (nelem + block - 1) // block
return min(max_blocks, need_blocks)
def _reserve_seed_and_offset_gridstride(x_device_index: int, nelem: int, block: int):
def _shape_to_offset(size, device: torch.device) -> int:
nelem = 1
for s in size:
nelem *= int(s)
UNROLL = 4
gen = torch.cuda.default_generators[x_device_index]
seed = int(gen.initial_seed())
grid_x = _compute_grid_x(nelem, block, x_device_index)
rounds_per_thread = (nelem + (block * grid_x * UNROLL) - 1) // (block * grid_x * UNROLL)
counter_offset_per_thread = rounds_per_thread * UNROLL
used_32 = counter_offset_per_thread #* block * grid_x
prop = torch.cuda.get_device_properties(device)
threads_per_round = prop.multi_processor_count * prop.max_threads_per_multi_processor
rounds_per_thread = (nelem + threads_per_round * UNROLL - 1) // (threads_per_round * UNROLL)
used_32 = rounds_per_thread * UNROLL
return used_32
#@torch._dynamo.disable
def _reserve_offset(device: torch.device, used_32: int) -> int:
dev_index = device.index if isinstance(device, torch.device) else int(device)
gen = torch.cuda.default_generators[dev_index]
old_off = int(gen.get_offset())
gen.set_offset(old_off + used_32)
return seed, (old_off // 4), (old_off % 4), grid_x
gen.set_offset(old_off + used_32)
return old_off // 4
def get_and_acc_base_offset():
pass
@triton.jit
def _write_offset(out_ptr, base_offset):
# BLOCK = 64
# idx = tl.arange(0, BLOCK)
# offset_vec = tl.full(idx.shape, base_offset, tl.uint32)
# mask = (idx == 0)
# tl.store(out_ptr + idx, offset_vec, mask=mask)
tl.store(out_ptr + 0, base_offset, mask=True)
@triton_op("triton_op::rand_eager_offset", mutates_args={})
def rand_eager_offset(
shape: Sequence[int], #*,
device: torch.device,
) -> torch.Tensor:
out = torch.empty(1, dtype=torch.uint32, device=device)
n = 1
for d in shape:
n *= d
seed_val, base_offset, lane_shift, grid_x = _reserve_seed_and_offset_gridstride(device.index, n, BLOCK)
grid = lambda meta: (1,)
wrap_triton(_write_offset)[grid] (out, base_offset)
@custom_op("my_ops::rand_eager_offset", mutates_args={})
def rand_eager_offset(offset: int, device: torch.device) -> torch.Tensor:
base = _reserve_offset(device, int(offset))
out = torch.empty(1, dtype=torch.int64, device=device)
out.fill_(int(base))
return out
@custom_op("my_ops::rand_eager_offsets", mutates_args={})
def rand_eager_offsets(offsets: list[int], device: torch.device) -> torch.Tensor:
bases = [int(_reserve_offset(device, int(off))) for off in offsets]
cpu = torch.tensor(bases, dtype=torch.int64).pin_memory()
out = torch.empty_like(cpu, device=device)
out.copy_(cpu, non_blocking=True)
return out
@rand_eager_offset.register_fake
def _(offset: int, device: torch.device):
return torch.empty((1,), dtype=torch.int64, device=device)
@rand_eager_offsets.register_fake
def _(offsets: list[int], device: torch.device):
return torch.empty((len(offsets),), dtype=torch.int64, device=device)
def replace_random_passes(gm: torch.fx.GraphModule):
"""Modify the given FX graph to use backend-native random ops"""
@ -159,9 +88,57 @@ def replace_random_passes(gm: torch.fx.GraphModule):
count = patterns.apply(gm)
with GraphTransformObserver(gm, "fuse_seed_creation_pass"):
count += fuse_seed_creation_pass(gm.graph)
if config.align_random_eager:
with GraphTransformObserver(gm, "fuse_offset_creation_pass"):
count += fuse_offset_creation_pass(gm.graph)
#print(f"After replace_random_passes: {gm}")
return count
def fuse_offset_creation_pass(graph: torch.fx.Graph):
"""
Horizontally fuse all the seed generation on each device
a = my_triton_op.rand_eager_offset(offset, dev)
b = my_triton_op.rand_eager_offset(offset, dev)
Becomes:
offsets = my_triton_op.rand_eager_offsets([offset1, offset2...], dev)
a = inductor_lookup_seed(offsets, 0)
b = inductor_lookup_seed(offsets, 1)
We do this because seed creation is entirely launch overhead bound.
"""
device_offsets = collections.defaultdict(list)
for node in graph.nodes:
if CallFunctionVarArgs(torch.ops.my_ops.rand_eager_offset).match(node):
device_offsets[node.args[1]].append(node)
if not device_offsets:
return 0
for device, offsets in device_offsets.items():
with graph.inserting_before(offsets[0]):
print(f"len(offsets) = {len(offsets)}")
offs = [n.args[0] for n in offsets]
combined = graph.call_function(torch.ops.my_ops.rand_eager_offsets.default, (offs, device))
with V.fake_mode:
combined.meta["val"] = torch.empty(
[len(offsets)], device=device, dtype=torch.int64
)
combined.meta["tensor_meta"] = _extract_tensor_metadata(
combined.meta["val"]
)
for idx, offset in enumerate(offsets):
with graph.inserting_before(offset):
new_offset = graph.call_function(
inductor_prims.lookup_seed, (combined, idx)
)
offset.replace_all_uses_with(new_offset)
new_offset.meta.update(offset.meta)
graph.erase_node(offset)
return len(device_offsets)
def fuse_seed_creation_pass(graph: torch.fx.Graph):
"""
@ -235,25 +212,6 @@ def replace_random(
if generator is not None:
return
mode = {
aten.rand: "rand",
aten.randn: "randn",
}[
match.output_node().target.overloadpacket # type: ignore[union-attr]
] # type: ignore[union-attr]
device = get_device(device)
# For uniform rand (e.g., used by dropout), call our custom Triton op directly
if mode == "rand":
def replacement(size):
# dtype: keep caller's dtype if provided, else default fp32
use_dtype = dtype if dtype is not None else torch.float32
return torch.ops.my_triton_op.philox_rand(size, device, use_dtype)
match.replace_by_example(replacement, [size])
return
# Fallback (e.g., randn) keeps existing inductor behavior
def replacement(size):
result = inductor_prims.random(
size, inductor_prims.seed(device), mode, **default_kwargs(device)
@ -272,8 +230,9 @@ def replace_random(
if mode == "rand" and config.align_random_eager:
def replacement(size):
offset = _shape_to_offset(size, device)
result = inductor_prims.random(
size, inductor_prims.lookup_seed(torch.ops.my_triton_op.rand_eager_offset(size, device), 0), mode, **default_kwargs(device)
size, torch.ops.my_ops.rand_eager_offset(offset, device), mode, **default_kwargs(device)
)
if dtype is not None:
result = result.to(dtype)