unitorch.models.mt5¤
MT5Processor¤
Bases: HfTextGenerationProcessor
Initializes an MT5Processor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
vocab_path |
str
|
The path to the vocabulary file. |
required |
special_input_ids |
Dict
|
Dictionary of special input tokens and their corresponding IDs. Defaults to an empty dictionary. |
dict()
|
max_seq_length |
int
|
The maximum length of input sequences. Defaults to 128. |
128
|
max_gen_seq_length |
int
|
The maximum length of generated sequences. Defaults to 48. |
48
|
Source code in src/unitorch/models/mt5/processing.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
|
MT5ForGeneration¤
Bases: GenericModel
Initializes an MT5ForGeneration model with the provided configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_path |
str
|
The path to the model configuration file. |
required |
gradient_checkpointing |
bool
|
Whether to use gradient checkpointing. Defaults to False. |
False
|
Source code in src/unitorch/models/mt5/modeling.py
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
|
forward ¤
forward(
input_ids: Tensor,
attention_mask: Tensor,
decoder_input_ids: Tensor,
decoder_attention_mask: Tensor,
)
Performs forward pass of the MT5ForGeneration model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_ids |
Tensor
|
Tensor of input token IDs. |
required |
attention_mask |
Tensor
|
Tensor of attention mask. |
required |
decoder_input_ids |
Tensor
|
Tensor of decoder input token IDs. |
required |
decoder_attention_mask |
Tensor
|
Tensor of decoder attention mask. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The model's logits. |
Source code in src/unitorch/models/mt5/modeling.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|
generate ¤
generate(
input_ids: Tensor,
num_beams: Optional[int] = 5,
decoder_start_token_id: Optional[int] = 0,
decoder_end_token_id: Optional[
Union[int, List[int]]
] = 1,
num_return_sequences: Optional[int] = 1,
min_gen_seq_length: Optional[int] = 0,
max_gen_seq_length: Optional[int] = 48,
repetition_penalty: Optional[float] = 1.0,
no_repeat_ngram_size: Optional[int] = 0,
early_stopping: Optional[bool] = True,
length_penalty: Optional[float] = 1.0,
num_beam_groups: Optional[int] = 1,
diversity_penalty: Optional[float] = 0.0,
do_sample: Optional[bool] = False,
temperature: Optional[float] = 1.0,
top_k: Optional[int] = 50,
top_p: Optional[float] = 1.0,
)
Generates sequences using the MT5ForGeneration model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_ids |
Tensor
|
The input token IDs. |
required |
num_beams |
int
|
The number of beams for beam search. Defaults to 5. |
5
|
decoder_start_token_id |
int
|
The decoder's start token ID. Defaults to 2. |
0
|
decoder_end_token_id |
int or List[int]
|
The decoder's end token ID. Defaults to 2. |
1
|
num_return_sequences |
int
|
The number of generated sequences to return. Defaults to 1. |
1
|
min_gen_seq_length |
int
|
The minimum length of the generated sequences. Defaults to 0. |
0
|
max_gen_seq_length |
int
|
The maximum length of the generated sequences. Defaults to 48. |
48
|
repetition_penalty |
float
|
The repetition penalty. Defaults to 1.0. |
1.0
|
no_repeat_ngram_size |
int
|
The size of n-grams to avoid repeating. Defaults to 0. |
0
|
early_stopping |
bool
|
Whether to stop generation early. Defaults to True. |
True
|
length_penalty |
float
|
The length penalty. Defaults to 1.0. |
1.0
|
num_beam_groups |
int
|
The number of beam groups for diverse beam search. Defaults to 1. |
1
|
diversity_penalty |
float
|
The diversity penalty. Defaults to 0.0. |
0.0
|
do_sample |
bool
|
Whether to use sampling for generation. Defaults to False. |
False
|
temperature |
float
|
The temperature for sampling. Defaults to 1.0. |
1.0
|
top_k |
int
|
The value for top-k sampling. Defaults to 50. |
50
|
top_p |
float
|
The value for top-p (nucleus) sampling. Defaults to 1.0. |
1.0
|
Returns:
Name | Type | Description |
---|---|---|
GenericOutputs |
The generated sequences and their scores. |
Source code in src/unitorch/models/mt5/modeling.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
|