mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Consider inheritance in type checking for tensors (#31378)
* Consider inheritance in type checking for tensors Add an additional check to bypass type assertion when both tensors are torch.Tensor instances. * Fix the quality issue
This commit is contained in:
parent
3b5fa14fb8
commit
547b5582ec
@ -131,9 +131,10 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
|
||||
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
|
||||
nested list/tuples/dict of tensors.
|
||||
"""
|
||||
assert type(tensors) == type(
|
||||
new_tensors
|
||||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
||||
if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)):
|
||||
assert (
|
||||
type(tensors) == type(new_tensors)
|
||||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
|
||||
elif isinstance(tensors, torch.Tensor):
|
||||
|
Loading…
Reference in New Issue
Block a user