Spaces:
Runtime error
Runtime error
| import tempfile | |
| import os | |
| import spaces | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from evo.tools.file_interface import read_kitti_poses_file | |
| from pathlib import Path | |
| import rerun as rr | |
| from typing import Optional, Dict | |
| from visualization.logger import SimulationLogger | |
| from scipy.spatial.transform import Rotation | |
| def load_trajectory_data(traj_file: str, char_file: str) -> Dict: | |
| trajectory = read_kitti_poses_file(traj_file) | |
| matrix_trajectory = torch.from_numpy( | |
| np.array(trajectory.poses_se3)).to(torch.float32) | |
| char_feature = torch.from_numpy(np.load(char_file)).to(torch.float32) | |
| return { | |
| "traj_filename": Path(traj_file).name, | |
| "char_filename": Path(char_file).name, | |
| "char_feat": char_feature, | |
| "matrix_trajectory": matrix_trajectory | |
| } | |
| class ETLogger(SimulationLogger): | |
| def __init__(self): | |
| super().__init__() | |
| rr.init("et_visualization") | |
| rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True) | |
| self.K = np.array([ | |
| [500, 0, 320], | |
| [0, 500, 240], | |
| [0, 0, 1] | |
| ]) | |
| def log_trajectory(self, trajectory: np.ndarray): | |
| positions = trajectory[:, :3, 3] | |
| rr.log( | |
| "world/trajectory/points", | |
| rr.Points3D( | |
| positions, | |
| colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0]) | |
| ), | |
| timeless=True | |
| ) | |
| if len(positions) > 1: | |
| lines = np.stack([positions[:-1], positions[1:]], axis=1) | |
| rr.log( | |
| "world/trajectory/line", | |
| rr.LineStrips3D( | |
| lines, | |
| colors=[(0.0, 0.8, 0.8, 1.0)] | |
| ), | |
| timeless=True | |
| ) | |
| for k in range(len(trajectory)): | |
| rr.set_time_sequence("frame_idx", k) | |
| translation = trajectory[k, :3, 3] | |
| rotation_q = Rotation.from_matrix( | |
| trajectory[k, :3, :3]).as_quat() | |
| rr.log( | |
| f"world/camera", | |
| rr.Transform3D( | |
| translation=translation, | |
| rotation=rr.Quaternion(xyzw=rotation_q), | |
| ), | |
| ) | |
| rr.log( | |
| f"world/camera/image", | |
| rr.Pinhole( | |
| image_from_camera=self.K, | |
| width=640, | |
| height=480, | |
| ), | |
| ) | |
| def log_character(self, char_feature: np.ndarray): | |
| rr.log( | |
| "world/character", | |
| rr.Points3D( | |
| char_feature.reshape(-1, 3), | |
| colors=np.full( | |
| (char_feature.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0]) | |
| ), | |
| timeless=True | |
| ) | |
| def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]: | |
| try: | |
| data = load_trajectory_data(traj_file, char_file) | |
| temp_dir = tempfile.mkdtemp() | |
| rrd_path = os.path.join(temp_dir, "et_visualization.rrd") | |
| logger = ETLogger() | |
| logger.log_trajectory(data["matrix_trajectory"].numpy()) | |
| logger.log_character(data["char_feat"].numpy()) | |
| rr.save(rrd_path) | |
| return rrd_path | |
| except Exception as e: | |
| print(f"Error visualizing E.T. data: {str(e)}") | |
| return None | |