mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492 Approved by: https://github.com/albanD
119 lines
3.5 KiB
Python
119 lines
3.5 KiB
Python
# mypy: allow-untyped-defs
|
|
try:
|
|
import halide as hl # type: ignore[import-untyped, import-not-found]
|
|
except ImportError:
|
|
hl = None
|
|
|
|
PHILOX_N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
|
|
|
if hl is not None:
|
|
PHILOX_KEY_A_U32 = hl.u32(0x9E3779B9)
|
|
PHILOX_KEY_B_U32 = hl.u32(0xBB67AE85)
|
|
PHILOX_ROUND_A_U32 = hl.u32(0xD2511F53)
|
|
PHILOX_ROUND_B_U32 = hl.u32(0xCD9E8D57)
|
|
else:
|
|
PHILOX_KEY_A_U32 = None
|
|
PHILOX_KEY_B_U32 = None
|
|
PHILOX_ROUND_A_U32 = None
|
|
PHILOX_ROUND_B_U32 = None
|
|
|
|
|
|
def _pair_uniform_to_normal(u1, u2):
|
|
"""Box-Muller transform"""
|
|
u1 = hl.max(hl.f32(1.0e-7), u1)
|
|
th = hl.f32(6.283185307179586) * u2
|
|
r = hl.sqrt(hl.f32(-2.0) * hl.log(u1))
|
|
return r * hl.cos(th), r * hl.sin(th)
|
|
|
|
|
|
def _uint_to_uniform_float(x):
|
|
"""
|
|
Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
|
|
"""
|
|
|
|
# TODO:
|
|
# conditions can be simplified
|
|
# scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
|
|
# https://github.com/triton-lang/triton/blob/e4a0d93ff1a367c7d4eeebbcd7079ed267e6b06f/python/triton/language/random.py#L116-L132.
|
|
assert x.type() == hl.UInt(32) or x.type() == hl.Int(32)
|
|
x = hl.cast(hl.Int(32), x)
|
|
scale = hl.f64(4.6566127342e-10)
|
|
x = hl.select(x < 0, -x - 1, x)
|
|
return x * scale
|
|
|
|
|
|
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds):
|
|
def umulhi(a, b):
|
|
a = hl.cast(hl.UInt(64), a)
|
|
b = hl.cast(hl.UInt(64), b)
|
|
return hl.cast(hl.UInt(32), ((a * b) >> 32) & hl.u64(0xFFFFFFFF))
|
|
|
|
for _ in range(n_rounds):
|
|
_c0, _c2 = c0, c2
|
|
|
|
c0 = umulhi(PHILOX_ROUND_B_U32, _c2) ^ c1 ^ k0
|
|
c2 = umulhi(PHILOX_ROUND_A_U32, _c0) ^ c3 ^ k1
|
|
c1 = PHILOX_ROUND_B_U32 * _c2
|
|
c3 = PHILOX_ROUND_A_U32 * _c0
|
|
# raise key
|
|
k0 = k0 + PHILOX_KEY_A_U32
|
|
k1 = k1 + PHILOX_KEY_B_U32
|
|
|
|
return c0, c1, c2, c3
|
|
|
|
|
|
def halide_philox(seed, c0, c1, c2, c3, n_rounds):
|
|
seed = hl.cast(hl.UInt(64), seed)
|
|
|
|
assert c0.type().bits() == 32
|
|
|
|
seed_hi = hl.cast(hl.UInt(32), (seed >> 32) & hl.u64(0xFFFFFFFF))
|
|
seed_lo = hl.cast(hl.UInt(32), seed & hl.u64(0xFFFFFFFF))
|
|
|
|
return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
|
|
|
|
|
|
def randint4x(seed, offset, n_rounds):
|
|
offset = hl.cast(hl.UInt(32), offset)
|
|
_0 = hl.u32(0)
|
|
return halide_philox(seed, offset, _0, _0, _0, n_rounds)
|
|
|
|
|
|
def rand4x(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
|
|
i1, i2, i3, i4 = randint4x(seed, offset, n_rounds)
|
|
u1 = _uint_to_uniform_float(i1)
|
|
u2 = _uint_to_uniform_float(i2)
|
|
u3 = _uint_to_uniform_float(i3)
|
|
u4 = _uint_to_uniform_float(i4)
|
|
return u1, u2, u3, u4
|
|
|
|
|
|
def randint(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
|
|
ret, _, _, _ = randint4x(seed, offset, n_rounds)
|
|
return ret
|
|
|
|
|
|
def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
|
|
source = randint(seed, offset, n_rounds)
|
|
return _uint_to_uniform_float(source)
|
|
|
|
|
|
def randn(seed, offset):
|
|
i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT)
|
|
u1 = _uint_to_uniform_float(i1)
|
|
u2 = _uint_to_uniform_float(i2)
|
|
n1, _ = _pair_uniform_to_normal(u1, u2)
|
|
return n1
|
|
|
|
|
|
def randint64(seed, offset, low, high):
|
|
r0, r1, _r2, _r3 = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT)
|
|
r0 = hl.cast(hl.UInt(64), r0)
|
|
r1 = hl.cast(hl.UInt(64), r1)
|
|
|
|
result = r0 | (r1 << 32)
|
|
size = high - low
|
|
result = result % hl.cast(hl.UInt(64), size)
|
|
result = hl.cast(hl.Int(64), result) + low
|
|
return result
|