mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Add support on CUDA keys for control flow ops. (#94465)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/94465 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
989fb7c921
commit
e3c4cea668
@ -101,16 +101,18 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
|
||||
|
||||
@cond.py_impl(DispatchKey.CUDA)
|
||||
@cond.py_impl(DispatchKey.CPU)
|
||||
def cond_dense(pred, true_fn, false_fn, operands):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is None), "Mode should never be enabled for CPU key"
|
||||
assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
|
||||
if pred:
|
||||
return true_fn(*operands)
|
||||
else:
|
||||
return false_fn(*operands)
|
||||
|
||||
|
||||
@cond.py_impl(DispatchKey.AutogradCUDA)
|
||||
@cond.py_impl(DispatchKey.AutogradCPU)
|
||||
def cond_autograd(pred, true_fn, false_fn, *operands):
|
||||
# TODO: support autograd
|
||||
|
@ -57,13 +57,15 @@ def trace_map(proxy_mode, func_overload, f, xs, *args):
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
|
||||
|
||||
@map.py_impl(DispatchKey.CUDA)
|
||||
@map.py_impl(DispatchKey.CPU)
|
||||
def map_cpu(f, xs, *args):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is None), "Mode should never be enabled for CPU key"
|
||||
assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
|
||||
return torch.stack([f(x, *args) for x in xs])
|
||||
|
||||
|
||||
@map.py_impl(DispatchKey.AutogradCUDA)
|
||||
@map.py_impl(DispatchKey.AutogradCPU)
|
||||
def map_autograd(f, xs, *args):
|
||||
# TODO: support autograd
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Owner(s): ["module: functorch"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from functorch.experimental import control_flow
|
||||
from functorch.experimental.control_flow import cond
|
||||
@ -20,6 +22,30 @@ class TestControlFlow(TestCase):
|
||||
result = cond(False, true_fn, false_fn, [x])
|
||||
self.assertEqual(result, torch.cos(x))
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
||||
def test_cond_gpu(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos()
|
||||
|
||||
x = torch.randn(4, device="cuda")
|
||||
pred = torch.tensor(False, device="cuda")
|
||||
result = cond(False, true_fn, false_fn, [x])
|
||||
self.assertEqual(result, torch.cos(x))
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
||||
def test_map_gpu(self):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
xs = torch.ones(3, 2, 2, device="cuda")
|
||||
y = torch.ones(2, device="cuda")
|
||||
res = control_flow.map(f, xs, y)
|
||||
|
||||
self.assertEqual(res, control_flow.map(f, torch.ones(3, 2, 2), torch.ones(2)))
|
||||
|
||||
|
||||
class TestControlFlowTraced(TestCase):
|
||||
def test_cond_traced_not_nested(self):
|
||||
|
Reference in New Issue
Block a user