mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-14 18:18:24 +06:00

* wip * fix __init__.py * add docs * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * address comments 1 * work on make fixup * pass configs down * add sdpa attention * remove DbrxBlock * add to configuration_auto * docstring now passes formatting test * fix style * update READMEs * add dbrx to modeling_auto * make fix-copies generated this * add DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP * config docstring passes formatting test * rename moe_loss_weight to router_aux_loss_coef * add to flash-attn documentation * fix model-path in tests * Explicitly make `"suli"` the default `ffn_act_fn` Co-authored-by: Wing Lian <wing.lian@gmail.com> * default to using router_aux_loss_coef over ffn_config[moe_loss_weight] * fix _flash_attn_uses_top_left_mask and is_causal * fix tests path * don't use token type IDs * follow Llama and remove token_type_ids from test * init ConfigTester differently so tests pass * remove multiple choice test * remove question + answer test * remove sequence classification test * remove token classification test * copy Llama tests and remove token_type_ids from test inputs * do not test pruning or headmasking; style code * add _tied_weights_keys parameter to pass test * add type hints * fix type check * update config tester * remove masked_lm test * remove encoder tests * initialize DbrxModelTester with correct params * style * torch_dtype does not rely on torch * run make fixup, fix-copies * use https://huggingface.co/v2ray/dbrx-base-fixed/blob/main/modeling_dbrx.py * add copyright info * fix imports and DbrxRotaryEmbedding * update DbrxModel docstring * use copies * change model path in docstring * use config in DbrxFFN * fix flashattention2, sdpaattention * input config to DbrXAttention, DbrxNormAttentionNorm * more fixes * fix * fix again! * add informative comment * fix ruff? * remove print statement + style * change doc-test * fix doc-test * fix docstring * delete commented out text * make defaults match dbrx-instruct * replace `router_aux_loss_coef` with `moe_loss_weight` * is_decoder=True * remove is_decoder from configtester * implement sdpa properly * make is_decoder pass tests * start on the GenerationTesterMixin tests * add dbrx to sdpa documentation * skip weight typing test * style * initialize smaller model Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Add DBRX to toctree * skip test_new_cache_format * make config defaults smaller again * add pad_token_id * remove pad_token_id from config * Remove all references to DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP * Update src/transformers/models/dbrx/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/dbrx/modeling_dbrx.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/dbrx.md Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Update src/transformers/models/dbrx/configuration_dbrx.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/dbrx.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix typo * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update docs, fix configuration_auto.py * address pr comments * remove is_decoder flag * slice * fix requires grad * remove grad * disconnect differently * remove grad * enable grads * patch * detach expert * nissan al ghaib * Update modeling_dbrx.py * Update src/transformers/models/dbrx/modeling_dbrx.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * replace "Gemma" with "Dbrx" * remove # type: ignore * don't hardcode vocab_size * remove ToDo * Re-add removed idefics2 line * Update test to use tiny-random! * Remove TODO * Remove one more case of loading the entire dbrx-instruct in the tests * Update src/transformers/models/dbrx/modeling_dbrx.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * address some comments * small model * add dbrx to tokenization_auto * More docstrings with add_start_docstrings * Dbrx for now * add PipelineTesterMixin * Update src/transformers/models/dbrx/configuration_dbrx.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove flash-attn2 import error * fix docstring Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add useage example * put on one line Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix ffn_act_fn Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change "dbrx" to "DBRX" for display purposes. * fix __init__.py? * fix __init__.py * fix README * return the aux_loss * remove extra spaces * fix configuration_auto.py * fix format in tokenization_auto * remove new line * add more useage examples --------- Co-authored-by: Abhi Venigalla <abhi.venigalla@databricks.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Eitan Turok <eitan.turok@databricks.com> Co-authored-by: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: Eitan Turok <eitanturok@gmail.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
121 lines
5.5 KiB
Markdown
121 lines
5.5 KiB
Markdown
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
# DBRX
|
|
|
|
## Overview
|
|
|
|
DBRX is a [transformer-based](https://www.isattentionallyouneed.com/) decoder-only large language model (LLM) that was trained using next-token prediction.
|
|
It uses a *fine-grained* mixture-of-experts (MoE) architecture with 132B total parameters of which 36B parameters are active on any input.
|
|
It was pre-trained on 12T tokens of text and code data.
|
|
Compared to other open MoE models like Mixtral-8x7B and Grok-1, DBRX is fine-grained, meaning it uses a larger number of smaller experts. DBRX has 16 experts and chooses 4, while Mixtral-8x7B and Grok-1 have 8 experts and choose 2.
|
|
This provides 65x more possible combinations of experts and we found that this improves model quality.
|
|
DBRX uses rotary position encodings (RoPE), gated linear units (GLU), and grouped query attention (GQA).
|
|
It is a BPE based model and uses the GPT-4 tokenizer as described in the [tiktoken](https://github.com/openai/tiktoken) repository.
|
|
We made these choices based on exhaustive evaluation and scaling experiments.
|
|
|
|
DBRX was pretrained on 12T tokens of carefully curated data and a maximum context length of 32K tokens.
|
|
We estimate that this data is at least 2x better token-for-token than the data we used to pretrain the MPT family of models.
|
|
This new dataset was developed using the full suite of Databricks tools, including Apache Spark™ and Databricks notebooks for data processing, and Unity Catalog for data management and governance.
|
|
We used curriculum learning for pretraining, changing the data mix during training in ways we found to substantially improve model quality.
|
|
|
|
|
|
More detailed information about DBRX Instruct and DBRX Base can be found in our [technical blog post](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm).
|
|
|
|
|
|
This model was contributed by [eitan-turok](https://huggingface.co/eitanturok) and [abhi-db](https://huggingface.co/abhi-db). The original code can be found [here](https://github.com/databricks/dbrx-instruct).
|
|
|
|
## Usage Examples
|
|
|
|
The `generate()` method can be used to generate text using DBRX. You can generate using the standard attention implementation, flash-attention, and the PyTorch scaled dot product attention. The last two attention implementations give speed ups.
|
|
|
|
```python
|
|
from transformers import DbrxForCausalLM, AutoTokenizer
|
|
import torch
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
|
|
model = DbrxForCausalLM.from_pretrained(
|
|
"databricks/dbrx-instruct",
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
token="YOUR_HF_TOKEN",
|
|
)
|
|
|
|
input_text = "What does it take to build a great LLM?"
|
|
messages = [{"role": "user", "content": input_text}]
|
|
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
|
|
|
outputs = model.generate(**input_ids, max_new_tokens=200)
|
|
print(tokenizer.decode(outputs[0]))
|
|
```
|
|
|
|
If you have flash-attention installed (`pip install flash-attn`), it is possible to generate faster. (The HuggingFace documentation for flash-attention can be found [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2).)
|
|
```python
|
|
from transformers import DbrxForCausalLM, AutoTokenizer
|
|
import torch
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
|
|
model = DbrxForCausalLM.from_pretrained(
|
|
"databricks/dbrx-instruct",
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
token="YOUR_HF_TOKEN",
|
|
attn_implementation="flash_attention_2",
|
|
)
|
|
|
|
input_text = "What does it take to build a great LLM?"
|
|
messages = [{"role": "user", "content": input_text}]
|
|
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
|
|
|
outputs = model.generate(**input_ids, max_new_tokens=200)
|
|
print(tokenizer.decode(outputs[0]))
|
|
```
|
|
|
|
You can also generate faster using the PyTorch scaled dot product attention. (The HuggingFace documentation for scaled dot product attention can be found [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention).)
|
|
```python
|
|
from transformers import DbrxForCausalLM, AutoTokenizer
|
|
import torch
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
|
|
model = DbrxForCausalLM.from_pretrained(
|
|
"databricks/dbrx-instruct",
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
token="YOUR_HF_TOKEN",
|
|
attn_implementation="sdpa",
|
|
)
|
|
|
|
input_text = "What does it take to build a great LLM?"
|
|
messages = [{"role": "user", "content": input_text}]
|
|
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
|
|
|
outputs = model.generate(**input_ids, max_new_tokens=200)
|
|
print(tokenizer.decode(outputs[0]))
|
|
```
|
|
|
|
## DbrxConfig
|
|
|
|
[[autodoc]] DbrxConfig
|
|
|
|
|
|
## DbrxModel
|
|
|
|
[[autodoc]] DbrxModel
|
|
- forward
|
|
|
|
|
|
## DbrxForCausalLM
|
|
|
|
[[autodoc]] DbrxForCausalLM
|
|
- forward
|
|
|