Compare commits

...

3 Commits

Author SHA1 Message Date
223937abce make style 2022-02-08 11:32:41 -05:00
4f8984e9b0 Any tensor with rank > 3 probably has a variable dimension 2022-02-08 11:22:20 -05:00
945a0c991d EXTREMELY BREAKING changes to predict_step 2022-02-08 10:51:05 -05:00

View File

@ -954,6 +954,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
del return_metrics["loss_loss"]
return return_metrics
def predict_step(self, data):
def raggedify_if_possible(array):
if array.shape.ndims >= 3:
return tf.RaggedTensor.from_tensor(array)
else:
return array
x, _, _ = data_adapter.unpack_x_y_sample_weight(data)
output = self(x, training=False)
if isinstance(output, dict):
output = {key: raggedify_if_possible(val) for key, val in output.items()}
elif isinstance(output, tuple) or isinstance(output, list):
output = tuple([raggedify_if_possible(arr) for arr in output])
elif isinstance(output, tf.Tensor):
output = raggedify_if_possible(output)
return output
def create_model_card(
self,
output_dir,