[functorch] Add support on CUDA keys for control flow ops. (#94465)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94465
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
zhxchen17
2023-02-12 06:45:53 +00:00
committed by PyTorch MergeBot
parent 989fb7c921
commit e3c4cea668
3 changed files with 32 additions and 2 deletions

View File

@ -101,16 +101,18 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@cond.py_impl(DispatchKey.CUDA)
@cond.py_impl(DispatchKey.CPU)
def cond_dense(pred, true_fn, false_fn, operands):
mode = _get_current_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU key"
assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)
@cond.py_impl(DispatchKey.AutogradCUDA)
@cond.py_impl(DispatchKey.AutogradCPU)
def cond_autograd(pred, true_fn, false_fn, *operands):
# TODO: support autograd

View File

@ -57,13 +57,15 @@ def trace_map(proxy_mode, func_overload, f, xs, *args):
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@map.py_impl(DispatchKey.CUDA)
@map.py_impl(DispatchKey.CPU)
def map_cpu(f, xs, *args):
mode = _get_current_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU key"
assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
return torch.stack([f(x, *args) for x in xs])
@map.py_impl(DispatchKey.AutogradCUDA)
@map.py_impl(DispatchKey.AutogradCPU)
def map_autograd(f, xs, *args):
# TODO: support autograd

View File

@ -1,4 +1,6 @@
# Owner(s): ["module: functorch"]
import unittest
import torch
from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond
@ -20,6 +22,30 @@ class TestControlFlow(TestCase):
result = cond(False, true_fn, false_fn, [x])
self.assertEqual(result, torch.cos(x))
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
def test_cond_gpu(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
x = torch.randn(4, device="cuda")
pred = torch.tensor(False, device="cuda")
result = cond(False, true_fn, false_fn, [x])
self.assertEqual(result, torch.cos(x))
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
def test_map_gpu(self):
def f(x, y):
return x + y
xs = torch.ones(3, 2, 2, device="cuda")
y = torch.ones(2, device="cuda")
res = control_flow.map(f, xs, y)
self.assertEqual(res, control_flow.map(f, torch.ones(3, 2, 2), torch.ones(2)))
class TestControlFlowTraced(TestCase):
def test_cond_traced_not_nested(self):