Skip to content

unitorch.models.beit¤

BeitProcessor¤

Bases: HfImageClassificationProcessor

Image classification processor for BEiT models.

Source code in src/unitorch/models/beit/processing.py
12
13
14
15
def __init__(self, vision_config_path: str) -> None:
    super().__init__(
        vision_processor=BeitImageProcessor.from_json_file(vision_config_path),
    )

BeitForImageClassification¤

Bases: GenericModel

BEiT model for image classification.

Source code in src/unitorch/models/beit/modeling.py
14
15
16
17
18
19
20
21
22
23
def __init__(
    self,
    config_path: str,
    num_classes: int = 1,
) -> None:
    super().__init__()
    config = BeitConfig.from_json_file(config_path)
    self.beit = BeitModel(config, add_pooling_layer=True)
    self.classifier = nn.Linear(config.hidden_size, num_classes)
    self.init_weights()

beit instance-attribute ¤

beit = BeitModel(config, add_pooling_layer=True)

classifier instance-attribute ¤

classifier = Linear(hidden_size, num_classes)

forward ¤

forward(pixel_values: Tensor) -> Tensor
Source code in src/unitorch/models/beit/modeling.py
25
26
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
    return self.classifier(self.beit(pixel_values=pixel_values).pooler_output)