mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix JetMoeIntegrationTest
(#32332)
JetMoeIntegrationTest Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
78d78cdf8a
commit
20a04497a8
@ -478,7 +478,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_model_8b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
with torch.no_grad():
|
||||
out = model(input_ids).logits.cpu()
|
||||
@ -498,7 +498,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_TEXT_COMPLETION = """My favourite condiment is ....\nI love ketchup. I love"""
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
|
||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||
|
||||
# greedy generation outputs
|
||||
@ -521,7 +521,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
|
||||
"My favourite ",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
|
||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
|
||||
input_ids = tokenizer(prompt, return_tensors="pt", padding=True).to(model.model.embed_tokens.weight.device)
|
||||
print(input_ids)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user