Kazuto Nakashima commited on
Commit
5acffd4
·
1 Parent(s): af80c65
Files changed (4) hide show
  1. README.md +4 -4
  2. app.py +225 -0
  3. pre-requirements.txt +3 -0
  4. requirements.txt +14 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: R2flow
3
- emoji: 🌖
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.22.0
8
  app_file: app.py
 
1
  ---
2
+ title: R2Flow
3
+ emoji: 🚗
4
+ colorFrom: indigo
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.22.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import einops
4
+ import gradio as gr
5
+ import matplotlib.cm as cm
6
+ import numpy as np
7
+ import plotly.graph_objects as go
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchdiffeq
11
+
12
+ DESCRIPTION = """
13
+ <div class="head">
14
+ <div class="title">Fast LiDAR Data Generation with Rectified Flows</div>
15
+ <div class="conference">ICRA 2025</div>
16
+ <div class="authors">
17
+ <a href="https://kazuto1011.github.io/" target="_blank" rel="noopener"> Kazuto Nakashima</a><sup>1</sup>
18
+ &nbsp;&nbsp;&nbsp;
19
+ <a> Xiaowen Liu</a><sup>1</sup>
20
+ &nbsp;&nbsp;&nbsp;
21
+ <a> Tomoya Miyawaki</a><sup>1</sup>
22
+ &nbsp;&nbsp;&nbsp;
23
+ <a> Yumi Iwashita</a><sup>2</sup>
24
+ &nbsp;&nbsp;&nbsp;
25
+ <a> Ryo Kurazume</a><sup>1</sup>
26
+ </div>
27
+ <div class="affiliations">
28
+ <sup>1</sup>Kyushu University
29
+ &nbsp;&nbsp;&nbsp;
30
+ <sup>2</sup>NASA Jet Propulsion Laboratory
31
+ </div>
32
+ <div class="materials">
33
+ <a href="https://kazuto1011.github.io/r2flow">Project</a> |
34
+ <a href="https://arxiv.org/abs/2412.02241">Paper</a> |
35
+ <a href="https://github.com/kazuto1011/r2flow">Code</a>
36
+ </div>
37
+ <br>
38
+ <div class="description">
39
+ This is a demo of our paper "Fast LiDAR Data Generation with Rectified Flows" accepted to ICRA 2025.<br>
40
+ We propose <strong>R2Flow</strong>, a rectified flow-based LiDAR generative model which generate the LiDAR range/reflectance images.<br>
41
+ </div>
42
+ <br>
43
+ </div>
44
+ """
45
+
46
+ if torch.cuda.is_available():
47
+ device = "cuda"
48
+ elif torch.backends.mps.is_available():
49
+ device = "mps"
50
+ else:
51
+ device = "cpu"
52
+
53
+ torch.set_grad_enabled(False)
54
+ torch.backends.cudnn.benchmark = True
55
+ device = torch.device(device)
56
+
57
+
58
+ model_dict = {
59
+ "1-RF": "r2flow-kitti360-1rf",
60
+ "2-RF": "r2flow-kitti360-2rf",
61
+ "2-RF + 4-TD": "r2flow-kitti360-2rf-4td",
62
+ "2-RF + 2-TD": "r2flow-kitti360-2rf-2td",
63
+ "2-RF + 1-TD": "r2flow-kitti360-2rf-1td",
64
+ }
65
+
66
+
67
+ torch_hub_kwargs = dict(
68
+ repo_or_dir="kazuto1011/r2flow",
69
+ model="pretrained_r2flow",
70
+ device=device,
71
+ show_info=False,
72
+ )
73
+
74
+
75
+ def colorize(tensor: torch.Tensor, cmap_fn=cm.turbo):
76
+ colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
77
+ colors = torch.from_numpy(colors).to(tensor)
78
+ tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
79
+ ids = (tensor * 256).clamp(0, 255).long()
80
+ tensor = F.embedding(ids, colors).permute(0, 3, 1, 2)
81
+ tensor = tensor.mul(255).clamp(0, 255).byte()
82
+ return tensor
83
+
84
+
85
+ def model_verbose(model, nfe, progress):
86
+ handler = progress.tqdm(range(nfe), desc="Generating...")
87
+
88
+ def _model(t, x):
89
+ handler.update(1)
90
+ return model(t, x)
91
+
92
+ return _model
93
+
94
+
95
+ def generate(nfe: int, solver: str, phase: str, progress=gr.Progress()):
96
+ model, lidar_utils, _ = torch.hub.load(config=model_dict[phase], **torch_hub_kwargs)
97
+
98
+ with torch.inference_mode():
99
+ x1 = torchdiffeq.odeint(
100
+ func=model_verbose(model, int(nfe), progress),
101
+ y0=torch.randn(1, model.in_channels, *model.resolution, device=device),
102
+ t=torch.linspace(0, 1, int(nfe) + 1, device=device),
103
+ method=solver,
104
+ )[-1]
105
+
106
+ depth = lidar_utils.restore_metric_depth(x1[:, [0]])
107
+ rflct = lidar_utils.denormalize(x1[:, [1]])
108
+ point = lidar_utils.convert_metric_depth(depth, format="cartesian")
109
+
110
+ z_min, z_max = -2, 0.5
111
+ z = (point[:, [2]] - z_min) / (z_max - z_min)
112
+ color = colorize(z.clamp(0, 1), cm.viridis) / 255
113
+ point = einops.rearrange(point, "1 c h w -> (h w) c").cpu().numpy()
114
+ color = einops.rearrange(color, "1 c h w -> (h w) c").cpu().numpy()
115
+ fig = go.Figure(
116
+ data=[
117
+ go.Scatter3d(
118
+ x=-point[..., 0],
119
+ y=-point[..., 1],
120
+ z=point[..., 2],
121
+ mode="markers",
122
+ marker=dict(size=1, color=color),
123
+ )
124
+ ],
125
+ layout=dict(
126
+ scene=dict(
127
+ xaxis=dict(showticklabels=False, visible=False),
128
+ yaxis=dict(showticklabels=False, visible=False),
129
+ zaxis=dict(showticklabels=False, visible=False),
130
+ aspectmode="data",
131
+ ),
132
+ margin=dict(l=0, r=0, b=0, t=0),
133
+ paper_bgcolor="white",
134
+ plot_bgcolor="white",
135
+ ),
136
+ )
137
+ depth = depth / lidar_utils.max_depth
138
+ depth = colorize(depth, cm.turbo)[0].permute(1, 2, 0).cpu().numpy()
139
+ rflct = colorize(rflct, cm.turbo)[0].permute(1, 2, 0).cpu().numpy()
140
+
141
+ model.cpu()
142
+ lidar_utils.cpu()
143
+ return depth, rflct, fig
144
+
145
+
146
+ def setup_dropdown(value):
147
+ if "TD" in value:
148
+ solver_choices = ["euler"]
149
+ solver_default = "euler"
150
+ num_step = re.findall(r"(\d+)-TD", value)[0]
151
+ nfe_choices = [num_step]
152
+ nfe_default = num_step
153
+ else:
154
+ solver_choices = ["euler", "dopri5"]
155
+ solver_default = "euler"
156
+ nfe_choices = [2**i for i in range(0, 9)]
157
+ nfe_default = 256
158
+ dropdown_solver = gr.Dropdown(
159
+ choices=solver_choices,
160
+ value=solver_default,
161
+ label="ODE solver",
162
+ info="Fixed if TD enabled",
163
+ )
164
+ dropdown_nfe = gr.Dropdown(
165
+ choices=nfe_choices,
166
+ value=nfe_default,
167
+ label="Number of sampling steps",
168
+ info="Fixed if TD enabled",
169
+ )
170
+ return dropdown_solver, dropdown_nfe
171
+
172
+
173
+ with gr.Blocks(
174
+ css="""
175
+ .head {
176
+ text-align: center;
177
+ display: block;
178
+ font-size: var(--text-xl);
179
+ }
180
+
181
+ .title {
182
+ font-size: var(--text-xxl);
183
+ font-weight: bold;
184
+ margin-top: 2rem;
185
+ }
186
+
187
+ .description {
188
+ font-size: var(--text-lg);
189
+ }
190
+ """,
191
+ theme=gr.themes.Ocean(),
192
+ ) as demo:
193
+ gr.HTML(DESCRIPTION)
194
+
195
+ with gr.Row(variant="panel"):
196
+ with gr.Column():
197
+ gr.Textbox(device, label="Running device")
198
+ dropdown_model = gr.Dropdown(
199
+ choices=list(model_dict.keys()),
200
+ value="2-RF + 4-TD",
201
+ label="Model checkpoint",
202
+ info="RF: rectified flow, TD: timestep distillation",
203
+ )
204
+ dropdown_solver, dropdown_nfe = setup_dropdown(dropdown_model.value)
205
+ dropdown_model.change(
206
+ setup_dropdown,
207
+ inputs=[dropdown_model],
208
+ outputs=[dropdown_solver, dropdown_nfe],
209
+ )
210
+ btn = gr.Button(value="Generate", variant="primary")
211
+
212
+ with gr.Column():
213
+ range_view = gr.Image(type="numpy", label="Range image")
214
+ rflct_view = gr.Image(type="numpy", label="Reflectance image")
215
+ point_view = gr.Plot(label="Point cloud")
216
+
217
+ btn.click(
218
+ generate,
219
+ inputs=[dropdown_nfe, dropdown_solver, dropdown_model],
220
+ outputs=[range_view, rflct_view, point_view],
221
+ )
222
+
223
+
224
+ demo.queue()
225
+ demo.launch()
pre-requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ --index-url https://download.pytorch.org/whl/cpu
2
+ torch==2.1.2
3
+ torchvision==0.16.2
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.6.1
2
+ gradio==5.22.0
3
+ kornia==0.7.0
4
+ matplotlib==3.7.1
5
+ pydantic==2.6.3
6
+ rich==13.5.1
7
+ simple-parsing==0.1.5
8
+ torchcfm==1.0.5
9
+ torchdiffeq==0.2.4
10
+ tqdm==4.66.1
11
+ plotly==6.0.1
12
+ numpy==1.26.4
13
+ --find-links https://shi-labs.com/natten/wheels
14
+ natten==0.17.5+torch260cpu