Compare commits

...

50 Commits

Author SHA1 Message Date
479627f26b update CK flash attention sources for v2 2025-10-17 18:09:11 +00:00
6b226773e9 Merge branch 'main' into hipify_without_caffe_attempt2 2025-10-17 14:52:14 +00:00
82a603414f Merge branch 'main' into hipify_without_caffe_attempt2 2025-10-14 21:37:39 +00:00
3d3e4be9be lint 2025-10-07 16:14:14 +00:00
a7c5524023 Merge branch 'viable/strict' into hipify_without_caffe_attempt2 2025-10-07 15:51:04 +00:00
e365285a57 Merge branch 'main' into hipify_without_caffe_attempt2 2025-10-01 16:27:56 +00:00
7d2afcf919 Merge branch 'main' into hipify_without_caffe_attempt2 2025-09-29 16:31:32 +00:00
7dcbb5f610 Merge branch 'main' into hipify_without_caffe_attempt2 2025-09-09 19:53:17 +00:00
64210febd2 lint 2025-09-09 19:46:28 +00:00
11e1c80965 add work-around for fbgemm hipify v1 vs v2 2025-09-08 23:54:35 +00:00
1bfb16e0c7 Merge branch 'main' into hipify_without_caffe_attempt2 2025-09-08 19:18:08 +00:00
1970cbfaec Merge branch 'main' into hipify_without_caffe_attempt2 2025-08-26 19:55:05 +00:00
e23ec3f287 Merge branch 'main' into hipify_without_caffe_attempt2 2025-08-19 20:28:23 +00:00
6b41f33303 fix build 2025-08-14 16:36:57 +00:00
b864122f8f Merge branch 'main' into hipify_without_caffe_attempt2 2025-08-12 21:20:48 +00:00
5a319f32f7 Merge branch 'main' into hipify_without_caffe_attempt2 2025-07-18 18:57:11 +00:00
dca5868a4b Merge branch 'main' into hipify_without_caffe_attempt2 2025-07-15 20:46:53 +00:00
71e55186e1 Merge branch 'main' into hipify_without_caffe_attempt2 2025-07-07 23:29:44 +00:00
ee1f754c73 add mappings from https://github.com/pytorch/pytorch/pull/157435 2025-07-07 23:25:09 +00:00
c3f73c3759 fix build 2025-07-01 00:56:19 +00:00
b09bba1ae2 Merge branch 'main' into hipify_without_caffe_attempt2 2025-06-30 21:54:49 +00:00
3c4c1aa965 Merge branch 'main' into hipify_without_caffe_attempt2 2025-06-28 22:48:54 +00:00
9cf12cb64f Merge branch 'main' into hipify_without_caffe_attempt2 2025-06-23 20:06:23 +00:00
e7838ab2ef Merge branch 'main' into hipify_without_caffe_attempt2 2025-06-16 22:26:08 +00:00
be81282147 lint 2025-06-16 22:23:43 +00:00
03a4032ba3 Merge branch 'main' into hipify_without_caffe_attempt2 2025-06-13 23:23:50 +00:00
9b60804b95 Merge branch 'main' into hipify_without_caffe_attempt2 2025-06-02 16:46:34 +00:00
dd3ca0b818 Merge branch 'main' into hipify_without_caffe_attempt2 2025-05-30 23:54:29 +00:00
3f7021ee7e new hipify mappings from #150578 2025-05-30 23:08:42 +00:00
bfc83368f1 missing hipify mappings after last merge from main 2025-05-20 23:43:41 +00:00
fa4eae9c4c Merge branch 'main' into hipify_without_caffe_attempt2 2025-05-20 21:44:50 +00:00
d21a727f03 Add cudaLaunchKernel to cuda_to_hip_mappings 2025-05-19 22:36:57 +00:00
84388230df Merge branch 'main' into hipify_without_caffe_attempt2 2025-05-19 22:15:20 +00:00
b25bf8345f do not map cublas to hipsolver 2025-05-19 22:07:32 +00:00
9cd9145a00 Merge branch 'main' into hipify_without_caffe_attempt2 2025-05-09 21:36:29 +00:00
1bd02632ec fix copy/paste typo from last commit 2025-05-09 21:30:11 +00:00
fced155a01 missing sparse mappings needed by #153262 in case of land race 2025-05-09 17:48:33 +00:00
43be93fa8c fix compile errors 2025-05-08 23:15:31 +00:00
0f0a5ccea3 Merge branch 'main' into hipify_without_caffe_attempt2 2025-05-07 18:19:12 +00:00
6c7cbe21f8 missing import os 2025-04-28 15:38:24 +00:00
7c7bee6737 restore _RCCL_HEADER workaround in cuda to hip mappings 2025-04-26 01:30:09 +00:00
52f5528cf0 constants.py updated with hipify_torch changes 2025-04-26 01:22:42 +00:00
c55ca807a5 remove unused Deprecated.h header 2025-04-26 01:09:07 +00:00
f0fca2f739 add missing mappings from hipify_torch project 2025-04-26 00:44:46 +00:00
e538b5052c use FutureWarning since DeprecationWarning is suppressed by default 2025-04-25 22:24:16 +00:00
c6f2cddbba change deprecated comments to warnings.warn 2025-04-25 21:42:44 +00:00
7d7e3fc5c0 do not deprecate unsafe_set_device 2025-04-25 21:31:28 +00:00
c5a07e0770 revert change to MultinomialKernel.cu 2025-04-25 21:28:08 +00:00
5d6943ddd8 update hipify version to 2.0.0 2025-04-21 22:24:34 +00:00
4cb196c3cd [reland][ROCm] remove caffe2 from hipify 2025-04-21 22:05:27 +00:00
36 changed files with 3515 additions and 10573 deletions

View File

@ -253,7 +253,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
scan_op,
num_items,
at::cuda::getCurrentCUDAStream());
C10_HIP_KERNEL_LAUNCH_CHECK();
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
@ -531,7 +531,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
init_value,
num_items,
at::cuda::getCurrentCUDAStream());
C10_HIP_KERNEL_LAUNCH_CHECK();
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,

View File

@ -1,239 +0,0 @@
#pragma once
#include <c10/hip/HIPCachingAllocator.h>
// Use of c10::hip namespace here makes hipification easier, because
// I don't have to also fix namespaces. Sorry!
namespace c10::hip {
// Takes a valid HIPAllocator (of any sort) and turns it into
// an allocator pretending to be a CUDA allocator. See
// Note [Masquerading as CUDA]
class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllocator {
HIPCachingAllocator::HIPAllocator* allocator_;
public:
explicit HIPAllocatorMasqueradingAsCUDA(HIPCachingAllocator::HIPAllocator* allocator)
: allocator_(allocator) {}
virtual ~HIPAllocatorMasqueradingAsCUDA() = default;
// From c10::Allocator
DataPtr allocate(size_t size) override {
DataPtr r = allocator_->allocate(size);
r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index()));
return r;
}
bool is_simple_data_ptr(const DataPtr& data_ptr) const override {
return allocator_->is_simple_data_ptr(data_ptr);
}
DeleterFnPtr raw_deleter() const override {
return allocator_->raw_deleter();
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
allocator_->copy_data(dest, src, count);
}
// From DeviceAllocator
bool initialized() override {
return allocator_->initialized();
}
void emptyCache(MempoolId_t mempool_id = {0, 0}) override {
allocator_->emptyCache(mempool_id);
}
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
HIPStream hip_stream = HIPStream(stream);
recordStream(ptr, hip_stream);
}
CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) override {
return allocator_->getDeviceStats(device);
}
void resetAccumulatedStats(c10::DeviceIndex device) override {
allocator_->resetAccumulatedStats(device);
}
void resetPeakStats(c10::DeviceIndex device) override {
allocator_->resetPeakStats(device);
}
// From CUDAAllocator
void* raw_alloc(size_t nbytes) override {
return allocator_->raw_alloc(nbytes);
}
void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) override {
return allocator_->raw_alloc_with_stream(nbytes, stream);
}
void raw_delete(void* ptr) override {
allocator_->raw_delete(ptr);
}
void init(int device_count) override {
allocator_->init(device_count);
}
double getMemoryFraction(c10::DeviceIndex device) override {
return allocator_->getMemoryFraction(device);
}
void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
allocator_->setMemoryFraction(fraction, device);
}
std::vector<HIPCachingAllocator::StreamSegmentSize> getExpandableSegmentSizes(c10::DeviceIndex device) override {
return allocator_->getExpandableSegmentSizes(device);
}
void enable(bool value) override {
allocator_->enable(value);
}
bool isEnabled() const override {
return allocator_->isEnabled();
}
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override {
allocator_->cacheInfo(device, largestBlock);
}
void* getBaseAllocation(void* ptr, size_t* size) override {
return allocator_->getBaseAllocation(ptr, size);
}
void recordStream(const DataPtr& ptr, HIPStream stream) override {
allocator_->recordStream(ptr, stream);
}
HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) override {
return allocator_->snapshot(mempool_id);
}
void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
std::function<bool(hipStream_t)> filter) override {
allocator_->beginAllocateToPool(device, mempool_id, filter);
}
void endAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id) override {
allocator_->endAllocateToPool(device, mempool_id);
}
void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
allocator_->releasePool(device, mempool_id);
}
int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) override {
return allocator_->getPoolUseCount(device, mempool_id);
}
void createOrIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
HIPAllocator* allocator = nullptr) override {
allocator_->createOrIncrefPool(device, mempool_id, allocator);
}
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
allocator_->setUseOnOOM(device, mempool_id);
}
bool checkPoolLiveAllocations(
c10::DeviceIndex device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) override {
return allocator_->checkPoolLiveAllocations(device, mempool_id, expected_live_allocations);
}
HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) override {
return allocator_->shareIpcHandle(ptr);
}
std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
return allocator_->getIpcDevPtr(handle);
}
bool isHistoryEnabled() override {
return allocator_->isHistoryEnabled();
}
void recordHistory(
bool enabled,
HIPCachingAllocator::CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
HIPCachingAllocator::RecordContext when,
bool clearHistory) override {
allocator_->recordHistory(enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
}
void recordAnnotation(
const std::vector<std::pair<std::string, std::string>>& md) override {
allocator_->recordAnnotation(md);
}
void pushCompileContext(std::string& md) override {
allocator_->pushCompileContext(md);
}
void popCompileContext() override {
allocator_->popCompileContext();
}
void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) override {
allocator_->attachOutOfMemoryObserver(observer);
}
void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) override {
allocator_->attachAllocatorTraceTracker(tracker);
}
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override {
allocator_->enablePeerAccess(dev, dev_to_access);
}
hipError_t memcpyAsync(
void* dst,
int dstDevice,
const void* src,
int srcDevice,
size_t count,
hipStream_t stream,
bool p2p_enabled) override {
return allocator_->memcpyAsync(dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
}
std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
c10::DeviceIndex device,
MempoolId_t id) override {
return allocator_->getCheckpointState(device, id);
}
HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
c10::DeviceIndex device,
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) override {
auto cpd = allocator_->setCheckpointPoolState(device, pps);
for (auto& ptr : cpd.dataptrs_allocd) {
ptr.unsafe_set_device(Device(c10::DeviceType::CUDA, ptr.device().index()));
}
return cpd;
}
std::string name() override {
return allocator_->name();
}
};
} // namespace c10::hip

View File

@ -1,18 +0,0 @@
#include <c10/hip/HIPCachingAllocator.h>
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
namespace c10 { namespace hip {
namespace HIPCachingAllocatorMasqueradingAsCUDA {
HIPCachingAllocator::HIPAllocator* get() {
static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
return &allocator;
}
void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream) {
HIPCachingAllocator::recordStream(ptr, stream.hip_stream());
}
} // namespace HIPCachingAllocatorMasqueradingAsCUDA
}} // namespace c10::hip

View File

@ -1,194 +0,0 @@
#pragma once
#include <c10/hip/HIPCachingAllocator.h>
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
namespace c10 {
// forward declaration
class DataPtr;
namespace hip {
namespace HIPCachingAllocatorMasqueradingAsCUDA {
C10_HIP_API HIPCachingAllocator::HIPAllocator* get();
C10_HIP_API void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream);
inline void* raw_alloc(size_t nbytes) {
return get()->raw_alloc(nbytes);
}
inline void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) {
return get()->raw_alloc_with_stream(nbytes, stream);
}
inline void raw_delete(void* ptr) {
return get()->raw_delete(ptr);
}
inline void init(int device_count) {
return get()->init(device_count);
}
inline double getMemoryFraction(c10::DeviceIndex device) {
return get()->getMemoryFraction(device);
}
inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
return get()->setMemoryFraction(fraction, device);
}
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
return get()->emptyCache(mempool_id);
}
inline void enable(bool value) {
return get()->enable(value);
}
inline bool isEnabled() {
return get()->isEnabled();
}
inline void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) {
return get()->cacheInfo(device, largestBlock);
}
inline void* getBaseAllocation(void* ptr, size_t* size) {
return get()->getBaseAllocation(ptr, size);
}
inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) {
return get()->getDeviceStats(device);
}
inline void resetAccumulatedStats(c10::DeviceIndex device) {
return get()->resetAccumulatedStats(device);
}
inline void resetPeakStats(c10::DeviceIndex device) {
return get()->resetPeakStats(device);
}
inline HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) {
return get()->snapshot(mempool_id);
}
inline std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
c10::DeviceIndex device,
MempoolId_t id) {
return get()->getCheckpointState(device, id);
}
inline HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
c10::DeviceIndex device,
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) {
return get()->setCheckpointPoolState(device, std::move(pps));
}
inline void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
std::function<bool(hipStream_t)> filter) {
get()->beginAllocateToPool(device, mempool_id, std::move(filter));
}
inline void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->endAllocateToPool(device, mempool_id);
}
inline void recordHistory(
bool enabled,
HIPCachingAllocator::CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
HIPCachingAllocator::RecordContext when,
bool clearHistory) {
return get()->recordHistory(
enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
}
inline void recordAnnotation(
const std::vector<std::pair<std::string, std::string>>& md) {
return get()->recordAnnotation(md);
}
inline void pushCompileContext(std::string& md) {
return get()->pushCompileContext(md);
}
inline void popCompileContext() {
return get()->popCompileContext();
}
inline bool isHistoryEnabled() {
return get()->isHistoryEnabled();
}
inline bool checkPoolLiveAllocations(
c10::DeviceIndex device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
return get()->checkPoolLiveAllocations(
device, mempool_id, expected_live_allocations);
}
inline void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) {
return get()->attachOutOfMemoryObserver(std::move(observer));
}
inline void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) {
return get()->attachAllocatorTraceTracker(std::move(tracker));
}
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
return get()->releasePool(device, mempool_id);
}
inline void createOrIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
HIPCachingAllocator::HIPAllocator* allocator_ptr = nullptr) {
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
}
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->setUseOnOOM(device, mempool_id);
}
inline int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) {
return get()->getPoolUseCount(device, mempool_id);
}
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
return get()->getIpcDevPtr(std::move(handle));
}
inline HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) {
return get()->shareIpcHandle(ptr);
}
inline std::string name() {
return get()->name();
}
inline hipError_t memcpyAsync(
void* dst,
int dstDevice,
const void* src,
int srcDevice,
size_t count,
hipStream_t stream,
bool p2p_enabled) {
return get()->memcpyAsync(
dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
}
inline void enablePeerAccess(
c10::DeviceIndex dev,
c10::DeviceIndex dev_to_access) {
return get()->enablePeerAccess(dev, dev_to_access);
}
} // namespace HIPCachingAllocatorMasqueradingAsCUDA
} // namespace hip
} // namespace c10

View File

@ -1,14 +0,0 @@
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
// THIS IS A MASSIVE HACK. This will BREAK you Caffe2 CUDA code if you
// load ATen_hip, even if you don't ever actually use ATen_hip at runtime.
//
// If you ever link ATen_hip statically into the full library along
// with ATen_cuda (libomnibus), the loading order of this versus the regular
// ATen_cuda will be nondeterministic, and you'll nondeterministically get
// one or the other. (This will be obvious because all of your code
// will fail.)
//
// This hack can be removed once PyTorch is out-of-place HIPified, and
// doesn't pretend CUDA is HIP.
C10_REGISTER_GUARD_IMPL(CUDA, at::cuda::HIPGuardImplMasqueradingAsCUDA)

View File

@ -1,383 +0,0 @@
#pragma once
#include <ATen/hip/HIPConfig.h>
// The includes of HIPGuard.h
#include <c10/hip/impl/HIPGuardImpl.h>
#include <c10/hip/HIPMacros.h>
#include <c10/core/DeviceType.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/InlineStreamGuard.h>
#include <c10/util/Exception.h>
#include <c10/hip/impl/HIPGuardImpl.h>
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
// Use of c10::hip namespace here makes hipification easier, because
// I don't have to also fix namespaces. Sorry!
namespace c10 { namespace hip {
// Note [Masquerading as CUDA]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~
// c10_hip is very easy to understand: it is HIPified from c10_cuda,
// and anywhere you said CUDA, the source code now says HIP. HIPified
// PyTorch is much harder to understand: it is HIPified from regular
// PyTorch, yes, but NO source-to-source translation from CUDA to
// HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP".
// For example, when you use HIPified PyTorch, you say x.cuda() to
// move a tensor onto ROCm device. We call this situation "HIP
// masquerading as CUDA".
//
// This leads to a very awkward situation when we want to call c10_hip
// code from PyTorch, since c10_hip is expecting things to be called
// HIP, but PyTorch is calling them CUDA (masquerading as HIP). To
// fix this impedance mismatch, we have MasqueradingAsCUDA variants
// for all c10_hip classes. These translate between the "HIP" and "CUDA
// masquerading as HIP" worlds. For example,
// HIPGuardImplMasqueradingAsCUDA (this file) provides something like a
// HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type()
// returns CUDA, getDevice() reports the current HIP device as a CUDA
// device.)
//
// We should be able to delete all of these classes entirely once
// we switch PyTorch to calling a HIP a HIP.
//
// When you add a new MasqueradingAsCUDA class/function, you need to
// also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py
//
//
//
// By the way, note that the cpp file associated with this also
// *overwrites* the entry in the DeviceGuardImpl registry for CUDA with
// this HIP implementation.
struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::CUDA;
HIPGuardImplMasqueradingAsCUDA() {}
HIPGuardImplMasqueradingAsCUDA(c10::DeviceType t) {
TORCH_INTERNAL_ASSERT(t == c10::DeviceType::CUDA);
}
c10::DeviceType type() const override {
return c10::DeviceType::CUDA;
}
Device exchangeDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_cuda());
Device old_device = getDevice();
if (old_device.index() != d.index()) {
C10_HIP_CHECK(hipSetDevice(d.index()));
}
return old_device;
}
Device getDevice() const override {
int device;
C10_HIP_CHECK(hipGetDevice(&device));
return Device(c10::DeviceType::CUDA, device);
}
void setDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_cuda());
C10_HIP_CHECK(hipSetDevice(d.index()));
}
void uncheckedSetDevice(Device d) const noexcept override {
C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
}
Stream getStream(Device d) const override {
return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap();
}
Stream getDefaultStream(Device d) const override {
return getDefaultHIPStreamMasqueradingAsCUDA(d.index());
}
Stream getNewStream(Device d, int priority = 0) const override {
return getStreamFromPoolMasqueradingAsCUDA(priority, d.index());
}
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override {
return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index());
}
Stream exchangeStream(Stream s) const override {
HIPStreamMasqueradingAsCUDA cs(s);
auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index());
setCurrentHIPStreamMasqueradingAsCUDA(cs);
return old_stream.unwrap();
}
DeviceIndex deviceCount() const noexcept override {
int deviceCnt;
hipError_t _err;
_err = hipGetDeviceCount(&deviceCnt);
if(_err != hipErrorNoDevice && _err != hipSuccess)
C10_HIP_CHECK(_err);
return deviceCnt;
}
// Event-related functions
// Note: hipEventCreateWithFlags should be called on the same device as
// the recording stream's device.
void createEvent(
hipEvent_t* hip_event,
const EventFlag flag) const {
// Maps PyTorch's Event::Flag to HIP flag
auto hip_flag = hipEventDefault;
switch (flag) {
case EventFlag::PYTORCH_DEFAULT:
hip_flag = hipEventDisableTiming;
break;
case EventFlag::BACKEND_DEFAULT:
hip_flag = hipEventDefault;
break;
default:
TORCH_CHECK(false, "HIP event received unknown flag");
}
C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag));
}
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override {
if (!event) return;
auto hip_event = static_cast<hipEvent_t>(event);
int orig_device;
C10_HIP_CHECK_WARN(hipGetDevice(&orig_device));
C10_HIP_CHECK_WARN(hipSetDevice(device_index));
C10_HIP_CHECK_WARN(hipEventDestroy(hip_event));
C10_HIP_CHECK_WARN(hipSetDevice(orig_device));
}
void record(void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");
hipEvent_t hip_event = static_cast<hipEvent_t>(*event);
HIPStreamMasqueradingAsCUDA hip_stream{stream};
// Moves to stream's device to record
const auto orig_device = getDevice();
setDevice(stream.device());
// Creates the event (lazily)
if (!hip_event) createEvent(&hip_event, flag);
C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream));
// Makes the void* point to the (possibly just allocated) HIP event
*event = hip_event;
// Resets device
setDevice(orig_device);
}
void block(
void* event,
const Stream& stream) const override {
if (!event) return;
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
HIPStreamMasqueradingAsCUDA hip_stream{stream};
const auto orig_device = getDevice();
setDevice(stream.device());
C10_HIP_CHECK(hipStreamWaitEvent(
hip_stream,
hip_event,
/*flags (must be zero)=*/ 0));
setDevice(orig_device);
}
bool queryEvent(void* event) const override {
if (!event) return true;
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
const hipError_t err = hipEventQuery(hip_event);
if (err != hipErrorNotReady) C10_HIP_CHECK(err);
else {
// ignore and clear the error if not ready
(void)hipGetLastError();
}
return (err == hipSuccess);
}
// Stream-related functions
bool queryStream(const Stream& stream) const override {
HIPStreamMasqueradingAsCUDA hip_stream{stream};
return hip_stream.query();
}
void synchronizeStream(const Stream& stream) const override {
HIPStreamMasqueradingAsCUDA hip_stream{stream};
hip_stream.synchronize();
}
void synchronizeEvent(void* event) const override {
if (!event)
return;
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
C10_HIP_CHECK(hipEventSynchronize(hip_event));
}
// Note: synchronizeDevice can be safely called from any device
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
int orig_device{-1};
C10_HIP_CHECK(hipGetDevice(&orig_device));
C10_HIP_CHECK(hipSetDevice(device_index));
C10_HIP_CHECK(hipDeviceSynchronize());
C10_HIP_CHECK(hipSetDevice(orig_device));
}
void recordDataPtrOnStream(
const c10::DataPtr& data_ptr,
const Stream& stream) const override {
HIPStreamMasqueradingAsCUDA hip_stream{stream};
HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream);
}
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
const override {
TORCH_CHECK(
event1 && event2,
"Both events must be recorded before calculating elapsed time.");
int orig_device;
C10_HIP_CHECK(hipGetDevice(&orig_device));
C10_HIP_CHECK(hipSetDevice(device_index));
hipEvent_t hip_event1 = static_cast<hipEvent_t>(event1);
hipEvent_t hip_event2 = static_cast<hipEvent_t>(event2);
float time_ms = 0;
// raise hipErrorNotReady if either event is recorded but not yet completed
C10_HIP_CHECK(hipEventElapsedTime(&time_ms, hip_event1, hip_event2));
C10_HIP_CHECK(hipSetDevice(orig_device));
return static_cast<double>(time_ms);
}
};
// All of the guards which have HIPGuardImpl burned in need to also have
// variants using HIPGuardImplMasqueradingAsCUDA.
/// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with
/// the correct InlineDeviceGuard burned in. Sorry about the
/// copy-pasting.
struct HIPGuardMasqueradingAsCUDA {
explicit HIPGuardMasqueradingAsCUDA() = delete;
explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {}
explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {}
HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete;
HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete;
HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete;
HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete;
void set_device(Device device) { guard_.set_device(device); }
void reset_device(Device device) { guard_.reset_device(device); }
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
Device original_device() const { return guard_.original_device(); }
Device current_device() const { return guard_.current_device(); }
private:
c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
};
struct OptionalHIPGuardMasqueradingAsCUDA {
explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {}
explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<Device> device_opt) : guard_(device_opt) {}
explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
void set_device(Device device) { guard_.set_device(device); }
void reset_device(Device device) { guard_.reset_device(device); }
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
std::optional<Device> original_device() const { return guard_.original_device(); }
std::optional<Device> current_device() const { return guard_.current_device(); }
void reset() { guard_.reset(); }
private:
c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
};
struct HIPStreamGuardMasqueradingAsCUDA {
explicit HIPStreamGuardMasqueradingAsCUDA() = delete;
explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
HIPStreamMasqueradingAsCUDA original_stream() const {
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream());
}
HIPStreamMasqueradingAsCUDA current_stream() const {
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream());
}
Device current_device() const { return guard_.current_device(); }
Device original_device() const { return guard_.original_device(); }
private:
c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
};
struct OptionalHIPStreamGuardMasqueradingAsCUDA {
explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {}
explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
explicit OptionalHIPStreamGuardMasqueradingAsCUDA(std::optional<Stream> stream_opt) : guard_(stream_opt) {}
OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
std::optional<HIPStreamMasqueradingAsCUDA> original_stream() const {
auto r = guard_.original_stream();
if (r.has_value()) {
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value());
} else {
return std::nullopt;
}
}
std::optional<HIPStreamMasqueradingAsCUDA> current_stream() const {
auto r = guard_.current_stream();
if (r.has_value()) {
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value());
} else {
return std::nullopt;
}
}
void reset() { guard_.reset(); }
private:
c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
};
struct HIPMultiStreamGuardMasqueradingAsCUDA {
explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef<HIPStreamMasqueradingAsCUDA> streams)
: guard_(unwrapStreams(streams)) {}
HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
private:
c10::impl::InlineMultiStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
static std::vector<Stream> unwrapStreams(ArrayRef<HIPStreamMasqueradingAsCUDA> hipStreams) {
std::vector<Stream> streams;
streams.reserve(hipStreams.size());
for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) {
streams.push_back(hipStream);
}
return streams;
}
};
}} // namespace c10::hip

View File

@ -1,135 +0,0 @@
#pragma once
#include <c10/hip/HIPStream.h>
// Use of c10::hip namespace here makes hipification easier, because
// I don't have to also fix namespaces. Sorry!
namespace c10 { namespace hip {
// See Note [Masquerading as CUDA] for motivation
class HIPStreamMasqueradingAsCUDA {
public:
enum Unchecked { UNCHECKED };
explicit HIPStreamMasqueradingAsCUDA(Stream stream)
: HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) {
// We did the coercion unchecked; check that it was right.
TORCH_CHECK(stream.device().is_cuda() /* !!! */);
}
explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream)
// Unsafely coerce the "CUDA" stream into a HIP stream
: stream_(
HIPStream(
Stream(
Stream::UNSAFE,
Device(c10::DeviceType::HIP, stream.device_index()),
stream.id())
)
) {}
// New constructor, just for this. Does NOT coerce.
explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {}
bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
return stream_ == other.stream_;
}
bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
return stream_ != other.stream_;
}
operator hipStream_t() const { return stream_.stream(); }
operator Stream() const {
// Unsafely coerce HIP stream into a "CUDA" stream
return Stream(Stream::UNSAFE, device(), id());
}
DeviceIndex device_index() const { return stream_.device_index(); }
// Unsafely coerce HIP device into CUDA device
c10::DeviceType device_type() const { return c10::DeviceType::CUDA; }
Device device() const {
// Unsafely coerce HIP device into CUDA device
return Device(c10::DeviceType::CUDA, stream_.device_index());
}
StreamId id() const { return stream_.id(); }
bool query() const { return stream_.query(); }
void synchronize() const { stream_.synchronize(); }
int priority() const { return stream_.priority(); }
hipStream_t stream() const { return stream_.stream(); }
Stream unwrap() const {
// Unsafely coerce HIP stream into "CUDA" stream
return Stream(Stream::UNSAFE, device(), id());
}
c10::StreamData3 pack3() const noexcept {
// Unsafely coerce HIP stream into "CUDA" stream before packing
return unwrap().pack3();
}
static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id,
DeviceIndex device_index,
c10::DeviceType device_type) {
// NB: constructor manages CUDA->HIP translation for us
return HIPStreamMasqueradingAsCUDA(Stream::unpack3(
stream_id, device_index, device_type));
}
static std::tuple<int, int> priority_range() { return HIPStream::priority_range(); }
// New method, gets the underlying HIPStream
HIPStream hip_stream() const { return stream_; }
private:
HIPStream stream_;
};
HIPStreamMasqueradingAsCUDA
inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) {
return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device));
}
HIPStreamMasqueradingAsCUDA
inline getStreamFromPoolMasqueradingAsCUDA(const int priority, DeviceIndex device = -1) {
return HIPStreamMasqueradingAsCUDA(getStreamFromPool(priority, device));
}
HIPStreamMasqueradingAsCUDA
inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) {
return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device));
}
inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index));
}
inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index));
}
inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) {
setCurrentHIPStream(stream.hip_stream());
}
inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) {
stream << s.hip_stream() << " (masquerading as CUDA)";
return stream;
}
}} // namespace c10::hip
namespace std {
template <>
struct hash<c10::hip::HIPStreamMasqueradingAsCUDA> {
size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept {
return std::hash<c10::Stream>{}(s.unwrap());
}
};
} // namespace std

View File

@ -39,7 +39,7 @@ using MIOpenPoolType = at::cuda::DeviceThreadHandlePool<
miopenHandle_t getMiopenHandle() {
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::hip::GetDevice(&device));
AT_CUDA_CHECK(at::cuda::GetDevice(&device));
// Thread local PoolWindows are lazily-initialized
// to avoid initialization issues that caused hangs on Windows.
@ -51,7 +51,7 @@ miopenHandle_t getMiopenHandle() {
pool->newPoolWindow());
auto handle = myPoolWindow->reserve(device);
MIOPEN_CHECK(miopenSetStream(handle, c10::hip::getCurrentHIPStream()));
MIOPEN_CHECK(miopenSetStream(handle, at::cuda::getCurrentCUDAStream()));
return handle;
}

View File

@ -5,9 +5,13 @@
#include <c10/macros/Macros.h>
#include <c10/util/MathConstants.h>
// ROCM hcc doesn't work well with using std:: in kernel functions
// ROCm hip compiler doesn't work well with using std:: in kernel functions
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
#if defined(__CUDA_ARCH__)
#include <c10/cuda/CUDAMathCompat.h>
#elif defined(__HIPCC__)
#include <c10/hip/HIPMathCompat.h>
#endif
#define compat_exp c10::cuda::compat::exp
#define compat_ceil c10::cuda::compat::ceil
#define compat_floor c10::cuda::compat::floor
@ -17,17 +21,6 @@
#define compat_tan c10::cuda::compat::tan
#define compat_abs c10::cuda::compat::abs
#define compat_log1p c10::cuda::compat::log1p
#elif defined(__HIPCC__)
#include <c10/hip/HIPMathCompat.h>
#define compat_exp c10::hip::compat::exp
#define compat_ceil c10::hip::compat::ceil
#define compat_floor c10::hip::compat::floor
#define compat_log c10::hip::compat::log
#define compat_pow c10::hip::compat::pow
#define compat_sqrt c10::hip::compat::sqrt
#define compat_tan c10::hip::compat::tan
#define compat_abs c10::hip::compat::abs
#define compat_log1p c10::hip::compat::log1p
#else
#define compat_exp std::exp
#define compat_ceil std::ceil

View File

@ -52,13 +52,14 @@ inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
#define MIN(X, Y) min_impl(X,Y)
#endif
// ROCM hcc doesn't work well with using std:: in kernel functions
// ROCm hip compiler doesn't work well with using std:: in kernel functions
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
#if defined(__CUDA_ARCH__)
#include <c10/cuda/CUDAMathCompat.h>
#define compat_pow c10::cuda::compat::pow
#elif defined(__HIPCC__)
#include <c10/hip/HIPMathCompat.h>
#define compat_pow c10::hip::compat::pow
#endif
#define compat_pow c10::cuda::compat::pow
#else
#define compat_pow std::pow
#endif

View File

@ -157,7 +157,7 @@ void bgemm_kernel_impl(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
auto stream = at::cuda::getCurrentHIPStream().stream();
auto stream = at::cuda::getCurrentCUDAStream().stream();
invoker.Run(argument, StreamConfig{stream, false});
}

View File

@ -11,7 +11,6 @@
#include <numeric>
#include <ATen/ATen.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <ATen/native/hip/ck_gemm.h>
#include <ATen/native/hip/ck_types.h>
@ -233,7 +232,7 @@ void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
}
auto stream = at::cuda::getCurrentHIPStream().stream();
auto stream = at::cuda::getCurrentCUDAStream().stream();
invoker.Run(argument, StreamConfig{stream, false});
}
@ -391,7 +390,7 @@ void gemm_impl_wmma(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
}
auto stream = at::cuda::getCurrentHIPStream().stream();
auto stream = at::cuda::getCurrentCUDAStream().stream();
#if 1
invoker.Run(argument, StreamConfig{stream, false});
#else

View File

@ -278,14 +278,14 @@ BenchmarkCache<size_t> bwd_filter_wssizes;
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = c10::hip::HIPCachingAllocator::raw_alloc(size);
data = c10::cuda::CUDACachingAllocator::raw_alloc(size);
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
c10::hip::HIPCachingAllocator::raw_delete(data);
c10::cuda::CUDACachingAllocator::raw_delete(data);
}
}
@ -587,7 +587,7 @@ void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) {
wsscache.insert(args.params, perfResults.memory);
if (at::native::_cudnn_get_conv_benchmark_empty_cache()) {
c10::hip::HIPCachingAllocator::emptyCache();
c10::cuda::CUDACachingAllocator::emptyCache();
}
}

View File

@ -76,14 +76,14 @@ namespace {
struct DropoutState {
DropoutState(size_t size) : size(size), data(NULL) {
data = c10::hip::HIPCachingAllocator::raw_alloc(size);
data = c10::cuda::CUDACachingAllocator::raw_alloc(size);
}
DropoutState(const DropoutState&) = delete;
DropoutState(DropoutState&&) = default;
DropoutState& operator=(DropoutState&&) = default;
~DropoutState() {
if (data) {
c10::hip::HIPCachingAllocator::raw_delete(data);
c10::cuda::CUDACachingAllocator::raw_delete(data);
}
}

View File

@ -59,8 +59,6 @@
#include <thrust/transform.h>
#include <thrust/unique.h>
#include <c10/cuda/CUDAMathCompat.h>
namespace at::native {
namespace {

View File

@ -37,7 +37,6 @@
#ifdef USE_FLASH_ATTENTION
#include <ATen/core/Tensor.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/HIPGraphsUtils.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
@ -162,7 +161,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
std::optional<int64_t> window_size_right,
const bool return_softmax,
const std::optional<at::Generator>& gen_) {
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto stream = at::cuda::getCurrentCUDAStream().stream();
check_gpu_arch(stream);
auto q_dtype = q.dtype();
@ -348,8 +347,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt");
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
check_gpu_arch(stream);
auto q_dtype = q.dtype();
@ -560,8 +559,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
const at::Tensor& philox_offset) {
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
check_gpu_arch(stream);
bool is_dropout = p_dropout > 0.0;
@ -793,8 +792,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
check_gpu_arch(stream);
bool is_dropout = p_dropout > 0.0;

View File

@ -261,7 +261,7 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
if (is_causal) { window_size_right = 0; }
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentHIPStream().stream();
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
@ -365,7 +365,7 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
}
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));

View File

@ -261,7 +261,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
bool has_lse = true;
@ -299,7 +299,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::cuda::getCurrentCUDAStream(), philox_args, rng_state_ptr);
seed_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[0])), at::dtype(at::kLong));
offset_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[1])), at::dtype(at::kLong));
}
@ -317,7 +317,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
if (seqlen_k > 0) {
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
auto stream = at::cuda::getCurrentHIPStream().stream();
auto stream = at::cuda::getCurrentCUDAStream().stream();
ck_tile::stream_config stream_config{stream};
auto traits =

View File

@ -255,7 +255,7 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
if (is_causal) { window_size_right = 0; }
bool is_dropout = p_dropout > 0.0;
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
@ -366,7 +366,7 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
}
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));

View File

@ -273,7 +273,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
bool has_lse = true;
@ -307,7 +307,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::cuda::getCurrentCUDAStream(), philox_args, rng_state_ptr);
}
// remove const from attn_bias_
@ -320,7 +320,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
if (max_seqlen_k > 0) {
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
auto stream = at::cuda::getCurrentHIPStream().stream();
auto stream = at::cuda::getCurrentCUDAStream().stream();
ck_tile::stream_config stream_config{stream};
auto traits =

View File

@ -7,7 +7,6 @@
#include <ATen/TensorIndexing.h>
#include <ATen/core/Tensor.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/HIPGraphsUtils.cuh>
#ifndef AT_PER_OPERATOR_HEADERS

View File

@ -119,8 +119,9 @@ class C10_API DataPtr {
}
// Unsafely mutates the device on a DataPtr. Under normal use,
// you should never actually need to call this function.
// We need this for the implementation of the hack detailed
// in Note [Masquerading as CUDA]
// We used to need this for the implementation of the hack detailed
// in Note [Masquerading as CUDA], but that hack has been removed.
// Other uses of this function now exist so it cannot be deprecated.
void unsafe_set_device(Device device) {
device_ = device;
}

View File

@ -4,12 +4,13 @@
#include <c10/util/TypeSafeSignMath.h>
#include <cmath>
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
#if defined(__CUDA_ARCH__)
#include <c10/cuda/CUDAMathCompat.h>
#define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign
#elif defined(__HIPCC__)
#include <c10/hip/HIPMathCompat.h>
#define C10_COMPAT_COPYSIGN c10::hip::compat::copysign
#endif
#define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign
#else
#include <c10/util/copysign.h>
#define C10_COMPAT_COPYSIGN c10::copysign

View File

@ -12,7 +12,7 @@ TEST_CODES = [
"CUdeviceptr var = reinterpret_cast<CUdeviceptr>(arg.data_ptr());",
"at::cuda::CUDAStreamGuard guard(at::cuda::getStreamFromExternal());",
# Hipification should be idempotent, hipifying should be a no-op for already hipified files
"at::hip::HIPStreamGuardMasqueradingAsCUDA guard(at::hip::getStreamFromExternalMasqueradingAsCUDA());",
"at::cuda::CUDAStreamGuard guard(at::cuda::getStreamFromExternal());",
]
HIP_CODES = [
@ -20,8 +20,8 @@ HIP_CODES = [
"hipFunction_t kernel = nullptr;",
"static hipFunction_t kernel = nullptr;",
"hipDeviceptr_t var = reinterpret_cast<hipDeviceptr_t>(arg.data_ptr());",
"at::hip::HIPStreamGuardMasqueradingAsCUDA guard(at::hip::getStreamFromExternalMasqueradingAsCUDA());",
"at::hip::HIPStreamGuardMasqueradingAsCUDA guard(at::hip::getStreamFromExternalMasqueradingAsCUDA());",
"at::cuda::CUDAStreamGuard guard(at::cuda::getStreamFromExternal());",
"at::cuda::CUDAStreamGuard guard(at::cuda::getStreamFromExternal());",
]

View File

@ -202,6 +202,51 @@ for hip_platform_file in hip_platform_files:
print(f"{hip_platform_file} updated")
# TODO Remove once the following submodules are updated to use hipify v2
hipify_v1_to_v2_files = [
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/gemm/ck_extensions.hip",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_common.h",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/ck_utility.hip",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_common.h",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_common.h",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_preshuffle/kernels/fp8_rowwise_preshuffle_common.h",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe_kernel.hip",
"third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/tuning_cache.hpp",
]
def hipify_v1_to_v2(line: str) -> str:
line = line.replace("hip::HIPStreamMasqueradingAsCUDA", "cuda::CUDAStream")
line = line.replace(
"hip::HIPStreamGuardMasqueradingAsCUDA", "cuda::CUDAStreamGuard"
)
line = line.replace(
"hip::getStreamFromPoolMasqueradingAsCUDA", "cuda::getStreamFromPool"
)
line = line.replace("getCurrentHIPStream", "getCurrentCUDAStream")
return line
for hipify_v1_to_v2_file in hipify_v1_to_v2_files:
do_write = False
if os.path.exists(hipify_v1_to_v2_file):
with open(hipify_v1_to_v2_file) as sources:
lines = sources.readlines()
newlines = [hipify_v1_to_v2(line) for line in lines]
if lines == newlines:
print(f"{hipify_v1_to_v2_file} skipped")
else:
with open(hipify_v1_to_v2_file, "w") as sources:
for line in newlines:
sources.write(line)
print(f"{hipify_v1_to_v2_file} updated")
hipify_python.hipify(
project_directory=proj_dir,
output_directory=out_dir,

View File

@ -198,7 +198,7 @@ void* CUDAPluggableAllocator::getBaseAllocation(void* ptr, size_t* size) {
void CUDAPluggableAllocator::recordStream(
const c10::DataPtr& ptr,
streamType stream) {
c10::cuda::CUDAStream stream) {
if (record_stream_fn_) {
record_stream_fn_(ptr.get(), stream);
}

View File

@ -11,12 +11,6 @@
namespace torch::cuda::CUDAPluggableAllocator {
#if defined(USE_ROCM)
using streamType = c10::hip::HIPStream;
#else
using streamType = c10::cuda::CUDAStream;
#endif
TORCH_CUDA_CPP_API std::shared_ptr<
c10::cuda::CUDACachingAllocator::CUDAAllocator>
getCurrentAllocator();
@ -98,7 +92,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override;
void* getBaseAllocation(void* ptr, size_t* size) override;
void recordStream(const c10::DataPtr&, streamType stream) override;
void recordStream(const c10::DataPtr&, c10::cuda::CUDAStream stream) override;
c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) override;

View File

@ -74,8 +74,8 @@ AllocationRef::~AllocationRef() {
#endif
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle));
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipMemUnmap(reinterpret_cast<hipDeviceptr_t>(ptr), block_size));
C10_HIP_CHECK(hipMemRelease(handle));
C10_CUDA_CHECK(hipMemUnmap(reinterpret_cast<hipDeviceptr_t>(ptr), block_size));
C10_CUDA_CHECK(hipMemRelease(handle));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
@ -387,12 +387,12 @@ void* CUDASymmetricMemoryAllocator::alloc(
prop.requestedHandleType = hipMemHandleTypePosixFileDescriptor;
size_t granularity;
C10_HIP_CHECK(hipMemGetAllocationGranularity(
C10_CUDA_CHECK(hipMemGetAllocationGranularity(
&granularity, &prop, hipMemAllocationGranularityRecommended));
block_size = at::round_up(block_size, granularity);
HandleType handle;
C10_HIP_CHECK(hipMemCreate(
C10_CUDA_CHECK(hipMemCreate(
reinterpret_cast<hipMemGenericAllocationHandle_t*>(&handle),
block_size,
&prop,
@ -633,7 +633,7 @@ c10::intrusive_ptr<CUDASymmetricMemory> make_symm_mem(
: CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
0));
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipMemExportToShareableHandle(
C10_CUDA_CHECK(hipMemExportToShareableHandle(
&block_handle,
block->alloc_ref->handle,
hipMemHandleTypePosixFileDescriptor,
@ -697,7 +697,7 @@ c10::intrusive_ptr<CUDASymmetricMemory> make_symm_mem(
CU_MEM_HANDLE_TYPE_FABRIC));
}
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipMemImportFromShareableHandle(
C10_CUDA_CHECK(hipMemImportFromShareableHandle(
&handles[r],
(void*)(uintptr_t) & (imported_handles[r]),
hipMemHandleTypePosixFileDescriptor));

View File

@ -1151,7 +1151,7 @@ at::Tensor memset32_(
count,
at::cuda::getCurrentCUDAStream()));
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(addr),
C10_CUDA_CHECK(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(addr),
val,
count,
at::cuda::getCurrentCUDAStream()));
@ -1208,7 +1208,7 @@ at::Tensor stream_write_value32_(
val,
0));
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipStreamWriteValue32(
C10_CUDA_CHECK(hipStreamWriteValue32(
at::cuda::getCurrentCUDAStream(),
reinterpret_cast<void*>(addr),
val,

View File

@ -246,14 +246,14 @@ void map_block(
desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1));
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipMemAddressReserve(ptr, size, 0ULL, 0, 0ULL));
C10_HIP_CHECK(hipMemMap(
C10_CUDA_CHECK(hipMemAddressReserve(ptr, size, 0ULL, 0, 0ULL));
C10_CUDA_CHECK(hipMemMap(
*ptr,
size,
0,
reinterpret_cast<hipMemGenericAllocationHandle_t>(handle),
0ULL));
C10_HIP_CHECK(hipMemMap(
C10_CUDA_CHECK(hipMemMap(
*ptr,
size,
0,
@ -265,7 +265,7 @@ void map_block(
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
desc.location.id = static_cast<int>(device_idx);
desc.flags = hipMemAccessFlagsProtReadWrite;
C10_HIP_CHECK(hipMemSetAccess(*ptr, size, &desc, 1));
C10_CUDA_CHECK(hipMemSetAccess(*ptr, size, &desc, 1));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");

View File

@ -7,6 +7,9 @@ and fall in three categories: 1) type of mapping, 2) API of mapping, 3) unsuppor
mapping.
"""
import warnings
warnings.warn("hipify's constants.py is no longer used as of version 2.0.0", FutureWarning)
CONV_VERSION = 0,
CONV_INIT = 1
CONV_DEVICE = 2
@ -54,8 +57,9 @@ API_LAST = 42
API_FFT = 43
API_RTC = 44
API_ROCTX = 45
API_PYT_EXT = 46
HIP_UNSUPPORTED = 46
HIP_UNSUPPORTED = 47
API_PYTORCH = 1337
API_CAFFE2 = 1338
API_C10 = 1339

File diff suppressed because it is too large Load Diff

View File

@ -30,18 +30,20 @@ import re
import shutil
import sys
import os
import warnings
from . import constants
from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
from typing import Optional
from collections.abc import Iterator
from collections.abc import Mapping, Iterable
from collections.abc import Iterable, Iterator, Mapping
from enum import Enum
import functools
import hashlib
def _deprecated(name):
warnings.warn(f"hipify version 2.0.0 no longer uses function {name}", FutureWarning, stacklevel=2)
class CurrentState(Enum):
INITIALIZED = 1
DONE = 2
@ -68,7 +70,7 @@ __all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_exte
'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group',
'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared',
'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_special_file', 'is_caffe2_gpu_file',
'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'CurrentState', 'HipifyResult', 'hipify']
@ -180,7 +182,6 @@ def matched_files_iter(
dirs.append("third_party/nvfuser")
for filename in filenames:
filepath = _to_unix_path(os.path.join(abs_dirpath, filename))
rel_filepath = _to_unix_path(os.path.join(rel_dirpath, filename))
# We respect extensions, UNLESS you wrote the entire
# filename verbatim, in which case we always accept it
if (
@ -188,11 +189,6 @@ def matched_files_iter(
and (not _fnmatch(filepath, ignores))
and (match_extensions(filepath, extensions) or filepath in exact_matches)
):
if not is_pytorch_extension: # for pytorch extensions, consider all files
if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath):
continue
if out_of_place_only and not is_out_of_place(rel_filepath):
continue
yield filepath
@ -562,8 +558,8 @@ def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
# it gets a different name from the original filename, so
# that we don't overwrite the original file
#
# There's a lot of different naming conventions across PyTorch
# and Caffe2, but the general recipe is to convert occurrences
# There's a lot of different naming conventions across PyTorch,
# but the general recipe is to convert occurrences
# of cuda/gpu to hip, and add hip if there are no occurrences
# of cuda/gpu anywhere.
#
@ -627,8 +623,8 @@ def is_out_of_place(rel_filepath):
return True
# Keep this synchronized with includes/ignores in build_amd.py
def is_pytorch_file(rel_filepath):
_deprecated("is_pytorch_file")
if os.path.isabs(rel_filepath):
raise AssertionError("rel_filepath must be a relative path")
if rel_filepath.startswith("aten/"):
@ -645,12 +641,14 @@ def is_pytorch_file(rel_filepath):
def is_cusparse_file(rel_filepath):
_deprecated("is_cusparse_file")
if is_pytorch_file(rel_filepath):
return "sparse" in rel_filepath.lower()
return False
def is_special_file(rel_filepath):
_deprecated("is_special_file")
if is_pytorch_file(rel_filepath):
if "sparse" in rel_filepath.lower():
return True
@ -660,7 +658,9 @@ def is_special_file(rel_filepath):
return True
return False
def is_caffe2_gpu_file(rel_filepath):
_deprecated("is_caffe2_gpu_file")
if os.path.isabs(rel_filepath):
raise AssertionError("rel_filepath must be a relative path")
if rel_filepath.startswith("c10/cuda"):
@ -670,6 +670,7 @@ def is_caffe2_gpu_file(rel_filepath):
# pyrefly: ignore # unsupported-operation
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
class TrieNode:
"""A Trie node whose children are represented as a directory of char: TrieNode.
A special char '' represents end of word
@ -678,6 +679,7 @@ class TrieNode:
def __init__(self):
self.children = {}
class Trie:
"""Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
The corresponding Regex should match much faster than a simple Regex union."""
@ -772,39 +774,16 @@ class Trie:
"""Export the Trie to a regex pattern."""
return self._pattern(self.root, self._digest)
CAFFE2_TRIE = Trie()
CAFFE2_MAP = {}
PYTORCH_TRIE = Trie()
PYTORCH_MAP: dict[str, object] = {}
# In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip.
# The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance.
# Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex.
# In the case of SPARSE, we must use the hip types for complex instead of the roc types,
# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority.
# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place.
# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices.
# Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling.
PYTORCH_SPECIAL_MAP = {}
for mapping in CUDA_TO_HIP_MAPPINGS:
if not isinstance(mapping, Mapping):
raise TypeError("Expected each mapping in CUDA_TO_HIP_MAPPINGS to be a Mapping")
for src, value in mapping.items():
dst = value[0]
meta_data = value[1:]
if constants.API_CAFFE2 not in meta_data:
PYTORCH_TRIE.add(src)
# if src is already in PYTORCH_MAP and dst belongs to API_SPECIAL
# do not overwrite PYTORCH_MAP, store dst separately
if constants.API_SPECIAL in meta_data and PYTORCH_MAP.get(src, ""):
PYTORCH_SPECIAL_MAP[src] = dst
else:
PYTORCH_MAP[src] = dst
if constants.API_PYTORCH not in meta_data and constants.API_SPECIAL not in meta_data:
CAFFE2_TRIE.add(src)
CAFFE2_MAP[src] = dst
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.export_to_regex())
for src, dst in mapping.items():
PYTORCH_TRIE.add(src)
PYTORCH_MAP[src] = dst
RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.export_to_regex()})(?=\W)')
RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
@ -865,22 +844,7 @@ def preprocessor(
def pt_repl(m):
return PYTORCH_MAP[m.group(0)]
def pt_special_repl(m):
# checks SPECIAL map first, and if a miss occurs, falls back to pytorch mappings
return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m))
if is_pytorch_extension:
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
else:
if is_special_file(rel_filepath):
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_special_repl, output_source)
elif is_pytorch_file(rel_filepath):
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
else:
def c2_repl(m):
return CAFFE2_MAP[m.group(0)]
output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
# Header rewrites
def mk_repl(templ, include_current_dir=True):

View File

@ -1 +1 @@
__version__ = '1.0.0'
__version__ = '2.0.0'

View File

@ -714,10 +714,7 @@ resize_out(out, sizes, strides, options);
raise RuntimeError(f"Unsupported SchemaKind {k}")
if self.backend_index.dispatch_key == DispatchKey.CUDA:
if self.rocm:
guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;"
else:
guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
elif (
self.backend_index.dispatch_key
== DispatchKey.CompositeExplicitAutogradNonFunctional

View File

@ -2228,7 +2228,7 @@ def gen_source_files(
#include <ATen/cuda/CUDAContext.h>"""
if rocm:
extra_cuda_headers = """\
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <c10/hip/HIPGuard.h>
#include <ATen/hip/ATenHIPGeneral.h>
#include <ATen/hip/HIPDevice.h>
#include <ATen/hip/HIPContext.h>"""