mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix JetMoeIntegrationTest
(#32332)
JetMoeIntegrationTest Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
78d78cdf8a
commit
20a04497a8
@ -478,7 +478,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_model_8b_logits(self):
|
def test_model_8b_logits(self):
|
||||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
|
||||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = model(input_ids).logits.cpu()
|
out = model(input_ids).logits.cpu()
|
||||||
@ -498,7 +498,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
|
|||||||
EXPECTED_TEXT_COMPLETION = """My favourite condiment is ....\nI love ketchup. I love"""
|
EXPECTED_TEXT_COMPLETION = """My favourite condiment is ....\nI love ketchup. I love"""
|
||||||
prompt = "My favourite condiment is "
|
prompt = "My favourite condiment is "
|
||||||
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
|
||||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
|
||||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||||
|
|
||||||
# greedy generation outputs
|
# greedy generation outputs
|
||||||
@ -521,7 +521,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
|
|||||||
"My favourite ",
|
"My favourite ",
|
||||||
]
|
]
|
||||||
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
|
||||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
|
||||||
input_ids = tokenizer(prompt, return_tensors="pt", padding=True).to(model.model.embed_tokens.weight.device)
|
input_ids = tokenizer(prompt, return_tensors="pt", padding=True).to(model.model.embed_tokens.weight.device)
|
||||||
print(input_ids)
|
print(input_ids)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user