|
""" |
|
Module for training the 2D clutter filtering model with L2 loss. |
|
""" |
|
import os |
|
import argparse |
|
import json |
|
import numpy as np |
|
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau |
|
|
|
from utils import * |
|
from Model_ClutterFilter2D import clutter_filter_2D |
|
from DataGen import DataGen |
|
|
|
def data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config): |
|
DtaGenTr_prm = { |
|
'dim': config["network_prm"]["input_dim"], |
|
'in_dir': in_ids_tr, |
|
'out_dir': out_ids_tr, |
|
'id_list': np.arange(len(in_ids_tr)), |
|
'batch_size': config["learning_prm"]["batch_size"], |
|
'tr_phase': True} |
|
DtaGenVal_prm = { |
|
'dim': config["network_prm"]["input_dim"], |
|
'in_dir': in_ids_val, |
|
'out_dir': out_ids_val, |
|
'id_list': np.arange(len(in_ids_val)), |
|
'batch_size': config["learning_prm"]["batch_size"], |
|
'tr_phase': True} |
|
tr_gen = DataGen(**DtaGenTr_prm) |
|
val_gen = DataGen(**DtaGenVal_prm) |
|
return tr_gen, val_gen |
|
|
|
def model_chkpnt(val_subject, te_subject, weight_dir, config): |
|
weight_name = ( |
|
f'CF2D_ValTeSbj_{val_subject}_{te_subject}_nLvl{config["network_prm"]["n_levels"]}' |
|
f'_InSkp{config["network_prm"]["in_skip"]}_Att{config["network_prm"]["attention"]}' |
|
f'_Act{config["network_prm"]["act"]}_nInitFlt{config["network_prm"]["n_init_filters"]}_lr{config["learning_prm"]["lr"]}') |
|
filepath = (weight_dir + '/'+ weight_name + |
|
'_epc' + "{epoch:03d}" + '_trloss' + "{loss:.5f}" + |
|
'_valloss' + "{val_loss:.5f}" + ".hdf5") |
|
model_checkpoint = ModelCheckpoint(filepath=filepath, |
|
monitor="val_loss", |
|
verbose=0, |
|
save_best_only=True) |
|
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, |
|
patience=4, min_lr=1e-7) |
|
return model_checkpoint, reduce_lr |
|
|
|
def main(config): |
|
in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config) |
|
weight_dir = create_weight_dir(val_subject, te_subject, config) |
|
tr_gen, val_gen = data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config) |
|
model = clutter_filter_2D(**config) |
|
model_checkpoint, reduce_lr = model_chkpnt(val_subject, te_subject, weight_dir, config) |
|
model.fit(tr_gen, |
|
validation_data=val_gen, |
|
epochs=config["learning_prm"]["n_epochs"], |
|
verbose=1, |
|
callbacks=[model_checkpoint, reduce_lr]) |
|
return None |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", help="path of the config file", default="config.json") |
|
args = parser.parse_args() |
|
assert os.path.isfile(args.config) |
|
with open(args.config, "r") as read_file: |
|
config = json.load(read_file) |
|
main(config) |