remove allow-untyped-defs for torch/masked/maskedtensor/creation.py (#143321)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143321
Approved by: https://github.com/laithsakka
This commit is contained in:
bobrenjc93
2024-12-16 13:28:14 -08:00
committed by PyTorch MergeBot
parent 4d90c487d8
commit cd7de1f4fa

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from .core import MaskedTensor
@ -15,9 +14,11 @@ __all__ = [
# torch.as_tensor - differentiable constructor that preserves the autograd history
def masked_tensor(data, mask, requires_grad=False):
def masked_tensor(
data: object, mask: object, requires_grad: bool = False
) -> MaskedTensor:
return MaskedTensor(data, mask, requires_grad)
def as_masked_tensor(data, mask):
def as_masked_tensor(data: object, mask: object) -> MaskedTensor:
return MaskedTensor._from_values(data, mask)