Skip to content

unitorch.models¤

CheckpointMixin¤

from_checkpoint ¤

from_checkpoint(
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs
)

Load model weights from a checkpoint.

Parameters:

Name Type Description Default
ckpt_dir str

Directory path of the checkpoint.

required
weight_name str

Name of the weight file.

None

Returns:

Type Description

None

Source code in src/unitorch/models/__init__.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def from_checkpoint(
    self,
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs,
):
    """
    Load model weights from a checkpoint.

    Args:
        ckpt_dir (str): Directory path of the checkpoint.
        weight_name (str): Name of the weight file.

    Returns:
        None
    """
    if weight_name is None:
        weight_name = self.checkpoint_name
    weight_path = os.path.join(ckpt_dir, weight_name)
    if not os.path.exists(weight_path):
        return

    if weight_path.endswith(".safetensors"):
        state_dict = safetensors.torch.load_file(weight_path)
    else:
        state_dict = torch.load(weight_path, map_location="cpu")
    self.load_state_dict(state_dict)
    logging.info(
        f"{type(self).__name__} model load weight from checkpoint {weight_path}"
    )

from_pretrained ¤

from_pretrained(
    weight_path: Union[str, List[str]] = None,
    state_dict: Union[Dict, List[Dict]] = None,
    replace_keys: Optional[Dict] = dict(),
    prefix_keys: Optional[Dict] = dict(),
)

Load pretrained weights into the model.

Parameters:

Name Type Description Default
weight_path str or List[str]

Path(s) to the pretrained weight file(s).

None
state_dict Dict or List[Dict]

Pretrained state_dict(s) to load weights from.

None
replace_keys Dict

Dictionary mapping keys in the pretrained state_dict to the model's keys.

dict()
prefix_keys Dict

Dictionary prefix keys in the pretrained state_dict to the model's keys.

dict()

Returns:

Type Description

None

Source code in src/unitorch/models/__init__.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def from_pretrained(
    self,
    weight_path: Union[str, List[str]] = None,
    state_dict: Union[Dict, List[Dict]] = None,
    replace_keys: Optional[Dict] = dict(),
    prefix_keys: Optional[Dict] = dict(),
):
    """
    Load pretrained weights into the model.

    Args:
        weight_path (str or List[str]): Path(s) to the pretrained weight file(s).
        state_dict (Dict or List[Dict]): Pretrained state_dict(s) to load weights from.
        replace_keys (Dict): Dictionary mapping keys in the pretrained state_dict to the model's keys.
        prefix_keys (Dict): Dictionary prefix keys in the pretrained state_dict to the model's keys.

    Returns:
        None
    """
    assert weight_path or state_dict, "weight_path or state_dict must be set"

    # Load state_dict(s) based on the provided weight_path or state_dict
    state_dicts = []
    if weight_path:
        if isinstance(weight_path, str):
            weight_path = [weight_path]
        for path in weight_path:
            logging.debug(f"Loading weights from {path}")
        state_dicts += [load_weight(path) for path in weight_path]

    if state_dict:
        state_dicts += state_dict if isinstance(state_dict, list) else [state_dict]

    self_state_dict = self.state_dict()  # Get the current state_dict of the model
    load_keys = []  # Keep track of the keys loaded from the state_dict(s)
    non_load_keys = []  # Keep track of the keys not loaded from the state_dict(s)

    if isinstance(self.replace_keys_in_state_dict, dict):
        replace_keys = {**self.replace_keys_in_state_dict, **replace_keys}

    if isinstance(self.prefix_keys_in_state_dict, dict):
        prefix_keys = {**self.prefix_keys_in_state_dict, **prefix_keys}

    # Iterate over the state_dict(s) and load the matching keys into the model's state_dict
    for _state_dict in state_dicts:
        if not _state_dict:
            continue
        for key, value in list(_state_dict.items()):
            for rkey, prefix in prefix_keys.items():
                if re.match(rkey, key):
                    key = prefix + key
                    break

            for rkey, nkey in replace_keys.items():
                key = re.sub(rkey, nkey, key)
            if key in self_state_dict and value.shape == self_state_dict[key].shape:
                self_state_dict[key] = value
                if key not in load_keys:
                    load_keys.append(key)
            else:
                non_load_keys.append(key)

    self.load_state_dict(self_state_dict, False)
    load_percent = (
        len(load_keys) / len(self_state_dict) * 100
    )  # Calculate the percentage of loaded keys
    logging.debug(f"Non load keys in pretrain weights: {list(non_load_keys)}")
    logging.debug(
        f"{type(self).__name__} missed keys: {list(self_state_dict.keys() - load_keys)}"
    )
    logging.info(f"{type(self).__name__} loaded weights ({int(load_percent)}%)")

save_checkpoint ¤

save_checkpoint(
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs
)

Save the model's current state as a checkpoint.

Parameters:

Name Type Description Default
ckpt_dir str

Directory path to save the checkpoint.

required
weight_name str

Name of the weight file.

None

Returns:

Type Description

None

Source code in src/unitorch/models/__init__.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def save_checkpoint(
    self,
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs,
):
    """
    Save the model's current state as a checkpoint.

    Args:
        ckpt_dir (str): Directory path to save the checkpoint.
        weight_name (str): Name of the weight file.

    Returns:
        None
    """
    if weight_name is None:
        weight_name = self.checkpoint_name
    state_dict = self.state_dict()
    weight_path = os.path.join(ckpt_dir, weight_name)
    if weight_path.endswith(".safetensors"):
        safetensors.torch.save_file(state_dict, weight_path)
    else:
        torch.save(state_dict, weight_path)
    logging.info(f"{type(self).__name__} model save checkpoint to {weight_path}")

GenericModel¤

Bases: Module, CheckpointMixin

Source code in src/unitorch/models/__init__.py
154
155
156
def __init__(self):
    super().__init__()
    pass

device property ¤

device

Returns the device of the model's parameters.

Returns:

Type Description

torch.device: The device of the model's parameters.

dtype property ¤

dtype: dtype

Returns the data type of the model's parameters.

Returns:

Type Description
dtype

torch.dtype: The data type of the model's parameters.

init_weights ¤

init_weights()

Initialize the weights of the model.

Source code in src/unitorch/models/__init__.py
175
176
177
178
179
def init_weights(self):
    """
    Initialize the weights of the model.
    """
    self.apply(self._init_weights)

HfTextGenerationProcessor¤

Processor for text generation tasks.

Parameters:

Name Type Description Default
tokenizer PreTrainedTokenizer

The tokenizer to use for text generation.

required
max_seq_length int

Maximum sequence length. Defaults to 128.

128
max_gen_seq_length int

Maximum generated sequence length. Defaults to 48.

48
Source code in src/unitorch/models/processing_utils.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    tokenizer: PreTrainedTokenizer,
    max_seq_length: Optional[int] = 128,
    max_gen_seq_length: Optional[int] = 48,
):
    self.tokenizer = tokenizer
    self.max_seq_length = max_seq_length
    self.max_gen_seq_length = max_gen_seq_length
    self.pad_token = self.tokenizer.pad_token
    self.bos_token = self.tokenizer.bos_token
    self.eos_token = self.tokenizer.eos_token
    self.mask_token = self.tokenizer.mask_token
    self.pad_token_id = self.tokenizer.pad_token_id
    self.vocab_size = len(self.tokenizer.get_vocab())

detokenize ¤

detokenize(
    sequences: Tensor,
    skip_special_tokens: Optional[bool] = True,
)

Detokenize the sequences.

Parameters:

Name Type Description Default
sequences Tensor

The sequences to detokenize.

required
skip_special_tokens bool

Whether to skip special tokens. Defaults to True.

True

Returns:

Name Type Description
list

The detokenized sequences.

Source code in src/unitorch/models/processing_utils.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def detokenize(
    self,
    sequences: torch.Tensor,
    skip_special_tokens: Optional[bool] = True,
):
    """
    Detokenize the sequences.

    Args:
        sequences (torch.Tensor): The sequences to detokenize.
        skip_special_tokens (bool, optional): Whether to skip special tokens. Defaults to True.

    Returns:
        list: The detokenized sequences.
    """
    if sequences.dim() == 3:
        _, num_return_sequences, sequences_length = sequences.size()
        sequences = sequences.reshape(-1, sequences_length).clamp_max(
            self.vocab_size
        )
        sequences = sequences.clamp_min(0)
        sequences[sequences == self.vocab_size] = self.pad_token_id
        decode_tokens = self.tokenizer.batch_decode(
            sequences,
            skip_special_tokens=skip_special_tokens,
        )
        decode_tokens = [
            decode_tokens[i : i + num_return_sequences]
            for i in range(0, len(decode_tokens), num_return_sequences)
        ]
    elif sequences.dim() == 2:
        sequences = sequences.clamp_min(0).clamp_max(self.vocab_size)
        sequences[sequences == self.vocab_size] = self.pad_token_id
        decode_tokens = self.tokenizer.batch_decode(
            sequences,
            skip_special_tokens=skip_special_tokens,
        )
    else:
        raise ValueError(f"Can't decode the tensor with shape {sequences.shape}")

    return decode_tokens

generation ¤

generation(
    text: str,
    text_pair: str,
    max_seq_length: Optional[int] = None,
    max_gen_seq_length: Optional[int] = None,
)

Generate inputs, labels, and tokens for text generation.

Parameters:

Name Type Description Default
text str

The input text.

required
text_pair str

The paired text.

required
max_seq_length int

Maximum sequence length. Defaults to None.

None
max_gen_seq_length int

Maximum generated sequence length. Defaults to None.

None

Returns:

Name Type Description
GenericOutputs

The generated input tokens, attention masks, label tokens, and attention masks.

Source code in src/unitorch/models/processing_utils.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def generation(
    self,
    text: str,
    text_pair: str,
    max_seq_length: Optional[int] = None,
    max_gen_seq_length: Optional[int] = None,
):
    """
    Generate inputs, labels, and tokens for text generation.

    Args:
        text (str): The input text.
        text_pair (str): The paired text.
        max_seq_length (int, optional): Maximum sequence length. Defaults to None.
        max_gen_seq_length (int, optional): Maximum generated sequence length. Defaults to None.

    Returns:
        GenericOutputs: The generated input tokens, attention masks, label tokens, and attention masks.
    """
    max_seq_length = pop_value(max_seq_length, self.max_seq_length)
    max_gen_seq_length = pop_value(max_gen_seq_length, self.max_gen_seq_length)

    tokens = self.generation_inputs(text, max_seq_length)
    tokens_pair = self.generation_inputs(text_pair, max_gen_seq_length)
    labels = self.generation_labels(text_pair, max_gen_seq_length)

    return GenericOutputs(
        input_ids=tokens.input_ids,
        attention_mask=tokens.attention_mask,
        input_ids_pair=tokens_pair.input_ids,
        attention_mask_pair=tokens_pair.attention_mask,
        input_ids_label=labels.input_ids,
        attention_mask_label=labels.attention_mask,
    )

generation_inputs ¤

generation_inputs(
    text: str, max_seq_length: Optional[int] = None
)

Generate inputs for text generation.

Parameters:

Name Type Description Default
text str

The input text.

required
max_seq_length int

Maximum sequence length. Defaults to None.

None

Returns:

Name Type Description
GenericOutputs

The generated input tokens and attention mask.

Source code in src/unitorch/models/processing_utils.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def generation_inputs(
    self,
    text: str,
    max_seq_length: Optional[int] = None,
):
    """
    Generate inputs for text generation.

    Args:
        text (str): The input text.
        max_seq_length (int, optional): Maximum sequence length. Defaults to None.

    Returns:
        GenericOutputs: The generated input tokens and attention mask.
    """
    max_seq_length = pop_value(max_seq_length, self.max_seq_length)
    tokens = self.tokenizer.tokenize(str(text))
    tokens = tokens[: max_seq_length - 2]
    tokens = [self.bos_token] + tokens + [self.eos_token]
    input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
    input_ids = input_ids[:max_seq_length]
    attention_mask = [1] * len(input_ids)

    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += [self.pad_token_id] * len(padding)
    attention_mask += padding

    assert len(input_ids) == max_seq_length
    assert len(attention_mask) == max_seq_length

    return GenericOutputs(
        input_ids=torch.tensor(input_ids, dtype=torch.long),
        attention_mask=torch.tensor(attention_mask, dtype=torch.long),
    )

generation_labels ¤

generation_labels(
    text: str, max_gen_seq_length: Optional[int] = None
)

Generate labels for text generation.

Parameters:

Name Type Description Default
text str

The input text.

required
max_gen_seq_length int

Maximum generated sequence length. Defaults to None.

None

Returns:

Name Type Description
GenericOutputs

The generated label tokens and attention mask.

Source code in src/unitorch/models/processing_utils.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def generation_labels(
    self,
    text: str,
    max_gen_seq_length: Optional[int] = None,
):
    """
    Generate labels for text generation.

    Args:
        text (str): The input text.
        max_gen_seq_length (int, optional): Maximum generated sequence length. Defaults to None.

    Returns:
        GenericOutputs: The generated label tokens and attention mask.
    """
    max_gen_seq_length = pop_value(max_gen_seq_length, self.max_gen_seq_length)
    tokens = self.tokenizer.tokenize(str(text))
    tokens = tokens[: max_gen_seq_length - 2]
    tokens = [self.bos_token] + tokens + [self.eos_token]
    input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
    input_ids = input_ids[1:max_gen_seq_length]
    attention_mask = [1] * len(input_ids)

    padding = [0] * (max_gen_seq_length - len(input_ids))
    input_ids += [self.pad_token_id] * len(padding)
    attention_mask += padding

    assert len(input_ids) == max_gen_seq_length

    return GenericOutputs(
        input_ids=torch.tensor(input_ids, dtype=torch.long),
        attention_mask=torch.tensor(attention_mask, dtype=torch.long),
    )

HfTextClassificationProcessor¤

Processor for text classification tasks.

Parameters:

Name Type Description Default
tokenizer PreTrainedTokenizer

The tokenizer to use for text classification.

required
max_seq_length int

Maximum sequence length. Defaults to 128.

128
source_type_id int

Source type ID. Defaults to 0.

0
target_type_id int

Target type ID. Defaults to 1.

1
position_start_id int

Start position ID. Defaults to 0.

0
Source code in src/unitorch/models/processing_utils.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def __init__(
    self,
    tokenizer: PreTrainedTokenizer,
    max_seq_length: Optional[int] = 128,
    source_type_id: Optional[int] = 0,
    target_type_id: Optional[int] = 1,
    position_start_id: Optional[int] = 0,
):
    self.tokenizer = tokenizer
    self.max_seq_length = max_seq_length
    self.pad_token = self.tokenizer.pad_token
    self.sep_token = self.tokenizer.sep_token
    self.cls_token = self.tokenizer.cls_token
    self.mask_token = self.tokenizer.mask_token
    self.pad_token_id = self.tokenizer.pad_token_id
    self.source_type_id = source_type_id
    self.target_type_id = target_type_id
    self.position_start_id = position_start_id

classification ¤

classification(
    text: str,
    text_pair: Optional[str] = None,
    max_seq_length: Optional[int] = None,
)

Generate inputs for text classification.

Parameters:

Name Type Description Default
text str

The input text.

required
text_pair str

The paired text. Defaults to None.

None
max_seq_length int

Maximum sequence length. Defaults to None.

None

Returns:

Name Type Description
GenericOutputs

The generated input tokens, token type IDs, attention mask, and position IDs.

Source code in src/unitorch/models/processing_utils.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def classification(
    self,
    text: str,
    text_pair: Optional[str] = None,
    max_seq_length: Optional[int] = None,
):
    """
    Generate inputs for text classification.

    Args:
        text (str): The input text.
        text_pair (str, optional): The paired text. Defaults to None.
        max_seq_length (int, optional): Maximum sequence length. Defaults to None.

    Returns:
        GenericOutputs: The generated input tokens, token type IDs, attention mask, and position IDs.
    """
    max_seq_length = pop_value(
        max_seq_length,
        self.max_seq_length,
    )

    tokens = self.tokenizer.tokenize(str(text))
    if text_pair is None:
        tokens = tokens[: max_seq_length - 2]
        tokens = [self.cls_token] + tokens + [self.sep_token]
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        token_type_ids = [self.source_type_id] * len(input_ids)
        attention_mask = [1] * len(input_ids)
    else:
        tokens_pair = self.tokenizer.tokenize(str(text_pair))
        truncate_sequence_pair(tokens, tokens_pair, max_seq_length - 3)
        token_type_ids = (
            [self.source_type_id]
            + [self.source_type_id] * len(tokens)
            + [self.source_type_id]
            + [self.target_type_id] * len(tokens_pair)
            + [self.target_type_id]
        )
        tokens = (
            [self.cls_token]
            + tokens
            + [self.sep_token]
            + tokens_pair
            + [self.sep_token]
        )
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        attention_mask = [1] * len(input_ids)

    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += len(padding) * [self.pad_token_id]
    attention_mask += padding
    token_type_ids += len(padding) * [self.target_type_id]

    assert len(input_ids) == max_seq_length
    assert len(attention_mask) == max_seq_length
    assert len(token_type_ids) == max_seq_length
    return GenericOutputs(
        input_ids=torch.tensor(input_ids, dtype=torch.long),
        token_type_ids=torch.tensor(token_type_ids, dtype=torch.long),
        attention_mask=torch.tensor(attention_mask, dtype=torch.long),
        position_ids=torch.tensor(
            list(
                range(
                    self.position_start_id,
                    self.position_start_id + max_seq_length,
                )
            ),
            dtype=torch.long,
        ),
    )

HfImageClassificationProcessor¤

Processor for image classification tasks.

Initialize the HfImageClassificationProcessor.

Parameters:

Name Type Description Default
vision_processor BaseImageProcessor

The vision processor object used for image transformations.

required
Source code in src/unitorch/models/processing_utils.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def __init__(
    self,
    vision_processor: BaseImageProcessor,
):
    """
    Initialize the HfImageClassificationProcessor.

    Args:
        vision_processor (BaseImageProcessor): The vision processor object used for image transformations.
    """
    self.vision_processor = vision_processor

    self.size = getattr(self.vision_processor, "size", None)

    self.resample = getattr(self.vision_processor, "resample", None)

    self.crop_size = getattr(self.vision_processor, "crop_size", None)
    self.pad_size = getattr(self.vision_processor, "pad_size", None)

    self.rescale_factor = getattr(self.vision_processor, "rescale_factor", None)

    self.image_mean = getattr(self.vision_processor, "image_mean", None)
    self.image_std = getattr(self.vision_processor, "image_std", None)

classification ¤

classification(image: Union[Image, str])

Perform image classification on the given image.

Parameters:

Name Type Description Default
image Image

The input image.

required

Returns:

Name Type Description
GenericOutputs

The output of the image classification, including pixel values.

Source code in src/unitorch/models/processing_utils.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def classification(
    self,
    image: Union[Image.Image, str],
):
    """
    Perform image classification on the given image.

    Args:
        image (Image.Image): The input image.

    Returns:
        GenericOutputs: The output of the image classification, including pixel values.
    """
    if isinstance(image, str):
        image = Image.open(image)

    if self.size is not None:
        image = self.vision_processor.resize(
            image=to_numpy_array(image.convert("RGB")),
            size=self.size,
            resample=self.resample,
        )

    if self.crop_size is not None:
        image = self.vision_processor.center_crop(
            image,
            size=self.crop_size,
        )

    if self.rescale_factor is not None:
        image = self.vision_processor.rescale(
            image,
            self.rescale_factor,
        )

    if self.image_mean is not None and self.image_std is not None:
        image = self.vision_processor.normalize(
            image=image,
            mean=self.image_mean,
            std=self.image_std,
        )

    if self.pad_size is not None:
        image = self.vision_processor.pad_image(
            image,
            size=self.pad_size,
        )

    image = to_channel_dimension_format(image, ChannelDimension.FIRST)

    return GenericOutputs(
        pixel_values=torch.tensor(image),
    )

ExponentialMovingAverage¤

Bases: Module

Exponential Moving Average (EMA) for model parameters.

Initializes the ExponentialMovingAverage.

Parameters:

Name Type Description Default
model Module

The model to apply EMA to.

required
decay float

Decay rate for the EMA. Defaults to 0.9999.

0.9999
tau int

Time constant for the EMA. Defaults to 2000.

2000
num_steps int

Number of steps taken for the EMA. Defaults to 0.

0
Source code in src/unitorch/models/modeling_ema.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    model,
    decay: Optional[float] = 0.9999,
    tau: Optional[int] = 2000,
    num_steps: Optional[int] = 0,
):
    """
    Initializes the ExponentialMovingAverage.

    Args:
        model (nn.Module): The model to apply EMA to.
        decay (float, optional): Decay rate for the EMA. Defaults to 0.9999.
        tau (int, optional): Time constant for the EMA. Defaults to 2000.
        num_steps (int, optional): Number of steps taken for the EMA. Defaults to 0.
    """
    super().__init__()
    self.model = deepcopy(model)
    self.num_steps = num_steps
    self.decay = lambda x: decay * (1 - math.exp(-x / tau))

    for p in self.model.parameters():
        p.requires_grad = False

forward ¤

forward(*args, **kwargs)

Forward pass through the model.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}

Returns:

Type Description

The output of the model.

Source code in src/unitorch/models/modeling_ema.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def forward(self, *args, **kwargs):
    """
    Forward pass through the model.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Returns:
        The output of the model.
    """
    return self.model(*args, **kwargs)

from_checkpoint ¤

from_checkpoint(
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs
)

Load model weights from a checkpoint.

Parameters:

Name Type Description Default
ckpt_dir str

Directory path of the checkpoint.

required
weight_name str

Name of the weight file.

None

Returns:

Type Description

None

Source code in src/unitorch/models/modeling_ema.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def from_checkpoint(
    self,
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs,
):
    """
    Load model weights from a checkpoint.

    Args:
        ckpt_dir (str): Directory path of the checkpoint.
        weight_name (str): Name of the weight file.

    Returns:
        None
    """
    if weight_name is None:
        weight_name = self.checkpoint_name
    weight_path = os.path.join(ckpt_dir, weight_name)
    if not os.path.exists(weight_path):
        return
    state_dict = torch.load(weight_path, map_location="cpu")
    self.model.load_state_dict(state_dict)
    logging.info(
        f"{type(self).__name__} model load weight from checkpoint {weight_path}"
    )

save_checkpoint ¤

save_checkpoint(
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs
)

Save the model's current state as a checkpoint.

Parameters:

Name Type Description Default
ckpt_dir str

Directory path to save the checkpoint.

required
weight_name str

Name of the weight file.

None

Returns:

Type Description

None

Source code in src/unitorch/models/modeling_ema.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def save_checkpoint(
    self,
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs,
):
    """
    Save the model's current state as a checkpoint.

    Args:
        ckpt_dir (str): Directory path to save the checkpoint.
        weight_name (str): Name of the weight file.

    Returns:
        None
    """
    if weight_name is None:
        weight_name = self.checkpoint_name
    state_dict = self.model.state_dict()
    weight_path = os.path.join(ckpt_dir, weight_name)
    torch.save(state_dict, weight_path)
    logging.info(f"{type(self).__name__} model save checkpoint to {weight_path}")

step ¤

step(model)

Performs a step of EMA.

Parameters:

Name Type Description Default
model Module

The model to update the EMA with.

required
Source code in src/unitorch/models/modeling_ema.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@torch.no_grad()
def step(self, model):
    """
    Performs a step of EMA.

    Args:
        model (nn.Module): The model to update the EMA with.
    """
    self.num_steps += 1
    rate = self.decay(self.num_steps)

    new_state = model.state_dict()
    for key, value in self.model.state_dict().items():
        if not value.dtype.is_floating_point:
            continue
        value *= rate
        value += (1 - rate) * new_state[key].detach()