| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | import logging |
| | import os |
| | import sys |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.distributed as dist |
| | import torch.nn.functional as F |
| |
|
| | from .norm import SimpleRMSNorm as SimpleRMSNormTorch |
| | from .srmsnorm_triton import SimpleRMSNorm as SimpleRMSNormTriton |
| |
|
| | use_triton = eval(os.environ.get("use_triton", default="True")) |
| | debug = eval(os.environ.get("debug", default="False")) |
| |
|
| | if use_triton: |
| | SimpleRMSNorm = SimpleRMSNormTriton |
| | else: |
| | SimpleRMSNorm = SimpleRMSNormTorch |
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | level=os.environ.get("LOGLEVEL", "INFO").upper(), |
| | stream=sys.stdout, |
| | ) |
| | logger = logging.getLogger("print_config") |
| |
|
| | BASE_DIM = 256 |
| |
|
| |
|
| | def is_dist_avail_and_initialized(): |
| | if not dist.is_available(): |
| | return False |
| | if not dist.is_initialized(): |
| | return False |
| | return True |
| |
|
| |
|
| | def get_world_size(): |
| | if not is_dist_avail_and_initialized(): |
| | return 1 |
| | return dist.get_world_size() |
| |
|
| |
|
| | def get_rank(): |
| | if not is_dist_avail_and_initialized(): |
| | return 0 |
| | return dist.get_rank() |
| |
|
| |
|
| | def is_main_process(): |
| | return get_rank() == 0 |
| |
|
| |
|
| | def logging_info(string): |
| | if is_main_process(): |
| | logger.info(string) |
| |
|
| |
|
| | def print_params(**kwargs): |
| | if is_main_process(): |
| | logger.info(f"start print config of {kwargs['__class__']}") |
| | for key in kwargs: |
| | if key in ["__class__", "self"]: |
| | continue |
| | logger.info(f"{key}: {kwargs[key]}") |
| | logger.info(f"end print config of {kwargs['__class__']}") |
| |
|
| |
|
| | def print_config(config): |
| | if is_main_process(): |
| | logger.info(f"start print config of {config['__class__']}") |
| | for key in config: |
| | if key in ["__class__", "self"]: |
| | continue |
| | logger.info(f"{key}: {config[key]}") |
| | logger.info(f"end print config of {config['__class__']}") |
| |
|
| |
|
| | def print_module(module): |
| | named_modules = set() |
| | for p in module.named_modules(): |
| | named_modules.update([p[0]]) |
| | named_modules = list(named_modules) |
| |
|
| | string_repr = "" |
| | for p in module.named_parameters(): |
| | name = p[0].split(".")[0] |
| | if name not in named_modules: |
| | string_repr = (string_repr + "(" + name + "): " + "Tensor(" + |
| | str(tuple(p[1].shape)) + ", requires_grad=" + |
| | str(p[1].requires_grad) + ")\n") |
| |
|
| | return string_repr.rstrip("\n") |
| |
|
| |
|
| | def get_activation_fn(activation): |
| | if debug: |
| | logger.info(f"activation: {activation}") |
| | if activation == "gelu": |
| | return F.gelu |
| | elif activation == "relu": |
| | return F.relu |
| | elif activation == "elu": |
| | return F.elu |
| | elif activation == "sigmoid": |
| | return F.sigmoid |
| | elif activation == "exp": |
| |
|
| | def f(x): |
| | with torch.no_grad(): |
| | x_max = torch.max(x, dim=-1, keepdims=True).values |
| | y = torch.exp(x - x_max) |
| |
|
| | return y |
| |
|
| | return f |
| | elif activation == "leak": |
| | return F.leaky_relu |
| | elif activation == "1+elu": |
| |
|
| | def f(x): |
| | return 1 + F.elu(x) |
| |
|
| | return f |
| | elif activation == "2+elu": |
| |
|
| | def f(x): |
| | return 2 + F.elu(x) |
| |
|
| | return f |
| | elif activation == "silu" or activation == "swish": |
| | return F.silu |
| | elif activation == "sine": |
| | return torch.sin |
| | else: |
| | logger.info( |
| | f"activation: does not support {activation}, use Identity!!!") |
| | return lambda x: x |
| |
|
| |
|
| | def get_norm_fn(norm_type): |
| | if norm_type == "simplermsnorm": |
| | return SimpleRMSNorm |
| | else: |
| | return nn.LayerNorm |
| |
|
| |
|
| | def convert_to_multiple_of_base(x): |
| | return BASE_DIM * ((x + BASE_DIM - 1) // BASE_DIM) |
| |
|