mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
fix bug in warnings T5 pipelines (#3545)
This commit is contained in:
parent
9de9ceb6c5
commit
06dd597552
@ -1235,17 +1235,19 @@ class SummarizationPipeline(Pipeline):
|
|||||||
elif self.framework == "tf":
|
elif self.framework == "tf":
|
||||||
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||||
|
|
||||||
if input_length < self.model.config.min_length // 2:
|
min_length = generate_kwargs.get("min_length", self.model.config.min_length)
|
||||||
|
if input_length < min_length // 2:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
|
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
|
||||||
self.model.config.min_length, input_length
|
min_length, input_length
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_length < self.model.config.max_length:
|
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||||
|
if input_length < max_length:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
|
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
|
||||||
self.model.config.max_length, input_length
|
max_length, input_length
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1349,10 +1351,11 @@ class TranslationPipeline(Pipeline):
|
|||||||
elif self.framework == "tf":
|
elif self.framework == "tf":
|
||||||
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||||
|
|
||||||
if input_length > 0.9 * self.model.config.max_length:
|
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||||
|
if input_length > 0.9 * max_length:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
|
"Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
|
||||||
input_length, self.model.config.max_length
|
input_length, max_length
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user