Spaces:
Sleeping
Sleeping
| # -------------------------------------------------------- | |
| # Swin Transformer | |
| # Copyright (c) 2021 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Written by Ze Liu | |
| # -------------------------------------------------------- | |
| import os | |
| import torch | |
| import numpy as np | |
| import torch.distributed as dist | |
| from torchvision import datasets, transforms | |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from timm.data import Mixup | |
| from timm.data import create_transform | |
| from timm.data.transforms import str_to_interp_mode | |
| from .cached_image_folder import CachedImageFolder | |
| from .samplers import SubsetRandomSampler | |
| from .dataset_fg import DatasetMeta | |
| def build_loader(config): | |
| config.defrost() | |
| dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) | |
| config.freeze() | |
| print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") | |
| dataset_val, _ = build_dataset(is_train=False, config=config) | |
| print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") | |
| num_tasks = dist.get_world_size() | |
| global_rank = dist.get_rank() | |
| if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': | |
| indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) | |
| sampler_train = SubsetRandomSampler(indices) | |
| else: | |
| sampler_train = torch.utils.data.DistributedSampler( | |
| dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True | |
| ) | |
| indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) | |
| sampler_val = SubsetRandomSampler(indices) | |
| data_loader_train = torch.utils.data.DataLoader( | |
| dataset_train, sampler=sampler_train, | |
| batch_size=config.DATA.BATCH_SIZE, | |
| num_workers=config.DATA.NUM_WORKERS, | |
| pin_memory=config.DATA.PIN_MEMORY, | |
| drop_last=True, | |
| ) | |
| data_loader_val = torch.utils.data.DataLoader( | |
| dataset_val, sampler=sampler_val, | |
| batch_size=config.DATA.BATCH_SIZE, | |
| shuffle=False, | |
| num_workers=config.DATA.NUM_WORKERS, | |
| pin_memory=config.DATA.PIN_MEMORY, | |
| drop_last=False | |
| ) | |
| # setup mixup / cutmix | |
| mixup_fn = None | |
| mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None | |
| if mixup_active: | |
| mixup_fn = Mixup( | |
| mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, | |
| prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, | |
| label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) | |
| return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn | |
| def build_dataset(is_train, config): | |
| transform = build_transform(is_train, config) | |
| if config.DATA.DATASET == 'imagenet': | |
| prefix = 'train' if is_train else 'val' | |
| if config.DATA.ZIP_MODE: | |
| ann_file = prefix + "_map.txt" | |
| prefix = prefix + ".zip@/" | |
| dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, | |
| cache_mode=config.DATA.CACHE_MODE if is_train else 'part') | |
| else: | |
| # root = os.path.join(config.DATA.DATA_PATH, prefix) | |
| root = './datasets/imagenet' | |
| dataset = datasets.ImageFolder(root, transform=transform) | |
| elif config.DATA.DATASET == 'inaturelist2021': | |
| root = './datasets/inaturelist2021' | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| elif config.DATA.DATASET == 'inaturelist2021_mini': | |
| root = './datasets/inaturelist2021_mini' | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| elif config.DATA.DATASET == 'inaturelist2017': | |
| root = './datasets/inaturelist2017' | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| elif config.DATA.DATASET == 'inaturelist2018': | |
| root = './datasets/inaturelist2018' | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| elif config.DATA.DATASET == 'cub-200': | |
| root = './datasets/cub-200' | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| elif config.DATA.DATASET == 'stanfordcars': | |
| root = './datasets/stanfordcars' | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| elif config.DATA.DATASET == 'oxfordflower': | |
| root = './datasets/oxfordflower' | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| elif config.DATA.DATASET == 'stanforddogs': | |
| root = './datasets/stanforddogs' | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| elif config.DATA.DATASET == 'nabirds': | |
| root = './datasets/nabirds' | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| elif config.DATA.DATASET == 'aircraft': | |
| root = './datasets/aircraft' | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| else: | |
| root = config.DATA.DATASET_ROOT | |
| dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET) | |
| nb_classes = len(dataset.class_to_idx) | |
| return dataset, nb_classes | |
| def build_transform(is_train, config): | |
| resize_im = config.DATA.IMG_SIZE > 32 | |
| if is_train: | |
| # this should always dispatch to transforms_imagenet_train | |
| transform = create_transform( | |
| input_size=config.DATA.IMG_SIZE, | |
| is_training=True, | |
| color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, | |
| auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, | |
| re_prob=config.AUG.REPROB, | |
| re_mode=config.AUG.REMODE, | |
| re_count=config.AUG.RECOUNT, | |
| interpolation=config.DATA.TRAIN_INTERPOLATION, | |
| ) | |
| if not resize_im: | |
| # replace RandomResizedCropAndInterpolation with | |
| # RandomCrop | |
| transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) | |
| return transform | |
| t = [] | |
| if resize_im: | |
| if config.TEST.CROP: | |
| size = int((256 / 224) * config.DATA.IMG_SIZE) | |
| t.append( | |
| transforms.Resize(size, interpolation=str_to_interp_mode(config.DATA.INTERPOLATION)), | |
| # to maintain same ratio w.r.t. 224 images | |
| ) | |
| t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) | |
| else: | |
| t.append( | |
| transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), | |
| interpolation=str_to_interp_mode(config.DATA.INTERPOLATION)) | |
| ) | |
| t.append(transforms.ToTensor()) | |
| t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) | |
| return transforms.Compose(t) | |