mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
broadcast op: it is an in-place op with both input and output set
This commit is contained in:
@ -12,7 +12,10 @@ class BroadcastOp final : public Operator<dtype, DeviceContext> {
|
||||
USE_OPERATOR_BASE_FUNCTIONS;
|
||||
BroadcastOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<dtype, DeviceContext>(operator_def, ws),
|
||||
root_(OperatorBase::template GetSingleArgument<int>("root", 0)) {}
|
||||
root_(OperatorBase::template GetSingleArgument<int>("root", 0)) {
|
||||
CHECK_EQ(operator_def.input(0), operator_def.output(0))
|
||||
<< "Broadcast is an in-place operator.";
|
||||
}
|
||||
~BroadcastOp() {}
|
||||
|
||||
bool RunOnDevice() {
|
||||
@ -25,8 +28,9 @@ class BroadcastOp final : public Operator<dtype, DeviceContext> {
|
||||
|
||||
protected:
|
||||
int root_;
|
||||
// Output: X. Note that X must have been initialized on root.
|
||||
INPUT_OUTPUT_STATS(0, 0, 1, 1);
|
||||
// Input: X. Output: X.
|
||||
// Note that Broadcast works in-place by definition.
|
||||
INPUT_OUTPUT_STATS(1, 1, 1, 1);
|
||||
DISABLE_COPY_AND_ASSIGN(BroadcastOp);
|
||||
};
|
||||
|
||||
|
@ -23,6 +23,7 @@ const char kBcastNet[] =
|
||||
" }"
|
||||
" }"
|
||||
" op {"
|
||||
" input: \"X\""
|
||||
" output: \"X\""
|
||||
" type: \"Broadcast\""
|
||||
" arg {"
|
||||
|
Reference in New Issue
Block a user