[philox_rand] Add decomps (#100206)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100206
Approved by: https://github.com/ngimel
This commit is contained in:
Animesh Jain
2023-04-27 21:30:59 +00:00
committed by PyTorch MergeBot
parent 9cda7b9e47
commit a8ad0dc333
3 changed files with 66 additions and 47 deletions

View File

@ -9,7 +9,6 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
from functorch.compile import aot_function, nop, min_cut_rematerialization_partition
from unittest import skip
from unittest.mock import patch
import functools
import torch.utils.checkpoint
@ -287,7 +286,6 @@ class TestFunctionalizationRngOps(TestCase):
self.assertEqual(x.grad, x_clone.grad)
# TODO - Dropout needs more work because of offset calculation
@skip("Dropout needs more work because of offset calculation")
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
@dtypes(torch.float32)
def test_checkpoint(self, dtype, device):
@ -307,11 +305,22 @@ class TestFunctionalizationRngOps(TestCase):
fwd_compiler = functools.partial(count_philox_rand, freq=1)
bwd_compiler = functools.partial(count_philox_rand, freq=1)
aot_fn = aot_function(fn, fwd_compiler, bwd_compiler)
torch.cuda.manual_seed(123)
# We cant check accuracy here because rand_like generated different rand numbers than dropout
res = aot_fn(x, y)
# res.sum().backward()
# TODO - This is not same. Debug this further.
self.assertEqual(ref, res)
res.sum().backward()
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_dropout_decomp(self, dtype, device):
def fn(x):
return torch.nn.functional.dropout(x, 0.6) * x
x = torch.rand(10, device=device, dtype=dtype)
# Ensure the decomp is happening
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1))
# We cant check accuracy here because rand_like generated different rand numbers than dropout
aot_fn(x)
only_for = ("cuda",)

View File

@ -1,8 +1,10 @@
import functools
from collections import defaultdict
from typing import Callable, Dict
import torch
import torch._decomp as decomp
from torch._decomp import get_decompositions
from torch._ops import OpOverload
aten = torch.ops.aten
@ -201,3 +203,51 @@ class PhiloxStateTracker:
return cls.multiple_of_4(
cls.bwd_state.base_offset + cls.bwd_state.relative_offset
)
# Adding more decompositions which eventually use rand_like inside decomps.
# Adding these in rng_decompositins ensures the functionalization of rand_like
# ops used in these decomps. The list is copied from inductor codebase, which
# uses it for similar purpose.
#
# Caution - These decomps do not have same accuracy as that of eager. However,
# we can't just disable them with a config flag like fallback_random, because
# for fuctionalization of rng ops, we have to decompose these ops.
extra_random_decomps = get_decompositions(
[
aten.cauchy,
aten.cauchy_,
aten.exponential,
aten.exponential_,
aten.geometric,
aten.geometric_,
aten.native_dropout,
aten.normal,
aten.normal_,
aten.normal_functional,
aten.log_normal,
aten.log_normal_,
aten.uniform_,
]
)
register_extra_random_decomp = functools.partial(
decomp.register_decomposition, registry=extra_random_decomps
)
@register_extra_random_decomp([aten.bernoulli_])
def bernoulli_(self, p=0.5):
if self.device == torch.device("cpu"):
return NotImplemented
return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
@register_extra_random_decomp([aten.bernoulli.p])
def bernoulli_p(self, p=0.5, *, generator=None):
if self.device == torch.device("cpu"):
return NotImplemented
assert generator is None
return torch.rand_like(self, dtype=torch.float32) < p
rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type]

View File

@ -9,6 +9,7 @@ import torch.ao.quantization.fx._decomposed
from torch import Tensor
from torch._decomp import core_aten_decompositions, get_decompositions
from torch._decomp.decompositions import pw_cast_for_opmath
from torch._decomp.decompositions_for_rng import extra_random_decomps
from torch.utils._mode_utils import no_dispatch
from . import config, utils
@ -497,47 +498,6 @@ def dequantize_per_tensor_tensor_decomp_impl(
return (input.to(torch.float32) - zero_point) * scale
"""
Some decomps result in differences from eager related to randomness.
We put these decomps in a separate table `extra_random_decomps` to allow
turning them on and off via `config.fallback_random`.
"""
extra_random_decomps = get_decompositions(
[
aten.cauchy,
aten.cauchy_,
aten.exponential,
aten.exponential_,
aten.geometric,
aten.geometric_,
aten.normal,
aten.normal_,
aten.normal_functional,
aten.log_normal,
aten.log_normal_,
aten.uniform_,
]
)
register_extra_random_decomp = functools.partial(
decomp.register_decomposition, registry=extra_random_decomps
)
@register_extra_random_decomp([aten.bernoulli_])
def bernoulli_(self, p=0.5):
if self.device == torch.device("cpu"):
return NotImplemented
return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
@register_extra_random_decomp([aten.bernoulli.p])
def bernoulli_p(self, p=0.5, *, generator=None):
if self.device == torch.device("cpu"):
return NotImplemented
assert generator is None
return torch.rand_like(self, dtype=torch.float32) < p
@functools.lru_cache(None)
def fast_random_decomps():
return {**decompositions, **extra_random_decomps}