Skip to content

unitorch.models.mask2former¤

Mask2FormerProcessor¤

Bases: HfImageClassificationProcessor

Source code in src/unitorch/models/mask2former/processing.py
11
12
13
14
15
16
17
18
19
def __init__(
    self,
    vision_config_path: str,
):
    vision_processor = Mask2FormerImageProcessor.from_json_file(vision_config_path)

    super().__init__(
        vision_processor=vision_processor,
    )

Mask2FormerForSegmentation¤

Bases: GenericModel

Source code in src/unitorch/models/mask2former/modeling.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(
    self,
    config_path: str,
):
    super().__init__()
    config = Mask2FormerConfig.from_json_file(config_path)

    self.model = Mask2FormerModel(config)
    self.weight_dict: Dict[str, float] = {
        "loss_cross_entropy": config.class_weight,
        "loss_mask": config.mask_weight,
        "loss_dice": config.dice_weight,
    }

    self.class_predictor = nn.Linear(config.hidden_dim, config.num_labels + 1)
    self.criterion = Mask2FormerLoss(config=config, weight_dict=self.weight_dict)
    self.init_weights()

model instance-attribute ¤

model = Mask2FormerModel(config)

weight_dict instance-attribute ¤

weight_dict: Dict[str, float] = {
    "loss_cross_entropy": class_weight,
    "loss_mask": mask_weight,
    "loss_dice": dice_weight,
}

class_predictor instance-attribute ¤

class_predictor = Linear(hidden_dim, num_labels + 1)

criterion instance-attribute ¤

criterion = Mask2FormerLoss(
    config=config, weight_dict=weight_dict
)

forward ¤

forward()
Source code in src/unitorch/models/mask2former/modeling.py
34
35
def forward(self):
    raise NotImplementedError

segment ¤

segment(
    pixel_values: Tensor,
    pixel_mask: Optional[Tensor] = None,
)
Source code in src/unitorch/models/mask2former/modeling.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def segment(
    self,
    pixel_values: torch.Tensor,
    pixel_mask: Optional[torch.Tensor] = None,
):
    outputs = self.model(
        pixel_values=pixel_values,
        pixel_mask=pixel_mask,
        output_hidden_states=True,
    )

    decoder_output = outputs.transformer_decoder_intermediate_states[-1]
    class_queries_logits = self.class_predictor(decoder_output.transpose(0, 1))

    masks_queries_logits = outputs.masks_queries_logits[-1]
    masks_queries_logits = nn.functional.interpolate(
        masks_queries_logits,
        scale_factor=4,
        mode="bilinear",
        align_corners=False,
    )
    return GenericOutputs(masks=masks_queries_logits, classes=class_queries_logits)