From cd7de1f4fa112a3c2d649cae37dfd88fab7f985d Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Mon, 16 Dec 2024 13:28:14 -0800 Subject: [PATCH] remove allow-untyped-defs for torch/masked/maskedtensor/creation.py (#143321) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143321 Approved by: https://github.com/laithsakka --- torch/masked/maskedtensor/creation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/masked/maskedtensor/creation.py b/torch/masked/maskedtensor/creation.py index a013ef1beb66..35c8e3d2aa94 100644 --- a/torch/masked/maskedtensor/creation.py +++ b/torch/masked/maskedtensor/creation.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from .core import MaskedTensor @@ -15,9 +14,11 @@ __all__ = [ # torch.as_tensor - differentiable constructor that preserves the autograd history -def masked_tensor(data, mask, requires_grad=False): +def masked_tensor( + data: object, mask: object, requires_grad: bool = False +) -> MaskedTensor: return MaskedTensor(data, mask, requires_grad) -def as_masked_tensor(data, mask): +def as_masked_tensor(data: object, mask: object) -> MaskedTensor: return MaskedTensor._from_values(data, mask)