mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Automatic generation of unittest for Glow integration
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19936 Reviewed By: ipiszy Differential Revision: D15138090 fbshipit-source-id: 29a812548bb5da00176b00c1e9a26a7c31cea9c0
This commit is contained in:
committed by
Facebook Github Bot
parent
3a0727e58b
commit
9a81d1e692
@ -22,7 +22,7 @@ void DataNetFiller::fill_input_internal(TensorList_t* input_data) const {
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_with_type(
|
||||
void fill_with_type(
|
||||
const TensorFiller& filler,
|
||||
const std::string& type,
|
||||
TensorCPU* output) {
|
||||
|
@ -79,6 +79,11 @@ class DataNetFiller : public Filler {
|
||||
const NetDef data_net_;
|
||||
};
|
||||
|
||||
void fill_with_type(
|
||||
const TensorFiller& filler,
|
||||
const std::string& type,
|
||||
TensorCPU* output);
|
||||
|
||||
/*
|
||||
* @run_net: the predict net with parameter and input names
|
||||
* @input_dims: the input dimentions of all operator inputs of run_net
|
||||
@ -95,10 +100,7 @@ class DataRandomFiller : public Filler {
|
||||
|
||||
void fill_parameter(Workspace* ws) const override;
|
||||
|
||||
protected:
|
||||
DataRandomFiller() {}
|
||||
|
||||
TensorFiller get_tensor_filler(
|
||||
static TensorFiller get_tensor_filler(
|
||||
const OperatorDef& op_def,
|
||||
int input_index,
|
||||
const std::vector<std::vector<int64_t>>& input_dims) {
|
||||
@ -117,6 +119,9 @@ class DataRandomFiller : public Filler {
|
||||
return filler;
|
||||
}
|
||||
|
||||
protected:
|
||||
DataRandomFiller() {}
|
||||
|
||||
using filler_type_pair_t = std::pair<TensorFiller, std::string>;
|
||||
std::unordered_map<std::string, filler_type_pair_t> parameters_;
|
||||
std::unordered_map<std::string, filler_type_pair_t> inputs_;
|
||||
|
Reference in New Issue
Block a user