Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
import copy | |
import os | |
from typing import Any, Dict, List, Optional | |
import yaml | |
from .workspace import GLOBAL_CONFIG | |
__all__ = [ | |
"load_config", | |
"merge_config", | |
"merge_dict", | |
"parse_cli", | |
] | |
INCLUDE_KEY = "__include__" | |
def load_config(file_path, cfg=dict()): | |
"""load config""" | |
_, ext = os.path.splitext(file_path) | |
assert ext in [".yml", ".yaml"], "only support yaml files" | |
with open(file_path) as f: | |
file_cfg = yaml.load(f, Loader=yaml.Loader) | |
if file_cfg is None: | |
return {} | |
if INCLUDE_KEY in file_cfg: | |
base_yamls = list(file_cfg[INCLUDE_KEY]) | |
for base_yaml in base_yamls: | |
if base_yaml.startswith("~"): | |
base_yaml = os.path.expanduser(base_yaml) | |
if not base_yaml.startswith("/"): | |
base_yaml = os.path.join(os.path.dirname(file_path), base_yaml) | |
with open(base_yaml) as f: | |
base_cfg = load_config(base_yaml, cfg) | |
merge_dict(cfg, base_cfg) | |
return merge_dict(cfg, file_cfg) | |
def merge_dict(dct, another_dct, inplace=True) -> Dict: | |
"""merge another_dct into dct""" | |
def _merge(dct, another) -> Dict: | |
for k in another: | |
if k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict): | |
_merge(dct[k], another[k]) | |
else: | |
dct[k] = another[k] | |
return dct | |
if not inplace: | |
dct = copy.deepcopy(dct) | |
return _merge(dct, another_dct) | |
def dictify(s: str, v: Any) -> Dict: | |
if "." not in s: | |
return {s: v} | |
key, rest = s.split(".", 1) | |
return {key: dictify(rest, v)} | |
def parse_cli(nargs: List[str]) -> Dict: | |
""" | |
parse command-line arguments | |
convert `a.c=3 b=10` to `{'a': {'c': 3}, 'b': 10}` | |
""" | |
cfg = {} | |
if nargs is None or len(nargs) == 0: | |
return cfg | |
for s in nargs: | |
s = s.strip() | |
k, v = s.split("=", 1) | |
d = dictify(k, yaml.load(v, Loader=yaml.Loader)) | |
cfg = merge_dict(cfg, d) | |
return cfg | |
def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool = False, overwrite: bool = False): | |
""" | |
Merge another_cfg into cfg, return the merged config | |
Example: | |
cfg1 = load_config('./dfine_r18vd_6x_coco.yml') | |
cfg1 = merge_config(cfg, inplace=True) | |
cfg2 = load_config('./dfine_r50vd_6x_coco.yml') | |
cfg2 = merge_config(cfg2, inplace=True) | |
model1 = create(cfg1['model'], cfg1) | |
model2 = create(cfg2['model'], cfg2) | |
""" | |
def _merge(dct, another): | |
for k in another: | |
if k not in dct: | |
dct[k] = another[k] | |
elif isinstance(dct[k], dict) and isinstance(another[k], dict): | |
_merge(dct[k], another[k]) | |
elif overwrite: | |
dct[k] = another[k] | |
return cfg | |
if not inplace: | |
cfg = copy.deepcopy(cfg) | |
return _merge(cfg, another_cfg) | |