mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
remove allow-untyped-defs for torch/masked/maskedtensor/creation.py (#143321)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143321 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d90c487d8
commit
cd7de1f4fa
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
from .core import MaskedTensor
|
||||
@ -15,9 +14,11 @@ __all__ = [
|
||||
# torch.as_tensor - differentiable constructor that preserves the autograd history
|
||||
|
||||
|
||||
def masked_tensor(data, mask, requires_grad=False):
|
||||
def masked_tensor(
|
||||
data: object, mask: object, requires_grad: bool = False
|
||||
) -> MaskedTensor:
|
||||
return MaskedTensor(data, mask, requires_grad)
|
||||
|
||||
|
||||
def as_masked_tensor(data, mask):
|
||||
def as_masked_tensor(data: object, mask: object) -> MaskedTensor:
|
||||
return MaskedTensor._from_values(data, mask)
|
||||
|
Reference in New Issue
Block a user