MahdiTabassian commited on
Commit
6477265
·
1 Parent(s): aea2e47

Filtering models and example video clips

Browse files
Filter2D_Lrec/DataGen.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A module for generating batches of 2D images.
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import random
7
+ import scipy.io as sio
8
+ import tensorflow as tf
9
+
10
+ def _shift_first_frame(vol_in, vol_out, tr_phase):
11
+ n_frames = vol_in.shape[-1]
12
+ n, p = 1, 0.5 # p is the probability of shifting the first frame.
13
+ if tr_phase:
14
+ if np.random.binomial(n,p):
15
+ first_frm = np.random.permutation(np.arange(n_frames))[0]
16
+ vol_in = np.concatenate((vol_in[:,:,first_frm:], vol_in[:,:,:first_frm]), axis=-1)
17
+ vol_out = np.concatenate((vol_out[:,:,first_frm:], vol_out[:,:,:first_frm]), axis=-1)
18
+ return vol_in, vol_out
19
+
20
+ def _image_vol_normalization(vol):
21
+ vol = vol/255
22
+ vol[vol < 0] = 0
23
+ vol[vol > 1] = 1
24
+ return vol
25
+
26
+ def _reshape_vol(vol):
27
+ new_vol = np.empty([vol.shape[-1], vol.shape[0], vol.shape[1]])
28
+ for i in range(vol.shape[-1]):
29
+ new_vol[i,:,:] = vol[:,:,i]
30
+ return new_vol
31
+
32
+ def _image_vol_augmentation(vol_in, vol_out, tr_phase):
33
+ """
34
+ Augmenting and normalizing the input and output image volumes.
35
+ """
36
+ vol_in, vol_out = _shift_first_frame(vol_in, vol_out, tr_phase)
37
+ vol_in_norm = _image_vol_normalization(vol_in)
38
+ vol_out_norm = _image_vol_normalization(vol_out)
39
+ return [_reshape_vol(vol_in_norm), _reshape_vol(vol_out_norm)]
40
+
41
+ class DataGen(tf.keras.utils.Sequence):
42
+ """
43
+ Generating batches of input cluttered volumes and their corresponding
44
+ clutter-free output volumes
45
+ """
46
+ def __init__(
47
+ self,
48
+ dim:list,
49
+ in_dir:str,
50
+ out_dir:str,
51
+ id_list:list,
52
+ batch_size:int,
53
+ tr_phase=True,
54
+ *args,
55
+ **kwargs):
56
+ 'Initialization'
57
+ self.dim = dim
58
+ self.in_dir = in_dir
59
+ self.out_dir = out_dir
60
+ self.id_list = id_list
61
+ self.batch_size = batch_size
62
+ self.tr_phase = tr_phase
63
+
64
+ def __len__(self):
65
+ return len(self.id_list)
66
+
67
+ def __getitem__(self, idx):
68
+ 'Generate one or more batches of data'
69
+
70
+ # Initialization
71
+ in_out_shape = [self.batch_size, self.dim[0], self.dim[1], self.dim[2]]
72
+ x_aug, y_aug = np.empty(in_out_shape), np.empty(in_out_shape)
73
+ vol_id = self.id_list[idx:(idx + 1)]
74
+
75
+ # Generate the data
76
+ for i, ID in enumerate(vol_id):
77
+ # Store sample
78
+ vol_in = sio.loadmat(self.in_dir[ID])['data_artf']
79
+ vol_out = sio.loadmat(self.out_dir[ID])['data_org']
80
+ # Call the data augmentation function
81
+ aug_vols = _image_vol_augmentation(vol_in, vol_out, self.tr_phase)
82
+ x_aug[:,:,:,0], y_aug[:,:,:,0] = aug_vols[0], aug_vols[1]
83
+
84
+ if self.tr_phase:
85
+ return x_aug, y_aug
86
+ else:
87
+ return x_aug
Filter2D_Lrec/Error_analysis.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A module with functions for computing different MAE and coherence errors.
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scipy.io as sio
8
+
9
+ def _compute_sample_temporal_coherency_score(filtered_smp, org_smp):
10
+ abs_diff_flt_org_smp = np.abs(filtered_smp-org_smp)
11
+ frm_pixel_sum = [np.sum(abs_diff_flt_org_smp[:,:,i]) for i in range(abs_diff_flt_org_smp.shape[-1])]
12
+ frm_pixel_sum_shifted = np.roll(frm_pixel_sum, -1)
13
+ frm_diff = np.abs(frm_pixel_sum - frm_pixel_sum_shifted)
14
+ frm_diff = frm_diff[:-1]
15
+ return np.mean(frm_diff)
16
+
17
+ def _compute_sample_mae(smp, in_ids, filtered_smp, clutter_class):
18
+ clt_smp = sio.loadmat(in_ids[smp])['data_artf']
19
+ org_smp = sio.loadmat(
20
+ in_ids[smp].split(f'data_{clutter_class}')[0] + 'data_org/1.mat')['data_org']
21
+ mae_CltFiltered_CltFree = np.mean(np.abs(255*filtered_smp-org_smp))
22
+ mae_Cltrd_CltFree = np.mean(np.abs(clt_smp-org_smp))
23
+ temporal_coherency_score = _compute_sample_temporal_coherency_score(255*filtered_smp, org_smp)
24
+ return mae_CltFiltered_CltFree, mae_Cltrd_CltFree, temporal_coherency_score
25
+
26
+ def _make_res_dct():
27
+ res_dct = {'Clutter_class': [],
28
+ 'Clutter_spec': [],
29
+ 'View': [],
30
+ 'Vendor': [],
31
+ 'MAE_CltFiltered_CltFree': [],
32
+ 'MAE_Cltrd_CltFree': [],
33
+ 'temporal_coherency_score': []
34
+ }
35
+ return res_dct
36
+
37
+ def _id_separation(in_id):
38
+ id_part0 = in_id.split('/A')[0].split('/')
39
+ id_part1 = in_id.split('/data_')[1].split('/')
40
+ v = [v for v in in_id.split('/') if 'A' in v and 'C' in v]
41
+ view = v[0]
42
+ vendor = id_part0[-1]
43
+ clutter_class = id_part1[0]
44
+ clutter_spec = id_part1[1]
45
+ return view, vendor, clutter_class, clutter_spec
46
+
47
+ def compute_mae(in_ids, filtered_dta, te_subsample=False, te_frames=50):
48
+ res_dct = _make_res_dct()
49
+ for i in range(len(in_ids)):
50
+ view, vendor, clutter_class, clutter_spec = _id_separation(in_ids[i])
51
+ res_dct['Clutter_class'].append(clutter_class)
52
+ res_dct['Clutter_spec'].append(clutter_spec)
53
+ res_dct['Vendor'].append(vendor)
54
+ res_dct['View'].append(view)
55
+ mae_CltFiltered_CltFree, mae_Cltrd_CltFree, temporal_coherency_score = _compute_sample_mae(
56
+ smp=i, in_ids=in_ids, filtered_smp=filtered_dta[i],
57
+ clutter_class=clutter_class)
58
+ res_dct['MAE_CltFiltered_CltFree'].append(mae_CltFiltered_CltFree)
59
+ res_dct['MAE_Cltrd_CltFree'].append(mae_Cltrd_CltFree)
60
+ res_dct['temporal_coherency_score'].append(temporal_coherency_score)
61
+ return pd.DataFrame(res_dct)
Filter2D_Lrec/Model_ClutterFilter2D.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions of the 2D clutter filtering algorithm.
3
+ """
4
+ import tensorflow as tf
5
+ from tensorflow.keras import backend as K
6
+ from tensorflow.keras.layers import (Conv2D, MaxPooling2D, Activation, BatchNormalization, Add,
7
+ Dropout, Concatenate, UpSampling2D, multiply, Input, Lambda)
8
+ from tensorflow.keras.models import Model
9
+ from tensorflow.keras.optimizers import Adam
10
+
11
+ def tensor_expansion(tensor, rep, axs):
12
+ expanded_tensor = Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=axs),
13
+ arguments={'repnum': rep})(tensor)
14
+ return expanded_tensor
15
+
16
+ def attention_gate_block_2D(x, g, n_inter_filters=None, name=None, **prm):
17
+ """
18
+ Attention gate block.
19
+ """
20
+ shape_x = K.int_shape(x)
21
+ shape_g = K.int_shape(g)
22
+ if n_inter_filters is None:
23
+ n_inter_filters = shape_x[-1] // 2
24
+
25
+ theta_x = Conv2D(n_inter_filters, 3, strides=(2, 2), padding='same', name=f"{name}_theta_x")(x)
26
+ phi_g = Conv2D(n_inter_filters, 1, strides=1, padding='valid', name=f"{name}_phi_g")(g)
27
+ concat_xg = Add()([phi_g, theta_x])
28
+ act_xg = Activation('relu')(concat_xg)
29
+ psi = Conv2D(1, 1, padding='same', name=f"{name}_psi")(act_xg)
30
+ sigmoid_xg = Activation('sigmoid')(psi)
31
+ shape_sigmoid = K.int_shape(sigmoid_xg)
32
+ upsample_sigmoid = UpSampling2D(size=(2, 2), name=f"{name}_upsampled_sig")(sigmoid_xg)
33
+ upsample_sigmoid_rep = tensor_expansion(upsample_sigmoid, rep=shape_x[-1], axs=-1)
34
+ y = multiply([upsample_sigmoid_rep, x], name=f"{name}_weighted_x")
35
+ return y
36
+
37
+ def conv_block(x, filters_list, act='linear', kernel_size=3, stride=1, pad='same', drp=0.05, name=None):
38
+ """
39
+ Blocks of 2D conv filters.
40
+ """
41
+ for i in range(len(filters_list)):
42
+ x = Conv2D(filters_list[i], kernel_size, padding=pad, strides=stride, name=f"{name}_blk{i+1}")(x)
43
+ x = BatchNormalization(name=f"{name}_bn{i+1}")(x)
44
+ x = Activation(act, name=f"{name}_act{i+1}")(x)
45
+ x = Dropout(drp)(x)
46
+ return x
47
+
48
+ def encoding_block(x_in, name, **config):
49
+ """
50
+ Encoding block of the 2D Unet.
51
+ """
52
+ encoding_dct = {}
53
+ for i in range(config["network_prm"]["n_levels"]):
54
+ if i == 0:
55
+ x = x_in
56
+ n_filters = config["network_prm"]["n_init_filters"]
57
+ else:
58
+ n_filters = (2**i)*config["network_prm"]["n_init_filters"]
59
+ x = MaxPooling2D(pool_size=config["network_prm"]["pool_size"], name=f"{name}_encd_pool{i}")(x)
60
+ x = conv_block(x, filters_list=[n_filters, 2*n_filters], act=config["network_prm"]["act"],
61
+ kernel_size=config["network_prm"]["kernel_size"],
62
+ stride=config["network_prm"]["conv_stride"],
63
+ pad=config["network_prm"]["padding"],
64
+ drp=config["learning_prm"]['drp'], name=f"{name}_encd_conv_lvl{i}")
65
+ encoding_dct[f"{name}_out_lvl{i}"] = x
66
+ return encoding_dct
67
+
68
+ def decoding_block(encoding_dct, name, **config):
69
+ """
70
+ Decoding block of the 2D Unet.
71
+ """
72
+ decoding_dct = {}
73
+ n_levels = config["network_prm"]["n_levels"]
74
+ for i in range(n_levels-1):
75
+ if i == 0:
76
+ x = encoding_dct[f"{name}_out_lvl{n_levels-i-1}"]
77
+ # upsampling via Conv(Upsampling)
78
+ x_shape = K.int_shape(x)
79
+ x_up = Conv2D(x_shape[-1], 2, activation=config["network_prm"]["act"], padding='same', strides=1,
80
+ name=f"{name}_decd_upsmpl{i}")(UpSampling2D(size=(2,2))(x))
81
+ x_up_shape = K.int_shape(x_up)
82
+ # concatenation
83
+ if config["network_prm"]['attention']:
84
+ if i == 0:
85
+ g = encoding_dct[f"{name}_out_lvl{n_levels-1}"]
86
+ else:
87
+ g = decoding_dct[f"{name}_out_lvl{i-1}"]
88
+ x_encd = attention_gate_block_2D(x=encoding_dct[f"{name}_out_lvl{n_levels-i-2}"],
89
+ g=g, name=f"{name}_att_blk{i}")
90
+ else:
91
+ x_encd = encoding_dct[f"{name}_out_lvl{n_levels-i-2}"]
92
+ x_concat = Concatenate(axis=-1, name=f"{name}_decd_concat{i}")([x_encd, x_up])
93
+ n_filters = x_up_shape[-1]//2
94
+ x = conv_block(x_concat, filters_list=[n_filters, n_filters], act=config["network_prm"]["act"],
95
+ kernel_size=config["network_prm"]["kernel_size"],
96
+ stride=config["network_prm"]["conv_stride"],
97
+ pad=config["network_prm"]["padding"],
98
+ drp=config["learning_prm"]['drp'], name=f"{name}_decd_conv_lvl{i}")
99
+ decoding_dct[f"{name}_out_lvl{i}"] = x
100
+ x = conv_block(x, filters_list=[1], act=config["network_prm"]["act"], kernel_size=1, stride=1,
101
+ pad='same', drp=1e-4, name=f"{name}_final_decd_conv")
102
+ decoding_dct[f"{name}_final_conv"] = x
103
+ return decoding_dct
104
+
105
+ def Unet2D(x_in, name, **config):
106
+ """
107
+ Spatial clutter filtering based on 2D Unet.
108
+ """
109
+ encoding_dct = encoding_block(x_in, name, **config)
110
+ decoding_dct = decoding_block(encoding_dct, name, **config)
111
+ # Add the Main input here
112
+ if config["network_prm"]["in_skip"]:
113
+ out_Unet = Add()([x_in, decoding_dct[f"{name}_final_conv"]])
114
+ else:
115
+ out_Unet = decoding_dct[f"{name}_final_conv"]
116
+ return out_Unet
117
+
118
+ def clutter_filter_2D(**config):
119
+ """
120
+ The main function for designing the clutter filtering algorithm.
121
+ """
122
+ main_in = Input(config["network_prm"]["input_dim"])
123
+ filter_out = Unet2D(x_in=main_in, name="CF", **config)
124
+ model = Model(inputs=main_in, outputs=filter_out, name=config['model_name'])
125
+ opt = Adam(learning_rate=config["learning_prm"]["lr"])
126
+ model.compile(optimizer=opt, loss=config["learning_prm"]["loss"], metrics=config["learning_prm"]["metrics"])
127
+ model.summary()
128
+ return model
Filter2D_Lrec/TestClutterFilter2D.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for testing the 3D clutter filtering model.
3
+ """
4
+ import os
5
+ import argparse
6
+ import json
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from utils import *
11
+ from Model_ClutterFilter2D import clutter_filter_2D
12
+ from DataGen import DataGen
13
+ from Error_analysis import compute_mae
14
+
15
+ def data_generation(in_ids_te, out_ids_te, config):
16
+ DtaGenTe_prm = {
17
+ 'dim': config["network_prm"]["input_dim"],
18
+ 'in_dir': in_ids_te,
19
+ 'out_dir': out_ids_te,
20
+ 'id_list': np.arange(len(in_ids_te)),
21
+ 'batch_size': config["learning_prm"]["batch_size"],
22
+ 'tr_phase': False}
23
+ return DataGen(**DtaGenTe_prm)
24
+
25
+ def main(config):
26
+ in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config)
27
+ te_gen = data_generation(in_ids_te, out_ids_te, config)
28
+ model = clutter_filter_2D(**config)
29
+ weight_dir = create_weight_dir(val_subject, te_subject, config)
30
+ model.load_weights(
31
+ os.path.join(weight_dir, config["weight_name"] + ".hdf5"))
32
+ results_te = model.predict_generator(te_gen, verbose=2)
33
+ df_errors = compute_mae(in_ids_te, results_te)
34
+ df_errors.to_csv(
35
+ os.path.join(weight_dir, config["weight_name"] + ".csv"))
36
+ return None
37
+
38
+ if __name__ == '__main__':
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument("--config", help="path of the config file", default="config.json")
41
+ args = parser.parse_args()
42
+ assert os.path.isfile(args.config)
43
+ with open(args.config, "r") as read_file:
44
+ config = json.load(read_file)
45
+ main(config)
Filter2D_Lrec/TrainClutterFilter2D.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for training the 2D clutter filtering model with L2 loss.
3
+ """
4
+ import os
5
+ import argparse
6
+ import json
7
+ import numpy as np
8
+ from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
9
+
10
+ from utils import *
11
+ from Model_ClutterFilter2D import clutter_filter_2D
12
+ from DataGen import DataGen
13
+
14
+ def data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config):
15
+ DtaGenTr_prm = {
16
+ 'dim': config["network_prm"]["input_dim"],
17
+ 'in_dir': in_ids_tr,
18
+ 'out_dir': out_ids_tr,
19
+ 'id_list': np.arange(len(in_ids_tr)),
20
+ 'batch_size': config["learning_prm"]["batch_size"],
21
+ 'tr_phase': True}
22
+ DtaGenVal_prm = {
23
+ 'dim': config["network_prm"]["input_dim"],
24
+ 'in_dir': in_ids_val,
25
+ 'out_dir': out_ids_val,
26
+ 'id_list': np.arange(len(in_ids_val)),
27
+ 'batch_size': config["learning_prm"]["batch_size"],
28
+ 'tr_phase': True}
29
+ tr_gen = DataGen(**DtaGenTr_prm)
30
+ val_gen = DataGen(**DtaGenVal_prm)
31
+ return tr_gen, val_gen
32
+
33
+ def model_chkpnt(val_subject, te_subject, weight_dir, config):
34
+ weight_name = (
35
+ f'CF2D_ValTeSbj_{val_subject}_{te_subject}_nLvl{config["network_prm"]["n_levels"]}'
36
+ f'_InSkp{config["network_prm"]["in_skip"]}_Att{config["network_prm"]["attention"]}'
37
+ f'_Act{config["network_prm"]["act"]}_nInitFlt{config["network_prm"]["n_init_filters"]}_lr{config["learning_prm"]["lr"]}')
38
+ filepath = (weight_dir + '/'+ weight_name +
39
+ '_epc' + "{epoch:03d}" + '_trloss' + "{loss:.5f}" +
40
+ '_valloss' + "{val_loss:.5f}" + ".hdf5")
41
+ model_checkpoint = ModelCheckpoint(filepath=filepath,
42
+ monitor="val_loss",
43
+ verbose=0,
44
+ save_best_only=True)
45
+ reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
46
+ patience=4, min_lr=1e-7)
47
+ return model_checkpoint, reduce_lr
48
+
49
+ def main(config):
50
+ in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config)
51
+ weight_dir = create_weight_dir(val_subject, te_subject, config)
52
+ tr_gen, val_gen = data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config)
53
+ model = clutter_filter_2D(**config)
54
+ model_checkpoint, reduce_lr = model_chkpnt(val_subject, te_subject, weight_dir, config)
55
+ model.fit(tr_gen,
56
+ validation_data=val_gen,
57
+ epochs=config["learning_prm"]["n_epochs"],
58
+ verbose=1,
59
+ callbacks=[model_checkpoint, reduce_lr])
60
+ return None
61
+
62
+ if __name__ == '__main__':
63
+ parser = argparse.ArgumentParser()
64
+ parser.add_argument("--config", help="path of the config file", default="config.json")
65
+ args = parser.parse_args()
66
+ assert os.path.isfile(args.config)
67
+ with open(args.config, "r") as read_file:
68
+ config = json.load(read_file)
69
+ main(config)
Filter2D_Lrec/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "paths": {
3
+ "data_path": "",
4
+ "save_path": ""
5
+ },
6
+ "subject_list": ["rca", "ladprox", "laddist", "lcx", "normal"],
7
+ "CV": {
8
+ "val_subject_id": 0
9
+ },
10
+ "network_prm": {
11
+ "input_dim": [128, 128, 1],
12
+ "n_levels": 4,
13
+ "n_init_filters": 16,
14
+ "in_skip": true,
15
+ "attention": true,
16
+ "kernel_size": 3,
17
+ "conv_stride": 1,
18
+ "upsampling_stride": [2, 2],
19
+ "pool_size": [2, 2],
20
+ "pool_stride": 1,
21
+ "padding": "same",
22
+ "act": "linear"
23
+ },
24
+ "learning_prm": {
25
+ "batch_size": 50,
26
+ "lr": 1e-4,
27
+ "drp": 0.05,
28
+ "loss": "mean_squared_error",
29
+ "metrics": ["mae"],
30
+ "n_epochs": 10
31
+ },
32
+ "tr_phase": true,
33
+ "model_name": "CF2D_L2Loss",
34
+ "weight_name": ""
35
+ }
Filter2D_Lrec/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions.
3
+ """
4
+ import os
5
+ import numpy as np
6
+
7
+ def generate_tr_val_te_subject_ids(subject_list, val_subject_id):
8
+ val_subject = subject_list[val_subject_id]
9
+ te_subject = subject_list[val_subject_id-1]
10
+ subject_list.remove(val_subject)
11
+ subject_list.remove(te_subject)
12
+ tr_subjects = subject_list
13
+ return tr_subjects, val_subject, te_subject
14
+
15
+ def generate_data_ids(data_dir, subject_list):
16
+ in_ids, out_ids = [], []
17
+ vendor_list = [vendor for vendor in os.listdir(data_dir) if '.' not in vendor]
18
+ for vendor in vendor_list:
19
+ vendor_dir = os.path.join(data_dir, vendor)
20
+ view_list = [view for view in os.listdir(vendor_dir) if '.' not in view]
21
+ for view in view_list:
22
+ view_dir = os.path.join(vendor_dir, view)
23
+ subject_full_list = [subject for subject in os.listdir(view_dir) if '.' not in subject]
24
+ for subject in subject_full_list:
25
+ if subject in subject_list:
26
+ subject_dir = os.path.join(view_dir, subject)
27
+ org_data_dir = os.path.join(subject_dir, 'data_org')
28
+ org_data_id = os.path.join(org_data_dir, os.listdir(org_data_dir)[0])
29
+ clutter_list = [clutter for clutter in os.listdir(subject_dir)
30
+ if clutter in ['data_NFClt', 'data_NFRvbClt', 'data_RvbClt']
31
+ and '.' not in clutter]
32
+ for clutter in clutter_list:
33
+ clutter_dir = os.path.join(subject_dir, clutter)
34
+ clutter_ids = os.listdir(clutter_dir)
35
+ clutter_ids_dir = [os.path.join(clutter_dir, id_dir) for id_dir in clutter_ids if '.DS' not in id_dir]
36
+ in_ids += clutter_ids_dir
37
+ out_ids += [org_data_id]*len(os.listdir(clutter_dir))
38
+ return in_ids, out_ids
39
+
40
+ def id_preparation(config):
41
+ tr_subjects, val_subject, te_subject = generate_tr_val_te_subject_ids(
42
+ subject_list=config["subject_list"], val_subject_id=config["CV"]["val_subject_id"])
43
+ if config["tr_phase"]:
44
+ in_ids_tr, out_ids_tr = generate_data_ids(config["paths"]["data_path"], tr_subjects)
45
+ in_ids_val, out_ids_val = generate_data_ids(config["paths"]["data_path"], val_subject)
46
+ return in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject
47
+ else:
48
+ in_ids_te, out_ids_te = generate_data_ids(config["paths"]["data_path"], te_subject)
49
+ return in_ids_te, out_ids_te, te_subject, val_subject
50
+
51
+ def create_weight_dir(val_subject, te_subject, config):
52
+ weight_dir = os.path.join(config["paths"]["save_path"],
53
+ "Weights", f"ValTeIDs_{val_subject}_{te_subject}")
54
+ if not os.path.exists(weight_dir):
55
+ os.makedirs(weight_dir)
56
+ return weight_dir
Filter3D_Lrec/DataGen.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for generating batches of 2D images.
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import random
7
+ import scipy.io as sio
8
+ import tensorflow as tf
9
+
10
+ def _shift_first_frame(vol_in, vol_out, tr_phase):
11
+ n_frames = vol_in.shape[-1]
12
+ n, p = 1, 0.5 # p is the probability of shifting the first frame.
13
+ if tr_phase:
14
+ if np.random.binomial(n,p):
15
+ first_frm = np.random.permutation(np.arange(n_frames))[0]
16
+ vol_in = np.concatenate((vol_in[:,:,first_frm:], vol_in[:,:,:first_frm]), axis=-1)
17
+ vol_out = np.concatenate((vol_out[:,:,first_frm:], vol_out[:,:,:first_frm]), axis=-1)
18
+ return vol_in, vol_out
19
+
20
+ def _image_vol_normalization(vol):
21
+ vol = vol/255
22
+ vol[vol < 0] = 0
23
+ vol[vol > 1] = 1
24
+ return vol
25
+
26
+ def _image_vol_augmentation(vol_in, vol_out, tr_phase):
27
+ vol_in, vol_out = _shift_first_frame(vol_in, vol_out, tr_phase)
28
+ vol_in_norm = _image_vol_normalization(vol_in)
29
+ vol_out_norm = _image_vol_normalization(vol_out)
30
+ vol_shape = [sh for sh in vol_in.shape]
31
+ vol_shape.append(1)
32
+ vol_in, vol_out = np.empty(vol_shape), np.empty(vol_shape)
33
+ vol_in[:,:,:,0], vol_out[:,:,:,0] = vol_in_norm, vol_out_norm
34
+ return [vol_in, vol_out]
35
+
36
+ class DataGen(tf.keras.utils.Sequence):
37
+ """
38
+ Generating batches of input cluttered volumes and their corresponding
39
+ clutter-free output volumes
40
+ """
41
+ def __init__(
42
+ self,
43
+ dim:list,
44
+ in_dir:str,
45
+ out_dir:str,
46
+ id_list:list,
47
+ batch_size:int,
48
+ tr_phase=True,
49
+ te_subsample=False,
50
+ te_frames=0,
51
+ *args,
52
+ **kwargs):
53
+ 'Initialization'
54
+ self.dim = dim
55
+ self.in_dir = in_dir
56
+ self.out_dir = out_dir
57
+ self.id_list = id_list
58
+ self.batch_size = batch_size
59
+ self.tr_phase = tr_phase
60
+
61
+ def __len__(self):
62
+ return int(np.floor(len(self.id_list) / self.batch_size))
63
+
64
+ def __getitem__(self, idx):
65
+ batch = self.id_list[idx*self.batch_size:(idx+1)*self.batch_size]
66
+ x_aug, y_aug = [], []
67
+ for i, ID in enumerate(batch):
68
+ vol_in = sio.loadmat(self.in_dir[ID])['data_artf']
69
+ vol_out = sio.loadmat(self.out_dir[ID])['data_org']
70
+ # Call the data augmentation function
71
+ aug_vols = _image_vol_augmentation(vol_in, vol_out, self.tr_phase)
72
+ x_aug.append(aug_vols[0])
73
+ y_aug.append(aug_vols[1])
74
+ if self.tr_phase:
75
+ return np.asarray(x_aug), np.asarray(y_aug)
76
+ else:
77
+ return np.asarray(x_aug)
Filter3D_Lrec/Error_analysis.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A module with functions for computing different MAE and coherence errors.
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scipy.io as sio
8
+
9
+ def _compute_sample_temporal_coherency_score(filtered_smp, org_smp):
10
+ abs_diff_flt_org_smp = np.abs(filtered_smp-org_smp)
11
+ frm_pixel_sum = [np.sum(abs_diff_flt_org_smp[:,:,i]) for i in range(abs_diff_flt_org_smp.shape[-1])]
12
+ frm_pixel_sum_shifted = np.roll(frm_pixel_sum, -1)
13
+ frm_diff = np.abs(frm_pixel_sum - frm_pixel_sum_shifted)
14
+ frm_diff = frm_diff[:-1]
15
+ return np.mean(frm_diff)
16
+
17
+ def _compute_sample_mae(smp, in_ids, filtered_smp, clutter_class):
18
+ clt_smp = sio.loadmat(in_ids[smp])['data_artf']
19
+ org_smp = sio.loadmat(
20
+ in_ids[smp].split(f'data_{clutter_class}')[0] + 'data_org/1.mat')['data_org']
21
+ mae_CltFiltered_CltFree = np.mean(np.abs(255*filtered_smp-org_smp))
22
+ mae_Cltrd_CltFree = np.mean(np.abs(clt_smp-org_smp))
23
+ temporal_coherency_score = _compute_sample_temporal_coherency_score(255*filtered_smp, org_smp)
24
+ return mae_CltFiltered_CltFree, mae_Cltrd_CltFree, temporal_coherency_score
25
+
26
+ def _make_res_dct():
27
+ res_dct = {'Clutter_class': [],
28
+ 'Clutter_spec': [],
29
+ 'View': [],
30
+ 'Vendor': [],
31
+ 'MAE_CltFiltered_CltFree': [],
32
+ 'MAE_Cltrd_CltFree': [],
33
+ 'temporal_coherency_score': []
34
+ }
35
+ return res_dct
36
+
37
+ def _id_separation(in_id):
38
+ id_part0 = in_id.split('/A')[0].split('/')
39
+ id_part1 = in_id.split('/data_')[1].split('/')
40
+ v = [v for v in in_id.split('/') if 'A' in v and 'C' in v]
41
+ view = v[0]
42
+ vendor = id_part0[-1]
43
+ clutter_class = id_part1[0]
44
+ clutter_spec = id_part1[1]
45
+ return view, vendor, clutter_class, clutter_spec
46
+
47
+ def compute_mae(in_ids, filtered_dta, te_subsample=False, te_frames=50):
48
+ res_dct = _make_res_dct()
49
+ for i in range(len(in_ids)):
50
+ view, vendor, clutter_class, clutter_spec = _id_separation(in_ids[i])
51
+ res_dct['Clutter_class'].append(clutter_class)
52
+ res_dct['Clutter_spec'].append(clutter_spec)
53
+ res_dct['Vendor'].append(vendor)
54
+ res_dct['View'].append(view)
55
+ mae_CltFiltered_CltFree, mae_Cltrd_CltFree, temporal_coherency_score = _compute_sample_mae(
56
+ smp=i, in_ids=in_ids, filtered_smp=filtered_dta[i,:,:,:,0],
57
+ clutter_class=clutter_class)
58
+ res_dct['MAE_CltFiltered_CltFree'].append(mae_CltFiltered_CltFree)
59
+ res_dct['MAE_Cltrd_CltFree'].append(mae_Cltrd_CltFree)
60
+ res_dct['temporal_coherency_score'].append(temporal_coherency_score)
61
+ return pd.DataFrame(res_dct)
Filter3D_Lrec/Model_ClutterFilter3D.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions of the 3D clutter filtering algorithm.
3
+ """
4
+ import tensorflow as tf
5
+ from tensorflow.keras import backend as K
6
+ from tensorflow.keras.layers import (Conv3D, MaxPooling3D, Activation, BatchNormalization, Add,
7
+ Dropout, Concatenate, UpSampling3D, multiply, Input, Lambda)
8
+ from tensorflow.keras.models import Model
9
+ from tensorflow.keras.optimizers import Adam
10
+
11
+ def tensor_expansion(tensor, rep, axs):
12
+ expanded_tensor = Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=axs),
13
+ arguments={'repnum': rep})(tensor)
14
+ return expanded_tensor
15
+
16
+ def attention_gate_block_3D(x, g, n_inter_filters=None, name=None, **prm):
17
+ """
18
+ Attention gate block.
19
+ """
20
+ shape_x = K.int_shape(x)
21
+ shape_g = K.int_shape(g)
22
+ if n_inter_filters is None:
23
+ n_inter_filters = shape_x[-1] // 2
24
+
25
+ theta_x = Conv3D(n_inter_filters, 3, strides=(2, 2, 1), padding='same', name=f"{name}_theta_x")(x)
26
+ phi_g = Conv3D(n_inter_filters, 1, strides=1, padding='valid', name=f"{name}_phi_g")(g)
27
+ concat_xg = Add()([phi_g, theta_x])
28
+ act_xg = Activation('relu')(concat_xg)
29
+ psi = Conv3D(1, 1, padding='same', name=f"{name}_psi")(act_xg)
30
+ sigmoid_xg = Activation('sigmoid')(psi)
31
+ shape_sigmoid = K.int_shape(sigmoid_xg)
32
+ upsample_sigmoid = UpSampling3D(size=(2, 2, 1), name=f"{name}_upsampled_sig")(sigmoid_xg)
33
+ upsample_sigmoid_rep = tensor_expansion(upsample_sigmoid, rep=shape_x[-1], axs=-1)
34
+ y = multiply([upsample_sigmoid_rep, x], name=f"{name}_weighted_x")
35
+ return y
36
+
37
+ def conv_block(x, filters_list, act='linear', kernel_size=3, stride=1, pad='same', drp=0.05, name=None):
38
+ """
39
+ Blocks of 3D conv filters.
40
+ """
41
+ for i in range(len(filters_list)):
42
+ x = Conv3D(filters_list[i], kernel_size, padding=pad, strides=stride, name=f"{name}_blk{i+1}")(x)
43
+ x = BatchNormalization(name=f"{name}_bn{i+1}")(x)
44
+ x = Activation(act, name=f"{name}_act{i+1}")(x)
45
+ x = Dropout(drp)(x)
46
+ return x
47
+
48
+ def encoding_block(x_in, name, **config):
49
+ """
50
+ Encoding block of the 3D Unet.
51
+ """
52
+ encoding_dct = {}
53
+ for i in range(config["network_prm"]["n_levels"]):
54
+ if i == 0:
55
+ x = x_in
56
+ n_filters = config["network_prm"]["n_init_filters"]
57
+ else:
58
+ n_filters = (2**i)*config["network_prm"]["n_init_filters"]
59
+ x = MaxPooling3D(pool_size=config["network_prm"]["pool_size"], name=f"{name}_encd_pool{i}")(x)
60
+ x = conv_block(x, filters_list=[n_filters, 2*n_filters], act=config["network_prm"]["act"],
61
+ kernel_size=config["network_prm"]["kernel_size"],
62
+ stride=config["network_prm"]["conv_stride"],
63
+ pad=config["network_prm"]["padding"],
64
+ drp=config["learning_prm"]['drp'], name=f"{name}_encd_conv_lvl{i}")
65
+ encoding_dct[f"{name}_out_lvl{i}"] = x
66
+ return encoding_dct
67
+
68
+ def decoding_block(encoding_dct, name, **config):
69
+ """
70
+ Decoding block of the 3D Unet.
71
+ """
72
+ decoding_dct = {}
73
+ n_levels = config["network_prm"]["n_levels"]
74
+ for i in range(n_levels-1):
75
+ if i == 0:
76
+ x = encoding_dct[f"{name}_out_lvl{n_levels-i-1}"]
77
+ # upsampling via Conv(Upsampling)
78
+ x_shape = K.int_shape(x)
79
+ x_up = Conv3D(x_shape[-1], 2, activation=config["network_prm"]["act"], padding='same', strides=1,
80
+ name=f"{name}_decd_upsmpl{i}")(UpSampling3D(size=(2,2,1))(x))
81
+ x_up_shape = K.int_shape(x_up)
82
+ # concatenation
83
+ if config["network_prm"]['attention']:
84
+ if i == 0:
85
+ g = encoding_dct[f"{name}_out_lvl{n_levels-1}"]
86
+ else:
87
+ g = decoding_dct[f"{name}_out_lvl{i-1}"]
88
+ x_encd = attention_gate_block_3D(x=encoding_dct[f"{name}_out_lvl{n_levels-i-2}"],
89
+ g=g, name=f"{name}_att_blk{i}")
90
+ else:
91
+ x_encd = encoding_dct[f"{name}_out_lvl{n_levels-i-2}"]
92
+ x_concat = Concatenate(axis=-1, name=f"{name}_decd_concat{i}")([x_encd, x_up])
93
+ n_filters = x_up_shape[-1]//2
94
+ x = conv_block(x_concat, filters_list=[n_filters, n_filters], act=config["network_prm"]["act"],
95
+ kernel_size=config["network_prm"]["kernel_size"],
96
+ stride=config["network_prm"]["conv_stride"],
97
+ pad=config["network_prm"]["padding"],
98
+ drp=config["learning_prm"]['drp'], name=f"{name}_decd_conv_lvl{i}")
99
+ decoding_dct[f"{name}_out_lvl{i}"] = x
100
+ x = conv_block(x, filters_list=[1], act=config["network_prm"]["act"], kernel_size=1, stride=1,
101
+ pad='same', drp=1e-4, name=f"{name}_final_decd_conv")
102
+ decoding_dct[f"{name}_final_conv"] = x
103
+ return decoding_dct
104
+
105
+ def Unet3D(x_in, name, **config):
106
+ """
107
+ Spatiotemporal clutter filtering model based on the 3D Unet.
108
+ """
109
+ encoding_dct = encoding_block(x_in, name, **config)
110
+ decoding_dct = decoding_block(encoding_dct, name, **config)
111
+ if config["network_prm"]["in_skip"]:
112
+ out_Unet = Add()([x_in, decoding_dct[f"{name}_final_conv"]])
113
+ else:
114
+ out_Unet = decoding_dct[f"{name}_final_conv"]
115
+ return out_Unet
116
+
117
+ def clutter_filter_3D(**config):
118
+ """
119
+ The main function for designing the clutter filtering algorithm.
120
+ """
121
+ main_in = Input(config["network_prm"]["input_dim"])
122
+ filter_out = Unet3D(x_in=main_in, name="CF", **config)
123
+ model = Model(inputs=main_in, outputs=filter_out, name=config['model_name'])
124
+ opt = Adam(learning_rate=config["learning_prm"]["lr"])
125
+ model.compile(optimizer=opt, loss=config["learning_prm"]["loss"],
126
+ metrics=config["learning_prm"]["metrics"])
127
+ model.summary()
128
+ return model
Filter3D_Lrec/TestClutterFilter3D.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for testing the 3D clutter filtering model.
3
+ """
4
+ import os
5
+ import argparse
6
+ import json
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from utils import *
11
+ from Model_ClutterFilter3D import clutter_filter_3D
12
+ from DataGen import DataGen
13
+ from Error_analysis import compute_mae
14
+
15
+ def data_generation(in_ids_te, out_ids_te, config):
16
+ DtaGenTe_prm = {
17
+ 'dim': config["network_prm"]["input_dim"],
18
+ 'in_dir': in_ids_te,
19
+ 'out_dir': out_ids_te,
20
+ 'id_list': np.arange(len(in_ids_te)),
21
+ 'batch_size': config["learning_prm"]["batch_size"],
22
+ 'tr_phase': False}
23
+ return DataGen(**DtaGenTe_prm)
24
+
25
+ def main(config):
26
+ in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config)
27
+ te_gen = data_generation(in_ids_te, out_ids_te, config)
28
+ model = clutter_filter_3D(**config)
29
+ weight_dir = create_weight_dir(val_subject, te_subject, config)
30
+ model.load_weights(
31
+ os.path.join(weight_dir, config["weight_name"] + ".hdf5"))
32
+ results_te = model.predict_generator(te_gen, verbose=2)
33
+ df_errors = compute_mae(in_ids_te, results_te)
34
+ df_errors.to_csv(
35
+ os.path.join(weight_dir, config["weight_name"] + ".csv"))
36
+ return None
37
+
38
+ if __name__ == '__main__':
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument("--config", help="path of the config file", default="config.json")
41
+ args = parser.parse_args()
42
+ assert os.path.isfile(args.config)
43
+ with open(args.config, "r") as read_file:
44
+ config = json.load(read_file)
45
+ main(config)
Filter3D_Lrec/TrainClutterFilter3D.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for training the 3D clutter filtering model with L2 loss.
3
+ """
4
+ import os
5
+ import argparse
6
+ import json
7
+ import numpy as np
8
+ from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
9
+
10
+ from utils import *
11
+ from Model_ClutterFilter3D import clutter_filter_3D
12
+ from DataGen import DataGen
13
+
14
+ def data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config):
15
+ DtaGenTr_prm = {
16
+ 'dim': config["network_prm"]["input_dim"],
17
+ 'in_dir': in_ids_tr,
18
+ 'out_dir': out_ids_tr,
19
+ 'id_list': np.arange(len(in_ids_tr)),
20
+ 'batch_size': config["learning_prm"]["batch_size"],
21
+ 'tr_phase': True}
22
+ DtaGenVal_prm = {
23
+ 'dim': config["network_prm"]["input_dim"],
24
+ 'in_dir': in_ids_val,
25
+ 'out_dir': out_ids_val,
26
+ 'id_list': np.arange(len(in_ids_val)),
27
+ 'batch_size': config["learning_prm"]["batch_size"],
28
+ 'tr_phase': True}
29
+ tr_gen = DataGen(**DtaGenTr_prm)
30
+ val_gen = DataGen(**DtaGenVal_prm)
31
+ return tr_gen, val_gen
32
+
33
+ def model_chkpnt(val_subject, te_subject, weight_dir, config):
34
+ weight_name = (
35
+ f'CF3D_ValTeSbj_{val_subject}_{te_subject}_nLvl{config["network_prm"]["n_levels"]}'
36
+ f'_InSkp{config["network_prm"]["in_skip"]}_Att{config["network_prm"]["attention"]}'
37
+ f'_Act{config["network_prm"]["act"]}_nInitFlt{config["network_prm"]["n_init_filters"]}'
38
+ f'_lr{config["learning_prm"]["lr"]}')
39
+ filepath = (weight_dir + '/'+ weight_name +
40
+ '_epc' + "{epoch:03d}" + '_trloss' + "{loss:.5f}" +
41
+ '_valloss' + "{val_loss:.5f}" + ".hdf5")
42
+ model_checkpoint = ModelCheckpoint(filepath=filepath,
43
+ monitor="val_loss",
44
+ verbose=0,
45
+ save_best_only=True)
46
+ reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
47
+ patience=4, min_lr=1e-7)
48
+ return model_checkpoint, reduce_lr
49
+
50
+ def main(config):
51
+ in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config)
52
+ weight_dir = create_weight_dir(val_subject, te_subject, config)
53
+ tr_gen, val_gen = data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config)
54
+ model = clutter_filter_3D(**config)
55
+ model_checkpoint, reduce_lr = model_chkpnt(val_subject, te_subject, weight_dir, config)
56
+ model.fit(tr_gen,
57
+ validation_data=val_gen,
58
+ epochs=config["learning_prm"]["n_epochs"],
59
+ verbose=1,
60
+ callbacks=[model_checkpoint, reduce_lr])
61
+ return None
62
+
63
+ if __name__ == '__main__':
64
+ parser = argparse.ArgumentParser()
65
+ parser.add_argument("--config", help="path of the config file", default="config.json")
66
+ args = parser.parse_args()
67
+ assert os.path.isfile(args.config)
68
+ with open(args.config, "r") as read_file:
69
+ config = json.load(read_file)
70
+ main(config)
Filter3D_Lrec/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "paths": {
3
+ "data_path": "",
4
+ "save_path": ""
5
+ },
6
+ "subject_list": ["rca", "ladprox", "laddist", "lcx", "normal"],
7
+ "CV": {
8
+ "val_subject_id": 0
9
+ },
10
+ "network_prm": {
11
+ "input_dim": [128, 128, 50, 1],
12
+ "n_levels": 4,
13
+ "n_init_filters": 16,
14
+ "in_skip": true,
15
+ "attention": true,
16
+ "kernel_size": 3,
17
+ "conv_stride": 1,
18
+ "upsampling_stride": [2, 2, 1],
19
+ "pool_size": [2, 2, 1],
20
+ "pool_stride": 1,
21
+ "padding": "same",
22
+ "act": "linear"
23
+ },
24
+ "learning_prm": {
25
+ "batch_size": 1,
26
+ "lr": 1e-4,
27
+ "drp": 0.05,
28
+ "loss": "mean_squared_error",
29
+ "metrics": ["mae"],
30
+ "n_epochs": 10
31
+ },
32
+ "tr_phase": true,
33
+ "model_name": "CF3D_L2Loss",
34
+ "weight_name": ""
35
+ }
Filter3D_Lrec/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions.
3
+ """
4
+ import os
5
+ import numpy as np
6
+
7
+ def generate_tr_val_te_subject_ids(subject_list, val_subject_id):
8
+ val_subject = subject_list[val_subject_id]
9
+ te_subject = subject_list[val_subject_id-1]
10
+ subject_list.remove(val_subject)
11
+ subject_list.remove(te_subject)
12
+ tr_subjects = subject_list
13
+ return tr_subjects, val_subject, te_subject
14
+
15
+ def generate_data_ids(data_dir, subject_list):
16
+ in_ids, out_ids = [], []
17
+ vendor_list = [vendor for vendor in os.listdir(data_dir) if '.' not in vendor]
18
+ for vendor in vendor_list:
19
+ vendor_dir = os.path.join(data_dir, vendor)
20
+ view_list = [view for view in os.listdir(vendor_dir) if '.' not in view]
21
+ for view in view_list:
22
+ view_dir = os.path.join(vendor_dir, view)
23
+ subject_full_list = [subject for subject in os.listdir(view_dir) if '.' not in subject]
24
+ for subject in subject_full_list:
25
+ if subject in subject_list:
26
+ subject_dir = os.path.join(view_dir, subject)
27
+ org_data_dir = os.path.join(subject_dir, 'data_org')
28
+ org_data_id = os.path.join(org_data_dir, os.listdir(org_data_dir)[0])
29
+ clutter_list = [clutter for clutter in os.listdir(subject_dir)
30
+ if clutter in ['data_NFClt', 'data_NFRvbClt', 'data_RvbClt']
31
+ and '.' not in clutter]
32
+ for clutter in clutter_list:
33
+ clutter_dir = os.path.join(subject_dir, clutter)
34
+ clutter_ids = os.listdir(clutter_dir)
35
+ clutter_ids_dir = [os.path.join(clutter_dir, id_dir) for id_dir in clutter_ids if '.DS' not in id_dir]
36
+ in_ids += clutter_ids_dir
37
+ out_ids += [org_data_id]*len(os.listdir(clutter_dir))
38
+ return in_ids, out_ids
39
+
40
+ def id_preparation(config):
41
+ tr_subjects, val_subject, te_subject = generate_tr_val_te_subject_ids(
42
+ subject_list=config["subject_list"], val_subject_id=config["CV"]["val_subject_id"])
43
+ if config["tr_phase"]:
44
+ in_ids_tr, out_ids_tr = generate_data_ids(config["paths"]["data_path"], tr_subjects)
45
+ in_ids_val, out_ids_val = generate_data_ids(config["paths"]["data_path"], val_subject)
46
+ return in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject
47
+ else:
48
+ in_ids_te, out_ids_te = generate_data_ids(config["paths"]["data_path"], te_subject)
49
+ return in_ids_te, out_ids_te, te_subject, val_subject
50
+
51
+ def create_weight_dir(val_subject, te_subject, config):
52
+ weight_dir = os.path.join(config["paths"]["save_path"],
53
+ "Weights", f"ValTeIDs_{val_subject}_{te_subject}")
54
+ if not os.path.exists(weight_dir):
55
+ os.makedirs(weight_dir)
56
+ return weight_dir
Filter3D_Lrec_adv/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Filter3D_Lrec_adv/DataGen.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for generating batches of 3D images and a 3D mask.
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import random
7
+ import scipy.io as sio
8
+ import tensorflow as tf
9
+
10
+ def _shift_first_frame(vol_in, vol_out, tr_phase):
11
+ n_frames = vol_in.shape[-1]
12
+ n, p = 1, 0.5 # p is the probability of shifting the first frame.
13
+ if tr_phase:
14
+ if np.random.binomial(n,p):
15
+ first_frm = np.random.permutation(np.arange(n_frames))[0]
16
+ vol_in = np.concatenate((vol_in[:,:,first_frm:], vol_in[:,:,:first_frm]), axis=-1)
17
+ vol_out = np.concatenate((vol_out[:,:,first_frm:], vol_out[:,:,:first_frm]), axis=-1)
18
+ return vol_in, vol_out
19
+
20
+ def _image_vol_normalization(vol):
21
+ vol = vol/255
22
+ vol[vol < 0] = 0
23
+ vol[vol > 1] = 1
24
+ return vol
25
+
26
+ def _image_vol_augmentation(vol_in, vol_out, tr_phase):
27
+ vol_in, vol_out = _shift_first_frame(vol_in, vol_out, tr_phase)
28
+ vol_in_norm = _image_vol_normalization(vol_in)
29
+ vol_out_norm = _image_vol_normalization(vol_out)
30
+ vol_shape = [sh for sh in vol_in.shape]
31
+ vol_shape.append(1)
32
+ vol_in, vol_out = np.empty(vol_shape), np.empty(vol_shape)
33
+ vol_in[:,:,:,0], vol_out[:,:,:,0] = vol_in_norm, vol_out_norm
34
+ return [vol_in, vol_out]
35
+
36
+ def _mask_generation(clt_pos_dta, NFClt):
37
+ vol_mask = np.zeros((128, 128, 50))
38
+ for frm in range(50):
39
+ if NFClt:
40
+ x0, x1 = clt_pos_dta[0,0], clt_pos_dta[0,1]
41
+ y0, y1 = clt_pos_dta[1,0], clt_pos_dta[1,1]
42
+ elif clt_pos_dta.shape[-1] == 1:
43
+ x0, x1 = clt_pos_dta[frm][0][0,0], clt_pos_dta[frm][0][0,1]
44
+ y0, y1 = clt_pos_dta[frm][0][1,0], clt_pos_dta[frm][0][1,1]
45
+ else:
46
+ x0 = np.min([clt_pos_dta[0][0][0,0], clt_pos_dta[0][1][0,0]])
47
+ x1 = np.max([clt_pos_dta[0][0][0,1], clt_pos_dta[0][1][0,1]])
48
+ y0 = np.min([clt_pos_dta[0][0][1,0], clt_pos_dta[0][1][1,0]])
49
+ y1 = np.max([clt_pos_dta[0][0][1,1], clt_pos_dta[0][1][1,1]])
50
+ # create the masks
51
+ vol_mask[x0:x1,y0:y1,frm] = 1
52
+ return vol_mask
53
+
54
+ class DataGen(tf.keras.utils.Sequence):
55
+ """
56
+ Generating batches of input cluttered volumes and their corresponding
57
+ clutter-free output volumes
58
+ """
59
+ def __init__(
60
+ self,
61
+ dim:list,
62
+ in_dir:str,
63
+ out_dir:str,
64
+ id_list:list,
65
+ batch_size:int,
66
+ masked_in=True,
67
+ tr_phase=True,
68
+ *args,
69
+ **kwargs):
70
+
71
+ 'Initialization'
72
+ self.dim = dim
73
+ self.in_dir = in_dir
74
+ self.out_dir = out_dir
75
+ self.id_list = id_list
76
+ self.batch_size = batch_size
77
+ self.masked_in = masked_in
78
+ self.tr_phase = tr_phase
79
+
80
+ def __len__(self):
81
+ return int(np.floor(len(self.id_list) / self.batch_size))
82
+
83
+ def __getitem__(self, idx):
84
+ batch = self.id_list[idx*self.batch_size:(idx+1)*self.batch_size]
85
+ x_aug, y_aug, mask = [], [], []
86
+ for i, ID in enumerate(batch):
87
+ vol_in = sio.loadmat(self.in_dir[ID])['data_artf']
88
+ vol_out = sio.loadmat(self.out_dir[ID])['data_org']
89
+ # Call the data augmentation function
90
+ aug_vols = _image_vol_augmentation(vol_in, vol_out, self.tr_phase)
91
+ # Creat the binary mask
92
+ if self.masked_in:
93
+ clt_pos_id = self.in_dir[ID]
94
+ clt_pos_id = clt_pos_id.replace('/data_', '/pos_')
95
+ if 'pos_NFClt' in clt_pos_id:
96
+ clt_pos_dta = sio.loadmat(clt_pos_id)['pos_sp_nf']
97
+ NFClt = True
98
+ else:
99
+ clt_pos_dta = sio.loadmat(clt_pos_id)['pos_artf']
100
+ NFClt = False
101
+ vol_mask = _mask_generation(clt_pos_dta, NFClt)
102
+ mask.append(vol_mask)
103
+ x_aug.append(aug_vols[0])
104
+ y_aug.append(aug_vols[1])
105
+
106
+ if self.tr_phase and self.masked_in:
107
+ return np.asarray(x_aug), np.asarray(y_aug), np.asarray(mask)
108
+ elif self.tr_phase:
109
+ return np.asarray(x_aug), np.asarray(y_aug)
110
+ else:
111
+ return np.asarray(x_aug)
Filter3D_Lrec_adv/Discriminator.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for training discriminator model (ResNet) of the 3D clutter filtering network.
3
+ """
4
+ import tensorflow as tf
5
+ from tensorflow.keras import backend as K
6
+ from tensorflow.keras.layers import (Conv3D, MaxPooling3D, Dense, Activation, BatchNormalization, Add,
7
+ Dropout, Concatenate, multiply, Input, Flatten, Lambda)
8
+ from tensorflow.keras.models import Model
9
+ from tensorflow.keras.optimizers import Adam
10
+
11
+ def max_fnc(x):
12
+ return K.max(x, axis=1, keepdims=False)
13
+
14
+ def ResNetModule(x, n_filters, n_blks, krnl_size=3, stride=1, name=None, **prm):
15
+
16
+ shortcut = Conv3D(2*n_filters, 1, strides=stride, name=f"{name}_conv_shortcut")(x)
17
+ shortcut = BatchNormalization(name=f"{name}_bn_shortcut")(shortcut)
18
+ for i in range(n_blks):
19
+ x = Conv3D(n_filters, krnl_size, padding='same', name=f"{name}_conv{i+1}")(x)
20
+ x = BatchNormalization(name=f"{name}_bn{i+1}")(x)
21
+ x = Activation('relu', name=f"{name}_relu{i+1}")(x)
22
+ x = Dropout(0.05, name=f"{name}_drp{i+1}")(x)
23
+
24
+ x = Conv3D(2*n_filters, 1, name=f"{name}_conv_last")(x)
25
+ x = BatchNormalization(name=f"{name}_bn_last")(x)
26
+ x = Add(name=f"{name}_add")([shortcut, x])
27
+ x = Activation('relu', name=f"{name}_out")(x)
28
+ return x
29
+
30
+ def ConvToFC(inp, kernel_size=3, name=None, **prm):
31
+ inp_shape = K.int_shape(inp)
32
+ x = Conv3D(filters=prm['ConvToFC'][0], kernel_size=(1,1,inp_shape[3]),
33
+ padding='valid', activation='relu', strides=1, name=f'{name}_1')(inp)
34
+ x = Conv3D(filters=prm['ConvToFC'][1], kernel_size=(inp_shape[1],inp_shape[2],1),
35
+ padding='valid', activation='relu', strides=1, name=f'{name}_2')(x)
36
+ if len(prm['ConvToFC']) > 2:
37
+ for i in range(2, len(prm['ConvToFC'])):
38
+ x = Conv3D(filters=prm['ConvToFC'][i], kernel_size=1, padding='valid',
39
+ activation='relu', strides=1, name=f'{name}_{i+1}')(x)
40
+ print(x.shape)
41
+ x = Flatten()(x)
42
+ return x
43
+
44
+ def ResNet(inp, n_krn, name=None, **prm):
45
+ x = Conv3D(n_krn, kernel_size=prm['kernel_size'], padding='same',
46
+ data_format="channels_last", strides=1, name=f"{name}_conv0")(inp)
47
+ x = BatchNormalization(name=f"{name}_bn0")(x)
48
+ x = Activation('relu', name=f"{name}_relu0")(x)
49
+ for i in range(len(prm['lvl_blks_config'])):
50
+ x = MaxPooling3D(pool_size=prm['pool_size'], strides=prm['strides'], name=f"{name}_pool{i+1}")(x)
51
+ x = ResNetModule(x, n_filters=(2**i)*n_krn, krnl_size=prm['kernel_size'], stride=1,
52
+ n_blks=prm['lvl_blks_config'][i], name=f"{name}_module{i+1}", **prm)
53
+ x = ConvToFC(x, name='CnvToFC', **prm)
54
+ return x
55
+
56
+ def DenseLayer(x, name=None, **prm):
57
+ for i in range(len(prm['dense_layer_spec'])):
58
+ x = Dense(prm['dense_layer_spec'][i], activation='relu', name=f"{name}_DenseLayer_{i}")(x)
59
+ return Dense(1, activation='sigmoid', name=f"{name}_DenseLayer_out")(x)
60
+
61
+ def discriminator_3D(lr, **prm):
62
+ inp_3D = Input(prm['input_dim'])
63
+ conv_out = ResNet(inp_3D, n_krn=prm['n_init_filters'], name=prm['model_name'], **prm)
64
+ print(conv_out.shape)
65
+ dense_out = DenseLayer(x=conv_out, name=prm['model_name'], **prm)
66
+ conv_model = Model(inputs=inp_3D, outputs=dense_out, name=prm['model_name'])
67
+ conv_model.summary()
68
+ opt = Adam(learning_rate=lr)
69
+ conv_model.compile(optimizer=opt, loss=prm['loss'], metrics=prm['metrics'])
70
+ return conv_model
Filter3D_Lrec_adv/TestClutterFilter3D_GAN.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for testing the 3D clutter filtering model.
3
+ """
4
+ import os
5
+ import sys
6
+ import argparse
7
+ import json
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from tensorflow.keras.layers import Input
11
+ from tensorflow.keras.models import Model
12
+ from tensorflow.keras.optimizers import Adam
13
+
14
+ code_dir = os.path.dirname(os.path.abspath(__file__))
15
+ sys.path.append(os.path.dirname(code_dir))
16
+ from Filter3D_Lrec.utils import *
17
+ from Filter3D_Lrec.Error_analysis import *
18
+ from Filter3D_Lrec.Model_ClutterFilter3D import Unet3D
19
+ from DataGen import DataGen
20
+ from Discriminator import discriminator_3D
21
+ from Train_GAN import train_gan
22
+
23
+ def data_generation(in_ids_te, out_ids_te, config):
24
+ DtaGenTe_prm = {
25
+ 'dim': config["generator_prm"]["input_dim"],
26
+ 'in_dir': in_ids_te,
27
+ 'out_dir': out_ids_te,
28
+ 'id_list': np.arange(len(in_ids_te)),
29
+ 'batch_size': config["learning_prm"]["batch_size"],
30
+ 'tr_phase': False}
31
+ return DataGen(**DtaGenTe_prm)
32
+
33
+ def generator_3D(**prm):
34
+ prm_gen = prm
35
+ prm_gen["network_prm"] = prm["generator_prm"]
36
+ main_in = Input(prm_gen["network_prm"]["input_dim"])
37
+ filter_out = Unet3D(x_in=main_in, name="CF", **prm_gen)
38
+ model = Model(inputs=main_in, outputs=filter_out, name='generator_3D')
39
+ model.summary()
40
+ return model
41
+
42
+ def main(config):
43
+ in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config)
44
+ te_gen = data_generation(in_ids_te, out_ids_te, config)
45
+ weight_dir = create_weight_dir(val_subject, te_subject, config)
46
+ generator = generator_3D(**config)
47
+ generator.load_weights(os.path.join(weight_dir, config["weight_name"] + ".hdf5"))
48
+ results_te = generator.predict_generator(te_gen, verbose=2)
49
+ df_errors = compute_mae(in_ids_te, results_te)
50
+ df_errors.to_csv(
51
+ os.path.join(weight_dir, config["weight_name"] + ".csv"))
52
+ return None
53
+
54
+ if __name__ == '__main__':
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument("--config", help="path of the config file", default="config.json")
57
+ args = parser.parse_args()
58
+ assert os.path.isfile(args.config)
59
+ with open(args.config, "r") as read_file:
60
+ config = json.load(read_file)
61
+ main(config)
Filter3D_Lrec_adv/TrainClutterFilter3D_L2_adv_loss.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for training the 3D clutter filtering model with adversarial loss.
3
+ """
4
+ import os
5
+ import sys
6
+ import argparse
7
+ import json
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from tensorflow.keras.layers import Input
11
+ from tensorflow.keras.models import Model
12
+ from tensorflow.keras.optimizers import Adam
13
+
14
+ code_dir = os.path.dirname(os.path.abspath(__file__))
15
+ sys.path.append(os.path.dirname(code_dir))
16
+ from Filter3D_Lrec.utils import *
17
+ from Filter3D_Lrec.Model_ClutterFilter3D import Unet3D
18
+ from DataGen import DataGen
19
+ from Discriminator import discriminator_3D
20
+ from Train_GAN import train_gan
21
+
22
+ def data_generation(in_ids_tr, out_ids_tr, config):
23
+ DtaGenTr_prm = {
24
+ 'dim': config["generator_prm"]["input_dim"],
25
+ 'in_dir': in_ids_tr,
26
+ 'out_dir': out_ids_tr,
27
+ 'id_list': np.arange(len(in_ids_tr)),
28
+ 'batch_size': config["learning_prm"]["batch_size"],
29
+ 'tr_phase': True}
30
+ return DataGen(**DtaGenTr_prm)
31
+
32
+ def generator_3D(**prm):
33
+ prm_gen = prm
34
+ prm_gen["network_prm"] = prm["generator_prm"]
35
+ main_in = Input(prm_gen["network_prm"]["input_dim"])
36
+ filter_out = Unet3D(x_in=main_in, name="CF", **prm_gen)
37
+ model = Model(inputs=main_in, outputs=filter_out, name='generator_3D')
38
+ model.summary()
39
+ return model
40
+
41
+ def gan_3D(gen_model, disc_model, w_g, w_d, lr):
42
+ disc_model.trainable = False
43
+ disc_output = disc_model(gen_model.output)
44
+ opt = Adam(lr=lr, beta_1=0.5)
45
+ model = Model(gen_model.input, [gen_model.output, disc_output])
46
+ model.compile(loss=['mse', 'binary_crossentropy'],
47
+ loss_weights=[w_g, w_d],
48
+ optimizer=opt)
49
+ return model
50
+
51
+ def main(config):
52
+ in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config)
53
+ weight_dir = create_weight_dir(val_subject, te_subject, config)
54
+ weight_name = (
55
+ f'CF3D_GAN_ValTeSbj_{val_subject}_{te_subject}'
56
+ f'_InSkp{config["generator_prm"]["in_skip"]}'
57
+ f'_Att{config["generator_prm"]["attention"]}_lr{config["learning_prm"]["lr"]}'
58
+ f'_wg{config["generator_prm"]["w_g"]}'
59
+ f'_MaskedIn{config["discriminator_prm"]["masked_in"]}')
60
+ tr_gen = data_generation(in_ids_tr, out_ids_tr, config)
61
+ generator = generator_3D(**config)
62
+ discriminator = discriminator_3D(config["learning_prm"]["lr"],
63
+ **config["discriminator_prm"])
64
+ gan_model = gan_3D(generator, discriminator,
65
+ config["generator_prm"]["w_g"], config["discriminator_prm"]["w_d"],
66
+ config["learning_prm"]["lr"])
67
+ train_gan(g_model=generator, d_model=discriminator, gan_model=gan_model,
68
+ n_epochs=config["learning_prm"]["n_epochs"],
69
+ tr_batches=tr_gen, w_dir=weight_dir, w_name=weight_name,
70
+ masked_in=config["discriminator_prm"]['masked_in'])
71
+ return None
72
+
73
+ if __name__ == '__main__':
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument("--config", help="path of the config file", default="config.json")
76
+ args = parser.parse_args()
77
+ assert os.path.isfile(args.config)
78
+ with open(args.config, "r") as read_file:
79
+ config = json.load(read_file)
80
+ main(config)
Filter3D_Lrec_adv/Train_GAN.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for training the 3D GAN model.
3
+ """
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from tensorflow.keras import backend as K
7
+ from tensorflow.keras.models import Model
8
+
9
+ def _apply_the_mask(in_g, in_d, gen_out, th=0.1):
10
+ mask = np.abs(in_g-in_d)
11
+ mask[mask < th] = 0
12
+ mask[mask >= th] = 1
13
+ in_d_masked = np.multiply(in_d, mask)
14
+ gen_out_masked = np.multiply(gen_out, mask)
15
+ return gen_out_masked, in_d_masked
16
+
17
+ def train_gan(g_model, d_model, gan_model, n_epochs, tr_batches, w_dir, w_name, masked_in, **prm):
18
+ for i in range(n_epochs):
19
+ print(f"epoch:{i}")
20
+ rnd_ids = np.random.permutation(tr_batches.__len__())
21
+ for j in range(tr_batches.__len__()):
22
+ # Train Discriminator
23
+ in_g, in_d = tr_batches.__getitem__(rnd_ids[j])[0], tr_batches.__getitem__(rnd_ids[j])[1]
24
+ gen_out = g_model.predict(in_g)
25
+ if masked_in:
26
+ gen_out, in_d = _apply_the_mask(in_g, in_d, gen_out)
27
+ d_loss_r, d_acc_r = d_model.train_on_batch(in_d, np.ones((1, 1)))
28
+ d_loss_f, d_acc_f = d_model.train_on_batch(gen_out, np.zeros((1, 1)))
29
+ d_loss = 0.5 * np.add(d_loss_r, d_loss_f)
30
+ # Train Generator
31
+ g_loss = gan_model.train_on_batch(tr_batches.__getitem__(rnd_ids[j])[0],
32
+ [tr_batches.__getitem__(rnd_ids[j])[1], np.ones((1, 1))])
33
+ # Save weights after each epoch
34
+ filename = (w_dir + '/' + f"{w_name}_epc{i}_g_loss{np.round(g_loss, 3)}_d_loss{np.round(d_loss, 3)}" + ".hdf5")
35
+ g_model.save_weights(filename)
36
+ return None
Filter3D_Lrec_adv/config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "paths": {
3
+ "data_path": "",
4
+ "save_path": ""
5
+ },
6
+ "subject_list": ["rca", "ladprox", "laddist", "lcx", "normal"],
7
+ "CV": {
8
+ "val_subject_id": 0
9
+ },
10
+ "generator_prm": {
11
+ "input_dim": [128, 128, 50, 1],
12
+ "n_levels": 4,
13
+ "n_init_filters": 16,
14
+ "in_skip": true,
15
+ "attention": true,
16
+ "kernel_size": 3,
17
+ "conv_stride": 1,
18
+ "upsampling_stride": [2, 2, 1],
19
+ "pool_size": [2, 2, 1],
20
+ "pool_stride": 1,
21
+ "padding": "same",
22
+ "act": "linear",
23
+ "w_g": 0.999
24
+ },
25
+ "discriminator_prm": {
26
+ "input_dim": [128, 128, 50, 1],
27
+ "n_init_filters": 16,
28
+ "lvl_blks_config": [3, 4, 6, 3],
29
+ "kernel_size": 3,
30
+ "pool_size": 2,
31
+ "strides": 2,
32
+ "ConvToFC": [128, 32],
33
+ "dense_layer_spec": [64, 32, 16],
34
+ "loss": "binary_crossentropy",
35
+ "metrics": ["accuracy"],
36
+ "attention": false,
37
+ "masked_in": true,
38
+ "model_name": "disc_ResNet34",
39
+ "w_d": 0.001
40
+ },
41
+ "learning_prm": {
42
+ "batch_size": 1,
43
+ "lr": 1e-4,
44
+ "drp": 0.05,
45
+ "n_epochs": 10
46
+ },
47
+ "tr_phase": true,
48
+ "model_name": "CF3D_AdvLoss",
49
+ "weight_name": ""
50
+ }
Filtering_results_videos/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Filtering_results_videos/in-vivo/Subject1.mp4 ADDED
Binary file (325 kB). View file
 
Filtering_results_videos/in-vivo/Subject2.mp4 ADDED
Binary file (409 kB). View file
 
Filtering_results_videos/in-vivo/Subject3.mp4 ADDED
Binary file (366 kB). View file
 
Filtering_results_videos/in-vivo/Subject4.mp4 ADDED
Binary file (354 kB). View file
 
Filtering_results_videos/synthetic/GE.mp4 ADDED
Binary file (669 kB). View file
 
Filtering_results_videos/synthetic/Hitachi.mp4 ADDED
Binary file (492 kB). View file
 
Filtering_results_videos/synthetic/Philips.mp4 ADDED
Binary file (530 kB). View file
 
Filtering_results_videos/synthetic/Samsung.mp4 ADDED
Binary file (509 kB). View file
 
Filtering_results_videos/synthetic/Siemens.mp4 ADDED
Binary file (401 kB). View file
 
Filtering_results_videos/synthetic/Toshiba.mp4 ADDED
Binary file (415 kB). View file