Commit Graph

196 Commits

Author SHA1 Message Date
Billy Bradley
dcc49d8a7e
In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) (#25242)
* In assisted decoding, pass model_kwargs to model's forward call

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.

* Improve variable names in _extend_attention_mask

* Refactor extending token_type_ids into a function

* Replace deepcopy with copy to optimize performance

* Update new persimmon model with llama changes for assisted generation

* Update new mistral model for assisted generation with prepare_inputs_for_generation

* Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
2023-10-11 13:18:42 +02:00
Dong-Yong Lee
8881f38a4f
Fix beam search when using model parallel (#24969)
* Fix GPTNeoX beam search when using parallelize

* Fix beam search idx device when using model parallel

* remove onnx related stuff

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix: move test_beam_search_on_multi_gpu to GenerationTesterMixin

* fix: add right item to _no_split_modules of MegaPreTrainedModel

* fix: add num_beams within parallelized beam_search test

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2023-09-14 11:00:52 -04:00
Joao Gante
3319eb5490
Generate: legacy mode is only triggered when generation_config is untouched (#25962) 2023-09-12 12:08:17 +01:00
Joao Gante
3c2383b1c6
Generate: general test for decoder-only generation from inputs_embeds (#25687)
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2023-08-23 19:17:01 +01:00
Joao Gante
3f9cb33504
Generate: fix default max length warning (#25539) 2023-08-16 15:30:54 +01:00
hukuda222
cb3c821cb7
aligned sample_beam output selection with beam_search (#25375)
* aligned sample_beam specs with beam_search

* pull origin main

* Revert "pull origin main"

This reverts commit 06d356f113.

* update test_utils.py

* fix format

* remove comment

---------

Co-authored-by: Shogo Fujita <shogo.fujita@legalontech.jp>
2023-08-09 18:28:57 +02:00
Guillaume "Vermeille" Sanchez
d533465150
add CFG for .generate() (#24654) 2023-08-06 20:15:24 +01:00
Benjamin Badger
caf5e369fc
Contrastive Search peak memory reduction (#24120)
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
2023-07-20 18:46:53 +01:00
Joao Gante
5f3efdf762
Generate: group_beam_search requires diversity_penalty>0.0 (#24456)
* add exception

* update docs
2023-06-27 10:46:39 +01:00
Bowen Bao
a28325e25e
Replace python random with torch.rand to enable dynamo.export (#24434)
* Replace python random with torch.rand to enable dynamo.export

* revert changes to flax model code

* Remove unused random import

* Fix torch template

* Move torch.manual_seed(0) to right location
2023-06-23 08:17:21 -04:00
Joao Gante
612b2a1a6d
Generate: increase left-padding test atol (#23448)
increase atol
2023-06-07 11:56:57 +01:00
Yih-Dar
2406dbdcfa
Less flaky test_assisted_decoding_matches_greedy_search (#23451)
* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2023-05-18 17:28:22 +02:00
Joao Gante
aea7b23b57
Generate: skip left-padding tests on old models (#23437) 2023-05-18 11:04:51 +01:00
Joao Gante
918a06e25d
Generate: add test to check KV format (#23403)
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2023-05-16 19:28:19 +01:00
Joao Gante
bbfb9fc22b
Generate: starcoder 🤜 🤛 assisted generation (#23182)
* starcoder has joined the chat

* indexing that works for all
2023-05-08 10:45:40 +01:00
Joao Gante
ce31e3c8bf
Generate: slow assisted generation test (#23125) 2023-05-03 14:24:50 +01:00
Joao Gante
849367ccf7
Generate: prepare assisted generation for release (#23052) 2023-04-29 10:53:30 +01:00
Joao Gante
e4a97f82bf
Generate: assisted generation with sample (take 2) (#22949)
* temperature controls speed
2023-04-24 19:54:55 +01:00
Joao Gante
78cda46f17
Generate: Add assisted generation (#22211)
* working mvp

* remove breakpoint

* fix commit

* standardize outputs

* tmp commit

* tests almost ready

* tmp commit

* skip a few models

* Add streaming; Docs and examples

* document limitations

* PR commits

* Amy PR comments
2023-04-18 17:36:56 +01:00
Yih-Dar
90247d3e01
Fix test_eos_token_id_int_and_list_top_k_top_sampling (#22826)
* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2023-04-18 16:04:51 +02:00
Joao Gante
9dfd6a4baa
Generate: handle text conditioning with multimodal encoder-decoder models (#22748) 2023-04-13 19:51:13 +01:00
Joao Gante
502fec779b
Generate: add test for left-padding support (#22322) 2023-03-23 17:00:22 +00:00
Joao Gante
fd3eb3e3cd
Beef up Llama tests (#22314)
* tmp commit

* beef up llama tests
2023-03-22 15:20:48 +00:00
Yih-Dar
5110e5748e
🔥py38 + torch 2 🔥🔥🔥🚀 (#22204)
* py38 + torch 2

* increment cache versions

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2023-03-16 22:59:23 +01:00
Aaron Gokaslan
5e8c8eb5ba
Apply ruff flake8-comprehensions (#21694) 2023-02-22 09:14:54 +01:00
Joao Gante
13e03e619d
Generate: filter encoder inputs when its signature does not accept wildcards (#21603) 2023-02-14 10:46:46 +00:00
Joao Gante
fa4bdb0a40
Generate: correct default model input creation for decoder-only models (#21580) 2023-02-13 17:04:49 +00:00
Joao Gante
24273268b7
Generate: Fix flaky indexing error in test_constrained_beam_search_generate_dict_output (#21561) 2023-02-13 15:12:07 +00:00
Joao Gante
eb6c59bc78
Generate: TF supports multiple eos tokens (#21571) 2023-02-13 12:24:22 +00:00
Joao Gante
e69f9715eb
Generate: make TF .generate() signature == PT .generate() signature (#21525) 2023-02-09 11:10:13 +00:00
Motoki Wu
9960506cbe
Fix multiple eos_token_ids in model.generate(...) (#21461)
* add tests with multiple eos_token_ids

* make math.prod instead of sum

* make fixup

* fix long and also use np.prod since math.prod does not exist <python 3.8

* make fixup

* add prod util

* use prod util instead of np.prod

* make fixup

* previous .long location

* use tensor ops

* remove prod

* remove prod

* update device

* make fixup

* fix none
2023-02-08 13:48:46 -05:00
Joao Gante
1d9c26a4b8
Generate: TF compute_transition_scores (#21341) 2023-02-08 16:36:43 +00:00
Joao Gante
1e4cf8bb44
Generate: TF can now generate from embeddings in encoder-decoder models (#21475) 2023-02-07 11:18:23 +00:00
Sylvain Gugger
6f79d26442
Update quality tooling for formatting (#21480)
* Result of black 23.1

* Update target to Python 3.7

* Switch flake8 to ruff

* Configure isort

* Configure isort

* Apply isort with line limit

* Put the right black version

* adapt black in check copies

* Fix copies
2023-02-06 18:10:56 -05:00
Joao Gante
4943331015
Generate: TF can now accept custom logits processors (#21454) 2023-02-06 15:44:47 +00:00
Yih-Dar
59d5edef34
Avoid flaky generation sampling tests (#21445)
* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2023-02-03 22:01:25 +01:00
Joao Gante
f21af26279
🚨🚨 Generate: standardize beam search behavior across frameworks (#21368) 2023-02-03 10:24:02 +00:00
Joao Gante
92ce53aab8
Generate: decoder-only models can generate with inputs_embeds (#21405) 2023-02-01 21:50:38 +00:00
Joao Gante
623346ab18
Template for framework-agnostic tests (#21348) 2023-01-31 11:33:18 +00:00
Joao Gante
42b60f8b02
Generate: Relaxed max_length and max_new_tokens coexistence (#21347)
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-01-30 17:53:54 +00:00
Joao Gante
af37d183b3
Generate: documented function to compute the transition scores (#21191)
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-01-20 12:50:01 +00:00
Joao Gante
b91048968b
Generate: Fix CI related to #20727 (#21003) 2023-01-04 20:26:56 +00:00
Motoki Wu
45da7cec5a
Add custom stop token ids for generation (#20727)
* Add StopIdStoppingCriteria

* add a working test for stop id criteria

* add to global scope

* add stop_ids to generate

* add pipeline test

* use tokenizer encode in test

* add test to generation utils

* reformat

* fixup

* make-fix-copies

* rename to stop_token_id

* use stop_tokens instead

* add to text to text generation

* make fixup

* make repo-consistency

* Add support for list of ints for eos_token_id inside generation/utils.py

* Instead of having if elses, cast the eos_token_id into a List[int]

* Add List[int] support for logits_process.py

* add List[int] for beam_search.py

* add List[int] for forced_eos_token_id

* revert stop token id stopping criteria changes

* make fixup

* fix tests

* add eos_token_id to generation/utils.py and added tests test_utils.py

* add eos_token_id type hints and fix for pad tokens

* add comments

* remove some prints and remove forced false test

* fix

* put back test_stop_sequence_stopping_criteria

* remove unused import and make fixup

* add a none check

* update docstring

* add more docstring for list ints

* make fixup
2023-01-03 15:18:24 -05:00
Joao Gante
4cf38148dc
Generate: model_kwargs can also be an input to prepare_inputs_for_generation (#20353) 2022-11-21 16:20:27 +00:00
Joao Gante
938cb04789
Generate: add Bloom fixes for contrastive search (#20213) 2022-11-14 18:34:11 +00:00
Joao Gante
f270b960d6
Generate: move generation_*.py src files into generation/*.py (#20096)
* move generation_*.py src files into generation/*.py

* populate generation.__init__ with lazy loading

* move imports and references from generation.xxx.object to generation.object
2022-11-09 15:34:08 +00:00