Commit
·
6477265
1
Parent(s):
aea2e47
Filtering models and example video clips
Browse files- Filter2D_Lrec/DataGen.py +87 -0
- Filter2D_Lrec/Error_analysis.py +61 -0
- Filter2D_Lrec/Model_ClutterFilter2D.py +128 -0
- Filter2D_Lrec/TestClutterFilter2D.py +45 -0
- Filter2D_Lrec/TrainClutterFilter2D.py +69 -0
- Filter2D_Lrec/config.json +35 -0
- Filter2D_Lrec/utils.py +56 -0
- Filter3D_Lrec/DataGen.py +77 -0
- Filter3D_Lrec/Error_analysis.py +61 -0
- Filter3D_Lrec/Model_ClutterFilter3D.py +128 -0
- Filter3D_Lrec/TestClutterFilter3D.py +45 -0
- Filter3D_Lrec/TrainClutterFilter3D.py +70 -0
- Filter3D_Lrec/config.json +35 -0
- Filter3D_Lrec/utils.py +56 -0
- Filter3D_Lrec_adv/.DS_Store +0 -0
- Filter3D_Lrec_adv/DataGen.py +111 -0
- Filter3D_Lrec_adv/Discriminator.py +70 -0
- Filter3D_Lrec_adv/TestClutterFilter3D_GAN.py +61 -0
- Filter3D_Lrec_adv/TrainClutterFilter3D_L2_adv_loss.py +80 -0
- Filter3D_Lrec_adv/Train_GAN.py +36 -0
- Filter3D_Lrec_adv/config.json +50 -0
- Filtering_results_videos/.DS_Store +0 -0
- Filtering_results_videos/in-vivo/Subject1.mp4 +0 -0
- Filtering_results_videos/in-vivo/Subject2.mp4 +0 -0
- Filtering_results_videos/in-vivo/Subject3.mp4 +0 -0
- Filtering_results_videos/in-vivo/Subject4.mp4 +0 -0
- Filtering_results_videos/synthetic/GE.mp4 +0 -0
- Filtering_results_videos/synthetic/Hitachi.mp4 +0 -0
- Filtering_results_videos/synthetic/Philips.mp4 +0 -0
- Filtering_results_videos/synthetic/Samsung.mp4 +0 -0
- Filtering_results_videos/synthetic/Siemens.mp4 +0 -0
- Filtering_results_videos/synthetic/Toshiba.mp4 +0 -0
Filter2D_Lrec/DataGen.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A module for generating batches of 2D images.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
import scipy.io as sio
|
8 |
+
import tensorflow as tf
|
9 |
+
|
10 |
+
def _shift_first_frame(vol_in, vol_out, tr_phase):
|
11 |
+
n_frames = vol_in.shape[-1]
|
12 |
+
n, p = 1, 0.5 # p is the probability of shifting the first frame.
|
13 |
+
if tr_phase:
|
14 |
+
if np.random.binomial(n,p):
|
15 |
+
first_frm = np.random.permutation(np.arange(n_frames))[0]
|
16 |
+
vol_in = np.concatenate((vol_in[:,:,first_frm:], vol_in[:,:,:first_frm]), axis=-1)
|
17 |
+
vol_out = np.concatenate((vol_out[:,:,first_frm:], vol_out[:,:,:first_frm]), axis=-1)
|
18 |
+
return vol_in, vol_out
|
19 |
+
|
20 |
+
def _image_vol_normalization(vol):
|
21 |
+
vol = vol/255
|
22 |
+
vol[vol < 0] = 0
|
23 |
+
vol[vol > 1] = 1
|
24 |
+
return vol
|
25 |
+
|
26 |
+
def _reshape_vol(vol):
|
27 |
+
new_vol = np.empty([vol.shape[-1], vol.shape[0], vol.shape[1]])
|
28 |
+
for i in range(vol.shape[-1]):
|
29 |
+
new_vol[i,:,:] = vol[:,:,i]
|
30 |
+
return new_vol
|
31 |
+
|
32 |
+
def _image_vol_augmentation(vol_in, vol_out, tr_phase):
|
33 |
+
"""
|
34 |
+
Augmenting and normalizing the input and output image volumes.
|
35 |
+
"""
|
36 |
+
vol_in, vol_out = _shift_first_frame(vol_in, vol_out, tr_phase)
|
37 |
+
vol_in_norm = _image_vol_normalization(vol_in)
|
38 |
+
vol_out_norm = _image_vol_normalization(vol_out)
|
39 |
+
return [_reshape_vol(vol_in_norm), _reshape_vol(vol_out_norm)]
|
40 |
+
|
41 |
+
class DataGen(tf.keras.utils.Sequence):
|
42 |
+
"""
|
43 |
+
Generating batches of input cluttered volumes and their corresponding
|
44 |
+
clutter-free output volumes
|
45 |
+
"""
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
dim:list,
|
49 |
+
in_dir:str,
|
50 |
+
out_dir:str,
|
51 |
+
id_list:list,
|
52 |
+
batch_size:int,
|
53 |
+
tr_phase=True,
|
54 |
+
*args,
|
55 |
+
**kwargs):
|
56 |
+
'Initialization'
|
57 |
+
self.dim = dim
|
58 |
+
self.in_dir = in_dir
|
59 |
+
self.out_dir = out_dir
|
60 |
+
self.id_list = id_list
|
61 |
+
self.batch_size = batch_size
|
62 |
+
self.tr_phase = tr_phase
|
63 |
+
|
64 |
+
def __len__(self):
|
65 |
+
return len(self.id_list)
|
66 |
+
|
67 |
+
def __getitem__(self, idx):
|
68 |
+
'Generate one or more batches of data'
|
69 |
+
|
70 |
+
# Initialization
|
71 |
+
in_out_shape = [self.batch_size, self.dim[0], self.dim[1], self.dim[2]]
|
72 |
+
x_aug, y_aug = np.empty(in_out_shape), np.empty(in_out_shape)
|
73 |
+
vol_id = self.id_list[idx:(idx + 1)]
|
74 |
+
|
75 |
+
# Generate the data
|
76 |
+
for i, ID in enumerate(vol_id):
|
77 |
+
# Store sample
|
78 |
+
vol_in = sio.loadmat(self.in_dir[ID])['data_artf']
|
79 |
+
vol_out = sio.loadmat(self.out_dir[ID])['data_org']
|
80 |
+
# Call the data augmentation function
|
81 |
+
aug_vols = _image_vol_augmentation(vol_in, vol_out, self.tr_phase)
|
82 |
+
x_aug[:,:,:,0], y_aug[:,:,:,0] = aug_vols[0], aug_vols[1]
|
83 |
+
|
84 |
+
if self.tr_phase:
|
85 |
+
return x_aug, y_aug
|
86 |
+
else:
|
87 |
+
return x_aug
|
Filter2D_Lrec/Error_analysis.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A module with functions for computing different MAE and coherence errors.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import scipy.io as sio
|
8 |
+
|
9 |
+
def _compute_sample_temporal_coherency_score(filtered_smp, org_smp):
|
10 |
+
abs_diff_flt_org_smp = np.abs(filtered_smp-org_smp)
|
11 |
+
frm_pixel_sum = [np.sum(abs_diff_flt_org_smp[:,:,i]) for i in range(abs_diff_flt_org_smp.shape[-1])]
|
12 |
+
frm_pixel_sum_shifted = np.roll(frm_pixel_sum, -1)
|
13 |
+
frm_diff = np.abs(frm_pixel_sum - frm_pixel_sum_shifted)
|
14 |
+
frm_diff = frm_diff[:-1]
|
15 |
+
return np.mean(frm_diff)
|
16 |
+
|
17 |
+
def _compute_sample_mae(smp, in_ids, filtered_smp, clutter_class):
|
18 |
+
clt_smp = sio.loadmat(in_ids[smp])['data_artf']
|
19 |
+
org_smp = sio.loadmat(
|
20 |
+
in_ids[smp].split(f'data_{clutter_class}')[0] + 'data_org/1.mat')['data_org']
|
21 |
+
mae_CltFiltered_CltFree = np.mean(np.abs(255*filtered_smp-org_smp))
|
22 |
+
mae_Cltrd_CltFree = np.mean(np.abs(clt_smp-org_smp))
|
23 |
+
temporal_coherency_score = _compute_sample_temporal_coherency_score(255*filtered_smp, org_smp)
|
24 |
+
return mae_CltFiltered_CltFree, mae_Cltrd_CltFree, temporal_coherency_score
|
25 |
+
|
26 |
+
def _make_res_dct():
|
27 |
+
res_dct = {'Clutter_class': [],
|
28 |
+
'Clutter_spec': [],
|
29 |
+
'View': [],
|
30 |
+
'Vendor': [],
|
31 |
+
'MAE_CltFiltered_CltFree': [],
|
32 |
+
'MAE_Cltrd_CltFree': [],
|
33 |
+
'temporal_coherency_score': []
|
34 |
+
}
|
35 |
+
return res_dct
|
36 |
+
|
37 |
+
def _id_separation(in_id):
|
38 |
+
id_part0 = in_id.split('/A')[0].split('/')
|
39 |
+
id_part1 = in_id.split('/data_')[1].split('/')
|
40 |
+
v = [v for v in in_id.split('/') if 'A' in v and 'C' in v]
|
41 |
+
view = v[0]
|
42 |
+
vendor = id_part0[-1]
|
43 |
+
clutter_class = id_part1[0]
|
44 |
+
clutter_spec = id_part1[1]
|
45 |
+
return view, vendor, clutter_class, clutter_spec
|
46 |
+
|
47 |
+
def compute_mae(in_ids, filtered_dta, te_subsample=False, te_frames=50):
|
48 |
+
res_dct = _make_res_dct()
|
49 |
+
for i in range(len(in_ids)):
|
50 |
+
view, vendor, clutter_class, clutter_spec = _id_separation(in_ids[i])
|
51 |
+
res_dct['Clutter_class'].append(clutter_class)
|
52 |
+
res_dct['Clutter_spec'].append(clutter_spec)
|
53 |
+
res_dct['Vendor'].append(vendor)
|
54 |
+
res_dct['View'].append(view)
|
55 |
+
mae_CltFiltered_CltFree, mae_Cltrd_CltFree, temporal_coherency_score = _compute_sample_mae(
|
56 |
+
smp=i, in_ids=in_ids, filtered_smp=filtered_dta[i],
|
57 |
+
clutter_class=clutter_class)
|
58 |
+
res_dct['MAE_CltFiltered_CltFree'].append(mae_CltFiltered_CltFree)
|
59 |
+
res_dct['MAE_Cltrd_CltFree'].append(mae_Cltrd_CltFree)
|
60 |
+
res_dct['temporal_coherency_score'].append(temporal_coherency_score)
|
61 |
+
return pd.DataFrame(res_dct)
|
Filter2D_Lrec/Model_ClutterFilter2D.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Functions of the 2D clutter filtering algorithm.
|
3 |
+
"""
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow.keras import backend as K
|
6 |
+
from tensorflow.keras.layers import (Conv2D, MaxPooling2D, Activation, BatchNormalization, Add,
|
7 |
+
Dropout, Concatenate, UpSampling2D, multiply, Input, Lambda)
|
8 |
+
from tensorflow.keras.models import Model
|
9 |
+
from tensorflow.keras.optimizers import Adam
|
10 |
+
|
11 |
+
def tensor_expansion(tensor, rep, axs):
|
12 |
+
expanded_tensor = Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=axs),
|
13 |
+
arguments={'repnum': rep})(tensor)
|
14 |
+
return expanded_tensor
|
15 |
+
|
16 |
+
def attention_gate_block_2D(x, g, n_inter_filters=None, name=None, **prm):
|
17 |
+
"""
|
18 |
+
Attention gate block.
|
19 |
+
"""
|
20 |
+
shape_x = K.int_shape(x)
|
21 |
+
shape_g = K.int_shape(g)
|
22 |
+
if n_inter_filters is None:
|
23 |
+
n_inter_filters = shape_x[-1] // 2
|
24 |
+
|
25 |
+
theta_x = Conv2D(n_inter_filters, 3, strides=(2, 2), padding='same', name=f"{name}_theta_x")(x)
|
26 |
+
phi_g = Conv2D(n_inter_filters, 1, strides=1, padding='valid', name=f"{name}_phi_g")(g)
|
27 |
+
concat_xg = Add()([phi_g, theta_x])
|
28 |
+
act_xg = Activation('relu')(concat_xg)
|
29 |
+
psi = Conv2D(1, 1, padding='same', name=f"{name}_psi")(act_xg)
|
30 |
+
sigmoid_xg = Activation('sigmoid')(psi)
|
31 |
+
shape_sigmoid = K.int_shape(sigmoid_xg)
|
32 |
+
upsample_sigmoid = UpSampling2D(size=(2, 2), name=f"{name}_upsampled_sig")(sigmoid_xg)
|
33 |
+
upsample_sigmoid_rep = tensor_expansion(upsample_sigmoid, rep=shape_x[-1], axs=-1)
|
34 |
+
y = multiply([upsample_sigmoid_rep, x], name=f"{name}_weighted_x")
|
35 |
+
return y
|
36 |
+
|
37 |
+
def conv_block(x, filters_list, act='linear', kernel_size=3, stride=1, pad='same', drp=0.05, name=None):
|
38 |
+
"""
|
39 |
+
Blocks of 2D conv filters.
|
40 |
+
"""
|
41 |
+
for i in range(len(filters_list)):
|
42 |
+
x = Conv2D(filters_list[i], kernel_size, padding=pad, strides=stride, name=f"{name}_blk{i+1}")(x)
|
43 |
+
x = BatchNormalization(name=f"{name}_bn{i+1}")(x)
|
44 |
+
x = Activation(act, name=f"{name}_act{i+1}")(x)
|
45 |
+
x = Dropout(drp)(x)
|
46 |
+
return x
|
47 |
+
|
48 |
+
def encoding_block(x_in, name, **config):
|
49 |
+
"""
|
50 |
+
Encoding block of the 2D Unet.
|
51 |
+
"""
|
52 |
+
encoding_dct = {}
|
53 |
+
for i in range(config["network_prm"]["n_levels"]):
|
54 |
+
if i == 0:
|
55 |
+
x = x_in
|
56 |
+
n_filters = config["network_prm"]["n_init_filters"]
|
57 |
+
else:
|
58 |
+
n_filters = (2**i)*config["network_prm"]["n_init_filters"]
|
59 |
+
x = MaxPooling2D(pool_size=config["network_prm"]["pool_size"], name=f"{name}_encd_pool{i}")(x)
|
60 |
+
x = conv_block(x, filters_list=[n_filters, 2*n_filters], act=config["network_prm"]["act"],
|
61 |
+
kernel_size=config["network_prm"]["kernel_size"],
|
62 |
+
stride=config["network_prm"]["conv_stride"],
|
63 |
+
pad=config["network_prm"]["padding"],
|
64 |
+
drp=config["learning_prm"]['drp'], name=f"{name}_encd_conv_lvl{i}")
|
65 |
+
encoding_dct[f"{name}_out_lvl{i}"] = x
|
66 |
+
return encoding_dct
|
67 |
+
|
68 |
+
def decoding_block(encoding_dct, name, **config):
|
69 |
+
"""
|
70 |
+
Decoding block of the 2D Unet.
|
71 |
+
"""
|
72 |
+
decoding_dct = {}
|
73 |
+
n_levels = config["network_prm"]["n_levels"]
|
74 |
+
for i in range(n_levels-1):
|
75 |
+
if i == 0:
|
76 |
+
x = encoding_dct[f"{name}_out_lvl{n_levels-i-1}"]
|
77 |
+
# upsampling via Conv(Upsampling)
|
78 |
+
x_shape = K.int_shape(x)
|
79 |
+
x_up = Conv2D(x_shape[-1], 2, activation=config["network_prm"]["act"], padding='same', strides=1,
|
80 |
+
name=f"{name}_decd_upsmpl{i}")(UpSampling2D(size=(2,2))(x))
|
81 |
+
x_up_shape = K.int_shape(x_up)
|
82 |
+
# concatenation
|
83 |
+
if config["network_prm"]['attention']:
|
84 |
+
if i == 0:
|
85 |
+
g = encoding_dct[f"{name}_out_lvl{n_levels-1}"]
|
86 |
+
else:
|
87 |
+
g = decoding_dct[f"{name}_out_lvl{i-1}"]
|
88 |
+
x_encd = attention_gate_block_2D(x=encoding_dct[f"{name}_out_lvl{n_levels-i-2}"],
|
89 |
+
g=g, name=f"{name}_att_blk{i}")
|
90 |
+
else:
|
91 |
+
x_encd = encoding_dct[f"{name}_out_lvl{n_levels-i-2}"]
|
92 |
+
x_concat = Concatenate(axis=-1, name=f"{name}_decd_concat{i}")([x_encd, x_up])
|
93 |
+
n_filters = x_up_shape[-1]//2
|
94 |
+
x = conv_block(x_concat, filters_list=[n_filters, n_filters], act=config["network_prm"]["act"],
|
95 |
+
kernel_size=config["network_prm"]["kernel_size"],
|
96 |
+
stride=config["network_prm"]["conv_stride"],
|
97 |
+
pad=config["network_prm"]["padding"],
|
98 |
+
drp=config["learning_prm"]['drp'], name=f"{name}_decd_conv_lvl{i}")
|
99 |
+
decoding_dct[f"{name}_out_lvl{i}"] = x
|
100 |
+
x = conv_block(x, filters_list=[1], act=config["network_prm"]["act"], kernel_size=1, stride=1,
|
101 |
+
pad='same', drp=1e-4, name=f"{name}_final_decd_conv")
|
102 |
+
decoding_dct[f"{name}_final_conv"] = x
|
103 |
+
return decoding_dct
|
104 |
+
|
105 |
+
def Unet2D(x_in, name, **config):
|
106 |
+
"""
|
107 |
+
Spatial clutter filtering based on 2D Unet.
|
108 |
+
"""
|
109 |
+
encoding_dct = encoding_block(x_in, name, **config)
|
110 |
+
decoding_dct = decoding_block(encoding_dct, name, **config)
|
111 |
+
# Add the Main input here
|
112 |
+
if config["network_prm"]["in_skip"]:
|
113 |
+
out_Unet = Add()([x_in, decoding_dct[f"{name}_final_conv"]])
|
114 |
+
else:
|
115 |
+
out_Unet = decoding_dct[f"{name}_final_conv"]
|
116 |
+
return out_Unet
|
117 |
+
|
118 |
+
def clutter_filter_2D(**config):
|
119 |
+
"""
|
120 |
+
The main function for designing the clutter filtering algorithm.
|
121 |
+
"""
|
122 |
+
main_in = Input(config["network_prm"]["input_dim"])
|
123 |
+
filter_out = Unet2D(x_in=main_in, name="CF", **config)
|
124 |
+
model = Model(inputs=main_in, outputs=filter_out, name=config['model_name'])
|
125 |
+
opt = Adam(learning_rate=config["learning_prm"]["lr"])
|
126 |
+
model.compile(optimizer=opt, loss=config["learning_prm"]["loss"], metrics=config["learning_prm"]["metrics"])
|
127 |
+
model.summary()
|
128 |
+
return model
|
Filter2D_Lrec/TestClutterFilter2D.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for testing the 3D clutter filtering model.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
from utils import *
|
11 |
+
from Model_ClutterFilter2D import clutter_filter_2D
|
12 |
+
from DataGen import DataGen
|
13 |
+
from Error_analysis import compute_mae
|
14 |
+
|
15 |
+
def data_generation(in_ids_te, out_ids_te, config):
|
16 |
+
DtaGenTe_prm = {
|
17 |
+
'dim': config["network_prm"]["input_dim"],
|
18 |
+
'in_dir': in_ids_te,
|
19 |
+
'out_dir': out_ids_te,
|
20 |
+
'id_list': np.arange(len(in_ids_te)),
|
21 |
+
'batch_size': config["learning_prm"]["batch_size"],
|
22 |
+
'tr_phase': False}
|
23 |
+
return DataGen(**DtaGenTe_prm)
|
24 |
+
|
25 |
+
def main(config):
|
26 |
+
in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config)
|
27 |
+
te_gen = data_generation(in_ids_te, out_ids_te, config)
|
28 |
+
model = clutter_filter_2D(**config)
|
29 |
+
weight_dir = create_weight_dir(val_subject, te_subject, config)
|
30 |
+
model.load_weights(
|
31 |
+
os.path.join(weight_dir, config["weight_name"] + ".hdf5"))
|
32 |
+
results_te = model.predict_generator(te_gen, verbose=2)
|
33 |
+
df_errors = compute_mae(in_ids_te, results_te)
|
34 |
+
df_errors.to_csv(
|
35 |
+
os.path.join(weight_dir, config["weight_name"] + ".csv"))
|
36 |
+
return None
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
parser = argparse.ArgumentParser()
|
40 |
+
parser.add_argument("--config", help="path of the config file", default="config.json")
|
41 |
+
args = parser.parse_args()
|
42 |
+
assert os.path.isfile(args.config)
|
43 |
+
with open(args.config, "r") as read_file:
|
44 |
+
config = json.load(read_file)
|
45 |
+
main(config)
|
Filter2D_Lrec/TrainClutterFilter2D.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for training the 2D clutter filtering model with L2 loss.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
|
9 |
+
|
10 |
+
from utils import *
|
11 |
+
from Model_ClutterFilter2D import clutter_filter_2D
|
12 |
+
from DataGen import DataGen
|
13 |
+
|
14 |
+
def data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config):
|
15 |
+
DtaGenTr_prm = {
|
16 |
+
'dim': config["network_prm"]["input_dim"],
|
17 |
+
'in_dir': in_ids_tr,
|
18 |
+
'out_dir': out_ids_tr,
|
19 |
+
'id_list': np.arange(len(in_ids_tr)),
|
20 |
+
'batch_size': config["learning_prm"]["batch_size"],
|
21 |
+
'tr_phase': True}
|
22 |
+
DtaGenVal_prm = {
|
23 |
+
'dim': config["network_prm"]["input_dim"],
|
24 |
+
'in_dir': in_ids_val,
|
25 |
+
'out_dir': out_ids_val,
|
26 |
+
'id_list': np.arange(len(in_ids_val)),
|
27 |
+
'batch_size': config["learning_prm"]["batch_size"],
|
28 |
+
'tr_phase': True}
|
29 |
+
tr_gen = DataGen(**DtaGenTr_prm)
|
30 |
+
val_gen = DataGen(**DtaGenVal_prm)
|
31 |
+
return tr_gen, val_gen
|
32 |
+
|
33 |
+
def model_chkpnt(val_subject, te_subject, weight_dir, config):
|
34 |
+
weight_name = (
|
35 |
+
f'CF2D_ValTeSbj_{val_subject}_{te_subject}_nLvl{config["network_prm"]["n_levels"]}'
|
36 |
+
f'_InSkp{config["network_prm"]["in_skip"]}_Att{config["network_prm"]["attention"]}'
|
37 |
+
f'_Act{config["network_prm"]["act"]}_nInitFlt{config["network_prm"]["n_init_filters"]}_lr{config["learning_prm"]["lr"]}')
|
38 |
+
filepath = (weight_dir + '/'+ weight_name +
|
39 |
+
'_epc' + "{epoch:03d}" + '_trloss' + "{loss:.5f}" +
|
40 |
+
'_valloss' + "{val_loss:.5f}" + ".hdf5")
|
41 |
+
model_checkpoint = ModelCheckpoint(filepath=filepath,
|
42 |
+
monitor="val_loss",
|
43 |
+
verbose=0,
|
44 |
+
save_best_only=True)
|
45 |
+
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
|
46 |
+
patience=4, min_lr=1e-7)
|
47 |
+
return model_checkpoint, reduce_lr
|
48 |
+
|
49 |
+
def main(config):
|
50 |
+
in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config)
|
51 |
+
weight_dir = create_weight_dir(val_subject, te_subject, config)
|
52 |
+
tr_gen, val_gen = data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config)
|
53 |
+
model = clutter_filter_2D(**config)
|
54 |
+
model_checkpoint, reduce_lr = model_chkpnt(val_subject, te_subject, weight_dir, config)
|
55 |
+
model.fit(tr_gen,
|
56 |
+
validation_data=val_gen,
|
57 |
+
epochs=config["learning_prm"]["n_epochs"],
|
58 |
+
verbose=1,
|
59 |
+
callbacks=[model_checkpoint, reduce_lr])
|
60 |
+
return None
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
parser = argparse.ArgumentParser()
|
64 |
+
parser.add_argument("--config", help="path of the config file", default="config.json")
|
65 |
+
args = parser.parse_args()
|
66 |
+
assert os.path.isfile(args.config)
|
67 |
+
with open(args.config, "r") as read_file:
|
68 |
+
config = json.load(read_file)
|
69 |
+
main(config)
|
Filter2D_Lrec/config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"paths": {
|
3 |
+
"data_path": "",
|
4 |
+
"save_path": ""
|
5 |
+
},
|
6 |
+
"subject_list": ["rca", "ladprox", "laddist", "lcx", "normal"],
|
7 |
+
"CV": {
|
8 |
+
"val_subject_id": 0
|
9 |
+
},
|
10 |
+
"network_prm": {
|
11 |
+
"input_dim": [128, 128, 1],
|
12 |
+
"n_levels": 4,
|
13 |
+
"n_init_filters": 16,
|
14 |
+
"in_skip": true,
|
15 |
+
"attention": true,
|
16 |
+
"kernel_size": 3,
|
17 |
+
"conv_stride": 1,
|
18 |
+
"upsampling_stride": [2, 2],
|
19 |
+
"pool_size": [2, 2],
|
20 |
+
"pool_stride": 1,
|
21 |
+
"padding": "same",
|
22 |
+
"act": "linear"
|
23 |
+
},
|
24 |
+
"learning_prm": {
|
25 |
+
"batch_size": 50,
|
26 |
+
"lr": 1e-4,
|
27 |
+
"drp": 0.05,
|
28 |
+
"loss": "mean_squared_error",
|
29 |
+
"metrics": ["mae"],
|
30 |
+
"n_epochs": 10
|
31 |
+
},
|
32 |
+
"tr_phase": true,
|
33 |
+
"model_name": "CF2D_L2Loss",
|
34 |
+
"weight_name": ""
|
35 |
+
}
|
Filter2D_Lrec/utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility functions.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def generate_tr_val_te_subject_ids(subject_list, val_subject_id):
|
8 |
+
val_subject = subject_list[val_subject_id]
|
9 |
+
te_subject = subject_list[val_subject_id-1]
|
10 |
+
subject_list.remove(val_subject)
|
11 |
+
subject_list.remove(te_subject)
|
12 |
+
tr_subjects = subject_list
|
13 |
+
return tr_subjects, val_subject, te_subject
|
14 |
+
|
15 |
+
def generate_data_ids(data_dir, subject_list):
|
16 |
+
in_ids, out_ids = [], []
|
17 |
+
vendor_list = [vendor for vendor in os.listdir(data_dir) if '.' not in vendor]
|
18 |
+
for vendor in vendor_list:
|
19 |
+
vendor_dir = os.path.join(data_dir, vendor)
|
20 |
+
view_list = [view for view in os.listdir(vendor_dir) if '.' not in view]
|
21 |
+
for view in view_list:
|
22 |
+
view_dir = os.path.join(vendor_dir, view)
|
23 |
+
subject_full_list = [subject for subject in os.listdir(view_dir) if '.' not in subject]
|
24 |
+
for subject in subject_full_list:
|
25 |
+
if subject in subject_list:
|
26 |
+
subject_dir = os.path.join(view_dir, subject)
|
27 |
+
org_data_dir = os.path.join(subject_dir, 'data_org')
|
28 |
+
org_data_id = os.path.join(org_data_dir, os.listdir(org_data_dir)[0])
|
29 |
+
clutter_list = [clutter for clutter in os.listdir(subject_dir)
|
30 |
+
if clutter in ['data_NFClt', 'data_NFRvbClt', 'data_RvbClt']
|
31 |
+
and '.' not in clutter]
|
32 |
+
for clutter in clutter_list:
|
33 |
+
clutter_dir = os.path.join(subject_dir, clutter)
|
34 |
+
clutter_ids = os.listdir(clutter_dir)
|
35 |
+
clutter_ids_dir = [os.path.join(clutter_dir, id_dir) for id_dir in clutter_ids if '.DS' not in id_dir]
|
36 |
+
in_ids += clutter_ids_dir
|
37 |
+
out_ids += [org_data_id]*len(os.listdir(clutter_dir))
|
38 |
+
return in_ids, out_ids
|
39 |
+
|
40 |
+
def id_preparation(config):
|
41 |
+
tr_subjects, val_subject, te_subject = generate_tr_val_te_subject_ids(
|
42 |
+
subject_list=config["subject_list"], val_subject_id=config["CV"]["val_subject_id"])
|
43 |
+
if config["tr_phase"]:
|
44 |
+
in_ids_tr, out_ids_tr = generate_data_ids(config["paths"]["data_path"], tr_subjects)
|
45 |
+
in_ids_val, out_ids_val = generate_data_ids(config["paths"]["data_path"], val_subject)
|
46 |
+
return in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject
|
47 |
+
else:
|
48 |
+
in_ids_te, out_ids_te = generate_data_ids(config["paths"]["data_path"], te_subject)
|
49 |
+
return in_ids_te, out_ids_te, te_subject, val_subject
|
50 |
+
|
51 |
+
def create_weight_dir(val_subject, te_subject, config):
|
52 |
+
weight_dir = os.path.join(config["paths"]["save_path"],
|
53 |
+
"Weights", f"ValTeIDs_{val_subject}_{te_subject}")
|
54 |
+
if not os.path.exists(weight_dir):
|
55 |
+
os.makedirs(weight_dir)
|
56 |
+
return weight_dir
|
Filter3D_Lrec/DataGen.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for generating batches of 2D images.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
import scipy.io as sio
|
8 |
+
import tensorflow as tf
|
9 |
+
|
10 |
+
def _shift_first_frame(vol_in, vol_out, tr_phase):
|
11 |
+
n_frames = vol_in.shape[-1]
|
12 |
+
n, p = 1, 0.5 # p is the probability of shifting the first frame.
|
13 |
+
if tr_phase:
|
14 |
+
if np.random.binomial(n,p):
|
15 |
+
first_frm = np.random.permutation(np.arange(n_frames))[0]
|
16 |
+
vol_in = np.concatenate((vol_in[:,:,first_frm:], vol_in[:,:,:first_frm]), axis=-1)
|
17 |
+
vol_out = np.concatenate((vol_out[:,:,first_frm:], vol_out[:,:,:first_frm]), axis=-1)
|
18 |
+
return vol_in, vol_out
|
19 |
+
|
20 |
+
def _image_vol_normalization(vol):
|
21 |
+
vol = vol/255
|
22 |
+
vol[vol < 0] = 0
|
23 |
+
vol[vol > 1] = 1
|
24 |
+
return vol
|
25 |
+
|
26 |
+
def _image_vol_augmentation(vol_in, vol_out, tr_phase):
|
27 |
+
vol_in, vol_out = _shift_first_frame(vol_in, vol_out, tr_phase)
|
28 |
+
vol_in_norm = _image_vol_normalization(vol_in)
|
29 |
+
vol_out_norm = _image_vol_normalization(vol_out)
|
30 |
+
vol_shape = [sh for sh in vol_in.shape]
|
31 |
+
vol_shape.append(1)
|
32 |
+
vol_in, vol_out = np.empty(vol_shape), np.empty(vol_shape)
|
33 |
+
vol_in[:,:,:,0], vol_out[:,:,:,0] = vol_in_norm, vol_out_norm
|
34 |
+
return [vol_in, vol_out]
|
35 |
+
|
36 |
+
class DataGen(tf.keras.utils.Sequence):
|
37 |
+
"""
|
38 |
+
Generating batches of input cluttered volumes and their corresponding
|
39 |
+
clutter-free output volumes
|
40 |
+
"""
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
dim:list,
|
44 |
+
in_dir:str,
|
45 |
+
out_dir:str,
|
46 |
+
id_list:list,
|
47 |
+
batch_size:int,
|
48 |
+
tr_phase=True,
|
49 |
+
te_subsample=False,
|
50 |
+
te_frames=0,
|
51 |
+
*args,
|
52 |
+
**kwargs):
|
53 |
+
'Initialization'
|
54 |
+
self.dim = dim
|
55 |
+
self.in_dir = in_dir
|
56 |
+
self.out_dir = out_dir
|
57 |
+
self.id_list = id_list
|
58 |
+
self.batch_size = batch_size
|
59 |
+
self.tr_phase = tr_phase
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
return int(np.floor(len(self.id_list) / self.batch_size))
|
63 |
+
|
64 |
+
def __getitem__(self, idx):
|
65 |
+
batch = self.id_list[idx*self.batch_size:(idx+1)*self.batch_size]
|
66 |
+
x_aug, y_aug = [], []
|
67 |
+
for i, ID in enumerate(batch):
|
68 |
+
vol_in = sio.loadmat(self.in_dir[ID])['data_artf']
|
69 |
+
vol_out = sio.loadmat(self.out_dir[ID])['data_org']
|
70 |
+
# Call the data augmentation function
|
71 |
+
aug_vols = _image_vol_augmentation(vol_in, vol_out, self.tr_phase)
|
72 |
+
x_aug.append(aug_vols[0])
|
73 |
+
y_aug.append(aug_vols[1])
|
74 |
+
if self.tr_phase:
|
75 |
+
return np.asarray(x_aug), np.asarray(y_aug)
|
76 |
+
else:
|
77 |
+
return np.asarray(x_aug)
|
Filter3D_Lrec/Error_analysis.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A module with functions for computing different MAE and coherence errors.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import scipy.io as sio
|
8 |
+
|
9 |
+
def _compute_sample_temporal_coherency_score(filtered_smp, org_smp):
|
10 |
+
abs_diff_flt_org_smp = np.abs(filtered_smp-org_smp)
|
11 |
+
frm_pixel_sum = [np.sum(abs_diff_flt_org_smp[:,:,i]) for i in range(abs_diff_flt_org_smp.shape[-1])]
|
12 |
+
frm_pixel_sum_shifted = np.roll(frm_pixel_sum, -1)
|
13 |
+
frm_diff = np.abs(frm_pixel_sum - frm_pixel_sum_shifted)
|
14 |
+
frm_diff = frm_diff[:-1]
|
15 |
+
return np.mean(frm_diff)
|
16 |
+
|
17 |
+
def _compute_sample_mae(smp, in_ids, filtered_smp, clutter_class):
|
18 |
+
clt_smp = sio.loadmat(in_ids[smp])['data_artf']
|
19 |
+
org_smp = sio.loadmat(
|
20 |
+
in_ids[smp].split(f'data_{clutter_class}')[0] + 'data_org/1.mat')['data_org']
|
21 |
+
mae_CltFiltered_CltFree = np.mean(np.abs(255*filtered_smp-org_smp))
|
22 |
+
mae_Cltrd_CltFree = np.mean(np.abs(clt_smp-org_smp))
|
23 |
+
temporal_coherency_score = _compute_sample_temporal_coherency_score(255*filtered_smp, org_smp)
|
24 |
+
return mae_CltFiltered_CltFree, mae_Cltrd_CltFree, temporal_coherency_score
|
25 |
+
|
26 |
+
def _make_res_dct():
|
27 |
+
res_dct = {'Clutter_class': [],
|
28 |
+
'Clutter_spec': [],
|
29 |
+
'View': [],
|
30 |
+
'Vendor': [],
|
31 |
+
'MAE_CltFiltered_CltFree': [],
|
32 |
+
'MAE_Cltrd_CltFree': [],
|
33 |
+
'temporal_coherency_score': []
|
34 |
+
}
|
35 |
+
return res_dct
|
36 |
+
|
37 |
+
def _id_separation(in_id):
|
38 |
+
id_part0 = in_id.split('/A')[0].split('/')
|
39 |
+
id_part1 = in_id.split('/data_')[1].split('/')
|
40 |
+
v = [v for v in in_id.split('/') if 'A' in v and 'C' in v]
|
41 |
+
view = v[0]
|
42 |
+
vendor = id_part0[-1]
|
43 |
+
clutter_class = id_part1[0]
|
44 |
+
clutter_spec = id_part1[1]
|
45 |
+
return view, vendor, clutter_class, clutter_spec
|
46 |
+
|
47 |
+
def compute_mae(in_ids, filtered_dta, te_subsample=False, te_frames=50):
|
48 |
+
res_dct = _make_res_dct()
|
49 |
+
for i in range(len(in_ids)):
|
50 |
+
view, vendor, clutter_class, clutter_spec = _id_separation(in_ids[i])
|
51 |
+
res_dct['Clutter_class'].append(clutter_class)
|
52 |
+
res_dct['Clutter_spec'].append(clutter_spec)
|
53 |
+
res_dct['Vendor'].append(vendor)
|
54 |
+
res_dct['View'].append(view)
|
55 |
+
mae_CltFiltered_CltFree, mae_Cltrd_CltFree, temporal_coherency_score = _compute_sample_mae(
|
56 |
+
smp=i, in_ids=in_ids, filtered_smp=filtered_dta[i,:,:,:,0],
|
57 |
+
clutter_class=clutter_class)
|
58 |
+
res_dct['MAE_CltFiltered_CltFree'].append(mae_CltFiltered_CltFree)
|
59 |
+
res_dct['MAE_Cltrd_CltFree'].append(mae_Cltrd_CltFree)
|
60 |
+
res_dct['temporal_coherency_score'].append(temporal_coherency_score)
|
61 |
+
return pd.DataFrame(res_dct)
|
Filter3D_Lrec/Model_ClutterFilter3D.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Functions of the 3D clutter filtering algorithm.
|
3 |
+
"""
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow.keras import backend as K
|
6 |
+
from tensorflow.keras.layers import (Conv3D, MaxPooling3D, Activation, BatchNormalization, Add,
|
7 |
+
Dropout, Concatenate, UpSampling3D, multiply, Input, Lambda)
|
8 |
+
from tensorflow.keras.models import Model
|
9 |
+
from tensorflow.keras.optimizers import Adam
|
10 |
+
|
11 |
+
def tensor_expansion(tensor, rep, axs):
|
12 |
+
expanded_tensor = Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=axs),
|
13 |
+
arguments={'repnum': rep})(tensor)
|
14 |
+
return expanded_tensor
|
15 |
+
|
16 |
+
def attention_gate_block_3D(x, g, n_inter_filters=None, name=None, **prm):
|
17 |
+
"""
|
18 |
+
Attention gate block.
|
19 |
+
"""
|
20 |
+
shape_x = K.int_shape(x)
|
21 |
+
shape_g = K.int_shape(g)
|
22 |
+
if n_inter_filters is None:
|
23 |
+
n_inter_filters = shape_x[-1] // 2
|
24 |
+
|
25 |
+
theta_x = Conv3D(n_inter_filters, 3, strides=(2, 2, 1), padding='same', name=f"{name}_theta_x")(x)
|
26 |
+
phi_g = Conv3D(n_inter_filters, 1, strides=1, padding='valid', name=f"{name}_phi_g")(g)
|
27 |
+
concat_xg = Add()([phi_g, theta_x])
|
28 |
+
act_xg = Activation('relu')(concat_xg)
|
29 |
+
psi = Conv3D(1, 1, padding='same', name=f"{name}_psi")(act_xg)
|
30 |
+
sigmoid_xg = Activation('sigmoid')(psi)
|
31 |
+
shape_sigmoid = K.int_shape(sigmoid_xg)
|
32 |
+
upsample_sigmoid = UpSampling3D(size=(2, 2, 1), name=f"{name}_upsampled_sig")(sigmoid_xg)
|
33 |
+
upsample_sigmoid_rep = tensor_expansion(upsample_sigmoid, rep=shape_x[-1], axs=-1)
|
34 |
+
y = multiply([upsample_sigmoid_rep, x], name=f"{name}_weighted_x")
|
35 |
+
return y
|
36 |
+
|
37 |
+
def conv_block(x, filters_list, act='linear', kernel_size=3, stride=1, pad='same', drp=0.05, name=None):
|
38 |
+
"""
|
39 |
+
Blocks of 3D conv filters.
|
40 |
+
"""
|
41 |
+
for i in range(len(filters_list)):
|
42 |
+
x = Conv3D(filters_list[i], kernel_size, padding=pad, strides=stride, name=f"{name}_blk{i+1}")(x)
|
43 |
+
x = BatchNormalization(name=f"{name}_bn{i+1}")(x)
|
44 |
+
x = Activation(act, name=f"{name}_act{i+1}")(x)
|
45 |
+
x = Dropout(drp)(x)
|
46 |
+
return x
|
47 |
+
|
48 |
+
def encoding_block(x_in, name, **config):
|
49 |
+
"""
|
50 |
+
Encoding block of the 3D Unet.
|
51 |
+
"""
|
52 |
+
encoding_dct = {}
|
53 |
+
for i in range(config["network_prm"]["n_levels"]):
|
54 |
+
if i == 0:
|
55 |
+
x = x_in
|
56 |
+
n_filters = config["network_prm"]["n_init_filters"]
|
57 |
+
else:
|
58 |
+
n_filters = (2**i)*config["network_prm"]["n_init_filters"]
|
59 |
+
x = MaxPooling3D(pool_size=config["network_prm"]["pool_size"], name=f"{name}_encd_pool{i}")(x)
|
60 |
+
x = conv_block(x, filters_list=[n_filters, 2*n_filters], act=config["network_prm"]["act"],
|
61 |
+
kernel_size=config["network_prm"]["kernel_size"],
|
62 |
+
stride=config["network_prm"]["conv_stride"],
|
63 |
+
pad=config["network_prm"]["padding"],
|
64 |
+
drp=config["learning_prm"]['drp'], name=f"{name}_encd_conv_lvl{i}")
|
65 |
+
encoding_dct[f"{name}_out_lvl{i}"] = x
|
66 |
+
return encoding_dct
|
67 |
+
|
68 |
+
def decoding_block(encoding_dct, name, **config):
|
69 |
+
"""
|
70 |
+
Decoding block of the 3D Unet.
|
71 |
+
"""
|
72 |
+
decoding_dct = {}
|
73 |
+
n_levels = config["network_prm"]["n_levels"]
|
74 |
+
for i in range(n_levels-1):
|
75 |
+
if i == 0:
|
76 |
+
x = encoding_dct[f"{name}_out_lvl{n_levels-i-1}"]
|
77 |
+
# upsampling via Conv(Upsampling)
|
78 |
+
x_shape = K.int_shape(x)
|
79 |
+
x_up = Conv3D(x_shape[-1], 2, activation=config["network_prm"]["act"], padding='same', strides=1,
|
80 |
+
name=f"{name}_decd_upsmpl{i}")(UpSampling3D(size=(2,2,1))(x))
|
81 |
+
x_up_shape = K.int_shape(x_up)
|
82 |
+
# concatenation
|
83 |
+
if config["network_prm"]['attention']:
|
84 |
+
if i == 0:
|
85 |
+
g = encoding_dct[f"{name}_out_lvl{n_levels-1}"]
|
86 |
+
else:
|
87 |
+
g = decoding_dct[f"{name}_out_lvl{i-1}"]
|
88 |
+
x_encd = attention_gate_block_3D(x=encoding_dct[f"{name}_out_lvl{n_levels-i-2}"],
|
89 |
+
g=g, name=f"{name}_att_blk{i}")
|
90 |
+
else:
|
91 |
+
x_encd = encoding_dct[f"{name}_out_lvl{n_levels-i-2}"]
|
92 |
+
x_concat = Concatenate(axis=-1, name=f"{name}_decd_concat{i}")([x_encd, x_up])
|
93 |
+
n_filters = x_up_shape[-1]//2
|
94 |
+
x = conv_block(x_concat, filters_list=[n_filters, n_filters], act=config["network_prm"]["act"],
|
95 |
+
kernel_size=config["network_prm"]["kernel_size"],
|
96 |
+
stride=config["network_prm"]["conv_stride"],
|
97 |
+
pad=config["network_prm"]["padding"],
|
98 |
+
drp=config["learning_prm"]['drp'], name=f"{name}_decd_conv_lvl{i}")
|
99 |
+
decoding_dct[f"{name}_out_lvl{i}"] = x
|
100 |
+
x = conv_block(x, filters_list=[1], act=config["network_prm"]["act"], kernel_size=1, stride=1,
|
101 |
+
pad='same', drp=1e-4, name=f"{name}_final_decd_conv")
|
102 |
+
decoding_dct[f"{name}_final_conv"] = x
|
103 |
+
return decoding_dct
|
104 |
+
|
105 |
+
def Unet3D(x_in, name, **config):
|
106 |
+
"""
|
107 |
+
Spatiotemporal clutter filtering model based on the 3D Unet.
|
108 |
+
"""
|
109 |
+
encoding_dct = encoding_block(x_in, name, **config)
|
110 |
+
decoding_dct = decoding_block(encoding_dct, name, **config)
|
111 |
+
if config["network_prm"]["in_skip"]:
|
112 |
+
out_Unet = Add()([x_in, decoding_dct[f"{name}_final_conv"]])
|
113 |
+
else:
|
114 |
+
out_Unet = decoding_dct[f"{name}_final_conv"]
|
115 |
+
return out_Unet
|
116 |
+
|
117 |
+
def clutter_filter_3D(**config):
|
118 |
+
"""
|
119 |
+
The main function for designing the clutter filtering algorithm.
|
120 |
+
"""
|
121 |
+
main_in = Input(config["network_prm"]["input_dim"])
|
122 |
+
filter_out = Unet3D(x_in=main_in, name="CF", **config)
|
123 |
+
model = Model(inputs=main_in, outputs=filter_out, name=config['model_name'])
|
124 |
+
opt = Adam(learning_rate=config["learning_prm"]["lr"])
|
125 |
+
model.compile(optimizer=opt, loss=config["learning_prm"]["loss"],
|
126 |
+
metrics=config["learning_prm"]["metrics"])
|
127 |
+
model.summary()
|
128 |
+
return model
|
Filter3D_Lrec/TestClutterFilter3D.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for testing the 3D clutter filtering model.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
from utils import *
|
11 |
+
from Model_ClutterFilter3D import clutter_filter_3D
|
12 |
+
from DataGen import DataGen
|
13 |
+
from Error_analysis import compute_mae
|
14 |
+
|
15 |
+
def data_generation(in_ids_te, out_ids_te, config):
|
16 |
+
DtaGenTe_prm = {
|
17 |
+
'dim': config["network_prm"]["input_dim"],
|
18 |
+
'in_dir': in_ids_te,
|
19 |
+
'out_dir': out_ids_te,
|
20 |
+
'id_list': np.arange(len(in_ids_te)),
|
21 |
+
'batch_size': config["learning_prm"]["batch_size"],
|
22 |
+
'tr_phase': False}
|
23 |
+
return DataGen(**DtaGenTe_prm)
|
24 |
+
|
25 |
+
def main(config):
|
26 |
+
in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config)
|
27 |
+
te_gen = data_generation(in_ids_te, out_ids_te, config)
|
28 |
+
model = clutter_filter_3D(**config)
|
29 |
+
weight_dir = create_weight_dir(val_subject, te_subject, config)
|
30 |
+
model.load_weights(
|
31 |
+
os.path.join(weight_dir, config["weight_name"] + ".hdf5"))
|
32 |
+
results_te = model.predict_generator(te_gen, verbose=2)
|
33 |
+
df_errors = compute_mae(in_ids_te, results_te)
|
34 |
+
df_errors.to_csv(
|
35 |
+
os.path.join(weight_dir, config["weight_name"] + ".csv"))
|
36 |
+
return None
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
parser = argparse.ArgumentParser()
|
40 |
+
parser.add_argument("--config", help="path of the config file", default="config.json")
|
41 |
+
args = parser.parse_args()
|
42 |
+
assert os.path.isfile(args.config)
|
43 |
+
with open(args.config, "r") as read_file:
|
44 |
+
config = json.load(read_file)
|
45 |
+
main(config)
|
Filter3D_Lrec/TrainClutterFilter3D.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for training the 3D clutter filtering model with L2 loss.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
|
9 |
+
|
10 |
+
from utils import *
|
11 |
+
from Model_ClutterFilter3D import clutter_filter_3D
|
12 |
+
from DataGen import DataGen
|
13 |
+
|
14 |
+
def data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config):
|
15 |
+
DtaGenTr_prm = {
|
16 |
+
'dim': config["network_prm"]["input_dim"],
|
17 |
+
'in_dir': in_ids_tr,
|
18 |
+
'out_dir': out_ids_tr,
|
19 |
+
'id_list': np.arange(len(in_ids_tr)),
|
20 |
+
'batch_size': config["learning_prm"]["batch_size"],
|
21 |
+
'tr_phase': True}
|
22 |
+
DtaGenVal_prm = {
|
23 |
+
'dim': config["network_prm"]["input_dim"],
|
24 |
+
'in_dir': in_ids_val,
|
25 |
+
'out_dir': out_ids_val,
|
26 |
+
'id_list': np.arange(len(in_ids_val)),
|
27 |
+
'batch_size': config["learning_prm"]["batch_size"],
|
28 |
+
'tr_phase': True}
|
29 |
+
tr_gen = DataGen(**DtaGenTr_prm)
|
30 |
+
val_gen = DataGen(**DtaGenVal_prm)
|
31 |
+
return tr_gen, val_gen
|
32 |
+
|
33 |
+
def model_chkpnt(val_subject, te_subject, weight_dir, config):
|
34 |
+
weight_name = (
|
35 |
+
f'CF3D_ValTeSbj_{val_subject}_{te_subject}_nLvl{config["network_prm"]["n_levels"]}'
|
36 |
+
f'_InSkp{config["network_prm"]["in_skip"]}_Att{config["network_prm"]["attention"]}'
|
37 |
+
f'_Act{config["network_prm"]["act"]}_nInitFlt{config["network_prm"]["n_init_filters"]}'
|
38 |
+
f'_lr{config["learning_prm"]["lr"]}')
|
39 |
+
filepath = (weight_dir + '/'+ weight_name +
|
40 |
+
'_epc' + "{epoch:03d}" + '_trloss' + "{loss:.5f}" +
|
41 |
+
'_valloss' + "{val_loss:.5f}" + ".hdf5")
|
42 |
+
model_checkpoint = ModelCheckpoint(filepath=filepath,
|
43 |
+
monitor="val_loss",
|
44 |
+
verbose=0,
|
45 |
+
save_best_only=True)
|
46 |
+
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
|
47 |
+
patience=4, min_lr=1e-7)
|
48 |
+
return model_checkpoint, reduce_lr
|
49 |
+
|
50 |
+
def main(config):
|
51 |
+
in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config)
|
52 |
+
weight_dir = create_weight_dir(val_subject, te_subject, config)
|
53 |
+
tr_gen, val_gen = data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config)
|
54 |
+
model = clutter_filter_3D(**config)
|
55 |
+
model_checkpoint, reduce_lr = model_chkpnt(val_subject, te_subject, weight_dir, config)
|
56 |
+
model.fit(tr_gen,
|
57 |
+
validation_data=val_gen,
|
58 |
+
epochs=config["learning_prm"]["n_epochs"],
|
59 |
+
verbose=1,
|
60 |
+
callbacks=[model_checkpoint, reduce_lr])
|
61 |
+
return None
|
62 |
+
|
63 |
+
if __name__ == '__main__':
|
64 |
+
parser = argparse.ArgumentParser()
|
65 |
+
parser.add_argument("--config", help="path of the config file", default="config.json")
|
66 |
+
args = parser.parse_args()
|
67 |
+
assert os.path.isfile(args.config)
|
68 |
+
with open(args.config, "r") as read_file:
|
69 |
+
config = json.load(read_file)
|
70 |
+
main(config)
|
Filter3D_Lrec/config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"paths": {
|
3 |
+
"data_path": "",
|
4 |
+
"save_path": ""
|
5 |
+
},
|
6 |
+
"subject_list": ["rca", "ladprox", "laddist", "lcx", "normal"],
|
7 |
+
"CV": {
|
8 |
+
"val_subject_id": 0
|
9 |
+
},
|
10 |
+
"network_prm": {
|
11 |
+
"input_dim": [128, 128, 50, 1],
|
12 |
+
"n_levels": 4,
|
13 |
+
"n_init_filters": 16,
|
14 |
+
"in_skip": true,
|
15 |
+
"attention": true,
|
16 |
+
"kernel_size": 3,
|
17 |
+
"conv_stride": 1,
|
18 |
+
"upsampling_stride": [2, 2, 1],
|
19 |
+
"pool_size": [2, 2, 1],
|
20 |
+
"pool_stride": 1,
|
21 |
+
"padding": "same",
|
22 |
+
"act": "linear"
|
23 |
+
},
|
24 |
+
"learning_prm": {
|
25 |
+
"batch_size": 1,
|
26 |
+
"lr": 1e-4,
|
27 |
+
"drp": 0.05,
|
28 |
+
"loss": "mean_squared_error",
|
29 |
+
"metrics": ["mae"],
|
30 |
+
"n_epochs": 10
|
31 |
+
},
|
32 |
+
"tr_phase": true,
|
33 |
+
"model_name": "CF3D_L2Loss",
|
34 |
+
"weight_name": ""
|
35 |
+
}
|
Filter3D_Lrec/utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility functions.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def generate_tr_val_te_subject_ids(subject_list, val_subject_id):
|
8 |
+
val_subject = subject_list[val_subject_id]
|
9 |
+
te_subject = subject_list[val_subject_id-1]
|
10 |
+
subject_list.remove(val_subject)
|
11 |
+
subject_list.remove(te_subject)
|
12 |
+
tr_subjects = subject_list
|
13 |
+
return tr_subjects, val_subject, te_subject
|
14 |
+
|
15 |
+
def generate_data_ids(data_dir, subject_list):
|
16 |
+
in_ids, out_ids = [], []
|
17 |
+
vendor_list = [vendor for vendor in os.listdir(data_dir) if '.' not in vendor]
|
18 |
+
for vendor in vendor_list:
|
19 |
+
vendor_dir = os.path.join(data_dir, vendor)
|
20 |
+
view_list = [view for view in os.listdir(vendor_dir) if '.' not in view]
|
21 |
+
for view in view_list:
|
22 |
+
view_dir = os.path.join(vendor_dir, view)
|
23 |
+
subject_full_list = [subject for subject in os.listdir(view_dir) if '.' not in subject]
|
24 |
+
for subject in subject_full_list:
|
25 |
+
if subject in subject_list:
|
26 |
+
subject_dir = os.path.join(view_dir, subject)
|
27 |
+
org_data_dir = os.path.join(subject_dir, 'data_org')
|
28 |
+
org_data_id = os.path.join(org_data_dir, os.listdir(org_data_dir)[0])
|
29 |
+
clutter_list = [clutter for clutter in os.listdir(subject_dir)
|
30 |
+
if clutter in ['data_NFClt', 'data_NFRvbClt', 'data_RvbClt']
|
31 |
+
and '.' not in clutter]
|
32 |
+
for clutter in clutter_list:
|
33 |
+
clutter_dir = os.path.join(subject_dir, clutter)
|
34 |
+
clutter_ids = os.listdir(clutter_dir)
|
35 |
+
clutter_ids_dir = [os.path.join(clutter_dir, id_dir) for id_dir in clutter_ids if '.DS' not in id_dir]
|
36 |
+
in_ids += clutter_ids_dir
|
37 |
+
out_ids += [org_data_id]*len(os.listdir(clutter_dir))
|
38 |
+
return in_ids, out_ids
|
39 |
+
|
40 |
+
def id_preparation(config):
|
41 |
+
tr_subjects, val_subject, te_subject = generate_tr_val_te_subject_ids(
|
42 |
+
subject_list=config["subject_list"], val_subject_id=config["CV"]["val_subject_id"])
|
43 |
+
if config["tr_phase"]:
|
44 |
+
in_ids_tr, out_ids_tr = generate_data_ids(config["paths"]["data_path"], tr_subjects)
|
45 |
+
in_ids_val, out_ids_val = generate_data_ids(config["paths"]["data_path"], val_subject)
|
46 |
+
return in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject
|
47 |
+
else:
|
48 |
+
in_ids_te, out_ids_te = generate_data_ids(config["paths"]["data_path"], te_subject)
|
49 |
+
return in_ids_te, out_ids_te, te_subject, val_subject
|
50 |
+
|
51 |
+
def create_weight_dir(val_subject, te_subject, config):
|
52 |
+
weight_dir = os.path.join(config["paths"]["save_path"],
|
53 |
+
"Weights", f"ValTeIDs_{val_subject}_{te_subject}")
|
54 |
+
if not os.path.exists(weight_dir):
|
55 |
+
os.makedirs(weight_dir)
|
56 |
+
return weight_dir
|
Filter3D_Lrec_adv/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Filter3D_Lrec_adv/DataGen.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for generating batches of 3D images and a 3D mask.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
import scipy.io as sio
|
8 |
+
import tensorflow as tf
|
9 |
+
|
10 |
+
def _shift_first_frame(vol_in, vol_out, tr_phase):
|
11 |
+
n_frames = vol_in.shape[-1]
|
12 |
+
n, p = 1, 0.5 # p is the probability of shifting the first frame.
|
13 |
+
if tr_phase:
|
14 |
+
if np.random.binomial(n,p):
|
15 |
+
first_frm = np.random.permutation(np.arange(n_frames))[0]
|
16 |
+
vol_in = np.concatenate((vol_in[:,:,first_frm:], vol_in[:,:,:first_frm]), axis=-1)
|
17 |
+
vol_out = np.concatenate((vol_out[:,:,first_frm:], vol_out[:,:,:first_frm]), axis=-1)
|
18 |
+
return vol_in, vol_out
|
19 |
+
|
20 |
+
def _image_vol_normalization(vol):
|
21 |
+
vol = vol/255
|
22 |
+
vol[vol < 0] = 0
|
23 |
+
vol[vol > 1] = 1
|
24 |
+
return vol
|
25 |
+
|
26 |
+
def _image_vol_augmentation(vol_in, vol_out, tr_phase):
|
27 |
+
vol_in, vol_out = _shift_first_frame(vol_in, vol_out, tr_phase)
|
28 |
+
vol_in_norm = _image_vol_normalization(vol_in)
|
29 |
+
vol_out_norm = _image_vol_normalization(vol_out)
|
30 |
+
vol_shape = [sh for sh in vol_in.shape]
|
31 |
+
vol_shape.append(1)
|
32 |
+
vol_in, vol_out = np.empty(vol_shape), np.empty(vol_shape)
|
33 |
+
vol_in[:,:,:,0], vol_out[:,:,:,0] = vol_in_norm, vol_out_norm
|
34 |
+
return [vol_in, vol_out]
|
35 |
+
|
36 |
+
def _mask_generation(clt_pos_dta, NFClt):
|
37 |
+
vol_mask = np.zeros((128, 128, 50))
|
38 |
+
for frm in range(50):
|
39 |
+
if NFClt:
|
40 |
+
x0, x1 = clt_pos_dta[0,0], clt_pos_dta[0,1]
|
41 |
+
y0, y1 = clt_pos_dta[1,0], clt_pos_dta[1,1]
|
42 |
+
elif clt_pos_dta.shape[-1] == 1:
|
43 |
+
x0, x1 = clt_pos_dta[frm][0][0,0], clt_pos_dta[frm][0][0,1]
|
44 |
+
y0, y1 = clt_pos_dta[frm][0][1,0], clt_pos_dta[frm][0][1,1]
|
45 |
+
else:
|
46 |
+
x0 = np.min([clt_pos_dta[0][0][0,0], clt_pos_dta[0][1][0,0]])
|
47 |
+
x1 = np.max([clt_pos_dta[0][0][0,1], clt_pos_dta[0][1][0,1]])
|
48 |
+
y0 = np.min([clt_pos_dta[0][0][1,0], clt_pos_dta[0][1][1,0]])
|
49 |
+
y1 = np.max([clt_pos_dta[0][0][1,1], clt_pos_dta[0][1][1,1]])
|
50 |
+
# create the masks
|
51 |
+
vol_mask[x0:x1,y0:y1,frm] = 1
|
52 |
+
return vol_mask
|
53 |
+
|
54 |
+
class DataGen(tf.keras.utils.Sequence):
|
55 |
+
"""
|
56 |
+
Generating batches of input cluttered volumes and their corresponding
|
57 |
+
clutter-free output volumes
|
58 |
+
"""
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
dim:list,
|
62 |
+
in_dir:str,
|
63 |
+
out_dir:str,
|
64 |
+
id_list:list,
|
65 |
+
batch_size:int,
|
66 |
+
masked_in=True,
|
67 |
+
tr_phase=True,
|
68 |
+
*args,
|
69 |
+
**kwargs):
|
70 |
+
|
71 |
+
'Initialization'
|
72 |
+
self.dim = dim
|
73 |
+
self.in_dir = in_dir
|
74 |
+
self.out_dir = out_dir
|
75 |
+
self.id_list = id_list
|
76 |
+
self.batch_size = batch_size
|
77 |
+
self.masked_in = masked_in
|
78 |
+
self.tr_phase = tr_phase
|
79 |
+
|
80 |
+
def __len__(self):
|
81 |
+
return int(np.floor(len(self.id_list) / self.batch_size))
|
82 |
+
|
83 |
+
def __getitem__(self, idx):
|
84 |
+
batch = self.id_list[idx*self.batch_size:(idx+1)*self.batch_size]
|
85 |
+
x_aug, y_aug, mask = [], [], []
|
86 |
+
for i, ID in enumerate(batch):
|
87 |
+
vol_in = sio.loadmat(self.in_dir[ID])['data_artf']
|
88 |
+
vol_out = sio.loadmat(self.out_dir[ID])['data_org']
|
89 |
+
# Call the data augmentation function
|
90 |
+
aug_vols = _image_vol_augmentation(vol_in, vol_out, self.tr_phase)
|
91 |
+
# Creat the binary mask
|
92 |
+
if self.masked_in:
|
93 |
+
clt_pos_id = self.in_dir[ID]
|
94 |
+
clt_pos_id = clt_pos_id.replace('/data_', '/pos_')
|
95 |
+
if 'pos_NFClt' in clt_pos_id:
|
96 |
+
clt_pos_dta = sio.loadmat(clt_pos_id)['pos_sp_nf']
|
97 |
+
NFClt = True
|
98 |
+
else:
|
99 |
+
clt_pos_dta = sio.loadmat(clt_pos_id)['pos_artf']
|
100 |
+
NFClt = False
|
101 |
+
vol_mask = _mask_generation(clt_pos_dta, NFClt)
|
102 |
+
mask.append(vol_mask)
|
103 |
+
x_aug.append(aug_vols[0])
|
104 |
+
y_aug.append(aug_vols[1])
|
105 |
+
|
106 |
+
if self.tr_phase and self.masked_in:
|
107 |
+
return np.asarray(x_aug), np.asarray(y_aug), np.asarray(mask)
|
108 |
+
elif self.tr_phase:
|
109 |
+
return np.asarray(x_aug), np.asarray(y_aug)
|
110 |
+
else:
|
111 |
+
return np.asarray(x_aug)
|
Filter3D_Lrec_adv/Discriminator.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Functions for training discriminator model (ResNet) of the 3D clutter filtering network.
|
3 |
+
"""
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow.keras import backend as K
|
6 |
+
from tensorflow.keras.layers import (Conv3D, MaxPooling3D, Dense, Activation, BatchNormalization, Add,
|
7 |
+
Dropout, Concatenate, multiply, Input, Flatten, Lambda)
|
8 |
+
from tensorflow.keras.models import Model
|
9 |
+
from tensorflow.keras.optimizers import Adam
|
10 |
+
|
11 |
+
def max_fnc(x):
|
12 |
+
return K.max(x, axis=1, keepdims=False)
|
13 |
+
|
14 |
+
def ResNetModule(x, n_filters, n_blks, krnl_size=3, stride=1, name=None, **prm):
|
15 |
+
|
16 |
+
shortcut = Conv3D(2*n_filters, 1, strides=stride, name=f"{name}_conv_shortcut")(x)
|
17 |
+
shortcut = BatchNormalization(name=f"{name}_bn_shortcut")(shortcut)
|
18 |
+
for i in range(n_blks):
|
19 |
+
x = Conv3D(n_filters, krnl_size, padding='same', name=f"{name}_conv{i+1}")(x)
|
20 |
+
x = BatchNormalization(name=f"{name}_bn{i+1}")(x)
|
21 |
+
x = Activation('relu', name=f"{name}_relu{i+1}")(x)
|
22 |
+
x = Dropout(0.05, name=f"{name}_drp{i+1}")(x)
|
23 |
+
|
24 |
+
x = Conv3D(2*n_filters, 1, name=f"{name}_conv_last")(x)
|
25 |
+
x = BatchNormalization(name=f"{name}_bn_last")(x)
|
26 |
+
x = Add(name=f"{name}_add")([shortcut, x])
|
27 |
+
x = Activation('relu', name=f"{name}_out")(x)
|
28 |
+
return x
|
29 |
+
|
30 |
+
def ConvToFC(inp, kernel_size=3, name=None, **prm):
|
31 |
+
inp_shape = K.int_shape(inp)
|
32 |
+
x = Conv3D(filters=prm['ConvToFC'][0], kernel_size=(1,1,inp_shape[3]),
|
33 |
+
padding='valid', activation='relu', strides=1, name=f'{name}_1')(inp)
|
34 |
+
x = Conv3D(filters=prm['ConvToFC'][1], kernel_size=(inp_shape[1],inp_shape[2],1),
|
35 |
+
padding='valid', activation='relu', strides=1, name=f'{name}_2')(x)
|
36 |
+
if len(prm['ConvToFC']) > 2:
|
37 |
+
for i in range(2, len(prm['ConvToFC'])):
|
38 |
+
x = Conv3D(filters=prm['ConvToFC'][i], kernel_size=1, padding='valid',
|
39 |
+
activation='relu', strides=1, name=f'{name}_{i+1}')(x)
|
40 |
+
print(x.shape)
|
41 |
+
x = Flatten()(x)
|
42 |
+
return x
|
43 |
+
|
44 |
+
def ResNet(inp, n_krn, name=None, **prm):
|
45 |
+
x = Conv3D(n_krn, kernel_size=prm['kernel_size'], padding='same',
|
46 |
+
data_format="channels_last", strides=1, name=f"{name}_conv0")(inp)
|
47 |
+
x = BatchNormalization(name=f"{name}_bn0")(x)
|
48 |
+
x = Activation('relu', name=f"{name}_relu0")(x)
|
49 |
+
for i in range(len(prm['lvl_blks_config'])):
|
50 |
+
x = MaxPooling3D(pool_size=prm['pool_size'], strides=prm['strides'], name=f"{name}_pool{i+1}")(x)
|
51 |
+
x = ResNetModule(x, n_filters=(2**i)*n_krn, krnl_size=prm['kernel_size'], stride=1,
|
52 |
+
n_blks=prm['lvl_blks_config'][i], name=f"{name}_module{i+1}", **prm)
|
53 |
+
x = ConvToFC(x, name='CnvToFC', **prm)
|
54 |
+
return x
|
55 |
+
|
56 |
+
def DenseLayer(x, name=None, **prm):
|
57 |
+
for i in range(len(prm['dense_layer_spec'])):
|
58 |
+
x = Dense(prm['dense_layer_spec'][i], activation='relu', name=f"{name}_DenseLayer_{i}")(x)
|
59 |
+
return Dense(1, activation='sigmoid', name=f"{name}_DenseLayer_out")(x)
|
60 |
+
|
61 |
+
def discriminator_3D(lr, **prm):
|
62 |
+
inp_3D = Input(prm['input_dim'])
|
63 |
+
conv_out = ResNet(inp_3D, n_krn=prm['n_init_filters'], name=prm['model_name'], **prm)
|
64 |
+
print(conv_out.shape)
|
65 |
+
dense_out = DenseLayer(x=conv_out, name=prm['model_name'], **prm)
|
66 |
+
conv_model = Model(inputs=inp_3D, outputs=dense_out, name=prm['model_name'])
|
67 |
+
conv_model.summary()
|
68 |
+
opt = Adam(learning_rate=lr)
|
69 |
+
conv_model.compile(optimizer=opt, loss=prm['loss'], metrics=prm['metrics'])
|
70 |
+
return conv_model
|
Filter3D_Lrec_adv/TestClutterFilter3D_GAN.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for testing the 3D clutter filtering model.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import numpy as np
|
9 |
+
import tensorflow as tf
|
10 |
+
from tensorflow.keras.layers import Input
|
11 |
+
from tensorflow.keras.models import Model
|
12 |
+
from tensorflow.keras.optimizers import Adam
|
13 |
+
|
14 |
+
code_dir = os.path.dirname(os.path.abspath(__file__))
|
15 |
+
sys.path.append(os.path.dirname(code_dir))
|
16 |
+
from Filter3D_Lrec.utils import *
|
17 |
+
from Filter3D_Lrec.Error_analysis import *
|
18 |
+
from Filter3D_Lrec.Model_ClutterFilter3D import Unet3D
|
19 |
+
from DataGen import DataGen
|
20 |
+
from Discriminator import discriminator_3D
|
21 |
+
from Train_GAN import train_gan
|
22 |
+
|
23 |
+
def data_generation(in_ids_te, out_ids_te, config):
|
24 |
+
DtaGenTe_prm = {
|
25 |
+
'dim': config["generator_prm"]["input_dim"],
|
26 |
+
'in_dir': in_ids_te,
|
27 |
+
'out_dir': out_ids_te,
|
28 |
+
'id_list': np.arange(len(in_ids_te)),
|
29 |
+
'batch_size': config["learning_prm"]["batch_size"],
|
30 |
+
'tr_phase': False}
|
31 |
+
return DataGen(**DtaGenTe_prm)
|
32 |
+
|
33 |
+
def generator_3D(**prm):
|
34 |
+
prm_gen = prm
|
35 |
+
prm_gen["network_prm"] = prm["generator_prm"]
|
36 |
+
main_in = Input(prm_gen["network_prm"]["input_dim"])
|
37 |
+
filter_out = Unet3D(x_in=main_in, name="CF", **prm_gen)
|
38 |
+
model = Model(inputs=main_in, outputs=filter_out, name='generator_3D')
|
39 |
+
model.summary()
|
40 |
+
return model
|
41 |
+
|
42 |
+
def main(config):
|
43 |
+
in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config)
|
44 |
+
te_gen = data_generation(in_ids_te, out_ids_te, config)
|
45 |
+
weight_dir = create_weight_dir(val_subject, te_subject, config)
|
46 |
+
generator = generator_3D(**config)
|
47 |
+
generator.load_weights(os.path.join(weight_dir, config["weight_name"] + ".hdf5"))
|
48 |
+
results_te = generator.predict_generator(te_gen, verbose=2)
|
49 |
+
df_errors = compute_mae(in_ids_te, results_te)
|
50 |
+
df_errors.to_csv(
|
51 |
+
os.path.join(weight_dir, config["weight_name"] + ".csv"))
|
52 |
+
return None
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
parser = argparse.ArgumentParser()
|
56 |
+
parser.add_argument("--config", help="path of the config file", default="config.json")
|
57 |
+
args = parser.parse_args()
|
58 |
+
assert os.path.isfile(args.config)
|
59 |
+
with open(args.config, "r") as read_file:
|
60 |
+
config = json.load(read_file)
|
61 |
+
main(config)
|
Filter3D_Lrec_adv/TrainClutterFilter3D_L2_adv_loss.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for training the 3D clutter filtering model with adversarial loss.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import numpy as np
|
9 |
+
import tensorflow as tf
|
10 |
+
from tensorflow.keras.layers import Input
|
11 |
+
from tensorflow.keras.models import Model
|
12 |
+
from tensorflow.keras.optimizers import Adam
|
13 |
+
|
14 |
+
code_dir = os.path.dirname(os.path.abspath(__file__))
|
15 |
+
sys.path.append(os.path.dirname(code_dir))
|
16 |
+
from Filter3D_Lrec.utils import *
|
17 |
+
from Filter3D_Lrec.Model_ClutterFilter3D import Unet3D
|
18 |
+
from DataGen import DataGen
|
19 |
+
from Discriminator import discriminator_3D
|
20 |
+
from Train_GAN import train_gan
|
21 |
+
|
22 |
+
def data_generation(in_ids_tr, out_ids_tr, config):
|
23 |
+
DtaGenTr_prm = {
|
24 |
+
'dim': config["generator_prm"]["input_dim"],
|
25 |
+
'in_dir': in_ids_tr,
|
26 |
+
'out_dir': out_ids_tr,
|
27 |
+
'id_list': np.arange(len(in_ids_tr)),
|
28 |
+
'batch_size': config["learning_prm"]["batch_size"],
|
29 |
+
'tr_phase': True}
|
30 |
+
return DataGen(**DtaGenTr_prm)
|
31 |
+
|
32 |
+
def generator_3D(**prm):
|
33 |
+
prm_gen = prm
|
34 |
+
prm_gen["network_prm"] = prm["generator_prm"]
|
35 |
+
main_in = Input(prm_gen["network_prm"]["input_dim"])
|
36 |
+
filter_out = Unet3D(x_in=main_in, name="CF", **prm_gen)
|
37 |
+
model = Model(inputs=main_in, outputs=filter_out, name='generator_3D')
|
38 |
+
model.summary()
|
39 |
+
return model
|
40 |
+
|
41 |
+
def gan_3D(gen_model, disc_model, w_g, w_d, lr):
|
42 |
+
disc_model.trainable = False
|
43 |
+
disc_output = disc_model(gen_model.output)
|
44 |
+
opt = Adam(lr=lr, beta_1=0.5)
|
45 |
+
model = Model(gen_model.input, [gen_model.output, disc_output])
|
46 |
+
model.compile(loss=['mse', 'binary_crossentropy'],
|
47 |
+
loss_weights=[w_g, w_d],
|
48 |
+
optimizer=opt)
|
49 |
+
return model
|
50 |
+
|
51 |
+
def main(config):
|
52 |
+
in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config)
|
53 |
+
weight_dir = create_weight_dir(val_subject, te_subject, config)
|
54 |
+
weight_name = (
|
55 |
+
f'CF3D_GAN_ValTeSbj_{val_subject}_{te_subject}'
|
56 |
+
f'_InSkp{config["generator_prm"]["in_skip"]}'
|
57 |
+
f'_Att{config["generator_prm"]["attention"]}_lr{config["learning_prm"]["lr"]}'
|
58 |
+
f'_wg{config["generator_prm"]["w_g"]}'
|
59 |
+
f'_MaskedIn{config["discriminator_prm"]["masked_in"]}')
|
60 |
+
tr_gen = data_generation(in_ids_tr, out_ids_tr, config)
|
61 |
+
generator = generator_3D(**config)
|
62 |
+
discriminator = discriminator_3D(config["learning_prm"]["lr"],
|
63 |
+
**config["discriminator_prm"])
|
64 |
+
gan_model = gan_3D(generator, discriminator,
|
65 |
+
config["generator_prm"]["w_g"], config["discriminator_prm"]["w_d"],
|
66 |
+
config["learning_prm"]["lr"])
|
67 |
+
train_gan(g_model=generator, d_model=discriminator, gan_model=gan_model,
|
68 |
+
n_epochs=config["learning_prm"]["n_epochs"],
|
69 |
+
tr_batches=tr_gen, w_dir=weight_dir, w_name=weight_name,
|
70 |
+
masked_in=config["discriminator_prm"]['masked_in'])
|
71 |
+
return None
|
72 |
+
|
73 |
+
if __name__ == '__main__':
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
parser.add_argument("--config", help="path of the config file", default="config.json")
|
76 |
+
args = parser.parse_args()
|
77 |
+
assert os.path.isfile(args.config)
|
78 |
+
with open(args.config, "r") as read_file:
|
79 |
+
config = json.load(read_file)
|
80 |
+
main(config)
|
Filter3D_Lrec_adv/Train_GAN.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Functions for training the 3D GAN model.
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
import tensorflow as tf
|
6 |
+
from tensorflow.keras import backend as K
|
7 |
+
from tensorflow.keras.models import Model
|
8 |
+
|
9 |
+
def _apply_the_mask(in_g, in_d, gen_out, th=0.1):
|
10 |
+
mask = np.abs(in_g-in_d)
|
11 |
+
mask[mask < th] = 0
|
12 |
+
mask[mask >= th] = 1
|
13 |
+
in_d_masked = np.multiply(in_d, mask)
|
14 |
+
gen_out_masked = np.multiply(gen_out, mask)
|
15 |
+
return gen_out_masked, in_d_masked
|
16 |
+
|
17 |
+
def train_gan(g_model, d_model, gan_model, n_epochs, tr_batches, w_dir, w_name, masked_in, **prm):
|
18 |
+
for i in range(n_epochs):
|
19 |
+
print(f"epoch:{i}")
|
20 |
+
rnd_ids = np.random.permutation(tr_batches.__len__())
|
21 |
+
for j in range(tr_batches.__len__()):
|
22 |
+
# Train Discriminator
|
23 |
+
in_g, in_d = tr_batches.__getitem__(rnd_ids[j])[0], tr_batches.__getitem__(rnd_ids[j])[1]
|
24 |
+
gen_out = g_model.predict(in_g)
|
25 |
+
if masked_in:
|
26 |
+
gen_out, in_d = _apply_the_mask(in_g, in_d, gen_out)
|
27 |
+
d_loss_r, d_acc_r = d_model.train_on_batch(in_d, np.ones((1, 1)))
|
28 |
+
d_loss_f, d_acc_f = d_model.train_on_batch(gen_out, np.zeros((1, 1)))
|
29 |
+
d_loss = 0.5 * np.add(d_loss_r, d_loss_f)
|
30 |
+
# Train Generator
|
31 |
+
g_loss = gan_model.train_on_batch(tr_batches.__getitem__(rnd_ids[j])[0],
|
32 |
+
[tr_batches.__getitem__(rnd_ids[j])[1], np.ones((1, 1))])
|
33 |
+
# Save weights after each epoch
|
34 |
+
filename = (w_dir + '/' + f"{w_name}_epc{i}_g_loss{np.round(g_loss, 3)}_d_loss{np.round(d_loss, 3)}" + ".hdf5")
|
35 |
+
g_model.save_weights(filename)
|
36 |
+
return None
|
Filter3D_Lrec_adv/config.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"paths": {
|
3 |
+
"data_path": "",
|
4 |
+
"save_path": ""
|
5 |
+
},
|
6 |
+
"subject_list": ["rca", "ladprox", "laddist", "lcx", "normal"],
|
7 |
+
"CV": {
|
8 |
+
"val_subject_id": 0
|
9 |
+
},
|
10 |
+
"generator_prm": {
|
11 |
+
"input_dim": [128, 128, 50, 1],
|
12 |
+
"n_levels": 4,
|
13 |
+
"n_init_filters": 16,
|
14 |
+
"in_skip": true,
|
15 |
+
"attention": true,
|
16 |
+
"kernel_size": 3,
|
17 |
+
"conv_stride": 1,
|
18 |
+
"upsampling_stride": [2, 2, 1],
|
19 |
+
"pool_size": [2, 2, 1],
|
20 |
+
"pool_stride": 1,
|
21 |
+
"padding": "same",
|
22 |
+
"act": "linear",
|
23 |
+
"w_g": 0.999
|
24 |
+
},
|
25 |
+
"discriminator_prm": {
|
26 |
+
"input_dim": [128, 128, 50, 1],
|
27 |
+
"n_init_filters": 16,
|
28 |
+
"lvl_blks_config": [3, 4, 6, 3],
|
29 |
+
"kernel_size": 3,
|
30 |
+
"pool_size": 2,
|
31 |
+
"strides": 2,
|
32 |
+
"ConvToFC": [128, 32],
|
33 |
+
"dense_layer_spec": [64, 32, 16],
|
34 |
+
"loss": "binary_crossentropy",
|
35 |
+
"metrics": ["accuracy"],
|
36 |
+
"attention": false,
|
37 |
+
"masked_in": true,
|
38 |
+
"model_name": "disc_ResNet34",
|
39 |
+
"w_d": 0.001
|
40 |
+
},
|
41 |
+
"learning_prm": {
|
42 |
+
"batch_size": 1,
|
43 |
+
"lr": 1e-4,
|
44 |
+
"drp": 0.05,
|
45 |
+
"n_epochs": 10
|
46 |
+
},
|
47 |
+
"tr_phase": true,
|
48 |
+
"model_name": "CF3D_AdvLoss",
|
49 |
+
"weight_name": ""
|
50 |
+
}
|
Filtering_results_videos/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Filtering_results_videos/in-vivo/Subject1.mp4
ADDED
Binary file (325 kB). View file
|
|
Filtering_results_videos/in-vivo/Subject2.mp4
ADDED
Binary file (409 kB). View file
|
|
Filtering_results_videos/in-vivo/Subject3.mp4
ADDED
Binary file (366 kB). View file
|
|
Filtering_results_videos/in-vivo/Subject4.mp4
ADDED
Binary file (354 kB). View file
|
|
Filtering_results_videos/synthetic/GE.mp4
ADDED
Binary file (669 kB). View file
|
|
Filtering_results_videos/synthetic/Hitachi.mp4
ADDED
Binary file (492 kB). View file
|
|
Filtering_results_videos/synthetic/Philips.mp4
ADDED
Binary file (530 kB). View file
|
|
Filtering_results_videos/synthetic/Samsung.mp4
ADDED
Binary file (509 kB). View file
|
|
Filtering_results_videos/synthetic/Siemens.mp4
ADDED
Binary file (401 kB). View file
|
|
Filtering_results_videos/synthetic/Toshiba.mp4
ADDED
Binary file (415 kB). View file
|
|