mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add decompositions for zero_, fill_, new_full, new_zeros, new_ones (#82332)
This makes symbolic tracing tests for logsigmoid and xlogy start working again. While I'm at it, add pin_memory and layout kwargs to empty; but they don't actually do anything and raise an error if they are non standard. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/82332 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
4a000ff03e
commit
98b9dfa129
@ -646,9 +646,6 @@ symbolic_tensor_failures = {
|
||||
xfail('nanmean', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
|
||||
xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('native_layer_norm', ''), # Unexpected type <class 'torch.SymbolicIntNode'> when computing elementwise type promot...
|
||||
xfail('new_full', ''),
|
||||
xfail('new_ones', ''),
|
||||
xfail('new_zeros', ''),
|
||||
xfail('nn.functional.adaptive_avg_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.adaptive_avg_pool2d', ''), # argument 'size' must be tuple of ints, but found element o...
|
||||
xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d.default - couldn't find symbolic meta func...
|
||||
@ -697,7 +694,6 @@ symbolic_tensor_failures = {
|
||||
xfail('nn.functional.layer_norm', ''), # Unexpected type <class 'torch.SymbolicIntNode'> when computing elementwise type...
|
||||
xfail('nn.functional.linear', ''), # aten.mv.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.local_response_norm', ''), # Tensors of type TensorImpl do not have numel
|
||||
xfail('nn.functional.logsigmoid', ''),
|
||||
xfail('nn.functional.margin_ranking_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
|
||||
xfail('nn.functional.max_pool2d', ''), # aten.max_pool2d_with_indices.default - couldn't find symbolic meta function/d...
|
||||
xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
|
||||
@ -831,7 +827,6 @@ symbolic_tensor_failures = {
|
||||
xfail('view', ''), # Tensors of type TensorImpl do not have numel
|
||||
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('where', ''), # expected predicate to be bool, got torch.float32
|
||||
xfail('xlogy', ''),
|
||||
xfail('zero_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('zeros_like', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition
|
||||
}
|
||||
|
@ -34,6 +34,12 @@ def torch_to_refs_map():
|
||||
torch.Tensor.__and__: torch._refs.bitwise_and,
|
||||
torch.Tensor.__or__: torch._refs.bitwise_or,
|
||||
torch.Tensor.__eq__: torch._refs.eq,
|
||||
torch.Tensor.new_empty: torch._refs.new_empty,
|
||||
torch.Tensor.new_full: torch._refs.new_full,
|
||||
torch.Tensor.new_zeros: torch._refs.new_zeros,
|
||||
torch.Tensor.new_ones: torch._refs.new_ones,
|
||||
torch.Tensor.fill_: torch._refs.fill_,
|
||||
torch.Tensor.zero_: torch._refs.zero_,
|
||||
# TODO: Should these methods be mapped some other way?
|
||||
torch.Tensor.copy_: torch._prims.copy_to,
|
||||
torch.Tensor.resize: torch._prims.resize,
|
||||
|
@ -488,6 +488,18 @@ def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
|
||||
return prims.fill(a, value)
|
||||
|
||||
|
||||
def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType:
|
||||
r = prims.fill(a, value)
|
||||
prims.copy_to(a, r)
|
||||
return a
|
||||
|
||||
|
||||
def zero_(a: TensorLikeType) -> TensorLikeType:
|
||||
r = prims.fill(a, 0)
|
||||
prims.copy_to(a, r)
|
||||
return a
|
||||
|
||||
|
||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
|
||||
def floor(a):
|
||||
return prims.floor(a)
|
||||
@ -2949,7 +2961,9 @@ def empty(
|
||||
*shape,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
layout: Optional[torch.layout] = None,
|
||||
requires_grad: bool = False,
|
||||
pin_memory: bool = False,
|
||||
memory_format: torch.memory_format = torch.contiguous_format,
|
||||
) -> TensorLikeType:
|
||||
check(
|
||||
@ -2971,7 +2985,13 @@ def empty(
|
||||
strides = utils.make_channels_last_2d_strides_for(shape)
|
||||
|
||||
return torch.empty_strided(
|
||||
shape, strides, dtype=dtype, device=device, requires_grad=requires_grad
|
||||
shape,
|
||||
strides,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
|
||||
@ -2998,13 +3018,66 @@ def new_empty(
|
||||
)
|
||||
|
||||
|
||||
# TODO: missing kwargs (e.g. layout)
|
||||
@register_decomposition(torch.ops.aten.new_zeros)
|
||||
def new_zeros(
|
||||
a: TensorLikeType,
|
||||
size: ShapeType,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: Optional[torch.layout] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
pin_memory: bool = False,
|
||||
) -> TensorLikeType:
|
||||
r = a.new_empty(
|
||||
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
||||
)
|
||||
r.zero_()
|
||||
return r
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.new_ones)
|
||||
def new_ones(
|
||||
a: TensorLikeType,
|
||||
size: ShapeType,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: Optional[torch.layout] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
pin_memory: bool = False,
|
||||
) -> TensorLikeType:
|
||||
r = a.new_empty(
|
||||
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
||||
)
|
||||
r.fill_(1)
|
||||
return r
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.new_full)
|
||||
def new_full(
|
||||
a: TensorLikeType,
|
||||
size: ShapeType,
|
||||
fill_value: NumberType,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: Optional[torch.layout] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
pin_memory: bool = False,
|
||||
) -> TensorLikeType:
|
||||
r = a.new_empty(
|
||||
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
||||
)
|
||||
r.fill_(fill_value) # type: ignore[arg-type]
|
||||
return r
|
||||
|
||||
|
||||
def empty_like(
|
||||
a: TensorLikeType,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
layout: Optional[torch.layout] = None,
|
||||
requires_grad: bool = False,
|
||||
pin_memory: bool = False,
|
||||
memory_format: torch.memory_format = torch.preserve_format,
|
||||
) -> TensorLikeType:
|
||||
|
||||
@ -3017,15 +3090,23 @@ def empty_like(
|
||||
return torch.empty(
|
||||
a.shape,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
requires_grad=requires_grad,
|
||||
pin_memory=pin_memory,
|
||||
memory_format=memory_format,
|
||||
)
|
||||
|
||||
# memory_format == torch.preserve_format
|
||||
strides = utils.compute_elementwise_output_strides(a)
|
||||
return torch.empty_strided(
|
||||
a.shape, strides, dtype=dtype, device=device, requires_grad=requires_grad
|
||||
a.shape,
|
||||
strides,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
|
||||
@ -3226,15 +3307,26 @@ def empty_strided(
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
layout: Optional[torch.layout] = None,
|
||||
requires_grad: bool = False,
|
||||
pin_memory: bool = False,
|
||||
) -> TensorLikeType:
|
||||
|
||||
if pin_memory:
|
||||
raise NotImplementedError("PrimTorch doesn't support pinned memory")
|
||||
if layout is not None and layout is not torch.strided:
|
||||
raise NotImplementedError(f"PrimTorch doesn't support layout={layout}")
|
||||
|
||||
shape = utils.extract_shape_from_varargs(shape)
|
||||
dtype = torch.get_default_dtype() if dtype is None else dtype
|
||||
device = torch.device("cpu") if device is None else device
|
||||
|
||||
return prims.empty_strided(
|
||||
shape, strides, dtype=dtype, device=device, requires_grad=requires_grad
|
||||
shape,
|
||||
strides,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
|
||||
|
@ -21976,6 +21976,21 @@ python_ref_db = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.new_full",
|
||||
torch_opinfo_name="new_full",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.new_ones",
|
||||
torch_opinfo_name="new_ones",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.new_zeros",
|
||||
torch_opinfo_name="new_zeros",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
#
|
||||
# Conditional Reference OpInfos
|
||||
#
|
||||
|
Reference in New Issue
Block a user