diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index a5b053d5f4e..6ca32f868af 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -359,7 +359,7 @@ class JanusVisionAttention(nn.Module): output = self.projection_layer(attn_output) output = self.projection_dropout(output) - return output + return output, attn_weights class JanusVisionMLP(nn.Module): diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 208a7135c01..25311d2774d 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -529,7 +529,7 @@ class JanusVisionAttention(nn.Module): output = self.projection_layer(attn_output) output = self.projection_dropout(output) - return output + return output, attn_weights class JanusVisionMLP(nn.Module): diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index ea7d9d52b81..9b6997e8b5b 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -952,10 +952,7 @@ def can_return_tuple(func): @wraps(func) def wrapper(self, *args, **kwargs): return_dict = self.config.use_return_dict if hasattr(self, "config") else True - if "return_dict" in kwargs: - return_dict = kwargs.get("return_dict", self.config.use_return_dict) - if "return_dict" in kwargs: - kwargs["return_dict"] = True + return_dict = kwargs.pop("return_dict", self.config.use_return_dict) output = func(self, *args, **kwargs) if "return_dict" in kwargs and return_dict is False: