[7/N] Fix Wextra-semi warning (#140225)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140225
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-11-10 14:28:10 +00:00
committed by PyTorch MergeBot
parent d90c25e3e2
commit ffb979032d
27 changed files with 470 additions and 471 deletions

View File

@ -101,9 +101,16 @@ SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: c++17
StatementMacros:
- C10_DEFINE_bool
- C10_DEFINE_int
- C10_DEFINE_int32
- C10_DEFINE_int64
- C10_DEFINE_string
- PyObject_HEAD
- PyObject_VAR_HEAD
- PyException_HEAD
- DEFINE_BINARY
TabWidth: 8
UseTab: Never
---

View File

@ -1594,8 +1594,8 @@ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(c
inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const Vectorized<float>& b) { \
return cvt_from_fp32<type>(__m512(a), __m512(b)); \
}
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_VECTORIZED_INIT(Half, half);
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16)
CONVERT_VECTORIZED_INIT(Half, half)
#else //defined(CPU_CAPABILITY_AVX512)
@ -1624,8 +1624,8 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
} \
return Vectorized<type>::loadu(arr2); \
}
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_NON_VECTORIZED_INIT(Half, half);
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16)
CONVERT_NON_VECTORIZED_INIT(Half, half)
#endif // defined(CPU_CAPABILITY_AVX512)
@ -1663,8 +1663,8 @@ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vec
data += Vectorized<float>::size(); \
load_fp32_from_##name(data, out2); \
}
LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16);
LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16);
LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16)
LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16)
#endif
}}}

View File

@ -61,7 +61,6 @@ bool SymInt::has_hint() const {
} \
}
// clang-format off
DEFINE_BINARY(operator+, std::plus<>(), add, SymInt)
DEFINE_BINARY(operator-, std::minus<>(), sub, SymInt)
DEFINE_BINARY(operator*, std::multiplies<>(), mul, SymInt)
@ -75,7 +74,6 @@ DEFINE_BINARY(sym_gt, std::greater<>(), gt, SymBool)
DEFINE_BINARY(sym_ge, std::greater_equal<>(), ge, SymBool)
DEFINE_BINARY(min, std::min, sym_min, SymInt)
DEFINE_BINARY(max, std::max, sym_max, SymInt)
// clang-format on
SymInt::operator SymFloat() const {
if (auto ma = maybe_as_int()) {

View File

@ -18,7 +18,7 @@
namespace at {
struct Quantizer;
};
}
namespace torch { namespace autograd {
@ -54,6 +54,6 @@ namespace VariableType {
const at::Tensor & unpack(const Tensor & t, const char * name, int pos);
at::Tensor unpack_opt(const Tensor & t, const char * name, int pos);
std::vector<at::Tensor> unpack(const at::ITensorListRef& tl, const char *name, int pos);
};
}
}} // namespace torch::autograd

View File

@ -21,7 +21,7 @@ std::unordered_set<rpc::worker_id_t> DistAutogradContext::getKnownWorkerIds()
const {
std::lock_guard<std::mutex> guard(lock_);
return knownWorkerIds_;
};
}
void DistAutogradContext::addKnownWorkerId(const rpc::worker_id_t workerId) {
std::lock_guard<std::mutex> guard(lock_);

View File

@ -103,7 +103,7 @@ class TORCH_API Reducer {
// been applied.
void set_optimizer_in_backward() {
optim_in_backward_ = true;
};
}
// Runs allreduce or installed communication hook given GradBucket instance.
c10::intrusive_ptr<c10::ivalue::Future> run_comm_hook(

View File

@ -16,7 +16,7 @@
C10_DEFINE_bool(
torch_jit_do_not_store_optimized_graph,
false,
"Do not store the optimized graph.");
"Do not store the optimized graph.")
namespace torch::jit {
namespace {
@ -133,8 +133,8 @@ GraphFunction::SpecializationKey GraphFunction::currentSpecialization() const {
void preoptimizeGraph(std::shared_ptr<Graph>& graph, bool disable_autocast) {
Inline(*graph);
// Peephole Optimize cleans up many "is None" checks and creates constant prop
// opportunities
// Peephole Optimize cleans up many "is None" checks and creates constant
// prop opportunities
PeepholeOptimize(graph, true);
// AliasDb construction can be slow, so run it just on immutable types

View File

@ -6,7 +6,7 @@
namespace torch::jit {
struct ChunkOutput {
ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
ChunkOutput(Value* v, size_t o) : val(v), offset(o) {}
Value* val;
size_t offset;
};

View File

@ -29,12 +29,12 @@
C10_DEFINE_bool(
torch_jit_disable_cat,
false,
"disable aten::cat in TE fusion groups");
"disable aten::cat in TE fusion groups")
C10_DEFINE_bool(
torch_jit_enable_dynamic_shape_fusion,
false,
"enable TE fusion using dynamic shapes");
"enable TE fusion using dynamic shapes")
namespace torch::jit {
@ -82,9 +82,8 @@ static const OperatorSet& supported_non_eltwise_set() {
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
"aten::matmul(Tensor self, Tensor other) -> Tensor",
};
// clang-format on
return supported_non_eltwise_set;
};
}
bool isSupported(Node* node) {
// For Block codegen we allow limited ops.
@ -102,7 +101,6 @@ bool isSupported(Node* node) {
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)",
};
// clang-format on
if (get_tensorexpr_elementwise_set().contains(node) ||
node->isMemberOf(supported_non_eltwise_set()) ||
@ -903,7 +901,6 @@ class TensorExprFuser {
static const OperatorSet pow{
"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor",
};
// clang-format on
// Check types of input values.
for (const Value* v : node->inputs()) {

View File

@ -167,7 +167,7 @@ static std::optional<std::vector<Value*>> build_script_grad(
auto grad_inputs = insertGraph(*graph, *bw_graph, grad);
grad_inputs = unpackOutputs(grad_inputs);
return grad_inputs;
};
}
namespace {
class GradientHelper {

View File

@ -56,10 +56,9 @@
C10_DEFINE_bool(
torch_jit_execution_plan_reuse_code_graph,
false,
"Directly reuse the preprocessed graph in the CodeImpl to reduce the memory consumption. This is aggressive memory saving, and please be cautious!");
"Directly reuse the preprocessed graph in the CodeImpl to reduce the memory consumption. This is aggressive memory saving, and please be cautious!")
namespace torch::jit {
EnableProfilingGuard::EnableProfilingGuard() {
auto& executor_mode = getExecutorMode();
old_executor_mode = executor_mode;
@ -432,8 +431,8 @@ struct DifferentiableGraphOp {
{
auto inputs = last(stack, num_inputs);
// hook up the outputs of df to the gradient functions of the inputs that
// require gradients
// hook up the outputs of df to the gradient functions of the inputs
// that require gradients
for (auto idx : grad.df_output_vjps) {
grad_fn->addOutputForIValue(inputs[idx]);
}
@ -455,8 +454,8 @@ struct DifferentiableGraphOp {
// TODO - XXX - if any output is the same tensor multiple times, views
// have to be setup here. We need to refactor autograd until it is safe
// for tensors to be constructed without all the viewing infrastructure.
// this is currently intentionally not done here so we can get an idea of
// our perf before introducing overhead for correctness
// this is currently intentionally not done here so we can get an idea
// of our perf before introducing overhead for correctness
for (auto idx : grad.df_input_vjps) {
grad_fn->addInputIValue(outputs[idx]);
}
@ -501,7 +500,8 @@ struct DifferentiableGraphOp {
detach(stack[i]);
}
}
// Capture (save) inputs that would be required to subsequently run backwards
// Capture (save) inputs that would be required to subsequently run
// backwards
void captureInputs(
DifferentiableGraphBackward& grad_fn,
at::ArrayRef<IValue> inputs) const {
@ -736,8 +736,10 @@ struct GraphExecutorImpl : public GraphExecutorImplBase {
runOptimization(opt_graph);
// Phase 4. If this graph will be differentiated, we need to slice out the
// symbolically differentiable subgraphs for further optimizations.
// Phase 5. Apply non-differentiable optimizations to the graphs we've found
// symbolically differentiable subgraphs for further
// optimizations.
// Phase 5. Apply non-differentiable optimizations to the graphs we've
// found
// (or the whole graph if we know we won't need its derivative).
if (needsGradient(opt_graph)) {
auto diff_nodes = CreateAutodiffSubgraphs(
@ -781,8 +783,8 @@ struct GraphExecutorImpl : public GraphExecutorImplBase {
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
ArgumentSpecCreator arg_spec_creator_;
// Populated only when optimize is false (and in that case plan_cache will be
// unused). The compiled version of graph.
// Populated only when optimize is false (and in that case plan_cache will
// be unused). The compiled version of graph.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
ExecutionPlan fallback;

View File

@ -49,12 +49,12 @@ using torch::distributed::autograd::DistAutogradContainer;
C10_DEFINE_bool(
torch_jit_enable_rethrow_caught_exception,
false,
"enable rethrowing caught exception");
"enable rethrowing caught exception")
C10_DEFINE_bool(
torch_jit_enable_expanded_stacks,
false,
"When true we will attemps to pre-expand node stacks and cache expanded stacks.");
"When true we will attemps to pre-expand node stacks and cache expanded stacks.")
namespace torch::jit {

View File

@ -41,32 +41,32 @@
C10_DEFINE_bool(
torch_jit_enable_new_executor,
true,
"If this flag is set to false TorchScript will be using the legacy/original executor");
"If this flag is set to false TorchScript will be using the legacy/original executor")
C10_DEFINE_bool(
torch_jit_disable_warning_prints,
false,
"Disables warning.warn prints in TorchScript graph");
"Disables warning.warn prints in TorchScript graph")
C10_DEFINE_bool(
torch_jit_static_then_dynamic,
false,
"fuse on two static compilations then 10 dynamic");
"fuse on two static compilations then 10 dynamic")
C10_DEFINE_bool(
torch_jit_always_dynamic,
false,
"fuse on 12 dynamic compilations");
"fuse on 12 dynamic compilations")
C10_DEFINE_bool(
torch_jit_release_profiling_graph_after_optimization,
false,
"After getOptimizedPlanFor release the optimization record for reduction of memory in inference. This is aggressive memory saving, and please be cautious!");
"After getOptimizedPlanFor release the optimization record for reduction of memory in inference. This is aggressive memory saving, and please be cautious!")
C10_DEFINE_int32(
torch_jit_release_profiling_graph_delay_in_seconds,
60,
"How long to wait before releasing the profiling graph after optimizaiton is done. Only used if torch_jit_release_profiling_graph_after_optimization is set to true.");
"How long to wait before releasing the profiling graph after optimizaiton is done. Only used if torch_jit_release_profiling_graph_after_optimization is set to true.")
constexpr size_t kDefaultNumProfiledRuns = 1;
constexpr size_t kDefaultBailoutDepth = 20;
@ -74,11 +74,11 @@ constexpr size_t kDefaultBailoutDepth = 20;
C10_DEFINE_int64(
torch_jit_num_profiled_runs,
kDefaultNumProfiledRuns,
"Number of profiling runs");
"Number of profiling runs")
C10_DEFINE_int64(
torch_jit_bailout_depth,
kDefaultBailoutDepth,
"Number of re-specializations");
"Number of re-specializations")
namespace torch::jit {

File diff suppressed because it is too large Load Diff

View File

@ -50,7 +50,7 @@
C10_DEFINE_bool(
static_runtime_disable_debug_memory_overlap_check,
false,
"If true, disable the memory overlap check in debug mode in ProcessedNode::run()");
"If true, disable the memory overlap check in debug mode in ProcessedNode::run()")
namespace torch::jit {

View File

@ -72,7 +72,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
// put output back
p_node->Output(0) = std::move(stack[0]);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::TupleUnpack,
@ -91,7 +91,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
p_node->Output(i) = elems[i];
}
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::DictConstruct,
@ -116,7 +116,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
}
p_node->Output(0) = result;
};
});
})
// See [Borrowed IValue Outputs]
REGISTER_NATIVE_OPERATOR_FUNCTOR(
@ -139,7 +139,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
p_node->Output(i - 1) = createBorrowedIValue(value->value());
}
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::__getitem__, aten_getitem, [](Node* n) -> SROperator {
if (!sr_schema_check(
@ -177,7 +177,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::__getitem__, aten_getitem, [](Node* n) ->
// TODO(T98581096): make __getitem__ work for other container types
return nullptr;
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::ListConstruct,
@ -197,7 +197,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
// put output back
p_node->Output(0) = std::move(stack[0]);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::ListUnpack,
@ -219,7 +219,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
p_node->Output(i) = list[i];
}
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::append,
@ -233,7 +233,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
auto list = p_node->Input(0).toList();
list.push_back(p_node->Input(1));
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::list,
@ -260,7 +260,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::numel,
@ -273,7 +273,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& arg = p_node->Input(0).toTensor();
p_node->Output(0) = arg.numel();
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::cpu,
@ -286,7 +286,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& arg = p_node->Input(0).toTensor();
p_node->Output(0) = arg.cpu();
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::__range_length,
@ -312,7 +312,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
p_node->Output(0) = 0;
}
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -332,7 +332,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) ->
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::item,
@ -345,7 +345,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& self = p_node->Input(0).toTensor();
p_node->Output(0) = at::native::item(self);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::GetAttr,
@ -362,7 +362,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto slot = type.getAttributeSlot(field);
p_node->Output(0) = module.getSlot(slot);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::SetAttr,
@ -379,7 +379,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto slot = type.getAttributeSlot(field);
module.setSlot(slot, p_node->Input(1));
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::transpose,
@ -396,7 +396,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto in2_i = p_node->Input(2).toInt();
p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::flatten, aten_flatten, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -410,7 +410,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::flatten, aten_flatten, [](Node* n) -> SRO
const auto in2_i = p_node->Input(2).toInt();
p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::permute,
@ -426,7 +426,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto in1_iv = p_node->Input(1).toDimVector();
p_node->Output(0) = at::native::permute(in0_t, in1_iv);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::reshape,
@ -442,7 +442,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto in1_iv = p_node->Input(1).toDimVector();
p_node->Output(0) = at::native::reshape(in0_t, in1_iv);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::slice, aten_slice, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -458,7 +458,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::slice, aten_slice, [](Node* n) -> SROpera
const auto in4_i = p_node->Input(4).toInt();
p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -497,7 +497,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROpe
").");
p_node->Output(0) = at::native::slice(self, dim, start, start + length, 1);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -544,7 +544,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
}
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::detach,
@ -559,7 +559,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& in0_t = p_node->Input(0).toTensor();
p_node->Output(0) = at::native::alias(in0_t);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::expand_as,
@ -575,7 +575,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& other = p_node->Input(1).toTensor();
p_node->Output(0) = self.expand(other.sizes());
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::isinstance,
@ -600,7 +600,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
p_node->Output(0) = false;
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::TypeCheck,
@ -633,7 +633,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
p_node->Output(num_inputs) = true;
};
});
})
// See [Borrowed IValue Outputs]
REGISTER_NATIVE_OPERATOR_FUNCTOR(
@ -653,7 +653,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
}
}
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::view,
@ -669,7 +669,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto size = p_node->Input(1).toIntList();
p_node->Output(0) = at::native::view(input, size.vec());
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::size,
@ -696,7 +696,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
}
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::squeeze,
@ -713,7 +713,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto dim = p_node->Input(1).toInt();
p_node->Output(0) = at::native::squeeze(self, dim);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::split, aten_split, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -739,7 +739,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::split, aten_split, [](Node* n) -> SROpera
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::split_with_sizes,
@ -759,7 +759,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
p_node->Output(0) =
at::native::split_with_sizes(self, split_sizes.vec(), dim);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
static_runtime::select_tensor,
@ -788,7 +788,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
IValue(c10::MaybeOwnedTraits<at::TensorBase>::createBorrow(
assignFrom.toTensor()));
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::mul,
@ -814,7 +814,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
}
pnode->Output(0) = ret;
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::sub,
@ -829,7 +829,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto b = pnode->Input(1).toInt();
pnode->Output(0) = a - b;
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::add,
@ -855,7 +855,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::tensor_split, aten_tensor_split, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -889,7 +889,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::tensor_split, aten_tensor_split, [](Node*
}
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::Int,
@ -903,7 +903,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& input = pnode->Input(0).toTensor();
pnode->Output(0) = at::native::item(input).toInt();
};
});
})
// See [Create owned refs for special values]
REGISTER_NATIVE_OPERATOR_FUNCTOR(
@ -915,7 +915,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
}
return
[](ProcessedNode* p_node) { p_node->Output(0) = p_node->Input(0); };
});
})
namespace {
bool outputsEmpty(const Block* block) {
@ -1020,7 +1020,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
return [](ProcessedNode*) {};
}
return [](ProcessedNode*) {};
});
})
namespace {
@ -1147,7 +1147,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
smodule, args, future, *launcher);
(*launcher)(std::move(runtime_launcher));
};
});
})
/*
aten::wait waits on the future (present in corresponding fork)
to be executed. Once the execution is complete, the future is marked
@ -1181,7 +1181,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
p_node->Output(i) = elems[i];
}
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::Loop,
@ -1225,7 +1225,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
p_node->Output(i) = std::move(args[i + 1]);
}
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::CreateObject,
@ -1240,7 +1240,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
c10::StrongTypePtr(class_type->compilation_unit(), class_type),
class_type->numAttributes());
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::TupleIndex,
@ -1262,7 +1262,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
}
pnode->Output(0) = elems[norm_idx];
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::RaiseException,
@ -1275,7 +1275,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& message = pnode->Input(0).toStringRef();
throw std::runtime_error(message);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::Uninitialized,
@ -1287,7 +1287,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
return [](ProcessedNode* pnode) {
pnode->Output(0) = IValue::uninitialized();
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::format,
@ -1304,7 +1304,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
TORCH_DCHECK_EQ(stack.size(), 1);
pnode->Output(0) = std::move(stack[0]);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::device,
@ -1317,7 +1317,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& input = pnode->Input(0).toTensor();
pnode->Output(0) = input.device();
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::dtype,
@ -1330,7 +1330,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& input = pnode->Input(0).toTensor();
pnode->Output(0) = static_cast<int64_t>(input.scalar_type());
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::dim,
@ -1343,7 +1343,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& input = pnode->Input(0).toTensor();
pnode->Output(0) = input.dim();
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::__not__,
@ -1356,7 +1356,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
auto input = pnode->Input(0).toBool();
pnode->Output(0) = !input;
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::Bool,
@ -1382,7 +1382,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
}
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::is_cuda,
@ -1395,7 +1395,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& input = pnode->Input(0).toTensor();
pnode->Output(0) = input.is_cuda();
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::tolist,
@ -1413,7 +1413,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
TORCH_DCHECK_EQ(stack.size(), 1);
pnode->Output(0) = std::move(stack[0]);
};
});
})
// See [Borrowed IValue Outputs]
REGISTER_NATIVE_OPERATOR_FUNCTOR(
@ -1428,7 +1428,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
pnode->Output(0) = condition ? createBorrowedIValue(pnode->Input(1))
: createBorrowedIValue(pnode->Input(2));
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::len,
@ -1474,7 +1474,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
}
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::IntImplicit,
@ -1500,7 +1500,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
}
pnode->Output(0) = at::native::item(tensor).toInt();
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::select,
@ -1517,7 +1517,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto index = pnode->Input(2).toInt();
pnode->Output(0) = at::native::select(self, dim, index);
};
});
})
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::reshape_as,
@ -1533,6 +1533,6 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto& other = pnode->Input(1).toTensor();
pnode->Output(0) = at::native::reshape(self, other.sizes());
};
});
})
} // namespace torch::jit

View File

@ -37,8 +37,6 @@
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <iterator>
#include <mutex>
#include <unordered_map>
#include <ATen/CompositeExplicitAutogradFunctions.h>
@ -46,10 +44,9 @@ C10_DEFINE_bool(
static_runtime_enable_fast_math,
true,
"If on, static runtime may use use optimizations that cause accuracy loss "
"vs the jit interpreter");
"vs the jit interpreter")
namespace at::native {
static void repeat_out(
at::Tensor& result,
const Tensor& self,
@ -140,9 +137,9 @@ static at::Tensor& flatten_copy_out(
// We don't want to infer_size on the entire shape, because that can give us
// an extra degree of freedom we don't want; for example, consider shape [0,
// 1, 3, 0], with start_dim=1, end_dim=2. It's clear we want result shape [0,
// 3, 0] but passing [0, -1, 0] to infer_size means the -1 can take on any
// value and satisfy the constraints.
// 1, 3, 0], with start_dim=1, end_dim=2. It's clear we want result shape
// [0, 3, 0] but passing [0, -1, 0] to infer_size means the -1 can take on
// any value and satisfy the constraints.
auto iter = self.sizes().data();
auto slice_numel = std::accumulate(
iter + start_dim,
@ -326,8 +323,8 @@ static Tensor& c2_argmin_out(
return true;
}
// if a is not nan and b is nan, then a is not less than b
// with LessOrNan semantics otherwise, act normally. If `b` is
// NaN then a < b will always return false, so this is
// with LessOrNan semantics otherwise, act normally. If `b`
// is NaN then a < b will always return false, so this is
// equivalent to the first snippet.
return a < b;
});
@ -378,7 +375,7 @@ static at::Tensor& dequantize_copy_out(Tensor& out, const Tensor& self) {
namespace torch::jit {
C10_DEFINE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
C10_DEFINE_REGISTRY(SROperatorRegistry, SROperatorFunctor)
bool opIsRegistered(const c10::Symbol& op_name) {
const std::string name(op_name.toQualString());
@ -505,7 +502,7 @@ REGISTER_OPERATOR_FUNCTOR(
}
listConstructSlowPath(type, size, p_node);
};
});
})
static void tupleConstructSlowPath(const size_t size, ProcessedNode* p_node) {
// prepare inputs
@ -557,7 +554,7 @@ REGISTER_OPERATOR_FUNCTOR(
}
tupleConstructSlowPath(size, p_node);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::abs, aten_abs, [](Node* n) -> SROperator {
if (!n->matches(torch::schema("aten::abs(Tensor self) -> Tensor"))) {
@ -574,7 +571,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::abs, aten_abs, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::native::abs_out(in0_t, out_t);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -594,7 +591,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::cpu::mul_out(out_t, in0_t, in1_t);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -616,7 +613,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::cpu::addmm_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s);
};
});
})
#ifdef FBCODE_CAFFE2
// Disable externally to avoid MSVC errors in open-source CI
@ -676,9 +673,9 @@ REGISTER_OPERATOR_FUNCTOR(
&clamp_min,
&clamp_max,
&nan,
&output_size});
&output_size})
};
});
})
#endif
@ -723,7 +720,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
}
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_OPERATOR_FUNCTOR(aten::bmm, aten_bmm, [](Node* n) -> SROperator {
if (!n->matches(
@ -741,7 +738,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::bmm, aten_bmm, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::cpu::bmm_out(out_t, in0_t, in1_t);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::nan_to_num, aten_nan_to_num, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -762,7 +759,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::nan_to_num, aten_nan_to_num, [](Node* n) -> SROp
fastResizeToZero(out_t);
at::native::nan_to_num_out(in0_t, in1_d, in2_d, in3_d, out_t);
};
});
})
namespace {
@ -895,7 +892,7 @@ static SROperator aten_stack(Node* n) {
};
}
REGISTER_OPERATOR_FUNCTOR(aten::stack, aten_stack, aten_stack);
REGISTER_OPERATOR_FUNCTOR(aten::stack, aten_stack, aten_stack)
REGISTER_OPERATOR_FUNCTOR(
prim::VarStack,
@ -913,7 +910,7 @@ REGISTER_OPERATOR_FUNCTOR(
}
varStackOut(*p_node, dim);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::leaky_relu, aten_leaky_relu, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -931,7 +928,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::leaky_relu, aten_leaky_relu, [](Node* n) -> SROp
auto& out_t = p_node->Output(0).toTensor();
at::cpu::leaky_relu_out(out_t, in0_t, in1_s);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator {
if (!n->matches(torch::schema("aten::relu(Tensor self) -> Tensor"))) {
@ -954,7 +951,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator {
int64_t nn = in0_t.numel();
te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn});
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator {
if (!n->matches(torch::schema("aten::tanh(Tensor self) -> Tensor"))) {
@ -977,7 +974,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator {
int64_t nn = in0_t.numel();
te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn});
};
});
})
REGISTER_OPERATOR_FUNCTOR(
prim::TensorExprDynamicGroup,
@ -1012,7 +1009,7 @@ REGISTER_OPERATOR_FUNCTOR(
}
}
};
});
})
REGISTER_OPERATOR_FUNCTOR(
aten::sigmoid,
@ -1038,7 +1035,7 @@ REGISTER_OPERATOR_FUNCTOR(
int64_t nn = in0_t.numel();
te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn});
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -1073,7 +1070,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator {
float c = clamp_value;
te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn, &c});
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -1114,7 +1111,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator {
out_t.unsafeGetTensorImpl(), src.sizes(), src.strides());
at::native::copy_(out_t, src, false);
};
});
})
REGISTER_OPERATOR_FUNCTOR(
quantized::embedding_bag_byte_rowwise_offsets,
@ -1152,7 +1149,7 @@ REGISTER_OPERATOR_FUNCTOR(
compressed_indices_mapping,
include_last_offset);
};
});
})
REGISTER_OPERATOR_FUNCTOR(
quantized::embedding_bag_4bit_rowwise_offsets,
@ -1190,7 +1187,7 @@ REGISTER_OPERATOR_FUNCTOR(
compressed_indices_mapping,
include_last_offset);
};
});
})
REGISTER_OPERATOR_FUNCTOR(
quantized::embedding_bag_byte_prepack,
@ -1211,7 +1208,7 @@ REGISTER_OPERATOR_FUNCTOR(
fastResizeToZero(out_t);
at::native::qembeddingbag_byte_prepack_out(out_t, weight);
};
});
})
// The out variant takes precedence over native
REGISTER_OPERATOR_FUNCTOR(aten::narrow_copy, aten_narrow_copy, [](Node* n) -> SROperator {
@ -1241,7 +1238,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::narrow_copy, aten_narrow_copy, [](Node* n) -> SR
fastResizeToZero(output);
at::native::narrow_copy_dense_cpu_out(self, dim, start, length, output);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::index, aten_index, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
"aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"))) {
@ -1260,7 +1257,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::index, aten_index, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::cpu::index_out(out_t, in0_t, in1_l);
};
});
})
REGISTER_OPERATOR_FUNCTOR(
aten::index_select,
@ -1283,7 +1280,7 @@ REGISTER_OPERATOR_FUNCTOR(
fastResizeToZero(out);
at::native::index_select_out_cpu_(self, dim, index, out);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::pow, aten_pow, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -1345,7 +1342,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::pow, aten_pow, [](Node* n) -> SROperator {
}
LogAndDumpSchema(n);
return nullptr;
});
})
namespace {
@ -1623,7 +1620,7 @@ REGISTER_OPERATOR_FUNCTOR(
return to_maybe_copy_out_functor<false, false>;
}
}
});
})
// out variant takes precedence over native
// NB: This impl doesn't work for cpu->cuda copy/cast or vice versa.
@ -1646,7 +1643,7 @@ REGISTER_OPERATOR_FUNCTOR(
const bool has_memory_format = n->inputs().size() == 5;
return get_to_copy_functor(
has_constant_non_tensor_dtype_and_flags, has_memory_format);
});
})
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_OPERATOR_FUNCTOR(
@ -1671,7 +1668,7 @@ REGISTER_OPERATOR_FUNCTOR(
fastResizeToZero(out_t);
at::native::dequantize_copy_out(out_t, self);
};
});
})
// Out variants for view ops are registered to a separate registry because
// their outputs (views) can't participate in memory reuse.
@ -1695,7 +1692,7 @@ REGISTER_OPERATOR_FUNCTOR(
auto& out = p_node->Output(0).toTensor();
at::native::reshape_copy_out(out, self, proposed_shape, true);
};
});
})
REGISTER_OPERATOR_FUNCTOR(
static_runtime::flatten_copy,
@ -1718,7 +1715,7 @@ REGISTER_OPERATOR_FUNCTOR(
auto& out = p_node->Output(0).toTensor();
at::native::flatten_copy_out(out, self, start_dim, end_dim);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
if (n->inputs().size() != 2 && n->inputs().size() != 4) {
@ -1758,7 +1755,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
}
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_OPERATOR_FUNCTOR(aten::mean, aten_mean, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -1795,7 +1792,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::mean, aten_mean, [](Node* n) -> SROperator {
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_OPERATOR_FUNCTOR(aten::repeat, aten_repeat, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -1814,7 +1811,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::repeat, aten_repeat, [](Node* n) -> SROperator {
at::Tensor& output = p_node->Output(0).toTensor();
at::native::repeat_out(output, self, repeats);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::max, aten_max, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -1869,7 +1866,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::max, aten_max, [](Node* n) -> SROperator {
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_OPERATOR_FUNCTOR(aten::sign, aten_sign, [](Node* n) -> SROperator {
if (!n->matches(torch::schema("aten::sign.Tensor(Tensor input) -> Tensor"))) {
@ -1886,7 +1883,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::sign, aten_sign, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::cpu::sign_out(out_t, in0_t);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -1944,7 +1941,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode);
}
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::log, aten_log, [](Node* n) -> SROperator {
if (!n->matches(torch::schema("aten::log.Tensor(Tensor input) -> Tensor"))) {
@ -1961,7 +1958,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::log, aten_log, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::cpu::log_out(out_t, in0_t);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::sub, aten_sub, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -1997,7 +1994,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::sub, aten_sub, [](Node* n) -> SROperator {
}
LogAndDumpSchema(n);
return nullptr;
});
})
// TODO: support clamp_min.Tensor(Tensor self, Tensor min) -> Tensor
REGISTER_OPERATOR_FUNCTOR(
@ -2020,7 +2017,7 @@ REGISTER_OPERATOR_FUNCTOR(
fastResizeToZero(out_t);
at::cpu::clamp_min_out(out_t, in0_t, in1_s);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::argmin, aten_argmin, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -2044,7 +2041,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::argmin, aten_argmin, [](Node* n) -> SROperator {
}
at::cpu::argmin_out(out_t, in0_t, dim, keepdim);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::softmax, aten_softmax, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -2066,7 +2063,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::softmax, aten_softmax, [](Node* n) -> SROperator
dtype == at::ScalarType::Float;
at::cpu::_softmax_out(out_t, in_t, dim, half_to_float);
};
});
})
namespace {
@ -2122,7 +2119,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::layer_norm, aten_layer_norm, [](Node* n) -> SROp
at::Tensor& output = p_node->Output(0).toTensor();
at::native::layer_norm_cpu_out(output, *X, *gamma, *beta, eps, M, N);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -2187,7 +2184,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator {
}
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_OPERATOR_FUNCTOR(aten::matmul, aten_matmul, [](Node* n) -> SROperator {
if (!n->matches(
@ -2207,7 +2204,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::matmul, aten_matmul, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::native::matmul_out(in0_t, in1_t, out_t);
};
});
})
REGISTER_OPERATOR_FUNCTOR(quantized::linear, quantized_linear, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -2249,7 +2246,7 @@ REGISTER_OPERATOR_FUNCTOR(quantized::linear, quantized_linear, [](Node* n) -> SR
input, output_scale, output_zero_point, out_t);
}
};
});
})
REGISTER_OPERATOR_FUNCTOR(
fb::quantized_linear,
@ -2296,7 +2293,7 @@ REGISTER_OPERATOR_FUNCTOR(
input, output_scale, output_zero_point, out_t);
}
};
});
})
namespace {
@ -2376,7 +2373,7 @@ REGISTER_OPERATOR_FUNCTOR(
return nullptr;
}
return quantized_linear_dynamic_fp16_impl<false>(n);
});
})
REGISTER_OPERATOR_FUNCTOR(
quantized::linear_relu_dynamic_fp16,
@ -2389,7 +2386,7 @@ REGISTER_OPERATOR_FUNCTOR(
return nullptr;
}
return quantized_linear_dynamic_fp16_impl<true>(n);
});
})
// device & pin_memory matter only when CUDA is enabled.
static bool hasTensorWithOptions(
@ -2438,7 +2435,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::full, aten_full, [](Node* n) -> SROperator {
p_node->Output(0) =
at::native::full_out(size, fill_value, p_node->Output(0).toTensor());
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::full_like, aten_full_like, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -2464,7 +2461,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::full_like, aten_full_like, [](Node* n) -> SROper
at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
at::native::fill_out(out_t, in1_s);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::ones, aten_ones, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -2487,7 +2484,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::ones, aten_ones, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::native::ones_out(size, out_t);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::ones_like, aten_ones_like, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -2512,7 +2509,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::ones_like, aten_ones_like, [](Node* n) -> SROper
fastResizeToZero(out_t);
at::native::ones_out(self.sizes(), out_t);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::zeros, aten_zeros, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -2533,7 +2530,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::zeros, aten_zeros, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::compositeexplicitautograd::zeros_out(out_t, size);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::linear, aten_linear, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -2555,7 +2552,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::linear, aten_linear, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
at::native::linear_out(out_t, in0_t, in1_t, in2_t);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::linalg_norm, aten_linalg_norm, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -2605,7 +2602,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::linalg_norm, aten_linalg_norm, [](Node* n) -> SR
}
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator {
if (!n->matches(
@ -2625,7 +2622,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator {
fastResizeToZero(output);
at::cpu::cat_outf(inputs, dim, output);
};
});
})
REGISTER_OPERATOR_FUNCTOR(aten::cumsum, aten_cumsum, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
@ -2645,7 +2642,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::cumsum, aten_cumsum, [](Node* n) -> SROperator {
fastResizeToZero(output);
at::cpu::cumsum_out(output, input, dim, dtype);
};
});
})
REGISTER_OPERATOR_FUNCTOR(
aten::nonzero,
@ -2665,7 +2662,7 @@ REGISTER_OPERATOR_FUNCTOR(
fastResizeToZero(output);
at::native::nonzero_out_cpu(input, output);
};
});
})
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_OPERATOR_FUNCTOR(
@ -2690,7 +2687,7 @@ REGISTER_OPERATOR_FUNCTOR(
fastResizeToZero(out_t);
at::cpu::cat_outf(inputs, dim, out_t);
};
});
})
namespace {
// This template and its specialization help us avoid compiler warnings
@ -2752,7 +2749,7 @@ REGISTER_OPERATOR_FUNCTOR(
int64_t nn = input.numel();
te->call({out.data_ptr(), input.data_ptr(), &nn});
};
});
})
REGISTER_OPERATOR_FUNCTOR(
aten::remainder,
@ -2790,7 +2787,7 @@ REGISTER_OPERATOR_FUNCTOR(
// Unrecognized overload
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_OPERATOR_FUNCTOR(aten::where, aten_where, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
@ -2811,7 +2808,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::where, aten_where, [](Node* n) -> SROperator {
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_OPERATOR_FUNCTOR(
prim::NumToTensor,
@ -2833,7 +2830,7 @@ REGISTER_OPERATOR_FUNCTOR(
}
LogAndDumpSchema(n);
return nullptr;
});
})
REGISTER_OPERATOR_FUNCTOR(
quantized::embedding_bag_byte_unpack,
@ -2855,6 +2852,6 @@ REGISTER_OPERATOR_FUNCTOR(
auto& out = pnode->Output(0).toTensor();
at::native::qembeddingbag_byte_unpack_out(out, weight);
};
});
})
} // namespace torch::jit

View File

@ -12,10 +12,9 @@
C10_DEFINE_bool(
enable_clip_ranges_gather_fusions,
true,
"If on, static runtime or optimize_sparse_nn_model will fuse clip ranges gather ops.");
"If on, static runtime or optimize_sparse_nn_model will fuse clip ranges gather ops.")
namespace torch::jit {
bool graphHasOp(std::shared_ptr<Graph>& graph, const char* op_name) {
DepthFirstGraphNodeIterator graph_it(graph);
for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
@ -715,8 +714,8 @@ static void ReplaceWithCopyImpl(
// b and c are aliases of a, sigmoid_ changes b, c, as well as a. e should
// equal to d in this case. If we replace reshape with the copy version, b
// and c are no longer aliases of a, the value of e would change as a
// result. To keep static runtime consistent with the jit interpreter, here
// we choose not to replace reshape with the copy version
// result. To keep static runtime consistent with the jit interpreter,
// here we choose not to replace reshape with the copy version
if (db.hasInputWriters(n)) {
continue;
}
@ -1086,8 +1085,8 @@ void ForceNonEmptyOutputsHelper(Value* none_value, Block* block) {
}
if (needs_output) {
// Loop sub-blocks should always return at least one output (the new loop
// condition)
// Loop sub-blocks should always return at least one output (the new
// loop condition)
DCHECK(node->kind() == prim::If);
auto* output = node->addOutput();
output->setType(c10::NoneType::get());
@ -1340,8 +1339,8 @@ bool isNoOpSlice(Node* node) {
return false;
}
auto end = toIValue(node->input(2));
// Could also look at list length, but most models that have this pattern are
// just doing list[0:], so it's not needed for now.
// Could also look at list length, but most models that have this pattern
// are just doing list[0:], so it's not needed for now.
return end.has_value() && end->isNone();
}
} // namespace

View File

@ -538,7 +538,7 @@ struct FileCheckImpl {
std::vector<std::vector<Check>> groups;
};
FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){};
FileCheck::FileCheck() : fcImpl(new FileCheckImpl()) {}
std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc) {
out << "FileCheck checks:\n";
@ -546,7 +546,7 @@ std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc) {
out << "\t" << c << "\n";
}
return out;
};
}
FileCheck::~FileCheck() {
if (!fcImpl->has_run) {
@ -554,17 +554,17 @@ FileCheck::~FileCheck() {
std::cout << *fcImpl;
}
fcImpl.reset();
};
}
void FileCheck::run(const std::string& test_file) {
fcImpl->run(test_file);
};
}
void FileCheck::run(const Graph& graph) {
std::stringstream graph_str;
graph_str << graph;
fcImpl->run(graph_str.str());
};
}
void FileCheck::run(
const std::string& input_checks_string,

View File

@ -6,74 +6,74 @@ C10_DEFINE_bool(torch_lazy_ir_debug, false, "Enable lazy tensor IR debugging");
C10_DEFINE_bool(
torch_lazy_param_aliasing,
true,
"Enable parameter aliasing support");
"Enable parameter aliasing support")
C10_DEFINE_bool(
torch_lazy_handle_special_scalars,
false,
"Handle special scalars 0 and 1 differently");
"Handle special scalars 0 and 1 differently")
C10_DEFINE_bool(
torch_lazy_all_numbers_special_scalars,
false,
"Handle all numbers as special scalars");
"Handle all numbers as special scalars")
C10_DEFINE_bool(
torch_lazy_reuse_ir,
false,
"Reuse IR nodes from previous tracing when possible");
"Reuse IR nodes from previous tracing when possible")
C10_DEFINE_bool(
torch_lazy_use_thread_pool,
false,
"Use thread pool to schedule backend execution");
"Use thread pool to schedule backend execution")
C10_DEFINE_bool(
torch_lazy_enable_device_data_cache,
true,
"Enable or disable device data cache (turns cache on or off), does not change cache state");
"Enable or disable device data cache (turns cache on or off), does not change cache state")
C10_DEFINE_int(
torch_lazy_compilation_cache_size,
1024,
"Size of the compilation cache");
"Size of the compilation cache")
C10_DEFINE_int(
torch_lazy_device_data_cache_size,
128,
"Size of the DeviceData cache");
"Size of the DeviceData cache")
C10_DEFINE_int(
torch_lazy_io_thread_pool_size,
// TODO: measure which default value will give better
// performance, std::thread::hardware_concurrency()?
// TODO: measure which default value
// will give better performance,
// std::thread::hardware_concurrency()?
1,
"Size of the execution thread pool");
"Size of the execution thread pool")
C10_DEFINE_int(torch_lazy_metrics_samples, 1024, "Max metrics sample size");
C10_DEFINE_int(torch_lazy_metrics_samples, 1024, "Max metrics sample size")
C10_DEFINE_int(
torch_lazy_trim_graph_check_frequency,
5000,
"How often to check for whether a graph needs to be split");
"How often to check for whether a graph needs to be split")
C10_DEFINE_int(
torch_lazy_trim_graph_size,
100000,
"The threshold (in terms of the number of nodes) for splitting a graph");
"The threshold (in terms of the number of nodes) for splitting a graph")
C10_DEFINE_string(
torch_lazy_metrics_percentiles,
"0.01:0.05:0.1:0.2:0.5:0.8:0.9:0.95:0.99",
"Metrics percentiles to be collected, using : as the delimiter");
"Metrics percentiles to be collected, using : as the delimiter")
C10_DEFINE_int(
torch_lazy_shape_cache_size,
4096,
"Set the size for the shape cache used for shape inference");
"Set the size for the shape cache used for shape inference")
namespace torch::lazy {
std::string& getLTCForceFallback() {
static std::string config;
static bool _ignore = [&]() {

View File

@ -35,13 +35,13 @@ class TORCH_API DimensionNode {
public:
virtual bool isSymbolic() const {
return false;
};
}
virtual int64_t getDynamicValue() const {
TORCH_CHECK(false, "NYI");
};
}
virtual int64_t getStaticValue() const {
TORCH_CHECK(false, "NYI");
};
}
virtual ~DimensionNode() = default;
};

View File

@ -10,10 +10,9 @@
C10_DEFINE_bool(
ltc_enable_dynamic_shapes,
false,
"Whether dynamic shape is enabled");
"Whether dynamic shape is enabled")
namespace torch::lazy {
static const torch::lazy::Output kNullOutput = torch::lazy::Output();
size_t Output::Hasher::operator()(const Output& output) const {

View File

@ -30,7 +30,7 @@ SizeNode::SizeNode(Value input, size_t dim)
std::vector<Shape>{},
1,
MHash(dim)),
dim_(dim){};
dim_(dim) {}
int64_t SizeNode::getStaticValue() const {
return dynamic_cast<const TsNode*>(operand(0).node)
@ -55,7 +55,7 @@ SizeAdd::SizeAdd(Value a, Value b)
OpKind{c10::Symbol::fromQualString("aten::add")},
{std::move(a), std::move(b)},
std::vector<Shape>{},
1){};
1) {}
int64_t SizeAdd::getStaticValue() const {
return DimCast(operand(0))->getStaticValue() +
@ -75,7 +75,7 @@ SizeMul::SizeMul(Value a, Value b)
OpKind{c10::Symbol::fromQualString("aten::mul")},
{std::move(a), std::move(b)},
std::vector<Shape>{},
1){};
1) {}
int64_t SizeMul::getStaticValue() const {
return DimCast(operand(0))->getStaticValue() *
@ -95,7 +95,7 @@ SizeDiv::SizeDiv(Value a, Value b)
OpKind{c10::Symbol::fromQualString("aten::div")},
{std::move(a), std::move(b)},
std::vector<Shape>{},
1){};
1) {}
int64_t SizeDiv::getStaticValue() const {
TORCH_CHECK(

View File

@ -268,7 +268,7 @@ at::Tensor LazyNativeFunctions::_to_copy(
std::move(node), lazy_self->GetDevice()));
return result;
}
};
}
at::Tensor LazyNativeFunctions::empty_symint(
at::SymIntArrayRef sym_size,

View File

@ -129,7 +129,7 @@ def gen_custom_ops_registration(
static_init_dispatch_registrations += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{dispatch_registrations_body}
}};"""
}}"""
anonymous_definition = "\n".join(
list(
concatMap(

View File

@ -1615,7 +1615,7 @@ def get_native_function_definitions(
registration_body += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{newline.join(registrations[kernel_namespace][namespace])}
}};"""
}}"""
definitions.extend(
fm.substitute_with_template(
"RegisterDispatchDefinitions.ini",

View File

@ -460,7 +460,7 @@ def gen_dispatcher_registrations(
"""\
TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
$dispatch_registrations_body
};"""
}"""
)
static_init_dispatch_registrations = static_template.substitute(
dispatch_key=dispatch_key,