mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-01 04:48:43 +08:00 
			
		
		
		
	LSTM: support dropping hidden / cell states when sequence
Summary: This is useful when data has standalone sequences which are not connected to each other by any meaningful context Reviewed By: yqwangustc Differential Revision: D4835164 fbshipit-source-id: f95626acc26acc3eba3bca7efb08ed1dbdb36c83
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							c4ce118393
						
					
				
				
					commit
					ad6204eb0b
				
			| @ -129,12 +129,14 @@ class LSTMCell(RNNCell): | ||||
|         memory_optimization, | ||||
|         name, | ||||
|         forward_only=False, | ||||
|         drop_states=False, | ||||
|     ): | ||||
|         super(LSTMCell, self).__init__(name, forward_only) | ||||
|         self.input_size = input_size | ||||
|         self.hidden_size = hidden_size | ||||
|         self.forget_bias = float(forget_bias) | ||||
|         self.memory_optimization = memory_optimization | ||||
|         self.drop_states = drop_states | ||||
|  | ||||
|     def _apply( | ||||
|         self, | ||||
| @ -163,6 +165,7 @@ class LSTMCell(RNNCell): | ||||
|             ], | ||||
|             list(self.get_state_names()), | ||||
|             forget_bias=self.forget_bias, | ||||
|             drop_states=self.drop_states, | ||||
|         ) | ||||
|         model.net.AddExternalOutputs(hidden_t, cell_t) | ||||
|         if self.memory_optimization: | ||||
| @ -202,7 +205,8 @@ class LSTMCell(RNNCell): | ||||
|  | ||||
| def LSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, | ||||
|          scope, outputs_with_grads=(0,), return_params=False, | ||||
|          memory_optimization=False, forget_bias=0.0, forward_only=False): | ||||
|          memory_optimization=False, forget_bias=0.0, forward_only=False, | ||||
|          drop_states=False): | ||||
|     ''' | ||||
|     Adds a standard LSTM recurrent network operator to a model. | ||||
|  | ||||
| @ -241,6 +245,7 @@ def LSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, | ||||
|         memory_optimization=memory_optimization, | ||||
|         name=scope, | ||||
|         forward_only=forward_only, | ||||
|         drop_states=drop_states, | ||||
|     ) | ||||
|     result = cell.apply_over_sequence( | ||||
|         model=model, | ||||
| @ -795,6 +800,7 @@ class MILSTMCell(LSTMCell): | ||||
|             [hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep], | ||||
|             [self.scope('hidden_t_intermediate'), self.scope('cell_t')], | ||||
|             forget_bias=self.forget_bias, | ||||
|             drop_states=self.drop_states, | ||||
|         ) | ||||
|         hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t')) | ||||
|         model.net.AddExternalOutputs( | ||||
| @ -808,7 +814,7 @@ class MILSTMCell(LSTMCell): | ||||
|  | ||||
| def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, | ||||
|            scope, outputs_with_grads=(0,), memory_optimization=False, | ||||
|            forget_bias=0.0, forward_only=False): | ||||
|            forget_bias=0.0, forward_only=False, drop_states=False): | ||||
|     ''' | ||||
|     Adds MI flavor of standard LSTM recurrent network operator to a model. | ||||
|     See https://arxiv.org/pdf/1606.06630.pdf | ||||
| @ -831,9 +837,9 @@ def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, | ||||
|     outputs_with_grads : position indices of output blobs which will receive | ||||
|     external error gradient during backpropagation | ||||
|  | ||||
|     memory_optimization: if enabled, the LSTM step is recomputed on backward step | ||||
|                    so that we don't need to store forward activations for each | ||||
|                    timestep. Saves memory with cost of computation. | ||||
|     memory_optimization: if enabled, the LSTM step is recomputed on backward | ||||
|     step. So that we don't need to store forward activations for each timestep. | ||||
|     Saves memory with cost of computation. | ||||
|  | ||||
|     forward_only run only forward pass | ||||
|     ''' | ||||
| @ -844,6 +850,7 @@ def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, | ||||
|         memory_optimization=memory_optimization, | ||||
|         name=scope, | ||||
|         forward_only=forward_only, | ||||
|         drop_states=drop_states, | ||||
|     ) | ||||
|     result = cell.apply_over_sequence( | ||||
|         model=model, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user