Consider inheritance in type checking for tensors (#31378)

* Consider inheritance in type checking for tensors

Add an additional check to bypass type assertion when both tensors are
torch.Tensor instances.

* Fix the quality issue
This commit is contained in:
Daemyung Jang 2024-06-19 21:05:20 +09:00 committed by GitHub
parent 3b5fa14fb8
commit 547b5582ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -131,9 +131,10 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
nested list/tuples/dict of tensors.
"""
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)):
assert (
type(tensors) == type(new_tensors)
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
elif isinstance(tensors, torch.Tensor):