unitorch.models.xpegasus¤
XPegasusProcessor¤
Bases: HfTextGenerationProcessor
Source code in src/unitorch/models/xpegasus/processing.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
XPegasusForGeneration¤
Bases: GenericModel
XPegasus model for text generation tasks.
Initializes the XPegasusForGeneration model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_path |
str
|
Path to the model configuration file. |
required |
gradient_checkpointing |
Optional[bool]
|
Whether to use gradient checkpointing. Defaults to False. |
False
|
Source code in src/unitorch/models/xpegasus/modeling.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
|
forward ¤
forward(
input_ids: Tensor,
attention_mask: Tensor,
decoder_input_ids: Tensor,
decoder_attention_mask: Tensor,
)
Performs forward pass of the XPegasusForGeneration model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_ids |
Tensor
|
Tensor of input token IDs. |
required |
attention_mask |
Tensor
|
Tensor of attention mask. |
required |
decoder_input_ids |
Tensor
|
Tensor of decoder input token IDs. |
required |
decoder_attention_mask |
Tensor
|
Tensor of decoder attention mask. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The model's logits. |
Source code in src/unitorch/models/xpegasus/modeling.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|
generate ¤
generate(
input_ids: Tensor,
num_beams: Optional[int] = 5,
decoder_start_token_id: Optional[int] = 0,
decoder_end_token_id: Optional[
Union[int, List[int]]
] = 1,
num_return_sequences: Optional[int] = 1,
min_gen_seq_length: Optional[int] = 0,
max_gen_seq_length: Optional[int] = 48,
repetition_penalty: Optional[float] = 1.0,
no_repeat_ngram_size: Optional[int] = 0,
early_stopping: Optional[bool] = True,
length_penalty: Optional[float] = 1.0,
num_beam_groups: Optional[int] = 1,
diversity_penalty: Optional[float] = 0.0,
do_sample: Optional[bool] = False,
temperature: Optional[float] = 1.0,
top_k: Optional[int] = 50,
top_p: Optional[float] = 1.0,
)
Generates sequences using the XPegasusForGeneration model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_ids |
Tensor
|
The input token IDs. |
required |
num_beams |
int
|
The number of beams for beam search. Defaults to 5. |
5
|
decoder_start_token_id |
int
|
The decoder's start token ID. Defaults to 2. |
0
|
decoder_end_token_id |
int or List[int]
|
The decoder's end token ID. Defaults to 2. |
1
|
num_return_sequences |
int
|
The number of generated sequences to return. Defaults to 1. |
1
|
min_gen_seq_length |
int
|
The minimum length of the generated sequences. Defaults to 0. |
0
|
max_gen_seq_length |
int
|
The maximum length of the generated sequences. Defaults to 48. |
48
|
repetition_penalty |
float
|
The repetition penalty. Defaults to 1.0. |
1.0
|
no_repeat_ngram_size |
int
|
The size of n-grams to avoid repeating. Defaults to 0. |
0
|
early_stopping |
bool
|
Whether to stop generation early. Defaults to True. |
True
|
length_penalty |
float
|
The length penalty. Defaults to 1.0. |
1.0
|
num_beam_groups |
int
|
The number of beam groups for diverse beam search. Defaults to 1. |
1
|
diversity_penalty |
float
|
The diversity penalty. Defaults to 0.0. |
0.0
|
do_sample |
bool
|
Whether to use sampling for generation. Defaults to False. |
False
|
temperature |
float
|
The temperature for sampling. Defaults to 1.0. |
1.0
|
top_k |
int
|
The value for top-k sampling. Defaults to 50. |
50
|
top_p |
float
|
The value for top-p (nucleus) sampling. Defaults to 1.0. |
1.0
|
Returns:
Name | Type | Description |
---|---|---|
GenericOutputs |
The generated sequences and their scores. |
Source code in src/unitorch/models/xpegasus/modeling.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
|