mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
Use non-blocking custom_op and fuse multiple offsets into one kernel, then utilize lookup_seed to lookup offset
1. Replace Triton helper with a non-blocking custom_op that writes offsets directly to device tensors, avoiding any dependence on BLOCK/XBLOCK and eliminating syncs on H2D. 2. Fuse multiple offset reservations into a single kernel/op to cut launch overhead and AVOID op only generate once issue; expose a vectorized rand_eager_offsets and register_fake for meta shape inference. 3. Plumb offsets through inductor_lookup_seed (GPU-resident) so downstream RNG kernels index by offset, preserving determinism across eager and torch.compile.
This commit is contained in:
@ -14,7 +14,6 @@ from ..pattern_matcher import (
|
||||
register_graph_pattern,
|
||||
)
|
||||
from ..virtualized import V
|
||||
from . import custom_philox_rand
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
patterns = PatternMatcherPass()
|
||||
@ -27,78 +26,7 @@ from typing import Sequence, Optional
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.library import triton_op, wrap_triton, register_fake, register_kernel
|
||||
|
||||
BLOCK = 256
|
||||
|
||||
@triton.jit
|
||||
def _pick_lane(u0, u1, u2, u3, lane):
|
||||
v = tl.where(lane == 0, u0, u1)
|
||||
v = tl.where(lane == 1, u1, v)
|
||||
v = tl.where(lane == 2, u2, v)
|
||||
v = tl.where(lane == 3, u3, v)
|
||||
return v
|
||||
|
||||
@triton.jit
|
||||
def _philox_fill_uniform_gridstride(out_ptr, n_elements, seed, offset_blocks, lane_shift,threads_per_round,BLOCK: tl.constexpr = BLOCK):
|
||||
UNROLL = 4
|
||||
pid = tl.program_id(0) # [0, grid_x)
|
||||
tid = pid * BLOCK + tl.arange(0, BLOCK) # [0, BLOCK*grid_x)
|
||||
inv = 1.0 / 4294967296.0
|
||||
half = inv * 0.5
|
||||
|
||||
# rounds_total = ceil(n / (threads_per_round * UNROLL))
|
||||
rounds_total = (n_elements + threads_per_round * UNROLL - 1) // (threads_per_round * UNROLL)
|
||||
|
||||
# tl.device_print("rand_philox offset_blocks %d rounds_total %d\n", offset_blocks, rounds_total)
|
||||
|
||||
for r in range(rounds_total):
|
||||
subseq = (tid).to(tl.uint64)
|
||||
lane = ((tid + lane_shift) % 4).to(tl.uint64)
|
||||
|
||||
offblk = tl.full(subseq.shape, (offset_blocks + r), tl.uint64)
|
||||
u0, u1, u2, u3 = tl.philox(
|
||||
seed,
|
||||
(offblk & 0xFFFFFFFF).to(tl.uint32),
|
||||
((offblk >> 32) & 0xFFFFFFFF).to(tl.uint32),
|
||||
(subseq & 0xFFFFFFFF).to(tl.uint32),
|
||||
((subseq >> 32) & 0xFFFFFFFF).to(tl.uint32),
|
||||
)
|
||||
|
||||
inv = 1.0 / 4294967296.0 # 2^-32
|
||||
half = inv * 0.5
|
||||
|
||||
base = tid * 4
|
||||
stride = threads_per_round
|
||||
|
||||
# k=0
|
||||
i0 = base + (r * UNROLL) * stride
|
||||
m0 = i0 < n_elements
|
||||
lane0 = tl.full(tid.shape, (lane_shift + 0) % 4, tl.uint32)
|
||||
f0 = _pick_lane(u0, u1, u2, u3, lane0).to(tl.float32) * inv + half
|
||||
tl.store(out_ptr + i0, 1.0 - f0, mask=m0)
|
||||
|
||||
# k=1
|
||||
i1 = base + 1 + (r * UNROLL) * stride
|
||||
m1 = i1 < n_elements
|
||||
lane1 = tl.full(tid.shape, (lane_shift + 1) % 4, tl.uint32)
|
||||
f1 = _pick_lane(u0, u1, u2, u3, lane1).to(tl.float32) * inv + half
|
||||
tl.store(out_ptr + i1, 1.0 - f1, mask=m1)
|
||||
|
||||
# k=2
|
||||
i2 = base + 2 + (r * UNROLL) * stride
|
||||
m2 = i2 < n_elements
|
||||
lane2 = tl.full(tid.shape, (lane_shift + 2) % 4, tl.uint32)
|
||||
f2 = _pick_lane(u0, u1, u2, u3, lane2).to(tl.float32) * inv + half
|
||||
tl.store(out_ptr + i2, 1.0 - f2, mask=m2)
|
||||
|
||||
# k=3
|
||||
i3 = base + 3 + (r * UNROLL) * stride
|
||||
m3 = i3 < n_elements
|
||||
lane3 = tl.full(tid.shape, (lane_shift + 3) % 4, tl.uint32)
|
||||
f3 = _pick_lane(u0, u1, u2, u3, lane3).to(tl.float32) * inv + half
|
||||
tl.store(out_ptr + i3, 1.0 - f3, mask=m3)
|
||||
|
||||
from torch.library import triton_op, wrap_triton, register_fake, register_kernel, custom_op
|
||||
|
||||
# ---- host helpers ----
|
||||
def _compute_grid_x(nelem: int, block: int, device_index: int) -> int:
|
||||
@ -108,48 +36,49 @@ def _compute_grid_x(nelem: int, block: int, device_index: int) -> int:
|
||||
need_blocks = (nelem + block - 1) // block
|
||||
return min(max_blocks, need_blocks)
|
||||
|
||||
def _reserve_seed_and_offset_gridstride(x_device_index: int, nelem: int, block: int):
|
||||
def _shape_to_offset(size, device: torch.device) -> int:
|
||||
nelem = 1
|
||||
for s in size:
|
||||
nelem *= int(s)
|
||||
|
||||
UNROLL = 4
|
||||
gen = torch.cuda.default_generators[x_device_index]
|
||||
seed = int(gen.initial_seed())
|
||||
grid_x = _compute_grid_x(nelem, block, x_device_index)
|
||||
rounds_per_thread = (nelem + (block * grid_x * UNROLL) - 1) // (block * grid_x * UNROLL)
|
||||
counter_offset_per_thread = rounds_per_thread * UNROLL
|
||||
used_32 = counter_offset_per_thread #* block * grid_x
|
||||
prop = torch.cuda.get_device_properties(device)
|
||||
|
||||
threads_per_round = prop.multi_processor_count * prop.max_threads_per_multi_processor
|
||||
rounds_per_thread = (nelem + threads_per_round * UNROLL - 1) // (threads_per_round * UNROLL)
|
||||
used_32 = rounds_per_thread * UNROLL
|
||||
return used_32
|
||||
|
||||
#@torch._dynamo.disable
|
||||
def _reserve_offset(device: torch.device, used_32: int) -> int:
|
||||
dev_index = device.index if isinstance(device, torch.device) else int(device)
|
||||
gen = torch.cuda.default_generators[dev_index]
|
||||
old_off = int(gen.get_offset())
|
||||
gen.set_offset(old_off + used_32)
|
||||
return seed, (old_off // 4), (old_off % 4), grid_x
|
||||
gen.set_offset(old_off + used_32)
|
||||
return old_off // 4
|
||||
|
||||
def get_and_acc_base_offset():
|
||||
pass
|
||||
|
||||
@triton.jit
|
||||
def _write_offset(out_ptr, base_offset):
|
||||
# BLOCK = 64
|
||||
# idx = tl.arange(0, BLOCK)
|
||||
# offset_vec = tl.full(idx.shape, base_offset, tl.uint32)
|
||||
# mask = (idx == 0)
|
||||
# tl.store(out_ptr + idx, offset_vec, mask=mask)
|
||||
|
||||
tl.store(out_ptr + 0, base_offset, mask=True)
|
||||
|
||||
@triton_op("triton_op::rand_eager_offset", mutates_args={})
|
||||
def rand_eager_offset(
|
||||
shape: Sequence[int], #*,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
out = torch.empty(1, dtype=torch.uint32, device=device)
|
||||
|
||||
n = 1
|
||||
for d in shape:
|
||||
n *= d
|
||||
|
||||
seed_val, base_offset, lane_shift, grid_x = _reserve_seed_and_offset_gridstride(device.index, n, BLOCK)
|
||||
|
||||
grid = lambda meta: (1,)
|
||||
wrap_triton(_write_offset)[grid] (out, base_offset)
|
||||
@custom_op("my_ops::rand_eager_offset", mutates_args={})
|
||||
def rand_eager_offset(offset: int, device: torch.device) -> torch.Tensor:
|
||||
base = _reserve_offset(device, int(offset))
|
||||
out = torch.empty(1, dtype=torch.int64, device=device)
|
||||
out.fill_(int(base))
|
||||
return out
|
||||
|
||||
@custom_op("my_ops::rand_eager_offsets", mutates_args={})
|
||||
def rand_eager_offsets(offsets: list[int], device: torch.device) -> torch.Tensor:
|
||||
bases = [int(_reserve_offset(device, int(off))) for off in offsets]
|
||||
cpu = torch.tensor(bases, dtype=torch.int64).pin_memory()
|
||||
out = torch.empty_like(cpu, device=device)
|
||||
out.copy_(cpu, non_blocking=True)
|
||||
return out
|
||||
|
||||
@rand_eager_offset.register_fake
|
||||
def _(offset: int, device: torch.device):
|
||||
return torch.empty((1,), dtype=torch.int64, device=device)
|
||||
|
||||
@rand_eager_offsets.register_fake
|
||||
def _(offsets: list[int], device: torch.device):
|
||||
return torch.empty((len(offsets),), dtype=torch.int64, device=device)
|
||||
|
||||
def replace_random_passes(gm: torch.fx.GraphModule):
|
||||
"""Modify the given FX graph to use backend-native random ops"""
|
||||
@ -159,9 +88,57 @@ def replace_random_passes(gm: torch.fx.GraphModule):
|
||||
count = patterns.apply(gm)
|
||||
with GraphTransformObserver(gm, "fuse_seed_creation_pass"):
|
||||
count += fuse_seed_creation_pass(gm.graph)
|
||||
|
||||
if config.align_random_eager:
|
||||
with GraphTransformObserver(gm, "fuse_offset_creation_pass"):
|
||||
count += fuse_offset_creation_pass(gm.graph)
|
||||
#print(f"After replace_random_passes: {gm}")
|
||||
return count
|
||||
|
||||
def fuse_offset_creation_pass(graph: torch.fx.Graph):
|
||||
"""
|
||||
Horizontally fuse all the seed generation on each device
|
||||
|
||||
a = my_triton_op.rand_eager_offset(offset, dev)
|
||||
b = my_triton_op.rand_eager_offset(offset, dev)
|
||||
|
||||
Becomes:
|
||||
offsets = my_triton_op.rand_eager_offsets([offset1, offset2...], dev)
|
||||
a = inductor_lookup_seed(offsets, 0)
|
||||
b = inductor_lookup_seed(offsets, 1)
|
||||
|
||||
We do this because seed creation is entirely launch overhead bound.
|
||||
"""
|
||||
device_offsets = collections.defaultdict(list)
|
||||
for node in graph.nodes:
|
||||
if CallFunctionVarArgs(torch.ops.my_ops.rand_eager_offset).match(node):
|
||||
device_offsets[node.args[1]].append(node)
|
||||
|
||||
if not device_offsets:
|
||||
return 0
|
||||
|
||||
for device, offsets in device_offsets.items():
|
||||
with graph.inserting_before(offsets[0]):
|
||||
print(f"len(offsets) = {len(offsets)}")
|
||||
offs = [n.args[0] for n in offsets]
|
||||
combined = graph.call_function(torch.ops.my_ops.rand_eager_offsets.default, (offs, device))
|
||||
with V.fake_mode:
|
||||
combined.meta["val"] = torch.empty(
|
||||
[len(offsets)], device=device, dtype=torch.int64
|
||||
)
|
||||
combined.meta["tensor_meta"] = _extract_tensor_metadata(
|
||||
combined.meta["val"]
|
||||
)
|
||||
|
||||
for idx, offset in enumerate(offsets):
|
||||
with graph.inserting_before(offset):
|
||||
new_offset = graph.call_function(
|
||||
inductor_prims.lookup_seed, (combined, idx)
|
||||
)
|
||||
offset.replace_all_uses_with(new_offset)
|
||||
new_offset.meta.update(offset.meta)
|
||||
graph.erase_node(offset)
|
||||
|
||||
return len(device_offsets)
|
||||
|
||||
def fuse_seed_creation_pass(graph: torch.fx.Graph):
|
||||
"""
|
||||
@ -235,25 +212,6 @@ def replace_random(
|
||||
if generator is not None:
|
||||
return
|
||||
|
||||
mode = {
|
||||
aten.rand: "rand",
|
||||
aten.randn: "randn",
|
||||
}[
|
||||
match.output_node().target.overloadpacket # type: ignore[union-attr]
|
||||
] # type: ignore[union-attr]
|
||||
device = get_device(device)
|
||||
# For uniform rand (e.g., used by dropout), call our custom Triton op directly
|
||||
|
||||
if mode == "rand":
|
||||
def replacement(size):
|
||||
# dtype: keep caller's dtype if provided, else default fp32
|
||||
use_dtype = dtype if dtype is not None else torch.float32
|
||||
return torch.ops.my_triton_op.philox_rand(size, device, use_dtype)
|
||||
|
||||
match.replace_by_example(replacement, [size])
|
||||
return
|
||||
|
||||
# Fallback (e.g., randn) keeps existing inductor behavior
|
||||
def replacement(size):
|
||||
result = inductor_prims.random(
|
||||
size, inductor_prims.seed(device), mode, **default_kwargs(device)
|
||||
@ -272,8 +230,9 @@ def replace_random(
|
||||
|
||||
if mode == "rand" and config.align_random_eager:
|
||||
def replacement(size):
|
||||
offset = _shape_to_offset(size, device)
|
||||
result = inductor_prims.random(
|
||||
size, inductor_prims.lookup_seed(torch.ops.my_triton_op.rand_eager_offset(size, device), 0), mode, **default_kwargs(device)
|
||||
size, torch.ops.my_ops.rand_eager_offset(offset, device), mode, **default_kwargs(device)
|
||||
)
|
||||
if dtype is not None:
|
||||
result = result.to(dtype)
|
||||
|
||||
Reference in New Issue
Block a user