|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A few things commonly used across A LOT of config files.""" |
|
|
|
|
|
import string |
|
|
|
|
|
import ml_collections as mlc |
|
|
|
|
|
|
|
|
def input_for_quicktest(config_input, quicktest): |
|
|
if quicktest: |
|
|
config_input.batch_size = 8 |
|
|
config_input.shuffle_buffer_size = 10 |
|
|
config_input.cache_raw = False |
|
|
|
|
|
|
|
|
def parse_arg(arg, lazy=False, **spec): |
|
|
"""Makes ConfigDict's get_config single-string argument more usable. |
|
|
|
|
|
Example use in the config file: |
|
|
|
|
|
import big_vision.configs.common as bvcc |
|
|
def get_config(arg): |
|
|
arg = bvcc.parse_arg(arg, |
|
|
res=(224, int), |
|
|
runlocal=False, |
|
|
schedule='short', |
|
|
) |
|
|
|
|
|
# ... |
|
|
|
|
|
config.shuffle_buffer = 250_000 if not arg.runlocal else 50 |
|
|
|
|
|
Ways that values can be passed when launching: |
|
|
|
|
|
--config amazing.py:runlocal,schedule=long,res=128 |
|
|
--config amazing.py:res=128 |
|
|
--config amazing.py:runlocal # A boolean needs no value for "true". |
|
|
--config amazing.py:runlocal=False # Explicit false boolean. |
|
|
--config amazing.py:128 # The first spec entry may be passed unnamed alone. |
|
|
|
|
|
Uses strict bool conversion (converting 'True', 'true' to True, and 'False', |
|
|
'false', '' to False). |
|
|
|
|
|
Args: |
|
|
arg: the string argument that's passed to get_config. |
|
|
lazy: allow lazy parsing of arguments, which are not in spec. For these, |
|
|
the type is auto-extracted in dependence of most complex possible type. |
|
|
**spec: the name and default values of the expected options. |
|
|
If the value is a tuple, the value's first element is the default value, |
|
|
and the second element is a function called to convert the string. |
|
|
Otherwise the type is automatically extracted from the default value. |
|
|
|
|
|
Returns: |
|
|
ConfigDict object with extracted type-converted values. |
|
|
""" |
|
|
|
|
|
arg = arg or '' |
|
|
spec = {k: get_type_with_default(v) for k, v in spec.items()} |
|
|
|
|
|
result = mlc.ConfigDict(type_safe=False) |
|
|
|
|
|
|
|
|
if arg and ',' not in arg and '=' not in arg: |
|
|
|
|
|
|
|
|
if arg in spec or not spec: |
|
|
arg = f'{arg}=True' |
|
|
|
|
|
else: |
|
|
arg = f'{list(spec.keys())[0]}={arg}' |
|
|
|
|
|
|
|
|
|
|
|
raw_kv = {raw_arg.split('=')[0]: |
|
|
raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True' |
|
|
for raw_arg in arg.split(',') if raw_arg} |
|
|
|
|
|
|
|
|
for name, (default, type_fn) in spec.items(): |
|
|
val = raw_kv.pop(name, None) |
|
|
result[name] = type_fn(val) if val is not None else default |
|
|
|
|
|
if raw_kv: |
|
|
if lazy: |
|
|
for k, v in raw_kv.items(): |
|
|
result[k] = autotype(v) |
|
|
else: |
|
|
raise ValueError(f'Unhandled config args remain: {raw_kv}') |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def get_type_with_default(v): |
|
|
"""Returns (v, string_to_v_type) with lenient bool parsing.""" |
|
|
|
|
|
if isinstance(v, bool): |
|
|
def strict_bool(x): |
|
|
assert x.lower() in {'true', 'false', ''} |
|
|
return x.lower() == 'true' |
|
|
return (v, strict_bool) |
|
|
|
|
|
if isinstance(v, (tuple, list)): |
|
|
assert len(v) == 2 and isinstance(v[1], type), ( |
|
|
'List or tuple types are currently not supported because we use `,` as' |
|
|
' dumb delimiter. Contributions (probably using ast) welcome. You can' |
|
|
' unblock by using a string with eval(s.replace(";", ",")) or similar') |
|
|
return (v[0], v[1]) |
|
|
|
|
|
return (v, type(v)) |
|
|
|
|
|
|
|
|
def autotype(x): |
|
|
"""Auto-converts string to bool/int/float if possible.""" |
|
|
assert isinstance(x, str) |
|
|
if x.lower() in {'true', 'false'}: |
|
|
return x.lower() == 'true' |
|
|
try: |
|
|
return int(x) |
|
|
except ValueError: |
|
|
try: |
|
|
return float(x) |
|
|
except ValueError: |
|
|
return x |
|
|
|
|
|
|
|
|
def pack_arg(**kw): |
|
|
"""Packs key-word args as a string to be parsed by `parse_arg()`.""" |
|
|
for v in kw.values(): |
|
|
assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}" |
|
|
return ','.join([f'{k}={v}' for k, v in kw.items()]) |
|
|
|
|
|
|
|
|
def arg(**kw): |
|
|
"""Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg.""" |
|
|
return {'config_arg': pack_arg(**kw), **kw} |
|
|
|
|
|
|
|
|
def _get_field_ref(config_dict, field_name): |
|
|
path = field_name.split('.') |
|
|
for field in path[:-1]: |
|
|
config_dict = getattr(config_dict, field) |
|
|
return config_dict.get_ref(path[-1]) |
|
|
|
|
|
|
|
|
def format_str(format_string, config): |
|
|
"""Format string with reference fields from config. |
|
|
|
|
|
This makes it easy to build preprocess strings that contain references to |
|
|
fields tha are edited after. E.g.: |
|
|
|
|
|
``` |
|
|
config = mlc.ConficDict() |
|
|
config.res = (256, 256) |
|
|
config.pp = bvcc.format_str('resize({res})', config) |
|
|
... |
|
|
# if config.res is modified (e.g. via sweeps) it will propagate to pp field: |
|
|
config.res = (512, 512) |
|
|
assert config.pp == 'resize((512, 512))' |
|
|
``` |
|
|
|
|
|
Args: |
|
|
format_string: string to format with references. |
|
|
config: ConfigDict to get references to format the string. |
|
|
|
|
|
Returns: |
|
|
A reference field which renders a string using references to config fields. |
|
|
""" |
|
|
output = '' |
|
|
parts = string.Formatter().parse(format_string) |
|
|
for (literal_text, field_name, format_spec, conversion) in parts: |
|
|
assert not format_spec and not conversion |
|
|
output += literal_text |
|
|
if field_name: |
|
|
output += _get_field_ref(config, field_name).to_str() |
|
|
return output |
|
|
|