Skip to content

unitorch.cli.models.bria¤

BRIAProcessor¤

Tip

core/process/bria is the section for configuration of BRIAProcessor.

Bases: BRIAProcessor

Source code in src/unitorch/cli/models/bria/processing.py
18
19
20
21
22
def __init__(
    self,
    image_size: Optional[int] = 1024,
):
    super().__init__(image_size=image_size)

from_config classmethod ¤

from_config(config, **kwargs)
Source code in src/unitorch/cli/models/bria/processing.py
24
25
26
27
28
29
30
31
32
@classmethod
@config_defaults_init("core/process/bria")
def from_config(cls, config, **kwargs):
    config.set_default_section("core/process/bria")
    image_size = config.getoption("image_size", 1024)

    return {
        "image_size": image_size,
    }

_segmentation_inputs ¤

_segmentation_inputs(image: Union[Image, str])
Source code in src/unitorch/cli/models/bria/processing.py
34
35
36
37
38
39
40
41
42
@register_process("core/process/bria/segmentation/inputs")
def _segmentation_inputs(
    self,
    image: Union[Image.Image, str],
):
    if isinstance(image, str):
        image = Image.open(image)
    inputs = super().segmentation_inputs(image=image)
    return TensorInputs(images=inputs.image, sizes=inputs.sizes)

_segmentation ¤

_segmentation(
    image: Union[Image, str], mask: Union[Image, str]
)
Source code in src/unitorch/cli/models/bria/processing.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@register_process("core/process/bria/segmentation")
def _segmentation(
    self,
    image: Union[Image.Image, str],
    mask: Union[Image.Image, str],
):
    if isinstance(image, str):
        image = Image.open(image)
    if isinstance(mask, str):
        mask = Image.open(mask)
    inputs = super().segmentation_inputs(image=image)
    labels = super().segmentation_labels(image=mask)
    return TensorInputs(images=inputs.image), SegmentationTargets(
        targets=labels.image
    )

BRIAForSegmentation¤

Tip

core/model/segmentation/bria is the section for configuration of BRIAForSegmentation.

Bases: BRIAForSegmentation

Source code in src/unitorch/cli/models/bria/modeling.py
23
24
def __init__(self):
    super().__init__()

from_config classmethod ¤

from_config(config, **kwargs)
Source code in src/unitorch/cli/models/bria/modeling.py
26
27
28
29
30
31
32
33
34
35
36
@classmethod
@config_defaults_init("core/model/segmentation/bria")
def from_config(cls, config, **kwargs):
    config.set_default_section("core/model/segmentation/bria")

    inst = cls()
    weight_path = config.getoption("pretrained_weight_path", None)
    if weight_path is not None:
        inst.from_pretrained(weight_path)

    return inst

forward ¤

forward(images)
Source code in src/unitorch/cli/models/bria/modeling.py
38
39
40
41
42
43
@autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"))
def forward(self, images):
    outputs = super().forward(images)
    return SegmentationOutputs(
        masks=outputs,
    )

segment ¤

segment(
    images, sizes: Optional[List[Tuple[int, int]]] = None
)
Source code in src/unitorch/cli/models/bria/modeling.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def segment(self, images, sizes: Optional[List[Tuple[int, int]]] = None):
    outputs = super().forward(images)
    if sizes is None:
        return SegmentationOutputs(
            masks=outputs.logits,
        )
    masks = [
        F.interpolate(mask.unsqueeze(0), size=list(size), mode="bilinear").squeeze(
            0
        )
        for mask, size in zip(outputs.logits, sizes)
    ]
    masks = [m.permute(1, 2, 0) for m in masks]
    masks = [(m - m.min()) / (m.max() - m.min()) for m in masks]
    return SegmentationOutputs(
        masks=masks,
    )