|
539363a873
|
[inductor] Lowering of rngprims philox_rand (#99289)
An example graph with Dynamic shapes on
`arg0_1` is seed, `arg1_1` is base offset.
~~~
===== Forward graph 0 =====
<eval_with_key>.5 class <lambda>(torch.nn.Module):
def forward(self, arg0_1: i64[], arg1_1: i64[], arg2_1: Sym(s0), arg3_1: f32[s0]):
# File: /scratch/anijain/work/pytorch/test/inductor/test_torchinductor.py:4605, code: a = torch.rand_like(x) * x
add: i64[] = torch.ops.aten.add.Tensor(arg1_1, 0)
philox_rand = torch.ops.rngprims.philox_rand.default([arg2_1], arg0_1, add, None, device(type='cuda', index=0), torch.float32); add = None
getitem: f32[s0] = philox_rand[0]
getitem_1: i64[] = philox_rand[1]; philox_rand = None
add_1: i64[] = torch.ops.aten.add.Tensor(getitem_1, 0); getitem_1 = None
mul: f32[s0] = torch.ops.aten.mul.Tensor(getitem, arg3_1); getitem = arg3_1 = None
# File: /scratch/anijain/work/pytorch/test/inductor/test_torchinductor.py:4606, code: a = torch.rand_like(x) * a
add_2: i64[] = torch.ops.aten.add.Tensor(arg1_1, add_1)
philox_rand_1 = torch.ops.rngprims.philox_rand.default([arg2_1], arg0_1, add_2, None, device(type='cuda', index=0), torch.float32); arg2_1 = arg0_1 = add_2 = None
getitem_2: f32[s0] = philox_rand_1[0]
getitem_3: i64[] = philox_rand_1[1]; philox_rand_1 = None
add_3: i64[] = torch.ops.aten.add.Tensor(add_1, getitem_3); add_1 = getitem_3 = None
mul_1: f32[s0] = torch.ops.aten.mul.Tensor(getitem_2, mul); getitem_2 = mul = None
# No stacktrace found for following nodes
add_4: i64[] = torch.ops.aten.add.Tensor(arg1_1, add_3); arg1_1 = add_3 = None
add_5: i64[] = torch.ops.aten.add.Tensor(add_4, 3); add_4 = None
div: i64[] = torch.ops.aten.div.Tensor_mode(add_5, 4, rounding_mode = 'floor'); add_5 = None
mul_2: i64[] = torch.ops.aten.mul.Tensor(div, 4); div = None
return (mul_1, mul_2)
~~~
Note that the output `mul2` is basically total `numel` of the random ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99289
Approved by: https://github.com/jansel
|
2023-04-26 01:22:41 +00:00 |
|
|
6bc4651193
|
[philox_rand] Dynamic shape support (#99290)
Extends the functionalization of rng work to Dynamic shapes. An example of the generated graph looks like this
~~~
[2023-04-24 21:41:37,446] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
===== Forward graph 1 =====
<eval_with_key>.7 class <lambda>(torch.nn.Module):
def forward(self, arg0_1: i64[], arg1_1: i64[], arg2_1: Sym(s0), arg3_1: Sym(s1), arg4_1: f32[s0, s1]):
# File: /scratch/anijain/work/pytorch/test/test_functionalization_of_rng_ops.py:46, code: a = torch.rand_like(x) * x
add: i64[] = torch.ops.aten.add.Tensor(arg1_1, 0)
philox_rand = torch.ops.rngprims.philox_rand.default([arg2_1, arg3_1], arg0_1, add, None, device(type='cuda', index=0), torch.float32); add = None
getitem: f32[s0, s1] = philox_rand[0]
getitem_1: i64[] = philox_rand[1]; philox_rand = None
add_1: i64[] = torch.ops.aten.add.Tensor(getitem_1, 0); getitem_1 = None
mul: f32[s0, s1] = torch.ops.aten.mul.Tensor(getitem, arg4_1); getitem = arg4_1 = None
# File: /scratch/anijain/work/pytorch/test/test_functionalization_of_rng_ops.py:47, code: a = torch.rand_like(x) * a
add_2: i64[] = torch.ops.aten.add.Tensor(arg1_1, add_1)
philox_rand_1 = torch.ops.rngprims.philox_rand.default([arg2_1, arg3_1], arg0_1, add_2, None, device(type='cuda', index=0), torch.float32); arg2_1 = arg3_1 = arg0_1 = add_2 = None
getitem_2: f32[s0, s1] = philox_rand_1[0]
getitem_3: i64[] = philox_rand_1[1]; philox_rand_1 = None
add_3: i64[] = torch.ops.aten.add.Tensor(add_1, getitem_3); add_1 = getitem_3 = None
mul_1: f32[s0, s1] = torch.ops.aten.mul.Tensor(getitem_2, mul); getitem_2 = mul = None
# No stacktrace found for following nodes
add_4: i64[] = torch.ops.aten.add.Tensor(arg1_1, add_3); arg1_1 = add_3 = None
return (mul_1, add_4)
~~~
Each rand op is accompanied by its offset calculation op.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99290
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
|
2023-04-25 22:40:28 +00:00 |
|
|
fdbc8625a1
|
Functionalization of torch.rand/rand_like ops (#97377)
This PR introduces the functionalization of RNG ops. Key points are
* Introduces a new `philox_rand` prim operator that accepts seed, offset.
* Adds decompositions for random operators that use these philox_rand prims
* Adds a PhiloxStateTracker to track the offset for each occurence of rand ops
* Changes calling convention of AOT Autograd and adds <fwd_seed, fwd_base_offset> and <bwd_seed, bwd_base_offset>
* Monkeypatches set_rng_state and get_rng_state while AOT Autograd tracing to record the rng state behavior
* Raises assertion for CPU because CPU does not Philox RNG.
Not dealt in this PR
* dropout op - offset calculation is different
* other distributions like normal, poisson etc
* Inductor support
* Cudagraph support
* Dynamic shape support
An example
~~~
class Custom(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
====== Forward graph 0 ======
def forward(self, fwd_seed_1: i64[], fwd_base_offset_1: i64[], primals_1: f32[16, 16]):
# No stacktrace found for following nodes
add: i64[] = torch.ops.aten.add.Tensor(fwd_base_offset_1, 0)
philox_rand: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], fwd_seed_1, add, [16, 1], device(type='cuda', index=0), torch.float32); add = None
mul: f32[16, 16] = torch.ops.aten.mul.Tensor(philox_rand, primals_1); philox_rand = None
add_1: i64[] = torch.ops.aten.add.Tensor(fwd_base_offset_1, 4); fwd_base_offset_1 = None
philox_rand_1: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], fwd_seed_1, add_1, [16, 1], device(type='cuda', index=0), torch.float32); fwd_seed_1 = add_1 = None
mul_1: f32[16, 16] = torch.ops.aten.mul.Tensor(philox_rand_1, mul); philox_rand_1 = mul = None
return [mul_1, primals_1]
====== Backward graph 0 ======
def forward(self, bwd_seed_1: i64[], bwd_base_offset_1: i64[], primals_1: f32[16, 16], tangents_1: f32[16, 16]):
# No stacktrace found for following nodes
add_2: i64[] = torch.ops.aten.add.Tensor(bwd_base_offset_1, 0); bwd_base_offset_1 = None
philox_rand_2: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], bwd_seed_1, add_2, [16, 1], device(type='cuda', index=0), torch.float32); bwd_seed_1 = add_2 = None
mul_2: f32[16, 16] = torch.ops.aten.mul.Tensor(tangents_1, philox_rand_2); tangents_1 = philox_rand_2 = None
cos: f32[16, 16] = torch.ops.aten.cos.default(primals_1); primals_1 = None
mul_3: f32[16, 16] = torch.ops.aten.mul.Tensor(mul_2, cos); mul_2 = cos = None
return [mul_3]
~~~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97377
Approved by: https://github.com/ezyang
|
2023-04-16 09:55:56 +00:00 |
|