import os import logging try: from google.cloud import storage # type: ignore exists_gcs = True except Exception: exists_gcs = False def _parse_gs_uri(uri: str): uri = uri.strip() if not uri.startswith('gs://'): raise ValueError(f"Invalid GCS URI: {uri}") path = uri[5:] parts = path.split('/', 1) bucket = parts[0] blob = parts[1] if len(parts) > 1 else '' if not bucket or not blob: raise ValueError(f"Invalid GCS URI (missing bucket/blob): {uri}") return bucket, blob def _ensure_dir(path: str): os.makedirs(path, exist_ok=True) def ensure_weights_available(): """Ensure model weights exist at MODEL_DIR. If WEIGHTS_URIS is set to comma-separated gs:// URIs, download any missing files. If google-cloud-storage is unavailable or URIs are not set, this is a no-op. """ model_dir = os.environ.get('MODEL_DIR', '/models') if not os.access(model_dir, os.W_OK): # Fallback to /tmp if not writeable model_dir = '/tmp/models' os.environ.setdefault('MODEL_DIR', model_dir) _ensure_dir(model_dir) weights_uris = os.environ.get('WEIGHTS_URIS', '').strip() if not weights_uris: logging.info("No WEIGHTS_URIS provided; skipping GCS download.") return if not exists_gcs: logging.warning("google-cloud-storage not installed; cannot download weights. Skipping.") return client = storage.Client() # Uses ADC for uri in [u.strip() for u in weights_uris.split(',') if u.strip()]: try: bucket_name, blob_name = _parse_gs_uri(uri) filename = os.path.basename(blob_name) dest_path = os.path.join(model_dir, filename) if os.path.exists(dest_path) and os.path.getsize(dest_path) > 0: logging.info(f"Weights already present: {dest_path}") continue logging.info(f"Downloading {uri} -> {dest_path}") bucket = client.bucket(bucket_name) blob = bucket.blob(blob_name) blob.download_to_filename(dest_path) logging.info(f"Downloaded: {dest_path}") except Exception as e: logging.error(f"Failed to download {uri}: {e}")