Skip to content

unitorch.cli.models.kolors¤

KolorsMPSProcessor¤

Tip

core/process/kolors/mps is the section for configuration of KolorsMPSProcessor.

Bases: KolorsMPSProcessor

Processor for the Kolors MPS model.

Source code in src/unitorch/cli/models/kolors/processing.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(
    self,
    vocab_path: str,
    merge_path: str,
    vision_config_path: str,
    max_seq_length: Optional[int] = 77,
    position_start_id: Optional[int] = 0,
):
    super().__init__(
        vocab_path=vocab_path,
        merge_path=merge_path,
        vision_config_path=vision_config_path,
        max_seq_length=max_seq_length,
        position_start_id=position_start_id,
    )

from_config classmethod ¤

from_config(config, **kwargs)
Source code in src/unitorch/cli/models/kolors/processing.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@classmethod
@config_defaults_init("core/process/kolors/mps")
def from_config(cls, config, **kwargs):
    config.set_default_section("core/process/koors/mps")
    pretrained_name = config.getoption("pretrained_name", "kolors-mps-overall")
    vocab_path = config.getoption("vocab_path", None)
    vocab_path = pop_value(
        vocab_path,
        nested_dict_value(pretrained_kolors_infos, pretrained_name, "vocab"),
    )
    vocab_path = cached_path(vocab_path)

    merge_path = config.getoption("merge_path", None)
    merge_path = pop_value(
        merge_path,
        nested_dict_value(pretrained_kolors_infos, pretrained_name, "merge"),
    )
    merge_path = cached_path(merge_path)

    vision_config_path = config.getoption("vision_config_path", None)
    vision_config_path = pop_value(
        vision_config_path,
        nested_dict_value(
            pretrained_kolors_infos, pretrained_name, "vision_config"
        ),
    )

    vision_config_path = cached_path(vision_config_path)

    return {
        "vocab_path": vocab_path,
        "merge_path": merge_path,
        "vision_config_path": vision_config_path,
    }

_classification ¤

_classification(
    text: str,
    image: Union[Image, str],
    condition: str,
    max_seq_length: Optional[int] = None,
)
Source code in src/unitorch/cli/models/kolors/processing.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
@register_process("core/process/kolors/mps/classification")
def _classification(
    self,
    text: str,
    image: Union[Image.Image, str],
    condition: str,
    max_seq_length: Optional[int] = None,
):
    if isinstance(image, str):
        image = Image.open(image)

    outputs = super().classification(
        text=text,
        image=image,
        condition=condition,
        max_seq_length=max_seq_length,
    )
    return TensorInputs(
        input_ids=outputs.input_ids,
        attention_mask=outputs.attention_mask,
        position_ids=outputs.position_ids,
        pixel_values=outputs.pixel_values,
        condition_input_ids=outputs.condition_input_ids,
        condition_attention_mask=outputs.condition_attention_mask,
        condition_position_ids=outputs.condition_position_ids,
    )

KolorsMPSModel¤

Tip

core/model/classification/kolors/mps is the section for configuration of KolorsMPSModel.

Bases: KolorsMPSModel

Kolors MPS model for image-text classification.

Source code in src/unitorch/cli/models/kolors/modeling.py
24
25
26
27
def __init__(self, config_path: str):
    super().__init__(
        config_path=config_path,
    )

from_config classmethod ¤

from_config(config, **kwargs)
Source code in src/unitorch/cli/models/kolors/modeling.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@classmethod
@config_defaults_init("core/model/classification/kolors/mps")
def from_config(cls, config, **kwargs):
    config.set_default_section("core/model/classification/kolors/mps")
    pretrained_name = config.getoption("pretrained_name", "kolors-mps-overall")
    config_path = config.getoption("config_path", None)
    config_path = pop_value(
        config_path,
        nested_dict_value(pretrained_kolors_infos, pretrained_name, "config"),
    )

    config_path = cached_path(config_path)

    inst = cls(
        config_path=config_path,
    )
    pretrained_weight_path = config.getoption("pretrained_weight_path", None)
    weight_path = pop_value(
        pretrained_weight_path,
        nested_dict_value(pretrained_kolors_infos, pretrained_name, "weight"),
        check_none=False,
    )
    if weight_path is not None:
        inst.from_pretrained(weight_path)

    return inst

forward ¤

forward(
    input_ids: Tensor,
    pixel_values: Tensor,
    condition_input_ids: Tensor,
    attention_mask: Optional[Tensor] = None,
    position_ids: Optional[Tensor] = None,
    condition_attention_mask: Optional[Tensor] = None,
    condition_position_ids: Optional[Tensor] = None,
    labels: Optional[Tensor] = None,
)
Source code in src/unitorch/cli/models/kolors/modeling.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"))
def forward(
    self,
    input_ids: torch.Tensor,
    pixel_values: torch.Tensor,
    condition_input_ids: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    condition_attention_mask: Optional[torch.Tensor] = None,
    condition_position_ids: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
):
    if self.training:
        assert labels is not None
        outputs = super().forward(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            position_ids=position_ids,
            condition_input_ids=condition_input_ids,
            condition_attention_mask=condition_attention_mask,
            condition_position_ids=condition_position_ids,
            labels=labels,
        )
        return LossOutputs(outputs=outputs)

    outputs = super().forward(
        input_ids=input_ids,
        pixel_values=pixel_values,
        attention_mask=attention_mask,
        position_ids=position_ids,
        condition_input_ids=condition_input_ids,
        condition_attention_mask=condition_attention_mask,
        condition_position_ids=condition_position_ids,
    )
    return ClassificationOutputs(outputs=outputs)