mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add determinmistic kernel for reflection2d (#136241)
Adds feature for #98925 Tests pass for both existing reflectionpad2d and the new one I inserted. **Summary of the work:** Simple conditional check for deterministic mode that will dispatch to a different kernel. This kernel does not use any atomic operations, and will lead to deterministic results as instead of going from the output to input(1:1) relationship, I am doing the opposite. I am going from input -> all outputs, which is 1 to many. These operations are done in the same order every execution as I simply traverse the data set with a grid stride loop and use simple linearized indexing into the input tensor. So each thread will compute the 4 conditionals, which are then used to see if the input has an output in the 8 regions. These 8 regions are top left, top, top right, left, right, bottom left, bottom, bottom right`. I did not focus on performance for this PR as that would expand the scope heavily. If there are any performance questions though i can answer. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136241 Approved by: https://github.com/eqy, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b8c28099a
commit
bb4964013f
@ -1632,18 +1632,6 @@ else:
|
||||
'reflection_pad1d_backward_out_cuda',
|
||||
torch.device(device).type == 'cuda')
|
||||
|
||||
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
|
||||
def test_nondeterministic_alert_ReflectionPad2d(self, device):
|
||||
module = torch.nn.ReflectionPad2d((1, 2, 3, 4))
|
||||
input = torch.randn(2, 3, 8, 8, device=device, requires_grad=True)
|
||||
res = module(input)
|
||||
grad = torch.ones_like(res)
|
||||
|
||||
self.check_nondeterministic_alert(
|
||||
lambda: res.backward(grad, retain_graph=True),
|
||||
'reflection_pad2d_backward_cuda',
|
||||
torch.device(device).type == 'cuda')
|
||||
|
||||
@skipIfMPS
|
||||
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
|
||||
def test_nondeterministic_alert_ReflectionPad3d(self, device):
|
||||
|
||||
Reference in New Issue
Block a user