diff --git a/DIC.py b/DIC.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bd67353e13c054ea320882506624ac0e2050a91
--- /dev/null
+++ b/DIC.py
@@ -0,0 +1,17 @@
+import torch
+from pathlib import Path
+
+
+dir=Path.home() / f"tmp/resnet50/CUB2011/123456/"
+dic=torch.load(dir/ f"SlDD_Selection_50.pt")
+
+print (dic)
+
+#if 'linear.selection' in dic.keys():
+ #print("key 'linear.selection' exist")
+#else:
+ #print("no such key")
+
+
+
+
diff --git a/FeatureDiversityLoss.py b/FeatureDiversityLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..be5745ae71dbe298244271c3a942c80c2b3e9867
--- /dev/null
+++ b/FeatureDiversityLoss.py
@@ -0,0 +1,59 @@
+import torch
+from torch import nn
+
+"""
+Feature Diversity Loss:
+Usage to replicate paper:
+Call
+loss_function = FeatureDiversityLoss(0.196, linear)
+to inititalize loss with linear layer of model.
+At each mini batch get feature maps (Output of final convolutional layer) and add to Loss:
+loss += loss_function(feature_maps, outputs)
+"""
+
+
+class FeatureDiversityLoss(nn.Module):
+ def __init__(self, scaling_factor, linear):
+ super().__init__()
+ self.scaling_factor = scaling_factor #* 0
+ print("Scaling Factor: ", self.scaling_factor)
+ self.linearLayer = linear
+
+ def initialize(self, linearLayer):
+ self.linearLayer = linearLayer
+
+ def get_weights(self, outputs):
+ weight_matrix = self.linearLayer.weight
+ weight_matrix = torch.abs(weight_matrix)
+ top_classes = torch.argmax(outputs, dim=1)
+ relevant_weights = weight_matrix[top_classes]
+ return relevant_weights
+
+ def forward(self, feature_maps, outputs):
+ relevant_weights = self.get_weights(outputs)
+ relevant_weights = norm_vector(relevant_weights)
+ feature_maps = preserve_avg_func(feature_maps)
+ flattened_feature_maps = feature_maps.flatten(2)
+ batch, features, map_size = flattened_feature_maps.size()
+ relevant_feature_maps = flattened_feature_maps * relevant_weights[..., None]
+ diversity_loss = torch.sum(
+ torch.amax(relevant_feature_maps, dim=1))
+ return -diversity_loss / batch * self.scaling_factor
+
+
+def norm_vector(x):
+ return x / (torch.norm(x, dim=1) + 1e-5)[:, None]
+
+
+def preserve_avg_func(x):
+ avgs = torch.mean(x, dim=[2, 3])
+ max_avgs = torch.max(avgs, dim=1)[0]
+ scaling_factor = avgs / torch.clamp(max_avgs[..., None], min=1e-6)
+ softmaxed_maps = softmax_feature_maps(x)
+ scaled_maps = softmaxed_maps * scaling_factor[..., None, None]
+ return scaled_maps
+
+
+def softmax_feature_maps(x):
+ return torch.softmax(x.reshape(x.size(0), x.size(1), -1), 2).view_as(x)
+
diff --git a/ReadME.md b/ReadME.md
new file mode 100644
index 0000000000000000000000000000000000000000..ea7057e74321e137281dfc1e7b4890ca2c193ef0
--- /dev/null
+++ b/ReadME.md
@@ -0,0 +1,138 @@
+# Q-SENN - Quantized Self-Explaining Neural Networks
+
+This repository contains the code for the AAAI 2024 paper
+[*Q-SENN: Quantized Self-Explaining Neural Network*](https://ojs.aaai.org/index.php/AAAI/article/view/30145) by Thomas
+Norrenbrock ,
+Marco Rudolph,
+and Bodo Rosenhahn.
+Additonally, the SLDD-model from [*Take 5:
+Interpretable Image Classification with a Handful of Features*](https://arxiv.org/pdf/2303.13166) (NeurIPS
+Workshop) from the same authors is included.
+
+
+
+
+
+
+---
+Abstract:
+>Explanations in Computer Vision are often desired, but most Deep Neural Networks can only provide saliency maps with questionable faithfulness. Self-Explaining Neural Networks (SENN) extract interpretable concepts with fidelity, diversity, and grounding to combine them linearly for decision-making. While they can explain what was recognized, initial realizations lack accuracy and general applicability. We propose the Quantized-Self-Explaining Neural Network Q-SENN. Q-SENN satisfies or exceeds the desiderata of SENN while being applicable to more complex datasets and maintaining most or all of the accuracy of an uninterpretable baseline model, out-performing previous work in all considered metrics. Q-SENN describes the relationship between every class and feature as either positive, negative or neutral instead of an arbitrary number of possible relations, enforcing more binary human-friendly features. Since every class is assigned just 5 interpretable features on average, Q-SENN shows convincing local and global interpretability. Additionally, we propose a feature alignment method, capable of aligning learned features with human language-based concepts without additional supervision. Thus, what is learned can be more easily verbalized.
+
+
+
+
+---
+
+## Installation
+You will need the usual libaries for deep learning, e.g. pytorch,
+torchvision, numpy, etc. Additionally, we use
+[GLM-Saga](https://github.com/MadryLab/glm_saga) that can be installed via pip.
+In case you are lazy (or like to spend your time otherwise), a suitable
+environment can be created using [Anaconda](https://www.anaconda.com/) and the
+provided environment.yml file:
+```shell
+conda env create -f environment.yml
+```
+
+## Data
+Supported datasets are:
+- [Cub2011](https://www.vision.caltech.edu/datasets/cub_200_2011/)
+- [StanfordCars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)
+- [TravelingBirds](https://worksheets.codalab.org/bundles/0x518829de2aa440c79cd9d75ef6669f27)
+- [ImageNet](https://www.image-net.org/)
+
+To use the data for training, the datasets have to be downloaded and put into the
+respective folder under ~/tmp/datasets such that the final structure looks like
+
+```shell
+~/tmp/datasets
+├── CUB200
+│ └── CUB_200_2011
+│ ├── ...
+├── StanfordCars
+│ ├── stanford_cars
+│ ├── ...
+├── TravelingBirds
+│ ├── CUB_fixed
+│ ├── ...
+├── imagenet
+│ ├── ...
+```
+
+The default paths could be changed in the dataset_classes or for Imagenet in
+get_data.py
+
+Note:
+If cropped images, like for PIP-Net, ProtoPool, etc. are desired, then the
+crop_root should be set to a folder containing the cropped images in the
+expected structure, obtained by following ProtoTree's instructions:
+https://github.com/M-Nauta/ProtoTree/blob/main/README.md#preprocessing-cub,
+default path is: PPCUB200 instead of CUB200 for Protopool. Using these images
+can be set using an additional flag `--cropGT` introduced later.
+
+
+
+## Usage
+The code to create a Q-SENN model can be started from the file main.py.
+Available parameters are:
+- `--dataset`: The dataset to use. Default: Cub2011
+- `--arch`: The backbone to use. Default: resnet50
+- `--model_type`: The model type to use. Default: qsenn
+- `--seed`: The seed to use. Default: None
+- `--do_dense`: Whether to train the dense model. Default: True
+- `--cropGT`: Whether to crop CUB/TravelingBirds based on GT Boundaries. Default: False
+- `--n_features`: How many features to select. Default: 50
+- `--n_per_class`: How many features to assign to each class. Default: 5
+- `--img_size`: Image size. Default: 448
+- `--reduced_strides`: Whether to use reduced strides for resnets. Default: False
+
+
+For Example the next command will start the creation of Q-SENN with resnet50 on
+StanfordCars using the default arguments in the paper.
+```shell
+python main.py --dataset StanfordCars
+```
+
+**Note:**
+All experiments on ImageNet in the paper skipped the dense training from
+scratch on ImageNet. The pretrained models are used directly.
+This can be replicated with the argument --do-dense False.
+## Citations
+Please cite this work as:\
+Q-SENN
+```bibtex
+@inproceedings{norrenbrock2024q,
+ title={Q-senn: Quantized self-explaining neural networks},
+ author={Norrenbrock, Thomas and Rudolph, Marco and Rosenhahn, Bodo},
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
+ volume={38},
+ number={19},
+ pages={21482--21491},
+ year={2024}
+}
+```
+SLDD-Model
+```bibtex
+@inproceedings{norrenbrocktake,
+ title={Take 5: Interpretable Image Classification with a Handful of Features},
+ author={Norrenbrock, Thomas and Rudolph, Marco and Rosenhahn, Bodo},
+ year={2022},
+ booktitle={Progress and Challenges in Building Trustworthy Embodied AI}
+}
+```
+## Pretrained Model
+One pretrained model for Q-SENN on CUB can be obtained via this link: https://drive.google.com/drive/folders/1agWqKhcWOVWueV4Fzaowr80lQroCJFYn?usp=drive_link
+## Acknowledgement
+This work was supported by the Federal Ministry of Education and Research (BMBF), Germany under the AI service center KISSKI (grant no. 01IS22093C) and the Deutsche Forschungsgemeinschaft (DFG) under Germany’s Excellence Strategy within the Cluster of Excellence PhoenixD (EXC 2122).
+This work was partially supported by Intel Corporation and by the German Federal Ministry
+of the Environment, Nature Conservation, Nuclear Safety
+and Consumer Protection (GreenAutoML4FAS project no.
+67KI32007A).
+
+The work was done at the Leibniz University Hannover and published at AAAI 2024.
+
+
+
+
+
+
diff --git a/__pycache__/get_data.cpython-310.pyc b/__pycache__/get_data.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db7e1ece40eba3119dbfd6595168c1599a3847bb
Binary files /dev/null and b/__pycache__/get_data.cpython-310.pyc differ
diff --git a/__pycache__/load_model.cpython-310.pyc b/__pycache__/load_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dc4bd3e12a89cd304e7c33e8cf190a0353a8698
Binary files /dev/null and b/__pycache__/load_model.cpython-310.pyc differ
diff --git a/architectures/FinalLayer.py b/architectures/FinalLayer.py
new file mode 100644
index 0000000000000000000000000000000000000000..af1a55a667c462ec8f256f9d28aefdc5e77d6cae
--- /dev/null
+++ b/architectures/FinalLayer.py
@@ -0,0 +1,36 @@
+import torch
+from torch import nn
+
+from architectures.SLDDLevel import SLDDLevel
+
+
+class FinalLayer():
+ def __init__(self, num_classes, n_features):
+ super().__init__()
+ self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
+ self.linear = nn.Linear(n_features, num_classes)
+ self.featureDropout = torch.nn.Dropout(0.2)
+ self.selection = None
+
+ def transform_output(self, feature_maps, with_feature_maps,
+ with_final_features):
+ if self.selection is not None:
+ feature_maps = feature_maps[:, self.selection]
+ x = self.avgpool(feature_maps)
+ pre_out = torch.flatten(x, 1)
+ final_features = self.featureDropout(pre_out)
+ final = self.linear(final_features)
+ final = [final]
+ if with_feature_maps:
+ final.append(feature_maps)
+ if with_final_features:
+ final.append(final_features)
+ if len(final) == 1:
+ final = final[0]
+ return final
+
+
+ def set_model_sldd(self, selection, weight, mean, std, bias = None):
+ self.selection = selection
+ self.linear = SLDDLevel(selection, weight, mean, std, bias)
+ self.featureDropout = torch.nn.Dropout(0.1)
\ No newline at end of file
diff --git a/architectures/SLDDLevel.py b/architectures/SLDDLevel.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc214c88f384690d29bda97d7ed82a8c01e866da
--- /dev/null
+++ b/architectures/SLDDLevel.py
@@ -0,0 +1,37 @@
+import torch.nn
+
+
+class SLDDLevel(torch.nn.Module):
+ def __init__(self, selection, weight_at_selection,mean, std, bias=None):
+ super().__init__()
+ self.register_buffer('selection', torch.tensor(selection, dtype=torch.long))
+ num_classes, n_features = weight_at_selection.shape
+ selected_mean = mean
+ selected_std = std
+ if len(selected_mean) != len(selection):
+ selected_mean = selected_mean[selection]
+ selected_std = selected_std[selection]
+ self.mean = torch.nn.Parameter(selected_mean)
+ self.std = torch.nn.Parameter(selected_std)
+ if bias is not None:
+ self.layer = torch.nn.Linear(n_features, num_classes)
+ self.layer.bias = torch.nn.Parameter(bias, requires_grad=False)
+ else:
+ self.layer = torch.nn.Linear(n_features, num_classes, bias=False)
+ self.layer.weight = torch.nn.Parameter(weight_at_selection, requires_grad=False)
+
+ @property
+ def weight(self):
+ return self.layer.weight
+
+ @property
+ def bias(self):
+ if self.layer.bias is None:
+ return torch.zeros(self.layer.out_features)
+ else:
+ return self.layer.bias
+
+
+ def forward(self, input):
+ input = (input - self.mean) / torch.clamp(self.std, min=1e-6)
+ return self.layer(input)
diff --git a/architectures/__pycache__/FinalLayer.cpython-310.pyc b/architectures/__pycache__/FinalLayer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a61e00b4a63ecca65185c2f0157f001dbead798
Binary files /dev/null and b/architectures/__pycache__/FinalLayer.cpython-310.pyc differ
diff --git a/architectures/__pycache__/SLDDLevel.cpython-310.pyc b/architectures/__pycache__/SLDDLevel.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a5d50a5be90c92f0840009d5a78ce7a4a4821df
Binary files /dev/null and b/architectures/__pycache__/SLDDLevel.cpython-310.pyc differ
diff --git a/architectures/__pycache__/model_mapping.cpython-310.pyc b/architectures/__pycache__/model_mapping.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..515201d4169359545fb87981950d8c22a4181b40
Binary files /dev/null and b/architectures/__pycache__/model_mapping.cpython-310.pyc differ
diff --git a/architectures/__pycache__/resnet.cpython-310.pyc b/architectures/__pycache__/resnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93dccc8da3700ae7fdcbfe887f868ef16d935b95
Binary files /dev/null and b/architectures/__pycache__/resnet.cpython-310.pyc differ
diff --git a/architectures/__pycache__/utils.cpython-310.pyc b/architectures/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a224141ff65b31e88257f416c3f7bfc2aaafbfa
Binary files /dev/null and b/architectures/__pycache__/utils.cpython-310.pyc differ
diff --git a/architectures/model_mapping.py b/architectures/model_mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..52df91009029b653b420ff03562616b2389eaa68
--- /dev/null
+++ b/architectures/model_mapping.py
@@ -0,0 +1,7 @@
+from architectures.resnet import resnet50
+
+
+def get_model(arch, num_classes, changed_strides=True):
+ if arch == "resnet50":
+ model = resnet50(True, num_classes=num_classes, changed_strides=changed_strides)
+ return model
\ No newline at end of file
diff --git a/architectures/resnet.py b/architectures/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaaa5d3c22e6ab85f9ac63b29462d20aec9594d3
--- /dev/null
+++ b/architectures/resnet.py
@@ -0,0 +1,420 @@
+import copy
+import time
+
+import torch
+import torch.nn as nn
+from torch.hub import load_state_dict_from_url
+from torchvision.models import get_model
+
+# from scripts.modelExtensions.crossModelfunctions import init_experiment_stuff
+
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2',
+ 'wide_resnet50_3', 'wide_resnet50_4', 'wide_resnet50_5',
+ 'wide_resnet50_6', ]
+
+from architectures.FinalLayer import FinalLayer
+from architectures.utils import SequentialWithArgs
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None, features=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+
+ def forward(self, x, no_relu=False):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+
+
+ out += identity
+
+ if no_relu:
+ return out
+ return self.relu(out)
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None, features=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ if features is None:
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ else:
+ self.conv3 = conv1x1(width, features)
+ self.bn3 = norm_layer(features)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x, no_relu=False, early_exit=False):
+ identity = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ if no_relu:
+ return out
+ return self.relu(out)
+
+
+class ResNet(nn.Module, FinalLayer):
+
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None, changed_strides=False,):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+ widths = [64, 128, 256, 512]
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.sstride = 2
+ if changed_strides:
+ self.sstride = 1
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=self.sstride,
+ dilate=replace_stride_with_dilation[1])
+ self.stride = 2
+
+ if changed_strides:
+ self.stride = 1
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=self.stride,
+ dilate=replace_stride_with_dilation[2])
+ FinalLayer.__init__(self, num_classes, 512 * block.expansion)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False, last_block_f=None):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ krepeep = None
+ if last_block_f is not None and _ == blocks - 1:
+ krepeep = last_block_f
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer, features=krepeep))
+
+ return SequentialWithArgs(*layers)
+
+ def _forward(self, x, with_feature_maps=False, with_final_features=False):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ feature_maps = self.layer4(x, no_relu=True)
+ feature_maps = torch.functional.F.relu(feature_maps)
+ return self.transform_output( feature_maps, with_feature_maps,
+ with_final_features)
+
+ # Allow for accessing forward method in a inherited class
+ forward = _forward
+
+
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ if kwargs["num_classes"] == 1000:
+ state_dict["linear.weight"] = state_dict["fc.weight"]
+ state_dict["linear.bias"] = state_dict["fc.bias"]
+ model.load_state_dict(state_dict, strict=False)
+ return model
+
+
+def resnet18(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet152(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_3(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-3 model
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 3
+ return _resnet('wide_resnet50_3', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_4(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-4 model
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 4
+ return _resnet('wide_resnet50_4', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_5(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-5 model
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 5
+ return _resnet('wide_resnet50_5', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_6(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-6 model
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 6
+ return _resnet('wide_resnet50_6', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
diff --git a/architectures/utils.py b/architectures/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed4cc78fc2799c675098bc73f0a9fb1719fb64b1
--- /dev/null
+++ b/architectures/utils.py
@@ -0,0 +1,17 @@
+import torch
+
+
+
+class SequentialWithArgs(torch.nn.Sequential):
+ def forward(self, input, *args, **kwargs):
+ vs = list(self._modules.values())
+ l = len(vs)
+ for i in range(l):
+ if i == l-1:
+ input = vs[i](input, *args, **kwargs)
+ else:
+ input = vs[i](input)
+ return input
+
+
+
diff --git a/configs/__pycache__/dataset_params.cpython-310.pyc b/configs/__pycache__/dataset_params.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f91c9298daab0217462fe2590a20695b80575e85
Binary files /dev/null and b/configs/__pycache__/dataset_params.cpython-310.pyc differ
diff --git a/configs/__pycache__/optim_params.cpython-310.pyc b/configs/__pycache__/optim_params.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0b3f696299233ff7e5b5d5936ea47d83c8f8b97
Binary files /dev/null and b/configs/__pycache__/optim_params.cpython-310.pyc differ
diff --git a/configs/architecture_params.py b/configs/architecture_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..21a5f2b7fb72dfbee2b4487cabca2d2e840ad938
--- /dev/null
+++ b/configs/architecture_params.py
@@ -0,0 +1 @@
+architecture_params = {"resnet50": {"beta":0.196}}
\ No newline at end of file
diff --git a/configs/dataset_params.py b/configs/dataset_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f227da674de51bf3c4ac0fe3a8faff2004775a6
--- /dev/null
+++ b/configs/dataset_params.py
@@ -0,0 +1,22 @@
+import torch
+
+from configs.optim_params import EvaluatedDict
+
+dataset_constants = {"CUB2011":{"num_classes":200},
+ "TravelingBirds":{"num_classes":200},
+ "ImageNet":{"num_classes":1000},
+ "StanfordCars":{"num_classes":196},
+ "FGVCAircraft": {"num_classes":100}}
+
+normalize_params = {"CUB2011":{"mean": torch.tensor([0.4853, 0.4964, 0.4295]),"std":torch.tensor([0.2300, 0.2258, 0.2625])},
+"TravelingBirds":{"mean": torch.tensor([0.4584, 0.4369, 0.3957]),"std":torch.tensor([0.2610, 0.2569, 0.2722])},
+ "ImageNet":{'mean': torch.tensor([0.485, 0.456, 0.406]),'std': torch.tensor([0.229, 0.224, 0.225])} ,
+"StanfordCars":{'mean': torch.tensor([0.4593, 0.4466, 0.4453]),'std': torch.tensor([0.2920, 0.2910, 0.2988])} ,
+ "FGVCAircraft":{'mean': torch.tensor([0.4827, 0.5130, 0.5352]),
+ 'std': torch.tensor([0.2236, 0.2170, 0.2478]),}
+ }
+
+
+dense_batch_size = EvaluatedDict({False: 16,True: 1024,}, lambda x: x == "ImageNet")
+
+ft_batch_size = EvaluatedDict({False: 16,True: 1024,}, lambda x: x == "ImageNet")# Untested
\ No newline at end of file
diff --git a/configs/optim_params.py b/configs/optim_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0fad011caec798b4d51948b28b4d0885c414b59
--- /dev/null
+++ b/configs/optim_params.py
@@ -0,0 +1,22 @@
+# order: lr,weight_decay, step_lr, step_lr_gamma
+import math
+
+
+class EvaluatedDict:
+ def __init__(self, d, func):
+ self.dict = d
+ self.func = func
+
+ def __getitem__(self, key):
+ return self.dict[self.func(key)]
+
+dense_params = EvaluatedDict({False: [0.005, 0.0005, 30, 0.4, 150],True: [None,None,None,None,None],}, lambda x: x == "ImageNet")
+def calculate_lr_from_args( epochs, step_lr, start_lr, step_lr_decay):
+ # Gets the final learning rate after dense training with step_lr_schedule.
+ n_steps = math.floor((epochs - step_lr) / step_lr)
+ final_lr = start_lr * step_lr_decay ** n_steps
+ return final_lr
+
+ft_params =EvaluatedDict({False: [1e-4, 0.0005, 10, 0.4, 40],True:[[calculate_lr_from_args(150,30,0.005, 0.4), 0.0005, 10, 0.4, 40]]}, lambda x: x == "ImageNet")
+
+
diff --git a/configs/qsenn_training_params.py b/configs/qsenn_training_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ca03c994ee04c47016c89357ff5d4953f634281
--- /dev/null
+++ b/configs/qsenn_training_params.py
@@ -0,0 +1,11 @@
+from configs.sldd_training_params import OptimizationScheduler
+
+
+class QSENNScheduler(OptimizationScheduler):
+ def get_params(self):
+ params = super().get_params()
+ if self.n_calls >= 2:
+ params[0] = params[0] * 0.9**(self.n_calls-2)
+ if 2 <= self.n_calls <= 4:
+ params[-2] = 10# Change num epochs to 10 for iterative finetuning
+ return params
diff --git a/configs/sldd_training_params.py b/configs/sldd_training_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a605602a1a399d0dd55e1f53d8cbaa8c5d73dc0
--- /dev/null
+++ b/configs/sldd_training_params.py
@@ -0,0 +1,17 @@
+from configs.optim_params import dense_params, ft_params
+
+
+class OptimizationScheduler:
+ def __init__(self, dataset):
+ self.dataset = dataset
+ self.n_calls = 0
+
+
+ def get_params(self):
+ if self.n_calls == 0: # Return Deńse Params
+ params = dense_params[self.dataset]+ [False]
+ else: # Return Finetuning Params
+ params = ft_params[self.dataset]+ [True]
+ self.n_calls += 1
+ return params
+
diff --git a/dataset_classes/__pycache__/cub200.cpython-310.pyc b/dataset_classes/__pycache__/cub200.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b65386a367820ce47ae2ecf095fb28397d58df2a
Binary files /dev/null and b/dataset_classes/__pycache__/cub200.cpython-310.pyc differ
diff --git a/dataset_classes/__pycache__/stanfordcars.cpython-310.pyc b/dataset_classes/__pycache__/stanfordcars.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e54b5188bd94de8a82b0639237f4ddf557cd52a
Binary files /dev/null and b/dataset_classes/__pycache__/stanfordcars.cpython-310.pyc differ
diff --git a/dataset_classes/__pycache__/travelingbirds.cpython-310.pyc b/dataset_classes/__pycache__/travelingbirds.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7cc7e30e2027a3730c1873443049962b75998fb0
Binary files /dev/null and b/dataset_classes/__pycache__/travelingbirds.cpython-310.pyc differ
diff --git a/dataset_classes/__pycache__/utils.cpython-310.pyc b/dataset_classes/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bb96db5a1c1e6f6cd7e6efc268cffd6d6e0004b
Binary files /dev/null and b/dataset_classes/__pycache__/utils.cpython-310.pyc differ
diff --git a/dataset_classes/cub200.py b/dataset_classes/cub200.py
new file mode 100644
index 0000000000000000000000000000000000000000..b59a933605948ed45365ccba82486c2c433d4173
--- /dev/null
+++ b/dataset_classes/cub200.py
@@ -0,0 +1,96 @@
+# Dataset should lie under /root/
+# root is currently set to ~/tmp/Datasets/CUB200
+# If cropped iamges, like for PIP-Net, ProtoPool, etc. are used, then the crop_root should be set to a folder containing the
+# cropped images in the expected structure, obtained by following ProtoTree's instructions.
+# https://github.com/M-Nauta/ProtoTree/blob/main/README.md#preprocessing-cub
+import os
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+from torch.utils.data import Dataset
+from torchvision.datasets.folder import default_loader
+
+from dataset_classes.utils import txt_load
+
+
+class CUB200Class(Dataset):
+ root = Path.home() / "tmp/Datasets/CUB200"
+ crop_root = Path.home() / "tmp/Datasets/PPCUB200"
+ base_folder = 'CUB_200_2011/images'
+ def __init__(self, train, transform, crop=True):
+ self.train = train
+ self.transform = transform
+ self.crop = crop
+ self._load_metadata()
+ self.loader = default_loader
+
+ if crop:
+ self.adapt_to_crop()
+
+ def _load_metadata(self):
+ images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
+ names=['img_id', 'filepath'])
+ image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
+ sep=' ', names=['img_id', 'target'])
+ train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
+ sep=' ', names=['img_id', 'is_training_img'])
+ data = images.merge(image_class_labels, on='img_id')
+ self.data = data.merge(train_test_split, on='img_id')
+ if self.train:
+ self.data = self.data[self.data.is_training_img == 1]
+ else:
+ self.data = self.data[self.data.is_training_img == 0]
+
+ def __len__(self):
+ return len(self.data)
+
+ def adapt_to_crop(self):
+ # ds_name = [x for x in self.cropped_dict.keys() if x in self.root][0]
+ self.root = self.crop_root
+ folder_name = "train" if self.train else "test"
+ folder_name = folder_name + "_cropped"
+ self.base_folder = 'CUB_200_2011/' + folder_name
+
+ def __getitem__(self, idx):
+ sample = self.data.iloc[idx]
+ path = os.path.join(self.root, self.base_folder, sample.filepath)
+ target = sample.target - 1 # Targets start at 1 by default, so shift to 0
+ img = self.loader(path)
+ img = self.transform(img)
+ return img, target
+
+ @classmethod
+ def get_image_attribute_labels(self, train=False):
+ image_attribute_labels = pd.read_csv(
+ os.path.join('/home/norrenbr/tmp/Datasets/CUB200', 'CUB_200_2011', "attributes",
+ 'image_attribute_labels.txt'),
+ sep=' ', names=['img_id', 'attribute', "is_present", "certainty", "time"], on_bad_lines="skip")
+ train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
+ sep=' ', names=['img_id', 'is_training_img'])
+ merged = image_attribute_labels.merge(train_test_split, on="img_id")
+ filtered_data = merged[merged["is_training_img"] == train]
+ return filtered_data
+
+
+ @staticmethod
+ def filter_attribute_labels(labels, min_certainty=3):
+ is_invisible_present = labels[labels["certainty"] == 1]["is_present"].sum()
+ if is_invisible_present != 0:
+ raise ValueError("Invisible present")
+ labels["img_id"] -= min(labels["img_id"])
+ labels["img_id"] = fillholes_in_array(labels["img_id"])
+ labels[labels["certainty"] == 1]["certainty"] = 4
+ labels = labels[labels["certainty"] >= min_certainty]
+ labels["attribute"] -= min(labels["attribute"])
+ labels = labels[["img_id", "attribute", "is_present"]]
+ labels["is_present"] = labels["is_present"].astype(bool)
+ return labels
+
+
+
+def fillholes_in_array(array):
+ unique_values = np.unique(array)
+ mapping = {x: i for i, x in enumerate(unique_values)}
+ array = array.map(mapping)
+ return array
diff --git a/dataset_classes/stanfordcars.py b/dataset_classes/stanfordcars.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be682a5d164a8b39cff5bd9cca82cc8cf5ebe53
--- /dev/null
+++ b/dataset_classes/stanfordcars.py
@@ -0,0 +1,121 @@
+import pathlib
+from typing import Callable, Optional, Any, Tuple
+
+import numpy as np
+import pandas as pd
+from PIL import Image
+from torchvision.datasets import VisionDataset
+from torchvision.datasets.utils import download_and_extract_archive, download_url
+
+
+class StanfordCarsClass(VisionDataset):
+ """`Stanford Cars `_ Dataset
+
+ The Cars dataset contains 16,185 images of 196 classes of cars. The data is
+ split into 8,144 training images and 8,041 testing images, where each class
+ has been split roughly in a 50-50 split
+
+ .. note::
+
+ This class needs `scipy `_ to load target files from `.mat` format.
+
+ Args:
+ root (string): Root directory of dataset
+ split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again."""
+ root = pathlib.Path.home() / "tmp" / "Datasets" / "StanfordCars"
+ def __init__(
+ self,
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = True,
+ ) -> None:
+
+ try:
+ import scipy.io as sio
+ except ImportError:
+ raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+ super().__init__(self.root, transform=transform, target_transform=target_transform)
+
+ self.train = train
+ self._base_folder = pathlib.Path(self.root) / "stanford_cars"
+ devkit = self._base_folder / "devkit"
+
+ if train:
+ self._annotations_mat_path = devkit / "cars_train_annos.mat"
+ self._images_base_path = self._base_folder / "cars_train"
+ else:
+ self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
+ self._images_base_path = self._base_folder / "cars_test"
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ self.samples = [
+ (
+ str(self._images_base_path / annotation["fname"]),
+ annotation["class"] - 1, # Original target mapping starts from 1, hence -1
+ )
+ for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
+ ]
+ self.targets = np.array([x[1] for x in self.samples])
+ self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
+ self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
+
+ def __len__(self) -> int:
+ return len(self.samples)
+
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+ """Returns pil_image and class_id for given index"""
+ image_path, target = self.samples[idx]
+ pil_image = Image.open(image_path).convert("RGB")
+
+ if self.transform is not None:
+ pil_image = self.transform(pil_image)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return pil_image, target
+
+ def download(self) -> None:
+ if self._check_exists():
+ return
+
+ download_and_extract_archive(
+ url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
+ download_root=str(self._base_folder),
+ md5="c3b158d763b6e2245038c8ad08e45376",
+ )
+ if self.train:
+ download_and_extract_archive(
+ url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
+ download_root=str(self._base_folder),
+ md5="065e5b463ae28d29e77c1b4b166cfe61",
+ )
+ else:
+ download_and_extract_archive(
+ url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
+ download_root=str(self._base_folder),
+ md5="4ce7ebf6a94d07f1952d94dd34c4d501",
+ )
+ download_url(
+ url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
+ root=str(self._base_folder),
+ md5="b0a2b23655a3edd16d84508592a98d10",
+ )
+
+ def _check_exists(self) -> bool:
+ if not (self._base_folder / "devkit").is_dir():
+ return False
+
+ return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
diff --git a/dataset_classes/travelingbirds.py b/dataset_classes/travelingbirds.py
new file mode 100644
index 0000000000000000000000000000000000000000..551ce1fd46b9b84e572ea18f4adc6ecd73cea00d
--- /dev/null
+++ b/dataset_classes/travelingbirds.py
@@ -0,0 +1,59 @@
+# TravelingBirds dataset needs to be downloaded from https://worksheets.codalab.org/bundles/0x518829de2aa440c79cd9d75ef6669f27
+# as it comes from https://github.com/yewsiang/ConceptBottleneck
+import os
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+
+from dataset_classes.cub200 import CUB200Class
+from dataset_classes.utils import index_list_with_sorting, mask_list
+
+
+class TravelingBirds(CUB200Class):
+ init_base_folder = 'CUB_fixed'
+ root = Path.home() / "tmp/Datasets/TravelingBirds"
+ crop_root = Path.home() / "tmp/Datasets/PPTravelingBirds"
+ def get_all_samples_dir(self, dir):
+
+ self.base_folder = os.path.join(self.init_base_folder, dir)
+ main_dir = Path(self.root) / self.init_base_folder / dir
+ return self.get_all_sample(main_dir)
+
+ def adapt_to_crop(self):
+ self.root = self.crop_root
+ folder_name = "train" if self.train else "test"
+ folder_name = folder_name + "_cropped"
+ self.base_folder = 'CUB_fixed/' + folder_name
+
+ def get_all_sample(self, dir):
+ answer = []
+ for i, sub_dir in enumerate(sorted(os.listdir(dir))):
+ class_dir = dir / sub_dir
+ for single_img in os.listdir(class_dir):
+ answer.append([Path(sub_dir) / single_img, i + 1])
+ return answer
+ def _load_metadata(self):
+ train_test_split = pd.read_csv(
+ os.path.join(Path(self.root).parent / "CUB200", 'CUB_200_2011', 'train_test_split.txt'),
+ sep=' ', names=['img_id', 'is_training_img'])
+ data = pd.read_csv(
+ os.path.join(Path(self.root).parent / "CUB200", 'CUB_200_2011', 'images.txt'),
+ sep=' ', names=['img_id', "path"])
+ img_dict = {x[1]: x[0] for x in data.values}
+ # TravelingBirds has all train+test images in both folders, just with different backgrounds.
+ # They are separated by train_test_split of CUB200.
+ if self.train:
+ samples = self.get_all_samples_dir("train")
+ mask = train_test_split["is_training_img"] == 1
+ else:
+ samples = self.get_all_samples_dir("test")
+ mask = train_test_split["is_training_img"] == 0
+ ids = np.array([img_dict[str(x[0])] for x in samples])
+ sorted = np.argsort(ids)
+ samples = index_list_with_sorting(samples, sorted)
+ samples = mask_list(samples, mask)
+ filepaths = [x[0] for x in samples]
+ labels = [x[1] for x in samples]
+ samples = pd.DataFrame({"filepath": filepaths, "target": labels})
+ self.data = samples
diff --git a/dataset_classes/utils.py b/dataset_classes/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f0039ba93d0966230da88f2568d02bb7cebeebf
--- /dev/null
+++ b/dataset_classes/utils.py
@@ -0,0 +1,16 @@
+def index_list_with_sorting(list_to_sort, sorting_list):
+ answer = []
+ for entry in sorting_list:
+ answer.append(list_to_sort[entry])
+ return answer
+
+
+def mask_list(list_input, mask):
+ return [x for i, x in enumerate(list_input) if mask[i]]
+
+
+def txt_load(filename):
+ with open(filename, 'r') as f:
+ data = f.read()
+ return data
+
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e4e2f7b3680115f3e38c80511baede60fda0db03
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,117 @@
+name: QSENNEnv
+channels:
+ - pytorch
+ - nvidia
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - blas=1.0=mkl
+ - brotli-python=1.0.9=py310h6a678d5_7
+ - bzip2=1.0.8=h7b6447c_0
+ - ca-certificates=2023.12.12=h06a4308_0
+ - certifi=2023.11.17=py310h06a4308_0
+ - cffi=1.16.0=py310h5eee18b_0
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
+ - cryptography=41.0.7=py310hdda0065_0
+ - cuda-cudart=12.1.105=0
+ - cuda-cupti=12.1.105=0
+ - cuda-libraries=12.1.0=0
+ - cuda-nvrtc=12.1.105=0
+ - cuda-nvtx=12.1.105=0
+ - cuda-opencl=12.3.101=0
+ - cuda-runtime=12.1.0=0
+ - ffmpeg=4.3=hf484d3e_0
+ - filelock=3.13.1=py310h06a4308_0
+ - freetype=2.12.1=h4a9f257_0
+ - giflib=5.2.1=h5eee18b_3
+ - gmp=6.2.1=h295c915_3
+ - gmpy2=2.1.2=py310heeb90bb_0
+ - gnutls=3.6.15=he1e5248_0
+ - idna=3.4=py310h06a4308_0
+ - intel-openmp=2023.1.0=hdb19cb5_46306
+ - jinja2=3.1.2=py310h06a4308_0
+ - jpeg=9e=h5eee18b_1
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - lerc=3.0=h295c915_0
+ - libcublas=12.1.0.26=0
+ - libcufft=11.0.2.4=0
+ - libcufile=1.8.1.2=0
+ - libcurand=10.3.4.107=0
+ - libcusolver=11.4.4.55=0
+ - libcusparse=12.0.2.55=0
+ - libdeflate=1.17=h5eee18b_1
+ - libffi=3.4.4=h6a678d5_0
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libiconv=1.16=h7f8727e_2
+ - libidn2=2.3.4=h5eee18b_0
+ - libjpeg-turbo=2.0.0=h9bf148f_0
+ - libnpp=12.0.2.50=0
+ - libnvjitlink=12.1.105=0
+ - libnvjpeg=12.1.1.14=0
+ - libpng=1.6.39=h5eee18b_0
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libtasn1=4.19.0=h5eee18b_0
+ - libtiff=4.5.1=h6a678d5_0
+ - libunistring=0.9.10=h27cfd23_0
+ - libuuid=1.41.5=h5eee18b_0
+ - libwebp=1.3.2=h11a3e52_0
+ - libwebp-base=1.3.2=h5eee18b_0
+ - llvm-openmp=14.0.6=h9e868ea_0
+ - lz4-c=1.9.4=h6a678d5_0
+ - markupsafe=2.1.3=py310h5eee18b_0
+ - mkl=2023.1.0=h213fc3f_46344
+ - mkl-service=2.4.0=py310h5eee18b_1
+ - mkl_fft=1.3.8=py310h5eee18b_0
+ - mkl_random=1.2.4=py310hdb19cb5_0
+ - mpc=1.1.0=h10f8cd9_1
+ - mpfr=4.0.2=hb69a4c5_1
+ - mpmath=1.3.0=py310h06a4308_0
+ - ncurses=6.4=h6a678d5_0
+ - nettle=3.7.3=hbbd107a_1
+ - networkx=3.1=py310h06a4308_0
+ - numpy=1.26.3=py310h5f9d8c6_0
+ - numpy-base=1.26.3=py310hb5e798b_0
+ - openh264=2.1.1=h4ff587b_0
+ - openjpeg=2.4.0=h3ad879b_0
+ - openssl=3.0.12=h7f8727e_0
+ - pillow=10.0.1=py310ha6cbd5a_0
+ - pip=23.3.1=py310h06a4308_0
+ - pycparser=2.21=pyhd3eb1b0_0
+ - pyopenssl=23.2.0=py310h06a4308_0
+ - pysocks=1.7.1=py310h06a4308_0
+ - python=3.10.13=h955ad1f_0
+ - pytorch=2.1.2=py3.10_cuda12.1_cudnn8.9.2_0
+ - pytorch-cuda=12.1=ha16c6d3_5
+ - pytorch-mutex=1.0=cuda
+ - pyyaml=6.0.1=py310h5eee18b_0
+ - readline=8.2=h5eee18b_0
+ - requests=2.31.0=py310h06a4308_0
+ - setuptools=68.2.2=py310h06a4308_0
+ - sqlite=3.41.2=h5eee18b_0
+ - sympy=1.12=py310h06a4308_0
+ - tbb=2021.8.0=hdb19cb5_0
+ - tk=8.6.12=h1ccaba5_0
+ - torchaudio=2.1.2=py310_cu121
+ - torchtriton=2.1.0=py310
+ - torchvision=0.16.2=py310_cu121
+ - typing_extensions=4.7.1=py310h06a4308_0
+ - urllib3=1.26.18=py310h06a4308_0
+ - wheel=0.41.2=py310h06a4308_0
+ - xz=5.4.5=h5eee18b_0
+ - yaml=0.2.5=h7b6447c_0
+ - zlib=1.2.13=h5eee18b_0
+ - zstd=1.5.5=hc292b87_0
+ - pip:
+ - fsspec==2023.12.2
+ - glm-saga==0.1.2
+ - pandas==2.1.4
+ - python-dateutil==2.8.2
+ - pytz==2023.3.post1
+ - six==1.16.0
+ - tqdm==4.66.1
+ - tzdata==2023.4
+prefix: /home/norrenbr/anaconda/tmp/envs/QSENN-Minimal
diff --git a/evaluation/Metrics/Dependence.py b/evaluation/Metrics/Dependence.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f1b26dfc19de0430925e38aac45ebcc33a94455
--- /dev/null
+++ b/evaluation/Metrics/Dependence.py
@@ -0,0 +1,21 @@
+import torch
+
+
+def compute_contribution_top_feature(features, outputs, weights, labels):
+ with torch.no_grad():
+ total_pre_softmax, predicted_classes = torch.max(outputs, dim=1)
+ feature_part = features * weights.to(features.device)[predicted_classes]
+ class_specific_feature_part = torch.zeros((weights.shape[0], features.shape[1],))
+ feature_class_part = torch.zeros((weights.shape[0], features.shape[1],))
+ for unique_class in predicted_classes.unique():
+ mask = predicted_classes == unique_class
+ class_specific_feature_part[unique_class] = feature_part[mask].mean(dim=0)
+ gt_mask = labels == unique_class
+ feature_class_part[unique_class] = feature_part[gt_mask].mean(dim=0)
+ abs_features = feature_part.abs()
+ abs_sum = abs_features.sum(dim=1)
+ fractions_abs = abs_features / abs_sum[:, None]
+ abs_max = fractions_abs.max(dim=1)[0]
+ mask = ~torch.isnan(abs_max)
+ abs_max = abs_max[mask]
+ return abs_max.mean()
\ No newline at end of file
diff --git a/evaluation/Metrics/__pycache__/Dependence.cpython-310.pyc b/evaluation/Metrics/__pycache__/Dependence.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d59c7cc91e9d0c3e84e533ca3876e3ee9850c52
Binary files /dev/null and b/evaluation/Metrics/__pycache__/Dependence.cpython-310.pyc differ
diff --git a/evaluation/Metrics/__pycache__/cub_Alignment.cpython-310.pyc b/evaluation/Metrics/__pycache__/cub_Alignment.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0de63438f9c10a07e3524cca02c53276c1d7622
Binary files /dev/null and b/evaluation/Metrics/__pycache__/cub_Alignment.cpython-310.pyc differ
diff --git a/evaluation/Metrics/cub_Alignment.py b/evaluation/Metrics/cub_Alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b4b41e427668f86ec530baab1796ac9d0678489
--- /dev/null
+++ b/evaluation/Metrics/cub_Alignment.py
@@ -0,0 +1,30 @@
+import numpy as np
+
+from dataset_classes.cub200 import CUB200Class
+
+
+def get_cub_alignment_from_features(features_train_sorted):
+ metric_matrix = compute_metric_matrix(np.array(features_train_sorted), "train")
+ return np.mean(np.max(metric_matrix, axis=1))
+ pass
+
+
+def compute_metric_matrix(features, mode):
+ image_attribute_labels = CUB200Class.get_image_attribute_labels(train=mode == "train")
+ image_attribute_labels = CUB200Class.filter_attribute_labels(image_attribute_labels)
+ matrix_shape = (
+ features.shape[1], max(image_attribute_labels["attribute"]) + 1)
+ accuracy_matrix = np.zeros(matrix_shape)
+ sensitivity_matrix = np.zeros_like(accuracy_matrix)
+ grouped_attributes = image_attribute_labels.groupby("attribute")
+ for attribute_id, group in grouped_attributes:
+ is_present = group[group["is_present"]]
+ not_present = group[~group["is_present"]]
+ is_present_avg = np.mean(features[is_present["img_id"]], axis=0)
+ not_present_avg = np.mean(features[not_present["img_id"]], axis=0)
+ sensitivity_matrix[:, attribute_id] = not_present_avg
+ accuracy_matrix[:, attribute_id] = is_present_avg
+ metric_matrix = accuracy_matrix - sensitivity_matrix
+ no_abs_features = features - np.min(features, axis=0)
+ no_abs_feature_mean = metric_matrix / no_abs_features.mean(axis=0)[:, None]
+ return no_abs_feature_mean
diff --git a/evaluation/__pycache__/diversity.cpython-310.pyc b/evaluation/__pycache__/diversity.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d65f89a660c7660f26ce7578e34557c87b970b66
Binary files /dev/null and b/evaluation/__pycache__/diversity.cpython-310.pyc differ
diff --git a/evaluation/__pycache__/helpers.cpython-310.pyc b/evaluation/__pycache__/helpers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f7034adbe8c30d420975414880bb581f2052080
Binary files /dev/null and b/evaluation/__pycache__/helpers.cpython-310.pyc differ
diff --git a/evaluation/__pycache__/qsenn_metrics.cpython-310.pyc b/evaluation/__pycache__/qsenn_metrics.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9460b6362b47ba4ddf699b74707b87b5a063ce73
Binary files /dev/null and b/evaluation/__pycache__/qsenn_metrics.cpython-310.pyc differ
diff --git a/evaluation/__pycache__/utils.cpython-310.pyc b/evaluation/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e70bce079f72b6f83344a1decd675b1312655b6
Binary files /dev/null and b/evaluation/__pycache__/utils.cpython-310.pyc differ
diff --git a/evaluation/diversity.py b/evaluation/diversity.py
new file mode 100644
index 0000000000000000000000000000000000000000..033679ce9cf4546b74b0d1d4bdb6b8590c5c8865
--- /dev/null
+++ b/evaluation/diversity.py
@@ -0,0 +1,111 @@
+import numpy as np
+import torch
+
+from evaluation.helpers import softmax_feature_maps
+
+
+class MultiKCrossChannelMaxPooledSum:
+ def __init__(self, top_k_range, weights, interactions, func="softmax"):
+ self.top_k_range = top_k_range
+ self.weights = weights
+ self.failed = False
+ self.max_ks = self.get_max_ks(weights)
+ self.locality_of_used_features = torch.zeros(len(top_k_range), device=weights.device)
+ self.locality_of_exclusely_used_features = torch.zeros(len(top_k_range), device=weights.device)
+ self.ns_k = torch.zeros(len(top_k_range), device=weights.device)
+ self.exclusive_ns = torch.zeros(len(top_k_range), device=weights.device)
+ self.interactions = interactions
+ self.func = func
+
+ def get_max_ks(self, weights):
+ nonzeros = torch.count_nonzero(torch.tensor(weights), 1)
+ return nonzeros
+
+ def get_top_n_locality(self, outputs, initial_feature_maps, k):
+ feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs,
+ initial_feature_maps)
+ max_ks = self.max_ks[top_classes]
+ max_k_based_row_selection = max_ks >= k
+
+ result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps,
+ separated=True)
+ return result
+
+ def get_locality(self, outputs, initial_feature_maps, n):
+ answer = self.get_top_n_locality(outputs, initial_feature_maps, n)
+ return answer
+
+ def get_result(self):
+ # if torch.sum(self.exclusive_ns) ==0:
+ # end_idx = len(self.exclusive_ns) - 1
+ # else:
+
+ exclusive_array = torch.zeros_like(self.locality_of_exclusely_used_features)
+ local_array = torch.zeros_like(self.locality_of_used_features)
+ # if self.failed:
+ # return local_array, exclusive_array
+ cumulated = torch.cumsum(self.exclusive_ns, 0)
+ end_idx = torch.argmax(cumulated)
+ exclusivity_array = self.locality_of_exclusely_used_features[:end_idx + 1] / self.exclusive_ns[:end_idx + 1]
+ exclusivity_array[exclusivity_array != exclusivity_array] = 0
+ exclusive_array[:len(exclusivity_array)] = exclusivity_array
+ locality_array = self.locality_of_used_features[self.locality_of_used_features != 0] / self.ns_k[
+ self.locality_of_used_features != 0]
+ local_array[:len(locality_array)] = locality_array
+ return local_array, exclusive_array
+
+ def get_crosspooled(self, relevant_weights, mask, k, vector_size, feature_maps, separated=False):
+ relevant_indices = get_relevant_indices(relevant_weights, k)[mask]
+ # this should have size batch x k x featuremapsize squared]
+ indices = relevant_indices.unsqueeze(2).repeat(1, 1, vector_size)
+ sub_feature_maps = torch.gather(feature_maps[mask], 1, indices)
+ # shape batch x featuremapsquared: For each "pixel" the highest value
+ cross_pooled = torch.max(sub_feature_maps, 1)[0]
+ if separated:
+ return torch.sum(cross_pooled, 1) / k
+ else:
+ ns = len(cross_pooled)
+ result = torch.sum(cross_pooled) / (k)
+ # should be batch x map size
+
+ return ns, result
+
+ def adapt_feature_maps(self, outputs, initial_feature_maps):
+ if self.func == "softmax":
+ feature_maps = softmax_feature_maps(initial_feature_maps)
+ feature_maps = torch.flatten(feature_maps, 2)
+ vector_size = feature_maps.shape[2]
+ top_classes = torch.argmax(outputs, dim=1)
+ relevant_weights = self.weights[top_classes]
+ if relevant_weights.shape[1] != feature_maps.shape[1]:
+ feature_maps = self.interactions.get_localized_features(initial_feature_maps)
+ feature_maps = softmax_feature_maps(feature_maps)
+ feature_maps = torch.flatten(feature_maps, 2)
+ return feature_maps, relevant_weights, vector_size, top_classes
+
+ def calculate_locality(self, outputs, initial_feature_maps):
+ feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs,
+ initial_feature_maps)
+ max_ks = self.max_ks[top_classes]
+ for k in self.top_k_range:
+ # relevant_k_s = max_ks[]
+ max_k_based_row_selection = max_ks >= k
+ if torch.sum(max_k_based_row_selection) == 0:
+ break
+
+ exclusive_k = max_ks == k
+ if torch.sum(exclusive_k) != 0:
+ ns, result = self.get_crosspooled(relevant_weights, exclusive_k, k, vector_size, feature_maps)
+ self.locality_of_exclusely_used_features[k - 1] += result
+ self.exclusive_ns[k - 1] += ns
+ ns, result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps)
+ self.ns_k[k - 1] += ns
+ self.locality_of_used_features[k - 1] += result
+
+ def __call__(self, outputs, initial_feature_maps):
+ self.calculate_locality(outputs, initial_feature_maps)
+
+
+def get_relevant_indices(weights, top_k):
+ top_k = weights.topk(top_k)[1]
+ return top_k
\ No newline at end of file
diff --git a/evaluation/helpers.py b/evaluation/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe4a9902103fe63df01994acb079127ab719c9f1
--- /dev/null
+++ b/evaluation/helpers.py
@@ -0,0 +1,6 @@
+import torch
+
+
+def softmax_feature_maps(x):
+ # done: verify that this applies softmax along first dimension
+ return torch.softmax(x.reshape(x.size(0), x.size(1), -1), 2).view_as(x)
\ No newline at end of file
diff --git a/evaluation/qsenn_metrics.py b/evaluation/qsenn_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bb8f21b6f7dfe101c8e668e9c422c1d88ce8751
--- /dev/null
+++ b/evaluation/qsenn_metrics.py
@@ -0,0 +1,39 @@
+import numpy as np
+import torch
+
+from evaluation.Metrics.Dependence import compute_contribution_top_feature
+from evaluation.Metrics.cub_Alignment import get_cub_alignment_from_features
+from evaluation.diversity import MultiKCrossChannelMaxPooledSum
+from evaluation.utils import get_metrics_for_model
+
+
+def evaluateALLMetricsForComps(features_train, outputs_train, feature_maps_test,
+ outputs_test, linear_matrix, labels_train):
+ with torch.no_grad():
+ if len(features_train) < 7000: # recognize CUB and TravelingBirds
+ cub_alignment = get_cub_alignment_from_features(features_train)
+ else:
+ cub_alignment = 0
+ print("cub_alignment: ", cub_alignment)
+ localizer = MultiKCrossChannelMaxPooledSum(range(1, 6), linear_matrix, None)
+ batch_size = 300
+ for i in range(np.floor(len(features_train) / batch_size).astype(int)):
+ localizer(outputs_test[i * batch_size:(i + 1) * batch_size].to("cuda"),
+ feature_maps_test[i * batch_size:(i + 1) * batch_size].to("cuda"))
+
+ locality, exlusive_locality = localizer.get_result()
+ diversity = locality[4]
+ print("diversity@5: ", diversity)
+ abs_frac_mean = compute_contribution_top_feature(
+ features_train,
+ outputs_train,
+ linear_matrix,
+ labels_train)
+ print("Dependence ", abs_frac_mean)
+ answer_dict = {"diversity": diversity.item(), "Dependence": abs_frac_mean.item(), "Alignment":cub_alignment}
+ return answer_dict
+
+def eval_model_on_all_qsenn_metrics(model, test_loader, train_loader):
+ return get_metrics_for_model(train_loader, test_loader, model, evaluateALLMetricsForComps)
+
+
diff --git a/evaluation/utils.py b/evaluation/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef1b679fc9dac88e2fb897d69c34d959c19b3101
--- /dev/null
+++ b/evaluation/utils.py
@@ -0,0 +1,57 @@
+import torch
+from tqdm import tqdm
+
+
+
+def get_metrics_for_model(train_loader, test_loader, model, metric_evaluator):
+ (features_train, feature_maps_train, outputs_train, features_test, feature_maps_test,
+ outputs_test, labels) = [], [], [], [], [], [], []
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model.eval()
+ model = model.to(device)
+ training_transforms = train_loader.dataset.transform
+ train_loader.dataset.transform = test_loader.dataset.transform # Use test transform for train
+ train_loader = torch.utils.data.DataLoader(train_loader.dataset, batch_size=100, shuffle=False) # Turn off shuffling
+ print("Going in get metrics")
+ linear_matrix = model.linear.weight
+ entries = torch.nonzero(linear_matrix)
+ rel_features = torch.unique(entries[:, 1])
+ with torch.no_grad():
+ iterator = tqdm(enumerate(train_loader), total=len(train_loader))
+ for batch_idx, (data, target) in iterator:
+ xs1 = data.to("cuda")
+ output, feature_maps, final_features = model(xs1, with_feature_maps=True, with_final_features=True,)
+ outputs_train.append(output.to("cpu"))
+ features_train.append(final_features.to("cpu"))
+ labels.append(target.to("cpu"))
+ total = 0
+ correct = 0
+ iterator = tqdm(enumerate(test_loader), total=len(test_loader))
+ for batch_idx, (data, target) in iterator:
+ xs1 = data.to("cuda")
+ output, feature_maps, final_features = model(xs1, with_feature_maps=True,
+ with_final_features=True, )
+ feature_maps_test.append(feature_maps[:, rel_features].to("cpu"))
+ outputs_test.append(output.to("cpu"))
+ total += target.size(0)
+ _, predicted = output.max(1)
+ correct += predicted.eq(target.to("cuda")).sum().item()
+ print("test accuracy: ", correct / total)
+ features_train = torch.cat(features_train)
+ outputs_train = torch.cat(outputs_train)
+ feature_maps_test = torch.cat(feature_maps_test)
+ outputs_test = torch.cat(outputs_test)
+ labels = torch.cat(labels)
+ linear_matrix = linear_matrix[:, rel_features]
+ print("Shape of linear matrix: ", linear_matrix.shape)
+ all_metrics_dict = metric_evaluator(features_train, outputs_train,
+ feature_maps_test,
+ outputs_test, linear_matrix, labels)
+ result_dict = {"Accuracy": correct / total, "NFfeatures": linear_matrix.shape[1],
+ "PerClass": torch.nonzero(linear_matrix).shape[0] / linear_matrix.shape[0],
+ }
+ result_dict.update(all_metrics_dict)
+ print(result_dict)
+ # Reset Train transforms
+ train_loader.dataset.transform = training_transforms
+ return result_dict
diff --git a/fig/AutoML4FAS_Logo.jpeg b/fig/AutoML4FAS_Logo.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..35d4066fa5cf5967553960097b57f80c2ac8c580
Binary files /dev/null and b/fig/AutoML4FAS_Logo.jpeg differ
diff --git a/fig/Bund.png b/fig/Bund.png
new file mode 100644
index 0000000000000000000000000000000000000000..1c92a104515f9b3c61642f7cd3cc898163e5ef0e
Binary files /dev/null and b/fig/Bund.png differ
diff --git a/fig/LUH.png b/fig/LUH.png
new file mode 100644
index 0000000000000000000000000000000000000000..af168ab3e866e5c66c616b6a090ef9c4ac212e3b
Binary files /dev/null and b/fig/LUH.png differ
diff --git a/fig/birds.png b/fig/birds.png
new file mode 100644
index 0000000000000000000000000000000000000000..330ebdff52c39b989a5c0cd42e0a35fdbeb7c1ff
Binary files /dev/null and b/fig/birds.png differ
diff --git a/finetuning/map_function.py b/finetuning/map_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aa65c3fa6dee0dc55484bdaae3fb181786eed1b
--- /dev/null
+++ b/finetuning/map_function.py
@@ -0,0 +1,11 @@
+from finetuning.qsenn import finetune_qsenn
+from finetuning.sldd import finetune_sldd
+
+
+def finetune(key, model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule, per_class, n_features):
+ if key == 'sldd':
+ return finetune_sldd(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,per_class, n_features)
+ elif key == 'qsenn':
+ return finetune_qsenn(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,n_features,per_class, )
+ else:
+ raise ValueError(f"Unknown Finetuning key: {key}")
\ No newline at end of file
diff --git a/finetuning/qsenn.py b/finetuning/qsenn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce4dc8b65e6c703e51fe602c2ac897c97844897c
--- /dev/null
+++ b/finetuning/qsenn.py
@@ -0,0 +1,30 @@
+import os
+
+import torch
+
+from finetuning.utils import train_n_epochs
+from sparsification.qsenn import compute_qsenn_feature_selection_and_assignment
+
+
+def finetune_qsenn(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule ,n_features, n_per_class):
+ for iteration_epoch in range(4):
+ print(f"Starting iteration epoch {iteration_epoch}")
+ this_log_dir = log_dir / f"iteration_epoch_{iteration_epoch}"
+ this_log_dir.mkdir(parents=True, exist_ok=True)
+ feature_sel, sparse_layer,bias_sparse, current_mean, current_std = compute_qsenn_feature_selection_and_assignment(model, train_loader,
+ test_loader,
+ this_log_dir, n_classes, seed, n_features, n_per_class)
+ model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
+ if os.path.exists(this_log_dir / "trained_model.pth"):
+ model.load_state_dict(torch.load(this_log_dir / "trained_model.pth"))
+ _ = optimization_schedule.get_params() # count up, to have get correct lr
+ continue
+
+ model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader)
+ torch.save(model.state_dict(), this_log_dir / "trained_model.pth")
+ print(f"Finished iteration epoch {iteration_epoch}")
+ return model
+
+
+
+
diff --git a/finetuning/sldd.py b/finetuning/sldd.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c8ac0034b14cbbf460f0bf59e25dfd8188ee94b
--- /dev/null
+++ b/finetuning/sldd.py
@@ -0,0 +1,22 @@
+import numpy as np
+import torch
+
+from FeatureDiversityLoss import FeatureDiversityLoss
+from finetuning.utils import train_n_epochs
+from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment
+from sparsification.sldd import compute_sldd_feature_selection_and_assignment
+from train import train, test
+from training.optim import get_optimizer
+
+
+
+
+def finetune_sldd(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,n_per_class, n_features, ):
+ feature_sel, weight, bias, mean, std = compute_sldd_feature_selection_and_assignment(model, train_loader,
+ test_loader,
+ log_dir, n_classes, seed,n_per_class, n_features)
+ model.set_model_sldd(feature_sel, weight, mean, std, bias)
+ model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader)
+ return model
+
+
diff --git a/finetuning/utils.py b/finetuning/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af751e2094c5ba6b2f83adadb5059f692329db37
--- /dev/null
+++ b/finetuning/utils.py
@@ -0,0 +1,14 @@
+from FeatureDiversityLoss import FeatureDiversityLoss
+from train import train, test
+from training.optim import get_optimizer
+
+
+def train_n_epochs(model, beta,optimization_schedule, train_loader, test_loader):
+ optimizer, schedule, epochs = get_optimizer(model, optimization_schedule)
+ fdl = FeatureDiversityLoss(beta, model.linear)
+ for epoch in range(epochs):
+ model = train(model, train_loader, optimizer, fdl, epoch)
+ schedule.step()
+ if epoch % 5 == 0 or epoch+1 == epochs:
+ test(model, test_loader, epoch)
+ return model
\ No newline at end of file
diff --git a/flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg b/flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6450f6174bdd37cace75c6b32a029bcfa8761ed7
Binary files /dev/null and b/flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg differ
diff --git a/flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg b/flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f9f6063f1c6130694ddd53c0231b317abe9ef03b
Binary files /dev/null and b/flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg differ
diff --git a/flagged/log.csv b/flagged/log.csv
new file mode 100644
index 0000000000000000000000000000000000000000..5af3d3f8c5830b52178c12538580c9cd038fd2e4
--- /dev/null
+++ b/flagged/log.csv
@@ -0,0 +1,3 @@
+input,output,flag,username,timestamp
+flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg,,,,2024-10-21 12:37:51.541901
+flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg,"[{""image"": ""flagged/output/e2f704607c002e0c557d/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1b4541c3e93f034d746d/image.webp"", ""caption"": null}, {""image"": ""flagged/output/f8727dcfa3c59de0d873/image.webp"", ""caption"": null}, {""image"": ""flagged/output/c4b75e9fbc946f6ead6d/image.webp"", ""caption"": null}, {""image"": ""flagged/output/5b5ad2dd997a635f4917/image.webp"", ""caption"": null}, {""image"": ""flagged/output/b066004e4a0114aa705b/image.webp"", ""caption"": null}, {""image"": ""flagged/output/036072cdcc620de8cb65/image.webp"", ""caption"": null}, {""image"": ""flagged/output/218135cb251eb6cd0b2c/image.webp"", ""caption"": null}, {""image"": ""flagged/output/2a0671ba5ac1aa3bd2b9/image.webp"", ""caption"": null}, {""image"": ""flagged/output/595953adce3a654bbd33/image.webp"", ""caption"": null}, {""image"": ""flagged/output/f333c69915509927b2ff/image.webp"", ""caption"": null}, {""image"": ""flagged/output/a966f50f23644e5046e8/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1a8a9e53fd4990fe5231/image.webp"", ""caption"": null}, {""image"": ""flagged/output/d7bc2f0eb8d70a562542/image.webp"", ""caption"": null}, {""image"": ""flagged/output/53fd53c5eab644d30338/image.webp"", ""caption"": null}, {""image"": ""flagged/output/ddf6b8ddc855838cc3b5/image.webp"", ""caption"": null}, {""image"": ""flagged/output/41a99b70366ac01533b4/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1b4ae8362917e14cb7a7/image.webp"", ""caption"": null}, {""image"": ""flagged/output/b321456290561eacf170/image.webp"", ""caption"": null}, {""image"": ""flagged/output/42d34c69c2384bda376b/image.webp"", ""caption"": null}, {""image"": ""flagged/output/35d0e9ae554c0b863ef3/image.webp"", ""caption"": null}, {""image"": ""flagged/output/799f55238c434907570f/image.webp"", ""caption"": null}, {""image"": ""flagged/output/db82081afaabf2fb505b/image.webp"", ""caption"": null}, {""image"": ""flagged/output/fff73f12467314dce395/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1bd17ff3896c5045b453/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e31f93405e1526fe3e55/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e9c9ff1da0805da0c0d8/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e6ef5ba2d6c65b3c1d21/image.webp"", ""caption"": null}, {""image"": ""flagged/output/f763a51fb4a6d8a13313/image.webp"", ""caption"": null}, {""image"": ""flagged/output/7bdb4562631122e4ced7/image.webp"", ""caption"": null}, {""image"": ""flagged/output/9f7495b7c7648ecb1a10/image.webp"", ""caption"": null}, {""image"": ""flagged/output/ecbe75612f5db6cc7370/image.webp"", ""caption"": null}, {""image"": ""flagged/output/31f824d9522d30106a44/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e06b9103e0bf90cd398a/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1441b4f37340c2afa3d0/image.webp"", ""caption"": null}]",,,2024-10-21 23:01:32.158338
diff --git a/flagged/output/036072cdcc620de8cb65/image.webp b/flagged/output/036072cdcc620de8cb65/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4e7a831b3e63d41cf7dc53178e8f19231f456648
Binary files /dev/null and b/flagged/output/036072cdcc620de8cb65/image.webp differ
diff --git a/flagged/output/1441b4f37340c2afa3d0/image.webp b/flagged/output/1441b4f37340c2afa3d0/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..46e20fcfccd3a763d3eae21a0fda7d2908c6f53b
Binary files /dev/null and b/flagged/output/1441b4f37340c2afa3d0/image.webp differ
diff --git a/flagged/output/1a8a9e53fd4990fe5231/image.webp b/flagged/output/1a8a9e53fd4990fe5231/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..30452bb9f913012c3a787e78f5af2a657bfc4a82
Binary files /dev/null and b/flagged/output/1a8a9e53fd4990fe5231/image.webp differ
diff --git a/flagged/output/1b4541c3e93f034d746d/image.webp b/flagged/output/1b4541c3e93f034d746d/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..41fb284622b8bf0e85dac87a497a4942011579f2
Binary files /dev/null and b/flagged/output/1b4541c3e93f034d746d/image.webp differ
diff --git a/flagged/output/1b4ae8362917e14cb7a7/image.webp b/flagged/output/1b4ae8362917e14cb7a7/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..3566fc9c4f4f8bc2d8be57ffbaf1fb0b84f6fed8
Binary files /dev/null and b/flagged/output/1b4ae8362917e14cb7a7/image.webp differ
diff --git a/flagged/output/1bd17ff3896c5045b453/image.webp b/flagged/output/1bd17ff3896c5045b453/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a2e8a49694c6a233177b8757e916860ec2c217cb
Binary files /dev/null and b/flagged/output/1bd17ff3896c5045b453/image.webp differ
diff --git a/flagged/output/218135cb251eb6cd0b2c/image.webp b/flagged/output/218135cb251eb6cd0b2c/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..986c085197db498a852f013f503db78b64b4f7c5
Binary files /dev/null and b/flagged/output/218135cb251eb6cd0b2c/image.webp differ
diff --git a/flagged/output/2a0671ba5ac1aa3bd2b9/image.webp b/flagged/output/2a0671ba5ac1aa3bd2b9/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..5e9a54c48df42fa656f0bced1d9580acd75cf7ba
Binary files /dev/null and b/flagged/output/2a0671ba5ac1aa3bd2b9/image.webp differ
diff --git a/flagged/output/31f824d9522d30106a44/image.webp b/flagged/output/31f824d9522d30106a44/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..c91c7a09d9b8611a1430afa699da601d7d0efe21
Binary files /dev/null and b/flagged/output/31f824d9522d30106a44/image.webp differ
diff --git a/flagged/output/35d0e9ae554c0b863ef3/image.webp b/flagged/output/35d0e9ae554c0b863ef3/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..2473cae43807f063aa4d3e568e06e17e4b569920
Binary files /dev/null and b/flagged/output/35d0e9ae554c0b863ef3/image.webp differ
diff --git a/flagged/output/41a99b70366ac01533b4/image.webp b/flagged/output/41a99b70366ac01533b4/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4121b433b87b66dd3fbb58722c67818906c67411
Binary files /dev/null and b/flagged/output/41a99b70366ac01533b4/image.webp differ
diff --git a/flagged/output/42d34c69c2384bda376b/image.webp b/flagged/output/42d34c69c2384bda376b/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..7618c903c18dc2451d25e1f32656f4caf9fe6ddb
Binary files /dev/null and b/flagged/output/42d34c69c2384bda376b/image.webp differ
diff --git a/flagged/output/53fd53c5eab644d30338/image.webp b/flagged/output/53fd53c5eab644d30338/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..2abbb10f1dbe66b93e37422b2470a0f071dea7cf
Binary files /dev/null and b/flagged/output/53fd53c5eab644d30338/image.webp differ
diff --git a/flagged/output/595953adce3a654bbd33/image.webp b/flagged/output/595953adce3a654bbd33/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..881add82b80c09007934e0467acd081e1b5fd7ac
Binary files /dev/null and b/flagged/output/595953adce3a654bbd33/image.webp differ
diff --git a/flagged/output/5b5ad2dd997a635f4917/image.webp b/flagged/output/5b5ad2dd997a635f4917/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..e57d0e88fbfc7af54aeb69995fe44af657c0d8dd
Binary files /dev/null and b/flagged/output/5b5ad2dd997a635f4917/image.webp differ
diff --git a/flagged/output/799f55238c434907570f/image.webp b/flagged/output/799f55238c434907570f/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..85d8a8fe108f97bec9684ccb2c614db43035d88e
Binary files /dev/null and b/flagged/output/799f55238c434907570f/image.webp differ
diff --git a/flagged/output/7bdb4562631122e4ced7/image.webp b/flagged/output/7bdb4562631122e4ced7/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..0b046a9a2ca40f025b7cc77df1b4c4f0613a7659
Binary files /dev/null and b/flagged/output/7bdb4562631122e4ced7/image.webp differ
diff --git a/flagged/output/9f7495b7c7648ecb1a10/image.webp b/flagged/output/9f7495b7c7648ecb1a10/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..bc21593a6869f1cc00f78de4dd9ebf912d18d795
Binary files /dev/null and b/flagged/output/9f7495b7c7648ecb1a10/image.webp differ
diff --git a/flagged/output/a966f50f23644e5046e8/image.webp b/flagged/output/a966f50f23644e5046e8/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..ffb81c67f03b993798f710e71e65b0f43cd151ca
Binary files /dev/null and b/flagged/output/a966f50f23644e5046e8/image.webp differ
diff --git a/flagged/output/b066004e4a0114aa705b/image.webp b/flagged/output/b066004e4a0114aa705b/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..b861d88dda4c0c7b783a87abaabc29f94dc943b2
Binary files /dev/null and b/flagged/output/b066004e4a0114aa705b/image.webp differ
diff --git a/flagged/output/b321456290561eacf170/image.webp b/flagged/output/b321456290561eacf170/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a10280498c93346105b7da59ba0808494004024c
Binary files /dev/null and b/flagged/output/b321456290561eacf170/image.webp differ
diff --git a/flagged/output/c4b75e9fbc946f6ead6d/image.webp b/flagged/output/c4b75e9fbc946f6ead6d/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..106535e80842768f14da47245baa981cabeea71b
Binary files /dev/null and b/flagged/output/c4b75e9fbc946f6ead6d/image.webp differ
diff --git a/flagged/output/d7bc2f0eb8d70a562542/image.webp b/flagged/output/d7bc2f0eb8d70a562542/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..686f7aac20f6a5449db99a97dc43c01dfdd99551
Binary files /dev/null and b/flagged/output/d7bc2f0eb8d70a562542/image.webp differ
diff --git a/flagged/output/db82081afaabf2fb505b/image.webp b/flagged/output/db82081afaabf2fb505b/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..af93f3eacee5ef0a4903995aaf8d2e2e5921976d
Binary files /dev/null and b/flagged/output/db82081afaabf2fb505b/image.webp differ
diff --git a/flagged/output/ddf6b8ddc855838cc3b5/image.webp b/flagged/output/ddf6b8ddc855838cc3b5/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..f5d206e37f97d1f45b91aceb45df630ca9fae223
Binary files /dev/null and b/flagged/output/ddf6b8ddc855838cc3b5/image.webp differ
diff --git a/flagged/output/e06b9103e0bf90cd398a/image.webp b/flagged/output/e06b9103e0bf90cd398a/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..8b4510986e28e5a132f4cc197ce9063b072b113b
Binary files /dev/null and b/flagged/output/e06b9103e0bf90cd398a/image.webp differ
diff --git a/flagged/output/e2f704607c002e0c557d/image.webp b/flagged/output/e2f704607c002e0c557d/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a7b2dafc246639799cfabd97306b3c4ba426cba6
Binary files /dev/null and b/flagged/output/e2f704607c002e0c557d/image.webp differ
diff --git a/flagged/output/e31f93405e1526fe3e55/image.webp b/flagged/output/e31f93405e1526fe3e55/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..955afaa8ead0e0bc67be7722bbc791dbfe4f35be
Binary files /dev/null and b/flagged/output/e31f93405e1526fe3e55/image.webp differ
diff --git a/flagged/output/e6ef5ba2d6c65b3c1d21/image.webp b/flagged/output/e6ef5ba2d6c65b3c1d21/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4e0429c00817f88878d5cfc460039b8ed169c74c
Binary files /dev/null and b/flagged/output/e6ef5ba2d6c65b3c1d21/image.webp differ
diff --git a/flagged/output/e9c9ff1da0805da0c0d8/image.webp b/flagged/output/e9c9ff1da0805da0c0d8/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..d73c97deb0388bfb5423ad36e686e5e3ca44ce8d
Binary files /dev/null and b/flagged/output/e9c9ff1da0805da0c0d8/image.webp differ
diff --git a/flagged/output/ecbe75612f5db6cc7370/image.webp b/flagged/output/ecbe75612f5db6cc7370/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..7608061bb4d13a3e87696620071a61112463dea9
Binary files /dev/null and b/flagged/output/ecbe75612f5db6cc7370/image.webp differ
diff --git a/flagged/output/f333c69915509927b2ff/image.webp b/flagged/output/f333c69915509927b2ff/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..766635e3fd9996a671a9fd9e09bd37901f37a20e
Binary files /dev/null and b/flagged/output/f333c69915509927b2ff/image.webp differ
diff --git a/flagged/output/f763a51fb4a6d8a13313/image.webp b/flagged/output/f763a51fb4a6d8a13313/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..65b0c4ddd52dac9f92f898bc64ed67c18722c6ac
Binary files /dev/null and b/flagged/output/f763a51fb4a6d8a13313/image.webp differ
diff --git a/flagged/output/f8727dcfa3c59de0d873/image.webp b/flagged/output/f8727dcfa3c59de0d873/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..175792b0035f9df9db88d62f56d40356d8afbbfe
Binary files /dev/null and b/flagged/output/f8727dcfa3c59de0d873/image.webp differ
diff --git a/flagged/output/fff73f12467314dce395/image.webp b/flagged/output/fff73f12467314dce395/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..0b51dcf8c9e672660529ad1adea9afe37e5a4f08
Binary files /dev/null and b/flagged/output/fff73f12467314dce395/image.webp differ
diff --git a/get_data.py b/get_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6e9414c933a64f1124f4eda6ec0faa8cd8ed2ee
--- /dev/null
+++ b/get_data.py
@@ -0,0 +1,119 @@
+from pathlib import Path
+
+import torch
+import torchvision
+from torchvision.transforms import transforms, TrivialAugmentWide
+
+from configs.dataset_params import normalize_params
+from dataset_classes.cub200 import CUB200Class
+from dataset_classes.stanfordcars import StanfordCarsClass
+from dataset_classes.travelingbirds import TravelingBirds
+
+
+def get_data(dataset, crop = True, img_size=448):
+ batchsize = 16
+ if dataset == "CUB2011":
+ train_transform = get_augmentation(0.1, img_size, True,not crop, True, True, normalize_params["CUB2011"])
+ test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["CUB2011"])
+ train_dataset = CUB200Class(True, train_transform, crop)
+ test_dataset = CUB200Class(False, test_transform, crop)
+ elif dataset == "TravelingBirds":
+ train_transform = get_augmentation(0.1, img_size, True, not crop, True, True, normalize_params["TravelingBirds"])
+ test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["TravelingBirds"])
+ train_dataset = TravelingBirds(True, train_transform, crop)
+ test_dataset = TravelingBirds(False, test_transform, crop)
+
+ elif dataset == "StanfordCars":
+ train_transform = get_augmentation(0.1, img_size, True, True, True, True, normalize_params["StanfordCars"])
+ test_transform = get_augmentation(0.1, img_size, False, True, True, True, normalize_params["StanfordCars"])
+ train_dataset = StanfordCarsClass(True, train_transform)
+ test_dataset = StanfordCarsClass(False, test_transform)
+ elif dataset == "FGVCAircraft":
+ raise NotImplementedError
+
+ elif dataset == "ImageNet":
+ # Defaults from the robustness package
+ if img_size != 224:
+ raise NotImplementedError("ImageNet is setup to only work with 224x224 images")
+ train_transform = transforms.Compose([
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ColorJitter(
+ brightness=0.1,
+ contrast=0.1,
+ saturation=0.1
+ ),
+ transforms.ToTensor(),
+ Lighting(0.05, IMAGENET_PCA['eigval'],
+ IMAGENET_PCA['eigvec'])
+ ])
+ """
+ Standard training data augmentation for ImageNet-scale datasets: Random crop,
+ Random flip, Color Jitter, and Lighting Transform (see https://git.io/fhBOc)
+ """
+ test_transform = transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ ])
+ imgnet_root = Path.home()/ "tmp" /"Datasets"/ "imagenet"
+ train_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='train', transform=train_transform)
+ test_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='val', transform=test_transform)
+ batchsize = 64
+
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8)
+ test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8)
+ return train_loader, test_loader
+
+def get_augmentation(jitter, size, training, random_center_crop, trivialAug, hflip, normalize):
+ augmentation = []
+ if random_center_crop:
+ augmentation.append(transforms.Resize(size))
+ else:
+ augmentation.append(transforms.Resize((size, size)))
+ if training:
+ if random_center_crop:
+ augmentation.append(transforms.RandomCrop(size, padding=4))
+ else:
+ if random_center_crop:
+ augmentation.append(transforms.CenterCrop(size))
+ if training:
+ if hflip:
+ augmentation.append(transforms.RandomHorizontalFlip())
+ if jitter:
+ augmentation.append(transforms.ColorJitter(jitter, jitter, jitter))
+ if trivialAug:
+ augmentation.append(TrivialAugmentWide())
+ augmentation.append(transforms.ToTensor())
+ augmentation.append(transforms.Normalize(**normalize))
+ return transforms.Compose(augmentation)
+
+class Lighting(object):
+ """
+ Lighting noise (see https://git.io/fhBOc)
+ """
+
+ def __init__(self, alphastd, eigval, eigvec):
+ self.alphastd = alphastd
+ self.eigval = eigval
+ self.eigvec = eigvec
+
+ def __call__(self, img):
+ if self.alphastd == 0:
+ return img
+
+ alpha = img.new().resize_(3).normal_(0, self.alphastd)
+ rgb = self.eigvec.type_as(img).clone() \
+ .mul(alpha.view(1, 3).expand(3, 3)) \
+ .mul(self.eigval.view(1, 3).expand(3, 3)) \
+ .sum(1).squeeze()
+
+ return img.add(rgb.view(3, 1, 1).expand_as(img))
+IMAGENET_PCA = {
+ 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
+ 'eigvec': torch.Tensor([
+ [-0.5675, 0.7192, 0.4009],
+ [-0.5808, -0.0045, -0.8140],
+ [-0.5836, -0.6948, 0.4203],
+ ])
+}
diff --git a/load_model.py b/load_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..13e962a7497da9a895feed12708fe9c55a24dbdd
--- /dev/null
+++ b/load_model.py
@@ -0,0 +1,51 @@
+from argparse import ArgumentParser
+from pathlib import Path
+
+import torch
+
+from architectures.model_mapping import get_model
+from configs.dataset_params import dataset_constants
+from evaluation.qsenn_metrics import eval_model_on_all_qsenn_metrics
+from get_data import get_data
+
+def extract_sel_mean_std_bias_assignemnt(state_dict):
+ feature_sel = state_dict["linear.selection"]
+ #feature_sel = selection
+ weight_at_selection = state_dict["linear.layer.weight"]
+ mean = state_dict["linear.mean"]
+ std = state_dict["linear.std"]
+ bias = state_dict["linear.layer.bias"]
+ return feature_sel, weight_at_selection, mean, std, bias
+
+
+def eval_model(dataset, arch,seed=123456, model_type="qsenn",crop = True, n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None):
+ n_classes = dataset_constants[dataset]["num_classes"]
+ train_loader, test_loader = get_data(dataset, crop=crop, img_size=img_size)
+ model = get_model(arch, n_classes, reduced_strides)
+ if folder is None:
+ folder = Path.home() / f"tmp/{arch}/{dataset}/{seed}/"
+ print(folder)
+ state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth")
+ selection= torch.load(folder / f"SlDD_Selection_50.pt")
+ state_dict['linear.selection']=selection
+ print(state_dict.keys())
+ feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
+ model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
+ model.load_state_dict(state_dict)
+ print(model)
+ metrics_finetuned = eval_model_on_all_qsenn_metrics(model, test_loader, train_loader)
+
+if __name__ == '__main__':
+ parser = ArgumentParser()
+ parser.add_argument('--dataset', default="CUB2011", type=str, help='dataset name', choices=["CUB2011", "ImageNet", "TravelingBirds", "StanfordCars"])
+ parser.add_argument('--arch', default="resnet50", type=str, help='Backbone Feature Extractor', choices=["resnet50", "resnet18"])
+ parser.add_argument('--model_type', default="qsenn", type=str, help='Type of Model', choices=["qsenn", "sldd"])
+ parser.add_argument('--seed', default=123456, type=int, help='seed, used for naming the folder and random processes. Could be useful to set to have multiple finetune runs (e.g. Q-SENN and SLDD) on the same dense model') # 769567, 552629
+ parser.add_argument('--cropGT', default=False, type=bool,
+ help='Whether to crop CUB/TravelingBirds based on GT Boundaries')
+ parser.add_argument('--n_features', default=50, type=int, help='How many features to select') #769567
+ parser.add_argument('--n_per_class', default=5, type=int, help='How many features to assign to each class')
+ parser.add_argument('--img_size', default=448, type=int, help='Image size')
+ parser.add_argument('--reduced_strides', default=False, type=bool, help='Whether to use reduced strides for resnets')
+ args = parser.parse_args()
+ eval_model(args.dataset, args.arch, args.seed, args.model_type,args.cropGT, args.n_features, args.n_per_class, args.img_size, args.reduced_strides)
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a340862967c2c1d8befc1eff79bd00122223f93
--- /dev/null
+++ b/main.py
@@ -0,0 +1,79 @@
+import os
+from argparse import ArgumentParser
+from pathlib import Path
+
+import numpy as np
+import torch
+from tqdm import trange
+
+from FeatureDiversityLoss import FeatureDiversityLoss
+from architectures.model_mapping import get_model
+from configs.architecture_params import architecture_params
+from configs.dataset_params import dataset_constants
+from evaluation.qsenn_metrics import eval_model_on_all_qsenn_metrics
+from finetuning.map_function import finetune
+from get_data import get_data
+from saving.logging import Tee
+from saving.utils import json_save
+from train import train, test
+from training.optim import get_optimizer, get_scheduler_for_model
+
+
+def main(dataset, arch,seed=None, model_type="qsenn", do_dense=True,crop = True, n_features = 50, n_per_class=5, img_size=448, reduced_strides=False):
+ # create random seed, if seed is None
+ if seed is None:
+ seed = np.random.randint(0, 1000000)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ dataset_key = dataset
+ if crop:
+ assert dataset in ["CUB2011","TravelingBirds"]
+ dataset_key += "_crop"
+ log_dir = Path.home()/f"tmp/{arch}/{dataset_key}/{seed}/"
+ log_dir.mkdir(parents=True, exist_ok=True)
+ tee = Tee(log_dir / "log.txt") # save log to file
+ n_classes = dataset_constants[dataset]["num_classes"]
+ train_loader, test_loader = get_data(dataset, crop=crop, img_size=img_size)
+ model = get_model(arch, n_classes, reduced_strides)
+ fdl = FeatureDiversityLoss(architecture_params[arch]["beta"], model.linear)
+ OptimizationSchedule = get_scheduler_for_model(model_type, dataset)
+ optimizer, schedule, dense_epochs =get_optimizer(model, OptimizationSchedule)
+ if not os.path.exists(log_dir / "Trained_DenseModel.pth"):
+ if do_dense:
+ for epoch in trange(dense_epochs):
+ model = train(model, train_loader, optimizer, fdl, epoch)
+ schedule.step()
+ if epoch % 5 == 0:
+ test(model, test_loader,epoch)
+ else:
+ print("Using pretrained model, only makes sense for ImageNet")
+ torch.save(model.state_dict(), os.path.join(log_dir, f"Trained_DenseModel.pth"))
+ else:
+ model.load_state_dict(torch.load(log_dir / "Trained_DenseModel.pth"))
+ if not os.path.exists( log_dir/f"Results_DenseModel.json"):
+ metrics_dense = eval_model_on_all_qsenn_metrics(model, test_loader, train_loader)
+ json_save(os.path.join(log_dir, f"Results_DenseModel.json"), metrics_dense)
+ final_model = finetune(model_type, model, train_loader, test_loader, log_dir, n_classes, seed, architecture_params[arch]["beta"], OptimizationSchedule, n_per_class, n_features)
+ torch.save(final_model.state_dict(), os.path.join(log_dir,f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth"))
+ metrics_finetuned = eval_model_on_all_qsenn_metrics(final_model, test_loader, train_loader)
+ json_save(os.path.join(log_dir, f"Results_{model_type}_{n_features}_{n_per_class}_FinetunedModel.json"), metrics_finetuned)
+ print("Done")
+ pass
+
+
+
+if __name__ == '__main__':
+ parser = ArgumentParser()
+ parser.add_argument('--dataset', default="CUB2011", type=str, help='dataset name', choices=["CUB2011", "ImageNet", "TravelingBirds", "StanfordCars"])
+ parser.add_argument('--arch', default="resnet50", type=str, help='Backbone Feature Extractor', choices=["resnet50", "resnet18"])
+ parser.add_argument('--model_type', default="qsenn", type=str, help='Type of Model', choices=["qsenn", "sldd"])
+ parser.add_argument('--seed', default=None, type=int, help='seed, used for naming the folder and random processes. Could be useful to set to have multiple finetune runs (e.g. Q-SENN and SLDD) on the same dense model') # 769567, 552629
+ parser.add_argument('--do_dense', default=True, type=bool, help='whether to train dense model. Should be true for all datasets except (maybe) ImageNet')
+ parser.add_argument('--cropGT', default=False, type=bool,
+ help='Whether to crop CUB/TravelingBirds based on GT Boundaries')
+ parser.add_argument('--n_features', default=50, type=int, help='How many features to select') #769567
+ parser.add_argument('--n_per_class', default=5, type=int, help='How many features to assign to each class')
+ parser.add_argument('--img_size', default=448, type=int, help='Image size')
+ parser.add_argument('--reduced_strides', default=False, type=bool, help='Whether to use reduced strides for resnets')
+ args = parser.parse_args()
+ main(args.dataset, args.arch, args.seed, args.model_type, args.do_dense,args.cropGT, args.n_features, args.n_per_class, args.img_size, args.reduced_strides)
diff --git a/saving/logging.py b/saving/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..377e31b33e06865bb588bfec32678c947e5c3bb3
--- /dev/null
+++ b/saving/logging.py
@@ -0,0 +1,27 @@
+import sys
+
+
+class Tee(object):
+ def __init__(self, name, file_only=False):
+ self.file = open(name, "a")
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+ sys.stdout = self
+ sys.stderr = self
+ self.file_only = file_only
+
+ def __del__(self):
+ sys.stdout = self.stdout
+ sys.stderr = self.stderr
+ self.file.close()
+
+ def write(self, data):
+ self.file.write(data)
+ if not self.file_only:
+ self.stdout.write(data)
+ self.flush()
+
+ def flush(self):
+ self.file.flush()
+
+
diff --git a/saving/utils.py b/saving/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bf35c3c5129e84ab6f13f85c4c04fba4f7e33b4
--- /dev/null
+++ b/saving/utils.py
@@ -0,0 +1,6 @@
+import json
+
+
+def json_save(filename, data):
+ with open(filename, "w") as f:
+ json.dump(data, f,indent=4)
\ No newline at end of file
diff --git a/sparsification/FeatureSelection.py b/sparsification/FeatureSelection.py
new file mode 100644
index 0000000000000000000000000000000000000000..885a7bc9fab8842e9eece0f07690636b1d623233
--- /dev/null
+++ b/sparsification/FeatureSelection.py
@@ -0,0 +1,473 @@
+from argparse import ArgumentParser
+import logging
+import math
+import os.path
+import sys
+import time
+import warnings
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from glm_saga.elasticnet import maximum_reg_loader, get_device, elastic_loss_and_acc_loader
+from torch import nn
+
+import torch as ch
+
+from sparsification.utils import safe_zip
+
+# TODO checkout this change: Marks changes to the group version of glmsaga
+
+"""
+This would need glm_saga to run
+usage to select 50 features with parameters as in paper:
+metadata contains information about the precomputed train features in feature_loaders
+args contains the default arguments for glm-saga, as described at the bottom
+def get_glm_to_zero(feature_loaders, metadata, args, num_classes, device, train_ds, Ntotal):
+ num_features = metadata["X"]["num_features"][0]
+ fittingClass = FeatureSelectionFitting(num_features, num_lasses, args, 0.8,
+ 50,
+ True,0.1,
+ lookback=3, tol=1e-4,
+ epsilon=1,)
+ to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device)
+ return to_drop
+
+to_drop is then used to remove the features from the downstream fitting and finetuning.
+"""
+
+
+class FeatureSelectionFitting:
+ def __init__(self, n_features, n_classes, args, selalpha, nKeep, lam_fac,out_dir, lookback=None, tol=None,
+ epsilon=None):
+ """
+ This is an adaption of the group version of glm-saga (https://github.com/MadryLab/DebuggableDeepNetworks)
+ The function extended_mask_max covers the changed operator,
+ Args:
+ n_features:
+ n_classes:
+ args: default args for glmsaga
+ selalpha: alpha for elastic net
+ nKeep: target number features
+ lam_fac: discount factor for lambda
+ parameters of glmsaga
+ lookback:
+ tol:
+ epsilon:
+ """
+ self.selected_features = torch.zeros(n_features, dtype=torch.bool)
+ self.num_features = n_features
+ self.selalpha = selalpha
+ self.lam_Fac = lam_fac
+ self.out_dir = out_dir
+ self.n_classes = n_classes
+ self.nKeep = nKeep
+ self.args = self.extend_args(args, lookback, tol, epsilon)
+
+ # Extended Proximal Operator for Feature Selection
+ def extended_mask_max(self, greater_to_keep, thresh):
+ prev = greater_to_keep[self.selected_features]
+ greater_to_keep[self.selected_features] = torch.min(greater_to_keep)
+ max_entry = torch.argmax(greater_to_keep)
+ greater_to_keep[self.selected_features] = prev
+ mask = torch.zeros_like(greater_to_keep)
+ mask[max_entry] = 1
+ final_mask = (greater_to_keep > thresh)
+ final_mask = final_mask * mask
+ allowed_to_keep = torch.logical_or(self.selected_features, final_mask)
+ return allowed_to_keep
+
+ def extend_args(self, args, lookback, tol, epsilon):
+ for key, entry in safe_zip(["lookbehind", "tol",
+ "lr_decay_factor", ], [lookback, tol, epsilon]):
+ if entry is not None:
+ setattr(args, key, entry)
+ return args
+
+ # Grouped L1 regularization
+ # proximal operator for f(weight) = lam * \|weight\|_2
+ # where the 2-norm is taken columnwise
+ def group_threshold(self, weight, lam):
+ norm = weight.norm(p=2, dim=0) + 1e-6
+ # print(ch.sum((norm > lam)))
+ return (weight - lam * weight / norm) * self.extended_mask_max(norm, lam)
+
+ # Elastic net regularization with group sparsity
+ # proximal operator for f(x) = alpha * \|x\|_1 + beta * \|x\|_2^2
+ # where the 2-norm is taken columnwise
+ def group_threshold_with_shrinkage(self, x, alpha, beta):
+ y = self.group_threshold(x, alpha)
+ return y / (1 + beta)
+
+ def threshold(self, weight_new, lr, lam):
+ alpha = self.selalpha
+ if alpha == 1:
+ # Pure L1 regularization
+ weight_new = self.group_threshold(weight_new, lr * lam * alpha)
+ else:
+ # Elastic net regularization
+ weight_new = self.group_threshold_with_shrinkage(weight_new, lr * lam * alpha,
+ lr * lam * (1 - alpha))
+ return weight_new
+
+ # Train an elastic GLM with proximal SAGA
+ # Since SAGA stores a scalar for each example-class pair, either pass
+ # the number of examples and number of classes or calculate it with an
+ # initial pass over the loaders
+ def train_saga(self, linear, loader, lr, nepochs, lam, alpha, group=True, verbose=None,
+ state=None, table_device=None, n_ex=None, n_classes=None, tol=1e-4,
+ preprocess=None, lookbehind=None, family='multinomial', logger=None):
+ if logger is None:
+ logger = print
+ with ch.no_grad():
+ weight, bias = list(linear.parameters())
+ if table_device is None:
+ table_device = weight.device
+
+ # get total number of examples and initialize scalars
+ # for computing the gradients
+ if n_ex is None:
+ n_ex = sum(tensors[0].size(0) for tensors in loader)
+ if n_classes is None:
+ if family == 'multinomial':
+ n_classes = max(tensors[1].max().item() for tensors in loader) + 1
+ elif family == 'gaussian':
+ for batch in loader:
+ y = batch[1]
+ break
+ n_classes = y.size(1)
+
+ # Storage for scalar gradients and averages
+ if state is None:
+ a_table = ch.zeros(n_ex, n_classes).to(table_device)
+ w_grad_avg = ch.zeros_like(weight).to(weight.device)
+ b_grad_avg = ch.zeros_like(bias).to(weight.device)
+ else:
+ a_table = state["a_table"].to(table_device)
+ w_grad_avg = state["w_grad_avg"].to(weight.device)
+ b_grad_avg = state["b_grad_avg"].to(weight.device)
+
+ obj_history = []
+ obj_best = None
+ nni = 0
+ for t in range(nepochs):
+ total_loss = 0
+ for n_batch, batch in enumerate(loader):
+ if len(batch) == 3:
+ X, y, idx = batch
+ w = None
+ elif len(batch) == 4:
+ X, y, w, idx = batch
+ else:
+ raise ValueError(
+ f"Loader must return (data, target, index) or (data, target, index, weight) but instead got a tuple of length {len(batch)}")
+
+ if preprocess is not None:
+ device = get_device(preprocess)
+ with ch.no_grad():
+ X = preprocess(X.to(device))
+ X = X.to(weight.device)
+ out = linear(X)
+
+ # split gradient on only the cross entropy term
+ # for efficient storage of gradient information
+ if family == 'multinomial':
+ if w is None:
+ loss = F.cross_entropy(out, y.to(weight.device), reduction='mean')
+ else:
+ loss = F.cross_entropy(out, y.to(weight.device), reduction='none')
+ loss = (loss * w).mean()
+ I = ch.eye(linear.weight.size(0))
+ target = I[y].to(weight.device) # change to OHE
+
+ # Calculate new scalar gradient
+ logits = F.softmax(linear(X))
+ elif family == 'gaussian':
+ if w is None:
+ loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='mean')
+ else:
+ loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='none')
+ loss = (loss * (w.unsqueeze(1))).mean()
+ target = y
+
+ # Calculate new scalar gradient
+ logits = linear(X)
+ else:
+ raise ValueError(f"Unknown family: {family}")
+ total_loss += loss.item() * X.size(0)
+
+ # BS x NUM_CLASSES
+ a = logits - target
+ if w is not None:
+ a = a * w.unsqueeze(1)
+ a_prev = a_table[idx].to(weight.device)
+
+ # weight parameter
+ w_grad = (a.unsqueeze(2) * X.unsqueeze(1)).mean(0)
+ w_grad_prev = (a_prev.unsqueeze(2) * X.unsqueeze(1)).mean(0)
+ w_saga = w_grad - w_grad_prev + w_grad_avg
+ weight_new = weight - lr * w_saga
+ weight_new = self.threshold(weight_new, lr, lam)
+ # bias parameter
+ b_grad = a.mean(0)
+ b_grad_prev = a_prev.mean(0)
+ b_saga = b_grad - b_grad_prev + b_grad_avg
+ bias_new = bias - lr * b_saga
+
+ # update table and averages
+ a_table[idx] = a.to(table_device)
+ w_grad_avg.add_((w_grad - w_grad_prev) * X.size(0) / n_ex)
+ b_grad_avg.add_((b_grad - b_grad_prev) * X.size(0) / n_ex)
+
+ if lookbehind is None:
+ dw = (weight_new - weight).norm(p=2)
+ db = (bias_new - bias).norm(p=2)
+ criteria = ch.sqrt(dw ** 2 + db ** 2)
+
+ if criteria.item() <= tol:
+ return {
+ "a_table": a_table.cpu(),
+ "w_grad_avg": w_grad_avg.cpu(),
+ "b_grad_avg": b_grad_avg.cpu()
+ }
+
+ weight.data = weight_new
+ bias.data = bias_new
+
+ saga_obj = total_loss / n_ex + lam * alpha * weight.norm(p=1) + 0.5 * lam * (1 - alpha) * (
+ weight ** 2).sum()
+
+ # save amount of improvement
+ obj_history.append(saga_obj.item())
+ if obj_best is None or saga_obj.item() + tol < obj_best:
+ obj_best = saga_obj.item()
+ nni = 0
+ else:
+ nni += 1
+
+ # Stop if no progress for lookbehind iterationsd:])
+ criteria = lookbehind is not None and (nni >= lookbehind)
+
+ nnz = (weight.abs() > 1e-5).sum().item()
+ total = weight.numel()
+ if verbose and (t % verbose) == 0:
+ if lookbehind is None:
+ logger(
+ f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) criteria {criteria:.4f} {dw} {db}")
+ else:
+ logger(
+ f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best}")
+
+ if lookbehind is not None and criteria:
+ logger(
+ f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best} [early stop at {t}]")
+ return {
+ "a_table": a_table.cpu(),
+ "w_grad_avg": w_grad_avg.cpu(),
+ "b_grad_avg": b_grad_avg.cpu()
+ }
+
+ logger(f"did not converge at {nepochs} iterations (criteria {criteria})")
+ return {
+ "a_table": a_table.cpu(),
+ "w_grad_avg": w_grad_avg.cpu(),
+ "b_grad_avg": b_grad_avg.cpu()
+ }
+
+ def glm_saga(self, linear, loader, max_lr, nepochs, alpha, dropout, tries,
+ table_device=None, preprocess=None, group=False,
+ verbose=None, state=None, n_ex=None, n_classes=None,
+ tol=1e-4, epsilon=0.001, k=100, checkpoint=None,
+ do_zero=True, lr_decay_factor=1, metadata=None,
+ val_loader=None, test_loader=None, lookbehind=None,
+ family='multinomial', encoder=None, tot_tries=1):
+ if encoder is not None:
+ warnings.warn("encoder argument is deprecated; please use preprocess instead", DeprecationWarning)
+ preprocess = encoder
+ device = get_device(linear)
+ checkpoint = self.out_dir
+ if preprocess is not None and (device != get_device(preprocess)):
+ raise ValueError(
+ f"Linear and preprocess must be on same device (got {get_device(linear)} and {get_device(preprocess)})")
+
+ if metadata is not None:
+ if n_ex is None:
+ n_ex = metadata['X']['num_examples']
+ if n_classes is None:
+ n_classes = metadata['y']['num_classes']
+ lam_fac = (1 + (tries - 1) / tot_tries)
+ print("Using lam_fac ", lam_fac)
+ max_lam = maximum_reg_loader(loader, group=group, preprocess=preprocess, metadata=metadata,
+ family=family) / max(
+ 0.001, alpha) * lam_fac
+ group_lam = maximum_reg_loader(loader, group=True, preprocess=preprocess, metadata=metadata,
+ family=family) / max(
+ 0.001, alpha) * lam_fac
+ min_lam = epsilon * max_lam
+ group_min_lam = epsilon * group_lam
+ # logspace is base 10 but log is base e so use log10
+ lams = ch.logspace(math.log10(max_lam), math.log10(min_lam), k)
+ lrs = ch.logspace(math.log10(max_lr), math.log10(max_lr / lr_decay_factor), k)
+ found = False
+ if do_zero:
+ lams = ch.cat([lams, lams.new_zeros(1)])
+ lrs = ch.cat([lrs, lrs.new_ones(1) * lrs[-1]])
+
+ path = []
+ best_val_loss = float('inf')
+
+ if checkpoint is not None:
+ os.makedirs(checkpoint, exist_ok=True)
+
+ file_handler = logging.FileHandler(filename=os.path.join(checkpoint, 'output.log'))
+ stdout_handler = logging.StreamHandler(sys.stdout)
+ handlers = [file_handler, stdout_handler]
+
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format='[%(asctime)s] %(levelname)s - %(message)s',
+ handlers=handlers
+ )
+ logger = logging.getLogger('glm_saga').info
+ else:
+ logger = print
+ while self.selected_features.sum() < self.nKeep: # TODO checkout this change, one iteration per feature
+ n_feature_to_keep = self.selected_features.sum()
+ for i, (lam, lr) in enumerate(zip(lams, lrs)):
+ lam = lam * self.lam_Fac
+ start_time = time.time()
+ self.selected_features = self.selected_features.to(device)
+ state = self.train_saga(linear, loader, lr, nepochs, lam, alpha,
+ table_device=table_device, preprocess=preprocess, group=group, verbose=verbose,
+ state=state, n_ex=n_ex, n_classes=n_classes, tol=tol, lookbehind=lookbehind,
+ family=family, logger=logger)
+
+ with ch.no_grad():
+ loss, acc = elastic_loss_and_acc_loader(linear, loader, lam, alpha, preprocess=preprocess,
+ family=family)
+ loss, acc = loss.item(), acc.item()
+
+ loss_val, acc_val = -1, -1
+ if val_loader:
+ loss_val, acc_val = elastic_loss_and_acc_loader(linear, val_loader, lam, alpha,
+ preprocess=preprocess,
+ family=family)
+ loss_val, acc_val = loss_val.item(), acc_val.item()
+
+ loss_test, acc_test = -1, -1
+ if test_loader:
+ loss_test, acc_test = elastic_loss_and_acc_loader(linear, test_loader, lam, alpha,
+ preprocess=preprocess, family=family)
+ loss_test, acc_test = loss_test.item(), acc_test.item()
+
+ params = {
+ "lam": lam,
+ "lr": lr,
+ "alpha": alpha,
+ "time": time.time() - start_time,
+ "loss": loss,
+ "metrics": {
+ "loss_tr": loss,
+ "acc_tr": acc,
+ "loss_val": loss_val,
+ "acc_val": acc_val,
+ "loss_test": loss_test,
+ "acc_test": acc_test,
+ },
+ "weight": linear.weight.detach().cpu().clone(),
+ "bias": linear.bias.detach().cpu().clone()
+
+ }
+ path.append(params)
+ if loss_val is not None and loss_val < best_val_loss:
+ best_val_loss = loss_val
+ best_params = params
+ found = True
+ nnz = (linear.weight.abs() > 1e-5).sum().item()
+ total = linear.weight.numel()
+ if family == 'multinomial':
+ logger(
+ f"{n_feature_to_keep} Feature ({i}) lambda {lam:.4f}, loss {loss:.4f}, acc {acc:.4f} [val acc {acc_val:.4f}] [test acc {acc_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}")
+ elif family == 'gaussian':
+ logger(
+ f"({i}) lambda {lam:.4f}, loss {loss:.4f} [val loss {loss_val:.4f}] [test loss {loss_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}")
+
+ if self.check_new_feature(linear.weight): # TODO checkout this change, canceling if new feature is used
+ if checkpoint is not None:
+ ch.save(params, os.path.join(checkpoint, f"params{n_feature_to_keep}.pth"))
+ break
+ if found:
+ return {
+ 'path': path,
+ 'best': best_params,
+ 'state': state
+ }
+ else:
+ return False
+
+ def check_new_feature(self, weight):
+ # TODO checkout this change, checking if new feature is used
+ copied_weight = torch.tensor(weight.cpu())
+ used_features = torch.unique(
+ torch.nonzero(copied_weight)[:, 1])
+ if len(used_features) > 0:
+ new_set = set(used_features.tolist())
+ old_set = set(torch.nonzero(self.selected_features)[:, 0].tolist())
+ diff = new_set - old_set
+ if len(diff) > 0:
+ self.selected_features[used_features] = True
+ return True
+ return False
+
+ def fit(self, feature_loaders, metadata, device):
+ # TODO checkout this change, glm saga code slightly adapted to return to_drop
+ print("Initializing linear model...")
+ linear = nn.Linear(self.num_features, self.n_classes).to(device)
+ for p in [linear.weight, linear.bias]:
+ p.data.zero_()
+
+ print("Preparing normalization preprocess and indexed dataloader")
+ preprocess = NormalizedRepresentation(feature_loaders['train'],
+ metadata=metadata,
+ device=linear.weight.device)
+
+ print("Calculating the regularization path")
+ mpl_logger = logging.getLogger("matplotlib")
+ mpl_logger.setLevel(logging.WARNING)
+ selected_features = self.glm_saga(linear,
+ feature_loaders['train'],
+ self.args.lr,
+ self.args.max_epochs,
+ self.selalpha, 0, 1,
+ val_loader=feature_loaders['val'],
+ test_loader=feature_loaders['test'],
+ n_classes=self.n_classes,
+ verbose=self.args.verbose,
+ tol=self.args.tol,
+ lookbehind=self.args.lookbehind,
+ lr_decay_factor=self.args.lr_decay_factor,
+ group=True,
+ epsilon=self.args.lam_factor,
+ metadata=metadata,
+ preprocess=preprocess, tot_tries=1)
+ to_drop = np.where(self.selected_features.cpu().numpy() == 0)[0]
+ test_acc = selected_features["path"][-1]["metrics"]["acc_test"]
+ torch.set_grad_enabled(True)
+ return to_drop, test_acc
+
+
+class NormalizedRepresentation(ch.nn.Module):
+ def __init__(self, loader, metadata, device='cuda', tol=1e-5):
+ super(NormalizedRepresentation, self).__init__()
+
+ assert metadata is not None
+ self.device = device
+ self.mu = metadata['X']['mean']
+ self.sigma = ch.clamp(metadata['X']['std'], tol)
+
+ def forward(self, X):
+ return (X - self.mu.to(self.device)) / self.sigma.to(self.device)
+
+
+
+
diff --git a/sparsification/data_helpers.py b/sparsification/data_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d48424564050c66238f9b731b433ca25d29d5a6b
--- /dev/null
+++ b/sparsification/data_helpers.py
@@ -0,0 +1,16 @@
+
+import torch
+
+
+class NormalizedRepresentation(torch.nn.Module):
+ def __init__(self, loader, metadata, device='cuda', tol=1e-5):
+ super(NormalizedRepresentation, self).__init__()
+
+ assert metadata is not None
+ self.device = device
+ self.mu = metadata['X']['mean']
+ self.sigma = torch.clamp(metadata['X']['std'], tol)
+
+ def forward(self, X):
+ return (X - self.mu.to(self.device)) / self.sigma.to(self.device)
+
diff --git a/sparsification/feature_helpers.py b/sparsification/feature_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c11867077be5ab067548f498dad17fb299fa162
--- /dev/null
+++ b/sparsification/feature_helpers.py
@@ -0,0 +1,378 @@
+import math
+import os
+import sys
+
+import torch.cuda
+
+import sparsification.utils
+
+sys.path.append('')
+import numpy as np
+import torch as ch
+from torch.utils.data import Subset
+from tqdm import tqdm
+
+
+
+# From glm_saga
+def get_features_batch(batch, model, device='cuda'):
+ if not torch.cuda.is_available():
+ device = "cpu"
+ ims, targets = batch
+ output, latents = model(ims.to(device), with_final_features=True )
+ return latents, targets
+
+
+def compute_features(loader, model, dataset_type, pooled_output,
+ batch_size, num_workers,
+ shuffle=False, device='cpu', n_epoch=1,
+ filename=None, chunk_threshold=20000, balance=False):
+ """Compute deep features for a given dataset using a modeln and returnss
+ them as a pytorch dataset and loader.
+ Args:
+ loader : Torch data loader
+ model: Torch model
+ dataset_type (str): One of vision or language
+ pooled_output (bool): Whether or not to pool outputs
+ (only relevant for some language models)
+ batch_size (int): Batch size for output loader
+ num_workers (int): Number of workers to use for output loader
+ shuffle (bool): Whether or not to shuffle output data loaoder
+ device (str): Device on which to keep the model
+ filename (str):Optional file to cache computed feature. Recommended
+ for large dataset_classes like ImageNet.
+ chunk_threshold (int): Size of shard while caching
+ balance (bool): Whether or not to balance output data loader
+ (only relevant for some language models)
+ Returns:
+ feature_dataset: Torch dataset with deep features
+ feature_loader: Torch data loader with deep features
+ """
+ if torch.cuda.is_available():
+ device = "cuda"
+ print("mem_get_info before", torch.cuda.mem_get_info())
+ torch.cuda.empty_cache()
+ print("mem_get_info after", torch.cuda.mem_get_info())
+ model = model.to(device)
+ if filename is None or not os.path.exists(os.path.join(filename, f'0_features.npy')):
+ model.eval()
+ all_latents, all_targets, all_images = [], [], []
+ Nsamples, chunk_id = 0, 0
+ for idx_epoch in range(n_epoch):
+ for batch_idx, batch in tqdm(enumerate(loader), total=len(loader)):
+ with ch.no_grad():
+ latents, targets = get_features_batch(batch, model,
+ device=device)
+ if batch_idx == 0:
+ print("Latents shape", latents.shape)
+ Nsamples += latents.size(0)
+
+ all_latents.append(latents.cpu())
+ if len(targets.shape) > 1:
+ targets = targets[:, 0]
+ all_targets.append(targets.cpu())
+ # all_images.append(batch[0])
+ if filename is not None and Nsamples > chunk_threshold:
+ if not os.path.exists(filename): os.makedirs(filename)
+ np.save(os.path.join(filename, f'{chunk_id}_features.npy'), ch.cat(all_latents).numpy())
+ np.save(os.path.join(filename, f'{chunk_id}_labels.npy'), ch.cat(all_targets).numpy())
+
+ all_latents, all_targets, Nsamples = [], [], 0
+ chunk_id += 1
+
+ if filename is not None and Nsamples > 0:
+ if not os.path.exists(filename): os.makedirs(filename)
+ np.save(os.path.join(filename, f'{chunk_id}_features.npy'), ch.cat(all_latents).numpy())
+ np.save(os.path.join(filename, f'{chunk_id}_labels.npy'), ch.cat(all_targets).numpy())
+ # np.save(os.path.join(filename, f'{chunk_id}_images.npy'), ch.cat(all_images).numpy())
+ feature_dataset = load_features(filename) if filename is not None else \
+ ch.utils.data.TensorDataset(ch.cat(all_latents), ch.cat(all_targets))
+ if balance:
+ feature_dataset = balance_dataset(feature_dataset)
+
+ feature_loader = ch.utils.data.DataLoader(feature_dataset,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=shuffle)
+
+ return feature_dataset, feature_loader
+
+
+def load_feature_loader(out_dir_feats, val_frac, batch_size, num_workers, random_seed):
+ feature_loaders = {}
+ for mode in ['train', 'test']:
+ print(f"For {mode} set...")
+ sink_path = f"{out_dir_feats}/features_{mode}"
+ metadata_path = f"{out_dir_feats}/metadata_{mode}.pth"
+ feature_ds = load_features(sink_path)
+ feature_loader = ch.utils.data.DataLoader(feature_ds,
+ num_workers=num_workers,
+ batch_size=batch_size)
+ if mode == 'train':
+ metadata = calculate_metadata(feature_loader,
+ num_classes=2048,
+ filename=metadata_path)
+ split_datasets, split_loaders = split_dataset(feature_ds,
+ len(feature_ds),
+ val_frac=val_frac,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ random_seed=random_seed,
+ shuffle=True)
+ feature_loaders.update({mm: sparsification.utils.add_index_to_dataloader(split_loaders[mi])
+ for mi, mm in enumerate(['train', 'val'])})
+
+ else:
+ feature_loaders[mode] = feature_loader
+ return feature_loaders, metadata
+
+
+def balance_dataset(dataset):
+ """Balances a given dataset to have the same number of samples/class.
+ Args:
+ dataset : Torch dataset
+ Returns:
+ Torch dataset with equal number of samples/class
+ """
+
+ print("Balancing dataset...")
+ n = len(dataset)
+ labels = ch.Tensor([dataset[i][1] for i in range(n)]).int()
+ n0 = sum(labels).item()
+ I_pos = labels == 1
+
+ idx = ch.arange(n)
+ idx_pos = idx[I_pos]
+ ch.manual_seed(0)
+ I = ch.randperm(n - n0)[:n0]
+ idx_neg = idx[~I_pos][I]
+ idx_bal = ch.cat([idx_pos, idx_neg], dim=0)
+ return Subset(dataset, idx_bal)
+
+
+def load_metadata(feature_path):
+ return ch.load(os.path.join(feature_path, f'metadata_train.pth'))
+
+
+def get_mean_std(feature_path):
+ metadata = load_metadata(feature_path)
+ return metadata["X"]["mean"], metadata["X"]["std"]
+
+
+def load_features_dataset_mode(feature_path, mode='test',
+ num_workers=10, batch_size=128):
+ """Loads precomputed deep features corresponding to the
+ train/test set along with normalization statitic.
+ Args:
+ feature_path (str): Path to precomputed deep features
+ mode (str): One of train or tesst
+ num_workers (int): Number of workers to use for output loader
+ batch_size (int): Batch size for output loader
+
+ Returns:
+ features (np.array): Recovered deep features
+ feature_mean: Mean of deep features
+ feature_std: Standard deviation of deep features
+ """
+ feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}'))
+ feature_loader = ch.utils.data.DataLoader(feature_dataset,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=False)
+ feature_metadata = ch.load(os.path.join(feature_path, f'metadata_train.pth'))
+ feature_mean, feature_std = feature_metadata['X']['mean'], feature_metadata['X']['std']
+ return feature_loader, feature_mean, feature_std
+
+
+def load_joint_dataset(feature_path, mode='test',
+ num_workers=10, batch_size=128):
+ feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}'))
+ feature_loader = ch.utils.data.DataLoader(feature_dataset,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=False)
+ features = []
+ labels = []
+ for _, (feature, label) in tqdm(enumerate(feature_loader), total=len(feature_loader)):
+ features.append(feature)
+ labels.append(label)
+ features = np.concatenate(features)
+ labels = np.concatenate(labels)
+ dataset = ch.utils.data.TensorDataset(torch.tensor(features), torch.tensor(labels))
+ return dataset
+
+
+def load_features_mode(feature_path, mode='test',
+ num_workers=10, batch_size=128):
+ """Loads precomputed deep features corresponding to the
+ train/test set along with normalization statitic.
+ Args:
+ feature_path (str): Path to precomputed deep features
+ mode (str): One of train or tesst
+ num_workers (int): Number of workers to use for output loader
+ batch_size (int): Batch size for output loader
+
+ Returns:
+ features (np.array): Recovered deep features
+ feature_mean: Mean of deep features
+ feature_std: Standard deviation of deep features
+ """
+ feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}'))
+ feature_loader = ch.utils.data.DataLoader(feature_dataset,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=False)
+
+ feature_metadata = ch.load(os.path.join(feature_path, f'metadata_train.pth'))
+ feature_mean, feature_std = feature_metadata['X']['mean'], feature_metadata['X']['std']
+
+ features = []
+
+ for _, (feature, _) in tqdm(enumerate(feature_loader), total=len(feature_loader)):
+ features.append(feature)
+
+ features = ch.cat(features).numpy()
+ return features, feature_mean, feature_std
+
+
+def load_features(feature_path):
+ """Loads precomputed deep features.
+ Args:
+ feature_path (str): Path to precomputed deep features
+
+ Returns:
+ Torch dataset with recovered deep features.
+ """
+ if not os.path.exists(os.path.join(feature_path, f"0_features.npy")):
+ raise ValueError(f"The provided location {feature_path} does not contain any representation files")
+
+ ds_list, chunk_id = [], 0
+ while os.path.exists(os.path.join(feature_path, f"{chunk_id}_features.npy")):
+ features = ch.from_numpy(np.load(os.path.join(feature_path, f"{chunk_id}_features.npy"))).float()
+ labels = ch.from_numpy(np.load(os.path.join(feature_path, f"{chunk_id}_labels.npy"))).long()
+ ds_list.append(ch.utils.data.TensorDataset(features, labels))
+ chunk_id += 1
+
+ print(f"==> loaded {chunk_id} files of representations...")
+ return ch.utils.data.ConcatDataset(ds_list)
+
+
+def calculate_metadata(loader, num_classes=None, filename=None):
+ """Calculates mean and standard deviation of the deep features over
+ a given set of images.
+ Args:
+ loader : torch data loader
+ num_classes (int): Number of classes in the dataset
+ filename (str): Optional filepath to cache metadata. Recommended
+ for large dataset_classes like ImageNet.
+
+ Returns:
+ metadata (dict): Dictionary with desired statistics.
+ """
+
+ if filename is not None and os.path.exists(filename):
+ print("loading Metadata from ", filename)
+ return ch.load(filename)
+
+ # Calculate number of classes if not given
+ if num_classes is None:
+ num_classes = 1
+ for batch in loader:
+ y = batch[1]
+ print(y)
+ num_classes = max(num_classes, y.max().item() + 1)
+
+ eye = ch.eye(num_classes)
+
+ X_bar, y_bar, y_max, n = 0, 0, 0, 0
+
+ # calculate means and maximum
+ print("Calculating means")
+ for ans in tqdm(loader, total=len(loader)):
+ X, y = ans[:2]
+ X_bar += X.sum(0)
+ y_bar += eye[y].sum(0)
+ y_max = max(y_max, y.max())
+ n += y.size(0)
+ X_bar = X_bar.float() / n
+ y_bar = y_bar.float() / n
+
+ # calculate std
+ X_std, y_std = 0, 0
+ print("Calculating standard deviations")
+ for ans in tqdm(loader, total=len(loader)):
+ X, y = ans[:2]
+ X_std += ((X - X_bar) ** 2).sum(0)
+ y_std += ((eye[y] - y_bar) ** 2).sum(0)
+ X_std = ch.sqrt(X_std.float() / n)
+ y_std = ch.sqrt(y_std.float() / n)
+
+ # calculate maximum regularization
+ inner_products = 0
+ print("Calculating maximum lambda")
+ for ans in tqdm(loader, total=len(loader)):
+ X, y = ans[:2]
+ y_map = (eye[y] - y_bar) / y_std
+ inner_products += X.t().mm(y_map) * y_std
+
+ inner_products_group = inner_products.norm(p=2, dim=1)
+
+ metadata = {
+ "X": {
+ "mean": X_bar,
+ "std": X_std,
+ "num_features": X.size()[1:],
+ "num_examples": n
+ },
+ "y": {
+ "mean": y_bar,
+ "std": y_std,
+ "num_classes": y_max + 1
+ },
+ "max_reg": {
+ "group": inner_products_group.abs().max().item() / n,
+ "nongrouped": inner_products.abs().max().item() / n
+ }
+ }
+
+ if filename is not None:
+ ch.save(metadata, filename)
+
+ return metadata
+
+
+def split_dataset(dataset, Ntotal, val_frac,
+ batch_size, num_workers,
+ random_seed=0, shuffle=True, balance=False):
+ """Splits a given dataset into train and validation
+ Args:
+ dataset : Torch dataset
+ Ntotal: Total number of dataset samples
+ val_frac: Fraction to reserve for validation
+ batch_size (int): Batch size for output loader
+ num_workers (int): Number of workers to use for output loader
+ random_seed (int): Random seed
+ shuffle (bool): Whether or not to shuffle output data loaoder
+ balance (bool): Whether or not to balance output data loader
+ (only relevant for some language models)
+
+ Returns:
+ split_datasets (list): List of dataset_classes (one each for train and val)
+ split_loaders (list): List of loaders (one each for train and val)
+ """
+
+ Nval = math.floor(Ntotal * val_frac)
+ train_ds, val_ds = ch.utils.data.random_split(dataset,
+ [Ntotal - Nval, Nval],
+ generator=ch.Generator().manual_seed(random_seed))
+ if balance:
+ val_ds = balance_dataset(val_ds)
+ split_datasets = [train_ds, val_ds]
+
+ split_loaders = []
+ for ds in split_datasets:
+ split_loaders.append(ch.utils.data.DataLoader(ds,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=shuffle))
+ return split_datasets, split_loaders
diff --git a/sparsification/glmBasedSparsification.py b/sparsification/glmBasedSparsification.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a681147b4394281069c3a6bf0596baf435ecae0
--- /dev/null
+++ b/sparsification/glmBasedSparsification.py
@@ -0,0 +1,130 @@
+import logging
+import os
+import shutil
+
+import numpy as np
+import pandas as pd
+import torch
+from glm_saga.elasticnet import glm_saga
+from torch import nn
+
+from sparsification.FeatureSelection import FeatureSelectionFitting
+from sparsification import data_helpers
+from sparsification.utils import get_default_args, compute_features_and_metadata, select_in_loader, get_feature_loaders
+
+
+def get_glm_selection(feature_loaders, metadata, args, num_classes, device, n_features_to_select, folder):
+ num_features = metadata["X"]["num_features"][0]
+ fittingClass = FeatureSelectionFitting(num_features, num_classes, args, 0.8,
+ n_features_to_select,
+ 0.1,folder,
+ lookback=3, tol=1e-4,
+ epsilon=1,)
+ to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device)
+ selected_features = torch.tensor([i for i in range(num_features) if i not in to_drop])
+ return selected_features
+
+
+def compute_feature_selection_and_assignment(model, train_loader, test_loader, log_folder,num_classes, seed, select_features = 50):
+ feature_loaders, metadata, device,args = get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, )
+
+ if os.path.exists(log_folder / f"SlDD_Selection_{select_features}.pt"):
+ feature_selection = torch.load(log_folder / f"SlDD_Selection_{select_features}.pt")
+ else:
+ used_features = model.linear.weight.shape[1]
+ if used_features != select_features:
+ selection_folder = log_folder / "sldd_selection" # overwrite with None to prevent saving
+ feature_selection = get_glm_selection(feature_loaders, metadata, args,
+ num_classes,
+ device,select_features, selection_folder
+ )
+ else:
+ feature_selection = model.linear.selection
+ torch.save(feature_selection, log_folder / f"SlDD_Selection_{select_features}.pt")
+ feature_loaders = select_in_loader(feature_loaders, feature_selection)
+ mean, std = metadata["X"]["mean"], metadata["X"]["std"]
+ mean_to_pass_in = mean
+ std_to_pass_in = std
+ if len(mean) != feature_selection.shape[0]:
+ mean_to_pass_in = mean[feature_selection]
+ std_to_pass_in = std[feature_selection]
+
+ sparse_matrices, biases = fit_glm(log_folder, mean_to_pass_in, std_to_pass_in, feature_loaders, num_classes, select_features)
+
+ return feature_selection, sparse_matrices, biases, mean, std
+
+
+def fit_glm(log_dir,mean, std , feature_loaders, num_classes, select_features = 50):
+ output_folder = log_dir / "glm_path"
+ if not output_folder.exists() or len(list(output_folder.iterdir())) != 102:
+ shutil.rmtree(output_folder, ignore_errors=True)
+ output_folder.mkdir(exist_ok=True, parents=True)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ linear = nn.Linear(select_features, num_classes).to(device)
+ for p in [linear.weight, linear.bias]:
+ p.data.zero_()
+ print("Preparing normalization preprocess and indexed dataloader")
+ metadata = {"X": {"mean": mean, "std": std},}
+ preprocess = data_helpers.NormalizedRepresentation(feature_loaders['train'],
+ metadata=metadata,
+ device=linear.weight.device)
+
+ print("Calculating the regularization path")
+ mpl_logger = logging.getLogger("matplotlib")
+ mpl_logger.setLevel(logging.WARNING)
+ params = glm_saga(linear,
+ feature_loaders['train'],
+ 0.1,
+ 2000,
+ 0.99, k=100,
+ val_loader=feature_loaders['val'],
+ test_loader=feature_loaders['test'],
+ n_classes=num_classes,
+ checkpoint=str(output_folder),
+ verbose=200,
+ tol=1e-4, # Change for ImageNet
+ lookbehind=5,
+ lr_decay_factor=1,
+ group=False,
+ epsilon=0.001,
+ metadata=None, # To let it be recomputed
+ preprocess=preprocess, )
+ results = load_glm(output_folder)
+ sparse_matrices = results["weights"]
+ biases = results["biases"]
+
+ return sparse_matrices, biases
+
+def load_glm(result_dir):
+ Nlambda = max([int(f.split('params')[1].split('.pth')[0])
+ for f in os.listdir(result_dir) if 'params' in f]) + 1
+
+ print(f"Loading regularization path of length {Nlambda}")
+
+ params_dict = {i: torch.load(os.path.join(result_dir, f"params{i}.pth"),
+ map_location=torch.device('cpu')) for i in range(Nlambda)}
+
+ regularization_strengths = [params_dict[i]['lam'].item() for i in range(Nlambda)]
+ weights = [params_dict[i]['weight'] for i in range(Nlambda)]
+ biases = [params_dict[i]['bias'] for i in range(Nlambda)]
+
+ metrics = {'acc_tr': [], 'acc_val': [], 'acc_test': []}
+
+ for k in metrics.keys():
+ for i in range(Nlambda):
+ metrics[k].append(params_dict[i]['metrics'][k])
+ metrics[k] = 100 * np.stack(metrics[k])
+ metrics = pd.DataFrame(metrics)
+ metrics = metrics.rename(columns={'acc_tr': 'acc_train'})
+
+ # weights_stacked = ch.stack(weights)
+ # sparsity = ch.sum(weights_stacked != 0, dim=2).numpy()
+ sparsity = np.array([torch.sum(w != 0, dim=1).numpy() for w in weights])
+
+ return {'metrics': metrics,
+ 'regularization_strengths': regularization_strengths,
+ 'weights': weights,
+ 'biases': biases,
+ 'sparsity': sparsity,
+ 'weight_dense': weights[-1],
+ 'bias_dense': biases[-1]}
diff --git a/sparsification/qsenn.py b/sparsification/qsenn.py
new file mode 100644
index 0000000000000000000000000000000000000000..45eb1bde64846d26996962f7a1c3b8d8e0ffa6ab
--- /dev/null
+++ b/sparsification/qsenn.py
@@ -0,0 +1,63 @@
+import numpy as np
+import torch
+
+from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment
+
+
+def compute_qsenn_feature_selection_and_assignment(model, train_loader, test_loader, log_folder, num_classes, seed,n_features, per_class = 5):
+ feature_sel, sparse_matrices, biases, mean, std = compute_feature_selection_and_assignment(model, train_loader,
+ test_loader,
+ log_folder, num_classes, seed, n_features)
+ weight_sparse, bias_sparse = get_sparsified_weights_for_factor(sparse_matrices[:-1], biases[:-1], per_class) # Last one in regularisation path has no regularisation
+ print(f"Number of nonzeros in weight matrix: {torch.sum(weight_sparse != 0)}")
+ return feature_sel, weight_sparse, bias_sparse, mean, std
+def get_sparsified_weights_for_factor(weights, biases, factor,):
+ no_reg_result_mat, no_reg_result_bias = weights[-1], biases[-1]
+ goal_nonzeros = factor * no_reg_result_mat.shape[0]
+ values = no_reg_result_mat.flatten()
+ values = values[values != 0]
+ values = -(torch.sort(-torch.abs(values))[0])
+ if goal_nonzeros < len(values):
+ threshold = (values[int(goal_nonzeros) - 1] + values[int(goal_nonzeros)]) / 2
+ else:
+ threshold = values[-1]
+ max_val = torch.max(torch.abs(values))
+ weight_sparse = discretize_2_bins_to_threshold(no_reg_result_mat, threshold, max_val)
+ sel_idx = len(weights) - 1
+ positive_weights_per_class = np.array(torch.sum(weight_sparse > 0, dim=1))
+ negative_weights_per_class = np.array(torch.sum(weight_sparse < 0, dim=1))
+ total_weight_count_per_class = positive_weights_per_class - negative_weights_per_class
+ max_bias = torch.max(torch.abs(biases[sel_idx]))
+ bias_sparse = torch.ones_like(biases[sel_idx]) * max_bias
+ diff_n_weight = total_weight_count_per_class - np.min(total_weight_count_per_class)
+ steps = np.max(diff_n_weight)
+ single_step = 2 * max_bias / steps
+ bias_sparse = bias_sparse - torch.tensor(diff_n_weight) * single_step
+ bias_sparse = torch.clamp(bias_sparse, -max_bias, max_bias)
+ return weight_sparse, bias_sparse
+
+
+def discretize_2_bins_to_threshold(data, treshold, max):
+ boundaries = torch.tensor([-max, -treshold, treshold, max], device=data.device)
+ bucketized_tensor = torch.bucketize(data, boundaries)
+ means = torch.tensor([-max, 0, max], device=data.device)
+ for i in range(len(means)):
+ if means[i] == 0:
+ break
+ positive_index = int(len(means) / 2 + 1) + i
+ positive_bucket = data[bucketized_tensor == positive_index + 1]
+ negative_bucket = data[bucketized_tensor == i + 1]
+ sum = 0
+ total = 0
+ for bucket in [positive_bucket, negative_bucket]:
+ if len(bucket) == 0:
+ continue
+ sum += torch.sum(torch.abs(bucket))
+ total += len(bucket)
+ if total == 0:
+ continue
+ avg = sum / total
+ means[i] = -avg
+ means[positive_index] = avg
+ discretized_tensor = means.cpu()[bucketized_tensor.cpu() - 1].to(bucketized_tensor.device)
+ return discretized_tensor
\ No newline at end of file
diff --git a/sparsification/sldd.py b/sparsification/sldd.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eeb3733797950107c6a0987fe242a0bfe2e732a
--- /dev/null
+++ b/sparsification/sldd.py
@@ -0,0 +1,44 @@
+import numpy as np
+import torch
+
+from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment
+
+
+def compute_sldd_feature_selection_and_assignment(model, train_loader, test_loader, log_folder, num_classes, seed,
+ per_class=5, select_features=50):
+ feature_sel, sparse_matrices, biases, mean, std = compute_feature_selection_and_assignment(model, train_loader,
+ test_loader,
+ log_folder, num_classes,
+ seed, select_features=select_features)
+ weight_sparse, bias_sparse = get_sparsified_weights_for_factor(sparse_matrices,biases,
+ per_class) # Last one in regularisation path has none
+ return feature_sel, weight_sparse, bias_sparse, mean, std
+
+def get_sparsified_weights_for_factor(sparse_layer,biases,keep_per_class, drop_rate=0.5):
+ nonzero_entries = [torch.sum(torch.count_nonzero(sparse_layer[i])) for i in range(len(sparse_layer))]
+ mean_sparsity = np.array([nonzero_entries[i] / sparse_layer[i].shape[0] for i in range(len(sparse_layer))])
+ factor =keep_per_class / drop_rate
+ # Get layer with desired sparsity
+ sparse_enough = mean_sparsity <= factor
+ sel_idx = np.argmax(sparse_enough * mean_sparsity)
+ if sel_idx == 0 and np.sum(mean_sparsity) > 1: # sometimes first one is odd
+ sparse_enough[0] = False
+ sel_idx = np.argmax(sparse_enough * mean_sparsity)
+ selected_weight = sparse_layer[sel_idx]
+ selected_bias = biases[sel_idx]
+ # only keep 5 per class on average
+ weight_5_per_matrix = set_lowest_percent_to_zero(selected_weight,5)
+
+ return weight_5_per_matrix,selected_bias
+
+
+def set_lowest_percent_to_zero(matrix, keep_per):
+ nonzero_indices = torch.nonzero(matrix)
+ values = torch.tensor([matrix[x[0], x[1]] for x in nonzero_indices])
+ sorted_indices = torch.argsort(torch.abs(values))
+ total_allowed = int(matrix.shape[0] * keep_per)
+ sorted_indices = sorted_indices[:-total_allowed]
+ nonzero_indices_to_zero = [nonzero_indices[x] for x in sorted_indices]
+ for to_zero in nonzero_indices_to_zero:
+ matrix[to_zero[0], to_zero[1]] = 0
+ return matrix
\ No newline at end of file
diff --git a/sparsification/utils.py b/sparsification/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e960e5c4131032242e41bf7ff5a8a02b571fc89
--- /dev/null
+++ b/sparsification/utils.py
@@ -0,0 +1,159 @@
+from argparse import ArgumentParser
+
+import torch
+
+#from sparsification.glm_saga import glm_saga
+from sparsification import feature_helpers
+
+
+def safe_zip(*args):
+ for iterable in args[1:]:
+ if len(iterable) != len(args[0]):
+ print("Unequally sized iterables to zip, printing lengths")
+ for i, entry in enumerate(args):
+ print(i, len(entry))
+ raise ValueError("Unequally sized iterables to zip")
+ return zip(*args)
+
+
+def compute_features_and_metadata(args, train_loader, test_loader, model, out_dir_feats, num_classes,
+ ):
+ print("Computing/loading deep features...")
+
+ Ntotal = len(train_loader.dataset)
+ feature_loaders = {}
+ # Compute Features for not augmented train and test set
+ train_loader_transforms = train_loader.dataset.transform
+ test_loader_transforms = test_loader.dataset.transform
+ train_loader.dataset.transform = test_loader_transforms
+ for mode, loader in zip(['train', 'test', ], [train_loader, test_loader, ]): #
+ print(f"For {mode} set...")
+
+ sink_path = f"{out_dir_feats}/features_{mode}"
+ metadata_path = f"{out_dir_feats}/metadata_{mode}.pth"
+
+ feature_ds, feature_loader = feature_helpers.compute_features(loader,
+ model,
+ dataset_type=args.dataset_type,
+ pooled_output=None,
+ batch_size=args.batch_size,
+ num_workers=0, # args.num_workers,
+ shuffle=(mode == 'test'),
+ device=args.device,
+ filename=sink_path, n_epoch=1,
+ balance=False,
+ ) # args.balance if mode == 'test' else False)
+
+ if mode == 'train':
+ metadata = feature_helpers.calculate_metadata(feature_loader,
+ num_classes=num_classes,
+ filename=metadata_path)
+ if metadata["max_reg"]["group"] == 0.0:
+ return None, False
+ split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds,
+ Ntotal,
+ val_frac=args.val_frac,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ random_seed=args.random_seed,
+ shuffle=True,
+ balance=False)
+ feature_loaders.update({mm: add_index_to_dataloader(split_loaders[mi])
+ for mi, mm in enumerate(['train', 'val'])})
+
+ else:
+ feature_loaders[mode] = feature_loader
+ train_loader.dataset.transform = train_loader_transforms
+ return feature_loaders, metadata
+
+def get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ):
+ args = get_default_args()
+ args.random_seed = seed
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ feature_folder = log_folder / "features"
+ feature_loaders, metadata, = compute_features_and_metadata(args, train_loader, test_loader, model,
+ feature_folder
+ ,
+ num_classes,
+ )
+ return feature_loaders, metadata, device,args
+def add_index_to_dataloader(loader, sample_weight=None,):
+ return torch.utils.data.DataLoader(
+ IndexedDataset(loader.dataset, sample_weight=sample_weight),
+ batch_size=loader.batch_size,
+ sampler=loader.sampler,
+ num_workers=loader.num_workers,
+ collate_fn=loader.collate_fn,
+ pin_memory=loader.pin_memory,
+ drop_last=loader.drop_last,
+ timeout=loader.timeout,
+ worker_init_fn=loader.worker_init_fn,
+ multiprocessing_context=loader.multiprocessing_context
+ )
+
+
+class IndexedDataset(torch.utils.data.Dataset):
+ def __init__(self, ds, sample_weight=None):
+ super(torch.utils.data.Dataset, self).__init__()
+ self.dataset = ds
+ self.sample_weight = sample_weight
+
+ def __getitem__(self, index):
+ val = self.dataset[index]
+ if self.sample_weight is None:
+ return val + (index,)
+ else:
+ weight = self.sample_weight[index]
+ return val + (weight, index)
+
+ def __len__(self):
+ return len(self.dataset)
+
+
+def get_default_args():
+ # Default args from glm_saga, https://github.com/MadryLab/glm_saga
+ parser = ArgumentParser()
+ parser.add_argument('--dataset', type=str, help='dataset name')
+ parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]')
+ parser.add_argument('--dataset-path', type=str, help='path to dataset')
+ parser.add_argument('--model-path', type=str, help='path to model checkpoint')
+ parser.add_argument('--arch', type=str, help='model architecture type')
+ parser.add_argument('--out-path', help='location for saving results')
+ parser.add_argument('--cache', action='store_true', help='cache deep features')
+ parser.add_argument('--balance', action='store_true', help='balance classes for evaluation')
+
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--random-seed', default=0)
+ parser.add_argument('--num-workers', type=int, default=2)
+ parser.add_argument('--batch-size', type=int, default=256)
+ parser.add_argument('--val-frac', type=float, default=0.1)
+ parser.add_argument('--lr-decay-factor', type=float, default=1)
+ parser.add_argument('--lr', type=float, default=0.1)
+ parser.add_argument('--alpha', type=float, default=0.99)
+ parser.add_argument('--max-epochs', type=int, default=2000)
+ parser.add_argument('--verbose', type=int, default=200)
+ parser.add_argument('--tol', type=float, default=1e-4)
+ parser.add_argument('--lookbehind', type=int, default=3)
+ parser.add_argument('--lam-factor', type=float, default=0.001)
+ parser.add_argument('--group', action='store_true')
+ args = parser.parse_args()
+
+ args = parser.parse_args()
+ return args
+
+
+def select_in_loader(feature_loaders, feature_selection):
+ for dataset in feature_loaders["train"].dataset.dataset.dataset.datasets: # Val is indexed via the same dataset as train
+ tensors = list(dataset.tensors)
+ if tensors[0].shape[1] == len(feature_selection):
+ continue
+ tensors[0] = tensors[0][:, feature_selection]
+ dataset.tensors = tensors
+ for dataset in feature_loaders["test"].dataset.datasets:
+ tensors = list(dataset.tensors)
+ if tensors[0].shape[1] == len(feature_selection):
+ continue
+ tensors[0] = tensors[0][:, feature_selection]
+ dataset.tensors = tensors
+ return feature_loaders
+
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc42416ac52c799af9548dff401d98bb4805b4ce
--- /dev/null
+++ b/train.py
@@ -0,0 +1,59 @@
+import torch
+from tqdm import tqdm
+
+from training.utils import VariableLossLogPrinter
+
+
+def get_acc(outputs, targets):
+ _, predicted = torch.max(outputs.data, 1)
+ total = targets.size(0)
+ correct = (predicted == targets).sum().item()
+ return correct / total * 100
+
+
+
+def train(model, train_loader, optimizer, fdl, epoch):
+ model.train()
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ VariableLossPrinter = VariableLossLogPrinter()
+ model = model.to(device)
+ iterator = tqdm(enumerate(train_loader), total=len(train_loader))
+ for batch_idx, (data, target) in iterator:
+ on_device = data.to(device)
+ target_on_device = target.to(device)
+
+ output, feature_maps = model(on_device, with_feature_maps=True)
+ loss = torch.nn.functional.cross_entropy(output, target_on_device)
+
+ fdl_loss = fdl(feature_maps, output)
+ total_loss = loss + fdl_loss
+
+ optimizer.zero_grad()
+ total_loss.backward()
+ optimizer.step()
+ acc = get_acc(output, target_on_device)
+ VariableLossPrinter.log_loss("Train Acc", acc, on_device.size(0))
+ VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0))
+ VariableLossPrinter.log_loss("FDL", fdl_loss.item(), on_device.size(0))
+ VariableLossPrinter.log_loss("Total-Loss", total_loss.item(), on_device.size(0))
+ iterator.set_description(f"Train Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}")
+ print("Trained model for one epoch ", epoch," with lr group 0: ", optimizer.param_groups[0]["lr"])
+ return model
+
+
+def test(model, test_loader, epoch):
+ model.eval()
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+ VariableLossPrinter = VariableLossLogPrinter()
+ iterator = tqdm(enumerate(test_loader), total=len(test_loader))
+ with torch.no_grad():
+ for batch_idx, (data, target) in iterator:
+ on_device = data.to(device)
+ target_on_device = target.to(device)
+ output, feature_maps = model(on_device, with_feature_maps=True)
+ loss = torch.nn.functional.cross_entropy(output, target_on_device)
+ acc = get_acc(output, target_on_device)
+ VariableLossPrinter.log_loss("Test Acc", acc, on_device.size(0))
+ VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0))
+ iterator.set_description(f"Test Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}")
diff --git a/training/img_net.py b/training/img_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b12214045c4b83a74c25ddd68cab9e57d1f0598
--- /dev/null
+++ b/training/img_net.py
@@ -0,0 +1,7 @@
+def get_default_img_optimizer(model):
+ raise NotImplementedError("TODO: Implement get_default_img_optimizer")
+ pass
+
+def get_default_img_schedule(default_img_optimizer):
+ raise NotImplementedError("TODO: Implement get_default_img_schedule")
+ pass
\ No newline at end of file
diff --git a/training/optim.py b/training/optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..dff2759f583b49318a8ea60c3229e2fb90bb0135
--- /dev/null
+++ b/training/optim.py
@@ -0,0 +1,45 @@
+from torch.optim import SGD, lr_scheduler
+
+from configs.qsenn_training_params import QSENNScheduler
+from configs.sldd_training_params import OptimizationScheduler
+from training.img_net import get_default_img_schedule, get_default_img_optimizer
+
+
+def get_optimizer(model, schedulingClass):
+ lr,weight_decay, step_lr, step_lr_gamma, n_epochs, finetune = schedulingClass.get_params()
+ print("Optimizer LR set to ", lr)
+ if lr is None: # Dense Training on ImageNet
+ print("Learning rate is None, using Default Recipe for Resnet50")
+ default_img_optimizer = get_default_img_optimizer(model)
+ default_img_schedule = get_default_img_schedule(default_img_optimizer)
+ return default_img_optimizer, default_img_schedule, 600
+ if finetune:
+ param_list = [x for x in model.parameters() if x.requires_grad]
+ else:
+ param_list = model.parameters()
+
+
+ if finetune:
+ optimizer = SGD(param_list,lr, momentum=0.95,
+ weight_decay=weight_decay)
+ else:
+ classifier_params_name = ["linear.bias","linear.weight"]
+ classifier_params = [x[1] for x in
+ list(filter(lambda kv: kv[0] in classifier_params_name, model.named_parameters()))]
+ base_params = [x[1] for x in list(
+ filter(lambda kv: kv[0] not in classifier_params_name, model.named_parameters()))]
+
+ optimizer = SGD([
+ {'params': base_params},
+ {"params": classifier_params, 'lr': 0.01}
+ ], momentum=0.9, lr=lr, weight_decay=weight_decay)
+ # Make schedule
+ schedule = lr_scheduler.StepLR(optimizer, step_size=step_lr, gamma=step_lr_gamma)
+ return optimizer, schedule, n_epochs
+
+
+def get_scheduler_for_model(model, dataset):
+ if model == "qsenn":
+ return QSENNScheduler(dataset)
+ elif model == "sldd":
+ return OptimizationScheduler(dataset)
\ No newline at end of file
diff --git a/training/utils.py b/training/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..07dc2d99536f0b243669f6741c34909264f7c9c2
--- /dev/null
+++ b/training/utils.py
@@ -0,0 +1,32 @@
+
+#from robustness.robustness.tools.helpers https://github.com/MadryLab/robustness
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+class VariableLossLogPrinter():
+ def __init__(self):
+ self.losses = {}
+
+ def log_loss(self, key, val, n=1):
+ if not key in self.losses:
+ self.losses[key] = AverageMeter()
+ self.losses[key].update(val, n)
+
+ def get_loss_string(self):
+ loss_string = " | ".join([f"{key}: {self.losses[key].avg:.4f}" for key in self.losses])
+
+ return loss_string
diff --git a/visualization.py b/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9d7582bd5fbe8c9d5c458cb2707aefc3d4f205
--- /dev/null
+++ b/visualization.py
@@ -0,0 +1,143 @@
+import gradio as gr
+from load_model import extract_sel_mean_std_bias_assignemnt
+from pathlib import Path
+from architectures.model_mapping import get_model
+from configs.dataset_params import dataset_constants
+import torch
+import torchvision.transforms as transforms
+import pandas as pd
+import cv2
+import numpy as np
+
+def overlapping_features_on_input(model,output, feature_maps, input, target):
+ W=model.linear.layer.weight
+ output=output.detach().cpu().numpy()
+ feature_maps=feature_maps.detach().cpu().numpy().squeeze()
+
+ if target !=None:
+ label=target
+ else:
+ label=np.argmax(output)+1
+
+ Interpretable_Selection= W[label,:]
+ print("W",Interpretable_Selection)
+ input_np=np.array(input)
+ h,w= input.shape[:2]
+ print("h,w:",h,w)
+ Interpretable_Features=[]
+ Feature_image_list=[]
+ for S in range(len(Interpretable_Selection)):
+ if Interpretable_Selection[S] > 0:
+ Interpretable_Features.append(feature_maps[S])
+ Feature_image=cv2.resize(feature_maps[S],(w,h))
+ Feature_image=((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)))*255
+ Feature_image=Feature_image.astype(np.uint8)
+ Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET)
+ Feature_image=0.3*Feature_image+0.7*input_np
+ Feature_image=np.clip(Feature_image, 0, 255).astype(np.uint8)
+ Feature_image_list.append(Feature_image)
+ #path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg"
+ #cv2.imwrite(path_to_featureimage,Feature_image)
+ print("len of Features:",len(Interpretable_Features))
+
+ return Feature_image_list
+
+
+def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None):
+ n_classes = dataset_constants[dataset]["num_classes"]
+
+ model = get_model(arch, n_classes, reduced_strides)
+ tr=transforms.ToTensor()
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if folder is None:
+ folder = Path.home() / f"tmp/{arch}/{dataset}/{seed}/"
+
+ state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth")
+ selection= torch.load(folder / f"SlDD_Selection_50.pt")
+ state_dict['linear.selection']=selection
+
+ feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
+ model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
+ model.load_state_dict(state_dict)
+
+ input = tr(input)
+ input= input.unsqueeze(0)
+ input= input.to(device)
+ model = model.to(device)
+ output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True)
+ print("final features:",final_features)
+ output=output.detach().cpu().numpy()
+ output= np.argmax(output)+1
+
+
+ print("outputclass:",output)
+ data_dir=Path.home()/"tmp/Datasets/CUB200/CUB_200_2011/"
+ labels = pd.read_csv(data_dir/"image_class_labels.txt", sep=' ', names=['img_id', 'target'])
+ namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name'])
+ classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
+ options_output=labels[labels['target']==output]
+ options_output=options_output.sample(1)
+ others=labels[labels['target']!=output]
+ options_others=others.sample(3)
+ options = pd.concat([options_others, options_output], ignore_index=True)
+ shuffled_options = options.sample(frac=1).reset_index(drop=True)
+ print("shuffled:",shuffled_options)
+ op=[]
+
+ for i in shuffled_options['img_id']:
+ print(i)
+ filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0]
+ targets=shuffled_options.loc[shuffled_options['img_id']==i,'target'].values[0]
+ print("targets",targets)
+ print("name",filenames)
+
+ classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0]
+ print(data_dir/f"images/{filenames}")
+
+ op_img=cv2.imread(data_dir/f"images/{filenames}")
+
+ op_images=tr(op_img)
+ op_images=op_images.unsqueeze(0)
+ op_images=op_images.to(device)
+ OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False)
+ print("OP:",OP,
+ "feature_maps_op:",feature_maps_op.shape)
+ opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets)
+ op+=opt
+
+ return op
+
+def post_next_image(op):
+ if len(op)<=1:
+ return [],None, "all done, thank you!"
+ else:
+ op=op[1:len(op)]
+ return op,op[0], "Is this feature also in your input?"
+
+def get_features_on_interface(input):
+ op=genreate_intepriable_output(input,dataset="CUB2011",
+ arch="resnet50",seed=123456,
+ model_type="qsenn", n_features = 50,n_per_class=5,
+ img_size=448, reduced_strides=False, folder = None)
+ return op, op[0],"Is this feature also in your input?",gr.update(interactive=False)
+
+
+with gr.Blocks() as demo:
+
+ gr.Markdown("Interiable Bird Classification
")
+ image_input=gr.Image()
+ image_output=gr.Image()
+ text_output=gr.Markdown()
+ but_generate=gr.Button("Get some interpriable Features")
+ but_feedback_y=gr.Button("Yes")
+ but_feedback_n=gr.Button("No")
+ image_list = gr.State([])
+ but_generate.click(fn=get_features_on_interface, inputs=image_input, outputs=[image_list,image_output,text_output,but_generate])
+ but_feedback_y.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output])
+ but_feedback_n.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output])
+
+demo.launch()
+
+
+
+
\ No newline at end of file