| import torch |
| import torchvision |
| from torch import nn |
| def create_effnetb2_model(num_classes : int , |
| seed : int=42): |
| """ |
| Create an EffNetB2 feature extractor model and move it to the target device. |
| Args: |
| num_classes (int, optional): number of classes in the classifier head. |
| Defaults to 3. |
| seed (int, optional): random seed value. Defaults to 42. |
| |
| Returns: |
| model (torch.nn.Module): EffNetB2 feature extractor model. |
| transforms (torchvision.transforms): EffNetB2 image transforms. |
| """ |
| |
| weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT |
| transforms = weights.transforms() |
| model = torchvision.models.efficientnet_b2(weights) |
|
|
| |
| for param in model.parameters(): |
| param.requires_grad = False |
| |
| torch.manual_seed(seed) |
| model.classifier = nn.Sequential( |
| nn.Dropout(p=0.2, inplace=True), |
| nn.Linear(in_features=1408, out_features=num_classes, bias=True) |
| ) |
| return model, transforms |
|
|