| import json | |
| from pathlib import Path | |
| from datasets import load_dataset | |
| def load_hoho_dataset(testing: bool): | |
| if testing: | |
| params_path = Path("params.json") | |
| if params_path.exists(): | |
| with params_path.open() as f: | |
| params = json.load(f) | |
| dataset_name = params["dataset"] | |
| else: | |
| dataset_name = "usm3d/hoho25k_test_x" | |
| data_path = Path("/tmp/data") | |
| from huggingface_hub import snapshot_download | |
| _ = snapshot_download( | |
| repo_id=dataset_name, | |
| local_dir=str(data_path), | |
| repo_type="dataset", | |
| ) | |
| data_files = { | |
| "validation": [str(p) for p in data_path.rglob("*public*/**/*.tar")], | |
| "test": [str(p) for p in data_path.rglob("*private*/**/*.tar")], | |
| } | |
| dataset = load_dataset( | |
| str(data_path / "hoho25k_test_x.py"), | |
| data_files=data_files, | |
| streaming=True, | |
| trust_remote_code=True, | |
| writer_batch_size=100, | |
| ) | |
| return dataset | |
| else: | |
| dataset = load_dataset( | |
| "usm3d/hoho25k", | |
| streaming=True, | |
| trust_remote_code=True, | |
| writer_batch_size=100, | |
| ) | |
| return dataset | |