File size: 6,486 Bytes
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import json
import os
import sys
import time
import warnings
from pathlib import Path
from typing import Callable, Dict, Optional

from omegaconf import DictConfig
from pytorch_lightning.utilities import rank_zero_only

from src.utils.logging_utils import close_loggers, get_pylogger
from src.utils.rich_utils import enforce_tags, print_config_tree

log = get_pylogger(__name__)


def task_wrapper(task_func: Callable) -> Callable:
    """Optional decorator that wraps the task function in extra utilities.

    Makes multirun more resistant to failure.

    Utilities:
    - Calling the `utils.extras()` before the task is started
    - Calling the `utils.close_loggers()` after the task is finished
    - Logging the exception if occurs
    - Logging the task total execution time
    - Logging the output dir
    """

    def wrap(cfg: DictConfig):

        # apply extra utilities
        extras(cfg)

        # execute the task
        start_time = time.time()
        try:
            task_result = task_func(cfg=cfg)
        except Exception as ex:
            log.exception("")  # save exception to `.log` file
            raise ex
        finally:
            path = Path(cfg.paths.output_dir, "exec_time.log")
            content = f"'{cfg.pipeline_type}' execution time: {time.time() - start_time} (s)"
            save_file(path, content)  # save task execution time (even if exception occurs)
            close_loggers()  # close loggers (even if exception occurs so multirun won't fail)

        log.info(f"Output dir: {cfg.paths.output_dir}")

        return task_result

    return wrap


def extras(cfg: DictConfig) -> None:
    """Applies optional utilities before the task is started.

    Utilities:
    - Ignoring python warnings
    - Setting tags from command line
    - Rich config printing
    """

    # return if no `extras` config
    if not cfg.get("extras"):
        log.warning("Extras config not found! <cfg.extras=null>")
        return

    # disable python warnings
    if cfg.extras.get("ignore_warnings"):
        log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
        warnings.filterwarnings("ignore")

    # prompt user to input tags from command line if none are provided in the config
    if cfg.extras.get("enforce_tags"):
        log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
        enforce_tags(cfg, save_to_file=True)

    # pretty print config tree using Rich library
    if cfg.extras.get("print_config"):
        log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
        print_config_tree(cfg, resolve=True, save_to_file=True)


@rank_zero_only
def save_file(path: str, content: str) -> None:
    """Save file in rank zero mode (only on one process in multi-GPU setup)."""
    with open(path, "w+") as file:
        file.write(content)


def load_value_from_file(path: str, split_path_key: str = ":", split_key_parts: str = "/") -> Dict:
    """Load a value from a file. The path can point to elements within the file (see split_path_key
    parameter) and that can be nested (see split_key_parts parameter). For now, only .json files
    are supported.

    Args:
        path: path to the file (and data within the file)
        split_path_key: split the path on this value to get the path to the file and the key within the file
        split_key_parts: the value to split the key on to get the nested keys
    """

    parts_path = path.split(split_path_key, maxsplit=1)
    file_extension = os.path.splitext(parts_path[0])[1]
    if file_extension == ".json":
        with open(parts_path[0], "r") as f:
            data = json.load(f)
    else:
        raise ValueError(f"Expected .json file, got {file_extension}")

    if len(parts_path) == 1:
        return data

    keys = parts_path[1].split(split_key_parts)
    for key in keys:
        data = data[key]
    return data


def replace_sys_args_with_values_from_files(
    load_prefix: str = "LOAD_ARG:",
    load_multi_prefix: str = "LOAD_MULTI_ARG:",
    **load_value_from_file_kwargs,
) -> None:
    """Replaces arguments in sys.argv with values loaded from files.

    Examples:
        # config.json contains {"a": 1, "b": 2}
        python train.py LOAD_ARG:job_return_value.json
        # this will pass "{a:1,b:2}" as the first argument to train.py

        # config.json contains [1, 2, 3]
        python train.py LOAD_MULTI_ARG:job_return_value.json
        # this will pass "1,2,3" as the first argument to train.py

        # config.json contains {"model": {"ouput_dir": ["path1", "path2"], f1: [0.7, 0.6]}}
        python train.py load_model=LOAD_ARG:job_return_value.json:model/output_dir
        # this will pass "load_model=path1,path2" to train.py

    Args:
        load_prefix: the prefix to use for loading a single value from a file
        load_multi_prefix: the prefix to use for loading a list of values from a file
        **load_value_from_file_kwargs: additional kwargs to pass to load_value_from_file
    """

    updated_args = []
    for arg in sys.argv[1:]:
        is_multirun_arg = False
        if load_prefix in arg:
            parts = arg.split(load_prefix, maxsplit=1)
        elif load_multi_prefix in arg:
            parts = arg.split(load_multi_prefix, maxsplit=1)
            is_multirun_arg = True
        else:
            updated_args.append(arg)
            continue
        if len(parts) == 2:
            log.warning(f'Replacing argument value for "{parts[0]}" with content from {parts[1]}')
            json_value = load_value_from_file(parts[1], **load_value_from_file_kwargs)
            json_value_str = json.dumps(json_value)
            # replace quotes and spaces
            json_value_str = json_value_str.replace('"', "").replace(" ", "")
            # remove outer brackets
            if is_multirun_arg:
                if not isinstance(json_value, list):
                    raise ValueError(
                        f"Expected list for multirun argument, got {type(json_value)}. If you just want "
                        f"to set a single value, use {load_prefix} instead of {load_multi_prefix}."
                    )
                json_value_str = json_value_str[1:-1]
            # add outer quotes
            modified_arg = f"{parts[0]}{json_value_str}"
            updated_args.append(modified_arg)
        else:
            updated_args.append(arg)
    # Set sys.argv to the updated arguments
    sys.argv = [sys.argv[0]] + updated_args