mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Profiler][1/N] add profiler support for custom device. (#101554)
1. `torch.autograd.profiler` interface parameters changed. (use `self.use_device` instead of `self.use_cuda` facilitates access by other devices and integrate it in subsequent pr) 2. Modify `ProfilerEventStub`(aka `std::shared_ptr<CUevent_st>`) to `ProfilerVoidEventStub`(aka `std::shared_ptr<void>`) so that `ProfilerStubs` can be inherited by any `{device}Methods`. In addition, `cuda_event_start_` is renamed to `device_event_start_` , cuda and other devices can use this event pointer if needed. 4. custom device support using legacy profiling(add `ProfilerState::KINETO_PRIVATEUSE1_FALLBACK` option) 5. add `privateuse1Stubs` register (parse results and test cases are added in subsequent pr) Pull Request resolved: https://github.com/pytorch/pytorch/pull/101554 Approved by: https://github.com/aaronenyeshi
This commit is contained in:
committed by
PyTorch MergeBot
parent
1204463bd0
commit
2e8ce910bb
@ -33,6 +33,7 @@ class ProfilerEvent:
|
||||
def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
||||
def cpu_memory_usage(self) -> int: ...
|
||||
def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
||||
def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
||||
def cuda_memory_usage(self) -> int: ...
|
||||
def device(self) -> int: ...
|
||||
def handle(self) -> int: ...
|
||||
|
@ -25,6 +25,7 @@ class ProfilerState(Enum):
|
||||
ITT = ...
|
||||
KINETO = ...
|
||||
KINETO_GPU_FALLBACK = ...
|
||||
KINETO_PRIVATEUSE1_FALLBACK = ...
|
||||
|
||||
class ActiveProfilerType(Enum):
|
||||
NONE = ...
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import torch.cuda
|
||||
from torch._C._profiler import _ExperimentalConfig
|
||||
from torch._C import _get_privateuse1_backend_name
|
||||
|
||||
from torch.autograd import (
|
||||
_disable_profiler,
|
||||
@ -170,6 +171,7 @@ class profile:
|
||||
enabled=True,
|
||||
*,
|
||||
use_cuda=False,
|
||||
use_device=None,
|
||||
record_shapes=False,
|
||||
with_flops=False,
|
||||
profile_memory=False,
|
||||
@ -182,6 +184,7 @@ class profile:
|
||||
if not self.enabled:
|
||||
return
|
||||
self.use_cuda = use_cuda
|
||||
self.use_device = use_device
|
||||
self.function_events: Optional[EventList] = None
|
||||
self.entered = False
|
||||
self.record_shapes = record_shapes
|
||||
@ -217,6 +220,22 @@ class profile:
|
||||
else:
|
||||
self.kineto_activities.add(ProfilerActivity.CUDA)
|
||||
|
||||
if self.use_device:
|
||||
if self.use_device == 'cuda':
|
||||
# TODO:using 'use_device' instead of 'use_cuda' facilitates access by other devices
|
||||
# and integrate it in subsequent pr.
|
||||
pass
|
||||
elif self.use_device == _get_privateuse1_backend_name():
|
||||
if not use_kineto:
|
||||
assert self.use_cpu, "Legacy custombackend profiling requires use_cpu=True"
|
||||
self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Now, custombackend events does not support Kineto (use_kineto=False)"
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"{self.use_device} doesn't support profile.")
|
||||
|
||||
assert len(self.kineto_activities) > 0, \
|
||||
"No activities specified for the profiler"
|
||||
|
||||
|
@ -245,6 +245,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
// Whether this is async event or not
|
||||
.def("is_async", [](const KinetoEvent& e) { return e.isAsync(); })
|
||||
.def("cuda_elapsed_us", &KinetoEvent::cudaElapsedUs)
|
||||
.def("privateuse1_elapsed_us", &KinetoEvent::privateuse1ElapsedUs)
|
||||
.def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); });
|
||||
|
||||
m.def("_soft_assert_raises", &setSoftAssertRaises);
|
||||
|
@ -442,10 +442,15 @@ void onFunctionExit(
|
||||
auto fallback = kineto_ctx_ptr->fallback_;
|
||||
TORCH_INTERNAL_ASSERT(fallback != nullptr);
|
||||
torch::profiler::impl::cudaStubs()->record(
|
||||
nullptr, &fallback->cuda_event_end_, nullptr);
|
||||
nullptr, &fallback->device_event_end_, nullptr);
|
||||
} catch (const std::exception& e) {
|
||||
LOG(WARNING) << "Failed to record CUDA event. " << e.what();
|
||||
}
|
||||
} else if (config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
|
||||
auto fallback = kineto_ctx_ptr->fallback_;
|
||||
TORCH_INTERNAL_ASSERT(fallback != nullptr);
|
||||
torch::profiler::impl::privateuse1Stubs()->record(
|
||||
nullptr, &fallback->device_event_end_, nullptr);
|
||||
}
|
||||
|
||||
if (fn.scope() == at::RecordScope::USER_SCOPE) {
|
||||
@ -519,7 +524,8 @@ void prepareProfiler(
|
||||
}
|
||||
TORCH_CHECK(
|
||||
config.state == ProfilerState::KINETO ||
|
||||
config.state == ProfilerState::KINETO_GPU_FALLBACK,
|
||||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
|
||||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK,
|
||||
"Supported only in Kineto profiler");
|
||||
torch::profiler::impl::kineto::prepareTrace(
|
||||
/*cpuOnly=*/!(at::hasCUDA() || at::hasXPU() || at::hasMTIA()),
|
||||
@ -594,7 +600,9 @@ void enableProfiler(
|
||||
|
||||
TORCH_CHECK(
|
||||
config.state == ProfilerState::KINETO ||
|
||||
config.state == ProfilerState::KINETO_GPU_FALLBACK || config.global());
|
||||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
|
||||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK ||
|
||||
config.global());
|
||||
TORCH_CHECK(!activities.empty(), "No activities specified.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
has_cpu || !config.global(),
|
||||
@ -620,6 +628,7 @@ std::unique_ptr<ProfilerResult> disableProfiler() {
|
||||
state_ptr &&
|
||||
(config.state == ProfilerState::KINETO ||
|
||||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
|
||||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK ||
|
||||
config.state == ProfilerState::KINETO_ONDEMAND ||
|
||||
config.state == ProfilerState::NVTX ||
|
||||
config.state == ProfilerState::ITT),
|
||||
@ -634,14 +643,15 @@ std::unique_ptr<ProfilerResult> disableProfiler() {
|
||||
return std::make_unique<ProfilerResult>();
|
||||
}
|
||||
|
||||
// Shared among NVTX, KINETO, KINETO_GPU_FALLBACK
|
||||
// Shared among NVTX, KINETO, KINETO_GPU_FALLBACK, KINETO_PRIVATEUSE1_FALLBACK
|
||||
std::unique_ptr<ProfilerResult> result;
|
||||
if (state_ptr->config().state == ProfilerState::NVTX) {
|
||||
result = std::make_unique<ProfilerResult>();
|
||||
}
|
||||
|
||||
if (config.state == ProfilerState::KINETO ||
|
||||
config.state == ProfilerState::KINETO_GPU_FALLBACK) {
|
||||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
|
||||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
|
||||
auto kineto_state_ptr =
|
||||
std::static_pointer_cast<KinetoThreadLocalState>(state_ptr);
|
||||
auto trace = kineto_state_ptr->finalizeTrace();
|
||||
@ -774,6 +784,17 @@ int64_t KinetoEvent::cudaElapsedUs() const {
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64_t KinetoEvent::privateuse1ElapsedUs() const {
|
||||
auto privateuse1_event_start = fallbackStart();
|
||||
auto privateuse1_event_end = fallbackEnd();
|
||||
if (!privateuse1_event_start || !privateuse1_event_end) {
|
||||
return -1;
|
||||
}
|
||||
return (int64_t)torch::profiler::impl::privateuse1Stubs()->elapsed(
|
||||
&privateuse1_event_start, &privateuse1_event_end);
|
||||
return -1;
|
||||
}
|
||||
|
||||
void KinetoEvent::getPerfEventCounters(std::vector<uint64_t>& in) const {
|
||||
return result_->visit(c10::overloaded(
|
||||
[&in](const ExtraFields<EventType::TorchOp>& e) -> void {
|
||||
@ -829,8 +850,8 @@ TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0)
|
||||
TYPED_ATTR(TorchOp, scope, static_cast<uint8_t>(e.scope_))
|
||||
TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty())
|
||||
TYPED_ATTR(TorchOp, isAsync, e.is_async_)
|
||||
TYPED_ATTR(TorchOp, fallbackStart, e.gpu_fallback_.cuda_event_start_)
|
||||
TYPED_ATTR(TorchOp, fallbackEnd, e.gpu_fallback_.cuda_event_end_)
|
||||
TYPED_ATTR(TorchOp, fallbackStart, e.device_fallback_.device_event_start_)
|
||||
TYPED_ATTR(TorchOp, fallbackEnd, e.device_fallback_.device_event_end_)
|
||||
TYPED_ATTR(
|
||||
TorchOp,
|
||||
flops,
|
||||
|
@ -57,11 +57,12 @@ struct TORCH_API KinetoEvent {
|
||||
std::string backend() const;
|
||||
bool isPythonFunction() const;
|
||||
int64_t cudaElapsedUs() const;
|
||||
int64_t privateuse1ElapsedUs() const;
|
||||
void getPerfEventCounters(torch::profiler::perf_counters_t&) const;
|
||||
|
||||
private:
|
||||
torch::profiler::impl::ProfilerEventStub fallbackStart() const;
|
||||
torch::profiler::impl::ProfilerEventStub fallbackEnd() const;
|
||||
torch::profiler::impl::ProfilerVoidEventStub fallbackStart() const;
|
||||
torch::profiler::impl::ProfilerVoidEventStub fallbackEnd() const;
|
||||
|
||||
std::shared_ptr<const torch::profiler::impl::Result> result_;
|
||||
std::vector<std::string> python_stack_;
|
||||
|
@ -267,7 +267,7 @@ struct TORCH_API LegacyEvent {
|
||||
int64_t cpu_memory_usage_ = 0;
|
||||
int64_t cuda_memory_usage_ = 0;
|
||||
int device_ = -1;
|
||||
torch::profiler::impl::ProfilerEventStub cuda_event = nullptr;
|
||||
torch::profiler::impl::ProfilerVoidEventStub cuda_event = nullptr;
|
||||
int node_id_ = 0;
|
||||
bool is_remote_ = false;
|
||||
int64_t cuda_us_ = -1;
|
||||
|
@ -408,7 +408,8 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
|
||||
auto profilingConfig = rpcWithProfilingReq.getProfilingConfig();
|
||||
|
||||
if (profilingConfig.state == ProfilerState::KINETO ||
|
||||
profilingConfig.state == ProfilerState::KINETO_GPU_FALLBACK) {
|
||||
profilingConfig.state == ProfilerState::KINETO_GPU_FALLBACK ||
|
||||
profilingConfig.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
|
||||
profilingConfig = ProfilerConfig(
|
||||
ProfilerState::CPU,
|
||||
profilingConfig.report_input_shapes,
|
||||
|
@ -335,12 +335,16 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
|
||||
|
||||
if (config_.state == ProfilerState::KINETO_GPU_FALLBACK) {
|
||||
try {
|
||||
out->fallback_ = torch_ops_.gpu_fallback_.emplace_back();
|
||||
out->fallback_ = torch_ops_.device_fallback_.emplace_back();
|
||||
torch::profiler::impl::cudaStubs()->record(
|
||||
nullptr, &out->fallback_->cuda_event_start_, nullptr);
|
||||
nullptr, &out->fallback_->device_event_start_, nullptr);
|
||||
} catch (const std::exception& e) {
|
||||
LOG(WARNING) << "Failed to record CUDA event. " << e.what();
|
||||
}
|
||||
} else if (config_.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
|
||||
out->fallback_ = torch_ops_.device_fallback_.emplace_back();
|
||||
torch::profiler::impl::privateuse1Stubs()->record(
|
||||
nullptr, &out->fallback_->device_event_start_, nullptr);
|
||||
}
|
||||
|
||||
event->start_time_ = torch::profiler::impl::getApproximateTime();
|
||||
@ -420,7 +424,8 @@ void ThreadLocalSubqueue::TorchOpStorage::materialize(
|
||||
auto jit_stack = StealOrDefault<decltype(jit_stack_)>(jit_stack_);
|
||||
auto jit_module = StealOrDefault<decltype(jit_modules_)>(jit_modules_);
|
||||
auto extra_args = StealOrDefault<decltype(extra_args_)>(extra_args_);
|
||||
auto gpu_fallback = StealOrDefault<decltype(gpu_fallback_)>(gpu_fallback_);
|
||||
auto gpu_fallback =
|
||||
StealOrDefault<decltype(device_fallback_)>(device_fallback_);
|
||||
|
||||
for (auto event = op_events_.begin(); event != op_events_.end(); ++event) {
|
||||
ExtraFields<EventType::TorchOp> e{
|
||||
|
@ -116,8 +116,8 @@ using jit_modules_t = std::vector<std::string>;
|
||||
using extra_args_t = std::unordered_map<std::string, c10::IValue>;
|
||||
|
||||
struct FallbackPair {
|
||||
ProfilerEventStub cuda_event_start_ = nullptr;
|
||||
ProfilerEventStub cuda_event_end_ = nullptr;
|
||||
ProfilerVoidEventStub device_event_start_ = nullptr;
|
||||
ProfilerVoidEventStub device_event_end_ = nullptr;
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -131,7 +131,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
|
||||
jit_stack_t&& jit_stack,
|
||||
jit_modules_t&& jit_modules,
|
||||
extra_args_t&& extra_args,
|
||||
FallbackPair&& gpu_fallback,
|
||||
FallbackPair&& device_fallback,
|
||||
bool allow_tf32_cublas,
|
||||
std::unique_ptr<perf_counters_t>&& perf_event_counters)
|
||||
: TorchOpBasicFields(std::move(f)),
|
||||
@ -142,7 +142,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
|
||||
jit_stack_{std::move(jit_stack)},
|
||||
jit_modules_{std::move(jit_modules)},
|
||||
extra_args_{std::move(extra_args)},
|
||||
gpu_fallback_{std::move(gpu_fallback)},
|
||||
device_fallback_{std::move(device_fallback)},
|
||||
allow_tf32_cublas_{allow_tf32_cublas},
|
||||
perf_event_counters_{std::move(perf_event_counters)} {}
|
||||
uint64_t correlation_id_;
|
||||
@ -152,7 +152,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
|
||||
jit_stack_t jit_stack_;
|
||||
jit_modules_t jit_modules_;
|
||||
extra_args_t extra_args_;
|
||||
FallbackPair gpu_fallback_;
|
||||
FallbackPair device_fallback_;
|
||||
bool allow_tf32_cublas_;
|
||||
std::unique_ptr<perf_counters_t> perf_event_counters_;
|
||||
};
|
||||
@ -579,8 +579,9 @@ class TORCH_API ThreadLocalSubqueue {
|
||||
// with_flops
|
||||
AppendOnlyList<extra_args_t, BlockSize> extra_args_;
|
||||
|
||||
// ProfilerState::KINETO_GPU_FALLBACK
|
||||
AppendOnlyList<FallbackPair, BlockSize> gpu_fallback_;
|
||||
// ProfilerState::KINETO_GPU_FALLBACK or
|
||||
// ProfilerState::KINETO_PRIVATEUSE1_FALLBACK
|
||||
AppendOnlyList<FallbackPair, BlockSize> device_fallback_;
|
||||
} torch_ops_;
|
||||
|
||||
// reportBackendEventToActiveKinetoProfiler
|
||||
|
@ -28,6 +28,7 @@ enum class C10_API_ENUM ProfilerState {
|
||||
ITT, // only emit ITT markers
|
||||
KINETO, // use libkineto
|
||||
KINETO_GPU_FALLBACK, // use CUDA events when CUPTI is not available
|
||||
KINETO_PRIVATEUSE1_FALLBACK, // use PrivateUse1 events
|
||||
KINETO_ONDEMAND, // run the profiler in on-demand mode
|
||||
NUM_PROFILER_STATES, // must be the last one
|
||||
};
|
||||
|
@ -38,7 +38,10 @@ void initPythonBindings(PyObject* module) {
|
||||
.value("NVTX", ProfilerState::NVTX)
|
||||
.value("ITT", ProfilerState::ITT)
|
||||
.value("KINETO", ProfilerState::KINETO)
|
||||
.value("KINETO_GPU_FALLBACK", ProfilerState::KINETO_GPU_FALLBACK);
|
||||
.value("KINETO_GPU_FALLBACK", ProfilerState::KINETO_GPU_FALLBACK)
|
||||
.value(
|
||||
"KINETO_PRIVATEUSE1_FALLBACK",
|
||||
ProfilerState::KINETO_PRIVATEUSE1_FALLBACK);
|
||||
|
||||
py::enum_<ActiveProfilerType>(m, "ActiveProfilerType")
|
||||
.value("NONE", ActiveProfilerType::NONE)
|
||||
|
@ -12,10 +12,10 @@ namespace {
|
||||
struct DefaultStubs : public ProfilerStubs {
|
||||
DefaultStubs(const char* name) : name_{name} {}
|
||||
|
||||
void record(int*, ProfilerEventStub*, int64_t*) const override {
|
||||
void record(int*, ProfilerVoidEventStub*, int64_t*) const override {
|
||||
fail();
|
||||
}
|
||||
float elapsed(const ProfilerEventStub*, const ProfilerEventStub*)
|
||||
float elapsed(const ProfilerVoidEventStub*, const ProfilerVoidEventStub*)
|
||||
const override {
|
||||
fail();
|
||||
return 0.f;
|
||||
@ -74,6 +74,7 @@ struct DefaultStubs : public ProfilerStubs {
|
||||
|
||||
REGISTER_DEFAULT(cuda, CUDA)
|
||||
REGISTER_DEFAULT(itt, ITT)
|
||||
REGISTER_DEFAULT(privateuse1, PrivateUse1)
|
||||
#undef REGISTER_DEFAULT
|
||||
|
||||
} // namespace impl
|
||||
|
@ -16,13 +16,16 @@ namespace impl {
|
||||
// -- Annotation --------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------
|
||||
using ProfilerEventStub = std::shared_ptr<CUevent_st>;
|
||||
using ProfilerVoidEventStub = std::shared_ptr<void>;
|
||||
|
||||
struct TORCH_API ProfilerStubs {
|
||||
virtual void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
|
||||
const = 0;
|
||||
virtual void record(
|
||||
int* device,
|
||||
ProfilerVoidEventStub* event,
|
||||
int64_t* cpu_ns) const = 0;
|
||||
virtual float elapsed(
|
||||
const ProfilerEventStub* event,
|
||||
const ProfilerEventStub* event2) const = 0;
|
||||
const ProfilerVoidEventStub* event,
|
||||
const ProfilerVoidEventStub* event2) const = 0;
|
||||
virtual void mark(const char* name) const = 0;
|
||||
virtual void rangePush(const char* name) const = 0;
|
||||
virtual void rangePop() const = 0;
|
||||
@ -38,6 +41,8 @@ TORCH_API void registerCUDAMethods(ProfilerStubs* stubs);
|
||||
TORCH_API const ProfilerStubs* cudaStubs();
|
||||
TORCH_API void registerITTMethods(ProfilerStubs* stubs);
|
||||
TORCH_API const ProfilerStubs* ittStubs();
|
||||
TORCH_API void registerPrivateUse1Methods(ProfilerStubs* stubs);
|
||||
TORCH_API const ProfilerStubs* privateuse1Stubs();
|
||||
|
||||
using vulkan_id_t = strong::type<
|
||||
int64_t,
|
||||
|
@ -36,7 +36,7 @@ static inline void cudaCheck(cudaError_t result, const char* file, int line) {
|
||||
#define TORCH_CUDA_CHECK(result) cudaCheck(result, __FILE__, __LINE__);
|
||||
|
||||
struct CUDAMethods : public ProfilerStubs {
|
||||
void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
|
||||
void record(int* device, ProfilerVoidEventStub* event, int64_t* cpu_ns)
|
||||
const override {
|
||||
if (device) {
|
||||
TORCH_CUDA_CHECK(c10::cuda::GetDevice(device));
|
||||
@ -54,8 +54,11 @@ struct CUDAMethods : public ProfilerStubs {
|
||||
TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream));
|
||||
}
|
||||
|
||||
float elapsed(const ProfilerEventStub* event, const ProfilerEventStub* event2)
|
||||
const override {
|
||||
float elapsed(
|
||||
const ProfilerVoidEventStub* event_,
|
||||
const ProfilerVoidEventStub* event2_) const override {
|
||||
auto event = (const ProfilerEventStub*)(event_);
|
||||
auto event2 = (const ProfilerEventStub*)(event2_);
|
||||
TORCH_CUDA_CHECK(cudaEventSynchronize(event->get()));
|
||||
TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get()));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
|
@ -10,11 +10,12 @@ namespace impl {
|
||||
namespace {
|
||||
|
||||
struct ITTMethods : public ProfilerStubs {
|
||||
void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
|
||||
void record(int* device, ProfilerVoidEventStub* event, int64_t* cpu_ns)
|
||||
const override {}
|
||||
|
||||
float elapsed(const ProfilerEventStub* event, const ProfilerEventStub* event2)
|
||||
const override {
|
||||
float elapsed(
|
||||
const ProfilerVoidEventStub* event,
|
||||
const ProfilerVoidEventStub* event2) const override {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user