mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	fix ctc_loss argument check error message (#26325)
Summary: Was confused by the wrong message while debugging. Turns out cpu version is wrong on comparison direction and gpu version is printing wrong number in addition to that. This fix should make the error message correct. jjsjann123 for tracking Pull Request resolved: https://github.com/pytorch/pytorch/pull/26325 Differential Revision: D17408969 Pulled By: soumith fbshipit-source-id: 0d9330e00aaabcb3e8e893b37a6a53fb378171c5
This commit is contained in:
		
				
					committed by
					
						
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							a76403f609
						
					
				
				
					commit
					3ce2ceca05
				
			@ -84,7 +84,7 @@ std::tuple<Tensor, Tensor> ctc_loss_cpu_template(const Tensor& log_probs, const
 | 
			
		||||
  int64_t max_input_length = log_probs.size(0);
 | 
			
		||||
  for (int64_t b = 0; b < batch_size; b++) {
 | 
			
		||||
    TORCH_CHECK(input_lengths[b] <= max_input_length,
 | 
			
		||||
             "Expected tensor to have size at least ", max_input_length, " at dimension 1, but got size ", input_lengths[b], " for ", log_probs_arg,
 | 
			
		||||
             "Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b],
 | 
			
		||||
             " (while checking arguments for ", c, ")");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -248,7 +248,7 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
 | 
			
		||||
  int64_t max_input_length = log_probs.size(0);
 | 
			
		||||
  for (int64_t b = 0; b < batch_size; b++) {
 | 
			
		||||
    TORCH_CHECK(input_lengths[b] <= max_input_length,
 | 
			
		||||
             "Expected tensor to have size at least ", max_input_length, " at dimension 1, but got size ", targets.size(0), " for ", targets_arg,
 | 
			
		||||
             "Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b],
 | 
			
		||||
             " (while checking arguments for ", c, ")");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user