hujiecpp commited on
Commit
eb55960
·
1 Parent(s): 0887f00

init project

Browse files
Files changed (1) hide show
  1. modules/dust3r/cloud_opt/commons.py +30 -7
modules/dust3r/cloud_opt/commons.py CHANGED
@@ -49,19 +49,42 @@ def get_imshapes(edges, pred_i, pred_j):
49
  return imshapes
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def get_conf_trf(mode):
53
  if mode == 'log':
54
- def conf_trf(x): return x.log()
55
  elif mode == 'sqrt':
56
- def conf_trf(x): return x.sqrt()
57
  elif mode == 'm1':
58
- def conf_trf(x): return x-1
59
  elif mode in ('id', 'none'):
60
- def conf_trf(x): return x
61
  else:
62
- raise ValueError(f'bad mode for {mode=}')
63
- return conf_trf
64
-
65
 
66
  def l2_dist(a, b, weight=None):
67
  if weight == None:
 
49
  return imshapes
50
 
51
 
52
+ # def get_conf_trf(mode):
53
+ # if mode == 'log':
54
+ # def conf_trf(x): return x.log()
55
+ # elif mode == 'sqrt':
56
+ # def conf_trf(x): return x.sqrt()
57
+ # elif mode == 'm1':
58
+ # def conf_trf(x): return x-1
59
+ # elif mode in ('id', 'none'):
60
+ # def conf_trf(x): return x
61
+ # else:
62
+ # raise ValueError(f'bad mode for {mode=}')
63
+ # return conf_trf
64
+
65
+ def conf_trf_log(x):
66
+ return x.log()
67
+
68
+ def conf_trf_sqrt(x):
69
+ return x.sqrt()
70
+
71
+ def conf_trf_m1(x):
72
+ return x - 1
73
+
74
+ def conf_trf_id(x):
75
+ return x
76
+
77
  def get_conf_trf(mode):
78
  if mode == 'log':
79
+ return conf_trf_log
80
  elif mode == 'sqrt':
81
+ return conf_trf_sqrt
82
  elif mode == 'm1':
83
+ return conf_trf_m1
84
  elif mode in ('id', 'none'):
85
+ return conf_trf_id
86
  else:
87
+ raise ValueError(f"bad mode {mode=}")
 
 
88
 
89
  def l2_dist(a, b, weight=None):
90
  if weight == None: