mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update from facebook 1ee4edd286a3 (#8040)
* Adding instance weight to batch distill loss as title * add bfloat 16-31 added bfloat 16-31 and their respective unit tests * [CUDA9] Upgrade - fbcode CUDA9 upgrade diff D5654023 has been out for a while thanks to Pieter. But with time growing it's becoming quite hard to rebase, because of the symlinks and auto-generated build/config files in tp2. Break D5654023 into two diffs, one touching tp2 config files, and another one touching fbcode TARGETS file (adding nvcc flag). These two should be a bit easier to rebase (for detailed procedure see "Test Plan"). This diff can only be committed if: 1. CUDA 9 rpm is rolled out fleet-wide (TBD) 2. NVidia driver 390.40 is rolled out fleet-wide (done) 3. Upgrade CUDA 9.1, cudnn 7.1, nccl 2.1 (done) 4. Make sure all dependents are built (done) 5. Test all C2 operators, PyTorch (see test plan) * Share intermediate int32 buffer across Conv ops Adding a known type * [C2 fix] infer function for ensure_cpu_output_op this is adding the missing device funtion for ensure_cpu_output_op * [int8] Add blob serializer/deserializer for Int8TensorCPU To export to logfiledb * [nomnigraph] Add try catch block to optimization passes in predictor This will catch failures that happen in the optimization pass. * Caffe2: avoid static initialization order fiasco for CAFFE_ENFORCE CAFFE_ENFORCE uses strack trace fetcher. Which is currently a global static variable. If at static initialization time CAFFE_ENFORCE is used, this is a SIOF. Recently CAFFE_ENFORCE was added into init functions registration, so we started to see this. Meyers singleton is going to provide safety here. If stacktrace fetcher was not registered yet, it will just use a dummy one. * NUMA support in SparseNN CPU benchmark Adding support for NUMA in SparseNN CPU benchmark * [mobile-roofline] Add logging needed for roofline model This should be all that's needed * Let the operators using the same input if the operators are not chained or else, we have to change the input data dims * fix null-pointer-use UBSAN errors in in reshape_op.h * revert previous fix on input blob name as title * Adding flag to let MineHardNegative automatically extract single value from dict Model exporter requires the output of the model to be a struct. This makes it convenient to use those models directly in MineHardNegative by allow automatic extraction of the single element of dict, which is a common use case. * Reverting change that broke internal tests back to OSS compatible state
This commit is contained in:
committed by
Soumith Chintala
parent
9060b7f4e2
commit
82b981e4db
@ -35,7 +35,7 @@ def main(args):
|
||||
input_name = args.input_name
|
||||
output_name = args.output_name
|
||||
|
||||
iters = int(args.iters)
|
||||
iters = int(args.instances)
|
||||
for i in range(iters):
|
||||
input_blob_name = input_name + (str(i) if i > 0 and args.chain else '')
|
||||
output_blob_name = output_name + str(i + 1)
|
||||
@ -85,8 +85,8 @@ if __name__ == "__main__":
|
||||
default="data")
|
||||
parser.add_argument("--output_name", help="Name of the output blob.",
|
||||
default="output")
|
||||
parser.add_argument("--iters",
|
||||
help="Number of iterations to run the operator.",
|
||||
parser.add_argument("--instances",
|
||||
help="Number of instances to run the operator.",
|
||||
default="1")
|
||||
parser.add_argument("-d", "--debug", help="Print debug information.",
|
||||
action='store_true')
|
||||
|
104
caffe2/core/int8_serialization.cc
Normal file
104
caffe2/core/int8_serialization.cc
Normal file
@ -0,0 +1,104 @@
|
||||
#include "caffe2/core/blob_serialization.h"
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/tensor_int8.h"
|
||||
#include "caffe2/core/typeid.h"
|
||||
#include "caffe2/core/types.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace int8 {
|
||||
|
||||
class Int8TensorCPUSerializer : public BlobSerializerBase {
|
||||
public:
|
||||
void Serialize(
|
||||
const Blob& blob,
|
||||
const string& name,
|
||||
SerializationAcceptor acceptor) override {
|
||||
const auto& tensor = blob.template Get<Int8TensorCPU>();
|
||||
BlobProto blob_proto;
|
||||
blob_proto.set_name(name);
|
||||
blob_proto.set_type("Int8TensorCPU");
|
||||
QTensorProto& proto = *blob_proto.mutable_qtensor();
|
||||
proto.set_name(name);
|
||||
for (int i = 0; i < tensor.t.ndim(); ++i) {
|
||||
proto.add_dims(tensor.t.dim32(i));
|
||||
}
|
||||
proto.set_precision(8);
|
||||
proto.set_scale(tensor.scale);
|
||||
proto.set_bias(tensor.zero_point);
|
||||
proto.set_is_signed(false);
|
||||
|
||||
const TensorProto::DataType data_type = TypeMetaToDataType(tensor.t.meta());
|
||||
proto.set_data_type(data_type);
|
||||
switch (data_type) {
|
||||
case TensorProto_DataType_INT32:
|
||||
detail::CopyToProtoAsIs(
|
||||
tensor.t.size(),
|
||||
tensor.t.template data<int32_t>(),
|
||||
proto.mutable_data(),
|
||||
&this->context_);
|
||||
break;
|
||||
case TensorProto_DataType_UINT8:
|
||||
detail::CopyToProtoWithCast(
|
||||
tensor.t.size(),
|
||||
tensor.t.template data<uint8_t>(),
|
||||
proto.mutable_data(),
|
||||
&this->context_);
|
||||
break;
|
||||
default:
|
||||
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
|
||||
}
|
||||
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
}
|
||||
|
||||
private:
|
||||
CPUContext context_;
|
||||
};
|
||||
|
||||
class Int8TensorCPUDeserializer : public TensorDeserializer<CPUContext> {
|
||||
public:
|
||||
void Deserialize(const BlobProto& blob_proto, Blob* blob) override {
|
||||
const QTensorProto& proto = blob_proto.qtensor();
|
||||
Int8TensorCPU* tensor = blob->template GetMutable<Int8TensorCPU>();
|
||||
tensor->scale = proto.scale();
|
||||
tensor->zero_point = proto.bias();
|
||||
vector<int> dims;
|
||||
for (const int d : proto.dims()) {
|
||||
dims.push_back(d);
|
||||
}
|
||||
tensor->t.Resize(dims);
|
||||
switch (proto.data_type()) {
|
||||
case TensorProto_DataType_INT32:
|
||||
detail::CopyFromProtoAsIs(
|
||||
tensor->t.size(),
|
||||
proto.data(),
|
||||
tensor->t.template mutable_data<int32_t>(),
|
||||
&this->context_);
|
||||
break;
|
||||
case TensorProto_DataType_UINT8:
|
||||
detail::CopyFromProtoWithCast(
|
||||
tensor->t.size(),
|
||||
proto.data(),
|
||||
tensor->t.template mutable_data<uint8_t>(),
|
||||
&this->context_);
|
||||
break;
|
||||
default:
|
||||
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
CPUContext context_;
|
||||
};
|
||||
|
||||
} // namespace int8
|
||||
|
||||
namespace {
|
||||
REGISTER_BLOB_SERIALIZER(
|
||||
(TypeMeta::Id<int8::Int8TensorCPU>()),
|
||||
int8::Int8TensorCPUSerializer);
|
||||
REGISTER_BLOB_DESERIALIZER(Int8TensorCPU, int8::Int8TensorCPUDeserializer);
|
||||
} // namespace
|
||||
|
||||
} // namespace caffe2
|
@ -37,10 +37,15 @@ size_t ReplaceAll(string& s, const char* from, const char* to) {
|
||||
return numReplaced;
|
||||
}
|
||||
|
||||
static std::function<string(void)> FetchStackTrace = []() { return ""; };
|
||||
namespace {
|
||||
std::function<string(void)>* GetFetchStackTrace() {
|
||||
static std::function<string(void)> func = []() { return ""; };
|
||||
return &func;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void SetStackTraceFetcher(std::function<string(void)> fetcher) {
|
||||
FetchStackTrace = fetcher;
|
||||
*GetFetchStackTrace() = fetcher;
|
||||
}
|
||||
|
||||
static std::function<void(const OperatorDef&)> OperatorLogger =
|
||||
@ -70,7 +75,7 @@ EnforceNotMet::EnforceNotMet(
|
||||
". ",
|
||||
msg,
|
||||
" ")},
|
||||
stack_trace_(FetchStackTrace()) {
|
||||
stack_trace_((*GetFetchStackTrace())()) {
|
||||
if (FLAGS_caffe2_use_fatal_for_enforce) {
|
||||
LOG(FATAL) << msg_stack_[0];
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
|
||||
type_ = operator_def.type();
|
||||
}
|
||||
|
||||
vector<TensorShape> OperatorBase::InputTensorShapes() {
|
||||
vector<TensorShape> OperatorBase::InputTensorShapes() const {
|
||||
vector<TensorShape> tps;
|
||||
for (const auto& blob : inputs_) {
|
||||
tps.push_back(GetTensorShapeOfBlob(blob));
|
||||
|
@ -115,7 +115,7 @@ class OperatorBase : public Observable<OperatorBase> {
|
||||
}
|
||||
inline const vector<const Blob*>& Inputs() const { return inputs_; }
|
||||
inline const vector<Blob*>& Outputs() { return outputs_; }
|
||||
vector<TensorShape> InputTensorShapes();
|
||||
vector<TensorShape> InputTensorShapes() const;
|
||||
|
||||
virtual void WaitEvent(const Event& ev, int /*stream_id */ = -1) {
|
||||
ev.Finish();
|
||||
|
@ -489,6 +489,7 @@ bool RunPlanOnWorkspace(
|
||||
|
||||
NetDefMap net_defs;
|
||||
for (const NetDef& net_def : plan.network()) {
|
||||
LOG(INFO) << "Processing net '" << net_def.name() << "'";
|
||||
CAFFE_ENFORCE(
|
||||
net_defs.count(net_def.name()) == 0,
|
||||
"Your plan contains networks of the same name \"",
|
||||
|
@ -97,7 +97,11 @@ Predictor::Predictor(
|
||||
|
||||
if (optimization) {
|
||||
#ifdef CAFFE2_OPTIMIZER
|
||||
run_net_ = opt::optimize(run_net_, &ws_, optimization);
|
||||
try {
|
||||
run_net_ = opt::optimize(run_net_, &ws_, optimization);
|
||||
} catch (const std::exception& e) {
|
||||
LOG(WARNING) << "Optimization pass failed: " << e.what();
|
||||
}
|
||||
#else
|
||||
LOG(WARNING) << "Caffe2 is compiled without optimization passes.";
|
||||
#endif
|
||||
|
@ -22,6 +22,7 @@ CAFFE_KNOWN_TYPE(double);
|
||||
CAFFE_KNOWN_TYPE(char);
|
||||
CAFFE_KNOWN_TYPE(std::unique_ptr<std::mutex>);
|
||||
CAFFE_KNOWN_TYPE(std::unique_ptr<std::atomic<bool>>);
|
||||
CAFFE_KNOWN_TYPE(std::vector<int32_t>);
|
||||
CAFFE_KNOWN_TYPE(std::vector<int64_t>);
|
||||
CAFFE_KNOWN_TYPE(std::vector<unsigned long>);
|
||||
CAFFE_KNOWN_TYPE(bool*);
|
||||
|
@ -28,7 +28,9 @@ class PrependDimOp : public Operator<Context> {
|
||||
CAFFE_ENFORCE(
|
||||
input.dim(0) % dim_size_ == 0,
|
||||
"First dimension must be multiple of prepend_dim. Current first dimension: ",
|
||||
input.dim(0));
|
||||
input.dim(0),
|
||||
", prepend dim: ",
|
||||
dim_size_);
|
||||
|
||||
vector<int64_t> actual_new_shape(input.ndim() + 1);
|
||||
actual_new_shape[0] = dim_size_;
|
||||
|
@ -48,13 +48,15 @@ class ReshapeOp : public Operator<Context> {
|
||||
auto& shape = Input(1);
|
||||
CAFFE_ENFORCE(shape.ndim() == 1, "Shape should be 1-D");
|
||||
|
||||
const T* shape_data = shape.template data<T>();
|
||||
if (shape.size()) {
|
||||
const T* shape_data = shape.template data<T>();
|
||||
|
||||
// Bit awkward, but needed so works on both CPU and CUDA contexts
|
||||
std::vector<T> tmpv(shape.size());
|
||||
context_.template CopyBytes<Context, CPUContext>(
|
||||
shape.size() * sizeof(T), shape_data, &tmpv[0]);
|
||||
actual_new_shape.assign(tmpv.begin(), tmpv.begin() + shape.size());
|
||||
// Bit awkward, but needed so works on both CPU and CUDA contexts
|
||||
std::vector<T> tmpv(shape.size());
|
||||
context_.template CopyBytes<Context, CPUContext>(
|
||||
shape.size() * sizeof(T), shape_data, &tmpv[0]);
|
||||
actual_new_shape.assign(tmpv.begin(), tmpv.begin() + shape.size());
|
||||
}
|
||||
}
|
||||
|
||||
// Copy over the dimensions for those that are specified zero.
|
||||
|
@ -70,6 +70,7 @@ message QTensorProto {
|
||||
required bool is_signed = 5;
|
||||
repeated int32 data = 6 [packed = true];
|
||||
optional string name = 7;
|
||||
optional TensorProto.DataType data_type = 8 [default = INT32];
|
||||
}
|
||||
|
||||
// TensorProtos stores multiple TensorProto objects in one single proto. This
|
||||
|
@ -82,7 +82,7 @@ def IsOperatorWithEngine(op_type, engine):
|
||||
return C.op_registry_key(op_type, engine) in _REGISTERED_OPERATORS
|
||||
|
||||
|
||||
def DeviceOption(device_type, cuda_gpu_id=0, random_seed=None, node_name=None):
|
||||
def DeviceOption(device_type, cuda_gpu_id=0, random_seed=None, node_name=None, numa_node_id=None):
|
||||
option = caffe2_pb2.DeviceOption()
|
||||
option.device_type = device_type
|
||||
option.cuda_gpu_id = cuda_gpu_id
|
||||
@ -90,6 +90,9 @@ def DeviceOption(device_type, cuda_gpu_id=0, random_seed=None, node_name=None):
|
||||
option.node_name = node_name
|
||||
if random_seed is not None:
|
||||
option.random_seed = random_seed
|
||||
if numa_node_id is not None:
|
||||
assert device_type == caffe2_pb2.CPU
|
||||
option.numa_node_id = numa_node_id
|
||||
return option
|
||||
|
||||
|
||||
@ -2256,6 +2259,8 @@ def InjectCrossDeviceCopies(net, blob_to_device=None, blob_remap=None,
|
||||
Assumptions:
|
||||
1. every external inputs of this net is already in blob_to_device!
|
||||
2. if not, this function will use net device option
|
||||
3. InferOpBlobDevices might fail to get the correct inference for ops like
|
||||
EnsureCPUOutput that could take in input from multiple places.
|
||||
'''
|
||||
new_net = net.Clone(net._net.name + '_cross_device', keep_schema=True)
|
||||
del new_net._net.op[:]
|
||||
|
@ -655,14 +655,7 @@ class TestInferDevice(test_util.TestCase):
|
||||
def test_infer_device_cross_device(self):
|
||||
self._test_op("CopyGPUToCPU", self.cuda_option, self.cpu_option)
|
||||
self._test_op("CopyCPUToGPU", self.cpu_option, self.cuda_option)
|
||||
self._test_op("EnsureCPUOutput", self.cuda_option, self.cpu_option)
|
||||
self._test_op("CopyFromCPUInput", self.cpu_option, self.cuda_option)
|
||||
self._test_op(
|
||||
"EnsureCPUOutput",
|
||||
self.cpu_option,
|
||||
self.cpu_option,
|
||||
op_option=self.cpu_option
|
||||
)
|
||||
self._test_op(
|
||||
"CopyFromCPUInput",
|
||||
self.cpu_option,
|
||||
|
@ -87,6 +87,26 @@ class BatchDistillLRLoss(ModelLayer):
|
||||
net.NextScopedBlob('scaled_teacher_cross_entropy'),
|
||||
scale=self._teacherWeight,
|
||||
)
|
||||
if 'weight' in self.input_record.fields:
|
||||
weight_blob = self.input_record.weight()
|
||||
if self.input_record.weight.field_type().base != np.float32:
|
||||
weight_blob = net.Cast(
|
||||
weight_blob,
|
||||
weight_blob + '_float32',
|
||||
to=core.DataType.FLOAT
|
||||
)
|
||||
weight_blob = net.StopGradient(
|
||||
[weight_blob],
|
||||
[net.NextScopedBlob('weight_stop_gradient')],
|
||||
)
|
||||
scaled_true_xent = net.Mul(
|
||||
[scaled_true_xent, weight_blob],
|
||||
net.NextScopedBlob('weighted_xent_label'),
|
||||
)
|
||||
scaled_teacher_xent = net.Mul(
|
||||
[scaled_teacher_xent, weight_blob],
|
||||
net.NextScopedBlob('weighted_xent_teacher'),
|
||||
)
|
||||
|
||||
true_loss = net.AveragedLoss(
|
||||
scaled_true_xent,
|
||||
@ -96,7 +116,6 @@ class BatchDistillLRLoss(ModelLayer):
|
||||
scaled_teacher_xent,
|
||||
net.NextScopedBlob('teacher_loss')
|
||||
)
|
||||
|
||||
net.Add(
|
||||
[true_loss, teacher_loss],
|
||||
self.output_schema.field_blobs()
|
||||
|
@ -613,6 +613,17 @@ class TestLayers(LayersTestCase):
|
||||
]
|
||||
)
|
||||
|
||||
def testDistillBatchLRLoss(self):
|
||||
input_record = self.new_record(schema.Struct(
|
||||
('label', schema.Scalar((np.float64, (1,)))),
|
||||
('logit', schema.Scalar((np.float32, (2,)))),
|
||||
('teacher_label', schema.Scalar((np.float32(1,)))),
|
||||
('weight', schema.Scalar((np.float64, (1,))))
|
||||
))
|
||||
loss = self.model.BatchDistillLRLoss(input_record)
|
||||
self.assertEqual(schema.Scalar((np.float32, tuple())), loss)
|
||||
|
||||
|
||||
def testBatchLRLoss(self):
|
||||
input_record = self.new_record(schema.Struct(
|
||||
('label', schema.Scalar((np.float64, (1,)))),
|
||||
|
@ -105,12 +105,10 @@ class TestAdagrad(hu.HypothesisTestCase):
|
||||
allow_nan=False, allow_infinity=False),
|
||||
epsilon=st.floats(min_value=0.01, max_value=0.99,
|
||||
allow_nan=False, allow_infinity=False),
|
||||
data_strategy=st.data(),
|
||||
**hu.gcs)
|
||||
def test_sparse_adagrad(self, inputs, lr, epsilon,
|
||||
data_strategy, gc, dc):
|
||||
def test_sparse_adagrad(self, inputs, lr, epsilon, gc, dc):
|
||||
return adagrad_sparse_test_helper(self, inputs, lr, epsilon,
|
||||
data_strategy, None, ref_adagrad, gc, dc)
|
||||
None, ref_adagrad, gc, dc)
|
||||
|
||||
@given(inputs=hu.tensors(n=2),
|
||||
lr=st.floats(min_value=0.01, max_value=0.99,
|
||||
|
@ -49,22 +49,15 @@ def ref_adagrad(param_in, mom_in, grad, lr, epsilon, using_fp16=False,
|
||||
|
||||
|
||||
def adagrad_sparse_test_helper(parent_test, inputs, lr, epsilon,
|
||||
data_strategy, engine, ref_adagrad, gc, dc):
|
||||
engine, ref_adagrad, gc, dc):
|
||||
param, momentum, grad = inputs
|
||||
momentum = np.abs(momentum)
|
||||
lr = np.array([lr], dtype=np.float32)
|
||||
|
||||
# Create an indexing array containing values that are lists of indices,
|
||||
# which index into grad
|
||||
indices = data_strategy.draw(
|
||||
hu.tensor(dtype=np.int64,
|
||||
elements=st.sampled_from(np.arange(grad.shape[0]))),
|
||||
)
|
||||
hypothesis.note('indices.shape: %s' % str(indices.shape))
|
||||
|
||||
# For now, the indices must be unique
|
||||
hypothesis.assume(np.array_equal(np.unique(indices.flatten()),
|
||||
np.sort(indices.flatten())))
|
||||
indices = np.random.choice(np.arange(grad.shape[0]),
|
||||
size=np.random.randint(grad.shape[0]), replace=False)
|
||||
|
||||
# Sparsify grad
|
||||
grad = grad[indices]
|
||||
|
@ -386,6 +386,7 @@ class TestConvolution(hu.HypothesisTestCase):
|
||||
order=order,
|
||||
engine=engine,
|
||||
device_option=gc,
|
||||
exhaustive_search=True,
|
||||
)
|
||||
if order == "NCHW":
|
||||
X_f = X.transpose((0, 3, 1, 2))
|
||||
|
@ -7,6 +7,19 @@
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
struct PerformanceInformation {
|
||||
// Analytic
|
||||
int64_t flops = 0;
|
||||
int64_t bytes_written = 0;
|
||||
int64_t bytes_read = 0;
|
||||
std::vector<TensorShape> tensor_shapes = {};
|
||||
std::vector<Argument> args = {};
|
||||
std::string engine = ""; // the engine used
|
||||
std::string type = ""; // the type of the operator
|
||||
// Measured
|
||||
double latency = 0;
|
||||
};
|
||||
|
||||
class CAFFE2_OBSERVER_API NetObserverReporter {
|
||||
public:
|
||||
virtual ~NetObserverReporter() = default;
|
||||
@ -16,9 +29,8 @@ class CAFFE2_OBSERVER_API NetObserverReporter {
|
||||
The delays are saved in a map. The key is an identifier associated
|
||||
with the reported delay. The value is the delay value in float
|
||||
*/
|
||||
virtual void reportDelay(
|
||||
virtual void report(
|
||||
NetBase* net,
|
||||
std::map<std::string, double>& delays,
|
||||
const char* unit) = 0;
|
||||
std::map<std::string, PerformanceInformation>&) = 0;
|
||||
};
|
||||
}
|
||||
|
@ -7,16 +7,13 @@ namespace caffe2 {
|
||||
|
||||
const std::string NetObserverReporterPrint::IDENTIFIER = "Caffe2Observer ";
|
||||
|
||||
void NetObserverReporterPrint::reportDelay(
|
||||
void NetObserverReporterPrint::report(
|
||||
NetBase* net,
|
||||
std::map<std::string, double>& delays,
|
||||
const char* unit) {
|
||||
CAFFE_ENFORCE(unit != nullptr, "Unit is null");
|
||||
std::map<std::string, PerformanceInformation>& info) {
|
||||
LOG(INFO) << IDENTIFIER << "Net Name - " << net->Name();
|
||||
LOG(INFO) << IDENTIFIER << "Delay Start";
|
||||
for (auto& p : delays) {
|
||||
LOG(INFO) << IDENTIFIER << p.first << " - " << p.second << "\t(" << *unit
|
||||
<< ")";
|
||||
for (auto& p : info) {
|
||||
LOG(INFO) << IDENTIFIER << p.first << " - " << p.second.latency << "\t(ms)";
|
||||
}
|
||||
LOG(INFO) << IDENTIFIER << "Delay End";
|
||||
}
|
||||
|
@ -9,9 +9,7 @@ namespace caffe2 {
|
||||
class CAFFE2_OBSERVER_API NetObserverReporterPrint : public NetObserverReporter {
|
||||
public:
|
||||
static const std::string IDENTIFIER;
|
||||
void reportDelay(
|
||||
NetBase* net,
|
||||
std::map<std::string, double>& delays,
|
||||
const char* unit);
|
||||
void report(NetBase* net, std::map<std::string, PerformanceInformation>&);
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
@ -73,17 +73,31 @@ void PerfNetObserver::Stop() {
|
||||
return;
|
||||
}
|
||||
auto currentRunTime = timer_.MilliSeconds();
|
||||
std::map<std::string, double> delays;
|
||||
delays.insert({"NET_DELAY", currentRunTime});
|
||||
std::map<std::string, PerformanceInformation> info;
|
||||
PerformanceInformation net_perf;
|
||||
net_perf.latency = currentRunTime;
|
||||
if (logType_ == PerfNetObserver::OPERATOR_DELAY) {
|
||||
const auto& operators = subject_->GetOperators();
|
||||
for (int idx = 0; idx < operators.size(); ++idx) {
|
||||
const auto* op = operators[idx];
|
||||
auto name = getObserverName(op, idx);
|
||||
double delay = static_cast<const PerfOperatorObserver*>(observerMap_[op])
|
||||
->getMilliseconds();
|
||||
delays.insert({name, delay});
|
||||
PerformanceInformation p;
|
||||
|
||||
p.latency = static_cast<const PerfOperatorObserver*>(observerMap_[op])
|
||||
->getMilliseconds();
|
||||
|
||||
p.engine = op->engine();
|
||||
p.type = op->type();
|
||||
p.tensor_shapes = op->InputTensorShapes();
|
||||
if (op->has_debug_def()) {
|
||||
for (auto arg : op->debug_def().arg()) {
|
||||
p.args.emplace_back(arg);
|
||||
}
|
||||
}
|
||||
|
||||
info.insert({name, p});
|
||||
}
|
||||
|
||||
/* clear all operator delay after use so that we don't spent time
|
||||
collecting the operator delay info in later runs */
|
||||
for (auto* op : operators) {
|
||||
@ -91,7 +105,8 @@ void PerfNetObserver::Stop() {
|
||||
}
|
||||
observerMap_.clear();
|
||||
}
|
||||
ObserverConfig::getReporter()->reportDelay(subject_, delays, "ms");
|
||||
info.insert({"NET_DELAY", net_perf});
|
||||
ObserverConfig::getReporter()->report(subject_, info);
|
||||
}
|
||||
|
||||
caffe2::string PerfNetObserver::getObserverName(const OperatorBase* op, int idx)
|
||||
@ -138,4 +153,25 @@ double PerfOperatorObserver::getMilliseconds() const {
|
||||
return milliseconds_;
|
||||
}
|
||||
|
||||
OpSchema::Cost PerfOperatorObserver::getAnalyticalCost() const {
|
||||
auto* op = subject_;
|
||||
auto* schema = OpSchemaRegistry::Schema(op->type());
|
||||
OpSchema::Cost cost;
|
||||
if (schema && schema->HasCostInferenceFunction()) {
|
||||
vector<TensorShape> shapes = op->InputTensorShapes();
|
||||
|
||||
auto all_good_shapes = std::accumulate(
|
||||
shapes.begin(),
|
||||
shapes.end(),
|
||||
true,
|
||||
[](bool acc, const TensorShape& shape) {
|
||||
return acc && !shape.unknown_shape();
|
||||
});
|
||||
if (all_good_shapes) {
|
||||
cost = schema->InferCost(op->debug_def(), shapes);
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
@ -45,6 +45,7 @@ class PerfOperatorObserver : public ObserverBase<OperatorBase> {
|
||||
virtual ~PerfOperatorObserver();
|
||||
|
||||
double getMilliseconds() const;
|
||||
OpSchema::Cost getAnalyticalCost() const;
|
||||
|
||||
private:
|
||||
void Start() override;
|
||||
|
Reference in New Issue
Block a user