[Profiler][1/N] add profiler support for custom device. (#101554)

1. `torch.autograd.profiler` interface parameters changed. (use `self.use_device` instead of `self.use_cuda` facilitates access by other devices and integrate it in subsequent pr)
2. Modify `ProfilerEventStub`(aka `std::shared_ptr<CUevent_st>`) to `ProfilerVoidEventStub`(aka `std::shared_ptr<void>`) so that `ProfilerStubs` can be inherited by any `{device}Methods`.
In addition, `cuda_event_start_` is renamed to `device_event_start_` , cuda and other devices can use this event pointer if needed.
4. custom device support using legacy profiling(add `ProfilerState::KINETO_PRIVATEUSE1_FALLBACK` option)
5. add `privateuse1Stubs` register
(parse results and test cases are added in subsequent pr)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101554
Approved by: https://github.com/aaronenyeshi
This commit is contained in:
dujinhang
2023-06-02 09:19:14 +00:00
committed by PyTorch MergeBot
parent 1204463bd0
commit 2e8ce910bb
16 changed files with 99 additions and 34 deletions

View File

@ -33,6 +33,7 @@ class ProfilerEvent:
def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
def cpu_memory_usage(self) -> int: ...
def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
def cuda_memory_usage(self) -> int: ...
def device(self) -> int: ...
def handle(self) -> int: ...

View File

@ -25,6 +25,7 @@ class ProfilerState(Enum):
ITT = ...
KINETO = ...
KINETO_GPU_FALLBACK = ...
KINETO_PRIVATEUSE1_FALLBACK = ...
class ActiveProfilerType(Enum):
NONE = ...

View File

@ -6,6 +6,7 @@ import torch
import torch.cuda
from torch._C._profiler import _ExperimentalConfig
from torch._C import _get_privateuse1_backend_name
from torch.autograd import (
_disable_profiler,
@ -170,6 +171,7 @@ class profile:
enabled=True,
*,
use_cuda=False,
use_device=None,
record_shapes=False,
with_flops=False,
profile_memory=False,
@ -182,6 +184,7 @@ class profile:
if not self.enabled:
return
self.use_cuda = use_cuda
self.use_device = use_device
self.function_events: Optional[EventList] = None
self.entered = False
self.record_shapes = record_shapes
@ -217,6 +220,22 @@ class profile:
else:
self.kineto_activities.add(ProfilerActivity.CUDA)
if self.use_device:
if self.use_device == 'cuda':
# TODO:using 'use_device' instead of 'use_cuda' facilitates access by other devices
# and integrate it in subsequent pr.
pass
elif self.use_device == _get_privateuse1_backend_name():
if not use_kineto:
assert self.use_cpu, "Legacy custombackend profiling requires use_cpu=True"
self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK
else:
raise AssertionError(
"Now, custombackend events does not support Kineto (use_kineto=False)"
)
else:
raise AssertionError(f"{self.use_device} doesn't support profile.")
assert len(self.kineto_activities) > 0, \
"No activities specified for the profiler"

View File

@ -245,6 +245,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
// Whether this is async event or not
.def("is_async", [](const KinetoEvent& e) { return e.isAsync(); })
.def("cuda_elapsed_us", &KinetoEvent::cudaElapsedUs)
.def("privateuse1_elapsed_us", &KinetoEvent::privateuse1ElapsedUs)
.def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); });
m.def("_soft_assert_raises", &setSoftAssertRaises);

View File

@ -442,10 +442,15 @@ void onFunctionExit(
auto fallback = kineto_ctx_ptr->fallback_;
TORCH_INTERNAL_ASSERT(fallback != nullptr);
torch::profiler::impl::cudaStubs()->record(
nullptr, &fallback->cuda_event_end_, nullptr);
nullptr, &fallback->device_event_end_, nullptr);
} catch (const std::exception& e) {
LOG(WARNING) << "Failed to record CUDA event. " << e.what();
}
} else if (config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
auto fallback = kineto_ctx_ptr->fallback_;
TORCH_INTERNAL_ASSERT(fallback != nullptr);
torch::profiler::impl::privateuse1Stubs()->record(
nullptr, &fallback->device_event_end_, nullptr);
}
if (fn.scope() == at::RecordScope::USER_SCOPE) {
@ -519,7 +524,8 @@ void prepareProfiler(
}
TORCH_CHECK(
config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK,
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK,
"Supported only in Kineto profiler");
torch::profiler::impl::kineto::prepareTrace(
/*cpuOnly=*/!(at::hasCUDA() || at::hasXPU() || at::hasMTIA()),
@ -594,7 +600,9 @@ void enableProfiler(
TORCH_CHECK(
config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK || config.global());
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK ||
config.global());
TORCH_CHECK(!activities.empty(), "No activities specified.");
TORCH_INTERNAL_ASSERT(
has_cpu || !config.global(),
@ -620,6 +628,7 @@ std::unique_ptr<ProfilerResult> disableProfiler() {
state_ptr &&
(config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK ||
config.state == ProfilerState::KINETO_ONDEMAND ||
config.state == ProfilerState::NVTX ||
config.state == ProfilerState::ITT),
@ -634,14 +643,15 @@ std::unique_ptr<ProfilerResult> disableProfiler() {
return std::make_unique<ProfilerResult>();
}
// Shared among NVTX, KINETO, KINETO_GPU_FALLBACK
// Shared among NVTX, KINETO, KINETO_GPU_FALLBACK, KINETO_PRIVATEUSE1_FALLBACK
std::unique_ptr<ProfilerResult> result;
if (state_ptr->config().state == ProfilerState::NVTX) {
result = std::make_unique<ProfilerResult>();
}
if (config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK) {
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
auto kineto_state_ptr =
std::static_pointer_cast<KinetoThreadLocalState>(state_ptr);
auto trace = kineto_state_ptr->finalizeTrace();
@ -774,6 +784,17 @@ int64_t KinetoEvent::cudaElapsedUs() const {
return -1;
}
int64_t KinetoEvent::privateuse1ElapsedUs() const {
auto privateuse1_event_start = fallbackStart();
auto privateuse1_event_end = fallbackEnd();
if (!privateuse1_event_start || !privateuse1_event_end) {
return -1;
}
return (int64_t)torch::profiler::impl::privateuse1Stubs()->elapsed(
&privateuse1_event_start, &privateuse1_event_end);
return -1;
}
void KinetoEvent::getPerfEventCounters(std::vector<uint64_t>& in) const {
return result_->visit(c10::overloaded(
[&in](const ExtraFields<EventType::TorchOp>& e) -> void {
@ -829,8 +850,8 @@ TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0)
TYPED_ATTR(TorchOp, scope, static_cast<uint8_t>(e.scope_))
TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty())
TYPED_ATTR(TorchOp, isAsync, e.is_async_)
TYPED_ATTR(TorchOp, fallbackStart, e.gpu_fallback_.cuda_event_start_)
TYPED_ATTR(TorchOp, fallbackEnd, e.gpu_fallback_.cuda_event_end_)
TYPED_ATTR(TorchOp, fallbackStart, e.device_fallback_.device_event_start_)
TYPED_ATTR(TorchOp, fallbackEnd, e.device_fallback_.device_event_end_)
TYPED_ATTR(
TorchOp,
flops,

View File

@ -57,11 +57,12 @@ struct TORCH_API KinetoEvent {
std::string backend() const;
bool isPythonFunction() const;
int64_t cudaElapsedUs() const;
int64_t privateuse1ElapsedUs() const;
void getPerfEventCounters(torch::profiler::perf_counters_t&) const;
private:
torch::profiler::impl::ProfilerEventStub fallbackStart() const;
torch::profiler::impl::ProfilerEventStub fallbackEnd() const;
torch::profiler::impl::ProfilerVoidEventStub fallbackStart() const;
torch::profiler::impl::ProfilerVoidEventStub fallbackEnd() const;
std::shared_ptr<const torch::profiler::impl::Result> result_;
std::vector<std::string> python_stack_;

View File

@ -267,7 +267,7 @@ struct TORCH_API LegacyEvent {
int64_t cpu_memory_usage_ = 0;
int64_t cuda_memory_usage_ = 0;
int device_ = -1;
torch::profiler::impl::ProfilerEventStub cuda_event = nullptr;
torch::profiler::impl::ProfilerVoidEventStub cuda_event = nullptr;
int node_id_ = 0;
bool is_remote_ = false;
int64_t cuda_us_ = -1;

View File

@ -408,7 +408,8 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
auto profilingConfig = rpcWithProfilingReq.getProfilingConfig();
if (profilingConfig.state == ProfilerState::KINETO ||
profilingConfig.state == ProfilerState::KINETO_GPU_FALLBACK) {
profilingConfig.state == ProfilerState::KINETO_GPU_FALLBACK ||
profilingConfig.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
profilingConfig = ProfilerConfig(
ProfilerState::CPU,
profilingConfig.report_input_shapes,

View File

@ -335,12 +335,16 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
if (config_.state == ProfilerState::KINETO_GPU_FALLBACK) {
try {
out->fallback_ = torch_ops_.gpu_fallback_.emplace_back();
out->fallback_ = torch_ops_.device_fallback_.emplace_back();
torch::profiler::impl::cudaStubs()->record(
nullptr, &out->fallback_->cuda_event_start_, nullptr);
nullptr, &out->fallback_->device_event_start_, nullptr);
} catch (const std::exception& e) {
LOG(WARNING) << "Failed to record CUDA event. " << e.what();
}
} else if (config_.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
out->fallback_ = torch_ops_.device_fallback_.emplace_back();
torch::profiler::impl::privateuse1Stubs()->record(
nullptr, &out->fallback_->device_event_start_, nullptr);
}
event->start_time_ = torch::profiler::impl::getApproximateTime();
@ -420,7 +424,8 @@ void ThreadLocalSubqueue::TorchOpStorage::materialize(
auto jit_stack = StealOrDefault<decltype(jit_stack_)>(jit_stack_);
auto jit_module = StealOrDefault<decltype(jit_modules_)>(jit_modules_);
auto extra_args = StealOrDefault<decltype(extra_args_)>(extra_args_);
auto gpu_fallback = StealOrDefault<decltype(gpu_fallback_)>(gpu_fallback_);
auto gpu_fallback =
StealOrDefault<decltype(device_fallback_)>(device_fallback_);
for (auto event = op_events_.begin(); event != op_events_.end(); ++event) {
ExtraFields<EventType::TorchOp> e{

View File

@ -116,8 +116,8 @@ using jit_modules_t = std::vector<std::string>;
using extra_args_t = std::unordered_map<std::string, c10::IValue>;
struct FallbackPair {
ProfilerEventStub cuda_event_start_ = nullptr;
ProfilerEventStub cuda_event_end_ = nullptr;
ProfilerVoidEventStub device_event_start_ = nullptr;
ProfilerVoidEventStub device_event_end_ = nullptr;
};
template <>
@ -131,7 +131,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
jit_stack_t&& jit_stack,
jit_modules_t&& jit_modules,
extra_args_t&& extra_args,
FallbackPair&& gpu_fallback,
FallbackPair&& device_fallback,
bool allow_tf32_cublas,
std::unique_ptr<perf_counters_t>&& perf_event_counters)
: TorchOpBasicFields(std::move(f)),
@ -142,7 +142,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
jit_stack_{std::move(jit_stack)},
jit_modules_{std::move(jit_modules)},
extra_args_{std::move(extra_args)},
gpu_fallback_{std::move(gpu_fallback)},
device_fallback_{std::move(device_fallback)},
allow_tf32_cublas_{allow_tf32_cublas},
perf_event_counters_{std::move(perf_event_counters)} {}
uint64_t correlation_id_;
@ -152,7 +152,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
jit_stack_t jit_stack_;
jit_modules_t jit_modules_;
extra_args_t extra_args_;
FallbackPair gpu_fallback_;
FallbackPair device_fallback_;
bool allow_tf32_cublas_;
std::unique_ptr<perf_counters_t> perf_event_counters_;
};
@ -579,8 +579,9 @@ class TORCH_API ThreadLocalSubqueue {
// with_flops
AppendOnlyList<extra_args_t, BlockSize> extra_args_;
// ProfilerState::KINETO_GPU_FALLBACK
AppendOnlyList<FallbackPair, BlockSize> gpu_fallback_;
// ProfilerState::KINETO_GPU_FALLBACK or
// ProfilerState::KINETO_PRIVATEUSE1_FALLBACK
AppendOnlyList<FallbackPair, BlockSize> device_fallback_;
} torch_ops_;
// reportBackendEventToActiveKinetoProfiler

View File

@ -28,6 +28,7 @@ enum class C10_API_ENUM ProfilerState {
ITT, // only emit ITT markers
KINETO, // use libkineto
KINETO_GPU_FALLBACK, // use CUDA events when CUPTI is not available
KINETO_PRIVATEUSE1_FALLBACK, // use PrivateUse1 events
KINETO_ONDEMAND, // run the profiler in on-demand mode
NUM_PROFILER_STATES, // must be the last one
};

View File

@ -38,7 +38,10 @@ void initPythonBindings(PyObject* module) {
.value("NVTX", ProfilerState::NVTX)
.value("ITT", ProfilerState::ITT)
.value("KINETO", ProfilerState::KINETO)
.value("KINETO_GPU_FALLBACK", ProfilerState::KINETO_GPU_FALLBACK);
.value("KINETO_GPU_FALLBACK", ProfilerState::KINETO_GPU_FALLBACK)
.value(
"KINETO_PRIVATEUSE1_FALLBACK",
ProfilerState::KINETO_PRIVATEUSE1_FALLBACK);
py::enum_<ActiveProfilerType>(m, "ActiveProfilerType")
.value("NONE", ActiveProfilerType::NONE)

View File

@ -12,10 +12,10 @@ namespace {
struct DefaultStubs : public ProfilerStubs {
DefaultStubs(const char* name) : name_{name} {}
void record(int*, ProfilerEventStub*, int64_t*) const override {
void record(int*, ProfilerVoidEventStub*, int64_t*) const override {
fail();
}
float elapsed(const ProfilerEventStub*, const ProfilerEventStub*)
float elapsed(const ProfilerVoidEventStub*, const ProfilerVoidEventStub*)
const override {
fail();
return 0.f;
@ -74,6 +74,7 @@ struct DefaultStubs : public ProfilerStubs {
REGISTER_DEFAULT(cuda, CUDA)
REGISTER_DEFAULT(itt, ITT)
REGISTER_DEFAULT(privateuse1, PrivateUse1)
#undef REGISTER_DEFAULT
} // namespace impl

View File

@ -16,13 +16,16 @@ namespace impl {
// -- Annotation --------------------------------------------------------------
// ----------------------------------------------------------------------------
using ProfilerEventStub = std::shared_ptr<CUevent_st>;
using ProfilerVoidEventStub = std::shared_ptr<void>;
struct TORCH_API ProfilerStubs {
virtual void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
const = 0;
virtual void record(
int* device,
ProfilerVoidEventStub* event,
int64_t* cpu_ns) const = 0;
virtual float elapsed(
const ProfilerEventStub* event,
const ProfilerEventStub* event2) const = 0;
const ProfilerVoidEventStub* event,
const ProfilerVoidEventStub* event2) const = 0;
virtual void mark(const char* name) const = 0;
virtual void rangePush(const char* name) const = 0;
virtual void rangePop() const = 0;
@ -38,6 +41,8 @@ TORCH_API void registerCUDAMethods(ProfilerStubs* stubs);
TORCH_API const ProfilerStubs* cudaStubs();
TORCH_API void registerITTMethods(ProfilerStubs* stubs);
TORCH_API const ProfilerStubs* ittStubs();
TORCH_API void registerPrivateUse1Methods(ProfilerStubs* stubs);
TORCH_API const ProfilerStubs* privateuse1Stubs();
using vulkan_id_t = strong::type<
int64_t,

View File

@ -36,7 +36,7 @@ static inline void cudaCheck(cudaError_t result, const char* file, int line) {
#define TORCH_CUDA_CHECK(result) cudaCheck(result, __FILE__, __LINE__);
struct CUDAMethods : public ProfilerStubs {
void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
void record(int* device, ProfilerVoidEventStub* event, int64_t* cpu_ns)
const override {
if (device) {
TORCH_CUDA_CHECK(c10::cuda::GetDevice(device));
@ -54,8 +54,11 @@ struct CUDAMethods : public ProfilerStubs {
TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream));
}
float elapsed(const ProfilerEventStub* event, const ProfilerEventStub* event2)
const override {
float elapsed(
const ProfilerVoidEventStub* event_,
const ProfilerVoidEventStub* event2_) const override {
auto event = (const ProfilerEventStub*)(event_);
auto event2 = (const ProfilerEventStub*)(event2_);
TORCH_CUDA_CHECK(cudaEventSynchronize(event->get()));
TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get()));
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)

View File

@ -10,11 +10,12 @@ namespace impl {
namespace {
struct ITTMethods : public ProfilerStubs {
void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
void record(int* device, ProfilerVoidEventStub* event, int64_t* cpu_ns)
const override {}
float elapsed(const ProfilerEventStub* event, const ProfilerEventStub* event2)
const override {
float elapsed(
const ProfilerVoidEventStub* event,
const ProfilerVoidEventStub* event2) const override {
return 0;
}