Skip to content

unitorch.models¤

CheckpointMixin¤

Mixin that adds checkpoint save/load and pretrained-weight loading to a model.

checkpoint_name class-attribute instance-attribute ¤

checkpoint_name = 'pytorch_model.bin'

replace_keys_in_state_dict class-attribute instance-attribute ¤

replace_keys_in_state_dict: Dict[str, str] = {}

prefix_keys_in_state_dict class-attribute instance-attribute ¤

prefix_keys_in_state_dict: Dict[str, str] = {}

from_checkpoint ¤

from_checkpoint(
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs
) -> None

Load model weights from ckpt_dir.

Source code in src/unitorch/models/__init__.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def from_checkpoint(
    self,
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs,
) -> None:
    """Load model weights from *ckpt_dir*."""
    weight_name = weight_name or self.checkpoint_name
    weight_path = os.path.join(ckpt_dir, weight_name)
    if not os.path.exists(weight_path):
        return
    if weight_path.endswith(".safetensors"):
        state_dict = safetensors.torch.load_file(weight_path)
    else:
        state_dict = torch.load(weight_path, map_location="cpu", weights_only=False)
    self.load_state_dict(state_dict)
    logging.info("%s loaded weights from %s", type(self).__name__, weight_path)

save_checkpoint ¤

save_checkpoint(
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs
) -> None

Save model weights to ckpt_dir.

Source code in src/unitorch/models/__init__.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def save_checkpoint(
    self,
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs,
) -> None:
    """Save model weights to *ckpt_dir*."""
    weight_name = weight_name or self.checkpoint_name
    weight_path = os.path.join(ckpt_dir, weight_name)
    state_dict = self.state_dict()
    if weight_path.endswith(".safetensors"):
        safetensors.torch.save_file(state_dict, weight_path)
    else:
        torch.save(state_dict, weight_path)
    logging.info("%s saved checkpoint to %s", type(self).__name__, weight_path)

from_pretrained ¤

from_pretrained(
    weight_path: Optional[Union[str, List[str]]] = None,
    state_dict: Optional[Union[Dict, List[Dict]]] = None,
    replace_keys: Optional[Dict[str, str]] = None,
    prefix_keys: Optional[Dict[str, str]] = None,
) -> None

Load pretrained weights into the model.

Parameters:

Name Type Description Default
weight_path Optional[Union[str, List[str]]]

Path(s) to pretrained weight file(s).

None
state_dict Optional[Union[Dict, List[Dict]]]

Pretrained state dict(s) to load from.

None
replace_keys Optional[Dict[str, str]]

Regex substitution rules {pattern: replacement} applied to each key before matching.

None
prefix_keys Optional[Dict[str, str]]

Regex prefix rules {pattern: prefix} — the first matching pattern prepends prefix to the key.

None
Source code in src/unitorch/models/__init__.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def from_pretrained(
    self,
    weight_path: Optional[Union[str, List[str]]] = None,
    state_dict: Optional[Union[Dict, List[Dict]]] = None,
    replace_keys: Optional[Dict[str, str]] = None,
    prefix_keys: Optional[Dict[str, str]] = None,
) -> None:
    """Load pretrained weights into the model.

    Args:
        weight_path: Path(s) to pretrained weight file(s).
        state_dict: Pretrained state dict(s) to load from.
        replace_keys: Regex substitution rules ``{pattern: replacement}``
            applied to each key before matching.
        prefix_keys: Regex prefix rules ``{pattern: prefix}`` — the first
            matching pattern prepends *prefix* to the key.
    """
    assert weight_path or state_dict, "weight_path or state_dict must be provided"

    replace_keys = {**self.replace_keys_in_state_dict, **(replace_keys or {})}
    prefix_keys = {**self.prefix_keys_in_state_dict, **(prefix_keys or {})}

    state_dicts: List[Dict] = []
    if weight_path:
        if isinstance(weight_path, str):
            weight_path = [weight_path]
        for path in weight_path:
            logging.debug("Loading weights from %s", path)
        state_dicts += [load_weight(p) for p in weight_path]
    if state_dict:
        state_dicts += state_dict if isinstance(state_dict, list) else [state_dict]

    self_state_dict = self.state_dict()
    load_keys: List[str] = []

    for sd in state_dicts:
        if not sd:
            continue
        for key, value in sd.items():
            for rkey, pfx in prefix_keys.items():
                if re.match(rkey, key):
                    key = pfx + key
                    break
            for rkey, nkey in replace_keys.items():
                key = re.sub(rkey, nkey, key)
            if key in self_state_dict and value.shape == self_state_dict[key].shape:
                self_state_dict[key] = value
                if key not in load_keys:
                    load_keys.append(key)
            else:
                logging.debug(
                    "Key %s with shape %s does not match model shape %s",
                    key,
                    value.shape,
                    self_state_dict.get(key, torch.empty(0)).shape,
                )

    self.load_state_dict(self_state_dict, strict=False)
    missed_keys = set(self_state_dict.keys()) - set(load_keys)
    for key in missed_keys:
        logging.debug(
            "%s key %s not in pretrained weights (shape %s)",
            type(self).__name__,
            key,
            self_state_dict[key].shape,
        )
    load_percent = (
        len(load_keys) / len(self_state_dict) * 100 if self_state_dict else 0
    )
    logging.info("%s loaded weights (%d%%)", type(self).__name__, int(load_percent))

GenericModel¤

Bases: Module, CheckpointMixin

Base class for all unitorch models.

Source code in src/unitorch/models/__init__.py
140
141
def __init__(self) -> None:
    super().__init__()

dtype property ¤

dtype: dtype

Data type of the model's parameters.

device property ¤

device: device

Device of the model's parameters.

_init_weights ¤

_init_weights(module: Module) -> None
Source code in src/unitorch/models/__init__.py
143
144
145
146
147
148
149
150
def _init_weights(self, module: nn.Module) -> None:
    if isinstance(module, (nn.Linear, nn.Embedding)):
        module.weight.data.normal_(mean=0.0, std=0.02)
    if isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()

init_weights ¤

init_weights() -> None

Initialise all submodule weights with the default scheme.

Source code in src/unitorch/models/__init__.py
152
153
154
def init_weights(self) -> None:
    """Initialise all submodule weights with the default scheme."""
    self.apply(self._init_weights)

HfTextGenerationProcessor¤

Processor for encoder-decoder text generation tasks.

Source code in src/unitorch/models/processing_utils.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(
    self,
    tokenizer: PreTrainedTokenizer,
    max_seq_length: int = 128,
    max_gen_seq_length: int = 48,
) -> None:
    self.tokenizer = tokenizer
    self.max_seq_length = max_seq_length
    self.max_gen_seq_length = max_gen_seq_length
    self.pad_token = tokenizer.pad_token
    self.bos_token = tokenizer.bos_token
    self.eos_token = tokenizer.eos_token
    self.mask_token = tokenizer.mask_token
    self.pad_token_id = tokenizer.pad_token_id
    self.vocab_size = len(tokenizer.get_vocab())

tokenizer instance-attribute ¤

tokenizer = tokenizer

max_seq_length instance-attribute ¤

max_seq_length = max_seq_length

max_gen_seq_length instance-attribute ¤

max_gen_seq_length = max_gen_seq_length

pad_token instance-attribute ¤

pad_token = pad_token

bos_token instance-attribute ¤

bos_token = bos_token

eos_token instance-attribute ¤

eos_token = eos_token

mask_token instance-attribute ¤

mask_token = mask_token

pad_token_id instance-attribute ¤

pad_token_id = pad_token_id

vocab_size instance-attribute ¤

vocab_size = len(get_vocab())

generation_inputs ¤

generation_inputs(
    text: str, max_seq_length: Optional[int] = None
) -> GenericOutputs

Tokenise text into padded encoder input IDs and attention mask.

Source code in src/unitorch/models/processing_utils.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def generation_inputs(
    self,
    text: str,
    max_seq_length: Optional[int] = None,
) -> GenericOutputs:
    """Tokenise *text* into padded encoder input IDs and attention mask."""
    max_seq_length = pop_value(max_seq_length, self.max_seq_length)
    tokens = self.tokenizer.tokenize(str(text))
    tokens = tokens[:max_seq_length]
    if self.bos_token is not None:
        tokens = [self.bos_token] + tokens[: max_seq_length - 1]
    if self.eos_token is not None:
        tokens = tokens[: max_seq_length - 1] + [self.eos_token]

    input_ids = self.tokenizer.convert_tokens_to_ids(tokens)[:max_seq_length]
    attention_mask = [1] * len(input_ids)
    pad_len = max_seq_length - len(input_ids)
    input_ids += [self.pad_token_id] * pad_len
    attention_mask += [0] * pad_len

    assert len(input_ids) == max_seq_length
    assert len(attention_mask) == max_seq_length
    return GenericOutputs(
        input_ids=torch.tensor(input_ids, dtype=torch.long),
        attention_mask=torch.tensor(attention_mask, dtype=torch.long),
    )

generation_labels ¤

generation_labels(
    text: str, max_gen_seq_length: Optional[int] = None
) -> GenericOutputs

Tokenise text into padded decoder label IDs and attention mask.

Source code in src/unitorch/models/processing_utils.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def generation_labels(
    self,
    text: str,
    max_gen_seq_length: Optional[int] = None,
) -> GenericOutputs:
    """Tokenise *text* into padded decoder label IDs and attention mask."""
    max_gen_seq_length = pop_value(max_gen_seq_length, self.max_gen_seq_length)
    tokens = self.tokenizer.tokenize(str(text))
    tokens = tokens[:max_gen_seq_length]
    if self.bos_token is not None:
        tokens = [self.bos_token] + tokens[: max_gen_seq_length - 1]
    if self.eos_token is not None:
        tokens = tokens[: max_gen_seq_length - 1] + [self.eos_token]

    input_ids = self.tokenizer.convert_tokens_to_ids(tokens)[1:max_gen_seq_length]
    attention_mask = [1] * len(input_ids)
    pad_len = max_gen_seq_length - len(input_ids)
    input_ids += [self.pad_token_id] * pad_len
    attention_mask += [0] * pad_len

    assert len(input_ids) == max_gen_seq_length
    assert len(attention_mask) == max_gen_seq_length
    return GenericOutputs(
        input_ids=torch.tensor(input_ids, dtype=torch.long),
        attention_mask=torch.tensor(attention_mask, dtype=torch.long),
    )

generation ¤

generation(
    text: str,
    text_pair: str,
    max_seq_length: Optional[int] = None,
    max_gen_seq_length: Optional[int] = None,
) -> GenericOutputs

Return encoder inputs, decoder inputs, and decoder labels for a text pair.

Source code in src/unitorch/models/processing_utils.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def generation(
    self,
    text: str,
    text_pair: str,
    max_seq_length: Optional[int] = None,
    max_gen_seq_length: Optional[int] = None,
) -> GenericOutputs:
    """Return encoder inputs, decoder inputs, and decoder labels for a text pair."""
    max_seq_length = pop_value(max_seq_length, self.max_seq_length)
    max_gen_seq_length = pop_value(max_gen_seq_length, self.max_gen_seq_length)
    tokens = self.generation_inputs(text, max_seq_length)
    tokens_pair = self.generation_inputs(text_pair, max_gen_seq_length)
    labels = self.generation_labels(text_pair, max_gen_seq_length)
    return GenericOutputs(
        input_ids=tokens.input_ids,
        attention_mask=tokens.attention_mask,
        input_ids_pair=tokens_pair.input_ids,
        attention_mask_pair=tokens_pair.attention_mask,
        input_ids_label=labels.input_ids,
        attention_mask_label=labels.attention_mask,
    )

detokenize ¤

detokenize(
    sequences: Tensor, skip_special_tokens: bool = True
) -> list

Decode a 2-D or 3-D token-ID tensor back to strings.

Source code in src/unitorch/models/processing_utils.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def detokenize(
    self,
    sequences: torch.Tensor,
    skip_special_tokens: bool = True,
) -> list:
    """Decode a 2-D or 3-D token-ID tensor back to strings."""
    if sequences.dim() == 3:
        _, num_return_sequences, seq_len = sequences.size()
        sequences = sequences.reshape(-1, seq_len).clamp(0, self.vocab_size)
        sequences[sequences == self.vocab_size] = self.pad_token_id
        decoded = self.tokenizer.batch_decode(
            sequences, skip_special_tokens=skip_special_tokens
        )
        return [
            decoded[i : i + num_return_sequences]
            for i in range(0, len(decoded), num_return_sequences)
        ]
    elif sequences.dim() == 2:
        sequences = sequences.clamp(0, self.vocab_size)
        sequences[sequences == self.vocab_size] = self.pad_token_id
        return self.tokenizer.batch_decode(
            sequences, skip_special_tokens=skip_special_tokens
        )
    else:
        raise ValueError(f"Cannot decode tensor with shape {sequences.shape}")

HfTextClassificationProcessor¤

Processor for BERT-style text classification tasks.

Source code in src/unitorch/models/processing_utils.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def __init__(
    self,
    tokenizer: PreTrainedTokenizer,
    max_seq_length: int = 128,
    source_type_id: int = 0,
    target_type_id: int = 1,
    position_start_id: int = 0,
) -> None:
    self.tokenizer = tokenizer
    self.max_seq_length = max_seq_length
    self.pad_token = tokenizer.pad_token
    self.sep_token = tokenizer.sep_token
    self.cls_token = tokenizer.cls_token
    self.mask_token = tokenizer.mask_token
    self.pad_token_id = tokenizer.pad_token_id
    self.source_type_id = source_type_id
    self.target_type_id = target_type_id
    self.position_start_id = position_start_id

tokenizer instance-attribute ¤

tokenizer = tokenizer

max_seq_length instance-attribute ¤

max_seq_length = max_seq_length

pad_token instance-attribute ¤

pad_token = pad_token

sep_token instance-attribute ¤

sep_token = sep_token

cls_token instance-attribute ¤

cls_token = cls_token

mask_token instance-attribute ¤

mask_token = mask_token

pad_token_id instance-attribute ¤

pad_token_id = pad_token_id

source_type_id instance-attribute ¤

source_type_id = source_type_id

target_type_id instance-attribute ¤

target_type_id = target_type_id

position_start_id instance-attribute ¤

position_start_id = position_start_id

classification ¤

classification(
    text: str,
    text_pair: Optional[str] = None,
    max_seq_length: Optional[int] = None,
) -> GenericOutputs

Tokenise text (and optional pair) for sequence classification.

Source code in src/unitorch/models/processing_utils.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
def classification(
    self,
    text: str,
    text_pair: Optional[str] = None,
    max_seq_length: Optional[int] = None,
) -> GenericOutputs:
    """Tokenise text (and optional pair) for sequence classification."""
    max_seq_length = pop_value(max_seq_length, self.max_seq_length)
    tokens = self.tokenizer.tokenize(str(text))

    if text_pair is None:
        if self.cls_token is not None:
            tokens = (
                [self.cls_token] + tokens[: max_seq_length - 2] + [self.sep_token]
            )
        else:
            tokens = tokens[: max_seq_length - 1] + [self.sep_token]
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        token_type_ids = [self.source_type_id] * len(input_ids)
        attention_mask = [1] * len(input_ids)
    else:
        tokens_pair = self.tokenizer.tokenize(str(text_pair))
        if self.cls_token is not None:
            truncate_sequence_pair(tokens, tokens_pair, max_seq_length - 3)
            token_type_ids = [self.source_type_id] * (len(tokens) + 2) + [
                self.target_type_id
            ] * (len(tokens_pair) + 1)
            tokens = (
                [self.cls_token]
                + tokens
                + [self.sep_token]
                + tokens_pair
                + [self.sep_token]
            )
        else:
            truncate_sequence_pair(tokens, tokens_pair, max_seq_length - 2)
            token_type_ids = [self.source_type_id] * len(tokens) + [
                self.target_type_id
            ] * len(tokens_pair)
            tokens = tokens + [self.sep_token] + tokens_pair + [self.sep_token]
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        attention_mask = [1] * len(input_ids)

    pad_len = max_seq_length - len(input_ids)
    input_ids += [self.pad_token_id] * pad_len
    attention_mask += [0] * pad_len
    token_type_ids += [self.target_type_id] * pad_len

    assert len(input_ids) == max_seq_length
    assert len(attention_mask) == max_seq_length
    assert len(token_type_ids) == max_seq_length
    return GenericOutputs(
        input_ids=torch.tensor(input_ids, dtype=torch.long),
        token_type_ids=torch.tensor(token_type_ids, dtype=torch.long),
        attention_mask=torch.tensor(attention_mask, dtype=torch.long),
        position_ids=torch.arange(
            self.position_start_id,
            self.position_start_id + max_seq_length,
            dtype=torch.long,
        ),
    )

HfImageClassificationProcessor¤

Processor for image classification tasks using a HuggingFace vision processor.

Source code in src/unitorch/models/processing_utils.py
412
413
def __init__(self, vision_processor: BaseImageProcessor) -> None:
    self.vision_processor = vision_processor

vision_processor instance-attribute ¤

vision_processor = vision_processor

classification ¤

classification(image: Union[Image, str]) -> GenericOutputs

Preprocess image into pixel values ready for a vision model.

Source code in src/unitorch/models/processing_utils.py
415
416
417
418
419
420
421
422
def classification(self, image: Union[Image.Image, str]) -> GenericOutputs:
    """Preprocess *image* into pixel values ready for a vision model."""
    if isinstance(image, str):
        image = Image.open(image)
    pixel_values = self.vision_processor.preprocess(
        image, return_tensors="pt"
    ).pixel_values[0]
    return GenericOutputs(pixel_values=pixel_values)

ExponentialMovingAverage¤

Bases: Module

Exponential Moving Average (EMA) wrapper for a model's parameters.

The effective decay at step t is decay * (1 - exp(-t / tau)), which ramps from 0 up to decay so early updates are not over-smoothed.

Source code in src/unitorch/models/modeling_ema.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(
    self,
    model: nn.Module,
    decay: float = 0.9999,
    tau: int = 2000,
    num_steps: int = 0,
) -> None:
    super().__init__()
    self.model = deepcopy(model)
    self.num_steps = num_steps
    self._decay_fn = lambda x: decay * (1 - math.exp(-x / tau))

    for p in self.model.parameters():
        p.requires_grad_(False)

checkpoint_name class-attribute instance-attribute ¤

checkpoint_name = 'pytorch_ema_model.bin'

model instance-attribute ¤

model = deepcopy(model)

num_steps instance-attribute ¤

num_steps = num_steps

_decay_fn instance-attribute ¤

_decay_fn = lambda x: decay * (1 - exp(-x / tau))

from_checkpoint ¤

from_checkpoint(
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs
) -> None

Load EMA weights from ckpt_dir.

Source code in src/unitorch/models/modeling_ema.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def from_checkpoint(
    self,
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs,
) -> None:
    """Load EMA weights from *ckpt_dir*."""
    weight_name = weight_name or self.checkpoint_name
    weight_path = os.path.join(ckpt_dir, weight_name)
    if not os.path.exists(weight_path):
        return
    self.model.load_state_dict(
        torch.load(weight_path, map_location="cpu", weights_only=False)
    )
    logging.info("%s loaded weights from %s", type(self).__name__, weight_path)

save_checkpoint ¤

save_checkpoint(
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs
) -> None

Save EMA weights to ckpt_dir.

Source code in src/unitorch/models/modeling_ema.py
54
55
56
57
58
59
60
61
62
63
64
def save_checkpoint(
    self,
    ckpt_dir: str,
    weight_name: Optional[str] = None,
    **kwargs,
) -> None:
    """Save EMA weights to *ckpt_dir*."""
    weight_name = weight_name or self.checkpoint_name
    weight_path = os.path.join(ckpt_dir, weight_name)
    torch.save(self.model.state_dict(), weight_path)
    logging.info("%s saved checkpoint to %s", type(self).__name__, weight_path)

forward ¤

forward(*args, **kwargs)

Delegate forward pass to the EMA model.

Source code in src/unitorch/models/modeling_ema.py
66
67
68
def forward(self, *args, **kwargs):
    """Delegate forward pass to the EMA model."""
    return self.model(*args, **kwargs)

step ¤

step(model: Module) -> None

Update EMA parameters with one step from model.

Source code in src/unitorch/models/modeling_ema.py
70
71
72
73
74
75
76
77
78
@torch.no_grad()
def step(self, model: nn.Module) -> None:
    """Update EMA parameters with one step from *model*."""
    self.num_steps += 1
    rate = self._decay_fn(self.num_steps)
    new_state = model.state_dict()
    for key, value in self.model.state_dict().items():
        if value.dtype.is_floating_point:
            value.mul_(rate).add_(new_state[key].detach(), alpha=1 - rate)