mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Experimental loading of MLX files (#29511)
* Experimental loading of MLX files * Update exception message * Add test * Style * Use model from hf-internal-testing
This commit is contained in:
parent
73a27345d4
commit
b382a09e28
@ -3297,9 +3297,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
elif metadata.get("format") == "flax":
|
elif metadata.get("format") == "flax":
|
||||||
from_flax = True
|
from_flax = True
|
||||||
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
|
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
|
||||||
|
elif metadata.get("format") == "mlx":
|
||||||
|
# This is a mlx file, we assume weights are compatible with pt
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}"
|
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
from_pt = not (from_tf | from_flax)
|
from_pt = not (from_tf | from_flax)
|
||||||
|
@ -1256,6 +1256,26 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
self.assertEqual(len(logs.output), 1)
|
self.assertEqual(len(logs.output), 1)
|
||||||
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
|
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_model_from_pretrained_from_mlx(self):
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-mistral-mlx")
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
with safe_open(os.path.join(tmp_dir, "model.safetensors"), framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
self.assertEqual(metadata.get("format"), "pt")
|
||||||
|
new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
input_ids = torch.randint(100, 1000, (1, 10))
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_ids)
|
||||||
|
outputs_from_saved = new_model(input_ids)
|
||||||
|
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
Loading…
Reference in New Issue
Block a user