guard torch version for uint16 (#36520)

* u16

* style

* fix
This commit is contained in:
Marc Sun 2025-03-05 11:27:01 +01:00 committed by GitHub
parent 66f29aaaf5
commit 752ef3fd4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -522,17 +522,19 @@ str_to_torch_dtype = {
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"U16": torch.uint16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"U32": torch.uint32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"U64": torch.uint64,
}
if is_torch_greater_or_equal("2.3.0"):
str_to_torch_dtype["U16"] = torch.uint16
str_to_torch_dtype["U32"] = torch.uint32
str_to_torch_dtype["U64"] = torch.uint64
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],