mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow functionalization to work with optional mutable (#114803)
Summary: - Added functionalization to allow Optionals Test Plan: CI tests. Reviewed By: zou3519 Differential Revision: D51209981 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114803 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
7b3e45be59
commit
ad09d81694
@ -428,7 +428,6 @@ class TestPythonRegistration(TestCase):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
|
||||
|
||||
schemas = [
|
||||
'foo(Tensor x, Tensor(a!)? y) -> ()',
|
||||
'foo(Tensor x, Tensor(a!)[] y) -> ()',
|
||||
'foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)',
|
||||
'foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))',
|
||||
@ -466,7 +465,7 @@ class TestPythonRegistration(TestCase):
|
||||
|
||||
# check rest of functional_result is the mutated args
|
||||
mutated_args = [maybe_mutated_arg for maybe_mutated_arg, arg in zip(cloned_args, args)
|
||||
if not torch.allclose(maybe_mutated_arg, arg)]
|
||||
if not(maybe_mutated_arg is not None and arg is not None and torch.allclose(maybe_mutated_arg, arg))]
|
||||
self.assertEqual(flat_functional_result[len(flat_mutable_result):], mutated_args)
|
||||
|
||||
# check that functionalization kernel was indeed registered
|
||||
@ -504,6 +503,33 @@ class TestPythonRegistration(TestCase):
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
|
||||
def test_register_functional_op_with_optional(self):
|
||||
lib = Library(self.test_ns, 'FRAGMENT')
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()')
|
||||
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
z.fill_(2.71)
|
||||
if w is not None:
|
||||
w.fill_(1.618)
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, None))
|
||||
|
||||
|
||||
def test_register_functional_op_one_return(self):
|
||||
lib = Library(self.test_ns, 'FRAGMENT')
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor')
|
||||
|
@ -1,12 +1,21 @@
|
||||
import torch
|
||||
from torch.library import Library
|
||||
from torch._ops import OpOverload
|
||||
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseTy, BaseType
|
||||
from torch._C import _ExcludeDispatchKeyGuard, DispatchKeySet, DispatchKey
|
||||
from .autograd import autograd_not_implemented
|
||||
import torch.utils._pytree as pytree
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
|
||||
from torch._ops import OpOverload
|
||||
from torch.library import Library
|
||||
from torchgen.model import (
|
||||
BaseTy,
|
||||
BaseType,
|
||||
FunctionSchema,
|
||||
OperatorName,
|
||||
OptionalType,
|
||||
SchemaKind,
|
||||
)
|
||||
|
||||
from .autograd import autograd_not_implemented
|
||||
|
||||
|
||||
def register_functional_op(
|
||||
lib: Library,
|
||||
@ -66,7 +75,7 @@ def construct_functional_impl(mutable_op):
|
||||
extra_rets = []
|
||||
for is_write, arg in zip(mutable_args(mutable_op), args):
|
||||
if is_write:
|
||||
cloned = arg.clone()
|
||||
cloned = arg.clone() if arg is not None else None
|
||||
new_args.append(cloned)
|
||||
extra_rets.append(cloned)
|
||||
else:
|
||||
@ -117,6 +126,8 @@ def construct_functionalization_kernel(mutable_op, functional_op):
|
||||
if is_write]
|
||||
assert len(new_values_to_propagate) == len(inputs_to_replace)
|
||||
for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
|
||||
if (arg is None and new_value is None) or (arg is not None and new_value is not None):
|
||||
continue
|
||||
torch._C._propagate_xla_data(arg, new_value)
|
||||
torch._C._replace_(arg, new_value)
|
||||
torch._C._commit_update(arg)
|
||||
@ -156,9 +167,12 @@ def validate(mutable_op: OpOverload):
|
||||
"not return the mutated value or aliases)")
|
||||
for arg in schema.arguments.flat_all:
|
||||
# construct_functionalization_kernel assumes this for simplicity
|
||||
if arg.type.is_tensor_like() and arg.type != BaseType(BaseTy.Tensor):
|
||||
if arg.type.is_tensor_like() and (
|
||||
arg.type != BaseType(BaseTy.Tensor)
|
||||
and arg.type != OptionalType(BaseType(BaseTy.Tensor))
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"NYI: register_functional_op(op) where op accepts Optional or List of tensors."
|
||||
"NYI: register_functional_op(op) where op has a List[Tensor] input."
|
||||
"Please file an issue.")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user