[MTIA Aten Backend][2/n] Migrate clamp ops(clamp.out/clamp_min.out/clamp_max.out) from out-of-tree to in-tree (#154015)

Summary:
# Context

See the first PR https://github.com/pytorch/pytorch/pull/153670

# This PR
1. Migrate 3 clamp ops from out-of-tree to in-tree(had to migrate the 3 ops altogether, because clamp.out calls all 3 stubs, which are also called by the other 2 ops):
- clamp.out
- clamp_min.out
- clamp_max.out
2. Also enabled structured kernel codegen for MTIA, which is needed by clamp
3. Also introduced the `--mtia` flag to torchgen to prevent OSS from gencoding MTIA code.(Otherwise we got such link error `lib/libtorch_cpu.so: undefined reference to at::detail::empty_mtia`)

Differential Revision: D74674418

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154015
Approved by: https://github.com/albanD, https://github.com/nautsimon
This commit is contained in:
Andy (An) Wang
2025-05-23 17:59:47 +00:00
committed by PyTorch MergeBot
parent bcb2125f0a
commit 0d62fd5c3c
8 changed files with 150 additions and 5 deletions

View File

@ -1535,7 +1535,6 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {
// Nothing beyond this point is important for meta functions, so it's fine to exit early here.
// Extend the condition to MAIA tesnors as MAIA tensors also don't have storage.
if (privateuse1_without_storage ||
common_device_.type() == DeviceType::MTIA ||
common_device_.type() == DeviceType::XLA ||
common_device_.type() == DeviceType::IPU ||
common_device_.type() == DeviceType::Lazy ||

View File

@ -0,0 +1,86 @@
#include <ATen/Context.h>
#include <ATen/EmptyTensor.h>
#include <ATen/native/mtia/EmptyTensor.h>
#include <c10/core/Allocator.h>
#include <c10/core/DeviceGuard.h>
namespace at::detail {
at::Allocator* GetMTIAAllocator() {
return GetAllocator(DeviceType::MTIA);
}
TensorBase empty_mtia(
IntArrayRef size,
ScalarType dtype,
std::optional<Device> device_opt,
std::optional<c10::MemoryFormat> memory_format_opt) {
at::globalContext().lazyInitDevice(c10::DeviceType::MTIA);
const auto device = device_or_default(device_opt);
TORCH_INTERNAL_ASSERT(device.is_mtia());
const DeviceGuard device_guard(device);
auto* allocator = GetMTIAAllocator();
constexpr c10::DispatchKeySet mtia_dks(c10::DispatchKey::MTIA);
return at::detail::empty_generic(
size, allocator, mtia_dks, dtype, memory_format_opt);
}
TensorBase empty_mtia(
IntArrayRef size,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt) {
const auto dtype = dtype_or_default(dtype_opt);
return at::detail::empty_mtia(size, dtype, device_opt, memory_format_opt);
}
TensorBase empty_mtia(IntArrayRef size, const TensorOptions& options) {
return at::detail::empty_mtia(
size,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt(),
options.memory_format_opt());
}
TensorBase empty_strided_mtia(
IntArrayRef size,
IntArrayRef stride,
ScalarType dtype,
std::optional<Device> device_opt) {
at::globalContext().lazyInitDevice(c10::DeviceType::MTIA);
const auto device = device_or_default(device_opt);
const DeviceGuard device_guard(device);
auto* allocator = GetMTIAAllocator();
constexpr c10::DispatchKeySet mtia_dks(c10::DispatchKey::MTIA);
return at::detail::empty_strided_generic(
size, stride, allocator, mtia_dks, dtype);
}
TensorBase empty_strided_mtia(
IntArrayRef size,
IntArrayRef stride,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt) {
const auto dtype = dtype_or_default(dtype_opt);
return at::detail::empty_strided_mtia(size, stride, dtype, device_opt);
}
TensorBase empty_strided_mtia(
IntArrayRef size,
IntArrayRef stride,
const TensorOptions& options) {
return at::detail::empty_strided_mtia(
size,
stride,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt());
}
} // namespace at::detail

View File

@ -0,0 +1,42 @@
#pragma once
#include <ATen/core/TensorBase.h>
namespace at::detail {
TensorBase empty_mtia(
IntArrayRef size,
ScalarType dtype,
std::optional<Device> device_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
TensorBase empty_mtia(
IntArrayRef size,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
TensorBase empty_mtia(IntArrayRef size, const TensorOptions& options);
TensorBase empty_strided_mtia(
IntArrayRef size,
IntArrayRef stride,
ScalarType dtype,
std::optional<Device> device_opt);
TensorBase empty_strided_mtia(
IntArrayRef size,
IntArrayRef stride,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt);
TensorBase empty_strided_mtia(
IntArrayRef size,
IntArrayRef stride,
const TensorOptions& options);
} // namespace at::detail

View File

@ -1548,7 +1548,7 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: clamp_out
CPU, CUDA, MTIA: clamp_out
MPS: clamp_out_mps
tags: pointwise
@ -1588,7 +1588,7 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: clamp_max_out
CPU, CUDA, MTIA: clamp_max_out
MPS: clamp_max_out_mps
tags: pointwise
@ -1628,7 +1628,7 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: clamp_min_out
CPU, CUDA, MTIA: clamp_min_out
MPS: clamp_min_out_mps
tags: pointwise

View File

@ -72,7 +72,7 @@ def define_targets(rules):
"--install_dir=$(RULEDIR)",
"--source-path aten/src/ATen",
"--aoti_install_dir=$(RULEDIR)/torch/csrc/inductor/aoti_torch/generated"
] + (["--static_dispatch_backend CPU"] if rules.is_cpu_static_dispatch_build() else []))
] + (["--static_dispatch_backend CPU"] if rules.is_cpu_static_dispatch_build() else []) + ["--mtia"])
gen_aten_outs_cuda = (
GENERATED_H_CUDA + GENERATED_CPP_CUDA + GENERATED_AOTI_CUDA_CPP +

View File

@ -66,6 +66,8 @@ def gen_registration_headers(
elif backend_index.dispatch_key == DispatchKey.XPU:
# XPU specific, this header resides in third_party/torch-xpu-ops
headers.append("#include <ATen/xpu/EmptyTensor.h>")
elif backend_index.dispatch_key == DispatchKey.MTIA:
headers.append("#include <ATen/native/mtia/EmptyTensor.h>")
elif per_operator_headers:
headers += [
"#include <ATen/ops/empty.h>",
@ -92,6 +94,7 @@ def gen_empty_impl_names(
DispatchKey.CUDA,
DispatchKey.MPS,
DispatchKey.XPU,
DispatchKey.MTIA,
):
dispatch = str(backend_index.dispatch_key).lower()
empty_impl = f"at::detail::empty_{dispatch}"
@ -645,6 +648,7 @@ if (C10_UNLIKELY(maybe_proxy.has_value())) {
DispatchKey.CUDA,
DispatchKey.MPS,
DispatchKey.XPU,
DispatchKey.MTIA,
DispatchKey.CompositeExplicitAutogradNonFunctional,
)
return f"""{maybe_set_guard_line}
@ -724,6 +728,8 @@ resize_out(out, sizes, strides, options);
guard_field = "c10::OptionalDeviceGuard guard_;"
elif self.backend_index.dispatch_key == DispatchKey.XPU:
guard_field = "c10::OptionalDeviceGuard guard_;"
elif self.backend_index.dispatch_key == DispatchKey.MTIA:
guard_field = "c10::OptionalDeviceGuard guard_;"
else:
guard_field = ""

View File

@ -2820,6 +2820,11 @@ def main() -> None:
action="store_true",
help="Generate XPU registration code when set",
)
parser.add_argument(
"--mtia",
action="store_true",
help="Generate MTIA registration code when set",
)
# TODO: --op-registration-whitelist will be removed when all call-sites
# for gen.py are moved over to using the operator YAML file for mobile
@ -2918,6 +2923,12 @@ def main() -> None:
if DispatchKey.XPU in dispatch_keys:
del dispatch_keys[dispatch_keys.index(DispatchKey.XPU)]
if not options.mtia:
ignore_keys.add(DispatchKey.MTIA)
if DispatchKey.MTIA in dispatch_keys:
del dispatch_keys[dispatch_keys.index(DispatchKey.MTIA)]
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
native_functions, backend_indices = (

View File

@ -271,6 +271,7 @@ STRUCTURED_DISPATCH_KEYS = {
DispatchKey.CUDA,
DispatchKey.CPU,
DispatchKey.XPU,
DispatchKey.MTIA,
}
UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}