optimize: eliminate duplicate split_enc_dec_inputs calls (#25573)

Signed-off-by: nicole-lihui <nicole.li@daocloud.io>
This commit is contained in:
Nicole LiHui 🥜
2025-09-25 13:03:25 +08:00
committed by GitHub
parent 845adb3ec6
commit c85be1f6dd

View File

@ -388,9 +388,9 @@ class Processor:
eos_token_id = self.input_preprocessor.get_eos_token_id()
self._validate_model_inputs(processed_inputs)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs)
# Mypy does not always properly infer the types of some elements of
# discriminated unions of TypedDicts, because of how it handles
# inheritance of TypedDict. If we explicitly extract the items we want
@ -458,9 +458,8 @@ class Processor:
trace_headers=trace_headers,
)
def _validate_model_inputs(self, inputs: ProcessorInputs):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
def _validate_model_inputs(self, encoder_inputs: Optional[SingletonInputs],
decoder_inputs: SingletonInputs):
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs, prompt_type="encoder")