mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable max.unary_out (#86855)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86855 Approved by: https://github.com/jerryzh168, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
25811663af
commit
cff333bdb5
@ -34,9 +34,16 @@ Tensor max(const Tensor &self) {
|
||||
}
|
||||
|
||||
Tensor& max_unary_out(const Tensor &self, Tensor& out) {
|
||||
Tensor tmp_output = at::max(self);
|
||||
at::native::resize_output(out, tmp_output.sizes());
|
||||
out.copy_(tmp_output);
|
||||
// First check if the devices match (CPU vs GPU)
|
||||
TORCH_CHECK(self.device() == out.device());
|
||||
|
||||
TORCH_CHECK(canCast(
|
||||
typeMetaToScalarType(self.dtype()),
|
||||
typeMetaToScalarType(out.dtype())));
|
||||
|
||||
at::native::resize_output(out, {});
|
||||
|
||||
max_all_stub(self.device().type(), out, self.contiguous());
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -8789,13 +8789,6 @@
|
||||
MPS: max_mps
|
||||
QuantizedCPU: max_quantized_cpu
|
||||
|
||||
# Not to be confused with binary op `max.out`. Commented because of failed CI
|
||||
# FIXME: enable this
|
||||
#- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
# device_check: NoCheck # TensorIterator
|
||||
# dispatch:
|
||||
# CompositeExplicitAutograd: max_unary_out
|
||||
|
||||
- func: fmax(Tensor self, Tensor other) -> Tensor
|
||||
structured_delegate: fmax.out
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -8831,6 +8824,12 @@
|
||||
- func: max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
|
||||
- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CPU, CUDA: max_unary_out
|
||||
QuantizedCPU: max_quantized_unary_out
|
||||
|
||||
- func: minimum(Tensor self, Tensor other) -> Tensor
|
||||
structured_delegate: minimum.out
|
||||
device_check: NoCheck # TensorIterator
|
||||
|
@ -14,6 +14,19 @@ Tensor max_quantized_cpu(const Tensor& self) {
|
||||
return std::get<0>(self.reshape({-1}).max(/*dim=*/0));
|
||||
}
|
||||
|
||||
Tensor& max_quantized_unary_out(const Tensor& self, Tensor& out) {
|
||||
// TODO this implementation is inefficient for now.
|
||||
TORCH_CHECK(self.device() == out.device());
|
||||
|
||||
TORCH_CHECK(canCast(
|
||||
typeMetaToScalarType(self.dtype()),
|
||||
typeMetaToScalarType(out.dtype())));
|
||||
Tensor temp = max_quantized_cpu(self);
|
||||
at::native::resize_output(out, temp.sizes());
|
||||
out.copy_(temp);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor min_quantized_cpu(const Tensor& self) {
|
||||
return std::get<0>(self.reshape({-1}).min(/*dim=*/0));
|
||||
}
|
||||
|
@ -115,11 +115,17 @@ def meta_index_select_out(self, dim, index, out):
|
||||
return out.copy_(torch.index_select(self, dim, index))
|
||||
|
||||
|
||||
@register_meta([aten.max.default, aten.min.default])
|
||||
@register_meta([aten.max.default, aten.max.unary_out])
|
||||
@out_wrapper()
|
||||
def meta_max(self):
|
||||
return self.new_empty(())
|
||||
|
||||
|
||||
@register_meta([aten.min.default])
|
||||
def meta_min(self):
|
||||
return self.new_empty(())
|
||||
|
||||
|
||||
@register_meta(aten.angle.default)
|
||||
def meta_angle(self):
|
||||
if self.is_complex():
|
||||
|
@ -324,16 +324,23 @@ static bool varargsCanBeUsedAsList(
|
||||
!typevar_list;
|
||||
}
|
||||
|
||||
// Note (@zasdfgbnm):
|
||||
// This is a workaround for https://github.com/pytorch/pytorch/issues/47964
|
||||
// Currently JIT does not distinguish ScalarType vs int, so there is really
|
||||
// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to hardcode
|
||||
// the aten::view.dtype here to block this overload. This blocklist should be
|
||||
// removed when JIT fully suports ScalarType as its own type.
|
||||
bool isBlockListedSchema(const FunctionSchema& schema) {
|
||||
// Note (@zasdfgbnm):
|
||||
// This is a workaround for https://github.com/pytorch/pytorch/issues/47964
|
||||
// Currently JIT does not distinguish ScalarType vs int, so there is really
|
||||
// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to
|
||||
// hardcode the aten::view.dtype here to block this overload. This blocklist
|
||||
// should be removed when JIT fully suports ScalarType as its own type.
|
||||
if (schema.name() == "aten::view" && schema.overload_name() == "dtype") {
|
||||
return true;
|
||||
}
|
||||
// Note (@tugsbayasgalan)
|
||||
// TorchScript doesn't suport kwargs so this op collides with aten.max.others
|
||||
// since both of them have 2 Tensor inputs. Since we don't expect users to
|
||||
// use this op in TS, we just skip it
|
||||
if (schema.name() == "aten::max" && schema.overload_name() == "unary_out") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,8 @@ struct MatchedSchema {
|
||||
std::string schema_name;
|
||||
};
|
||||
|
||||
TORCH_API bool isBlockListedSchema(const FunctionSchema& schema);
|
||||
|
||||
TORCH_API MatchedSchema matchSchema(
|
||||
const ::c10::FunctionSchema& schema,
|
||||
const SourceRange& loc,
|
||||
|
@ -1009,6 +1009,9 @@ Value* Node::namedInput(Symbol name) const {
|
||||
}
|
||||
|
||||
bool Node::matches(const FunctionSchema& schema) const {
|
||||
if (isBlockListedSchema(schema)) {
|
||||
return false;
|
||||
}
|
||||
// wrong name
|
||||
if (kind().toQualString() != schema.name()) {
|
||||
return false;
|
||||
|
@ -9901,7 +9901,7 @@ op_db: List[OpInfo] = [
|
||||
OpInfo('max',
|
||||
variant_test_name='reduction_no_dim',
|
||||
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
|
||||
supports_out=False,
|
||||
supports_out=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_max_min_reduction_no_dim,
|
||||
|
Reference in New Issue
Block a user