@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
num_beams: Optional[int] = 5,
decoder_start_token_id: Optional[int] = 2,
decoder_end_token_id: Optional[Union[int, List[int]]] = 2,
num_return_sequences: Optional[int] = 1,
min_gen_seq_length: Optional[int] = 0,
max_gen_seq_length: Optional[int] = 48,
repetition_penalty: Optional[float] = 1.0,
no_repeat_ngram_size: Optional[int] = 0,
early_stopping: Optional[bool] = True,
length_penalty: Optional[float] = 1.0,
num_beam_groups: Optional[int] = 1,
diversity_penalty: Optional[float] = 0.0,
do_sample: Optional[bool] = False,
temperature: Optional[float] = 1.0,
top_k: Optional[int] = 50,
top_p: Optional[float] = 1.0,
):
outputs = self.model.generate(
input_ids,
max_length=max_gen_seq_length,
min_length=min_gen_seq_length,
num_beams=num_beams,
do_sample=do_sample,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=early_stopping,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
bos_token_id=decoder_start_token_id,
eos_token_id=decoder_end_token_id,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
temperature=temperature,
top_k=top_k,
top_p=top_p,
return_dict_in_generate=True,
output_scores=True,
)
sequences = outputs.sequences.reshape(
-1, num_return_sequences, outputs.sequences.size(-1)
)
outputs.sequences = (
torch.zeros(sequences.size(0), num_return_sequences, max_gen_seq_length).to(
device=sequences.device
)
+ decoder_start_token_id
)
outputs.sequences[:, :, : sequences.size(-1)].copy_(sequences)
if num_return_sequences == 1:
outputs.sequences = outputs.sequences.reshape(-1, max_gen_seq_length)
return GenericOutputs(
sequences=outputs.sequences,
sequences_scores=outputs.sequences_scores,
)