mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
refactor ps benchmark (#60784)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60784 This pr refactors the ps benchmark for modular trainers. Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D29697291 Pulled By: gcramer23 fbshipit-source-id: 64579a1f5326d3cd9f32936dcf53bc243d54b71d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7d2ea9a8f7
commit
304c02ee44
@ -9,13 +9,24 @@ class DummyModel(nn.Module):
|
||||
embedding_dim: int,
|
||||
dense_input_size: int,
|
||||
dense_output_size: int,
|
||||
dense_layers_count: int,
|
||||
sparse: bool
|
||||
):
|
||||
r"""
|
||||
A dummy model with an EmbeddingBag Layer and Dense Layer.
|
||||
Args:
|
||||
num_embeddings (int): size of the dictionary of embeddings
|
||||
embedding_dim (int): the size of each embedding vector
|
||||
dense_input_size (int): size of each input sample
|
||||
dense_output_size (int): size of each output sample
|
||||
dense_layers_count: (int): number of dense layers in dense Sequential module
|
||||
sparse (bool): if True, gradient w.r.t. weight matrix will be a sparse tensor
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding = nn.EmbeddingBag(
|
||||
num_embeddings, embedding_dim, sparse=sparse
|
||||
)
|
||||
self.dense = nn.Sequential(*[nn.Linear(dense_input_size, dense_output_size) for _ in range(10)])
|
||||
self.dense = nn.Sequential(*[nn.Linear(dense_input_size, dense_output_size) for _ in range(dense_layers_count)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embedding(x)
|
||||
|
||||
Reference in New Issue
Block a user