Spaces:
Runtime error
Runtime error
| from .arkit import ArkitScene | |
| from .blendedmvs import BlendMVS | |
| from .co3d import Co3d | |
| from .habitat import habitat | |
| from .scannet import Scannet | |
| from .scannetpp import Scannetpp | |
| from .seven_scenes import SevenScenes | |
| from .nrgbd import NRGBD | |
| from .dtu import DTU | |
| from .demo import Demo | |
| from dust3r.datasets.utils.transforms import * | |
| def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): | |
| import torch | |
| from croco.utils.misc import get_world_size, get_rank | |
| # pytorch dataset | |
| if isinstance(dataset, str): | |
| dataset = eval(dataset) | |
| world_size = get_world_size() | |
| rank = get_rank() | |
| try: | |
| sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, | |
| rank=rank, drop_last=drop_last) | |
| except (AttributeError, NotImplementedError): | |
| # not avail for this dataset | |
| if torch.distributed.is_initialized(): | |
| sampler = torch.utils.data.DistributedSampler( | |
| dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last | |
| ) | |
| elif shuffle: | |
| sampler = torch.utils.data.RandomSampler(dataset) | |
| else: | |
| sampler = torch.utils.data.SequentialSampler(dataset) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| sampler=sampler, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=pin_mem, | |
| drop_last=drop_last, | |
| ) | |
| return data_loader | |
| def build_dataset(dataset, batch_size, num_workers, test=False): | |
| split = ['Train', 'Test'][test] | |
| print(f'Building {split} Data loader for dataset: ', dataset) | |
| loader = get_data_loader(dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_mem=True, | |
| shuffle=not (test), | |
| drop_last=not (test)) | |
| print(f"{split} dataset length: ", len(loader)) | |
| return loader | |