mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Remove dependencies from Caffe2Go on PyTorch JIT (#20463)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20463 Source file changes mostly involve ifdef'ing-out references to JIT code from files that are part of Caffe2Go. Update Internal build scripts to remove those files from our globs. After this, changes to most of the JIT files should not trigger mobile CI. Reviewed By: dzhulgakov Differential Revision: D15329407 fbshipit-source-id: 48f614c6b028eef0a03ce5161d083a3e078b0412
This commit is contained in:
committed by
Facebook Github Bot
parent
3479777519
commit
9e7f22b223
@ -8,7 +8,6 @@
|
||||
// To explicitly use interned strings as symbols in your code, you must add
|
||||
// them to this list.
|
||||
|
||||
#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)
|
||||
#define FORALL_ATEN_BASE_SYMBOLS(_) \
|
||||
_(aten, __and__) \
|
||||
_(aten, __iand__) \
|
||||
@ -1013,4 +1012,3 @@ _(attr, workspace) \
|
||||
_(attr, x) \
|
||||
_(attr, x1) \
|
||||
_(attr, x2)
|
||||
#endif
|
||||
|
@ -5,9 +5,12 @@
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
|
||||
#include <ATen/core/aten_interned_strings.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)
|
||||
#include <ATen/core/aten_interned_strings.h>
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)
|
||||
|
@ -1,10 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
#include <torch/csrc/jit/script/function_schema_parser.h>
|
||||
#endif
|
||||
#include <vector>
|
||||
|
||||
namespace caffe2 {
|
||||
@ -156,7 +155,6 @@ inline std::unique_ptr<c10::KernelCache> noCache() {
|
||||
* - If your operator has a variable number of input tensors, make the first (!)
|
||||
* input an input of type TensorList. There must be no other tensor inputs.
|
||||
*/
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
#define C10_DECLARE_CAFFE2_OPERATOR(OperatorName) \
|
||||
namespace caffe2 { \
|
||||
namespace _c10_ops { \
|
||||
|
@ -63,8 +63,11 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
|
||||
type_ = operator_def.type();
|
||||
}
|
||||
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
namespace {
|
||||
int compute_input_size_(const std::vector<c10::IValue>& inputs) {
|
||||
int
|
||||
C10_UNUSED // Suppress unused function warning on mobile.
|
||||
compute_input_size_(const std::vector<c10::IValue>& inputs) {
|
||||
if (inputs.empty()) {
|
||||
return 0;
|
||||
}
|
||||
@ -103,6 +106,7 @@ OperatorBase::OperatorBase(
|
||||
input_tensors_.resize(input_size_);
|
||||
output_tensors_.resize(newstyle_outputs_.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
vector<TensorShape> OperatorBase::InputTensorShapes() const {
|
||||
vector<TensorShape> tps;
|
||||
@ -737,7 +741,11 @@ std::function<void(const OperatorDef&)> GetOperatorLogger() {
|
||||
|
||||
c10::optional<int> OperatorBase::argumentIndexWithName(
|
||||
const std::string& name) const {
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
return getFunctionSchema().argumentIndexWithName(name);
|
||||
#else
|
||||
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
|
||||
#endif
|
||||
}
|
||||
|
||||
OperatorBase::~OperatorBase() noexcept = default;
|
||||
|
@ -26,7 +26,9 @@
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
#include <ATen/core/ivalue.h>
|
||||
#endif
|
||||
|
||||
C10_DECLARE_bool(caffe2_operator_throw_if_fp_exceptions);
|
||||
|
||||
@ -50,10 +52,12 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
* Alternatively, inputs can be one tensor list ivalue followed by non-tensors
|
||||
* to represent operators with a variable number of inputs.
|
||||
*/
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
explicit OperatorBase(
|
||||
const c10::FunctionSchema& schema,
|
||||
std::vector<c10::IValue> inputs,
|
||||
std::vector<at::Tensor> outputs);
|
||||
#endif
|
||||
|
||||
virtual ~OperatorBase() noexcept;
|
||||
|
||||
@ -61,12 +65,20 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
* New operators should be instantiated with FunctionSchema
|
||||
*/
|
||||
bool isLegacyOperator() const {
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
return !fn_schema_;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
const c10::FunctionSchema& getFunctionSchema() const {
|
||||
CAFFE_ENFORCE(!isLegacyOperator());
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
return *fn_schema_.get();
|
||||
#else
|
||||
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
|
||||
#endif
|
||||
}
|
||||
|
||||
/** @brief Checks if the operator has an argument of the given name.
|
||||
@ -88,10 +100,14 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
|
||||
*operator_def_, name, default_value);
|
||||
}
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
auto index = argumentIndexWithName(name);
|
||||
CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name);
|
||||
const auto& value = newstyle_inputs_[index.value()];
|
||||
return value.template to<T>();
|
||||
#else
|
||||
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -100,10 +116,12 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
|
||||
*operator_def_, name);
|
||||
}
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
template <typename T>
|
||||
inline vector<T> GetVectorFromIValueList(const c10::IValue& value) const {
|
||||
return value.template to<vector<T>>();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
inline vector<T> GetRepeatedArgument(
|
||||
@ -114,10 +132,14 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
|
||||
*operator_def_, name, default_value);
|
||||
}
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
auto index = argumentIndexWithName(name);
|
||||
CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name);
|
||||
const auto& value = newstyle_inputs_[index.value()];
|
||||
return GetVectorFromIValueList<T>(value);
|
||||
#else
|
||||
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the inputs and outputs as specific types.
|
||||
@ -165,6 +187,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
throw enf;
|
||||
}
|
||||
}
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
DCHECK_LT(0, newstyle_inputs_.size());
|
||||
IValue ival;
|
||||
if (newstyle_inputs_[0].isTensorList()) {
|
||||
@ -186,6 +209,9 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
CAFFE_ENFORCE_EQ(tensor.GetDeviceType(), type);
|
||||
input_tensors_[idx] = std::move(tensor);
|
||||
return input_tensors_[idx];
|
||||
#else
|
||||
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -207,6 +233,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
// When you get a Tensor here it is not fully initialized
|
||||
return BlobGetMutableTensor(outputs_.at(idx), type);
|
||||
}
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
auto& output = newstyle_outputs_[idx];
|
||||
Tensor tensor = caffe2::Tensor(output);
|
||||
if (!tensor.defined() || tensor.GetDeviceType() != type) {
|
||||
@ -216,6 +243,9 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
}
|
||||
output_tensors_[idx] = caffe2::Tensor(output);
|
||||
return &output_tensors_[idx];
|
||||
#else
|
||||
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
|
||||
#endif
|
||||
}
|
||||
|
||||
inline Tensor
|
||||
@ -232,10 +262,14 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
|
||||
void SetOutputTensor(int idx, Tensor tensor) {
|
||||
if (!isLegacyOperator()) {
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
newstyle_outputs_[idx] = at::Tensor(tensor);
|
||||
|
||||
// also update the tensor in the hack
|
||||
output_tensors_[idx] = std::move(tensor);
|
||||
#else
|
||||
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
|
||||
#endif
|
||||
} else {
|
||||
// update the tensor in the workspace
|
||||
BlobSetTensor(outputs_.at(idx), std::move(tensor));
|
||||
@ -257,6 +291,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
"device must be provided in options.");
|
||||
return BlobGetMutableTensor(outputs_.at(idx), dims, options);
|
||||
}
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
auto& output = newstyle_outputs_[idx];
|
||||
Tensor tensor =
|
||||
GetSizedTensorWithOptions(caffe2::Tensor(output), dims, options);
|
||||
@ -265,6 +300,9 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
|
||||
output_tensors_[idx] = caffe2::Tensor(output);
|
||||
return &output_tensors_[idx];
|
||||
#else
|
||||
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get output Tensor of the operator and CopyFrom the given Tensor
|
||||
@ -349,7 +387,11 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
if (isLegacyOperator()) {
|
||||
return outputs_.size();
|
||||
}
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
return newstyle_outputs_.size();
|
||||
#else
|
||||
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
|
||||
#endif
|
||||
}
|
||||
inline const vector<const Blob*>& Inputs() const { return inputs_; }
|
||||
inline const vector<Blob*>& Outputs() { return outputs_; }
|
||||
@ -540,9 +582,11 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
return helper_;
|
||||
}
|
||||
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
std::vector<at::Tensor> move_newstyle_outputs() && {
|
||||
return std::move(newstyle_outputs_);
|
||||
}
|
||||
#endif
|
||||
|
||||
public:
|
||||
static const int kNoNetPositionSet = -1;
|
||||
@ -556,9 +600,11 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
vector<const Blob*> inputs_;
|
||||
vector<Blob*> outputs_;
|
||||
// Preferrably use c10::optional, but nvcc doesn't work
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
std::unique_ptr<const c10::FunctionSchema> fn_schema_;
|
||||
vector<c10::IValue> newstyle_inputs_;
|
||||
vector<at::Tensor> newstyle_outputs_;
|
||||
#endif
|
||||
// HACK
|
||||
// We preserve the fact that Output() returns Tensor*
|
||||
// by storing Tensor in a vector owned by the
|
||||
@ -618,6 +664,7 @@ inline NetDef OperatorBase::GetSingleArgument<NetDef>(
|
||||
return NetDef();
|
||||
}
|
||||
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
template <>
|
||||
inline vector<int> OperatorBase::GetVectorFromIValueList<int>(
|
||||
const c10::IValue& value) const {
|
||||
@ -649,6 +696,7 @@ inline vector<string> OperatorBase::GetVectorFromIValueList<string>(
|
||||
vector<string> out;
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
// OP_SINGLE_ARG provides a shorter initialization choice for initialization of
|
||||
// member variables for the class constructors.
|
||||
@ -688,6 +736,7 @@ class Operator : public OperatorBase {
|
||||
// constructors will run on that device.
|
||||
context_.SwitchToDevice();
|
||||
}
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
explicit Operator(
|
||||
const c10::FunctionSchema& fn_schema,
|
||||
std::vector<c10::IValue> inputs,
|
||||
@ -697,6 +746,7 @@ class Operator : public OperatorBase {
|
||||
// constructors will run on that device.
|
||||
context_.SwitchToDevice();
|
||||
}
|
||||
#endif
|
||||
~Operator() noexcept override {}
|
||||
|
||||
/// Retrieve a non-owning reference to the input at position 'idx' for this
|
||||
|
@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
// TODO Also register c10 operators on mobile
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
@ -225,9 +227,8 @@ createC10OperatorWrapper(const char* op_name, const char* overload_name) {
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace caffe2
|
||||
|
||||
// TODO Also register c10 operators on mobile
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
// TODO Currently we only register the CPU variant. This is going to be fixed
|
||||
// once the tensor detemplatization lands.
|
||||
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU( \
|
||||
@ -256,4 +257,3 @@ createC10OperatorWrapper(const char* op_name, const char* overload_name) {
|
||||
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_HIP( \
|
||||
OperatorName, Name)
|
||||
#endif
|
||||
} // namespace caffe2
|
||||
|
Reference in New Issue
Block a user