Skip to content

unitorch.models.swin¤

SwinProcessor¤

Bases: HfImageClassificationProcessor

Initializes the SwinProcessor.

Parameters:

Name Type Description Default
vision_config_path str

Path to the ViTImageProcessor configuration file.

required
Source code in src/unitorch/models/swin/processing.py
10
11
12
13
14
15
16
17
18
19
20
21
def __init__(
    self,
    vision_config_path: str,
):
    """
    Initializes the SwinProcessor.

    Args:
        vision_config_path (str): Path to the ViTImageProcessor configuration file.
    """
    vision_processor = ViTImageProcessor.from_json_file(vision_config_path)
    super().__init__(vision_processor=vision_processor)

SwinForImageClassification¤

Bases: GenericModel

Initializes the SwinForImageClassification model.

Parameters:

Name Type Description Default
config_path str

Path to the Swin Transformer configuration file.

required
num_classes int

Number of output classes. Defaults to 1.

1
Source code in src/unitorch/models/swin/modeling.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(
    self,
    config_path: str,
    num_classes: Optional[int] = 1,
):
    """
    Initializes the SwinForImageClassification model.

    Args:
        config_path (str): Path to the Swin Transformer configuration file.
        num_classes (int, optional): Number of output classes. Defaults to 1.
    """
    super().__init__()
    config = SwinConfig.from_json_file(config_path)
    self.swin = SwinModel(config)
    self.classifier = nn.Linear(self.swin.num_features, num_classes)
    self.init_weights()

swin instance-attribute ¤

swin = SwinModel(config)

classifier instance-attribute ¤

classifier = Linear(num_features, num_classes)

forward ¤

forward(pixel_values: Tensor)

Forward pass of the SwinForImageClassification model.

Parameters:

Name Type Description Default
pixel_values Tensor

Input image pixel values.

required

Returns:

Type Description

torch.Tensor: Classification logits.

Source code in src/unitorch/models/swin/modeling.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def forward(
    self,
    pixel_values: torch.Tensor,
):
    """
    Forward pass of the SwinForImageClassification model.

    Args:
        pixel_values (torch.Tensor): Input image pixel values.

    Returns:
        torch.Tensor: Classification logits.
    """
    outputs = self.swin(pixel_values=pixel_values)
    pooled_output = outputs[1]
    return self.classifier(pooled_output)