mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[philox_rand] Add decomps (#100206)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100206 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
9cda7b9e47
commit
a8ad0dc333
@ -9,7 +9,6 @@ from torch.testing._internal.common_utils import (
|
||||
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
|
||||
from functorch.compile import aot_function, nop, min_cut_rematerialization_partition
|
||||
from unittest import skip
|
||||
from unittest.mock import patch
|
||||
import functools
|
||||
import torch.utils.checkpoint
|
||||
@ -287,7 +286,6 @@ class TestFunctionalizationRngOps(TestCase):
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
# TODO - Dropout needs more work because of offset calculation
|
||||
@skip("Dropout needs more work because of offset calculation")
|
||||
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
||||
@dtypes(torch.float32)
|
||||
def test_checkpoint(self, dtype, device):
|
||||
@ -307,11 +305,22 @@ class TestFunctionalizationRngOps(TestCase):
|
||||
fwd_compiler = functools.partial(count_philox_rand, freq=1)
|
||||
bwd_compiler = functools.partial(count_philox_rand, freq=1)
|
||||
aot_fn = aot_function(fn, fwd_compiler, bwd_compiler)
|
||||
torch.cuda.manual_seed(123)
|
||||
# We cant check accuracy here because rand_like generated different rand numbers than dropout
|
||||
res = aot_fn(x, y)
|
||||
# res.sum().backward()
|
||||
# TODO - This is not same. Debug this further.
|
||||
self.assertEqual(ref, res)
|
||||
res.sum().backward()
|
||||
|
||||
@dtypes(torch.float32)
|
||||
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
||||
def test_dropout_decomp(self, dtype, device):
|
||||
def fn(x):
|
||||
return torch.nn.functional.dropout(x, 0.6) * x
|
||||
|
||||
x = torch.rand(10, device=device, dtype=dtype)
|
||||
|
||||
# Ensure the decomp is happening
|
||||
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1))
|
||||
# We cant check accuracy here because rand_like generated different rand numbers than dropout
|
||||
aot_fn(x)
|
||||
|
||||
|
||||
only_for = ("cuda",)
|
||||
|
@ -1,8 +1,10 @@
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
import torch._decomp as decomp
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._ops import OpOverload
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -201,3 +203,51 @@ class PhiloxStateTracker:
|
||||
return cls.multiple_of_4(
|
||||
cls.bwd_state.base_offset + cls.bwd_state.relative_offset
|
||||
)
|
||||
|
||||
|
||||
# Adding more decompositions which eventually use rand_like inside decomps.
|
||||
# Adding these in rng_decompositins ensures the functionalization of rand_like
|
||||
# ops used in these decomps. The list is copied from inductor codebase, which
|
||||
# uses it for similar purpose.
|
||||
#
|
||||
# Caution - These decomps do not have same accuracy as that of eager. However,
|
||||
# we can't just disable them with a config flag like fallback_random, because
|
||||
# for fuctionalization of rng ops, we have to decompose these ops.
|
||||
extra_random_decomps = get_decompositions(
|
||||
[
|
||||
aten.cauchy,
|
||||
aten.cauchy_,
|
||||
aten.exponential,
|
||||
aten.exponential_,
|
||||
aten.geometric,
|
||||
aten.geometric_,
|
||||
aten.native_dropout,
|
||||
aten.normal,
|
||||
aten.normal_,
|
||||
aten.normal_functional,
|
||||
aten.log_normal,
|
||||
aten.log_normal_,
|
||||
aten.uniform_,
|
||||
]
|
||||
)
|
||||
register_extra_random_decomp = functools.partial(
|
||||
decomp.register_decomposition, registry=extra_random_decomps
|
||||
)
|
||||
|
||||
|
||||
@register_extra_random_decomp([aten.bernoulli_])
|
||||
def bernoulli_(self, p=0.5):
|
||||
if self.device == torch.device("cpu"):
|
||||
return NotImplemented
|
||||
return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
|
||||
|
||||
|
||||
@register_extra_random_decomp([aten.bernoulli.p])
|
||||
def bernoulli_p(self, p=0.5, *, generator=None):
|
||||
if self.device == torch.device("cpu"):
|
||||
return NotImplemented
|
||||
assert generator is None
|
||||
return torch.rand_like(self, dtype=torch.float32) < p
|
||||
|
||||
|
||||
rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type]
|
||||
|
@ -9,6 +9,7 @@ import torch.ao.quantization.fx._decomposed
|
||||
from torch import Tensor
|
||||
from torch._decomp import core_aten_decompositions, get_decompositions
|
||||
from torch._decomp.decompositions import pw_cast_for_opmath
|
||||
from torch._decomp.decompositions_for_rng import extra_random_decomps
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
|
||||
from . import config, utils
|
||||
@ -497,47 +498,6 @@ def dequantize_per_tensor_tensor_decomp_impl(
|
||||
return (input.to(torch.float32) - zero_point) * scale
|
||||
|
||||
|
||||
"""
|
||||
Some decomps result in differences from eager related to randomness.
|
||||
We put these decomps in a separate table `extra_random_decomps` to allow
|
||||
turning them on and off via `config.fallback_random`.
|
||||
"""
|
||||
extra_random_decomps = get_decompositions(
|
||||
[
|
||||
aten.cauchy,
|
||||
aten.cauchy_,
|
||||
aten.exponential,
|
||||
aten.exponential_,
|
||||
aten.geometric,
|
||||
aten.geometric_,
|
||||
aten.normal,
|
||||
aten.normal_,
|
||||
aten.normal_functional,
|
||||
aten.log_normal,
|
||||
aten.log_normal_,
|
||||
aten.uniform_,
|
||||
]
|
||||
)
|
||||
register_extra_random_decomp = functools.partial(
|
||||
decomp.register_decomposition, registry=extra_random_decomps
|
||||
)
|
||||
|
||||
|
||||
@register_extra_random_decomp([aten.bernoulli_])
|
||||
def bernoulli_(self, p=0.5):
|
||||
if self.device == torch.device("cpu"):
|
||||
return NotImplemented
|
||||
return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
|
||||
|
||||
|
||||
@register_extra_random_decomp([aten.bernoulli.p])
|
||||
def bernoulli_p(self, p=0.5, *, generator=None):
|
||||
if self.device == torch.device("cpu"):
|
||||
return NotImplemented
|
||||
assert generator is None
|
||||
return torch.rand_like(self, dtype=torch.float32) < p
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def fast_random_decomps():
|
||||
return {**decompositions, **extra_random_decomps}
|
||||
|
Reference in New Issue
Block a user