Allow functionalization to work with optional mutable (#114803)

Summary: - Added functionalization to allow Optionals

Test Plan: CI tests.

Reviewed By: zou3519

Differential Revision: D51209981

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114803
Approved by: https://github.com/zou3519
This commit is contained in:
Flavio Sales Truzzi
2023-11-30 23:48:03 +00:00
committed by PyTorch MergeBot
parent 7b3e45be59
commit ad09d81694
2 changed files with 52 additions and 12 deletions

View File

@ -428,7 +428,6 @@ class TestPythonRegistration(TestCase):
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
schemas = [
'foo(Tensor x, Tensor(a!)? y) -> ()',
'foo(Tensor x, Tensor(a!)[] y) -> ()',
'foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)',
'foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))',
@ -466,7 +465,7 @@ class TestPythonRegistration(TestCase):
# check rest of functional_result is the mutated args
mutated_args = [maybe_mutated_arg for maybe_mutated_arg, arg in zip(cloned_args, args)
if not torch.allclose(maybe_mutated_arg, arg)]
if not(maybe_mutated_arg is not None and arg is not None and torch.allclose(maybe_mutated_arg, arg))]
self.assertEqual(flat_functional_result[len(flat_mutable_result):], mutated_args)
# check that functionalization kernel was indeed registered
@ -504,6 +503,33 @@ class TestPythonRegistration(TestCase):
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
def test_register_functional_op_with_optional(self):
lib = Library(self.test_ns, 'FRAGMENT')
lib.define('foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()')
def foo_impl(x, y, z, w):
y.fill_(3.14)
z.fill_(2.71)
if w is not None:
w.fill_(1.618)
lib.impl('foo', foo_impl, 'CPU')
register_functional_op(
lib,
'foo_functional',
getattr(torch.ops, self.test_ns).foo.default)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, None))
def test_register_functional_op_one_return(self):
lib = Library(self.test_ns, 'FRAGMENT')
lib.define('foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor')

View File

@ -1,12 +1,21 @@
import torch
from torch.library import Library
from torch._ops import OpOverload
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseTy, BaseType
from torch._C import _ExcludeDispatchKeyGuard, DispatchKeySet, DispatchKey
from .autograd import autograd_not_implemented
import torch.utils._pytree as pytree
import weakref
import torch
import torch.utils._pytree as pytree
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
from torch._ops import OpOverload
from torch.library import Library
from torchgen.model import (
BaseTy,
BaseType,
FunctionSchema,
OperatorName,
OptionalType,
SchemaKind,
)
from .autograd import autograd_not_implemented
def register_functional_op(
lib: Library,
@ -66,7 +75,7 @@ def construct_functional_impl(mutable_op):
extra_rets = []
for is_write, arg in zip(mutable_args(mutable_op), args):
if is_write:
cloned = arg.clone()
cloned = arg.clone() if arg is not None else None
new_args.append(cloned)
extra_rets.append(cloned)
else:
@ -117,6 +126,8 @@ def construct_functionalization_kernel(mutable_op, functional_op):
if is_write]
assert len(new_values_to_propagate) == len(inputs_to_replace)
for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
if (arg is None and new_value is None) or (arg is not None and new_value is not None):
continue
torch._C._propagate_xla_data(arg, new_value)
torch._C._replace_(arg, new_value)
torch._C._commit_update(arg)
@ -156,9 +167,12 @@ def validate(mutable_op: OpOverload):
"not return the mutated value or aliases)")
for arg in schema.arguments.flat_all:
# construct_functionalization_kernel assumes this for simplicity
if arg.type.is_tensor_like() and arg.type != BaseType(BaseTy.Tensor):
if arg.type.is_tensor_like() and (
arg.type != BaseType(BaseTy.Tensor)
and arg.type != OptionalType(BaseType(BaseTy.Tensor))
):
raise NotImplementedError(
"NYI: register_functional_op(op) where op accepts Optional or List of tensors."
"NYI: register_functional_op(op) where op has a List[Tensor] input."
"Please file an issue.")