mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 23:04:54 +08:00
Compare commits
15 Commits
ciflow/tru
...
zhxchen17/
| Author | SHA1 | Date | |
|---|---|---|---|
| 33f776b894 | |||
| 4de24bcc56 | |||
| f2d0a472ef | |||
| 9ae0ecec7d | |||
| ce4f31f662 | |||
| 2c846bb614 | |||
| 8c86ccfbc9 | |||
| 8f96e7bc1d | |||
| 782fc3c72b | |||
| 1a67403fc6 | |||
| 3d801a4c01 | |||
| 2034ca99ae | |||
| 480b4ff882 | |||
| f570e589da | |||
| f9851af59b |
@ -1680,6 +1680,22 @@ test_operator_microbenchmark() {
|
||||
done
|
||||
}
|
||||
|
||||
test_attention_microbenchmark() {
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
TEST_DIR=$(pwd)
|
||||
|
||||
# Install attention-gym dependency
|
||||
echo "Installing attention-gym..."
|
||||
python -m pip install git+https://github.com/meta-pytorch/attention-gym.git@main
|
||||
pip show triton
|
||||
|
||||
cd "${TEST_DIR}"/benchmarks/transformer
|
||||
|
||||
$TASKSET python score_mod.py --config configs/config_basic.yaml \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/attention_microbenchmark.json"
|
||||
}
|
||||
|
||||
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||
(cd test && python -c "import torch; print(torch.__config__.show())")
|
||||
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
|
||||
@ -1737,6 +1753,8 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then
|
||||
test_operator_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *attention_microbenchmark* ]]; then
|
||||
test_attention_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
|
||||
73
.github/workflows/attention_op_microbenchmark.yml
vendored
Normal file
73
.github/workflows/attention_op_microbenchmark.yml
vendored
Normal file
@ -0,0 +1,73 @@
|
||||
name: attention_op_microbenchmark
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/op-benchmark/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
# Run at 06:00 UTC everyday
|
||||
- cron: 0 7 * * *
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
attn-microbenchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '8.0 9.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
|
||||
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
attn-microbenchmark-test:
|
||||
name: attn-microbenchmark-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: attn-microbenchmark-build
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
# B200 runner
|
||||
opmicrobenchmark-build-b200:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: opmicrobenchmark-build-b200
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
opmicrobenchmark-test-b200:
|
||||
name: opmicrobenchmark-test-b200
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: opmicrobenchmark-build-b200
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
||||
@ -94,6 +94,11 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
||||
}
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/CUDAGraph.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/MemPool.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
|
||||
@ -13,7 +14,7 @@ static bool _cuda_graphs_debug = false;
|
||||
MempoolId_t graph_pool_handle() {
|
||||
// Sets just the second value, to distinguish it from MempoolId_ts created from
|
||||
// cudaStreamGetCaptureInfo id_s in capture_begin.
|
||||
return c10::cuda::MemPool::graph_pool_handle();
|
||||
return at::cuda::MemPool::graph_pool_handle();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -90,7 +91,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
||||
} else {
|
||||
// User did not ask us to share a mempool. Create graph pool handle using is_user_created=false.
|
||||
// Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle().
|
||||
mempool_id_ = c10::cuda::MemPool::graph_pool_handle(false);
|
||||
mempool_id_ = at::cuda::MemPool::graph_pool_handle(false);
|
||||
TORCH_INTERNAL_ASSERT(mempool_id_.first > 0);
|
||||
}
|
||||
|
||||
|
||||
69
aten/src/ATen/cuda/MemPool.cpp
Normal file
69
aten/src/ATen/cuda/MemPool.cpp
Normal file
@ -0,0 +1,69 @@
|
||||
#include <ATen/core/CachingHostAllocator.h>
|
||||
#include <ATen/cuda/MemPool.h>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
// uid_ is incremented when a user creates a MemPool,
|
||||
// for example: using graph_pool_handle() or c10::cuda::MemPool().
|
||||
//
|
||||
// uuid_ is incremented when CUDAGraph creates a MemPool
|
||||
// as a result of a user not providing a pool.
|
||||
//
|
||||
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
|
||||
// passed to a function, either by user or CUDAGraphs. For example,
|
||||
// default value of MempoolId_t for capture_begin function is {0, 0}.
|
||||
// That's why uid_ and uuid_ start at 1.
|
||||
std::atomic<CaptureId_t> MemPool::uid_{1};
|
||||
std::atomic<CaptureId_t> MemPool::uuid_{1};
|
||||
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
} else {
|
||||
id_ = {uuid_++, 0};
|
||||
}
|
||||
device_ = c10::cuda::current_device();
|
||||
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
|
||||
if (use_on_oom) {
|
||||
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
||||
}
|
||||
}
|
||||
|
||||
MemPool::~MemPool() {
|
||||
// TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
// We used to assert that TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
// However, this assertion is not true if a memory pool is shared
|
||||
// with a cuda graph. That CUDAGraph will increase the use count
|
||||
// until it is reset.
|
||||
CUDACachingAllocator::releasePool(device_, id_);
|
||||
c10::cuda::CUDACachingAllocator::emptyCache(id_);
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
return id_;
|
||||
}
|
||||
|
||||
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
int MemPool::use_count() {
|
||||
return CUDACachingAllocator::getPoolUseCount(device_, id_);
|
||||
}
|
||||
|
||||
c10::DeviceIndex MemPool::device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
|
||||
if (is_user_created) {
|
||||
return {0, uid_++};
|
||||
}
|
||||
return {uuid_++, 0};
|
||||
}
|
||||
|
||||
} // namespace at::cuda
|
||||
44
aten/src/ATen/cuda/MemPool.h
Normal file
44
aten/src/ATen/cuda/MemPool.h
Normal file
@ -0,0 +1,44 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
// Keep BC only
|
||||
using c10::CaptureId_t;
|
||||
using c10::MempoolId_t;
|
||||
|
||||
// MemPool represents a pool of memory in a caching allocator. Currently,
|
||||
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||
//
|
||||
// An allocator pointer can be passed to the MemPool to define how the
|
||||
// allocations should be done in the pool. For example: using a different
|
||||
// system allocator such as ncclMemAlloc.
|
||||
struct TORCH_CUDA_CPP_API MemPool {
|
||||
MemPool(
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true,
|
||||
bool use_on_oom = false);
|
||||
MemPool(const MemPool&) = delete;
|
||||
MemPool(MemPool&&) = default;
|
||||
MemPool& operator=(const MemPool&) = delete;
|
||||
MemPool& operator=(MemPool&&) = default;
|
||||
~MemPool();
|
||||
|
||||
MempoolId_t id();
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator();
|
||||
int use_count();
|
||||
c10::DeviceIndex device();
|
||||
static MempoolId_t graph_pool_handle(bool is_user_created = true);
|
||||
|
||||
private:
|
||||
static std::atomic<CaptureId_t> uid_;
|
||||
static std::atomic<CaptureId_t> uuid_;
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator_;
|
||||
bool is_user_created_;
|
||||
MempoolId_t id_;
|
||||
c10::DeviceIndex device_;
|
||||
};
|
||||
|
||||
} // namespace at::cuda
|
||||
62
benchmarks/dynamo/pr_time_benchmarks/benchmarks/dtensor.py
Normal file
62
benchmarks/dynamo/pr_time_benchmarks/benchmarks/dtensor.py
Normal file
@ -0,0 +1,62 @@
|
||||
import sys
|
||||
|
||||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DTensor, Replicate
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
|
||||
class BenchmarkDTensorDispatch(BenchmarkBase):
|
||||
def __init__(self, operator, world_size) -> None:
|
||||
super().__init__(
|
||||
category=f"dtensor_dispatch_{operator}",
|
||||
device="cuda",
|
||||
)
|
||||
self.world_size = world_size
|
||||
|
||||
def name(self) -> str:
|
||||
prefix = f"{self.category()}"
|
||||
return prefix
|
||||
|
||||
def description(self) -> str:
|
||||
return f"DTensor dispatch time for {self.category()}"
|
||||
|
||||
def _prepare_once(self) -> None:
|
||||
self.mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
"cuda", (self.world_size,), mesh_dim_names=("dp",)
|
||||
)
|
||||
self.a = DTensor.from_local(
|
||||
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
|
||||
)
|
||||
self.b = DTensor.from_local(
|
||||
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
|
||||
)
|
||||
|
||||
def _prepare(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkDetach(BenchmarkDTensorDispatch):
|
||||
def __init__(self, world_size) -> None:
|
||||
super().__init__(operator="detach", world_size=world_size)
|
||||
|
||||
def _work(self) -> None:
|
||||
self.a.detach()
|
||||
|
||||
|
||||
def main():
|
||||
world_size = 256
|
||||
fake_store = FakeStore()
|
||||
torch.distributed.init_process_group(
|
||||
"fake", store=fake_store, rank=0, world_size=world_size
|
||||
)
|
||||
result_path = sys.argv[1]
|
||||
BenchmarkDetach(world_size).enable_instruction_count().collect_all().append_results(
|
||||
result_path
|
||||
)
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -125,6 +125,17 @@ AttentionType = Literal[
|
||||
]
|
||||
DtypeString = Literal["bfloat16", "float16", "float32"]
|
||||
SpeedupType = Literal["fwd", "bwd"]
|
||||
# Operator Name mapping
|
||||
backend_to_operator_name = {
|
||||
"math": "math attention kernel",
|
||||
"efficient": "efficient attention kernel",
|
||||
"cudnn": "cudnn attention kernel",
|
||||
"fav2": "flash attention 2 kernel",
|
||||
"fav3": "flash attention 3 kernel",
|
||||
"fakv": "flash attention kv cache kernel",
|
||||
"og-eager": "eager attention kernel",
|
||||
"flex": "flex attention kernel",
|
||||
}
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
@ -1265,12 +1276,14 @@ def _output_json_for_dashboard(
|
||||
model: ModelInfo
|
||||
metric: MetricInfo
|
||||
|
||||
operator_name = backend_to_operator_name.get(backend, backend)
|
||||
|
||||
# Benchmark extra info
|
||||
benchmark_extra_info = {
|
||||
"input_config": input_config,
|
||||
"device": device,
|
||||
"arch": device_arch,
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
"attn_type": config.attn_type,
|
||||
"shape": str(config.shape),
|
||||
"max_autotune": config.max_autotune,
|
||||
@ -1288,7 +1301,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
"attn_type": config.attn_type,
|
||||
},
|
||||
),
|
||||
@ -1315,7 +1328,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
@ -1341,7 +1354,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
@ -1371,7 +1384,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
|
||||
@ -19,6 +19,17 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
using CaptureId_t = unsigned long long;
|
||||
// first is set if the instance is created by CUDAGraph::capture_begin.
|
||||
// second is set if the instance is created by at::cuda::graph_pool_handle.
|
||||
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
|
||||
|
||||
struct MempoolIdHash {
|
||||
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
|
||||
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
|
||||
}
|
||||
};
|
||||
|
||||
// A DataPtr is a unique pointer (with an attached deleter and some
|
||||
// context for the deleter) to some memory, which also records what
|
||||
// device is for its data.
|
||||
|
||||
@ -96,6 +96,13 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Resets peak memory usage statistics for the specified device
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "getMemoryInfo is not implemented for this allocator yet.");
|
||||
}
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -1012,12 +1012,6 @@ PrivatePoolState::PrivatePoolState(
|
||||
}
|
||||
}
|
||||
|
||||
struct MempoolIdHash {
|
||||
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
|
||||
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
|
||||
}
|
||||
};
|
||||
|
||||
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
|
||||
if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
|
||||
*ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
|
||||
@ -4510,66 +4504,3 @@ std::atomic<CUDAAllocator*> allocator;
|
||||
static BackendStaticInitializer backend_static_initializer;
|
||||
} // namespace cuda::CUDACachingAllocator
|
||||
} // namespace c10
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
// uid_ is incremented when a user creates a MemPool,
|
||||
// for example: using graph_pool_handle() or c10::cuda::MemPool().
|
||||
//
|
||||
// uuid_ is incremented when CUDAGraph creates a MemPool
|
||||
// as a result of a user not providing a pool.
|
||||
//
|
||||
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
|
||||
// passed to a function, either by user or CUDAGraphs. For example,
|
||||
// default value of MempoolId_t for capture_begin function is {0, 0}.
|
||||
// That's why uid_ and uuid_ start at 1.
|
||||
std::atomic<CaptureId_t> MemPool::uid_{1};
|
||||
std::atomic<CaptureId_t> MemPool::uuid_{1};
|
||||
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
} else {
|
||||
id_ = {uuid_++, 0};
|
||||
}
|
||||
device_ = c10::cuda::current_device();
|
||||
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
|
||||
if (use_on_oom) {
|
||||
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
||||
}
|
||||
}
|
||||
|
||||
MemPool::~MemPool() {
|
||||
TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
CUDACachingAllocator::releasePool(device_, id_);
|
||||
c10::cuda::CUDACachingAllocator::emptyCache(id_);
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
return id_;
|
||||
}
|
||||
|
||||
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
int MemPool::use_count() {
|
||||
return CUDACachingAllocator::getPoolUseCount(device_, id_);
|
||||
}
|
||||
|
||||
c10::DeviceIndex MemPool::device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
|
||||
if (is_user_created) {
|
||||
return {0, uid_++};
|
||||
}
|
||||
return {uuid_++, 0};
|
||||
}
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
||||
@ -345,6 +345,13 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<AllocatorState> pps) = 0;
|
||||
virtual std::string name() = 0;
|
||||
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
|
||||
c10::DeviceGuard device_guard({at::kCUDA, device});
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
return {free, total};
|
||||
}
|
||||
};
|
||||
|
||||
// Allocator object, statically initialized
|
||||
@ -555,41 +562,7 @@ inline std::string getUserMetadata() {
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
// Keep BC only
|
||||
using c10::CaptureId_t;
|
||||
using c10::MempoolId_t;
|
||||
|
||||
// MemPool represents a pool of memory in a caching allocator. Currently,
|
||||
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||
//
|
||||
// An allocator pointer can be passed to the MemPool to define how the
|
||||
// allocations should be done in the pool. For example: using a different
|
||||
// system allocator such as ncclMemAlloc.
|
||||
struct C10_CUDA_API MemPool {
|
||||
MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true,
|
||||
bool use_on_oom = false);
|
||||
MemPool(const MemPool&) = delete;
|
||||
MemPool(MemPool&&) = default;
|
||||
MemPool& operator=(const MemPool&) = delete;
|
||||
MemPool& operator=(MemPool&&) = default;
|
||||
~MemPool();
|
||||
|
||||
MempoolId_t id();
|
||||
CUDACachingAllocator::CUDAAllocator* allocator();
|
||||
int use_count();
|
||||
c10::DeviceIndex device();
|
||||
static MempoolId_t graph_pool_handle(bool is_user_created = true);
|
||||
|
||||
private:
|
||||
static std::atomic<CaptureId_t> uid_;
|
||||
static std::atomic<CaptureId_t> uuid_;
|
||||
CUDACachingAllocator::CUDAAllocator* allocator_;
|
||||
bool is_user_created_;
|
||||
MempoolId_t id_;
|
||||
c10::DeviceIndex device_;
|
||||
};
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
||||
@ -926,15 +926,14 @@ class DeviceCachingAllocator {
|
||||
(release_cached_blocks() && alloc_block(params, true));
|
||||
}
|
||||
if (!block_found) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
const auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
const auto device_total =
|
||||
raw_device.get_info<sycl::info::device::global_mem_size>();
|
||||
// Estimate the available device memory when the SYCL runtime does not
|
||||
// support the corresponding aspect (ext_intel_free_memory).
|
||||
size_t device_free = device_prop.global_mem_size -
|
||||
size_t device_free = device_total -
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
||||
// affected devices.
|
||||
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
||||
@ -1052,21 +1051,37 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo() {
|
||||
const auto& device = c10::xpu::get_raw_device(device_index);
|
||||
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
device.get_info<sycl::info::device::name>(),
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
const size_t free =
|
||||
device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
static_cast<double>(device_total);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
||||
set_fraction = true;
|
||||
}
|
||||
@ -1240,6 +1255,11 @@ class XPUAllocator : public DeviceAllocator {
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryInfo();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# This will define the following variables:
|
||||
# SYCL_FOUND : True if the system has the SYCL library.
|
||||
# SYCL_INCLUDE_DIR : Include directories needed to use SYCL.
|
||||
# SYCL_LIBRARY_DIR :The path to the SYCL library.
|
||||
# SYCL_LIBRARY_DIR : The path to the SYCL library.
|
||||
# SYCL_LIBRARY : SYCL library fullname.
|
||||
# SYCL_COMPILER_VERSION : SYCL compiler version.
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_memory_info
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
memory_allocated
|
||||
|
||||
164
docs/source/accelerator/hooks.md
Normal file
164
docs/source/accelerator/hooks.md
Normal file
@ -0,0 +1,164 @@
|
||||
# Accelerator Hooks
|
||||
|
||||
## Background
|
||||
|
||||
OpenReg hooks provide a mechanism for integrating custom accelerator devices into PyTorch's runtime system. OpenReg (Open Registration) is PyTorch's extensibility framework that allows accelerator vendors to register custom device backends without modifying PyTorch core code.
|
||||
|
||||
## Design
|
||||
|
||||
The following tables list all hooks that accelerator vendors need to implement when integrating a new device backend. These hooks are categorized into two priority levels:
|
||||
|
||||
- **High Priority Hooks**: Core APIs that PyTorch runtime directly depends on. Accelerator vendors are recommended to implement all high priority hooks to ensure full PyTorch compatibility and enable basic device functionality.
|
||||
|
||||
- **Low Priority Hooks**: Device management and utility APIs that PyTorch does not directly depend on. These hooks enhance user experience and multi-device support but are *optional*. Accelerator vendors can choose to implement them based on their specific requirements and use cases.
|
||||
|
||||
### High Priority Hooks
|
||||
|
||||
| Hook Method | Description | Application Scenario |
|
||||
| ---------------------------------- | --------------------------------------------------------- | -------------------------------------------------------------------------------- |
|
||||
| `init()` | Initializes the accelerator runtime and device contexts | Set up necessary state when PyTorch first accesses the device |
|
||||
| `hasPrimaryContext(DeviceIndex)` | Checks if a primary context exists for the device | Determine whether device initialization has occurred |
|
||||
| `getDefaultGenerator(DeviceIndex)` | Returns the default random number generator for a device | Access the device's primary RNG for reproducible random operations |
|
||||
| `getNewGenerator(DeviceIndex)` | Creates a new independent random number generator | Create isolated RNG instances for parallel operations |
|
||||
| `getDeviceFromPtr(void*)` | Determines which device a memory pointer belongs to | Identify the accelerator device associated with a memory allocation |
|
||||
| `getPinnedMemoryAllocator()` | Returns an allocator for pinned (page-locked) host memory | Allocate host memory that can be efficiently transferred to/from the accelerator |
|
||||
| `isPinnedPtr(void*)` | Checks if a pointer points to pinned memory | Validate memory types before performing operations |
|
||||
|
||||
### Low Priority Hooks
|
||||
|
||||
| Hook Method | Description | Application Scenario |
|
||||
| ---------------------------------- | ---------------------------------------------------------------------------- | -------------------------------------------------------------------- |
|
||||
| `isBuilt()` | Returns whether the accelerator backend is built/compiled into the extension | Check whether the accelerator library is available at compile time |
|
||||
| `isAvailable()` | Returns whether the accelerator hardware is available at runtime | Verify whether accelerator devices can be detected and initialized |
|
||||
| `deviceCount()` | Returns the number of available accelerator devices | Enumerate all available accelerator devices for device selection |
|
||||
| `setCurrentDevice(DeviceIndex)` | Sets the active device for the current thread | Switch the current thread's context to a specific accelerator device |
|
||||
| `getCurrentDevice()` | Returns the currently active device index | Query which accelerator device is active in the current thread |
|
||||
| `exchangeDevice(DeviceIndex)` | Atomically exchanges the current device and returns the previous one | Temporarily switch devices and restore the previous device afterward |
|
||||
| `maybeExchangeDevice(DeviceIndex)` | Conditionally exchanges device only if the index is valid | Safely attempt device switching with validation |
|
||||
|
||||
## Implementation
|
||||
|
||||
We can just take `getDefaultGenerator` as an implementation example:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
|
||||
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
|
||||
:linenos:
|
||||
```
|
||||
|
||||
In this implementation:
|
||||
|
||||
1. **Override the base interface**: The `getDefaultGenerator` method overrides the virtual method from `at::PrivateUse1HooksInterface`.
|
||||
|
||||
2. **Delegate to device-specific implementation**: It calls `getDefaultOpenRegGenerator(device_index)`, which manages a per-device generator instance.
|
||||
|
||||
3. **Return device-specific generator**: The returned `at::Generator` wraps an `OpenRegGeneratorImpl` that implements device-specific random number generation.
|
||||
|
||||
This pattern applies to all hooks: override the interface method, validate inputs, delegate to your device-specific API, and return results in PyTorch's expected format.
|
||||
|
||||
## Integration Example
|
||||
|
||||
The following sections demonstrate how PyTorch integrates with accelerator hooks when accessing the default random number generator. The example traces the complete flow from user-facing Python code down to the device-specific implementation.
|
||||
|
||||
### Layer 1: User Code
|
||||
|
||||
User code initiates the operation by calling `manual_seed` to set the random seed for reproducible results:
|
||||
|
||||
```python
|
||||
import torch
|
||||
torch.openreg.manual_seed(42)
|
||||
```
|
||||
|
||||
### Layer 2: Extension Python API
|
||||
|
||||
The Python API layer handles device management and calls into the C++ extension (defined in [`torch_openreg/openreg/random.py`][random.py]):
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py
|
||||
:language: python
|
||||
:start-after: LITERALINCLUDE START: OPENREG MANUAL SEED
|
||||
:end-before: LITERALINCLUDE END: OPENREG MANUAL SEED
|
||||
:linenos:
|
||||
```
|
||||
|
||||
The `manual_seed` function gets the current device index and calls `torch_openreg._C._get_default_generator(idx)` to obtain the device-specific generator, then sets the seed on it.
|
||||
|
||||
### Layer 3: Python/C++ Bridge
|
||||
|
||||
The C++ extension exposes `_getDefaultGenerator` to Python, which bridges to PyTorch's core runtime:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR
|
||||
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
||||
:linenos:
|
||||
:emphasize-lines: 10-11
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
|
||||
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
|
||||
:linenos:
|
||||
:emphasize-lines: 3
|
||||
```
|
||||
|
||||
This function unpacks the device index from Python, creates a `PrivateUse1` device object, and calls `at::globalContext().defaultGenerator()`. PyTorch's context then dispatches to the registered hooks.
|
||||
|
||||
### Layer 4: PyTorch Core Context
|
||||
|
||||
PyTorch's Context class dispatches to the appropriate accelerator hooks ([`aten/src/ATen/Context.h`][Context.h]):
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../aten/src/ATen/Context.h
|
||||
:language: c++
|
||||
:lines: 60-103
|
||||
:linenos:
|
||||
:emphasize-lines: 8-9, 24-25
|
||||
```
|
||||
|
||||
This layered architecture enables PyTorch to remain device-agnostic while delegating hardware-specific operations to accelerator implementations. The hooks are registered once at module load time:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG HOOK REGISTER
|
||||
:end-before: LITERALINCLUDE END: OPENREG HOOK REGISTER
|
||||
:linenos:
|
||||
:emphasize-lines: 4
|
||||
```
|
||||
|
||||
### Layer 5: Accelerator Hooks
|
||||
|
||||
The hooks interface provides the abstraction that PyTorch uses to delegate to device-specific implementations:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
|
||||
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
|
||||
:linenos:
|
||||
```
|
||||
|
||||
The `getDefaultGenerator` hook method overrides the base interface and delegates to `getDefaultOpenRegGenerator`, which manages the actual generator instances.
|
||||
|
||||
### Layer 6: Device-Specific Implementation
|
||||
|
||||
The device-specific implementation manages per-device generator instances:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR IMPL
|
||||
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR IMPL
|
||||
:linenos:
|
||||
```
|
||||
|
||||
This function maintains a static vector of generators (one per device), initializes them on first access, validates the device index, and returns the appropriate generator instance.
|
||||
|
||||
[random.py]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py#L48-L53 "random.py"
|
||||
[Context.h]: https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/Context.h#L61-L102 "Context.h"
|
||||
@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
||||
:glob:
|
||||
:maxdepth: 1
|
||||
|
||||
hooks
|
||||
autoload
|
||||
operators
|
||||
amp
|
||||
|
||||
@ -5,6 +5,7 @@ static std::vector<at::Generator> default_generators;
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR IMPL
|
||||
const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) {
|
||||
static bool flag [[maybe_unused]] = []() {
|
||||
auto deivce_nums = device_count();
|
||||
@ -24,5 +25,6 @@ const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) {
|
||||
}
|
||||
return default_generators[idx];
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR IMPL
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include "OpenRegHooks.h"
|
||||
|
||||
// LITERALINCLUDE START: OPENREG HOOK REGISTER
|
||||
namespace c10::openreg {
|
||||
|
||||
static bool register_hook_flag [[maybe_unused]] = []() {
|
||||
@ -9,3 +10,4 @@ static bool register_hook_flag [[maybe_unused]] = []() {
|
||||
}();
|
||||
|
||||
} // namespace c10::openreg
|
||||
// LITERALINCLUDE END: OPENREG HOOK REGISTER
|
||||
@ -8,17 +8,58 @@
|
||||
|
||||
#include <include/openreg.h>
|
||||
|
||||
#include "OpenRegFunctions.h"
|
||||
#include "OpenRegGenerator.h"
|
||||
|
||||
namespace c10::openreg {
|
||||
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
struct OPENREG_EXPORT OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
OpenRegHooksInterface() {};
|
||||
~OpenRegHooksInterface() override = default;
|
||||
|
||||
bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
|
||||
void init() const override {
|
||||
// Initialize OpenReg runtime if needed
|
||||
// This is called when PyTorch first accesses the device
|
||||
}
|
||||
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isBuilt() const override {
|
||||
// This extension is compiled as part of the OpenReg test extension.
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isAvailable() const override {
|
||||
// Consider OpenReg available if there's at least one device reported.
|
||||
return device_count() > 0;
|
||||
}
|
||||
|
||||
DeviceIndex deviceCount() const override {
|
||||
return device_count();
|
||||
}
|
||||
|
||||
void setCurrentDevice(DeviceIndex device) const override {
|
||||
set_device(device);
|
||||
}
|
||||
|
||||
DeviceIndex getCurrentDevice() const override {
|
||||
return current_device();
|
||||
}
|
||||
|
||||
DeviceIndex exchangeDevice(DeviceIndex device) const override {
|
||||
return ExchangeDevice(device);
|
||||
}
|
||||
|
||||
DeviceIndex maybeExchangeDevice(DeviceIndex device) const override {
|
||||
// Only exchange if the requested device is valid; otherwise, no-op and return current
|
||||
auto count = device_count();
|
||||
if (device < 0 || device >= count) {
|
||||
return getCurrentDevice();
|
||||
}
|
||||
return exchangeDevice(device);
|
||||
}
|
||||
|
||||
at::Allocator* getPinnedMemoryAllocator() const override {
|
||||
return at::getHostAllocator(at::kPrivateUse1);
|
||||
}
|
||||
@ -30,12 +71,23 @@ struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
return attr.type == orMemoryTypeHost;
|
||||
}
|
||||
|
||||
const at::Generator& getDefaultGenerator(
|
||||
c10::DeviceIndex device_index) const override {
|
||||
at::Device getDeviceFromPtr(void* data) const override {
|
||||
orPointerAttributes attr{};
|
||||
auto err = orPointerGetAttributes(&attr, data);
|
||||
if (err == orSuccess && attr.type == orMemoryTypeDevice) {
|
||||
return at::Device(at::DeviceType::PrivateUse1, static_cast<int>(attr.device));
|
||||
} else {
|
||||
TORCH_CHECK(false, "failed to get device from pointer");
|
||||
}
|
||||
return at::Device(at::DeviceType::PrivateUse1, current_device());
|
||||
}
|
||||
// LITERALINCLUDE START: OPENREG HOOK EXAMPLES
|
||||
const at::Generator& getDefaultGenerator(DeviceIndex device_index) const override {
|
||||
return getDefaultOpenRegGenerator(device_index);
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG HOOK EXAMPLES
|
||||
|
||||
at::Generator getNewGenerator(c10::DeviceIndex device_index) const override {
|
||||
at::Generator getNewGenerator(DeviceIndex device_index) const override {
|
||||
return at::make_generator<OpenRegGeneratorImpl>(device_index);
|
||||
}
|
||||
};
|
||||
|
||||
@ -17,6 +17,7 @@ static PyObject* _initExtension(PyObject* self, PyObject* noargs) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR
|
||||
static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
@ -31,6 +32,7 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
|
||||
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
||||
|
||||
PyObject* _setDevice(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
@ -73,6 +75,7 @@ PyObject* _getDeviceCount(PyObject* self, PyObject* noargs) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// LITERALINCLUDE START: OPENREG MODULE METHODS
|
||||
static PyMethodDef methods[] = {
|
||||
{"_init", _initExtension, METH_NOARGS, nullptr},
|
||||
{"_get_default_generator", _getDefaultGenerator, METH_O, nullptr},
|
||||
@ -81,7 +84,7 @@ static PyMethodDef methods[] = {
|
||||
{"_exchangeDevice", _exchangeDevice, METH_O, nullptr},
|
||||
{"_get_device_count", _getDeviceCount, METH_NOARGS, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
// LITERALINCLUDE END: OPENREG MODULE METHODS
|
||||
/*
|
||||
* When ASAN is enabled, PyTorch modifies the dlopen flag during import,
|
||||
* causing all global and weak symbols in _C.so and its dependent libraries
|
||||
|
||||
@ -45,6 +45,7 @@ def initial_seed() -> int:
|
||||
return default_generator.initial_seed()
|
||||
|
||||
|
||||
# LITERALINCLUDE START: OPENREG MANUAL SEED
|
||||
def manual_seed(seed: int) -> None:
|
||||
seed = int(seed)
|
||||
|
||||
@ -53,6 +54,9 @@ def manual_seed(seed: int) -> None:
|
||||
default_generator.manual_seed(seed)
|
||||
|
||||
|
||||
# LITERALINCLUDE END: OPENREG MANUAL SEED
|
||||
|
||||
|
||||
def manual_seed_all(seed: int) -> None:
|
||||
seed = int(seed)
|
||||
|
||||
|
||||
@ -450,6 +450,9 @@ class TestDTensorDebugMode(TestCase):
|
||||
op for op in debug_mode.operators if str(op.op) == "aten.sum.dim_IntList"
|
||||
][-1]
|
||||
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
|
||||
self.assertTrue(
|
||||
"self.l2(self.l1(x))" in debug_mode.debug_string(show_stack_trace=True)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "requires GPU")
|
||||
@unittest.skipIf(not has_triton_package(), "requires triton")
|
||||
|
||||
@ -706,11 +706,11 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_dtensor_dtype_conversion(self):
|
||||
from torch.distributed.tensor.debug import (
|
||||
_clear_sharding_prop_cache,
|
||||
_get_sharding_prop_cache_info,
|
||||
_clear_fast_path_sharding_prop_cache,
|
||||
_get_fast_path_sharding_prop_cache_stats,
|
||||
)
|
||||
|
||||
_clear_sharding_prop_cache()
|
||||
_clear_fast_path_sharding_prop_cache()
|
||||
device_mesh = self.build_device_mesh()
|
||||
shard_spec = [Shard(0)]
|
||||
# by default we start from bf16 dtype
|
||||
@ -730,13 +730,13 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16)
|
||||
|
||||
# by this point we only have cache misses
|
||||
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
||||
hits, misses = _get_fast_path_sharding_prop_cache_stats()
|
||||
self.assertEqual(hits, 0)
|
||||
self.assertEqual(misses, 2)
|
||||
|
||||
# convert to fp32 again and see if there's cache hit
|
||||
bf16_sharded_dtensor1.float()
|
||||
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
||||
hits, misses = _get_fast_path_sharding_prop_cache_stats()
|
||||
# by now we should have cache hit
|
||||
self.assertEqual(hits, 1)
|
||||
self.assertEqual(misses, 2)
|
||||
|
||||
@ -15295,12 +15295,12 @@ graph():
|
||||
def forward(self, block):
|
||||
return block.a + block.b
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError, "It looks like one of the inputs with type"
|
||||
):
|
||||
_dynamo_graph_capture_for_export(Foo())(
|
||||
dynamo_graph_capture_for_export(Foo())(
|
||||
Block(torch.randn(4, 4), torch.randn(4, 4))
|
||||
)
|
||||
|
||||
|
||||
@ -1,71 +0,0 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
from torch._export import config as export_config
|
||||
|
||||
|
||||
try:
|
||||
from . import test_export, testing
|
||||
except ImportError:
|
||||
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
||||
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||
|
||||
from torch.export import export
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def mocked_strict_export(*args, **kwargs):
|
||||
# If user already specified strict, don't make it strict
|
||||
if "strict" in kwargs:
|
||||
return export(*args, **kwargs)
|
||||
return export(*args, **kwargs, strict=True)
|
||||
|
||||
|
||||
def make_dynamic_cls(cls):
|
||||
# Some test check for ending in suffix; need to make
|
||||
# the `_strict` for end of string as a result
|
||||
suffix = test_export.INLINE_AND_INSTALL_STRICT_SUFFIX
|
||||
|
||||
cls_prefix = "InlineAndInstall"
|
||||
|
||||
cls_a = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
"StrictExport",
|
||||
suffix,
|
||||
mocked_strict_export,
|
||||
xfail_prop="_expected_failure_strict",
|
||||
)
|
||||
test_class = make_test_cls_with_patches(
|
||||
cls_a,
|
||||
cls_prefix,
|
||||
"",
|
||||
(export_config, "use_new_tracer_experimental", True),
|
||||
(dynamo_config, "install_free_tensors", True),
|
||||
(dynamo_config, "inline_inbuilt_nn_modules", True),
|
||||
xfail_prop="_expected_failure_inline_and_install",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
return test_class
|
||||
|
||||
|
||||
tests = [
|
||||
test_export.TestDynamismExpression,
|
||||
test_export.TestExport,
|
||||
]
|
||||
for test in tests:
|
||||
make_dynamic_cls(test)
|
||||
del test
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -1832,6 +1832,71 @@ class TestPatternMatcher(TestCase):
|
||||
self.assertEqual(len(sigmoid_nodes), 1)
|
||||
self.assertTrue("original_aten" in sigmoid_nodes[0].meta)
|
||||
|
||||
@inductor_config.patch(is_predispatch=True)
|
||||
def test_remove_noop_pass_with_remove_passes(self):
|
||||
def fn_with_noop(x):
|
||||
batch_size, dim = x.shape
|
||||
y = x.view(batch_size, dim)
|
||||
return y + 1
|
||||
|
||||
def count_view_ops(graph_module):
|
||||
count = 0
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op == "call_function" and node.target in [
|
||||
torch.ops.aten.view.default,
|
||||
torch.ops.aten.reshape.default,
|
||||
]:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
device = "cuda" if HAS_GPU else "cpu"
|
||||
input_tensor = torch.randn(8, 16, device=device)
|
||||
|
||||
with inductor_config.patch(remove_pre_grad_passes=None):
|
||||
compiled_fn_default = torch.compile(fn_with_noop, fullgraph=True)
|
||||
result_default = compiled_fn_default(input_tensor)
|
||||
|
||||
with inductor_config.patch(remove_pre_grad_passes="remove_noop"):
|
||||
compiled_fn_skip_noop = torch.compile(fn_with_noop, fullgraph=True)
|
||||
result_skip_noop = compiled_fn_skip_noop(input_tensor)
|
||||
|
||||
expected = fn_with_noop(input_tensor)
|
||||
torch.testing.assert_close(result_default, expected)
|
||||
torch.testing.assert_close(result_skip_noop, expected)
|
||||
|
||||
from torch._inductor.fx_passes.pre_grad import pre_grad_passes
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
with inductor_config.patch(
|
||||
is_predispatch=True, pattern_matcher=True, remove_pre_grad_passes=None
|
||||
):
|
||||
gm_default = make_fx(fn_with_noop)(input_tensor)
|
||||
gm_default_processed = pre_grad_passes(
|
||||
gm_default, [input_tensor], add_passes=None, remove_passes=None
|
||||
)
|
||||
view_count_default = count_view_ops(gm_default_processed)
|
||||
|
||||
with inductor_config.patch(
|
||||
is_predispatch=True,
|
||||
pattern_matcher=True,
|
||||
remove_pre_grad_passes="remove_noop",
|
||||
):
|
||||
gm_skip_noop = make_fx(fn_with_noop)(input_tensor)
|
||||
gm_skip_noop_processed = pre_grad_passes(
|
||||
gm_skip_noop,
|
||||
[input_tensor],
|
||||
add_passes=None,
|
||||
remove_passes="remove_noop",
|
||||
)
|
||||
view_count_skip_noop = count_view_ops(gm_skip_noop_processed)
|
||||
|
||||
self.assertGreaterEqual(
|
||||
view_count_skip_noop,
|
||||
view_count_default,
|
||||
f"Expected view count with remove_noop disabled ({view_count_skip_noop}) "
|
||||
f"to be >= view count with remove_noop enabled ({view_count_default})",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_GPU:
|
||||
|
||||
@ -7,16 +7,17 @@ from itertools import product
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
dtypesIfXPU,
|
||||
instantiate_device_type_tests,
|
||||
largeTensorTest,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypes,
|
||||
onlyOn,
|
||||
skipCUDAIf,
|
||||
skipMeta,
|
||||
skipXPUIf,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
@ -29,6 +30,13 @@ from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
set_default_dtype,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TEST_XPU,
|
||||
)
|
||||
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
|
||||
@ -36,7 +44,7 @@ class TestEmbeddingNN(NNTestCase):
|
||||
_do_cuda_memory_leak_check = True
|
||||
_do_cuda_non_default_stream = True
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA/XPU unavailable")
|
||||
def test_embedding_max_norm_unsorted_repeating_indices(self):
|
||||
def create_embedding(device):
|
||||
# Seed RNG so we get the same Embedding each time
|
||||
@ -48,8 +56,8 @@ class TestEmbeddingNN(NNTestCase):
|
||||
ix = torch.arange(2, device="cpu", dtype=torch.long).repeat(2000)
|
||||
out_cpu = create_embedding("cpu")(ix)
|
||||
|
||||
ix = ix.to("cuda")
|
||||
out = create_embedding("cuda")(ix)
|
||||
ix = ix.to(device_type)
|
||||
out = create_embedding(device_type)(ix)
|
||||
self.assertEqual(out.cpu(), out_cpu)
|
||||
|
||||
def test_embedding_sparse_basic(self):
|
||||
@ -81,9 +89,9 @@ class TestEmbeddingNN(NNTestCase):
|
||||
self.assertEqual(embedding.embedding_dim, 3)
|
||||
self.assertEqual(embedding.num_embeddings, 10)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
embedding.to("cuda")
|
||||
self.assertEqual(embedding.weight.device.type, "cuda")
|
||||
if not torch.accelerator.is_available():
|
||||
embedding.to(device_type)
|
||||
self.assertEqual(embedding.weight.device.type, device_type)
|
||||
embedding.to("cpu")
|
||||
self.assertEqual(embedding.weight.device.type, "cpu")
|
||||
|
||||
@ -182,11 +190,11 @@ class TestEmbeddingNN(NNTestCase):
|
||||
self.assertEqual(res_old, res_F)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/130806
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
@largeTensorTest("40GB", device="cuda")
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA/XPU not available")
|
||||
@largeTensorTest("40GB", device=device_type)
|
||||
def test_large_tensors(self):
|
||||
input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
|
||||
w = torch.randn([16032, 16384], device="cuda")
|
||||
input = torch.randint(low=0, high=16032, size=[131072], device=device_type)
|
||||
w = torch.randn([16032, 16384], device=device_type)
|
||||
out = torch.nn.functional.embedding(input, w)
|
||||
self.assertEqual(out.dim(), 2)
|
||||
self.assertEqual(out.numel(), 2147483648)
|
||||
@ -308,6 +316,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
torch.nn.functional.embedding(indices, weight)
|
||||
|
||||
@dtypesIfCUDA(torch.float16, torch.float64)
|
||||
@dtypesIfXPU(torch.float16, torch.float64)
|
||||
@dtypes(torch.float64)
|
||||
def test_embedding_backward(self, device, dtype):
|
||||
embedding = nn.Embedding(10, 3, sparse=True)
|
||||
@ -348,6 +357,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
else (torch.float, torch.double, torch.half)
|
||||
)
|
||||
)
|
||||
@dtypesIfXPU(torch.float32, torch.double, torch.half)
|
||||
@dtypes(torch.float32)
|
||||
def test_embedding_max_norm_backward(self, device, dtype):
|
||||
# can't use gradcheck since in place renorm makes analytical gradients different from produced ones
|
||||
@ -372,6 +382,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
else (torch.float, torch.double, torch.half)
|
||||
)
|
||||
)
|
||||
@dtypesIfXPU(torch.float32, torch.double, torch.half)
|
||||
@dtypes(torch.float32)
|
||||
def test_embedding_max_norm_fwd_AD(self, device, dtype):
|
||||
if torch.device(device).type == "xla":
|
||||
@ -396,6 +407,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
else (torch.float, torch.double, torch.half)
|
||||
)
|
||||
)
|
||||
@dtypesIfXPU(torch.float32, torch.double, torch.half)
|
||||
@dtypes(torch.float32)
|
||||
def test_embedding_padding_idx(self, device, dtype):
|
||||
embedding = nn.Embedding(10, 20, padding_idx=0).to(device, dtype)
|
||||
@ -488,6 +500,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
@dtypesIfCUDA(torch.half, torch.bfloat16)
|
||||
@dtypesIfXPU(torch.half, torch.bfloat16)
|
||||
def test_embedding_bag_1D_padding_idx(self, device, dtype):
|
||||
num_features = 3
|
||||
max_indices_per_bag = 10
|
||||
@ -632,11 +645,12 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
weights.grad, weights_check.grad, msg=msg, atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
@dtypes(
|
||||
torch.bfloat16,
|
||||
)
|
||||
@largeTensorTest("80GB", device="cuda")
|
||||
@largeTensorTest("80GB", device="xpu")
|
||||
def test_embedding_backward_large_batch_overflow(self, device, dtype):
|
||||
"""
|
||||
Test that embedding_dense_backward handles large batches that exceed INT32_MAX thread IDs.
|
||||
@ -708,6 +722,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
@dtypesIfCUDA(torch.half, torch.bfloat16)
|
||||
@dtypesIfXPU(torch.half, torch.bfloat16)
|
||||
def test_embedding_bag_2D_padding_idx(self, device, dtype):
|
||||
# Use a Python implementation of embedding_bag with padding_idx support
|
||||
# to check torch.nn.functional.embedding_bag correctness
|
||||
@ -818,7 +833,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
rtol = None
|
||||
self.assertEqual(grad, grad_check, msg=msg, atol=atol, rtol=rtol)
|
||||
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
@dtypes(
|
||||
*(
|
||||
(torch.float, torch.double, torch.bfloat16, torch.half)
|
||||
@ -854,6 +869,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
self.assertEqual(output, torch.zeros_like(output))
|
||||
|
||||
@skipCUDAIf(True, "no out-of-bounds check on CUDA for perf.")
|
||||
@skipXPUIf(True, "no out-of-bounds check on XPU for perf.")
|
||||
@dtypes(*itertools.product((torch.float, torch.double), (torch.int, torch.long)))
|
||||
@parametrize_test("padding_idx", [None, 0])
|
||||
@parametrize_test("mode", ["sum", "mean", "max"])
|
||||
@ -1066,6 +1082,13 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
(torch.float, torch.double, torch.half),
|
||||
)
|
||||
)
|
||||
@dtypesIfXPU(
|
||||
*itertools.product(
|
||||
(torch.int, torch.long),
|
||||
(torch.int, torch.long),
|
||||
(torch.float32, torch.double, torch.half),
|
||||
)
|
||||
)
|
||||
def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes):
|
||||
# Test empty input and per sample weight, and backward pass. There was a CUDA
|
||||
# invalid configuration bug (more context in #46572)
|
||||
@ -1132,6 +1155,13 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
(torch.float, torch.double, torch.half),
|
||||
)
|
||||
)
|
||||
@dtypesIfXPU(
|
||||
*itertools.product(
|
||||
(torch.int, torch.long),
|
||||
(torch.int, torch.long),
|
||||
(torch.float32, torch.double, torch.half),
|
||||
)
|
||||
)
|
||||
def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes):
|
||||
def test_per_sample_weights(mode, trainable_scale):
|
||||
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
|
||||
@ -1193,6 +1223,13 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
(torch.float, torch.double, torch.half),
|
||||
)
|
||||
)
|
||||
@dtypesIfXPU(
|
||||
*itertools.product(
|
||||
(torch.int, torch.long),
|
||||
(torch.int, torch.long),
|
||||
(torch.float32, torch.double, torch.half),
|
||||
)
|
||||
)
|
||||
def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes):
|
||||
def test_per_sample_weights_new_offsets(
|
||||
mode, trainable_scale, include_last_offset, has_weight=True
|
||||
@ -1357,6 +1394,11 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
(torch.int, torch.long), (torch.half, torch.float, torch.double)
|
||||
)
|
||||
)
|
||||
@dtypesIfXPU(
|
||||
*itertools.product(
|
||||
(torch.int, torch.long), (torch.half, torch.float32, torch.double)
|
||||
)
|
||||
)
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
|
||||
def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes):
|
||||
def run_tests(mode, sparse, trainable_per_sample_weights):
|
||||
@ -1390,8 +1432,8 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
):
|
||||
run_tests(mode, sparse, trainable_per_sample_weights)
|
||||
|
||||
# Test CUDA Dense on half precision
|
||||
if device == "cuda":
|
||||
# Test CUDA/XPU Dense on half precision
|
||||
if device != "cpu":
|
||||
modes = ("sum",)
|
||||
sparsity = (False,)
|
||||
trainable_scale = (True, False)
|
||||
@ -1552,9 +1594,18 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
(torch.float, torch.double, torch.half),
|
||||
)
|
||||
)
|
||||
@dtypesIfXPU(
|
||||
*itertools.product(
|
||||
(torch.int, torch.long),
|
||||
(torch.int, torch.long),
|
||||
(torch.float32, torch.double, torch.half),
|
||||
)
|
||||
)
|
||||
def test_embedding_bag_device(self, device, dtypes):
|
||||
if IS_JETSON and torch.bfloat16 in dtypes and device == "cpu":
|
||||
self.skipTest("bfloat16 not supported with Jetson cpu")
|
||||
if dtypes[2] == torch.float64 and "xpu" in device:
|
||||
self.skipTest("https://github.com/intel/torch-xpu-ops/issues/2295")
|
||||
with set_default_dtype(torch.double):
|
||||
self._test_EmbeddingBag(
|
||||
device,
|
||||
@ -1582,10 +1633,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
)
|
||||
|
||||
test_backward = False
|
||||
if self.device_type == "cuda":
|
||||
if self.device_type != "cpu":
|
||||
# see 'todo' in test_embedding_bag.
|
||||
test_backward = dtypes[2] is not torch.float16
|
||||
elif self.device_type == "cpu":
|
||||
else:
|
||||
# TODO: figure out why precision on sparse embeddings isn't the
|
||||
# same as for dense.
|
||||
test_backward = (
|
||||
@ -1626,6 +1677,13 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
(torch.float, torch.double, torch.half),
|
||||
)
|
||||
)
|
||||
@dtypesIfXPU(
|
||||
*itertools.product(
|
||||
(torch.int, torch.long),
|
||||
(torch.int, torch.long),
|
||||
(torch.float32, torch.double, torch.half),
|
||||
)
|
||||
)
|
||||
def test_embedding_bag_non_contiguous_weight(self, device, dtypes):
|
||||
weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device)
|
||||
|
||||
@ -1703,7 +1761,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
||||
bag(x, per_sample_weights=F.softmax(w, dim=-1))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestEmbeddingNNDeviceType, globals())
|
||||
instantiate_device_type_tests(TestEmbeddingNNDeviceType, globals(), allow_xpu=True)
|
||||
instantiate_parametrized_tests(TestEmbeddingNN)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -239,6 +239,12 @@ class TestAccelerator(TestCase):
|
||||
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
|
||||
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
|
||||
|
||||
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
|
||||
def test_get_memory_info(self):
|
||||
free_bytes, total_bytes = torch.accelerator.get_memory_info()
|
||||
self.assertGreaterEqual(free_bytes, 0)
|
||||
self.assertGreaterEqual(total_bytes, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -17,12 +17,14 @@ from torch.testing._internal.common_device_type import (
|
||||
dtypesIfCPU,
|
||||
dtypesIfCUDA,
|
||||
dtypesIfMPS,
|
||||
dtypesIfXPU,
|
||||
expectedFailureMPS,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypes,
|
||||
onlyOn,
|
||||
skipXLA,
|
||||
skipXPUIf,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_mps_types_and,
|
||||
@ -38,6 +40,7 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TEST_MPS,
|
||||
TEST_XPU,
|
||||
TestCase,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
@ -598,8 +601,8 @@ class TestIndexing(TestCase):
|
||||
|
||||
# test invalid index fails
|
||||
reference = torch.empty(10, dtype=dtype, device=device)
|
||||
# can't test cuda because it is a device assert
|
||||
if not reference.is_cuda:
|
||||
# can't test cuda/xpu because it is a device assert
|
||||
if reference.device.type == "cpu":
|
||||
for err_idx in (10, -11):
|
||||
with self.assertRaisesRegex(IndexError, r"out of"):
|
||||
reference[err_idx]
|
||||
@ -744,7 +747,7 @@ class TestIndexing(TestCase):
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_set_eq(reference, indexer, 212)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
if torch.cuda.is_available():
|
||||
if torch.accelerator.is_available():
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6)
|
||||
@ -1009,7 +1012,7 @@ class TestIndexing(TestCase):
|
||||
@skipIfTorchDynamo(
|
||||
"This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472"
|
||||
)
|
||||
@serialTest(TEST_CUDA or TEST_MPS)
|
||||
@serialTest(TEST_CUDA or TEST_XPU or TEST_MPS)
|
||||
def test_index_put_accumulate_large_tensor(self, device):
|
||||
# This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
|
||||
N = (1 << 31) + 5
|
||||
@ -1086,7 +1089,7 @@ class TestIndexing(TestCase):
|
||||
out_cpu = t.index_put_(indices, values2d, accumulate=True)
|
||||
self.assertEqual(out_cuda.cpu(), out_cpu)
|
||||
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
def test_index_put_large_indices(self, device):
|
||||
def generate_indices(num_indices: int, index_range: int):
|
||||
indices = []
|
||||
@ -1138,7 +1141,7 @@ class TestIndexing(TestCase):
|
||||
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
|
||||
self.assertEqual(a_dev.cpu(), a)
|
||||
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
def test_index_put_accumulate_non_contiguous(self, device):
|
||||
t = torch.zeros((5, 2, 2))
|
||||
t_dev = t.to(device)
|
||||
@ -1157,7 +1160,7 @@ class TestIndexing(TestCase):
|
||||
|
||||
self.assertEqual(out_cuda.cpu(), out_cpu)
|
||||
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
def test_index_put_deterministic_with_optional_tensors(self, device):
|
||||
def func(x, i, v):
|
||||
with DeterministicGuard(True):
|
||||
@ -1188,7 +1191,7 @@ class TestIndexing(TestCase):
|
||||
indices = torch.tensor([1, 4, 3])
|
||||
indices_dev = indices.to(device)
|
||||
val = torch.randn(4)
|
||||
out_cuda = func1(t_dev, indices_dev, val.cuda())
|
||||
out_cuda = func1(t_dev, indices_dev, val.to(device))
|
||||
out_cpu = func1(t, indices, val)
|
||||
self.assertEqual(out_cuda.cpu(), out_cpu)
|
||||
|
||||
@ -1321,6 +1324,14 @@ class TestIndexing(TestCase):
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
)
|
||||
@dtypesIfXPU(
|
||||
torch.cfloat,
|
||||
torch.cdouble,
|
||||
torch.half,
|
||||
torch.long,
|
||||
torch.bool,
|
||||
torch.bfloat16,
|
||||
)
|
||||
@dtypesIfMPS(torch.float, torch.float16, torch.long, torch.bool)
|
||||
def test_index_put_src_datatype(self, device, dtype):
|
||||
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
|
||||
@ -1332,6 +1343,7 @@ class TestIndexing(TestCase):
|
||||
@dtypes(torch.float, torch.bfloat16, torch.long, torch.bool)
|
||||
@dtypesIfCPU(torch.float, torch.long, torch.bfloat16, torch.bool)
|
||||
@dtypesIfCUDA(torch.half, torch.long, torch.bfloat16, torch.bool)
|
||||
@dtypesIfXPU(torch.half, torch.long, torch.bfloat16, torch.bool)
|
||||
def test_index_src_datatype(self, device, dtype):
|
||||
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
|
||||
# test index
|
||||
@ -1630,7 +1642,7 @@ class TestIndexing(TestCase):
|
||||
|
||||
self.assertRaisesRegex(IndexError, "invalid index", runner)
|
||||
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
def test_invalid_device(self, device):
|
||||
idx = torch.tensor([0, 1])
|
||||
b = torch.zeros(5, device=device)
|
||||
@ -1642,7 +1654,7 @@ class TestIndexing(TestCase):
|
||||
lambda: torch.index_put_(b, (idx,), c, accumulate=accumulate),
|
||||
)
|
||||
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
def test_cpu_indices(self, device):
|
||||
idx = torch.tensor([0, 1])
|
||||
b = torch.zeros(2, device=device)
|
||||
@ -1718,7 +1730,7 @@ class TestIndexing(TestCase):
|
||||
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
|
||||
torch.take_along_dim(t, indices, dim=7)
|
||||
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
@dtypes(torch.float)
|
||||
def test_gather_take_along_dim_cross_device(self, device, dtype):
|
||||
shape = (2, 3, 1, 4)
|
||||
@ -1748,7 +1760,7 @@ class TestIndexing(TestCase):
|
||||
):
|
||||
torch.take_along_dim(t.cpu(), indices, dim=0)
|
||||
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
def test_cuda_broadcast_index_use_deterministic_algorithms(self, device):
|
||||
with DeterministicGuard(True):
|
||||
idx1 = torch.tensor([0])
|
||||
@ -1969,6 +1981,7 @@ class TestIndexing(TestCase):
|
||||
return (x, index, src)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1973")
|
||||
@expectedFailureMPS # See https://github.com/pytorch/pytorch/issues/161029
|
||||
def test_index_copy_deterministic(self, device: torch.device) -> None:
|
||||
for dim in range(3):
|
||||
@ -2011,6 +2024,7 @@ class TestIndexing(TestCase):
|
||||
self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1973")
|
||||
def test_index_put_non_accumulate_deterministic(self, device) -> None:
|
||||
with DeterministicGuard(True):
|
||||
for i in range(3):
|
||||
@ -2048,6 +2062,7 @@ class TestIndexing(TestCase):
|
||||
# The test fails for zero-dimensional tensors on XLA
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_complex_float8_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@dtypesIfXPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat))
|
||||
def test_index_select(self, device, dtype):
|
||||
num_src, num_out = 3, 5
|
||||
@ -2361,8 +2376,8 @@ class NumpyTests(TestCase):
|
||||
def test_trivial_fancy_out_of_bounds(self, device):
|
||||
a = torch.zeros(5, device=device)
|
||||
ind = torch.ones(20, dtype=torch.int64, device=device)
|
||||
if a.is_cuda:
|
||||
raise unittest.SkipTest("CUDA asserts instead of raising an exception")
|
||||
if a.device.type in ["cuda", "xpu"]:
|
||||
raise unittest.SkipTest("CUDA/XPU asserts instead of raising an exception")
|
||||
ind[-1] = 10
|
||||
self.assertRaises(IndexError, a.__getitem__, ind)
|
||||
self.assertRaises(IndexError, a.__setitem__, ind, 0)
|
||||
@ -2397,9 +2412,9 @@ class NumpyTests(TestCase):
|
||||
|
||||
|
||||
instantiate_device_type_tests(
|
||||
TestIndexing, globals(), except_for="meta", allow_mps=True
|
||||
TestIndexing, globals(), except_for="meta", allow_mps=True, allow_xpu=True
|
||||
)
|
||||
instantiate_device_type_tests(NumpyTests, globals(), except_for="meta")
|
||||
instantiate_device_type_tests(NumpyTests, globals(), except_for="meta", allow_xpu=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -6,11 +6,14 @@ import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
dtypesIfXPU,
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
onlyOn,
|
||||
skipMeta,
|
||||
skipXPUIf,
|
||||
)
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_WITH_ROCM
|
||||
from torch.nn.attention import SDPBackend
|
||||
|
||||
class TestMHADeviceType(TestCase):
|
||||
@torch.no_grad()
|
||||
@ -89,6 +92,7 @@ class TestMHADeviceType(TestCase):
|
||||
torch.testing.assert_close(v, correct_v)
|
||||
|
||||
@dtypesIfCUDA(torch.float)
|
||||
@dtypesIfXPU(torch.float)
|
||||
@dtypes(torch.float)
|
||||
@skipMeta
|
||||
def test_transform_bias_rescale_qkv(self, device, dtype):
|
||||
@ -99,9 +103,11 @@ class TestMHADeviceType(TestCase):
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.float)
|
||||
@dtypesIfXPU(torch.float)
|
||||
@dtypes(torch.float)
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/2182")
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
def test_transform_bias_rescale_qkv_nested(self, device, dtype):
|
||||
for use_padding in (False, True):
|
||||
with self.subTest(use_padding=use_padding):
|
||||
@ -185,9 +191,9 @@ class TestMHADeviceType(TestCase):
|
||||
embed_dim=embed_dim, num_heads=num_heads, qkv=native_qkv, proj=native_proj
|
||||
).to(dtype)
|
||||
|
||||
if device == "cuda":
|
||||
pt = pt.cuda()
|
||||
npt = npt.cuda()
|
||||
if device == "cuda" or device == "xpu":
|
||||
pt = pt.to(device)
|
||||
npt = npt.to(device)
|
||||
|
||||
ypt, weight_pt = pt(
|
||||
q,
|
||||
@ -266,6 +272,7 @@ class TestMHADeviceType(TestCase):
|
||||
self.assertEqual(weight_pt, weight_npt)
|
||||
|
||||
@dtypesIfCUDA(torch.float, torch.half)
|
||||
@dtypesIfXPU(torch.float, torch.half)
|
||||
@dtypes(torch.float)
|
||||
@skipMeta
|
||||
@parametrize("use_nt", [False, True])
|
||||
@ -285,10 +292,25 @@ class TestMHADeviceType(TestCase):
|
||||
with self.subTest(use_padding=use_padding, pad_all=pad_all,
|
||||
use_nt=use_nt, need_weights=need_weights,
|
||||
average_attn_weights=average_attn_weights):
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=False, enable_mem_efficient=False
|
||||
) if not fused else torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_mem_efficient=True
|
||||
sdpa_backends_fused = [
|
||||
SDPBackend.MATH,
|
||||
SDPBackend.OVERRIDEABLE,
|
||||
SDPBackend.CUDNN_ATTENTION,
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
]
|
||||
sdpa_backends_not_fused = [
|
||||
SDPBackend.MATH,
|
||||
SDPBackend.OVERRIDEABLE,
|
||||
SDPBackend.CUDNN_ATTENTION,
|
||||
]
|
||||
if device == "xpu":
|
||||
sdpa_backends_fused = [SDPBackend.OVERRIDEABLE, SDPBackend.MATH]
|
||||
sdpa_backends_not_fused = [SDPBackend.MATH]
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
sdpa_backends_not_fused
|
||||
) if not fused else torch.nn.attention.sdpa_kernel(
|
||||
sdpa_backends_fused
|
||||
):
|
||||
self._test_multihead_attention_impl(
|
||||
device,
|
||||
@ -302,6 +324,7 @@ class TestMHADeviceType(TestCase):
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.float, torch.half)
|
||||
@dtypesIfXPU(torch.float, torch.half)
|
||||
@dtypes(torch.float)
|
||||
@skipMeta
|
||||
@torch.no_grad()
|
||||
@ -316,6 +339,7 @@ class TestMHADeviceType(TestCase):
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.float, torch.half)
|
||||
@dtypesIfXPU(torch.float, torch.half)
|
||||
@dtypes(torch.float)
|
||||
@skipMeta
|
||||
@torch.no_grad()
|
||||
@ -330,7 +354,7 @@ class TestMHADeviceType(TestCase):
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestMHADeviceType, globals())
|
||||
instantiate_device_type_tests(TestMHADeviceType, globals(), allow_xpu=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -1967,6 +1967,8 @@ def _DTensor_OpSchema_recompute_comparison_key(self: OpSchema) -> None: ...
|
||||
def _DTensor_compute_global_tensor_info(
|
||||
tensor: Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
|
||||
) -> tuple[list[_int], list[_int]]: ...
|
||||
def _get_DTensor_sharding_propagator_cache_stats() -> tuple[_int, _int]: ...
|
||||
def _clear_DTensor_sharding_propagator_cache() -> None: ...
|
||||
|
||||
# Defined in torch/csrc/multiprocessing/init.cpp
|
||||
def _multiprocessing_init() -> None: ...
|
||||
@ -2501,6 +2503,7 @@ def _accelerator_emptyCache() -> None: ...
|
||||
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
|
||||
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
|
||||
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
|
||||
def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ...
|
||||
def _accelerator_setAllocatorSettings(env: str) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
||||
|
||||
@ -499,6 +499,9 @@ def pytreeify(
|
||||
root = mod.__self__
|
||||
|
||||
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
|
||||
torch._dynamo.eval_frame.check_user_input_output(
|
||||
flat_real_args[1 if root else 0 :], UserErrorType.INVALID_INPUT
|
||||
)
|
||||
|
||||
class Yield(Exception):
|
||||
pass
|
||||
|
||||
@ -264,13 +264,14 @@ def _run_pre_dispatch_passes(
|
||||
f"[Pre grad(predispatch IR)] Apply {pass_name} pass",
|
||||
)
|
||||
|
||||
# Remove noops at the end, which may be generated other passes.
|
||||
pass_execution_and_save(
|
||||
remove_noop_pass,
|
||||
gm,
|
||||
example_inputs,
|
||||
"[Pre grad(predispatch IR)]Apply remove_noop pass",
|
||||
)
|
||||
if "remove_noop" not in remove_passes_list:
|
||||
# Remove noops at the end, which may be generated other passes.
|
||||
pass_execution_and_save(
|
||||
remove_noop_pass,
|
||||
gm,
|
||||
example_inputs,
|
||||
"[Pre grad(predispatch IR)]Apply remove_noop pass",
|
||||
)
|
||||
shape_prop(gm)
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
from ._utils import _device_t, _get_device_index
|
||||
from .memory import (
|
||||
empty_cache,
|
||||
get_memory_info,
|
||||
max_memory_allocated,
|
||||
max_memory_reserved,
|
||||
memory_allocated,
|
||||
@ -25,9 +26,10 @@ __all__ = [
|
||||
"current_device_idx", # deprecated
|
||||
"current_device_index",
|
||||
"current_stream",
|
||||
"empty_cache",
|
||||
"device_count",
|
||||
"device_index",
|
||||
"empty_cache",
|
||||
"get_memory_info",
|
||||
"is_available",
|
||||
"max_memory_allocated",
|
||||
"max_memory_reserved",
|
||||
|
||||
@ -8,6 +8,7 @@ from ._utils import _device_t, _get_device_index
|
||||
|
||||
__all__ = [
|
||||
"empty_cache",
|
||||
"get_memory_info",
|
||||
"max_memory_allocated",
|
||||
"max_memory_reserved",
|
||||
"memory_allocated",
|
||||
@ -87,6 +88,9 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
OrderedDict[str, Any]: an ordered dictionary mapping statistic names to their values.
|
||||
"""
|
||||
if not torch._C._accelerator_isAllocatorInitialized():
|
||||
return OrderedDict()
|
||||
@ -117,6 +121,9 @@ def memory_allocated(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the current memory occupied by live tensors (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("allocated_bytes.all.current", 0)
|
||||
|
||||
@ -134,6 +141,9 @@ def max_memory_allocated(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the peak memory occupied by live tensors (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
|
||||
|
||||
@ -147,6 +157,9 @@ def memory_reserved(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the current memory reserved by PyTorch (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("reserved_bytes.all.current", 0)
|
||||
|
||||
@ -164,6 +177,9 @@ def max_memory_reserved(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the peak memory reserved by PyTorch (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
|
||||
|
||||
@ -200,3 +216,21 @@ def reset_peak_memory_stats(device_index: _device_t = None, /) -> None:
|
||||
"""
|
||||
device_index = _get_device_index(device_index, optional=True)
|
||||
return torch._C._accelerator_resetPeakStats(device_index)
|
||||
|
||||
|
||||
def get_memory_info(device_index: _device_t = None, /) -> tuple[int, int]:
|
||||
r"""Return the current device memory information for a given device index.
|
||||
|
||||
Args:
|
||||
device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: a tuple of two integers (free_memory, total_memory) in bytes.
|
||||
The first value is the free memory on the device (available across all processes and applications),
|
||||
The second value is the device's total hardware memory capacity.
|
||||
"""
|
||||
device_index = _get_device_index(device_index, optional=True)
|
||||
return torch._C._accelerator_getMemoryInfo(device_index)
|
||||
|
||||
@ -138,6 +138,13 @@ void initModule(PyObject* module) {
|
||||
at::accelerator::resetPeakStats(device_index);
|
||||
});
|
||||
|
||||
m.def("_accelerator_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
||||
const auto device_type = at::accelerator::getAccelerator(true).value();
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
py::gil_scoped_release no_gil;
|
||||
return at::accelerator::getMemoryInfo(device_index);
|
||||
});
|
||||
|
||||
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
|
||||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
});
|
||||
|
||||
@ -357,6 +357,8 @@ void ConcretePyInterpreterVTable::dispatch(
|
||||
nullptr,
|
||||
torch_api_function_overload.ptr(),
|
||||
nullptr,
|
||||
&op,
|
||||
&arguments,
|
||||
TorchFunctionName::TorchDispatch);
|
||||
pushPyOutToStack(
|
||||
op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -86,6 +86,15 @@ void pushPyOutToStack(
|
||||
py::object out,
|
||||
const char* msg);
|
||||
|
||||
py::handle get_dtensor_class();
|
||||
|
||||
py::object dispatchDTensorOp(
|
||||
const c10::OperatorHandle& op,
|
||||
py::handle py_op,
|
||||
py::handle args,
|
||||
py::handle kwargs,
|
||||
torch::jit::Stack* stack);
|
||||
|
||||
inline PyObject* THPVariable_WrapList(
|
||||
const torch::autograd::variable_list& inputs) {
|
||||
PyObject* pyinput = PyList_New(static_cast<Py_ssize_t>(inputs.size()));
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include <torch/csrc/utils/device_lazy_init.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <ATen/cuda/MemPool.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
template <typename T>
|
||||
@ -12,16 +13,16 @@ using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
||||
// NOLINTNEXTLINE(misc-use-internal-linkage)
|
||||
void THCPMemPool_init(PyObject* module) {
|
||||
auto torch_C_m = py::handle(module).cast<py::module>();
|
||||
shared_ptr_class_<::c10::cuda::MemPool>(torch_C_m, "_MemPool")
|
||||
shared_ptr_class_<::at::cuda::MemPool>(torch_C_m, "_MemPool")
|
||||
.def(
|
||||
py::init([](c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom) {
|
||||
torch::utils::device_lazy_init(at::kCUDA);
|
||||
return std::make_shared<::c10::cuda::MemPool>(
|
||||
return std::make_shared<::at::cuda::MemPool>(
|
||||
allocator, is_user_created, use_on_oom);
|
||||
}))
|
||||
.def_property_readonly("id", &::c10::cuda::MemPool::id)
|
||||
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator)
|
||||
.def("use_count", &::c10::cuda::MemPool::use_count);
|
||||
.def_property_readonly("id", &::at::cuda::MemPool::id)
|
||||
.def_property_readonly("allocator", &::at::cuda::MemPool::allocator)
|
||||
.def("use_count", &::at::cuda::MemPool::use_count);
|
||||
}
|
||||
|
||||
@ -1104,7 +1104,7 @@ ErrorType ProcessGroupNCCL::getError() {
|
||||
return error_;
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool, bool symm) {
|
||||
void ProcessGroupNCCL::registerMemPool(at::cuda::MemPool* pool, bool symm) {
|
||||
const auto key = std::to_string(pool->device());
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "Performing NCCL user buffer registration for all buffers in "
|
||||
@ -1138,7 +1138,7 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool, bool symm) {
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) {
|
||||
void ProcessGroupNCCL::deregisterMemPool(at::cuda::MemPool* pool) {
|
||||
const auto key = std::to_string(pool->device());
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "Performing NCCL user buffer deregistration for all buffers in "
|
||||
@ -5826,7 +5826,7 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
|
||||
reinterpret_cast<c10::cuda::CUDACachingAllocator::CUDAAllocator*>(
|
||||
getMemAllocator().get());
|
||||
// Pool is created
|
||||
memPool_ = std::make_unique<c10::cuda::MemPool>(allocator);
|
||||
memPool_ = std::make_unique<at::cuda::MemPool>(allocator);
|
||||
// Register so that we call ncclCommRegister on all new allocations
|
||||
registerMemPool(memPool_.get(), /*symmetric*/ false);
|
||||
LOG(INFO) << logPrefix() << "Created memory pool";
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <ATen/cuda/MemPool.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/core/StreamGuard.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
@ -1023,11 +1024,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
// Performs NCCL user buffer registration for all buffers in
|
||||
// the given MemPool
|
||||
void registerMemPool(c10::cuda::MemPool* pool, bool symm = false);
|
||||
void registerMemPool(at::cuda::MemPool* pool, bool symm = false);
|
||||
|
||||
// Performs NCCL user buffer de-registration for all buffers in
|
||||
// the given MemPool
|
||||
void deregisterMemPool(c10::cuda::MemPool* pool);
|
||||
void deregisterMemPool(at::cuda::MemPool* pool);
|
||||
|
||||
// This method adds a temporary extension for the timeout period,
|
||||
// applying to all collectives between the calling of this API and
|
||||
@ -1491,7 +1492,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
std::optional<bool> useNonblocking_{std::nullopt};
|
||||
|
||||
// Communication-optimized memory pool associated with this PG
|
||||
std::unique_ptr<c10::cuda::MemPool> memPool_ = nullptr;
|
||||
std::unique_ptr<at::cuda::MemPool> memPool_ = nullptr;
|
||||
};
|
||||
|
||||
// Reset the flighrecorder recordings for the current rank.
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include <torch/csrc/Layout.h>
|
||||
#include <torch/csrc/MemoryFormat.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/utils/invalid_arguments.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/python_torch_function_mode.h>
|
||||
@ -12,6 +13,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/PythonTorchFunctionTLS.h>
|
||||
#include <ATen/TracerMode.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <sstream>
|
||||
@ -301,6 +303,16 @@ static py::object maybe_get_registered_torch_dispatch_rule(
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool is_dtensor(PyObject* obj) {
|
||||
#ifdef USE_DISTRIBUTED
|
||||
const py::handle dtensor = get_dtensor_class();
|
||||
return (PyObject*)Py_TYPE(obj) == dtensor.ptr() ||
|
||||
py::isinstance(py::handle(obj), dtensor);
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
// NB: Invariant: if you run this function, you MUST test if the returned
|
||||
// py::object is nullptr, as this will occur WITHOUT error condition being set.
|
||||
// And if an error happens, this function is responsible for throwing a C++
|
||||
@ -313,8 +325,8 @@ static py::object dispatch_on_subclass(
|
||||
PyObject* torch_api_function,
|
||||
bool is_torch_function,
|
||||
const char* torch_function_name_str,
|
||||
std::optional<c10::impl::TorchDispatchModeKey> maybe_mode_key =
|
||||
std::nullopt) {
|
||||
const c10::OperatorHandle* opt_op,
|
||||
torch::jit::Stack* opt_stack) {
|
||||
py::object ret;
|
||||
for (auto& arg : overloaded_args) {
|
||||
py::object torch_function =
|
||||
@ -367,13 +379,39 @@ static py::object dispatch_on_subclass(
|
||||
}
|
||||
}
|
||||
|
||||
ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
|
||||
torch_function.ptr(),
|
||||
torch_api_function,
|
||||
py_types.ptr(),
|
||||
args,
|
||||
kwargs,
|
||||
NULL));
|
||||
if (!is_torch_function && is_dtensor(arg)) {
|
||||
if (opt_op && opt_stack) {
|
||||
ret = dispatchDTensorOp(
|
||||
*opt_op, torch_api_function, args, kwargs, opt_stack);
|
||||
} else {
|
||||
// Slow path -- reconstruct C++ data structures since they were not
|
||||
// provided.
|
||||
auto schema = py::cast<at::FunctionSchema>(
|
||||
py::handle(torch_api_function).attr("_schema"));
|
||||
auto opt_op_handle =
|
||||
c10::Dispatcher::singleton().findOp(schema.operator_name());
|
||||
TORCH_CHECK(
|
||||
opt_op_handle.has_value(),
|
||||
"could not look up op for ",
|
||||
schema.operator_name());
|
||||
const auto& op_handle = *opt_op_handle;
|
||||
auto stack = torch::jit::createStackForSchema(
|
||||
op_handle.schema(),
|
||||
py::reinterpret_borrow<py::args>(args),
|
||||
py::reinterpret_borrow<py::kwargs>(kwargs),
|
||||
std::nullopt);
|
||||
ret = dispatchDTensorOp(
|
||||
op_handle, torch_api_function, args, kwargs, &stack);
|
||||
}
|
||||
} else {
|
||||
ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
|
||||
torch_function.ptr(),
|
||||
torch_api_function,
|
||||
py_types.ptr(),
|
||||
args,
|
||||
kwargs,
|
||||
NULL));
|
||||
}
|
||||
if (ret.ptr() == nullptr) {
|
||||
throw python_error();
|
||||
}
|
||||
@ -480,6 +518,28 @@ auto handle_torch_function_no_python_arg_parser(
|
||||
PyObject* torch_api_function,
|
||||
const char* module_name,
|
||||
TorchFunctionName torch_function_name) -> PyObject* {
|
||||
return handle_torch_function_no_python_arg_parser(
|
||||
overloaded_args,
|
||||
args,
|
||||
kwargs,
|
||||
func_name,
|
||||
torch_api_function,
|
||||
module_name,
|
||||
nullptr,
|
||||
nullptr,
|
||||
torch_function_name);
|
||||
}
|
||||
|
||||
auto handle_torch_function_no_python_arg_parser(
|
||||
at::ArrayRef<PyObject*> overloaded_args,
|
||||
PyObject* args,
|
||||
PyObject* kwargs,
|
||||
const char* func_name,
|
||||
PyObject* torch_api_function,
|
||||
const char* module_name,
|
||||
const c10::OperatorHandle* opt_op,
|
||||
torch::jit::Stack* opt_stack,
|
||||
TorchFunctionName torch_function_name) -> PyObject* {
|
||||
const char* torch_function_name_str = nullptr;
|
||||
switch (torch_function_name) {
|
||||
case TorchFunctionName::TorchFunction:
|
||||
@ -579,7 +639,9 @@ auto handle_torch_function_no_python_arg_parser(
|
||||
py_types,
|
||||
torch_api_function,
|
||||
is_torch_function,
|
||||
torch_function_name_str);
|
||||
torch_function_name_str,
|
||||
opt_op,
|
||||
opt_stack);
|
||||
if (curr_ret.ptr() != nullptr) {
|
||||
ret = curr_ret;
|
||||
}
|
||||
|
||||
@ -1267,6 +1267,18 @@ auto TORCH_PYTHON_API handle_torch_function_no_python_arg_parser(
|
||||
TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction)
|
||||
-> PyObject*;
|
||||
|
||||
auto handle_torch_function_no_python_arg_parser(
|
||||
at::ArrayRef<PyObject*> overloaded_args,
|
||||
PyObject* args,
|
||||
PyObject* kwargs,
|
||||
const char* func_name,
|
||||
PyObject* torch_api_function,
|
||||
const char* module_name,
|
||||
const c10::OperatorHandle* opt_op,
|
||||
torch::jit::Stack* opt_stack,
|
||||
TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction)
|
||||
-> PyObject*;
|
||||
|
||||
// Used for getters of Tensor properties
|
||||
auto handle_torch_function_getter(
|
||||
THPVariable* self,
|
||||
|
||||
@ -386,23 +386,8 @@ static void bindGetDeviceProperties(PyObject* module) {
|
||||
static void initXpuMethodBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
||||
#if SYCL_COMPILER_VERSION >= 20250000
|
||||
auto total = at::xpu::getDeviceProperties(device_index)->global_mem_size;
|
||||
auto& device = c10::xpu::get_raw_device(device_index);
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
at::xpu::getDeviceProperties(device_index)->name,
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
auto free = device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return std::make_tuple(free, total);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"torch.xpu.mem_get_info requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
||||
#endif
|
||||
py::gil_scoped_release no_gil;
|
||||
return at::getDeviceAllocator(at::kXPU)->getMemoryInfo(device_index);
|
||||
});
|
||||
m.def(
|
||||
"_xpu_getStreamFromExternal",
|
||||
|
||||
@ -391,7 +391,6 @@ class MemTracker(TorchDispatchMode):
|
||||
# Weak references to the topmost AC module currently active
|
||||
self._ac_mod: Optional[weakref.ref] = None
|
||||
self._orig_resize = torch.UntypedStorage.resize_
|
||||
self._orig_dtensor_dispatch = DTensor._op_dispatcher.dispatch
|
||||
self._depth = 0
|
||||
|
||||
def _update_snap(
|
||||
|
||||
@ -338,14 +338,11 @@ class DTensor(torch.Tensor):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@torch._disable_dynamo
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override]
|
||||
return DTensor._op_dispatcher.dispatch(
|
||||
func,
|
||||
args,
|
||||
kwargs or {},
|
||||
# We just need to have an implementation here; the __torch_dispatch__ machinery
|
||||
# calls into a specific C++ fast path that doesn't call here.
|
||||
raise NotImplementedError(
|
||||
"DTensor.__torch_dispatch__ should not actually get called"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -12,7 +12,12 @@ import torch.distributed.tensor._random as random
|
||||
from torch._library.utils import fill_defaults
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor._op_schema import OpInfo, OpSchema, OutputSpecType
|
||||
from torch.distributed.tensor._op_schema import (
|
||||
OpInfo,
|
||||
OpSchema,
|
||||
OutputSharding,
|
||||
OutputSpecType,
|
||||
)
|
||||
from torch.distributed.tensor._random import is_rng_supported_mesh
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor._sharding_prop import ShardingPropagator
|
||||
@ -125,6 +130,8 @@ class OpDispatcher:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sharding_propagator = ShardingPropagator()
|
||||
# NOTE: must stay in sync with is_random_op in
|
||||
# torch/csrc/autograd/python_variable.cpp
|
||||
self._random_ops = {
|
||||
aten.native_dropout.default,
|
||||
aten.normal_.default,
|
||||
@ -157,26 +164,17 @@ class OpDispatcher:
|
||||
def _allow_implicit_replication(self, value: bool) -> None:
|
||||
return torch._C._set_dtensor_allow_implicit_replication(value)
|
||||
|
||||
def dispatch(
|
||||
def _propagate_op_sharding_non_cached_dispatch_slow_path(
|
||||
self,
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
op_info: OpInfo,
|
||||
) -> object:
|
||||
"""
|
||||
Main dispatching logic. Follows precedence order:
|
||||
(1) custom_op_handler
|
||||
(2) registered sharding strategy, then rule
|
||||
(3) composite implicit autograd decomposition
|
||||
"""
|
||||
if op_call in self._custom_op_handlers:
|
||||
return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
|
||||
|
||||
# extract local tensor and sharding infos to a OpInfo
|
||||
op_info = self.unwrap_to_op_info(op_call, args, kwargs)
|
||||
|
||||
try:
|
||||
self.sharding_propagator.propagate(op_info)
|
||||
return self.sharding_propagator.propagate_op_sharding_non_cached(
|
||||
op_info.schema
|
||||
)
|
||||
except NotImplementedError:
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
op_call.name(), torch._C.DispatchKey.CompositeImplicitAutograd
|
||||
@ -193,6 +191,12 @@ class OpDispatcher:
|
||||
f"{e}\n\nSharding propagation failed for {op_info.schema}"
|
||||
) from e
|
||||
|
||||
def _dispatch_get_local_results_slow_path(
|
||||
self,
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
op_info: OpInfo,
|
||||
) -> object:
|
||||
output_sharding = op_info.output_sharding
|
||||
assert output_sharding is not None, "output sharding should not be None"
|
||||
|
||||
@ -264,7 +268,7 @@ class OpDispatcher:
|
||||
# 2. if the return type is Tensor or List[Tensor], return empty
|
||||
# tensor(s) with correct dtype.
|
||||
spec = output_sharding.output_spec
|
||||
ret_list = op_info.schema.op._schema.returns
|
||||
ret_list = op_call._schema.returns
|
||||
|
||||
if spec is None:
|
||||
# For a scalar return type, the non-participating device has None
|
||||
@ -299,6 +303,23 @@ class OpDispatcher:
|
||||
raise NotImplementedError(
|
||||
f"return type {ret_type} in DTensor op is not supported"
|
||||
)
|
||||
return local_results
|
||||
|
||||
def _dispatch_fast_path_python_tail(
|
||||
self,
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
compute_mesh: DeviceMesh,
|
||||
output_sharding: OutputSharding,
|
||||
local_results: object,
|
||||
participating: bool,
|
||||
is_inplace_op: bool,
|
||||
is_out_variant_op: bool,
|
||||
) -> object:
|
||||
"""
|
||||
Tail of main dispatching logic, called from C++ fast path.
|
||||
"""
|
||||
|
||||
if output_sharding.output_spec is None:
|
||||
if op_call == aten.equal.default:
|
||||
@ -308,12 +329,12 @@ class OpDispatcher:
|
||||
assert local_results is None or isinstance(local_results, bool)
|
||||
r = torch.tensor(
|
||||
int(local_results) if local_results is not None else 1,
|
||||
device=mesh.device_type,
|
||||
device=compute_mesh.device_type,
|
||||
)
|
||||
dist.all_reduce(r, op=dist.ReduceOp.MIN)
|
||||
local_results = bool(r.item())
|
||||
|
||||
if op_info.schema.is_inplace_op():
|
||||
if is_inplace_op:
|
||||
# inplace op should return self instead of re-wrapping
|
||||
if output_sharding.output_spec is not None:
|
||||
# NOTE: aten.squeeze_.dim is an inplace op but it also may change
|
||||
@ -332,7 +353,7 @@ class OpDispatcher:
|
||||
return args[0]
|
||||
else:
|
||||
return None
|
||||
elif op_info.schema.is_out_variant_op():
|
||||
elif is_out_variant_op:
|
||||
# out variant could possibly have multiple out args (i.e. lu_unpack.out)
|
||||
output_specs = (
|
||||
(output_sharding.output_spec,)
|
||||
@ -351,8 +372,9 @@ class OpDispatcher:
|
||||
assert len(out_dts) >= 1, "out variant should have at least one out arg"
|
||||
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
|
||||
else:
|
||||
assert op_call == aten.equal.default, op_call
|
||||
ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
|
||||
if participating and op_info.schema.is_view_op():
|
||||
if participating and op_call._schema._is_view_op():
|
||||
return return_and_correct_aliasing(op_call, args, kwargs, ret)
|
||||
else:
|
||||
return ret
|
||||
@ -419,6 +441,15 @@ class OpDispatcher:
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
) -> OpInfo:
|
||||
return self._unwrap_to_op_info_impl(op_call, args, kwargs, True)
|
||||
|
||||
def _unwrap_to_op_info_impl(
|
||||
self,
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
create_schema: bool,
|
||||
) -> OpInfo:
|
||||
# get runtime schema info to determine whether to use pytree to flatten inputs
|
||||
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
|
||||
@ -495,7 +526,9 @@ class OpDispatcher:
|
||||
),
|
||||
kwargs_schema,
|
||||
schema_info=runtime_schema_info,
|
||||
),
|
||||
)
|
||||
if create_schema
|
||||
else None, # type: ignore[arg-type]
|
||||
args_schema,
|
||||
tuple(local_args),
|
||||
local_kwargs,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch._C
|
||||
from torch.distributed.tensor.debug._comm_mode import CommDebugMode
|
||||
from torch.distributed.tensor.debug._visualize_sharding import visualize_sharding
|
||||
|
||||
@ -6,11 +7,12 @@ from torch.distributed.tensor.debug._visualize_sharding import visualize_shardin
|
||||
__all__ = ["CommDebugMode", "visualize_sharding"]
|
||||
|
||||
|
||||
def _get_sharding_prop_cache_info():
|
||||
def _get_python_sharding_prop_cache_info():
|
||||
"""
|
||||
Get the cache info for the sharding propagation cache, used for debugging purpose only.
|
||||
Get the cache info for the Python sharding propagation cache, used for debugging purpose only.
|
||||
This would return a named tuple showing hits, misses, maxsize and cursize of the sharding
|
||||
propagator cache.
|
||||
propagator cache. Note that directly calling into the sharding propagator does not share cache
|
||||
state with the DTensor dispatch fast path!
|
||||
"""
|
||||
from torch.distributed.tensor._api import DTensor
|
||||
|
||||
@ -19,9 +21,17 @@ def _get_sharding_prop_cache_info():
|
||||
)
|
||||
|
||||
|
||||
def _clear_sharding_prop_cache():
|
||||
def _get_fast_path_sharding_prop_cache_stats():
|
||||
"""
|
||||
Clears the cache for the sharding propagation cache, used for debugging purpose only.
|
||||
Get a tuple (hits, misses) for the fast path sharding propagation cache, used for debugging
|
||||
only.
|
||||
"""
|
||||
return torch._C._get_DTensor_sharding_propagator_cache_stats()
|
||||
|
||||
|
||||
def _clear_python_sharding_prop_cache():
|
||||
"""
|
||||
Clears the cache for the Python sharding propagation cache, used for debugging purpose only.
|
||||
"""
|
||||
from torch.distributed.tensor._api import DTensor
|
||||
|
||||
@ -30,6 +40,13 @@ def _clear_sharding_prop_cache():
|
||||
)
|
||||
|
||||
|
||||
def _clear_fast_path_sharding_prop_cache():
|
||||
"""
|
||||
Clears the cache for the fast path sharding propagation cache, used for debugging purpose only.
|
||||
"""
|
||||
torch._C._clear_DTensor_sharding_propagator_cache()
|
||||
|
||||
|
||||
# Set namespace for exposed private names
|
||||
CommDebugMode.__module__ = "torch.distributed.tensor.debug"
|
||||
visualize_sharding.__module__ = "torch.distributed.tensor.debug"
|
||||
|
||||
@ -297,7 +297,9 @@ class _ParsedStackTrace:
|
||||
|
||||
|
||||
# get File:lineno code from stack_trace
|
||||
def _parse_stack_trace(stack_trace: str):
|
||||
def _parse_stack_trace(
|
||||
stack_trace: str, filter_fn: Optional[Callable[[str, str, str], bool]] = None
|
||||
):
|
||||
if stack_trace is None:
|
||||
return None
|
||||
pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
|
||||
@ -314,6 +316,8 @@ def _parse_stack_trace(stack_trace: str):
|
||||
name = matches.group(3)
|
||||
# next line should be the code
|
||||
code = lines[idx + 1].strip()
|
||||
if filter_fn and not filter_fn(file, name, code):
|
||||
continue
|
||||
return _ParsedStackTrace(file, lineno, name, code)
|
||||
return None
|
||||
|
||||
|
||||
@ -34,6 +34,8 @@ Usage::
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
@ -41,6 +43,7 @@ from typing import Any, Optional, TYPE_CHECKING # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.fx.graph import _parse_stack_trace
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
from torch.utils._python_dispatch import (
|
||||
_get_current_dispatch_mode,
|
||||
@ -193,6 +196,16 @@ def _get_stack_trace() -> str:
|
||||
return "".join(summary.format())
|
||||
|
||||
|
||||
def _get_user_stack_trace(stack_trace_str: str) -> str | None:
|
||||
# Extract user code stack trace, filtering out torch internals.
|
||||
torch_dir = os.path.dirname(inspect.getfile(torch))
|
||||
filter_fn = lambda file, name, code: not file.startswith(torch_dir + os.path.sep) # noqa: E731
|
||||
trace = _parse_stack_trace(stack_trace_str, filter_fn=filter_fn)
|
||||
if trace:
|
||||
return f"File: {trace.file}:{trace.lineno} in {trace.name}, code: {trace.code}"
|
||||
return None
|
||||
|
||||
|
||||
def _maybe_get_autograd_trace() -> str | None:
|
||||
if torch._C._current_autograd_node() is not None:
|
||||
tb = torch._C._current_autograd_node().metadata.get("traceback_") # type: ignore[attr-defined]
|
||||
@ -781,14 +794,55 @@ class DebugMode(TorchDispatchMode):
|
||||
self.operators.append(call)
|
||||
return call
|
||||
|
||||
def debug_string(self) -> str:
|
||||
def debug_string(self, show_stack_trace: bool = False) -> str:
|
||||
"""
|
||||
show_stack_trace: If True, display one-line stack trace summaries above groups
|
||||
of operations (similar to gm.print_readable() style).
|
||||
Requires record_stack_trace=True.
|
||||
"""
|
||||
with torch._C.DisableTorchFunction():
|
||||
result = ""
|
||||
result += "\n".join(
|
||||
" " + " " * op.call_depth + op.render(self.record_tensor_attributes)
|
||||
for op in self.operators
|
||||
)
|
||||
return result
|
||||
if not show_stack_trace:
|
||||
result = "\n".join(
|
||||
" "
|
||||
+ " " * op.call_depth
|
||||
+ op.render(self.record_tensor_attributes)
|
||||
for op in self.operators
|
||||
)
|
||||
return result
|
||||
|
||||
# Group operations by stack trace
|
||||
lines = []
|
||||
prev_stack_summary = None
|
||||
|
||||
for op in self.operators:
|
||||
# Get the stack trace: prefer fwd_stack_trace, fallback to stack_trace
|
||||
stack_trace = None
|
||||
if hasattr(op, "fwd_stack_trace") and op.fwd_stack_trace:
|
||||
stack_trace = op.fwd_stack_trace
|
||||
elif hasattr(op, "stack_trace") and op.stack_trace:
|
||||
stack_trace = op.stack_trace
|
||||
|
||||
stack_summary = None
|
||||
if stack_trace:
|
||||
stack_summary = _get_user_stack_trace(stack_trace)
|
||||
|
||||
if stack_summary and stack_summary != prev_stack_summary:
|
||||
# add blank line before stack trace comment for readability
|
||||
if lines: # don't add blank line at the very start
|
||||
lines.append("")
|
||||
indent = " " * (op.call_depth + 1)
|
||||
lines.append(indent + "# " + stack_summary)
|
||||
prev_stack_summary = stack_summary
|
||||
|
||||
# Add the operation line
|
||||
line = (
|
||||
" "
|
||||
+ " " * op.call_depth
|
||||
+ op.render(self.record_tensor_attributes)
|
||||
)
|
||||
lines.append(line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
|
||||
@ -190,6 +190,7 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
|
||||
int: the memory available on the device in units of bytes.
|
||||
int: the total memory on the device in units of bytes
|
||||
"""
|
||||
_lazy_init()
|
||||
device = _get_device_index(device, optional=True)
|
||||
return torch._C._xpu_getMemoryInfo(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user