Skip to content

unitorch.models.mbart¤

MBartProcessor¤

Bases: HfTextGenerationProcessor

Initializes an MBartProcessor with the provided parameters.

Parameters:

Name Type Description Default
vocab_path str

The path to the vocabulary file.

required
max_seq_length int

The maximum sequence length. Defaults to 128.

128
max_gen_seq_length int

The maximum generation sequence length. Defaults to 48.

48
special_input_ids Dict

A dictionary of special input IDs. Defaults to an empty dictionary.

dict()
Source code in src/unitorch/models/mbart/processing.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __init__(
    self,
    vocab_path: str,
    max_seq_length: Optional[int] = 128,
    max_gen_seq_length: Optional[int] = 48,
    special_input_ids: Optional[Dict] = dict(),
):
    """
    Initializes an MBartProcessor with the provided parameters.

    Args:
        vocab_path (str): The path to the vocabulary file.
        max_seq_length (int, optional): The maximum sequence length. Defaults to 128.
        max_gen_seq_length (int, optional): The maximum generation sequence length. Defaults to 48.
        special_input_ids (Dict, optional): A dictionary of special input IDs. Defaults to an empty dictionary.
    """
    tokenizer = get_mbart_tokenizer(
        vocab_path,
        special_input_ids=special_input_ids,
    )
    super().__init__(
        tokenizer=tokenizer,
        max_seq_length=max_seq_length,
        max_gen_seq_length=max_gen_seq_length,
    )

MBartForGeneration¤

Bases: GenericModel

Initializes an MBartForGeneration model with the provided configuration.

Parameters:

Name Type Description Default
config_path str

The path to the model configuration file.

required
freeze_input_embedding bool

Whether to freeze the input embeddings. Defaults to True.

True
gradient_checkpointing bool

Whether to use gradient checkpointing. Defaults to False.

False
Source code in src/unitorch/models/mbart/modeling.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(
    self,
    config_path: str,
    freeze_input_embedding: Optional[bool] = True,
    gradient_checkpointing: Optional[bool] = False,
):
    """
    Initializes an MBartForGeneration model with the provided configuration.

    Args:
        config_path (str): The path to the model configuration file.
        freeze_input_embedding (bool, optional): Whether to freeze the input embeddings. Defaults to True.
        gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
    """
    super().__init__()
    self.config = MBartConfig.from_json_file(config_path)
    self.config.gradient_checkpointing = gradient_checkpointing
    self.model = MBartForConditionalGeneration(self.config)

    if freeze_input_embedding:
        for param in self.model.get_input_embeddings().parameters():
            param.requires_grad = False

    self.init_weights()

forward ¤

forward(
    input_ids: Tensor,
    attention_mask: Tensor,
    decoder_input_ids: Tensor,
    decoder_attention_mask: Tensor,
)

Performs forward pass of the MBartForGeneration model.

Parameters:

Name Type Description Default
input_ids Tensor

Tensor of input token IDs.

required
attention_mask Tensor

Tensor of attention mask.

required
decoder_input_ids Tensor

Tensor of decoder input token IDs.

required
decoder_attention_mask Tensor

Tensor of decoder attention mask.

required

Returns:

Type Description
Tensor

The model's logits.

Source code in src/unitorch/models/mbart/modeling.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def forward(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    decoder_input_ids: torch.Tensor,
    decoder_attention_mask: torch.Tensor,
):
    """
    Performs forward pass of the MBartForGeneration model.

    Args:
        input_ids (torch.Tensor): Tensor of input token IDs.
        attention_mask (torch.Tensor): Tensor of attention mask.
        decoder_input_ids (torch.Tensor): Tensor of decoder input token IDs.
        decoder_attention_mask (torch.Tensor): Tensor of decoder attention mask.

    Returns:
        (torch.Tensor):The model's logits.
    """
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,
        decoder_attention_mask=decoder_attention_mask,
        return_dict=True,
    )
    logits = outputs.logits
    return logits

generate ¤

generate(
    input_ids: Tensor,
    num_beams: Optional[int] = 5,
    decoder_start_token_id: Optional[int] = 2,
    decoder_end_token_id: Optional[
        Union[int, List[int]]
    ] = 2,
    num_return_sequences: Optional[int] = 1,
    min_gen_seq_length: Optional[int] = 0,
    max_gen_seq_length: Optional[int] = 48,
    repetition_penalty: Optional[float] = 1.0,
    no_repeat_ngram_size: Optional[int] = 0,
    early_stopping: Optional[bool] = True,
    length_penalty: Optional[float] = 1.0,
    num_beam_groups: Optional[int] = 1,
    diversity_penalty: Optional[float] = 0.0,
    do_sample: Optional[bool] = False,
    temperature: Optional[float] = 1.0,
    top_k: Optional[int] = 50,
    top_p: Optional[float] = 1.0,
)

Generates sequences using the MBartForGeneration model.

Parameters:

Name Type Description Default
input_ids Tensor

The input token IDs.

required
num_beams int

The number of beams for beam search. Defaults to 5.

5
decoder_start_token_id int

The decoder's start token ID. Defaults to 2.

2
decoder_end_token_id int or List[int]

The decoder's end token ID. Defaults to 2.

2
num_return_sequences int

The number of generated sequences to return. Defaults to 1.

1
min_gen_seq_length int

The minimum length of the generated sequences. Defaults to 0.

0
max_gen_seq_length int

The maximum length of the generated sequences. Defaults to 48.

48
repetition_penalty float

The repetition penalty. Defaults to 1.0.

1.0
no_repeat_ngram_size int

The size of n-grams to avoid repeating. Defaults to 0.

0
early_stopping bool

Whether to stop generation early. Defaults to True.

True
length_penalty float

The length penalty. Defaults to 1.0.

1.0
num_beam_groups int

The number of beam groups for diverse beam search. Defaults to 1.

1
diversity_penalty float

The diversity penalty. Defaults to 0.0.

0.0
do_sample bool

Whether to use sampling for generation. Defaults to False.

False
temperature float

The temperature for sampling. Defaults to 1.0.

1.0
top_k int

The value for top-k sampling. Defaults to 50.

50
top_p float

The value for top-p (nucleus) sampling. Defaults to 1.0.

1.0

Returns:

Name Type Description
GenericOutputs

The generated sequences and their scores.

Source code in src/unitorch/models/mbart/modeling.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
@torch.no_grad()
def generate(
    self,
    input_ids: torch.Tensor,
    num_beams: Optional[int] = 5,
    decoder_start_token_id: Optional[int] = 2,
    decoder_end_token_id: Optional[Union[int, List[int]]] = 2,
    num_return_sequences: Optional[int] = 1,
    min_gen_seq_length: Optional[int] = 0,
    max_gen_seq_length: Optional[int] = 48,
    repetition_penalty: Optional[float] = 1.0,
    no_repeat_ngram_size: Optional[int] = 0,
    early_stopping: Optional[bool] = True,
    length_penalty: Optional[float] = 1.0,
    num_beam_groups: Optional[int] = 1,
    diversity_penalty: Optional[float] = 0.0,
    do_sample: Optional[bool] = False,
    temperature: Optional[float] = 1.0,
    top_k: Optional[int] = 50,
    top_p: Optional[float] = 1.0,
):
    """
    Generates sequences using the MBartForGeneration model.

    Args:
        input_ids: The input token IDs.
        num_beams (int, optional): The number of beams for beam search. Defaults to 5.
        decoder_start_token_id (int, optional): The decoder's start token ID. Defaults to 2.
        decoder_end_token_id (int or List[int], optional): The decoder's end token ID. Defaults to 2.
        num_return_sequences (int, optional): The number of generated sequences to return. Defaults to 1.
        min_gen_seq_length (int, optional): The minimum length of the generated sequences. Defaults to 0.
        max_gen_seq_length (int, optional): The maximum length of the generated sequences. Defaults to 48.
        repetition_penalty (float, optional): The repetition penalty. Defaults to 1.0.
        no_repeat_ngram_size (int, optional): The size of n-grams to avoid repeating. Defaults to 0.
        early_stopping (bool, optional): Whether to stop generation early. Defaults to True.
        length_penalty (float, optional): The length penalty. Defaults to 1.0.
        num_beam_groups (int, optional): The number of beam groups for diverse beam search. Defaults to 1.
        diversity_penalty (float, optional): The diversity penalty. Defaults to 0.0.
        do_sample (bool, optional): Whether to use sampling for generation. Defaults to False.
        temperature (float, optional): The temperature for sampling. Defaults to 1.0.
        top_k (int, optional): The value for top-k sampling. Defaults to 50.
        top_p (float, optional): The value for top-p (nucleus) sampling. Defaults to 1.0.

    Returns:
        GenericOutputs: The generated sequences and their scores.
    """
    outputs = self.model.generate(
        input_ids,
        max_length=max_gen_seq_length,
        min_length=min_gen_seq_length,
        num_beams=num_beams,
        do_sample=do_sample,
        decoder_start_token_id=decoder_start_token_id,
        no_repeat_ngram_size=no_repeat_ngram_size,
        early_stopping=early_stopping,
        length_penalty=length_penalty,
        repetition_penalty=repetition_penalty,
        num_return_sequences=num_return_sequences,
        bos_token_id=decoder_start_token_id,
        eos_token_id=decoder_end_token_id,
        num_beam_groups=num_beam_groups,
        diversity_penalty=diversity_penalty,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        return_dict_in_generate=True,
        output_scores=True,
    )

    sequences = outputs.sequences.reshape(
        -1, num_return_sequences, outputs.sequences.size(-1)
    )
    outputs.sequences = torch.zeros(
        sequences.size(0), num_return_sequences, max_gen_seq_length
    ).to(device=sequences.device)
    outputs.sequences[:, :, : sequences.size(-1)].copy_(sequences)

    if num_return_sequences == 1:
        outputs.sequences = outputs.sequences.reshape(-1, max_gen_seq_length)

    return GenericOutputs(
        sequences=outputs.sequences,
        sequences_scores=outputs.sequences_scores,
    )