mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
66f29aaaf5
commit
752ef3fd4e
@ -522,17 +522,19 @@ str_to_torch_dtype = {
|
||||
"U8": torch.uint8,
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"U16": torch.uint16,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I32": torch.int32,
|
||||
"U32": torch.uint32,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"U64": torch.uint64,
|
||||
}
|
||||
|
||||
if is_torch_greater_or_equal("2.3.0"):
|
||||
str_to_torch_dtype["U16"] = torch.uint16
|
||||
str_to_torch_dtype["U32"] = torch.uint32
|
||||
str_to_torch_dtype["U64"] = torch.uint64
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
|
Loading…
Reference in New Issue
Block a user