Fix JetMoeIntegrationTest (#32332)

JetMoeIntegrationTest

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-08-14 16:22:06 +02:00 committed by GitHub
parent 78d78cdf8a
commit 20a04497a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -478,7 +478,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
@slow
def test_model_8b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()
@ -498,7 +498,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
EXPECTED_TEXT_COMPLETION = """My favourite condiment is ....\nI love ketchup. I love"""
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs
@ -521,7 +521,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
"My favourite ",
]
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
input_ids = tokenizer(prompt, return_tensors="pt", padding=True).to(model.model.embed_tokens.weight.device)
print(input_ids)