Add determinmistic kernel for reflection2d (#136241)

Adds feature for #98925

Tests pass for both existing reflectionpad2d and the new one I inserted.

**Summary of the work:**

Simple conditional check for deterministic mode that will dispatch to a different kernel. This kernel does not use any atomic operations, and will lead to deterministic results as instead of going from the output to input(1:1) relationship, I am doing the opposite. I am going from input -> all outputs, which is 1 to many. These operations are done in the same order every execution as I simply traverse the data set with a grid stride loop and use simple linearized indexing into the input tensor.

So each thread will compute the 4 conditionals, which are then used to see if the input has an output in the 8 regions. These 8 regions are top left, top, top right, left, right, bottom left, bottom, bottom right`.

I did not focus on performance for this PR as that would expand the scope heavily. If there are any performance questions though i can answer.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136241
Approved by: https://github.com/eqy, https://github.com/albanD
This commit is contained in:
Danial Javady
2025-01-29 20:34:01 +00:00
committed by PyTorch MergeBot
parent 2b8c28099a
commit bb4964013f
3 changed files with 311 additions and 47 deletions

View File

@ -1632,18 +1632,6 @@ else:
'reflection_pad1d_backward_out_cuda',
torch.device(device).type == 'cuda')
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
def test_nondeterministic_alert_ReflectionPad2d(self, device):
module = torch.nn.ReflectionPad2d((1, 2, 3, 4))
input = torch.randn(2, 3, 8, 8, device=device, requires_grad=True)
res = module(input)
grad = torch.ones_like(res)
self.check_nondeterministic_alert(
lambda: res.backward(grad, retain_graph=True),
'reflection_pad2d_backward_cuda',
torch.device(device).type == 'cuda')
@skipIfMPS
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
def test_nondeterministic_alert_ReflectionPad3d(self, device):