Spaces:
Running
Running
| """ | |
| File: config.py | |
| Author: Elena Ryumina and Dmitry Ryumin | |
| Description: Plotting statistical information. | |
| License: MIT License | |
| """ | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| # Importing necessary components for the Gradio app | |
| from app.config import DICT_PRED | |
| def show_cam_on_image( | |
| img: np.ndarray, | |
| mask: np.ndarray, | |
| use_rgb: bool = False, | |
| colormap: int = cv2.COLORMAP_JET, | |
| image_weight: float = 0.5, | |
| ) -> np.ndarray: | |
| """This function overlays the cam mask on the image as an heatmap. | |
| By default the heatmap is in BGR format. | |
| :param img: The base image in RGB or BGR format. | |
| :param mask: The cam mask. | |
| :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. | |
| :param colormap: The OpenCV colormap to be used. | |
| :param image_weight: The final result is image_weight * img + (1-image_weight) * mask. | |
| :returns: The default image with the cam overlay. | |
| Implemented by https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py | |
| """ | |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
| if use_rgb: | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| heatmap = np.float32(heatmap) / 255 | |
| if np.max(img) > 1: | |
| raise Exception("The input image should np.float32 in the range [0, 1]") | |
| if image_weight < 0 or image_weight > 1: | |
| raise Exception( | |
| f"image_weight should be in the range [0, 1].\ | |
| Got: {image_weight}" | |
| ) | |
| cam = (1 - image_weight) * heatmap + image_weight * img | |
| cam = cam / np.max(cam) | |
| return np.uint8(255 * cam) | |
| def get_heatmaps( | |
| gradients, activations, name_layer, face_image, use_rgb=True, image_weight=0.6 | |
| ): | |
| gradient = gradients[name_layer] | |
| activation = activations[name_layer] | |
| pooled_gradients = torch.mean(gradient[0], dim=[0, 2, 3]) | |
| for i in range(activation.size()[1]): | |
| activation[:, i, :, :] *= pooled_gradients[i] | |
| heatmap = torch.mean(activation, dim=1).squeeze().cpu() | |
| heatmap = np.maximum(heatmap, 0) | |
| heatmap /= torch.max(heatmap) | |
| heatmap = torch.unsqueeze(heatmap, -1) | |
| heatmap = cv2.resize(heatmap.detach().numpy(), (224, 224)) | |
| cur_face_hm = cv2.resize(face_image, (224, 224)) | |
| cur_face_hm = np.float32(cur_face_hm) / 255 | |
| heatmap = show_cam_on_image( | |
| cur_face_hm, heatmap, use_rgb=use_rgb, image_weight=image_weight | |
| ) | |
| return heatmap | |
| def plot_compound_expression_prediction( | |
| dict_preds: dict[str, list[float]], | |
| save_path: str = None, | |
| frame_indices: list[int] = None, | |
| colors: list[str] = ["green", "orange", "red", "purple", "blue"], | |
| figsize: tuple = (12, 6), | |
| title: str = "Confusion Matrix", | |
| ) -> plt.Figure: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| for idx, (k, v) in enumerate(dict_preds.items()): | |
| if idx == 2: | |
| offset = (idx+1 - len(dict_preds) // 2) * 0.1 | |
| elif idx == 3: | |
| offset = (idx-1 - len(dict_preds) // 2) * 0.1 | |
| else: | |
| offset = (idx - len(dict_preds) // 2) * 0.1 | |
| shifted_v = [val + offset + 1 for val in v] | |
| ax.plot(range(1, len(shifted_v) + 1), shifted_v, color=colors[idx], linestyle='dotted', label=k) | |
| ax.legend() | |
| ax.grid(True) | |
| ax.set_xlabel("Number of frames") | |
| ax.set_ylabel("Basic emotion / compound expression") | |
| ax.set_title(title) | |
| ax.set_xticks([i+1 for i in frame_indices]) | |
| ax.set_yticks( | |
| range(0, 21) | |
| ) | |
| ax.set_yticklabels([''] + list(DICT_PRED.values()) + ['']) | |
| fig.tight_layout() | |
| if save_path: | |
| fig.savefig( | |
| save_path, | |
| format=save_path.rsplit(".", 1)[1], | |
| bbox_inches="tight", | |
| pad_inches=0, | |
| ) | |
| return fig | |
| def display_frame_info(img, text, margin=1.0, box_scale=1.0): | |
| img_copy = img.copy() | |
| img_h, img_w, _ = img_copy.shape | |
| line_width = int(min(img_h, img_w) * 0.001) | |
| thickness = max(int(line_width / 3), 1) | |
| font_face = cv2.FONT_HERSHEY_SIMPLEX | |
| font_color = (0, 0, 0) | |
| font_scale = thickness / 1.5 | |
| t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0] | |
| margin_n = int(t_h * margin) | |
| sub_img = img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale), | |
| img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n] | |
| white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255 | |
| img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale), | |
| img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5, 1.0) | |
| cv2.putText(img=img_copy, | |
| text=text, | |
| org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2, | |
| 0 + margin_n + t_h + int(2 * t_h * box_scale) // 2), | |
| fontFace=font_face, | |
| fontScale=font_scale, | |
| color=font_color, | |
| thickness=thickness, | |
| lineType=cv2.LINE_AA, | |
| bottomLeftOrigin=False) | |
| return img_copy | |
| def plot_audio(time_axis, waveform, frame_indices, fps, figsize=(10, 4)) -> plt.Figure: | |
| frame_times = np.array(frame_indices) / fps | |
| fig, ax = plt.subplots(figsize=figsize) | |
| ax.plot(time_axis, waveform[0]) | |
| ax.set_xlabel('Time (frames)') | |
| ax.set_ylabel('Amplitude') | |
| ax.grid(True) | |
| ax.set_xticks(frame_times) | |
| ax.set_xticklabels([f'{int(frame_time*fps)+1}' for frame_time in frame_times]) | |
| fig.tight_layout() | |
| return fig | |
| def plot_images(image_paths): | |
| fig, axes = plt.subplots(1, len(image_paths), figsize=(12, 2)) | |
| for ax, img_path in zip(axes, image_paths): | |
| ax.imshow(img_path) | |
| ax.axis('off') | |
| fig.tight_layout() | |
| return fig | |