Skip to content

unitorch.cli.models.sam¤

SamProcessor¤

Tip

core/process/sam is the section for configuration of SamProcessor.

Bases: SamProcessor

Source code in src/unitorch/cli/models/sam/processing.py
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(
    self,
    vision_config_path: str,
    output_folder: Optional[str] = None,
):
    super().__init__(
        vision_config_path=vision_config_path,
    )
    assert output_folder is not None
    self.output_folder = output_folder
    if not os.path.exists(output_folder):
        os.makedirs(self.output_folder, exist_ok=True)

output_folder instance-attribute ¤

output_folder = output_folder

from_config classmethod ¤

from_config(config, **kwargs)
Source code in src/unitorch/cli/models/sam/processing.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@classmethod
@config_defaults_init("core/process/sam")
def from_config(cls, config, **kwargs):
    config.set_default_section("core/process/sam")
    pretrained_name = config.getoption("pretrained_name", "sam-vit-base")
    vision_config_path = config.getoption("vision_config_path", None)
    vision_config_path = pop_value(
        vision_config_path,
        nested_dict_value(pretrained_sam_infos, pretrained_name, "vision_config"),
    )

    vision_config_path = cached_path(vision_config_path)

    return {
        "vision_config_path": vision_config_path,
    }

save ¤

save(image: Image)
Source code in src/unitorch/cli/models/sam/processing.py
51
52
53
54
55
56
def save(self, image: Image.Image):
    md5 = hashlib.md5()
    md5.update(image.tobytes())
    name = md5.hexdigest() + ".jpg"
    image.save(f"{self.output_folder}/{name}")
    return name

_segmentation_inputs ¤

_segmentation_inputs(
    image: Union[Image, str],
    points_per_crop: Optional[int] = 32,
)
Source code in src/unitorch/cli/models/sam/processing.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@register_process("core/process/sam/segmentation/inputs")
def _segmentation_inputs(
    self,
    image: Union[Image.Image, str],
    points_per_crop: Optional[int] = 32,
):
    if isinstance(image, str):
        image = Image.open(image)
    outputs = super().segmentation_inputs(
        image=image,
        points_per_crop=points_per_crop,
    )
    return TensorInputs(
        pixel_values=outputs.pixel_values,
        original_sizes=outputs.original_sizes,
        reshaped_input_sizes=outputs.reshaped_input_sizes,
        input_points=outputs.input_points,
        input_labels=outputs.input_labels,
        input_boxes=outputs.input_boxes,
    )

_processing_masks ¤

_processing_masks(outputs: SegmentationOutputs)
Source code in src/unitorch/cli/models/sam/processing.py
79
80
81
82
83
84
85
86
87
@register_process("core/postprocess/sam/segmentation")
def _processing_masks(self, outputs: SegmentationOutputs):
    results = outputs.to_pandas()
    assert results.shape[0] == 0 or results.shape[0] == len(outputs.masks)
    results["mask_image"] = [
        ";".join([self.save(Image.fromarray(_m_.numpy())) for _m_ in m])
        for m in outputs.outputs
    ]
    return WriterOutputs(results)

SamForSegmentation¤

Tip

core/model/segmentation/sam is the section for configuration of SamForSegmentation.

Bases: SamForSegmentation

Source code in src/unitorch/cli/models/sam/modeling.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    config_path: str,
    vision_config_path: str,
    mask_threshold: Optional[float] = 0.0,
    pred_iou_thresh: Optional[float] = 0.88,
    stability_score_thresh: Optional[float] = 0.95,
    stability_score_offset: Optional[int] = 1,
    crops_nms_thresh: Optional[float] = 0.7,
):
    super().__init__(
        config_path=config_path,
    )
    self.processor = SamProcessor(vision_config_path=vision_config_path)
    self.mask_threshold = mask_threshold
    self.pred_iou_thresh = pred_iou_thresh
    self.stability_score_thresh = stability_score_thresh
    self.stability_score_offset = stability_score_offset
    self.crops_nms_thresh = crops_nms_thresh

processor instance-attribute ¤

processor = SamProcessor(
    vision_config_path=vision_config_path
)

mask_threshold instance-attribute ¤

mask_threshold = mask_threshold

pred_iou_thresh instance-attribute ¤

pred_iou_thresh = pred_iou_thresh

stability_score_thresh instance-attribute ¤

stability_score_thresh = stability_score_thresh

stability_score_offset instance-attribute ¤

stability_score_offset = stability_score_offset

crops_nms_thresh instance-attribute ¤

crops_nms_thresh = crops_nms_thresh

from_config classmethod ¤

from_config(config, **kwargs)
Source code in src/unitorch/cli/models/sam/modeling.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@classmethod
@config_defaults_init("core/model/segmentation/sam")
def from_config(cls, config, **kwargs):
    config.set_default_section("core/model/segmentation/sam")
    pretrained_name = config.getoption("pretrained_name", "sam-vit-base")
    config_path = config.getoption("config_path", None)
    config_path = pop_value(
        config_path,
        nested_dict_value(pretrained_sam_infos, pretrained_name, "config"),
    )

    config_path = cached_path(config_path)

    vision_config_path = config.getoption("vision_config_path", None)
    vision_config_path = pop_value(
        vision_config_path,
        nested_dict_value(pretrained_sam_infos, pretrained_name, "vision_config"),
    )

    vision_config_path = cached_path(vision_config_path)

    inst = cls(
        config_path=config_path,
        vision_config_path=vision_config_path,
    )
    pretrained_weight_path = config.getoption("pretrained_weight_path", None)
    weight_path = pop_value(
        pretrained_weight_path,
        nested_dict_value(pretrained_sam_infos, pretrained_name, "weight"),
        check_none=False,
    )
    if weight_path is not None:
        inst.from_pretrained(weight_path)

    pretrained_lora_weight_path = config.getoption(
        "pretrained_lora_weight_path", None
    )
    pretrained_lora_weight = config.getoption("pretrained_lora_weight", 1.0)
    pretrained_lora_alpha = config.getoption("pretrained_lora_alpha", 32.0)
    if pretrained_lora_weight_path is not None:
        inst.load_lora_weights(
            pretrained_lora_weight_path,
            lora_weights=pretrained_lora_weight,
            lora_alphas=pretrained_lora_alpha,
            save_base_state=False,
        )

    return inst

forward ¤

forward()
Source code in src/unitorch/cli/models/sam/modeling.py
94
95
96
97
98
@autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"))
def forward(
    self,
):
    raise NotImplementedError

segment ¤

segment(
    pixel_values: Tensor,
    input_points: Tensor,
    input_boxes: Tensor,
    original_sizes: Tensor,
    reshaped_input_sizes: Tensor,
    input_labels: Optional[Tensor] = None,
)
Source code in src/unitorch/cli/models/sam/modeling.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
@config_defaults_method("core/model/segmentation/sam")
@torch.no_grad()
def segment(
    self,
    pixel_values: torch.Tensor,
    input_points: torch.Tensor,
    input_boxes: torch.Tensor,
    original_sizes: torch.Tensor,
    reshaped_input_sizes: torch.Tensor,
    input_labels: Optional[torch.Tensor] = None,
):
    outputs = super().segment(
        pixel_values=pixel_values,
        input_points=input_points,
        input_labels=input_labels,
    )
    processed_masks = self.processor.processing_masks(
        masks=outputs.masks,
        scores=outputs.scores,
        input_boxes=input_boxes,
        original_sizes=original_sizes,
        reshaped_input_sizes=reshaped_input_sizes,
        mask_threshold=self.mask_threshold,
        pred_iou_thresh=self.pred_iou_thresh,
        stability_score_thresh=self.stability_score_thresh,
        stability_score_offset=self.stability_score_offset,
        crops_nms_thresh=self.crops_nms_thresh,
    )
    return SegmentationOutputs(
        masks=processed_masks.masks,
    )