Skip to content

unitorch.cli.models.segformer¤

SegformerProcessor¤

Tip

core/process/segformer is the section for configuration of SegformerProcessor.

Bases: SegformerProcessor

Segformer processor for image segmentation tasks.

Source code in src/unitorch/cli/models/segformer/processing.py
21
22
23
24
25
26
27
def __init__(
    self,
    vision_config_path: str,
):
    super().__init__(
        vision_config_path=vision_config_path,
    )

from_config classmethod ¤

from_config(config, **kwargs)
Source code in src/unitorch/cli/models/segformer/processing.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@classmethod
@config_defaults_init("core/process/segformer")
def from_config(cls, config, **kwargs):
    config.set_default_section("core/process/segformer")
    pretrained_name = config.getoption(
        "pretrained_name", "segformer-b2-human-parse-24"
    )
    vision_config_path = config.getoption("vision_config_path", None)
    vision_config_path = pop_value(
        vision_config_path,
        nested_dict_value(
            pretrained_segformer_infos, pretrained_name, "vision_config"
        ),
    )

    vision_config_path = cached_path(vision_config_path)

    return {
        "vision_config_path": vision_config_path,
    }

_segmentation_inputs ¤

_segmentation_inputs(image: Union[Image, str])
Source code in src/unitorch/cli/models/segformer/processing.py
50
51
52
53
54
55
56
57
58
@register_process("core/process/segformer/image_segmentation")
def _segmentation_inputs(
    self,
    image: Union[Image.Image, str],
):
    if isinstance(image, str):
        image = Image.open(image)
    outputs = super().classification(image=image)
    return TensorInputs(pixel_values=outputs.pixel_values)

SegformerForSegmentation¤

Tip

core/model/segmentation/segformer is the section for configuration of SegformerForSegmentation.

Bases: SegformerForSegmentation

Source code in src/unitorch/cli/models/segformer/modeling.py
25
26
27
28
29
30
31
def __init__(
    self,
    config_path: str,
):
    super().__init__(
        config_path=config_path,
    )

from_config classmethod ¤

from_config(config, **kwargs)
Source code in src/unitorch/cli/models/segformer/modeling.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@classmethod
@config_defaults_init("core/model/segmentation/segformer")
def from_config(cls, config, **kwargs):
    config.set_default_section("core/model/segmentation/segformer")
    pretrained_name = config.getoption(
        "pretrained_name", "segformer-b2-human-parse-24"
    )
    config_path = config.getoption("config_path", None)
    config_path = pop_value(
        config_path,
        nested_dict_value(pretrained_segformer_infos, pretrained_name, "config"),
    )
    config_path = cached_path(config_path)

    inst = cls(
        config_path=config_path,
    )
    pretrained_weight_path = config.getoption("pretrained_weight_path", None)
    weight_path = pop_value(
        pretrained_weight_path,
        nested_dict_value(pretrained_segformer_infos, pretrained_name, "weight"),
        check_none=False,
    )
    if weight_path is not None:
        inst.from_pretrained(weight_path)

    return inst

forward ¤

forward()
Source code in src/unitorch/cli/models/segformer/modeling.py
61
62
63
64
65
@autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"))
def forward(
    self,
):
    raise NotImplementedError

segment ¤

segment(pixel_values: Tensor)
Source code in src/unitorch/cli/models/segformer/modeling.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@config_defaults_method("core/model/segmentation/segformer")
@torch.no_grad()
def segment(
    self,
    pixel_values: torch.Tensor,
):
    outputs = super().forward(
        pixel_values=pixel_values,
    )
    batch = outputs.logits.shape[0]
    num_classes = outputs.logits.shape[-1]

    masks = torch.softmax(outputs.logits, dim=1)
    masks = masks * (masks == masks.max(dim=1, keepdim=True).values).float()
    classes = (
        torch.arange(num_classes, device=masks.device)
        .unsqueeze(0)
        .expand(batch, -1)
    )

    return SegmentationOutputs(
        masks=masks,
        classes=classes,
    )