Skip to content

Commit a3fef89

Browse files
authored
[GPT2] Propose fix for huggingface#21080 (huggingface#21853)
* Make sure position ids are masked * test that padded input produce the same results * fix failing tests * fixup * fix batch test
1 parent eee195b commit a3fef89

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

src/transformers/models/decision_transformer/modeling_decision_transformer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,14 @@ def forward(
553553
past_key_values = tuple([None] * len(self.h))
554554
else:
555555
past_length = past_key_values[0][0].size(-2)
556-
if position_ids is None:
556+
557+
if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
558+
# create position_ids on the fly for batch generation
559+
position_ids = attention_mask.long().cumsum(-1) - 1
560+
position_ids.masked_fill_(attention_mask == 0, 1)
561+
if past_length > 0:
562+
position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
563+
elif position_ids is None:
557564
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
558565
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
559566

src/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,14 @@ def forward(
797797
past_key_values = tuple([None] * len(self.h))
798798
else:
799799
past_length = past_key_values[0][0].size(-2)
800-
if position_ids is None:
800+
801+
if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
802+
# create position_ids on the fly for batch generation
803+
position_ids = attention_mask.long().cumsum(-1) - 1
804+
position_ids.masked_fill_(attention_mask == 0, 1)
805+
if past_length > 0:
806+
position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
807+
elif position_ids is None:
801808
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
802809
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
803810

tests/models/gpt2/test_modeling_gpt2.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,27 @@ def test_batch_generation(self):
590590
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
591591
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
592592

593+
@slow
594+
def test_batch_forward(self):
595+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
596+
tokenizer.padding_side = "left"
597+
598+
# This tokenizer has no pad token, so we have to set it in some way
599+
# Define PAD Token = EOS Token = 50256
600+
tokenizer.pad_token = tokenizer.eos_token
601+
602+
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
603+
sentences = ["Hello, my dog is a little bit of a mess. I'm not sure if he's"]
604+
inputs = tokenizer(sentences, padding=True, return_tensors="pt")
605+
logits = model(**inputs).logits[:, -1, :]
606+
indexes = torch.argmax(logits).item()
607+
608+
inputs_padded = tokenizer(sentences, padding="max_length", max_length=30, return_tensors="pt")
609+
logits_padded = model(**inputs_padded).logits[:, -1, :]
610+
indexes_padded = torch.argmax(logits_padded).item()
611+
612+
self.assertTrue(indexes == indexes_padded)
613+
593614
@slow
594615
def test_batch_generation_2heads(self):
595616
model = GPT2DoubleHeadsModel.from_pretrained("gpt2")

0 commit comments

Comments
 (0)