|
|
|
import glob |
|
import os |
|
import os.path as osp |
|
import warnings |
|
from typing import Union |
|
|
|
from mmengine.config import Config, ConfigDict |
|
from mmengine.logging import print_log |
|
|
|
|
|
def find_latest_checkpoint(path, suffix='pth'): |
|
"""Find the latest checkpoint from the working directory. |
|
|
|
Args: |
|
path(str): The path to find checkpoints. |
|
suffix(str): File extension. |
|
Defaults to pth. |
|
|
|
Returns: |
|
latest_path(str | None): File path of the latest checkpoint. |
|
References: |
|
.. [1] https://github.com/microsoft/SoftTeacher |
|
/blob/main/ssod/utils/patch.py |
|
""" |
|
if not osp.exists(path): |
|
warnings.warn('The path of checkpoints does not exist.') |
|
return None |
|
if osp.exists(osp.join(path, f'latest.{suffix}')): |
|
return osp.join(path, f'latest.{suffix}') |
|
|
|
checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) |
|
if len(checkpoints) == 0: |
|
warnings.warn('There are no checkpoints in the path.') |
|
return None |
|
latest = -1 |
|
latest_path = None |
|
for checkpoint in checkpoints: |
|
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) |
|
if count > latest: |
|
latest = count |
|
latest_path = checkpoint |
|
return latest_path |
|
|
|
|
|
def update_data_root(cfg, logger=None): |
|
"""Update data root according to env MMDET_DATASETS. |
|
|
|
If set env MMDET_DATASETS, update cfg.data_root according to |
|
MMDET_DATASETS. Otherwise, using cfg.data_root as default. |
|
|
|
Args: |
|
cfg (:obj:`Config`): The model config need to modify |
|
logger (logging.Logger | str | None): the way to print msg |
|
""" |
|
assert isinstance(cfg, Config), \ |
|
f'cfg got wrong type: {type(cfg)}, expected mmengine.Config' |
|
|
|
if 'MMDET_DATASETS' in os.environ: |
|
dst_root = os.environ['MMDET_DATASETS'] |
|
print_log(f'MMDET_DATASETS has been set to be {dst_root}.' |
|
f'Using {dst_root} as data root.') |
|
else: |
|
return |
|
|
|
assert isinstance(cfg, Config), \ |
|
f'cfg got wrong type: {type(cfg)}, expected mmengine.Config' |
|
|
|
def update(cfg, src_str, dst_str): |
|
for k, v in cfg.items(): |
|
if isinstance(v, ConfigDict): |
|
update(cfg[k], src_str, dst_str) |
|
if isinstance(v, str) and src_str in v: |
|
cfg[k] = v.replace(src_str, dst_str) |
|
|
|
update(cfg.data, cfg.data_root, dst_root) |
|
cfg.data_root = dst_root |
|
|
|
|
|
def get_test_pipeline_cfg(cfg: Union[str, ConfigDict]) -> ConfigDict: |
|
"""Get the test dataset pipeline from entire config. |
|
|
|
Args: |
|
cfg (str or :obj:`ConfigDict`): the entire config. Can be a config |
|
file or a ``ConfigDict``. |
|
|
|
Returns: |
|
:obj:`ConfigDict`: the config of test dataset. |
|
""" |
|
if isinstance(cfg, str): |
|
cfg = Config.fromfile(cfg) |
|
|
|
def _get_test_pipeline_cfg(dataset_cfg): |
|
if 'pipeline' in dataset_cfg: |
|
return dataset_cfg.pipeline |
|
|
|
elif 'dataset' in dataset_cfg: |
|
return _get_test_pipeline_cfg(dataset_cfg.dataset) |
|
|
|
elif 'datasets' in dataset_cfg: |
|
return _get_test_pipeline_cfg(dataset_cfg.datasets[0]) |
|
|
|
raise RuntimeError('Cannot find `pipeline` in `test_dataloader`') |
|
|
|
return _get_test_pipeline_cfg(cfg.test_dataloader.dataset) |
|
|