[PyTorch] Make RecordFunction store inputs as ArrayRef (#72484)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72484

Stepping stone toward stack-allocating array of inputs.

Funnily enough, this seems to improve performance too.
ghstack-source-id: 155492056

Test Plan:
1) CI
2) framework overhead benchmark with --stressTestRecordFunction --captureRecordFunctionInputs goes from 0.76 usec/iter to 0.72.

Reviewed By: chaekit, robieta

Differential Revision: D34061169

fbshipit-source-id: 073fedf1d3d162f927c4e9867cfda7dbfabba215
(cherry picked from commit dae77cf1cd8813d902d73999ad97133a3ef8e291)
This commit is contained in:
Scott Wolchok
2022-05-05 14:31:07 -07:00
committed by PyTorch MergeBot
parent 710246ea99
commit 52af4fc5ba
16 changed files with 101 additions and 66 deletions

View File

@ -7,7 +7,7 @@ void record_kernel_function_dtype(std::string name) {
RECORD_FUNCTION_WITH_SCOPE(
at::RecordScope::KERNEL_FUNCTION_DTYPE,
name,
{});
c10::ArrayRef<const c10::IValue>{});
}
}} // namespace at::detail

View File

@ -45,7 +45,7 @@ namespace torch {
namespace detail {
void record_custom_class(std::string name) {
RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::CUSTOM_CLASS, name, {});
RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::CUSTOM_CLASS, name, c10::ArrayRef<const c10::IValue>{});
}
} // namespace detail

View File

@ -357,11 +357,11 @@ int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchK
}
void Dispatcher::runRecordFunction(at::RecordFunction& guard, const OperatorHandle& op, DispatchKey dispatchKey, const torch::jit::Stack &stack) {
guard.before(op, stack, sequenceNumberForRunningRecordFunction(dispatchKey));
guard.before(op, c10::ArrayRef<const IValue>(stack.data(), stack.size()), sequenceNumberForRunningRecordFunction(dispatchKey));
}
void Dispatcher::runRecordFunction(at::RecordFunction& guard, const OperatorHandle& op, DispatchKey dispatchKey, torch::jit::Stack &&stack) {
guard.before(op, std::move(stack), sequenceNumberForRunningRecordFunction(dispatchKey));
guard.before(op, c10::ArrayRef<const IValue>(stack.data(), stack.size()), sequenceNumberForRunningRecordFunction(dispatchKey));
}
void Dispatcher::runRecordFunction(at::RecordFunction& guard, const OperatorHandle& op, DispatchKey dispatchKey) {

View File

@ -637,6 +637,7 @@ void RecordFunction::before(const char* name, int64_t sequence_nr) {
state_->operator_name_.reset();
runStartCallbacks();
invalidateInputs();
}
void RecordFunction::before(std::string name, int64_t sequence_nr) {
@ -649,6 +650,7 @@ void RecordFunction::before(std::string name, int64_t sequence_nr) {
state_->operator_name_.reset();
runStartCallbacks();
invalidateInputs();
}
void RecordFunction::before(
@ -664,6 +666,7 @@ void RecordFunction::before(
state_->name_ = op.schema().name();
runStartCallbacks();
invalidateInputs();
}
/* static */ void RecordFunction::setDefaultNodeId(int64_t newDefaultNodeId) {

View File

@ -284,15 +284,26 @@ struct TORCH_API RecordFunction {
template <typename F>
void before(
F fn,
const std::vector<c10::IValue>* args,
c10::ArrayRef<const c10::IValue> args,
int64_t current_sequence_nr = -1) {
if (!isActive()) {
return;
}
state_->inputs_ = *args;
state_->inputs_ = args;
#ifndef NDEBUG
state_->inputs_valid_ = true;
#endif
before(fn, current_sequence_nr);
}
template <typename F>
void before(
F fn,
const std::vector<IValue>* args,
int64_t current_sequence_nr = -1) {
before(std::move(fn), c10::ArrayRef<const c10::IValue>(args->data(), args->size()), current_sequence_nr);
}
// Destructor calls end callbacks
virtual ~RecordFunction();
@ -309,8 +320,11 @@ struct TORCH_API RecordFunction {
return state_->sequence_nr_;
}
const std::vector<c10::IValue>& inputs() const {
c10::ArrayRef<const IValue> inputs() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called inputs() on inactive RecordFunction");
#ifndef NDEBUG
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_->inputs_valid_, "Called inputs() outside RecordFunction start callback");
#endif
return state_->inputs_;
}
@ -383,30 +397,6 @@ struct TORCH_API RecordFunction {
// Gets node ID for distributed profiling
static int64_t getDefaultNodeId();
template<typename F>
void before(
F fn,
c10::ArrayRef<c10::IValue> args,
int64_t current_sequence_nr = -1) {
if (!isActive()) {
return;
}
state_->inputs_ = args.vec();
before(fn, current_sequence_nr);
}
template<typename F>
void before(
F fn,
std::vector<c10::IValue>&& args,
int64_t current_sequence_nr = -1) {
if (!isActive()) {
return;
}
state_->inputs_ = std::move(args);
before(fn, current_sequence_nr);
}
// Calls end callbacks. After end(), accessors will no longer provide useful results.
void end();
@ -460,6 +450,13 @@ struct TORCH_API RecordFunction {
state_->debug_handle_ = debug_handle;
}
void invalidateInputs() {
#ifndef NDEBUG
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called invalidateInputs() on inactive RecordFunction");
state_->inputs_valid_ = false;
#endif
}
private:
void runStartCallbacks();
@ -474,12 +471,16 @@ struct TORCH_API RecordFunction {
// flag is used to check whether the start callbacks were called
bool called_start_callbacks_ = false;
#ifndef NDEBUG
bool inputs_valid_ = false;
#endif
// Stores various ObserverContext objects with event metadata for callbacks.
ObserverContextList ctx_;
std::string name_;
int64_t sequence_nr_ = -1;
std::vector<c10::IValue> inputs_;
c10::ArrayRef<const IValue> inputs_;
std::vector<c10::IValue> outputs_;
c10::optional<c10::OperatorName> operator_name_;
@ -514,16 +515,43 @@ struct TORCH_API RecordFunction {
TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);
// Using macro to minimize inputs copies,
namespace detail {
template <typename Inputs, typename F, typename... Args>
void record_function_with_scope(RecordFunction& guard, F fn, const Inputs& inputs, Args&&... args) {
if (guard.needsInputs()) {
guard.before(fn, c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()), std::forward<Args>(args)...);
} else {
guard.before(fn, std::forward<Args>(args)...);
}
}
template <typename Inputs, typename F, typename... Args>
void record_function_with_scope_and_debug_handle(RecordFunction& guard, F fn, int64_t debug_handle, const Inputs& inputs, Args&&... args) {
guard.setDebugHandle(debug_handle);
if (guard.needsInputs()) {
guard.before(fn, c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()), std::forward<Args>(args)...);
} else {
guard.before(fn, std::forward<Args>(args)...);
}
}
template <typename F, typename... Args>
void record_function_with_scope(RecordFunction& guard, F fn, c10::ArrayRef<const c10::IValue> inputs, Args&&... args) {
return record_function_with_scope<c10::ArrayRef<const c10::IValue>, F, Args...>(guard, std::move(fn), inputs, std::forward<Args>(args)...);
}
template <typename F, typename... Args>
void record_function_with_scope_and_debug_handle(RecordFunction& guard, F fn, int64_t debug_handle, c10::ArrayRef<const c10::IValue> inputs, Args&&... args) {
return record_function_with_scope_and_debug_handle<c10::ArrayRef<const c10::IValue>, F, Args...>(guard, std::move(fn), debug_handle, inputs, std::forward<Args>(args)...);
}
} // namespace detail
// optional argument - function's seq_no
#define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \
at::RecordFunction guard(scope); \
if (guard.isActive()) { \
if (guard.needsInputs()) { \
guard.before(fn, inputs, ##__VA_ARGS__); \
} else { \
guard.before(fn, ##__VA_ARGS__); \
} \
if (guard.isActive()) { \
::at::detail::record_function_with_scope(guard, fn, inputs, ##__VA_ARGS__); \
}
#define RECORD_FUNCTION(fn, inputs, ...) \
@ -538,7 +566,7 @@ TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);
// Custom user scopes in C++; similar to Python's 'with record_function("..."):'
#define RECORD_USER_SCOPE(fn) \
RECORD_FUNCTION_WITH_SCOPE( \
at::RecordScope::USER_SCOPE, fn, {})
at::RecordScope::USER_SCOPE, fn, c10::ArrayRef<const c10::IValue>{})
// RECORD_USER_SCOPE with inputs
#define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \
@ -549,15 +577,10 @@ TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);
// post process events
#define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \
scope, fn, debug_handle, inputs, ...) \
at::RecordFunction guard(scope); \
if (guard.isActive()) { \
guard.setDebugHandle(debug_handle); \
if (guard.needsInputs()) { \
guard.before(fn, inputs, ##__VA_ARGS__); \
} else { \
guard.before(fn, ##__VA_ARGS__); \
} \
}
at::RecordFunction guard(scope); \
if (guard.isActive()) { \
::at::detail::record_function_with_scope_and_debug_handle(guard, fn, debug_handle, inputs, ##__VA_ARGS__); \
}
// Helper macros to record LITE INTERPETER scope events with debug handles
#define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \

View File

@ -433,7 +433,7 @@ auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
c10::str(
"autograd::engine::evaluate_function: ",
task.fn_.get()->name()),
std::vector<c10::IValue>());
c10::ArrayRef<const c10::IValue>());
evaluate_function(
local_graph_task,
task.fn_.get(),

View File

@ -159,9 +159,10 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// the forward pass function
guard.setForwardThreadId(thread_id_);
if (guard.needsInputs()) {
std::vector<c10::IValue> inputs_vec(inputs.begin(), inputs.end());
guard.before(
name(),
std::vector<c10::IValue>(inputs.begin(), inputs.end()),
c10::ArrayRef<const c10::IValue>(inputs_vec.data(), inputs_vec.size()),
sequence_nr());
} else {
guard.before(name(), sequence_nr());

View File

@ -303,7 +303,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
for (const auto& arg : args) {
iv_inputs.push_back(torch::jit::toTypeInferredIValue(arg));
}
rec->before(name, iv_inputs);
rec->before(name, c10::ArrayRef<const c10::IValue>(iv_inputs.data(), iv_inputs.size()));
} else {
rec->before(name);
}

View File

@ -24,7 +24,7 @@ void record_function_enter(
at::RecordFunction &rec) {
if (rec.isActive()) {
if (rec.needsInputs() && args.has_value()) {
rec.before(name, std::vector<c10::IValue>{c10::IValue{args.value()}});
rec.before(name, c10::ArrayRef<const c10::IValue>{c10::IValue{args.value()}});
} else {
rec.before(name);
}

View File

@ -78,7 +78,7 @@ ProcessGroup::Work::Work(
inputs.emplace_back(tensor);
}
}
recordingFunction->before(profilingTitle, inputs);
recordingFunction->before(profilingTitle, c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()));
std::function<void()> end_handler = [recordingFunction]() {
recordingFunction->end();
};

View File

@ -490,7 +490,7 @@ inline void ProcessGroupGloo::AsyncWork::recordAsyncWorkProfilingInfo(
inputs.emplace_back(tensor);
}
}
recordingFunction->before(profilingTitle, inputs);
recordingFunction->before(profilingTitle, c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()));
};
recordFunctionBeforeCallback_ = at::wrapPropagateTLSState(before_handler);
std::function<void()> end_handler = [recordingFunction]() {

View File

@ -503,7 +503,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::get_device(Tensor self) -> int"),
[](Stack& stack) {
RECORD_FUNCTION("get_device", std::vector<c10::IValue>());
RECORD_FUNCTION("get_device", c10::ArrayRef<const c10::IValue>{});
auto result =
at::get_device((std::move(peek(stack, 0, 1))).toTensor());
drop(stack, 1);
@ -513,7 +513,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::storage_offset(Tensor self) -> int"),
[](Stack& stack) {
RECORD_FUNCTION("storage_offset", std::vector<c10::IValue>());
RECORD_FUNCTION("storage_offset", c10::ArrayRef<const c10::IValue>{});
auto result =
((std::move(peek(stack, 0, 1))).toTensor()).storage_offset();
drop(stack, 1);
@ -523,7 +523,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::is_contiguous(Tensor self) -> bool"),
[](Stack& stack) {
RECORD_FUNCTION("is_contiguous", std::vector<c10::IValue>());
RECORD_FUNCTION("is_contiguous", c10::ArrayRef<const c10::IValue>{});
auto result =
((std::move(peek(stack, 0, 1))).toTensor()).is_contiguous();
drop(stack, 1);

View File

@ -1203,8 +1203,10 @@ c10::IValue BlockRunner::run_impl_record_functions(
if (!step_callbacks.empty()) {
at::RecordFunction guard(std::move(step_callbacks));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
guard.needsInputs() ? guard.before("forward", &args)
: guard.before("forward");
guard.needsInputs()
? guard.before(
"forward", c10::ArrayRef<const IValue>(args.data(), args.size()))
: guard.before("forward");
return run_impl(std::forward<IValueList>(args), kwargs);
}
@ -1845,8 +1847,14 @@ void ProcessedNode::run() {
if (!step_callbacks.empty()) {
at::RecordFunction guard(std::move(step_callbacks));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
guard.needsInputs() ? guard.before(get_op_name(), inputs_ivalue_vec())
: guard.before(get_op_name());
if (guard.needsInputs()) {
const auto inputs = inputs_ivalue_vec();
guard.before(
get_op_name(),
c10::ArrayRef<const IValue>(inputs.data(), inputs.size()));
} else {
guard.before(get_op_name());
}
if (has_out_variant()) {
guard._setStaticRuntimeOutVariant();
}

View File

@ -11,7 +11,7 @@ namespace torch {
namespace profiler {
namespace impl {
void InputOutputEncoder::push(const std::vector<c10::IValue>& values) {
void InputOutputEncoder::push(c10::ArrayRef<const c10::IValue> values) {
for (const auto& value : values) {
if (value.isTensor()) {
push(value.toTensor());

View File

@ -105,7 +105,7 @@ constexpr int IO_ENCODER_DEFAULT_BLOCK_SIZE = 1024;
// Those vectors can be created during post-processing.
class InputOutputEncoder final {
public:
void push(const std::vector<c10::IValue>& values);
void push(c10::ArrayRef<const c10::IValue> values);
// Used during post-processing to create vectors for shapes and dtype.
auto getNextShapesAndDtypes();

View File

@ -317,7 +317,7 @@ static constexpr auto kMat2Size = "mat2_size";
static bool validateInput(
const std::string& op_name,
size_t min_size,
const std::vector<c10::IValue>& inputs,
c10::ArrayRef<const c10::IValue> inputs,
const c10::ArrayRef<int>& should_be_tensor) {
std::stringstream ss;
if (inputs.size() < min_size) {
@ -342,7 +342,7 @@ std::unordered_map<std::string, c10::IValue> saveExtraArgs(
const at::RecordFunction& fn) {
// for specific types of fn, return the saved extra args for computing flops
std::unordered_map<std::string, c10::IValue> map;
std::vector<c10::IValue> inputs = fn.inputs();
auto inputs = fn.inputs();
std::string fname(fn.name());
if (inputs.empty()) {