mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add gpt2 test on XPU (#37028)
* add gpt2 test on XPU Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * auto dtype has been fixed Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * convert model to train mode Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
4b13a02920
commit
3a6ab46a0b
@ -626,7 +626,6 @@ class Bnb4BitTestTraining(Base4bitTest):
|
||||
|
||||
|
||||
@apply_skip_if_not_implemented
|
||||
@unittest.skipIf(torch_device == "xpu", reason="XPU has precision issue on gpt model, will test it once fixed")
|
||||
class Bnb4BitGPT2Test(Bnb4BitTest):
|
||||
model_name = "openai-community/gpt2-xl"
|
||||
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
|
||||
|
@ -889,6 +889,7 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||
|
||||
# Step 1: freeze all parameters
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True)
|
||||
model.train()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
|
||||
@ -914,14 +915,9 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||
batch = self.tokenizer("Test batch ", return_tensors="pt").to(torch_device)
|
||||
|
||||
# Step 4: Check if the gradient is not None
|
||||
if torch_device in {"xpu", "cpu"}:
|
||||
# XPU and CPU finetune do not support autocast for now.
|
||||
with torch.autocast(torch_device):
|
||||
out = model.forward(**batch)
|
||||
out.logits.norm().backward()
|
||||
else:
|
||||
with torch.autocast(torch_device):
|
||||
out = model.forward(**batch)
|
||||
out.logits.norm().backward()
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
@ -932,7 +928,6 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||
|
||||
|
||||
@apply_skip_if_not_implemented
|
||||
@unittest.skipIf(torch_device == "xpu", reason="XPU has precision issue on gpt model, will test it once fixed")
|
||||
class MixedInt8GPT2Test(MixedInt8Test):
|
||||
model_name = "openai-community/gpt2-xl"
|
||||
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
|
||||
|
Loading…
Reference in New Issue
Block a user