open-clip / src /open_clip_train /distributed.py
letitiaaa's picture
Initial commit
e66e8cc
import os
import warnings
from typing import Optional
import torch
import torch.distributed as dist
try:
import horovod.torch as hvd
except ImportError:
hvd = None
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
def is_device_available(device):
device_type = torch.device(device).type
is_avail = False
is_known = False
if device_type == 'cuda':
is_avail = torch.cuda.is_available()
is_known = True
elif device_type == 'npu':
# NOTE autoload device extension needed for this not to error out on this check
is_avail = torch.npu.is_available()
is_known = True
elif device_type == 'mps':
is_avail = torch.backends.mps.is_available()
is_known = True
elif device_type == 'cpu':
is_avail = True
is_known = True
return is_avail, is_known
def set_device(device):
if device.startswith('cuda:'):
torch.cuda.set_device(device)
elif device.startswith('npu:'):
torch.npu.set_device(device)
def is_using_horovod():
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
return True
else:
return False
def is_using_distributed():
if 'WORLD_SIZE' in os.environ:
return int(os.environ['WORLD_SIZE']) > 1
if 'SLURM_NTASKS' in os.environ:
return int(os.environ['SLURM_NTASKS']) > 1
return False
def world_info_from_env():
local_rank = 0
for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
if v in os.environ:
world_size = int(os.environ[v])
break
return local_rank, global_rank, world_size
def init_distributed_device(args):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args.distributed = False
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
result = init_distributed_device_so(
device=getattr(args, 'device', 'cuda'),
dist_backend=getattr(args, 'dist_backend', None),
dist_url=getattr(args, 'dist_url', None),
horovod=getattr(args, 'horovod', False),
no_set_device_rank=getattr(args, 'no_set_device_rank', False),
)
args.device = result['device']
args.world_size = result['world_size']
args.rank = result['global_rank']
args.local_rank = result['local_rank']
args.distributed = result['distributed']
device = torch.device(args.device)
return device
def init_distributed_device_so(
device: str = 'cuda',
dist_backend: Optional[str] = None,
dist_url: Optional[str] = None,
horovod: bool = False,
no_set_device_rank: bool = False,
):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
distributed = False
world_size = 1
global_rank = 0
local_rank = 0
device_type, *device_idx = device.split(':', maxsplit=1)
is_avail, is_known = is_device_available(device_type)
if not is_known:
warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.")
elif not is_avail:
warnings.warn(f"Device {device} was not available, falling back to CPU.")
device_type = device = 'cpu'
if horovod:
import horovod.torch as hvd
assert hvd is not None, "Horovod is not installed"
hvd.init()
local_rank = int(hvd.local_rank())
global_rank = hvd.rank()
world_size = hvd.size()
distributed = True
elif is_using_distributed():
if dist_backend is None:
dist_backends = {
"cuda": "nccl",
"hpu": "hccl",
"npu": "hccl",
"xpu": "ccl",
}
dist_backend = dist_backends.get(device_type, 'gloo')
dist_url = dist_url or 'env://'
if 'SLURM_PROCID' in os.environ:
# DDP via SLURM
local_rank, global_rank, world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['RANK'] = str(global_rank)
os.environ['WORLD_SIZE'] = str(world_size)
torch.distributed.init_process_group(
backend=dist_backend,
init_method=dist_url,
world_size=world_size,
rank=global_rank,
)
else:
# DDP via torchrun, torch.distributed.launch
local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=dist_backend,
init_method=dist_url,
)
world_size = torch.distributed.get_world_size()
global_rank = torch.distributed.get_rank()
distributed = True
if distributed and not no_set_device_rank and device_type not in ('cpu', 'mps'):
# Ignore manually specified device index in distributed mode and
# override with resolved local rank, fewer headaches in most setups.
if device_idx:
warnings.warn(f'device index {device_idx[0]} removed from specified ({device}).')
device = f'{device_type}:{local_rank}'
set_device(device)
return dict(
device=device,
global_rank=global_rank,
local_rank=local_rank,
world_size=world_size,
distributed=distributed,
)
def broadcast_object(args, obj, src=0):
# broadcast a pickle-able python object from rank-0 to all ranks
if args.horovod:
return hvd.broadcast_object(obj, root_rank=src)
else:
if args.rank == src:
objects = [obj]
else:
objects = [None]
dist.broadcast_object_list(objects, src=src)
return objects[0]
def all_gather_object(args, obj, dst=0):
# gather a pickle-able python object across all ranks
if args.horovod:
return hvd.allgather_object(obj)
else:
objects = [None for _ in range(args.world_size)]
dist.all_gather_object(objects, obj)
return objects