File size: 6,504 Bytes
a7b33df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin # Import the mixin

# --- Custom Model Definitions ---

class Identity(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x

class AdditiveAttention(nn.Module):
    def __init__(self, d_model: int, hidden_dim: int = 128):
        super(AdditiveAttention, self).__init__()
        self.query_projection = nn.Linear(d_model, hidden_dim)
        self.key_projection = nn.Linear(d_model, hidden_dim)
        self.value_projection = nn.Linear(d_model, hidden_dim)
        self.attention_mechanism = nn.Linear(hidden_dim, hidden_dim) # Output hidden_dim

    def forward(self, query: torch.Tensor) -> torch.Tensor:
        keys = self.key_projection(query)
        values = self.value_projection(query)
        queries = self.query_projection(query)

        attention_scores = torch.tanh(queries + keys)
        attention_weights = F.softmax(self.attention_mechanism(attention_scores), dim=1)

        attended_values = values * attention_weights # Element-wise product
        return attended_values

class ResNet50Custom(nn.Module, PyTorchModelHubMixin): # Inherit from PyTorchModelHubMixin
    def __init__(self, input_channels: int, num_classes: int, **kwargs):
        super(ResNet50Custom, self).__init__()

        # Store config for PyTorchModelHubMixin to serialize to config.json
        self.config = {
            "input_channels": input_channels,
            "num_classes": num_classes,
            **kwargs
        }

        self.input_channels = input_channels

        self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

        self.model.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # The final FC layer of ResNet50Custom will be used *only* when ResNet50Custom is a standalone classifier.
        # When used as a feature extractor within MultiModalModel, this layer will be temporarily replaced by Identity().
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def get_feature_size(self) -> int:
        return self.model.fc.in_features


class MultiModalModel(nn.Module, PyTorchModelHubMixin): # Inherit from PyTorchModelHubMixin
    def __init__(self,

                 image_input_channels: int,

                 bathy_input_channels: int,

                 sss_input_channels: int,

                 num_classes: int,

                 attention_type: str = "scaled_dot_product",

                 **kwargs): # Added **kwargs for mixin compatibility
        super(MultiModalModel, self).__init__()

        # Store config for PyTorchModelHubMixin to serialize to config.json
        self.config = {
            "image_input_channels": image_input_channels,
            "bathy_input_channels": bathy_input_channels,
            "sss_input_channels": sss_input_channels,
            "num_classes": num_classes,
            "attention_type": attention_type,
            **kwargs # Pass along any extra kwargs for mixin
        }

        # Instantiate feature extraction models *inside* MultiModalModel
        # Their final FC layers will be treated as Identity for feature extraction
        self.image_model_feat = ResNet50Custom(input_channels=image_input_channels, num_classes=num_classes)
        self.bathy_model_feat = ResNet50Custom(input_channels=bathy_input_channels, num_classes=num_classes)
        self.sss_model_feat = ResNet50Custom(input_channels=sss_input_channels, num_classes=num_classes)

        # The ResNet50's feature output size is 2048 before its final FC layer
        feature_dim = self.image_model_feat.get_feature_size() # Should be 2048

        # Attention layers (AdditiveAttention uses d_model and outputs hidden_dim)
        attention_hidden_dim = 128 # This matches your fc layer input calculation (3*128=384)
        self.attention_image = AdditiveAttention(feature_dim, hidden_dim=attention_hidden_dim)
        self.attention_bathy = AdditiveAttention(feature_dim, hidden_dim=attention_hidden_dim)
        self.attention_sss = AdditiveAttention(feature_dim, hidden_dim=attention_hidden_dim)

        # Final classification layers
        self.fc = nn.Linear(3 * attention_hidden_dim, 1284)
        self.fc1 = nn.Linear(1284, 32)
        # Ensure num_classes is int for the linear layer
        num_classes_int = int(num_classes)
        if not isinstance(num_classes_int, int):
            raise TypeError("num_classes must be an integer after casting")
        self.fc2 = nn.Linear(32, num_classes_int)
        self.attention_type = attention_type

    def forward(self, inputs: torch.Tensor, bathy_tensor: torch.Tensor, sss_image: torch.Tensor) -> torch.Tensor:
        # Temporarily replace the final FC layer of the feature extractors with Identity
        # to get the 2048 features, then restore them.
        original_image_fc = self.image_model_feat.model.fc
        original_bathy_fc = self.bathy_model_feat.model.fc
        original_sss_fc = self.sss_model_feat.model.fc

        self.image_model_feat.model.fc = Identity()
        self.bathy_model_feat.model.fc = Identity()
        self.sss_model_feat.model.fc = Identity()

        image_features = self.image_model_feat(inputs)
        bathy_features = self.bathy_model_feat(bathy_tensor)
        sss_features = self.sss_model_feat(sss_image)

        # Restore original FC layers on the feature extractors
        self.image_model_feat.model.fc = original_image_fc
        self.bathy_model_feat.model.fc = original_bathy_fc
        self.sss_model_feat.model.fc = original_sss_fc

        # Apply attention
        image_features_attended = self.attention_image(image_features)
        bathy_features_attended = self.attention_bathy(bathy_features)
        sss_features_attended = self.attention_sss(sss_features)

        # Concatenate attended features
        combined_features = torch.cat([image_features_attended, bathy_features_attended, sss_features_attended], dim=1)

        # Pass through final classification layers
        outputs_1 = self.fc(combined_features)
        output_2 = self.fc1(outputs_1)
        outputs = self.fc2(output_2)
        return outputs