making gpt2 fx traceable (#34633)

* making gpt2 fx tracable

* running make fix-copies

* Revert "running make fix-copies"

This reverts commit 5a3437cb5b.
This commit is contained in:
xuzifei-dmatrix 2024-11-25 10:30:38 -08:00 committed by GitHub
parent 95c10fedb3
commit bfc3556b20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1101,7 +1101,8 @@ class GPT2Model(GPT2PreTrainedModel):
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
for i in range(len(self.h)):
block, layer_past = self.h[i], past_key_values[i]
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)