mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Resubmit: Set the correct engine name for position weighted pooling when fp16 is used for training
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13768 Reviewed By: xianjiec Differential Revision: D12996103 fbshipit-source-id: 5ca4cda4210f68ece2b5d6eced8cf52ee91fb36f
This commit is contained in:
committed by
Facebook Github Bot
parent
ae1b37650c
commit
0199d59d3a
@ -173,6 +173,8 @@ class SparseLookup(ModelLayer):
|
|||||||
"Train version {} is not currently supported".format(trainer_version)
|
"Train version {} is not currently supported".format(trainer_version)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.trainer_version = trainer_version
|
||||||
|
|
||||||
return default_weight_init
|
return default_weight_init
|
||||||
|
|
||||||
def _gather_wrapper(self, net, version, in_indices, out):
|
def _gather_wrapper(self, net, version, in_indices, out):
|
||||||
@ -215,11 +217,22 @@ class SparseLookup(ModelLayer):
|
|||||||
if version in ['fp32', 'fp16']:
|
if version in ['fp32', 'fp16']:
|
||||||
# SparseLengths* Ops will accept either fp16 or fp32 embedding
|
# SparseLengths* Ops will accept either fp16 or fp32 embedding
|
||||||
# matrix and output fp32 pooled embedding
|
# matrix and output fp32 pooled embedding
|
||||||
net.__getattr__(layer_name)(
|
# A special case here is that we need FP16 engine for
|
||||||
op_input,
|
# SparseLengthsWeightedSum when FP16 embeedings are used for
|
||||||
self.output_schema.field_blobs(),
|
# correct backward updates
|
||||||
grad_on_weights=grad_on_weights,
|
if reducer == "WeightedSum" and version == "fp16":
|
||||||
)
|
net.SparseLengthsWeightedSum(
|
||||||
|
op_input,
|
||||||
|
self.output_schema.field_blobs(),
|
||||||
|
grad_on_weights=grad_on_weights,
|
||||||
|
engine='FP16',
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
net.__getattr__(layer_name)(
|
||||||
|
op_input,
|
||||||
|
self.output_schema.field_blobs(),
|
||||||
|
grad_on_weights=grad_on_weights,
|
||||||
|
)
|
||||||
elif version == 'uint8rowwise':
|
elif version == 'uint8rowwise':
|
||||||
op_input.insert(len(op_input), self.scale_bias)
|
op_input.insert(len(op_input), self.scale_bias)
|
||||||
net.__getattr__(layer_name + '8BitsRowwise')(
|
net.__getattr__(layer_name + '8BitsRowwise')(
|
||||||
@ -345,6 +358,17 @@ class SparseLookup(ModelLayer):
|
|||||||
raise "Only Sum, Mean, None are supported for IdScoreList input." +\
|
raise "Only Sum, Mean, None are supported for IdScoreList input." +\
|
||||||
"Trying to create with {}".format(self.reducer)
|
"Trying to create with {}".format(self.reducer)
|
||||||
|
|
||||||
|
def _add_ops(self, net, version='fp32'):
|
||||||
|
if _is_id_list(self.input_record):
|
||||||
|
self._add_ops_id_list(net, version=version)
|
||||||
|
elif _is_id_score_list(self.input_record):
|
||||||
|
self._add_ops_id_score_list(net, version=version)
|
||||||
|
else:
|
||||||
|
raise "Unsupported input type {0}".format(self.input_record)
|
||||||
|
|
||||||
|
def add_train_ops(self, net):
|
||||||
|
self._add_ops(net, self.trainer_version)
|
||||||
|
|
||||||
def add_ops(self, net):
|
def add_ops(self, net):
|
||||||
cur_scope = get_current_scope()
|
cur_scope = get_current_scope()
|
||||||
version = get_sparse_lookup_predictor_version(
|
version = get_sparse_lookup_predictor_version(
|
||||||
@ -357,9 +381,4 @@ class SparseLookup(ModelLayer):
|
|||||||
'fused_uint8rowwise'}:
|
'fused_uint8rowwise'}:
|
||||||
version = 'fp32'
|
version = 'fp32'
|
||||||
|
|
||||||
if _is_id_list(self.input_record):
|
self._add_ops(net, version)
|
||||||
self._add_ops_id_list(net, version=version)
|
|
||||||
elif _is_id_score_list(self.input_record):
|
|
||||||
self._add_ops_id_score_list(net, version=version)
|
|
||||||
else:
|
|
||||||
raise "Unsupported input type {0}".format(self.input_record)
|
|
||||||
|
Reference in New Issue
Block a user