Skip to content

unitorch.models.sam¤

SamProcessor¤

Initializes the SamProcessor.

Parameters:

Name Type Description Default
vision_config_path str

Path to the SamImageProcessor configuration file.

required
Source code in src/unitorch/models/sam/processing.py
13
14
15
16
17
18
19
20
21
22
23
def __init__(
    self,
    vision_config_path: str,
):
    """
    Initializes the SamProcessor.

    Args:
        vision_config_path (str): Path to the SamImageProcessor configuration file.
    """
    self.vision_processor = SamImageProcessor.from_json_file(vision_config_path)

vision_processor instance-attribute ¤

vision_processor = from_json_file(vision_config_path)

segmentation_inputs ¤

segmentation_inputs(
    image: Union[Image, str],
    crops_n_layers: int = 0,
    crop_overlap_ratio: float = 512 / 1500,
    points_per_crop: Optional[int] = 32,
    crop_n_points_downscale_factor: Optional[int] = 1,
)

Generates segmentation inputs using grid-based point prompts.

Parameters:

Name Type Description Default
image Image or str

Input image or path.

required
crops_n_layers int

Number of crop layers. Defaults to 0.

0
crop_overlap_ratio float

Overlap ratio between crops. Defaults to 512/1500.

512 / 1500
points_per_crop int

Number of grid points per crop. Defaults to 32.

32
crop_n_points_downscale_factor int

Downscale factor for points. Defaults to 1.

1

Returns:

Name Type Description
GenericOutputs

Processed pixel values, grid points, labels, and boxes.

Source code in src/unitorch/models/sam/processing.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def segmentation_inputs(
    self,
    image: Union[Image.Image, str],
    crops_n_layers: int = 0,
    crop_overlap_ratio: float = 512 / 1500,
    points_per_crop: Optional[int] = 32,
    crop_n_points_downscale_factor: Optional[int] = 1,
):
    """
    Generates segmentation inputs using grid-based point prompts.

    Args:
        image (PIL.Image.Image or str): Input image or path.
        crops_n_layers (int, optional): Number of crop layers. Defaults to 0.
        crop_overlap_ratio (float, optional): Overlap ratio between crops. Defaults to 512/1500.
        points_per_crop (int, optional): Number of grid points per crop. Defaults to 32.
        crop_n_points_downscale_factor (int, optional): Downscale factor for points. Defaults to 1.

    Returns:
        GenericOutputs: Processed pixel values, grid points, labels, and boxes.
    """
    if isinstance(image, str):
        image = Image.open(image)

    target_size = self.vision_processor.size["longest_edge"]
    crop_boxes, grid_points, cropped_images, input_labels = (
        self.vision_processor.generate_crop_boxes(
            image,
            target_size,
            crops_n_layers,
            crop_overlap_ratio,
            points_per_crop,
            crop_n_points_downscale_factor,
        )
    )
    pixel_inputs = self.vision_processor(cropped_images, return_tensors="pt")
    return GenericOutputs(
        pixel_values=pixel_inputs.get("pixel_values")[0],
        original_sizes=pixel_inputs.get("original_sizes")[0],
        reshaped_input_sizes=pixel_inputs.get("reshaped_input_sizes")[0],
        input_points=grid_points[0],
        input_labels=input_labels[0],
        input_boxes=crop_boxes[0],
    )

processing_masks ¤

processing_masks(
    masks: Tensor,
    scores: Tensor,
    original_sizes: Union[Tensor, List[Tuple[int, int]]],
    reshaped_input_sizes: Union[
        Tensor, List[Tuple[int, int]]
    ],
    input_boxes: Tensor,
    mask_threshold: Optional[float] = 0.0,
    pred_iou_thresh: Optional[float] = 0.88,
    stability_score_thresh: Optional[float] = 0.95,
    stability_score_offset: Optional[int] = 1,
    crops_nms_thresh: Optional[float] = 0.7,
)

Post-processes predicted masks and scores.

Parameters:

Name Type Description Default
masks Tensor

Raw predicted masks.

required
scores Tensor

IoU scores.

required
original_sizes Union[Tensor, List[Tuple[int, int]]]

Original image sizes.

required
reshaped_input_sizes Union[Tensor, List[Tuple[int, int]]]

Reshaped input sizes.

required
input_boxes Tensor

Crop boxes.

required
mask_threshold float

Threshold for mask binarization. Defaults to 0.0.

0.0
pred_iou_thresh float

IoU score threshold. Defaults to 0.88.

0.88
stability_score_thresh float

Stability score threshold. Defaults to 0.95.

0.95
stability_score_offset int

Stability score offset. Defaults to 1.

1
crops_nms_thresh float

NMS threshold for crops. Defaults to 0.7.

0.7

Returns:

Name Type Description
GenericOutputs

Filtered masks, scores, RLE masks, and bounding boxes.

Source code in src/unitorch/models/sam/processing.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def processing_masks(
    self,
    masks: torch.Tensor,
    scores: torch.Tensor,
    original_sizes: Union[torch.Tensor, List[Tuple[int, int]]],
    reshaped_input_sizes: Union[torch.Tensor, List[Tuple[int, int]]],
    input_boxes: torch.Tensor,
    mask_threshold: Optional[float] = 0.0,
    pred_iou_thresh: Optional[float] = 0.88,
    stability_score_thresh: Optional[float] = 0.95,
    stability_score_offset: Optional[int] = 1,
    crops_nms_thresh: Optional[float] = 0.7,
):
    """
    Post-processes predicted masks and scores.

    Args:
        masks (torch.Tensor): Raw predicted masks.
        scores (torch.Tensor): IoU scores.
        original_sizes: Original image sizes.
        reshaped_input_sizes: Reshaped input sizes.
        input_boxes (torch.Tensor): Crop boxes.
        mask_threshold (float, optional): Threshold for mask binarization. Defaults to 0.0.
        pred_iou_thresh (float, optional): IoU score threshold. Defaults to 0.88.
        stability_score_thresh (float, optional): Stability score threshold. Defaults to 0.95.
        stability_score_offset (int, optional): Stability score offset. Defaults to 1.
        crops_nms_thresh (float, optional): NMS threshold for crops. Defaults to 0.7.

    Returns:
        GenericOutputs: Filtered masks, scores, RLE masks, and bounding boxes.
    """
    if isinstance(original_sizes, torch.Tensor):
        original_sizes = original_sizes.tolist()
    if isinstance(reshaped_input_sizes, torch.Tensor):
        reshaped_input_sizes = reshaped_input_sizes.tolist()

    masks = self.vision_processor.post_process_masks(
        masks,
        original_sizes,
        reshaped_input_sizes,
        mask_threshold=mask_threshold,
        binarize=False,
    )

    output_masks, output_scores, output_rle_mask, output_bounding_boxes = (
        [],
        [],
        [],
        [],
    )
    for _masks, _scores, _original_sizes, _input_boxes in zip(
        masks, scores, original_sizes, input_boxes
    ):
        _masks, _scores, _boxes = self.vision_processor.filter_masks(
            _masks,
            _scores,
            _original_sizes,
            _input_boxes,
            pred_iou_thresh,
            stability_score_thresh,
            mask_threshold,
            stability_score_offset,
        )
        _masks, _scores, _rle_mask, _bounding_boxes = (
            self.vision_processor.post_process_for_mask_generation(
                _masks, _scores, _boxes, crops_nms_thresh
            )
        )
        output_masks.append(torch.from_numpy(np.array(_masks)))
        output_scores.append(_scores)
        output_rle_mask.append(_rle_mask)
        output_bounding_boxes.append(_bounding_boxes)

    return GenericOutputs(
        masks=output_masks,
        scores=output_scores,
        rle_mask=output_rle_mask,
        bounding_boxes=output_bounding_boxes,
    )

SamForSegmentation¤

Bases: GenericModel, PeftWeightLoaderMixin

SAM model for segmentation tasks.

Initializes the SamForSegmentation model.

Parameters:

Name Type Description Default
config_path str

Path to the SAM configuration file.

required
Source code in src/unitorch/models/sam/modeling.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __init__(
    self,
    config_path: str,
):
    """
    Initializes the SamForSegmentation model.

    Args:
        config_path (str): Path to the SAM configuration file.
    """
    super().__init__()
    config = SamConfig.from_json_file(config_path)
    self.sam = SamModel(config)
    self.init_weights()

prefix_keys_in_state_dict class-attribute instance-attribute ¤

prefix_keys_in_state_dict = {
    "^mask_decoder.*": "sam.",
    "^vision_encoder.*": "sam.",
    "^prompt_encoder.*": "sam.",
    "^shared_image_embedding.*": "sam.",
}

replace_keys_in_peft_state_dict class-attribute instance-attribute ¤

replace_keys_in_peft_state_dict = {
    "peft_model.base_model.model.": ""
}

sam instance-attribute ¤

sam = SamModel(config)

forward ¤

forward()
Source code in src/unitorch/models/sam/modeling.py
39
40
def forward(self):
    raise NotImplementedError

segment ¤

segment(
    pixel_values: Tensor,
    input_points: Tensor,
    input_labels: Optional[Tensor] = None,
    input_boxes: Optional[Tensor] = None,
    input_masks: Optional[Tensor] = None,
)

Runs segmentation inference.

Parameters:

Name Type Description Default
pixel_values Tensor

Input image pixel values.

required
input_points Tensor

Input point prompts.

required
input_labels Tensor

Labels for input points. Defaults to None.

None
input_boxes Tensor

Input box prompts. Defaults to None.

None
input_masks Tensor

Input mask prompts. Defaults to None.

None

Returns:

Name Type Description
GenericOutputs

Predicted masks and IoU scores.

Source code in src/unitorch/models/sam/modeling.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def segment(
    self,
    pixel_values: torch.Tensor,
    input_points: torch.Tensor,
    input_labels: Optional[torch.Tensor] = None,
    input_boxes: Optional[torch.Tensor] = None,
    input_masks: Optional[torch.Tensor] = None,
):
    """
    Runs segmentation inference.

    Args:
        pixel_values (torch.Tensor): Input image pixel values.
        input_points (torch.Tensor): Input point prompts.
        input_labels (torch.Tensor, optional): Labels for input points. Defaults to None.
        input_boxes (torch.Tensor, optional): Input box prompts. Defaults to None.
        input_masks (torch.Tensor, optional): Input mask prompts. Defaults to None.

    Returns:
        GenericOutputs: Predicted masks and IoU scores.
    """
    outputs = self.sam(
        pixel_values,
        input_points=input_points,
        input_labels=input_labels,
        input_masks=input_masks,
    )
    return GenericOutputs(masks=outputs.pred_masks, scores=outputs.iou_scores)