Match parameter names and = default (#9737)

Summary:
More clang tidy cleanups in `torch/csrc`. This time:

1. `hicpp-use-equals-default` recommends `= default` instead of `{}` for constructors/destructors. This is better practice because it expresses the intent better (https://stackoverflow.com/questions/6502828/what-does-default-mean-after-a-class-function-declaration)
2. `readability-inconsistent-declaration-parameter-name` enforces that parameter names in the declaration match parameter names in the definition. This is just generally useful and can prevent confusion and bugs.

Also updated my script a little bit.

apaszke ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9737

Differential Revision: D9069069

Pulled By: goldsborough

fbshipit-source-id: f7b3f3a4eb4c9fadc30425a153566d3b613a41ae
This commit is contained in:
Peter Goldsborough
2018-07-30 13:57:19 -07:00
committed by Facebook Github Bot
parent 40a8239984
commit 04939a4745
42 changed files with 137 additions and 84 deletions

View File

@ -2,6 +2,7 @@
# NOTE: there must be no spaces before the '-', so put the comma first.
Checks: '
*
,clang-analyzer-*
,modernize-*
,-cert-err58-cpp
,-cert-err60-cpp
@ -9,6 +10,7 @@ Checks: '
,-cppcoreguidelines-owning-memory
,-cppcoreguidelines-pro-bounds-array-to-pointer-decay
,-cppcoreguidelines-pro-bounds-constant-array-index
,-cppcoreguidelines-pro-type-member-init
,-cppcoreguidelines-pro-type-static-cast-downcast
,-cppcoreguidelines-pro-type-vararg
,-cppcoreguidelines-special-member-functions
@ -23,9 +25,11 @@ Checks: '
,-hicpp-braces-around-statements
,-hicpp-explicit-conversions
,-hicpp-no-array-decay
,-hicpp-signed-bitwise
,-hicpp-special-member-functions
,-hicpp-vararg
,-llvm-header-guard
,-llvm-include-order
,-llvm-namespace-comment
,-misc-unused-parameters
,-modernize-make-unique
@ -34,7 +38,6 @@ Checks: '
,-readability-braces-around-statements
,-readability-else-after-return
,-readability-named-parameter
,clang-analyzer-*
'
WarningsAsErrors: ''
HeaderFilterRegex: 'torch/csrc/'

View File

@ -7,6 +7,7 @@ import re
import subprocess
import sys
DEFAULT_FILE_PATTERN = r".*\.[ch](pp)?"
# @@ -start,count +start,count @@
@ -26,6 +27,11 @@ def run_shell_command(arguments, process_name=None):
return output.decode()
def normalize_directory_path(path):
"""Normalizes a directory path."""
return path.rstrip('/')
def transform_globs_into_regexes(globs):
"""Turns glob patterns into regular expressions."""
return [glob.replace("*", ".*").replace("?", ".") for glob in globs]
@ -49,16 +55,37 @@ def git_diff(args, verbose):
return run_shell_command(command, process_name="git diff")
def filter_files(files, file_patterns):
def filter_files(files, file_patterns, verbose):
"""Returns all files that match any of the patterns."""
filtered = []
for file in files:
has_match = False
for pattern in file_patterns:
if pattern.match(file):
if pattern.search(file):
filtered.append(file)
has_match = True
if not has_match and verbose:
message = "{} does not match any ".format(file)
message += "file pattern in {{{}}}".format(', '.join(map(str, file_patterns)))
print(message)
return filtered
def remove_recursive_files(files, paths, verbose):
"""
Removes all files that are not immediately under one of the given paths.
"""
for file in files:
if os.path.dirname(file) in paths:
yield file
else:
if verbose:
message = "{} ({}) does not match any ".format(file, os.path.dirname(file))
message += "non-recursive path in {{{}}}".format(", ".join(paths))
print(message)
def get_changed_files(revision, paths, verbose):
"""Runs git diff to get the paths of all changed files."""
# --diff-filter AMU gets us files that are (A)dded, (M)odified or (U)nmerged (in the working copy).
@ -152,7 +179,17 @@ def parse_options():
)
parser.add_argument("-r", "--revision", help="Git revision to get changes from")
parser.add_argument(
"-p", "--paths", nargs="+", default=["."], help="Lint only the given paths"
"-p",
"--paths",
nargs="+",
default=["."],
help="Lint only the given paths (recursively)",
)
parser.add_argument(
"-n",
"--no-recursive",
action="store_true",
help="If paths are supplied with -p/--paths, do not recurse into paths",
)
parser.add_argument(
"-s",
@ -173,12 +210,15 @@ def parse_options():
def main():
options = parse_options()
paths = map(normalize_directory_path, options.paths)
if options.revision:
files = get_changed_files(options.revision, options.paths, options.verbose)
files = get_changed_files(options.revision, paths, options.verbose)
else:
files = get_all_files(options.paths)
files = get_all_files(paths)
if options.no_recursive:
files = remove_recursive_files(files, paths, options.verbose)
file_patterns = get_file_patterns(options.glob, options.regex)
files = filter_files(files, file_patterns)
files = filter_files(files, file_patterns, options.verbose)
# clang-tidy error's when it does not get input files.
if not files:

View File

@ -24,6 +24,7 @@ cmake -DUSE_CUDA:BOOL=$USE_CUDA \
-DCMAKE_BUILD_TYPE:STRING=$BUILD_TYPE \
-DCMAKE_INSTALL_PREFIX:STRING=$INSTALL_PREFIX \
-DCMAKE_INSTALL_MESSAGE=NEVER \
-DCMAKE_EXPORT_COMPILE_COMMANDS:BOOL=ON \
-G "$GENERATE" \
$PYTORCHPATH/
$MAKE -j "$JOBS" install

View File

@ -24,6 +24,7 @@ cmake -DUSE_CUDA:BOOL=$USE_CUDA \
-DCMAKE_INSTALL_MESSAGE=NEVER \
-Dnanopb_BUILD_GENERATOR:BOOL=OFF \
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON \
-DCMAKE_EXPORT_COMPILE_COMMANDS:BOOL=ON \
-DVERBOSE:BOOL=${VERBOSE:-0} \
-G "$GENERATE" \
$PYTORCHPATH/torch

View File

@ -48,7 +48,7 @@ class CursorBase {
/// A `(key, value)` pair exposed by cursor iterators.
struct Item {
Item(const std::string& key_, T& module_);
Item(const std::string& key_, T& value_);
T& operator*();
const T& operator*() const;

View File

@ -18,7 +18,7 @@ private:
struct AnomalyMetadata {
virtual ~AnomalyMetadata(){};
virtual ~AnomalyMetadata() = default;
virtual void store_stack() = 0;
virtual void print_stack() = 0;
};

View File

@ -159,7 +159,7 @@ struct GraphTask {
std::unordered_map<Function*, ExecInfo> exec_info;
std::vector<Variable> captured_vars;
void init_to_execute(Function& graph_root, const edge_list& captures);
void init_to_execute(Function& graph_root, const edge_list& outputs);
// The value of worker_device in the thread that created this task.
// See Note [Reentrant backwards]
@ -499,14 +499,14 @@ struct ClearCallbacks {
std::mutex& callbacks_lock;
};
auto Engine::execute(const edge_list& input_roots,
auto Engine::execute(const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
const edge_list& outputs) -> variable_list {
std::call_once(start_threads_flag, &Engine::start_threads, this);
validate_outputs(input_roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
validate_outputs(roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
return msg;
});
@ -517,7 +517,7 @@ auto Engine::execute(const edge_list& input_roots,
std::unique_lock<std::mutex> lock(graph_task.mutex);
// Now compute the dependencies for all executable functions and queue the root
auto graph_root = std::make_shared<GraphRoot>(input_roots, inputs);
auto graph_root = std::make_shared<GraphRoot>(roots, inputs);
compute_dependencies(graph_root.get(), graph_task);
if (!outputs.empty()) {
graph_task.init_to_execute(*graph_root, outputs);

View File

@ -57,7 +57,7 @@ protected:
ReadyQueue& ready_queue(int device);
void start_threads();
virtual void thread_init(int device);
virtual void thread_main(GraphTask *task);
virtual void thread_main(GraphTask *graph_task);
virtual void thread_on_exception(FunctionTask& task, std::exception& e);
std::once_flag start_threads_flag;

View File

@ -328,7 +328,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
/// See Function::is_traceable() for definition.
struct TraceableFunction : public Function {
using Function::Function;
bool is_traceable() final override {
bool is_traceable() final {
return true;
}
};

View File

@ -10,12 +10,12 @@ struct Variable;
using variable_list = std::vector<Variable>;
struct FunctionPreHook {
virtual ~FunctionPreHook() {}
virtual ~FunctionPreHook() = default;
virtual variable_list operator()(const variable_list& grads) = 0;
};
struct FunctionPostHook {
virtual ~FunctionPostHook() {}
virtual ~FunctionPostHook() = default;
virtual variable_list operator()(const variable_list& grad_input, const variable_list& grad_output) = 0;
};

View File

@ -6,9 +6,9 @@
namespace torch { namespace autograd {
struct AccumulateGrad : public Function {
explicit AccumulateGrad(Variable variable);
explicit AccumulateGrad(Variable variable_);
variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& grads) override;
Variable variable;
};

View File

@ -11,7 +11,7 @@
namespace torch { namespace autograd {
auto Error::apply(variable_list&& grad_outputs) -> variable_list {
auto Error::apply(variable_list&& inputs) -> variable_list {
throw std::runtime_error(msg);
}

View File

@ -13,7 +13,7 @@
namespace torch { namespace autograd {
struct CopyBackwards : public Function {
variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& grads) override;
at::Type *src_type;
int32_t src_device = -1;
@ -23,9 +23,12 @@ struct CopyBackwards : public Function {
// grad[idx] is defined by the relative sizes, strides, and offset of base and
// view.
struct CopySlices : public Function {
CopySlices(const Variable& base, at::TensorGeometry view, std::shared_ptr<Function> fn);
CopySlices(
const Variable& base_var,
at::TensorGeometry view_,
std::shared_ptr<Function> fn_);
variable_list apply(variable_list&& grads) override;
variable_list apply(variable_list&& inputs) override;
void release_variables() override;
at::TensorGeometry base;

View File

@ -22,14 +22,14 @@ struct InputBuffer {
InputBuffer& operator=(InputBuffer&& other) = default;
// Accumulates the variable at a specified index.
void add(size_t idx, Variable var);
void add(size_t pos, Variable var);
int device() const;
Variable operator[](size_t pos) { return buffer[pos]; }
// Returns the inputs as a list of variables. Destroys given InputBuffer.
static std::vector<Variable> variables(InputBuffer&& buffer);
static std::vector<Variable> variables(InputBuffer&& g);
private:
std::vector<Variable> buffer;

View File

@ -185,7 +185,7 @@ struct TORCH_API RecordFunction {
using thread_event_lists = std::vector<std::vector<Event>>;
// NOTE: changing profiler modes is **NOT THREAD SAFE**. You should ensure that
// there no autograd functions are being executed when these function are used.
TORCH_API void enableProfiler(ProfilerState state);
TORCH_API void enableProfiler(ProfilerState new_state);
TORCH_API thread_event_lists disableProfiler();
} // namespace profiler

View File

@ -45,10 +45,10 @@ class TORCH_API SavedVariable {
std::weak_ptr<Function> grad_accumulator_;
VariableVersion version_counter_;
uint32_t saved_version_;
uint32_t output_nr_;
uint32_t saved_version_ = 0;
uint32_t output_nr_ = 0;
bool was_default_constructed_ = true;
bool requires_grad_;
bool has_grad_fn_;
bool requires_grad_ = false;
bool has_grad_fn_ = false;
};
}} // namespace torch::autograd

View File

@ -263,7 +263,7 @@ struct Variable::Impl : public at::TensorImpl {
TORCH_API explicit Impl(
at::Tensor data,
bool requires_grad = false,
Edge edge = Edge());
Edge gradient_edge = Edge());
~Impl() override;

View File

@ -28,7 +28,7 @@ struct AttributeValue {
Symbol name;
virtual AttributeKind kind() const = 0;
virtual Ptr clone() const = 0;
virtual ~AttributeValue() {}
virtual ~AttributeValue() = default;
};
template<typename T, AttributeKind Kind>
@ -101,7 +101,7 @@ private:
// we return Derived* pointers because Nodes are normally held as pointers.
template<typename Derived>
struct Attributes {
Attributes() {}
Attributes() = default;
void copyAttributes(const Attributes & rhs) {
values_.clear();
for(auto & i : rhs.values_) {

View File

@ -9,6 +9,7 @@
#include <torch/csrc/jit/assertions.h>
#include <algorithm>
#include <memory>
namespace torch { namespace jit {
@ -564,14 +565,13 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
reverse_block->owningNode()->destroy();
}
Gradient differentiate(std::shared_ptr<Graph>& _graph, const std::vector<bool>& requires_grad) {
Gradient differentiate(std::shared_ptr<Graph>& graph, const std::vector<bool>& requires_grad) {
Gradient grad_desc;
// Take ownership of the graph
JIT_ASSERTM(
_graph.use_count() == 1,
JIT_ASSERTM(graph.use_count() == 1,
"differentiate will mutate and destroy the graph, so it requires "
"graph.use_count() == 1, but found ", _graph.use_count());
std::swap(_graph, grad_desc.f);
"graph.use_count() == 1, but found %d", graph.use_count());
std::swap(graph, grad_desc.f);
// XXX: Take care when handling outputs - they can be duplicated!
WithInsertPoint guard(grad_desc.f->block());

View File

@ -4,7 +4,9 @@
#include "torch/csrc/jit/ir.h"
#include <ATen/ATen.h>
#include <vector>
#include <memory>
namespace torch { namespace jit {

View File

@ -86,7 +86,7 @@ struct CompiledFusionFunction {
TH_DISALLOW_COPY_AND_ASSIGN(CompiledFusionFunction);
CompiledFusionFunction(const std::string & name, AnnotatedGraph & agraph);
virtual ~CompiledFusionFunction() {}
virtual ~CompiledFusionFunction() = default;
// expects outputs to be pre-allocated
void launch_with_tensors(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs);

View File

@ -516,28 +516,28 @@ void runRequiredPasses(const std::shared_ptr<Graph>& g) {
RemoveExpands(g);
}
void specializeToSpec(const std::shared_ptr<Graph>& graph_, const ArgumentSpec& spec) {
void specializeToSpec(const std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
// clean up GradOf and AutogradAdd nodes
// this must be first because later passes do not know what GradOfs are
std::vector<bool> defined;
for(size_t i = 0; i < spec.size(); ++i) {
defined.push_back(spec.at(i).defined());
}
specializeUndef(*graph_, defined);
specializeUndef(*graph, defined);
// required passes shared with autograd fallback
runRequiredPasses(graph_);
runRequiredPasses(graph);
// Decompose addmm nodes to add + mm, so expands can be inserted and
// gradients accumulated on the backward pass
//
// In the future, if we need more passes like this, we should convert this
// into a generic canonicalization pass.
DecomposeAddmm(graph_);
DecomposeAddmm(graph);
// clean up dead constants from specialization
EliminateDeadCode(graph_);
EliminateDeadCode(graph);
// calculate all input shapes
PropagateInputShapes(*graph_, spec);
PropagateInputShapes(*graph, spec);
}
void runOptimization(std::shared_ptr<Graph> & graph, bool graphMustSupportVariables) {

View File

@ -34,7 +34,7 @@ struct GraphExecutorState {
struct GraphExecutorImpl;
struct TORCH_API GraphExecutor {
GraphExecutor() {}
GraphExecutor() = default;
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
// note: if not specified, symbolically_differentiable is computed from the graph.
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize, bool symbolically_differentiable);

View File

@ -1,3 +1,5 @@
#pragma once
#include "torch/csrc/jit/assertions.h"
namespace torch { namespace jit {

View File

@ -66,7 +66,7 @@ struct Model_ {
// Readers
struct ReaderBase {
ReaderBase() {}
ReaderBase() = default;
ReaderBase(pb_callback_t& cb) {
initialize_callback(cb);
}

View File

@ -339,7 +339,7 @@ public:
ContainerTensor()
: TensorImpl(&(at::globalContext().getType(at::Backend::Undefined,at::ScalarType::Undefined)), nullptr) {}
virtual ~ContainerTensor() {}
virtual ~ContainerTensor() = default;
virtual at::IntList sizes() const override {
throw std::runtime_error("sizes() on ContainerTensor");
}
@ -685,8 +685,8 @@ struct CodeImpl {
// InterpreterState state that is held across stages and used to compute a Code
struct InterpreterStateImpl {
InterpreterStateImpl(const Code & function_)
: function(function_.pImpl),
InterpreterStateImpl(const Code & code)
: function(code.pImpl),
int_data(function->int_data.data()),
bool_data(function->bool_data),
registers(function->register_size) {
@ -775,15 +775,15 @@ std::ostream & operator<<(std::ostream & out, const Code & code) {
Code::Code(std::shared_ptr<Graph>& graph)
: pImpl(new CodeImpl(graph)) {}
Code::~Code() {}
Code::~Code() = default;
const std::vector<GraphExecutor*>& Code::executors() {
return pImpl->executors();
}
InterpreterState::InterpreterState(const Code & function)
: pImpl(new InterpreterStateImpl(function)) {}
InterpreterState::~InterpreterState() {}
InterpreterState::InterpreterState(const Code & code)
: pImpl(new InterpreterStateImpl(code)) {}
InterpreterState::~InterpreterState() = default;
void InterpreterState::runOneStage(Stack & stack) {
return pImpl->runOneStage(stack);

View File

@ -355,7 +355,7 @@ void Graph::lint() const {
// - every use will occur later in the topsort
struct LintScope {
LintScope() {}
LintScope() = default;
LintScope(std::unique_ptr<LintScope> parent)
: parent(std::move(parent)) {}
bool contains(const Value * v) {
@ -487,13 +487,13 @@ void LintGraph(std::shared_ptr<Graph>& graph) {
graph->lint();
}
void Block::cloneFrom(Block * src, std::function<Value*(Value*)> outer_map) {
void Block::cloneFrom(Block * src, std::function<Value*(Value*)> value_map) {
std::unordered_map<Value*, Value*> local_map;
auto env = [&](Value * v) {
auto it = local_map.find(v);
if(it != local_map.end())
return it->second;
return outer_map(v);
return value_map(v);
};
auto graph = owningGraph();

View File

@ -54,7 +54,7 @@ struct Value;
TORCH_API std::ostream& operator<<(std::ostream & out, const Graph & g);
TORCH_API std::ostream& operator<<(std::ostream & out, const Type & t);
TORCH_API std::ostream& operator<<(std::ostream & out, const Node & t);
TORCH_API std::ostream& operator<<(std::ostream & out, const Node & n);
// A list of nodes, with inputs and outputs
struct Block;
@ -683,7 +683,7 @@ public:
return *schema_;
}
virtual ~Node() {}
virtual ~Node() = default;
private:
std::pair<Value*, const Argument&> findInput(Symbol name);
void findSchema() const;
@ -889,8 +889,7 @@ public:
, block_(new Block(this, nullptr))
, insert_before_(return_node()) {}
Graph()
: Graph( std::make_shared<Scope>()) {}
Graph() : Graph(std::make_shared<Scope>()) {}
at::ArrayRef<Value*> inputs() {
return block_->inputs();

View File

@ -10,7 +10,7 @@
namespace torch { namespace jit {
FunctionSchema parseSchema(const std::string& decl);
FunctionSchema parseSchema(const std::string& schema);
using OperationCreator = std::function<Operation(Node*)>;
@ -33,7 +33,7 @@ struct TORCH_API Operator {
FunctionSchema schema;
bool matches(const Node* n) const;
bool matches(const Node* node) const;
// Operators have different versions depending on if some inputs are encoded
// as attributes or inputs. This function returns the right Operation function,
// given a node encoded for one variant.

View File

@ -10,6 +10,6 @@ namespace torch { namespace jit {
// outputs = <original_computation>
// else:
// outputs = undefineds
TORCH_API void LowerGradOf(Graph& graph);
TORCH_API void LowerGradOf(Graph& g);
}}

View File

@ -104,7 +104,7 @@ struct ParsedArgs {
ParsedArgs flatten(py::handle obj);
PyObject* unflatten(at::ArrayRef<autograd::Variable> outputs,
PyObject* unflatten(at::ArrayRef<autograd::Variable> vars,
const IODescriptor& structure);
}}} // namespace torch::jit::python

View File

@ -103,10 +103,10 @@ void pythonRecordSourceLocation(Node* n) {
n->setSourceLocation(sl);
}
void initPythonTracerBindings(PyObject* module_) {
void initPythonTracerBindings(PyObject* module) {
setRecordSourceLocation(pythonRecordSourceLocation);
auto m = py::handle(module_).cast<py::module>();
auto m = py::handle(module).cast<py::module>();
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
// NB: no constructor; you have to get it from C++ code
.def("__repr__", [](const TracingState& s) {

View File

@ -68,7 +68,7 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
SourceRange loc,
Method & m,
// note: names for args will be 'argument 0', 'argument 1', etc..
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> inputs_,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
// n_binders is always set to the number of variables an expression is
@ -89,7 +89,7 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
throw ErrorReport(loc) << "cannot call a " << kind();
}
virtual ~SugaredValue() {}
virtual ~SugaredValue() = default;
};
// most things in the environment are just simple value types

View File

@ -89,7 +89,7 @@ struct Tree : std::enable_shared_from_this<Tree> {
throw std::runtime_error(ss.str());
}
}
virtual ~Tree() {}
virtual ~Tree() = default;
private:
int kind_;

View File

@ -77,8 +77,8 @@ inline void pack(Stack & stack, T&& v) {
}
template<>
inline void pack(Stack & stack, std::vector<at::Tensor>&& ts) {
for(auto& t : ts) {
inline void pack(Stack & stack, std::vector<at::Tensor>&& v) {
for(auto& t : v) {
stack.push_back(IValue(std::move(t)));
}
}

View File

@ -80,7 +80,7 @@ public:
JIT_ASSERT(T::Kind == kind());
return std::static_pointer_cast<const T>(shared_from_this());
}
virtual ~Type() {}
virtual ~Type() = default;
};
inline bool operator!=(const Type & lhs, const Type & rhs) {

View File

@ -6,7 +6,7 @@ namespace torch { namespace jit {
// a wrapper to mark places where we expect all the at::Tensors to be
// variables
struct variable_tensor_list : public std::vector<at::Tensor> {
variable_tensor_list() {}
variable_tensor_list() = default;
template<class InputIt>
variable_tensor_list(InputIt first, InputIt last)
: std::vector<at::Tensor>(first, last) {}

View File

@ -32,7 +32,7 @@ namespace torch {
// DEALINGS IN THE SOFTWARE.
inline size_t hash_combine(size_t seed, size_t value) {
return seed ^ (value + 0x9e3779b9 + (seed << 6) + (seed >> 2));
return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -16,7 +16,7 @@ std::string py_typename(PyObject *object) {
struct Type {
virtual bool is_matching(PyObject *object) = 0;
virtual ~Type() {};
virtual ~Type() = default;
};
struct SimpleType: public Type {

View File

@ -7,7 +7,9 @@
namespace torch {
std::string format_invalid_args(
PyObject *args, PyObject *kwargs, const std::string& name,
PyObject* given_args,
PyObject* given_kwargs,
const std::string& function_name,
const std::vector<std::string>& options);
} // namespace torch

View File

@ -90,8 +90,8 @@ struct PythonArgParser {
private:
[[noreturn]]
void print_error(PyObject* args, PyObject* kwargs, PyObject* dst[]);
PythonArgs raw_parse(PyObject* args, PyObject* kwargs, PyObject* dst[]);
void print_error(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]);
PythonArgs raw_parse(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]);
std::vector<FunctionSignature> signatures_;
std::string function_name;

View File

@ -6,8 +6,8 @@
namespace torch { namespace utils {
at::Tensor & apply_(at::Tensor & self, PyObject* fn);
at::Tensor & map_(at::Tensor & self, const at::Tensor & other, PyObject* fn);
at::Tensor & map2_(at::Tensor & self, const at::Tensor & other1,
const at::Tensor & other2, PyObject* fn);
at::Tensor & map_(at::Tensor & self, const at::Tensor & other_, PyObject* fn);
at::Tensor & map2_(at::Tensor & self, const at::Tensor & x_,
const at::Tensor & y_, PyObject* fn);
}} // namespace torch::utils