adding files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +14 -0
- .gitignore +139 -0
- CODE_OF_CONDUCT.md +128 -0
- Expr.py +14 -0
- LICENSE +21 -0
- README.md +186 -3
- configs/InSPyReNet_Res2Net50.yaml +76 -0
- configs/InSPyReNet_SwinB.yaml +76 -0
- configs/extra_dataset/InSPyReNet_SwinB_DH.yaml +76 -0
- configs/extra_dataset/InSPyReNet_SwinB_DHU_LR.yaml +76 -0
- configs/extra_dataset/InSPyReNet_SwinB_DH_LR.yaml +76 -0
- configs/extra_dataset/InSPyReNet_SwinB_DIS5K.yaml +76 -0
- configs/extra_dataset/InSPyReNet_SwinB_DIS5K_LR.yaml +76 -0
- configs/extra_dataset/InSPyReNet_SwinB_HU.yaml +76 -0
- configs/extra_dataset/InSPyReNet_SwinB_HU_LR.yaml +76 -0
- configs/extra_dataset/Plus_Ultra.yaml +98 -0
- configs/extra_dataset/Plus_Ultra_LR.yaml +96 -0
- data/custom_transforms.py +206 -0
- data/dataloader.py +243 -0
- docs/getting_started.md +138 -0
- docs/model_zoo.md +65 -0
- figures/demo_image.gif +3 -0
- figures/demo_type.png +3 -0
- figures/demo_video.gif +3 -0
- figures/demo_webapp.gif +3 -0
- figures/fig_architecture.png +3 -0
- figures/fig_pyramid_blending.png +3 -0
- figures/fig_qualitative.png +3 -0
- figures/fig_qualitative2.png +3 -0
- figures/fig_qualitative3.jpg +3 -0
- figures/fig_qualitative_dis.png +3 -0
- figures/fig_quantitative.png +3 -0
- figures/fig_quantitative2.png +3 -0
- figures/fig_quantitative3.png +3 -0
- figures/fig_quantitative4.png +3 -0
- lib/InSPyReNet.py +199 -0
- lib/__init__.py +1 -0
- lib/backbones/Res2Net_v1b.py +268 -0
- lib/backbones/SwinTransformer.py +652 -0
- lib/modules/attention_module.py +103 -0
- lib/modules/context_module.py +54 -0
- lib/modules/decoder_module.py +44 -0
- lib/modules/layers.py +175 -0
- lib/optim/__init__.py +2 -0
- lib/optim/losses.py +36 -0
- lib/optim/scheduler.py +30 -0
- requirements.txt +9 -0
- run/Eval.py +138 -0
- run/Inference.py +167 -0
- run/Test.py +59 -0
.gitattributes
CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
figures/demo_image.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
figures/demo_type.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
figures/demo_video.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
figures/demo_webapp.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
figures/fig_architecture.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
figures/fig_pyramid_blending.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
figures/fig_qualitative_dis.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
figures/fig_qualitative.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
figures/fig_qualitative2.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
figures/fig_qualitative3.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
figures/fig_quantitative.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
figures/fig_quantitative2.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
figures/fig_quantitative3.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
figures/fig_quantitative4.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib64/
|
18 |
+
parts/
|
19 |
+
sdist/
|
20 |
+
var/
|
21 |
+
wheels/
|
22 |
+
pip-wheel-metadata/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
|
53 |
+
# Translations
|
54 |
+
*.mo
|
55 |
+
*.pot
|
56 |
+
|
57 |
+
# Django stuff:
|
58 |
+
*.log
|
59 |
+
local_settings.py
|
60 |
+
db.sqlite3
|
61 |
+
db.sqlite3-journal
|
62 |
+
|
63 |
+
# Flask stuff:
|
64 |
+
instance/
|
65 |
+
.webassets-cache
|
66 |
+
|
67 |
+
# Scrapy stuff:
|
68 |
+
.scrapy
|
69 |
+
|
70 |
+
# Sphinx documentation
|
71 |
+
docs/_build/
|
72 |
+
|
73 |
+
# PyBuilder
|
74 |
+
target/
|
75 |
+
|
76 |
+
# Jupyter Notebook
|
77 |
+
.ipynb_checkpoints
|
78 |
+
|
79 |
+
# IPython
|
80 |
+
profile_default/
|
81 |
+
ipython_config.py
|
82 |
+
|
83 |
+
# pyenv
|
84 |
+
.python-version
|
85 |
+
|
86 |
+
# pipenv
|
87 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
88 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
89 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
90 |
+
# install all needed dependencies.
|
91 |
+
#Pipfile.lock
|
92 |
+
|
93 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
94 |
+
__pypackages__/
|
95 |
+
|
96 |
+
# Celery stuff
|
97 |
+
celerybeat-schedule
|
98 |
+
celerybeat.pid
|
99 |
+
|
100 |
+
# SageMath parsed files
|
101 |
+
*.sage.py
|
102 |
+
|
103 |
+
# Environments
|
104 |
+
.env
|
105 |
+
.venv
|
106 |
+
env/
|
107 |
+
venv/
|
108 |
+
ENV/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
# Spyder project settings
|
113 |
+
.spyderproject
|
114 |
+
.spyproject
|
115 |
+
|
116 |
+
# Rope project settings
|
117 |
+
.ropeproject
|
118 |
+
|
119 |
+
# mkdocs documentation
|
120 |
+
/site
|
121 |
+
|
122 |
+
# mypy
|
123 |
+
.mypy_cache/
|
124 |
+
.dmypy.json
|
125 |
+
dmypy.json
|
126 |
+
|
127 |
+
# Pyre type checker
|
128 |
+
.pyre/
|
129 |
+
|
130 |
+
data/backbone_ckpt
|
131 |
+
data/Train_Dataset
|
132 |
+
data/Test_Dataset
|
133 |
+
results/*
|
134 |
+
snapshots
|
135 |
+
snapshots/*
|
136 |
+
Samples/*
|
137 |
+
temp/*
|
138 |
+
.idea
|
139 |
+
.vscode
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributor Covenant Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
We as members, contributors, and leaders pledge to make participation in our
|
6 |
+
community a harassment-free experience for everyone, regardless of age, body
|
7 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
8 |
+
identity and expression, level of experience, education, socio-economic status,
|
9 |
+
nationality, personal appearance, race, religion, or sexual identity
|
10 |
+
and orientation.
|
11 |
+
|
12 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
13 |
+
diverse, inclusive, and healthy community.
|
14 |
+
|
15 |
+
## Our Standards
|
16 |
+
|
17 |
+
Examples of behavior that contributes to a positive environment for our
|
18 |
+
community include:
|
19 |
+
|
20 |
+
* Demonstrating empathy and kindness toward other people
|
21 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
22 |
+
* Giving and gracefully accepting constructive feedback
|
23 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
24 |
+
and learning from the experience
|
25 |
+
* Focusing on what is best not just for us as individuals, but for the
|
26 |
+
overall community
|
27 |
+
|
28 |
+
Examples of unacceptable behavior include:
|
29 |
+
|
30 |
+
* The use of sexualized language or imagery, and sexual attention or
|
31 |
+
advances of any kind
|
32 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
33 |
+
* Public or private harassment
|
34 |
+
* Publishing others' private information, such as a physical or email
|
35 |
+
address, without their explicit permission
|
36 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
37 |
+
professional setting
|
38 |
+
|
39 |
+
## Enforcement Responsibilities
|
40 |
+
|
41 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
42 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
43 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
44 |
+
or harmful.
|
45 |
+
|
46 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
47 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
48 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
49 |
+
decisions when appropriate.
|
50 |
+
|
51 |
+
## Scope
|
52 |
+
|
53 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
54 |
+
an individual is officially representing the community in public spaces.
|
55 |
+
Examples of representing our community include using an official e-mail address,
|
56 |
+
posting via an official social media account, or acting as an appointed
|
57 |
+
representative at an online or offline event.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported to the community leaders responsible for enforcement at
|
63 |
+
.
|
64 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
65 |
+
|
66 |
+
All community leaders are obligated to respect the privacy and security of the
|
67 |
+
reporter of any incident.
|
68 |
+
|
69 |
+
## Enforcement Guidelines
|
70 |
+
|
71 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
72 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
73 |
+
|
74 |
+
### 1. Correction
|
75 |
+
|
76 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
77 |
+
unprofessional or unwelcome in the community.
|
78 |
+
|
79 |
+
**Consequence**: A private, written warning from community leaders, providing
|
80 |
+
clarity around the nature of the violation and an explanation of why the
|
81 |
+
behavior was inappropriate. A public apology may be requested.
|
82 |
+
|
83 |
+
### 2. Warning
|
84 |
+
|
85 |
+
**Community Impact**: A violation through a single incident or series
|
86 |
+
of actions.
|
87 |
+
|
88 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
89 |
+
interaction with the people involved, including unsolicited interaction with
|
90 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
91 |
+
includes avoiding interactions in community spaces as well as external channels
|
92 |
+
like social media. Violating these terms may lead to a temporary or
|
93 |
+
permanent ban.
|
94 |
+
|
95 |
+
### 3. Temporary Ban
|
96 |
+
|
97 |
+
**Community Impact**: A serious violation of community standards, including
|
98 |
+
sustained inappropriate behavior.
|
99 |
+
|
100 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
101 |
+
communication with the community for a specified period of time. No public or
|
102 |
+
private interaction with the people involved, including unsolicited interaction
|
103 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
104 |
+
Violating these terms may lead to a permanent ban.
|
105 |
+
|
106 |
+
### 4. Permanent Ban
|
107 |
+
|
108 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
109 |
+
standards, including sustained inappropriate behavior, harassment of an
|
110 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
111 |
+
|
112 |
+
**Consequence**: A permanent ban from any sort of public interaction within
|
113 |
+
the community.
|
114 |
+
|
115 |
+
## Attribution
|
116 |
+
|
117 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
118 |
+
version 2.0, available at
|
119 |
+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
120 |
+
|
121 |
+
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
122 |
+
enforcement ladder](https://github.com/mozilla/diversity).
|
123 |
+
|
124 |
+
[homepage]: https://www.contributor-covenant.org
|
125 |
+
|
126 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
127 |
+
https://www.contributor-covenant.org/faq. Translations are available at
|
128 |
+
https://www.contributor-covenant.org/translations.
|
Expr.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.cuda as cuda
|
2 |
+
|
3 |
+
from utils.misc import *
|
4 |
+
from run import *
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
args = parse_args()
|
8 |
+
opt = load_config(args.config)
|
9 |
+
|
10 |
+
train(opt, args)
|
11 |
+
cuda.empty_cache()
|
12 |
+
if args.local_rank <= 0:
|
13 |
+
test(opt, args)
|
14 |
+
evaluate(opt, args)
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Taehun Kim
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,3 +1,186 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Revisiting Image Pyramid Structure for High Resolution Salient Object Detection (InSPyReNet)
|
2 |
+
|
3 |
+
[](https://paperswithcode.com/sota/salient-object-detection-on-hku-is?p=revisiting-image-pyramid-structure-for-high)
|
4 |
+
[](https://paperswithcode.com/sota/salient-object-detection-on-duts-te?p=revisiting-image-pyramid-structure-for-high)
|
5 |
+
[](https://paperswithcode.com/sota/salient-object-detection-on-pascal-s?p=revisiting-image-pyramid-structure-for-high)
|
6 |
+
[](https://paperswithcode.com/sota/salient-object-detection-on-ecssd?p=revisiting-image-pyramid-structure-for-high)
|
7 |
+
[](https://paperswithcode.com/sota/salient-object-detection-on-dut-omron?p=revisiting-image-pyramid-structure-for-high)
|
8 |
+
|
9 |
+
[](https://paperswithcode.com/sota/rgb-salient-object-detection-on-davis-s?p=revisiting-image-pyramid-structure-for-high)
|
10 |
+
[](https://paperswithcode.com/sota/rgb-salient-object-detection-on-hrsod?p=revisiting-image-pyramid-structure-for-high)
|
11 |
+
[](https://paperswithcode.com/sota/rgb-salient-object-detection-on-uhrsd?p=revisiting-image-pyramid-structure-for-high)
|
12 |
+
|
13 |
+
[](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=revisiting-image-pyramid-structure-for-high)
|
14 |
+
[](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te1?p=revisiting-image-pyramid-structure-for-high)
|
15 |
+
[](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te2?p=revisiting-image-pyramid-structure-for-high)
|
16 |
+
[](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te3?p=revisiting-image-pyramid-structure-for-high)
|
17 |
+
[](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te4?p=revisiting-image-pyramid-structure-for-high)
|
18 |
+
|
19 |
+
Official PyTorch implementation of PyTorch implementation of Revisiting Image Pyramid Structure for High Resolution Salient Object Detection (InSPyReNet)
|
20 |
+
|
21 |
+
To appear in the 16th Asian Conference on Computer Vision (ACCV2022)
|
22 |
+
|
23 |
+
<p align="center">
|
24 |
+
<a href="https://github.com/plemeri/InSPyReNet/blob/main/LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue"></a>
|
25 |
+
<a href="https://arxiv.org/abs/2209.09475"><img src="https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg" ></a>
|
26 |
+
<a href="https://openaccess.thecvf.com/content/ACCV2022/html/Kim_Revisiting_Image_Pyramid_Structure_for_High_Resolution_Salient_Object_Detection_ACCV_2022_paper.html"><img src="https://img.shields.io/static/v1?label=inproceedings&message=Paper&color=orange"></a>
|
27 |
+
<a href="https://huggingface.co/spaces/taskswithcode/salient-object-detection"><img src="https://img.shields.io/static/v1?label=HuggingFace&message=Demo&color=yellow"></a>
|
28 |
+
<a href="https://www.taskswithcode.com/salient_object_detection/"><img src="https://img.shields.io/static/v1?label=TasksWithCode&message=Demo&color=blue"></a>
|
29 |
+
<a href="https://colab.research.google.com/github/taskswithcode/InSPyReNet/blob/main/TWCSOD.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg"></a>
|
30 |
+
</p>
|
31 |
+
|
32 |
+
> [Taehun Kim](https://scholar.google.co.kr/citations?user=f12-9yQAAAAJ&hl=en), [Kunhee Kim](https://scholar.google.co.kr/citations?user=6sU5r7MAAAAJ&hl=en), [Joonyeong Lee](https://scholar.google.co.kr/citations?hl=en&user=pOM4zSYAAAAJ), Dongmin Cha, [Jiho Lee](https://scholar.google.co.kr/citations?user=1Q1awj8AAAAJ&hl=en), [Daijin Kim](https://scholar.google.co.kr/citations?user=Mw6anjAAAAAJ&hl=en)
|
33 |
+
|
34 |
+
> **Abstract:**
|
35 |
+
Salient object detection (SOD) has been in the spotlight recently, yet has been studied less for high-resolution (HR) images.
|
36 |
+
Unfortunately, HR images and their pixel-level annotations are certainly more labor-intensive and time-consuming compared to low-resolution (LR) images.
|
37 |
+
Therefore, we propose an image pyramid-based SOD framework, Inverse Saliency Pyramid Reconstruction Network (InSPyReNet), for HR prediction without any of HR datasets.
|
38 |
+
We design InSPyReNet to produce a strict image pyramid structure of saliency map, which enables to ensemble multiple results with pyramid-based image blending.
|
39 |
+
For HR prediction, we design a pyramid blending method which synthesizes two different image pyramids from a pair of LR and HR scale from the same image to overcome effective receptive field (ERF) discrepancy. Our extensive evaluation on public LR and HR SOD benchmarks demonstrates that InSPyReNet surpasses the State-of-the-Art (SotA) methods on various SOD metrics and boundary accuracy.
|
40 |
+
|
41 |
+
## Contents
|
42 |
+
|
43 |
+
1. [News](#news-newspaper)
|
44 |
+
2. [Demo](#demo-rocket)
|
45 |
+
3. [Applications](#applications-video_game)
|
46 |
+
4. [Easy Download](#easy-download-cake)
|
47 |
+
5. [Getting Started](#getting-started-flight_departure)
|
48 |
+
6. [Model Zoo](#model-zoo-giraffe)
|
49 |
+
7. [Results](#results-100)
|
50 |
+
* [Quantitative Results](#quantitative-results)
|
51 |
+
* [Qualitative Results](#qualitative-results)
|
52 |
+
8. [Citation](#citation)
|
53 |
+
9. [Acknowledgement](#acknowledgement)
|
54 |
+
* [Special Thanks to](#special-thanks-to-tada)
|
55 |
+
10. [References](#references)
|
56 |
+
|
57 |
+
## News :newspaper:
|
58 |
+
|
59 |
+
[2022.10.04] [TasksWithCode](https://github.com/taskswithcode) mentioned our work in [Blog](https://medium.com/@taskswithcode/twc-9-7c960c921f69) and reproducing our work on [Colab](https://github.com/taskswithcode/InSPyReNet). Thank you for your attention!
|
60 |
+
|
61 |
+
[2022.10.20] We trained our model on [Dichotomous Image Segmentation dataset (DIS5K)](https://xuebinqin.github.io/dis/index.html) and showed competitive results! Trained checkpoint and pre-computed segmentation masks are available in [Model Zoo](./docs/model_zoo.md)). Also, you can check our qualitative and quantitative results in [Results](#results-100) section.
|
62 |
+
|
63 |
+
[2022.10.28] Multi GPU training for latest pytorch is now available.
|
64 |
+
|
65 |
+
[2022.10.31] [TasksWithCode](https://github.com/taskswithcode) provided an amazing web demo with [HuggingFace](https://huggingface.co). Visit the [WepApp](https://huggingface.co/spaces/taskswithcode/salient-object-detection) and try with your image!
|
66 |
+
|
67 |
+
[2022.11.09] :car: Lane segmentation for driving scene built based on InSPyReNet is available in [`LaneSOD`](https://github.com/plemeri/LaneSOD) repository.
|
68 |
+
|
69 |
+
[2022.11.18] I am speaking at The 16th Asian Conference on Computer Vision (ACCV2022). Please check out my talk if you're attending the event! #ACCV2022 #Macau - via #Whova event app
|
70 |
+
|
71 |
+
[2022.11.23] We made our work available on pypi package. Please visit [`transparent-background`](https://github.com/plemeri/transparent-background) to download our tool and try on your machine. It works as command-line tool and python API.
|
72 |
+
|
73 |
+
[2023.01.18] [rsreetech](https://github.com/rsreetech) shared a tutorial for our pypi package [`transparent-background`](https://github.com/plemeri/transparent-background) using colab. :tv: [[Youtube](https://www.youtube.com/watch?v=jKuQEnKmv4A)]
|
74 |
+
|
75 |
+
## Demo :rocket:
|
76 |
+
|
77 |
+
[Image Sample](./figures/demo_image.gif) | [Video Sample](./figures/demo_video.gif)
|
78 |
+
:-:|:-:
|
79 |
+
<img src=./figures/demo_image.gif height=200px> | <img src=./figures/demo_video.gif height=200px>
|
80 |
+
|
81 |
+
## Applications :video_game:
|
82 |
+
Here are some applications/extensions of our work.
|
83 |
+
### Web Application <img src=https://huggingface.co/front/assets/huggingface_logo-noborder.svg height="20px" width="20px">
|
84 |
+
[TasksWithCode](https://github.com/taskswithcode) provided [WepApp](https://huggingface.co/spaces/taskswithcode/salient-object-detection) on HuggingFace to generate your own results!
|
85 |
+
|
86 |
+
[Web Demo](https://huggingface.co/spaces/taskswithcode/salient-object-detection) |
|
87 |
+
|:-:
|
88 |
+
<img src=./figures/demo_webapp.gif height=200px> |
|
89 |
+
|
90 |
+
### Command-line Tool / Python API :pager:
|
91 |
+
Try using our model as command-line tool or python API. More details about how to use is available on [`transparent-background`](https://github.com/plemeri/transparent-background).
|
92 |
+
```bash
|
93 |
+
pip install transparent-background
|
94 |
+
```
|
95 |
+
### Lane Segmentation :car:
|
96 |
+
We extend our model to detect lane markers in a driving scene in [`LaneSOD`](https://github.com/plemeri/LaneSOD)
|
97 |
+
|
98 |
+
[Lane Segmentation](https://github.com/plemeri/LaneSOD) |
|
99 |
+
|:-:
|
100 |
+
<img src=https://github.com/plemeri/LaneSOD/blob/main/figures/Teaser.gif height=200px> |
|
101 |
+
|
102 |
+
## Easy Download :cake:
|
103 |
+
|
104 |
+
<details><summary>How to use easy download</summary>
|
105 |
+
<p>
|
106 |
+
|
107 |
+
Downloading each dataset, checkpoint is quite bothering, even for me :zzz:. Instead, you can download data we provide including `ImageNet pre-trained backbone checkpoints`, `Training Datasets`, `Testing Datasets for benchmark`, `Pre-trained model checkpoints`, `Pre-computed saliency maps` with single command below.
|
108 |
+
```bash
|
109 |
+
python utils/download.py --extra --dest [DEST]
|
110 |
+
```
|
111 |
+
|
112 |
+
* `--extra, -e`: Without this argument, only the datasets, checkpoint, and results from our main paper will be downloaded. Otherwise, all data will be downloaded including results from supplementary material and DIS5K results.
|
113 |
+
* `--dest [DEST], -d [DEST]`: If you want to specify the destination, use this argument. It will automatically create a symbolic links of the destination folders inside `data` and `snapshots`. Use this argument if you want to download data on other physical disk. Otherwise, it will download inside this repository folder.
|
114 |
+
|
115 |
+
If you want to download a certain checkpoint or pre-computed map, please refer to [Getting Started](#getting-started-flight_departure) and [Model Zoo](#model-zoo-giraffe).
|
116 |
+
|
117 |
+
</p>
|
118 |
+
</details>
|
119 |
+
|
120 |
+
## Getting Started :flight_departure:
|
121 |
+
|
122 |
+
Please refer to [getting_started.md](./docs/getting_started.md) for training, testing, and evaluating on benchmarks, and inferencing on your own images.
|
123 |
+
|
124 |
+
## Model Zoo :giraffe:
|
125 |
+
|
126 |
+
Please refer to [model_zoo.md](./docs/model_zoo.md) for downloading pre-trained models and pre-computed saliency maps.
|
127 |
+
|
128 |
+
## Results :100:
|
129 |
+
|
130 |
+
### Quantitative Results
|
131 |
+
|
132 |
+
[LR Benchmark](./figures/fig_quantitative.png) | [HR Benchmark](./figures/fig_quantitative2.png) | [HR Benchmark (Trained with extra DB)](./figures/fig_quantitative3.png) | [DIS](./figures/fig_quantitative4.png)
|
133 |
+
:-:|:-:|:-:|:-:
|
134 |
+
<img src=./figures/fig_quantitative.png height=200px> | <img src=./figures/fig_quantitative2.png height=200px> | <img src=./figures/fig_quantitative3.png height=200px> | <img src=./figures/fig_quantitative4.png height=200px>
|
135 |
+
|
136 |
+
</p>
|
137 |
+
</details>
|
138 |
+
|
139 |
+
### Qualitative Results
|
140 |
+
|
141 |
+
[DAVIS-S & HRSOD](./figures/fig_qualitative.png) | [UHRSD](./figures/fig_qualitative2.png) | [UHRSD (w/ HR scale)](./figures/fig_qualitative3.jpg) | [DIS](./figures/fig_qualitative_dis.png)
|
142 |
+
:-:|:-:|:-:|:-:
|
143 |
+
<img src=./figures/fig_qualitative.png height=200px> | <img src=./figures/fig_qualitative2.png height=200px> | <img src=./figures/fig_qualitative3.jpg height=200px> | <img src=./figures/fig_qualitative_dis.png height=200px>
|
144 |
+
|
145 |
+
## Citation
|
146 |
+
|
147 |
+
```
|
148 |
+
@inproceedings{kim2022revisiting,
|
149 |
+
title={Revisiting Image Pyramid Structure for High Resolution Salient Object Detection},
|
150 |
+
author={Kim, Taehun and Kim, Kunhee and Lee, Joonyeong and Cha, Dongmin and Lee, Jiho and Kim, Daijin},
|
151 |
+
booktitle={Proceedings of the Asian Conference on Computer Vision},
|
152 |
+
pages={108--124},
|
153 |
+
year={2022}
|
154 |
+
}
|
155 |
+
```
|
156 |
+
|
157 |
+
## Acknowledgement
|
158 |
+
|
159 |
+
This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT)
|
160 |
+
(No.2017-0-00897, Development of Object Detection and Recognition for Intelligent Vehicles) and
|
161 |
+
(No.B0101-15-0266, Development of High Performance Visual BigData Discovery Platform for Large-Scale Realtime Data Analysis)
|
162 |
+
|
163 |
+
### Special Thanks to :tada:
|
164 |
+
* [TasksWithCode](https://github.com/taskswithcode) team for sharing our work and making the most amazing web app demo.
|
165 |
+
|
166 |
+
|
167 |
+
## References
|
168 |
+
|
169 |
+
### Related Works
|
170 |
+
|
171 |
+
* Towards High-Resolution Salient Object Detection ([paper](https://drive.google.com/open?id=15o-Fel0BSyNulGoptrxfHR0t22qMHlTr) | [github](https://github.com/yi94code/HRSOD))
|
172 |
+
* Disentangled High Quality Salient Object Detection ([paper](https://openaccess.thecvf.com/content/ICCV2021/papers/Tang_Disentangled_High_Quality_Salient_Object_Detection_ICCV_2021_paper.pdf) | [github](https://github.com/luckybird1994/HQSOD))
|
173 |
+
* Pyramid Grafting Network for One-Stage High Resolution Saliency Detection ([paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Xie_Pyramid_Grafting_Network_for_One-Stage_High_Resolution_Saliency_Detection_CVPR_2022_paper.pdf) | [github](https://github.com/iCVTEAM/PGNet))
|
174 |
+
|
175 |
+
### Resources
|
176 |
+
|
177 |
+
* Backbones: [Res2Net](https://github.com/Res2Net/Res2Net-PretrainedModels), [Swin Transformer](https://github.com/microsoft/Swin-Transformer)
|
178 |
+
|
179 |
+
* Datasets
|
180 |
+
* LR Benchmarks: [DUTS](http://saliencydetection.net/duts/), [DUT-OMRON](http://saliencydetection.net/dut-omron/), [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html), [HKU-IS](https://i.cs.hku.hk/~gbli/deep_saliency.html), [PASCAL-S](http://cbi.gatech.edu/salobj/)
|
181 |
+
* HR Benchmarks: [DAVIS-S, HRSOD](https://github.com/yi94code/HRSOD), [UHRSD](https://github.com/iCVTEAM/PGNet)
|
182 |
+
* Dichotomous Image Segmentation: [DIS5K](https://xuebinqin.github.io/dis/index.html)
|
183 |
+
|
184 |
+
* Evaluation Toolkit
|
185 |
+
* SOD Metrics (e.g., S-measure): [PySOD Metrics](https://github.com/lartpang/PySODMetrics)
|
186 |
+
* Boundary Metric (mBA): [CascadePSP](https://github.com/hkchengrex/CascadePSP)
|
configs/InSPyReNet_Res2Net50.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_Res2Net50"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [384, 384]
|
6 |
+
threshold: NULL
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data/Train_Dataset"
|
12 |
+
sets: ['DUTS-TR']
|
13 |
+
transforms:
|
14 |
+
static_resize:
|
15 |
+
size: [384, 384]
|
16 |
+
random_scale_crop:
|
17 |
+
range: [0.75, 1.25]
|
18 |
+
random_flip:
|
19 |
+
lr: True
|
20 |
+
ud: False
|
21 |
+
random_rotate:
|
22 |
+
range: [-10, 10]
|
23 |
+
random_image_enhance:
|
24 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
25 |
+
tonumpy: NULL
|
26 |
+
normalize:
|
27 |
+
mean: [0.485, 0.456, 0.406]
|
28 |
+
std: [0.229, 0.224, 0.225]
|
29 |
+
totensor: NULL
|
30 |
+
Dataloader:
|
31 |
+
batch_size: 6
|
32 |
+
shuffle: True
|
33 |
+
num_workers: 8
|
34 |
+
pin_memory: True
|
35 |
+
Optimizer:
|
36 |
+
type: "Adam"
|
37 |
+
lr: 1.0e-05
|
38 |
+
weight_decay: 0.0
|
39 |
+
mixed_precision: False
|
40 |
+
Scheduler:
|
41 |
+
type: "PolyLr"
|
42 |
+
epoch: 60
|
43 |
+
gamma: 0.9
|
44 |
+
minimum_lr: 1.0e-07
|
45 |
+
warmup_iteration: 12000
|
46 |
+
Checkpoint:
|
47 |
+
checkpoint_epoch: 1
|
48 |
+
checkpoint_dir: "snapshots/InSPyReNet_Res2Net50"
|
49 |
+
Debug:
|
50 |
+
keys: ['saliency', 'laplacian']
|
51 |
+
|
52 |
+
Test:
|
53 |
+
Dataset:
|
54 |
+
type: "RGB_Dataset"
|
55 |
+
root: "data/Test_Dataset"
|
56 |
+
sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD', 'UHRSD-TE']
|
57 |
+
transforms:
|
58 |
+
static_resize:
|
59 |
+
size: [384, 384]
|
60 |
+
tonumpy: NULL
|
61 |
+
normalize:
|
62 |
+
mean: [0.485, 0.456, 0.406]
|
63 |
+
std: [0.229, 0.224, 0.225]
|
64 |
+
totensor: NULL
|
65 |
+
Dataloader:
|
66 |
+
num_workers: 8
|
67 |
+
pin_memory: True
|
68 |
+
Checkpoint:
|
69 |
+
checkpoint_dir: "snapshots/InSPyReNet_Res2Net50"
|
70 |
+
|
71 |
+
Eval:
|
72 |
+
gt_root: "data/Test_Dataset"
|
73 |
+
pred_root: "snapshots/InSPyReNet_Res2Net50"
|
74 |
+
result_path: "results"
|
75 |
+
datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD', 'UHRSD-TE']
|
76 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
configs/InSPyReNet_SwinB.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_SwinB"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [384, 384]
|
6 |
+
threshold: 512
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data/Train_Dataset"
|
12 |
+
sets: ['DUTS-TR']
|
13 |
+
transforms:
|
14 |
+
static_resize:
|
15 |
+
size: [384, 384]
|
16 |
+
random_scale_crop:
|
17 |
+
range: [0.75, 1.25]
|
18 |
+
random_flip:
|
19 |
+
lr: True
|
20 |
+
ud: False
|
21 |
+
random_rotate:
|
22 |
+
range: [-10, 10]
|
23 |
+
random_image_enhance:
|
24 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
25 |
+
tonumpy: NULL
|
26 |
+
normalize:
|
27 |
+
mean: [0.485, 0.456, 0.406]
|
28 |
+
std: [0.229, 0.224, 0.225]
|
29 |
+
totensor: NULL
|
30 |
+
Dataloader:
|
31 |
+
batch_size: 6
|
32 |
+
shuffle: True
|
33 |
+
num_workers: 8
|
34 |
+
pin_memory: False
|
35 |
+
Optimizer:
|
36 |
+
type: "Adam"
|
37 |
+
lr: 1.0e-05
|
38 |
+
weight_decay: 0.0
|
39 |
+
mixed_precision: False
|
40 |
+
Scheduler:
|
41 |
+
type: "PolyLr"
|
42 |
+
epoch: 60
|
43 |
+
gamma: 0.9
|
44 |
+
minimum_lr: 1.0e-07
|
45 |
+
warmup_iteration: 12000
|
46 |
+
Checkpoint:
|
47 |
+
checkpoint_epoch: 1
|
48 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB"
|
49 |
+
Debug:
|
50 |
+
keys: ['saliency', 'laplacian']
|
51 |
+
|
52 |
+
Test:
|
53 |
+
Dataset:
|
54 |
+
type: "RGB_Dataset"
|
55 |
+
root: "data/Test_Dataset"
|
56 |
+
sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
57 |
+
transforms:
|
58 |
+
dynamic_resize:
|
59 |
+
L: 1280
|
60 |
+
tonumpy: NULL
|
61 |
+
normalize:
|
62 |
+
mean: [0.485, 0.456, 0.406]
|
63 |
+
std: [0.229, 0.224, 0.225]
|
64 |
+
totensor: NULL
|
65 |
+
Dataloader:
|
66 |
+
num_workers: 8
|
67 |
+
pin_memory: True
|
68 |
+
Checkpoint:
|
69 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB"
|
70 |
+
|
71 |
+
Eval:
|
72 |
+
gt_root: "data/Test_Dataset"
|
73 |
+
pred_root: "snapshots/InSPyReNet_SwinB"
|
74 |
+
result_path: "results"
|
75 |
+
datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
76 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
configs/extra_dataset/InSPyReNet_SwinB_DH.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_SwinB"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [1024, 1024]
|
6 |
+
threshold: NULL
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data/Train_Dataset"
|
12 |
+
sets: ['DUTS-TR', 'HRSOD-TR']
|
13 |
+
transforms:
|
14 |
+
static_resize:
|
15 |
+
size: [1024, 1024]
|
16 |
+
random_scale_crop:
|
17 |
+
range: [0.75, 1.25]
|
18 |
+
random_flip:
|
19 |
+
lr: True
|
20 |
+
ud: False
|
21 |
+
random_rotate:
|
22 |
+
range: [-10, 10]
|
23 |
+
random_image_enhance:
|
24 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
25 |
+
tonumpy: NULL
|
26 |
+
normalize:
|
27 |
+
mean: [0.485, 0.456, 0.406]
|
28 |
+
std: [0.229, 0.224, 0.225]
|
29 |
+
totensor: NULL
|
30 |
+
Dataloader:
|
31 |
+
batch_size: 1
|
32 |
+
shuffle: True
|
33 |
+
num_workers: 8
|
34 |
+
pin_memory: False
|
35 |
+
Optimizer:
|
36 |
+
type: "Adam"
|
37 |
+
lr: 1.0e-05
|
38 |
+
weight_decay: 0.0
|
39 |
+
mixed_precision: False
|
40 |
+
Scheduler:
|
41 |
+
type: "PolyLr"
|
42 |
+
epoch: 60
|
43 |
+
gamma: 0.9
|
44 |
+
minimum_lr: 1.0e-07
|
45 |
+
warmup_iteration: 12000
|
46 |
+
Checkpoint:
|
47 |
+
checkpoint_epoch: 1
|
48 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_DH"
|
49 |
+
Debug:
|
50 |
+
keys: ['saliency', 'laplacian']
|
51 |
+
|
52 |
+
Test:
|
53 |
+
Dataset:
|
54 |
+
type: "RGB_Dataset"
|
55 |
+
root: "data/Test_Dataset"
|
56 |
+
sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
57 |
+
transforms:
|
58 |
+
static_resize:
|
59 |
+
size: [1024, 1024]
|
60 |
+
tonumpy: NULL
|
61 |
+
normalize:
|
62 |
+
mean: [0.485, 0.456, 0.406]
|
63 |
+
std: [0.229, 0.224, 0.225]
|
64 |
+
totensor: NULL
|
65 |
+
Dataloader:
|
66 |
+
num_workers: 8
|
67 |
+
pin_memory: True
|
68 |
+
Checkpoint:
|
69 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_DH"
|
70 |
+
|
71 |
+
Eval:
|
72 |
+
gt_root: "data/Test_Dataset"
|
73 |
+
pred_root: "snapshots/InSPyReNet_SwinB_DH"
|
74 |
+
result_path: "results"
|
75 |
+
datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
76 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
configs/extra_dataset/InSPyReNet_SwinB_DHU_LR.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_SwinB"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [384, 384]
|
6 |
+
threshold: 512
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data/Train_Dataset"
|
12 |
+
sets: ['DUTS-TR', 'HRSOD-TR', 'UHRSD-TR']
|
13 |
+
transforms:
|
14 |
+
static_resize:
|
15 |
+
size: [384, 384]
|
16 |
+
random_scale_crop:
|
17 |
+
range: [0.75, 1.25]
|
18 |
+
random_flip:
|
19 |
+
lr: True
|
20 |
+
ud: False
|
21 |
+
random_rotate:
|
22 |
+
range: [-10, 10]
|
23 |
+
random_image_enhance:
|
24 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
25 |
+
tonumpy: NULL
|
26 |
+
normalize:
|
27 |
+
mean: [0.485, 0.456, 0.406]
|
28 |
+
std: [0.229, 0.224, 0.225]
|
29 |
+
totensor: NULL
|
30 |
+
Dataloader:
|
31 |
+
batch_size: 6
|
32 |
+
shuffle: True
|
33 |
+
num_workers: 8
|
34 |
+
pin_memory: False
|
35 |
+
Optimizer:
|
36 |
+
type: "Adam"
|
37 |
+
lr: 1.0e-05
|
38 |
+
weight_decay: 0.0
|
39 |
+
mixed_precision: False
|
40 |
+
Scheduler:
|
41 |
+
type: "PolyLr"
|
42 |
+
epoch: 60
|
43 |
+
gamma: 0.9
|
44 |
+
minimum_lr: 1.0e-07
|
45 |
+
warmup_iteration: 12000
|
46 |
+
Checkpoint:
|
47 |
+
checkpoint_epoch: 1
|
48 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_DHU_LR"
|
49 |
+
Debug:
|
50 |
+
keys: ['saliency', 'laplacian']
|
51 |
+
|
52 |
+
Test:
|
53 |
+
Dataset:
|
54 |
+
type: "RGB_Dataset"
|
55 |
+
root: "data/Test_Dataset"
|
56 |
+
sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
57 |
+
transforms:
|
58 |
+
dynamic_resize:
|
59 |
+
L: 1280
|
60 |
+
tonumpy: NULL
|
61 |
+
normalize:
|
62 |
+
mean: [0.485, 0.456, 0.406]
|
63 |
+
std: [0.229, 0.224, 0.225]
|
64 |
+
totensor: NULL
|
65 |
+
Dataloader:
|
66 |
+
num_workers: 8
|
67 |
+
pin_memory: True
|
68 |
+
Checkpoint:
|
69 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_DHU_LR"
|
70 |
+
|
71 |
+
Eval:
|
72 |
+
gt_root: "data/Test_Dataset"
|
73 |
+
pred_root: "snapshots/InSPyReNet_SwinB_DHU_LR"
|
74 |
+
result_path: "results"
|
75 |
+
datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
76 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
configs/extra_dataset/InSPyReNet_SwinB_DH_LR.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_SwinB"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [384, 384]
|
6 |
+
threshold: 512
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data/Train_Dataset"
|
12 |
+
sets: ['DUTS-TR', 'HRSOD-TR']
|
13 |
+
transforms:
|
14 |
+
static_resize:
|
15 |
+
size: [384, 384]
|
16 |
+
random_scale_crop:
|
17 |
+
range: [0.75, 1.25]
|
18 |
+
random_flip:
|
19 |
+
lr: True
|
20 |
+
ud: False
|
21 |
+
random_rotate:
|
22 |
+
range: [-10, 10]
|
23 |
+
random_image_enhance:
|
24 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
25 |
+
tonumpy: NULL
|
26 |
+
normalize:
|
27 |
+
mean: [0.485, 0.456, 0.406]
|
28 |
+
std: [0.229, 0.224, 0.225]
|
29 |
+
totensor: NULL
|
30 |
+
Dataloader:
|
31 |
+
batch_size: 6
|
32 |
+
shuffle: True
|
33 |
+
num_workers: 8
|
34 |
+
pin_memory: False
|
35 |
+
Optimizer:
|
36 |
+
type: "Adam"
|
37 |
+
lr: 1.0e-05
|
38 |
+
weight_decay: 0.0
|
39 |
+
mixed_precision: False
|
40 |
+
Scheduler:
|
41 |
+
type: "PolyLr"
|
42 |
+
epoch: 60
|
43 |
+
gamma: 0.9
|
44 |
+
minimum_lr: 1.0e-07
|
45 |
+
warmup_iteration: 12000
|
46 |
+
Checkpoint:
|
47 |
+
checkpoint_epoch: 1
|
48 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_DH_LR"
|
49 |
+
Debug:
|
50 |
+
keys: ['saliency', 'laplacian']
|
51 |
+
|
52 |
+
Test:
|
53 |
+
Dataset:
|
54 |
+
type: "RGB_Dataset"
|
55 |
+
root: "data/Test_Dataset"
|
56 |
+
sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
57 |
+
transforms:
|
58 |
+
dynamic_resize:
|
59 |
+
L: 1280
|
60 |
+
tonumpy: NULL
|
61 |
+
normalize:
|
62 |
+
mean: [0.485, 0.456, 0.406]
|
63 |
+
std: [0.229, 0.224, 0.225]
|
64 |
+
totensor: NULL
|
65 |
+
Dataloader:
|
66 |
+
num_workers: 8
|
67 |
+
pin_memory: True
|
68 |
+
Checkpoint:
|
69 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_DH_LR"
|
70 |
+
|
71 |
+
Eval:
|
72 |
+
gt_root: "data/Test_Dataset"
|
73 |
+
pred_root: "snapshots/InSPyReNet_SwinB_DH_LR"
|
74 |
+
result_path: "results"
|
75 |
+
datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
76 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
configs/extra_dataset/InSPyReNet_SwinB_DIS5K.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_SwinB"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [1024, 1024]
|
6 |
+
threshold: NULL
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data/Train_Dataset"
|
12 |
+
sets: ['DIS-TR']
|
13 |
+
transforms:
|
14 |
+
static_resize:
|
15 |
+
size: [1024, 1024]
|
16 |
+
random_scale_crop:
|
17 |
+
range: [0.75, 1.25]
|
18 |
+
random_flip:
|
19 |
+
lr: True
|
20 |
+
ud: False
|
21 |
+
random_rotate:
|
22 |
+
range: [-10, 10]
|
23 |
+
random_image_enhance:
|
24 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
25 |
+
tonumpy: NULL
|
26 |
+
normalize:
|
27 |
+
mean: [0.485, 0.456, 0.406]
|
28 |
+
std: [0.229, 0.224, 0.225]
|
29 |
+
totensor: NULL
|
30 |
+
Dataloader:
|
31 |
+
batch_size: 1
|
32 |
+
shuffle: True
|
33 |
+
num_workers: 8
|
34 |
+
pin_memory: False
|
35 |
+
Optimizer:
|
36 |
+
type: "Adam"
|
37 |
+
lr: 1.0e-05
|
38 |
+
weight_decay: 0.0
|
39 |
+
mixed_precision: False
|
40 |
+
Scheduler:
|
41 |
+
type: "PolyLr"
|
42 |
+
epoch: 180
|
43 |
+
gamma: 0.9
|
44 |
+
minimum_lr: 1.0e-07
|
45 |
+
warmup_iteration: 12000
|
46 |
+
Checkpoint:
|
47 |
+
checkpoint_epoch: 1
|
48 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_DIS5K"
|
49 |
+
Debug:
|
50 |
+
keys: ['saliency', 'laplacian']
|
51 |
+
|
52 |
+
Test:
|
53 |
+
Dataset:
|
54 |
+
type: "RGB_Dataset"
|
55 |
+
root: "data/Test_Dataset"
|
56 |
+
sets: ['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']
|
57 |
+
transforms:
|
58 |
+
static_resize:
|
59 |
+
size: [1024, 1024]
|
60 |
+
tonumpy: NULL
|
61 |
+
normalize:
|
62 |
+
mean: [0.485, 0.456, 0.406]
|
63 |
+
std: [0.229, 0.224, 0.225]
|
64 |
+
totensor: NULL
|
65 |
+
Dataloader:
|
66 |
+
num_workers: 8
|
67 |
+
pin_memory: True
|
68 |
+
Checkpoint:
|
69 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_DIS5K"
|
70 |
+
|
71 |
+
Eval:
|
72 |
+
gt_root: "data/Test_Dataset"
|
73 |
+
pred_root: "snapshots/InSPyReNet_SwinB_DIS5K"
|
74 |
+
result_path: "results"
|
75 |
+
datasets: ['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']
|
76 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
configs/extra_dataset/InSPyReNet_SwinB_DIS5K_LR.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_SwinB"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [384, 384]
|
6 |
+
threshold: 512
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data/Train_Dataset"
|
12 |
+
sets: ['DIS-TR']
|
13 |
+
transforms:
|
14 |
+
static_resize:
|
15 |
+
size: [384, 384]
|
16 |
+
random_scale_crop:
|
17 |
+
range: [0.75, 1.25]
|
18 |
+
random_flip:
|
19 |
+
lr: True
|
20 |
+
ud: False
|
21 |
+
random_rotate:
|
22 |
+
range: [-10, 10]
|
23 |
+
random_image_enhance:
|
24 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
25 |
+
tonumpy: NULL
|
26 |
+
normalize:
|
27 |
+
mean: [0.485, 0.456, 0.406]
|
28 |
+
std: [0.229, 0.224, 0.225]
|
29 |
+
totensor: NULL
|
30 |
+
Dataloader:
|
31 |
+
batch_size: 6
|
32 |
+
shuffle: True
|
33 |
+
num_workers: 8
|
34 |
+
pin_memory: False
|
35 |
+
Optimizer:
|
36 |
+
type: "Adam"
|
37 |
+
lr: 1.0e-05
|
38 |
+
weight_decay: 0.0
|
39 |
+
mixed_precision: False
|
40 |
+
Scheduler:
|
41 |
+
type: "PolyLr"
|
42 |
+
epoch: 180
|
43 |
+
gamma: 0.9
|
44 |
+
minimum_lr: 1.0e-07
|
45 |
+
warmup_iteration: 12000
|
46 |
+
Checkpoint:
|
47 |
+
checkpoint_epoch: 1
|
48 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_DIS5K_LR"
|
49 |
+
Debug:
|
50 |
+
keys: ['saliency', 'laplacian']
|
51 |
+
|
52 |
+
Test:
|
53 |
+
Dataset:
|
54 |
+
type: "RGB_Dataset"
|
55 |
+
root: "data/Test_Dataset"
|
56 |
+
sets: ['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']
|
57 |
+
transforms:
|
58 |
+
dynamic_resize:
|
59 |
+
L: 1280
|
60 |
+
tonumpy: NULL
|
61 |
+
normalize:
|
62 |
+
mean: [0.485, 0.456, 0.406]
|
63 |
+
std: [0.229, 0.224, 0.225]
|
64 |
+
totensor: NULL
|
65 |
+
Dataloader:
|
66 |
+
num_workers: 8
|
67 |
+
pin_memory: True
|
68 |
+
Checkpoint:
|
69 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_DIS5K_LR"
|
70 |
+
|
71 |
+
Eval:
|
72 |
+
gt_root: "data/Test_Dataset"
|
73 |
+
pred_root: "snapshots/InSPyReNet_SwinB_DIS5K_LR"
|
74 |
+
result_path: "results"
|
75 |
+
datasets: ['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']
|
76 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
configs/extra_dataset/InSPyReNet_SwinB_HU.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_SwinB"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [1024, 1024]
|
6 |
+
threshold: NULL
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data/Train_Dataset"
|
12 |
+
sets: ['HRSOD-TR', 'UHRSD-TR']
|
13 |
+
transforms:
|
14 |
+
static_resize:
|
15 |
+
size: [1024, 1024]
|
16 |
+
random_scale_crop:
|
17 |
+
range: [0.75, 1.25]
|
18 |
+
random_flip:
|
19 |
+
lr: True
|
20 |
+
ud: False
|
21 |
+
random_rotate:
|
22 |
+
range: [-10, 10]
|
23 |
+
random_image_enhance:
|
24 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
25 |
+
tonumpy: NULL
|
26 |
+
normalize:
|
27 |
+
mean: [0.485, 0.456, 0.406]
|
28 |
+
std: [0.229, 0.224, 0.225]
|
29 |
+
totensor: NULL
|
30 |
+
Dataloader:
|
31 |
+
batch_size: 1
|
32 |
+
shuffle: True
|
33 |
+
num_workers: 8
|
34 |
+
pin_memory: False
|
35 |
+
Optimizer:
|
36 |
+
type: "Adam"
|
37 |
+
lr: 1.0e-05
|
38 |
+
weight_decay: 0.0
|
39 |
+
mixed_precision: False
|
40 |
+
Scheduler:
|
41 |
+
type: "PolyLr"
|
42 |
+
epoch: 60
|
43 |
+
gamma: 0.9
|
44 |
+
minimum_lr: 1.0e-07
|
45 |
+
warmup_iteration: 12000
|
46 |
+
Checkpoint:
|
47 |
+
checkpoint_epoch: 1
|
48 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_HU"
|
49 |
+
Debug:
|
50 |
+
keys: ['saliency', 'laplacian']
|
51 |
+
|
52 |
+
Test:
|
53 |
+
Dataset:
|
54 |
+
type: "RGB_Dataset"
|
55 |
+
root: "data/Test_Dataset"
|
56 |
+
sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
57 |
+
transforms:
|
58 |
+
static_resize:
|
59 |
+
size: [1024, 1024]
|
60 |
+
tonumpy: NULL
|
61 |
+
normalize:
|
62 |
+
mean: [0.485, 0.456, 0.406]
|
63 |
+
std: [0.229, 0.224, 0.225]
|
64 |
+
totensor: NULL
|
65 |
+
Dataloader:
|
66 |
+
num_workers: 8
|
67 |
+
pin_memory: True
|
68 |
+
Checkpoint:
|
69 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_HU"
|
70 |
+
|
71 |
+
Eval:
|
72 |
+
gt_root: "data/Test_Dataset"
|
73 |
+
pred_root: "snapshots/InSPyReNet_SwinB_HU"
|
74 |
+
result_path: "results"
|
75 |
+
datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
76 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
configs/extra_dataset/InSPyReNet_SwinB_HU_LR.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_SwinB"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [384, 384]
|
6 |
+
threshold: 512
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data/Train_Dataset"
|
12 |
+
sets: ['HRSOD-TR', 'UHRSD-TR']
|
13 |
+
transforms:
|
14 |
+
static_resize:
|
15 |
+
size: [384, 384]
|
16 |
+
random_scale_crop:
|
17 |
+
range: [0.75, 1.25]
|
18 |
+
random_flip:
|
19 |
+
lr: True
|
20 |
+
ud: False
|
21 |
+
random_rotate:
|
22 |
+
range: [-10, 10]
|
23 |
+
random_image_enhance:
|
24 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
25 |
+
tonumpy: NULL
|
26 |
+
normalize:
|
27 |
+
mean: [0.485, 0.456, 0.406]
|
28 |
+
std: [0.229, 0.224, 0.225]
|
29 |
+
totensor: NULL
|
30 |
+
Dataloader:
|
31 |
+
batch_size: 6
|
32 |
+
shuffle: True
|
33 |
+
num_workers: 8
|
34 |
+
pin_memory: False
|
35 |
+
Optimizer:
|
36 |
+
type: "Adam"
|
37 |
+
lr: 1.0e-05
|
38 |
+
weight_decay: 0.0
|
39 |
+
mixed_precision: False
|
40 |
+
Scheduler:
|
41 |
+
type: "PolyLr"
|
42 |
+
epoch: 60
|
43 |
+
gamma: 0.9
|
44 |
+
minimum_lr: 1.0e-07
|
45 |
+
warmup_iteration: 12000
|
46 |
+
Checkpoint:
|
47 |
+
checkpoint_epoch: 1
|
48 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_HU_LR"
|
49 |
+
Debug:
|
50 |
+
keys: ['saliency', 'laplacian']
|
51 |
+
|
52 |
+
Test:
|
53 |
+
Dataset:
|
54 |
+
type: "RGB_Dataset"
|
55 |
+
root: "data/Test_Dataset"
|
56 |
+
sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
57 |
+
transforms:
|
58 |
+
dynamic_resize:
|
59 |
+
L: 1280
|
60 |
+
tonumpy: NULL
|
61 |
+
normalize:
|
62 |
+
mean: [0.485, 0.456, 0.406]
|
63 |
+
std: [0.229, 0.224, 0.225]
|
64 |
+
totensor: NULL
|
65 |
+
Dataloader:
|
66 |
+
num_workers: 8
|
67 |
+
pin_memory: True
|
68 |
+
Checkpoint:
|
69 |
+
checkpoint_dir: "snapshots/InSPyReNet_SwinB_HU_LR"
|
70 |
+
|
71 |
+
Eval:
|
72 |
+
gt_root: "data/Test_Dataset"
|
73 |
+
pred_root: "snapshots/InSPyReNet_SwinB_HU_LR"
|
74 |
+
result_path: "results"
|
75 |
+
datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
76 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
configs/extra_dataset/Plus_Ultra.yaml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_SwinB"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [1024, 1024]
|
6 |
+
threshold: NULL
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data"
|
12 |
+
sets: [
|
13 |
+
'Train_Dataset/DUTS-TR',
|
14 |
+
'Train_Dataset/HRSOD-TR',
|
15 |
+
'Train_Dataset/UHRSD-TR',
|
16 |
+
'Train_Dataset/DIS-TR',
|
17 |
+
'Test_Dataset/DUTS-TE',
|
18 |
+
'Test_Dataset/DUT-OMRON',
|
19 |
+
'Test_Dataset/ECSSD',
|
20 |
+
'Test_Dataset/HKU-IS',
|
21 |
+
'Test_Dataset/PASCAL-S',
|
22 |
+
'Test_Dataset/DAVIS-S',
|
23 |
+
'Test_Dataset/HRSOD-TE',
|
24 |
+
'Test_Dataset/UHRSD-TE',
|
25 |
+
'Test_Dataset/FSS-1000',
|
26 |
+
'Test_Dataset/MSRA-10K',
|
27 |
+
'Test_Dataset/DIS-VD',
|
28 |
+
'Test_Dataset/DIS-TE1',
|
29 |
+
'Test_Dataset/DIS-TE2',
|
30 |
+
'Test_Dataset/DIS-TE3',
|
31 |
+
'Test_Dataset/DIS-TE4'
|
32 |
+
]
|
33 |
+
transforms:
|
34 |
+
static_resize:
|
35 |
+
size: [1024, 1024]
|
36 |
+
random_scale_crop:
|
37 |
+
range: [0.75, 1.25]
|
38 |
+
random_flip:
|
39 |
+
lr: True
|
40 |
+
ud: False
|
41 |
+
random_rotate:
|
42 |
+
range: [-10, 10]
|
43 |
+
random_image_enhance:
|
44 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
45 |
+
tonumpy: NULL
|
46 |
+
normalize:
|
47 |
+
mean: [0.485, 0.456, 0.406]
|
48 |
+
std: [0.229, 0.224, 0.225]
|
49 |
+
totensor: NULL
|
50 |
+
Dataloader:
|
51 |
+
batch_size: 1
|
52 |
+
shuffle: True
|
53 |
+
num_workers: 8
|
54 |
+
pin_memory: False
|
55 |
+
Optimizer:
|
56 |
+
type: "Adam"
|
57 |
+
lr: 1.0e-05
|
58 |
+
weight_decay: 0.0
|
59 |
+
mixed_precision: False
|
60 |
+
Scheduler:
|
61 |
+
type: "PolyLr"
|
62 |
+
epoch: 60
|
63 |
+
gamma: 0.9
|
64 |
+
minimum_lr: 1.0e-07
|
65 |
+
warmup_iteration: 12000
|
66 |
+
Checkpoint:
|
67 |
+
checkpoint_epoch: 1
|
68 |
+
checkpoint_dir: "snapshots/Plus_Ultra"
|
69 |
+
Debug:
|
70 |
+
keys: ['saliency', 'laplacian']
|
71 |
+
|
72 |
+
Test:
|
73 |
+
Dataset:
|
74 |
+
type: "RGB_Dataset"
|
75 |
+
root: "data/Test_Dataset"
|
76 |
+
sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
77 |
+
transforms:
|
78 |
+
# static_resize:
|
79 |
+
# size: [1024, 1024]
|
80 |
+
dynamic_resize:
|
81 |
+
L: 1280
|
82 |
+
tonumpy: NULL
|
83 |
+
normalize:
|
84 |
+
mean: [0.485, 0.456, 0.406]
|
85 |
+
std: [0.229, 0.224, 0.225]
|
86 |
+
totensor: NULL
|
87 |
+
Dataloader:
|
88 |
+
num_workers: 8
|
89 |
+
pin_memory: True
|
90 |
+
Checkpoint:
|
91 |
+
checkpoint_dir: "snapshots/Plus_Ultra"
|
92 |
+
|
93 |
+
Eval:
|
94 |
+
gt_root: "data/Test_Dataset"
|
95 |
+
pred_root: "snapshots/Plus_Ultra"
|
96 |
+
result_path: "results"
|
97 |
+
datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
98 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
configs/extra_dataset/Plus_Ultra_LR.yaml
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model:
|
2 |
+
name: "InSPyReNet_SwinB"
|
3 |
+
depth: 64
|
4 |
+
pretrained: True
|
5 |
+
base_size: [384, 384]
|
6 |
+
threshold: 512
|
7 |
+
|
8 |
+
Train:
|
9 |
+
Dataset:
|
10 |
+
type: "RGB_Dataset"
|
11 |
+
root: "data"
|
12 |
+
sets: [
|
13 |
+
'Train_Dataset/DUTS-TR',
|
14 |
+
'Train_Dataset/HRSOD-TR',
|
15 |
+
'Train_Dataset/UHRSD-TR',
|
16 |
+
'Train_Dataset/DIS-TR',
|
17 |
+
'Test_Dataset/DUTS-TE',
|
18 |
+
'Test_Dataset/DUT-OMRON',
|
19 |
+
'Test_Dataset/ECSSD',
|
20 |
+
'Test_Dataset/HKU-IS',
|
21 |
+
'Test_Dataset/PASCAL-S',
|
22 |
+
'Test_Dataset/DAVIS-S',
|
23 |
+
'Test_Dataset/HRSOD-TE',
|
24 |
+
'Test_Dataset/UHRSD-TE',
|
25 |
+
'Test_Dataset/FSS-1000',
|
26 |
+
'Test_Dataset/MSRA-10K',
|
27 |
+
'Test_Dataset/DIS-VD',
|
28 |
+
'Test_Dataset/DIS-TE1',
|
29 |
+
'Test_Dataset/DIS-TE2',
|
30 |
+
'Test_Dataset/DIS-TE3',
|
31 |
+
'Test_Dataset/DIS-TE4'
|
32 |
+
]
|
33 |
+
transforms:
|
34 |
+
static_resize:
|
35 |
+
size: [384, 384]
|
36 |
+
random_scale_crop:
|
37 |
+
range: [0.75, 1.25]
|
38 |
+
random_flip:
|
39 |
+
lr: True
|
40 |
+
ud: False
|
41 |
+
random_rotate:
|
42 |
+
range: [-10, 10]
|
43 |
+
random_image_enhance:
|
44 |
+
methods: ['contrast', 'sharpness', 'brightness']
|
45 |
+
tonumpy: NULL
|
46 |
+
normalize:
|
47 |
+
mean: [0.485, 0.456, 0.406]
|
48 |
+
std: [0.229, 0.224, 0.225]
|
49 |
+
totensor: NULL
|
50 |
+
Dataloader:
|
51 |
+
batch_size: 6
|
52 |
+
shuffle: True
|
53 |
+
num_workers: 8
|
54 |
+
pin_memory: False
|
55 |
+
Optimizer:
|
56 |
+
type: "Adam"
|
57 |
+
lr: 1.0e-05
|
58 |
+
weight_decay: 0.0
|
59 |
+
mixed_precision: False
|
60 |
+
Scheduler:
|
61 |
+
type: "PolyLr"
|
62 |
+
epoch: 60
|
63 |
+
gamma: 0.9
|
64 |
+
minimum_lr: 1.0e-07
|
65 |
+
warmup_iteration: 12000
|
66 |
+
Checkpoint:
|
67 |
+
checkpoint_epoch: 1
|
68 |
+
checkpoint_dir: "snapshots/Plus_Ultra_LR"
|
69 |
+
Debug:
|
70 |
+
keys: ['saliency', 'laplacian']
|
71 |
+
|
72 |
+
Test:
|
73 |
+
Dataset:
|
74 |
+
type: "RGB_Dataset"
|
75 |
+
root: "data/Test_Dataset"
|
76 |
+
sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
77 |
+
transforms:
|
78 |
+
dynamic_resize:
|
79 |
+
L: 1280
|
80 |
+
tonumpy: NULL
|
81 |
+
normalize:
|
82 |
+
mean: [0.485, 0.456, 0.406]
|
83 |
+
std: [0.229, 0.224, 0.225]
|
84 |
+
totensor: NULL
|
85 |
+
Dataloader:
|
86 |
+
num_workers: 8
|
87 |
+
pin_memory: True
|
88 |
+
Checkpoint:
|
89 |
+
checkpoint_dir: "snapshots/Plus_Ultra_LR"
|
90 |
+
|
91 |
+
Eval:
|
92 |
+
gt_root: "data/Test_Dataset"
|
93 |
+
pred_root: "snapshots/Plus_Ultra_LR"
|
94 |
+
result_path: "results"
|
95 |
+
datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
|
96 |
+
metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
|
data/custom_transforms.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from email.mime import base
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from PIL import Image, ImageOps, ImageFilter, ImageEnhance
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
filepath = os.path.split(__file__)[0]
|
12 |
+
repopath = os.path.split(filepath)[0]
|
13 |
+
sys.path.append(repopath)
|
14 |
+
|
15 |
+
from utils.misc import *
|
16 |
+
|
17 |
+
class static_resize:
|
18 |
+
# Resize for training
|
19 |
+
# size: h x w
|
20 |
+
def __init__(self, size=[384, 384], base_size=None):
|
21 |
+
self.size = size[::-1]
|
22 |
+
self.base_size = base_size[::-1] if base_size is not None else None
|
23 |
+
|
24 |
+
def __call__(self, sample):
|
25 |
+
sample['image'] = sample['image'].resize(self.size, Image.BILINEAR)
|
26 |
+
if 'gt' in sample.keys():
|
27 |
+
sample['gt'] = sample['gt'].resize(self.size, Image.NEAREST)
|
28 |
+
|
29 |
+
if self.base_size is not None:
|
30 |
+
sample['image_resized'] = sample['image'].resize(self.size, Image.BILINEAR)
|
31 |
+
if 'gt' in sample.keys():
|
32 |
+
sample['gt_resized'] = sample['gt'].resize(self.size, Image.NEAREST)
|
33 |
+
|
34 |
+
return sample
|
35 |
+
|
36 |
+
class dynamic_resize:
|
37 |
+
# base_size: h x w
|
38 |
+
def __init__(self, L=1280, base_size=[384, 384]):
|
39 |
+
self.L = L
|
40 |
+
self.base_size = base_size[::-1]
|
41 |
+
|
42 |
+
def __call__(self, sample):
|
43 |
+
size = list(sample['image'].size)
|
44 |
+
if (size[0] >= size[1]) and size[1] > self.L:
|
45 |
+
size[0] = size[0] / (size[1] / self.L)
|
46 |
+
size[1] = self.L
|
47 |
+
elif (size[1] > size[0]) and size[0] > self.L:
|
48 |
+
size[1] = size[1] / (size[0] / self.L)
|
49 |
+
size[0] = self.L
|
50 |
+
size = (int(round(size[0] / 32)) * 32, int(round(size[1] / 32)) * 32)
|
51 |
+
|
52 |
+
if 'image' in sample.keys():
|
53 |
+
sample['image_resized'] = sample['image'].resize(self.base_size, Image.BILINEAR)
|
54 |
+
sample['image'] = sample['image'].resize(size, Image.BILINEAR)
|
55 |
+
|
56 |
+
if 'gt' in sample.keys():
|
57 |
+
sample['gt_resized'] = sample['gt'].resize(self.base_size, Image.NEAREST)
|
58 |
+
sample['gt'] = sample['gt'].resize(size, Image.NEAREST)
|
59 |
+
|
60 |
+
return sample
|
61 |
+
|
62 |
+
class random_scale_crop:
|
63 |
+
def __init__(self, range=[0.75, 1.25]):
|
64 |
+
self.range = range
|
65 |
+
|
66 |
+
def __call__(self, sample):
|
67 |
+
scale = np.random.random() * (self.range[1] - self.range[0]) + self.range[0]
|
68 |
+
if np.random.random() < 0.5:
|
69 |
+
for key in sample.keys():
|
70 |
+
if key in ['image', 'gt']:
|
71 |
+
base_size = sample[key].size
|
72 |
+
|
73 |
+
scale_size = tuple((np.array(base_size) * scale).round().astype(int))
|
74 |
+
sample[key] = sample[key].resize(scale_size)
|
75 |
+
|
76 |
+
lf = (sample[key].size[0] - base_size[0]) // 2
|
77 |
+
up = (sample[key].size[1] - base_size[1]) // 2
|
78 |
+
rg = (sample[key].size[0] + base_size[0]) // 2
|
79 |
+
lw = (sample[key].size[1] + base_size[1]) // 2
|
80 |
+
|
81 |
+
border = -min(0, min(lf, up))
|
82 |
+
sample[key] = ImageOps.expand(sample[key], border=border)
|
83 |
+
sample[key] = sample[key].crop((lf + border, up + border, rg + border, lw + border))
|
84 |
+
return sample
|
85 |
+
|
86 |
+
class random_flip:
|
87 |
+
def __init__(self, lr=True, ud=True):
|
88 |
+
self.lr = lr
|
89 |
+
self.ud = ud
|
90 |
+
|
91 |
+
def __call__(self, sample):
|
92 |
+
lr = np.random.random() < 0.5 and self.lr is True
|
93 |
+
ud = np.random.random() < 0.5 and self.ud is True
|
94 |
+
|
95 |
+
for key in sample.keys():
|
96 |
+
if key in ['image', 'gt']:
|
97 |
+
sample[key] = np.array(sample[key])
|
98 |
+
if lr:
|
99 |
+
sample[key] = np.fliplr(sample[key])
|
100 |
+
if ud:
|
101 |
+
sample[key] = np.flipud(sample[key])
|
102 |
+
sample[key] = Image.fromarray(sample[key])
|
103 |
+
|
104 |
+
return sample
|
105 |
+
|
106 |
+
class random_rotate:
|
107 |
+
def __init__(self, range=[0, 360], interval=1):
|
108 |
+
self.range = range
|
109 |
+
self.interval = interval
|
110 |
+
|
111 |
+
def __call__(self, sample):
|
112 |
+
rot = (np.random.randint(*self.range) // self.interval) * self.interval
|
113 |
+
rot = rot + 360 if rot < 0 else rot
|
114 |
+
|
115 |
+
if np.random.random() < 0.5:
|
116 |
+
for key in sample.keys():
|
117 |
+
if key in ['image', 'gt']:
|
118 |
+
base_size = sample[key].size
|
119 |
+
sample[key] = sample[key].rotate(rot, expand=True, fillcolor=255 if key == 'depth' else None)
|
120 |
+
|
121 |
+
sample[key] = sample[key].crop(((sample[key].size[0] - base_size[0]) // 2,
|
122 |
+
(sample[key].size[1] - base_size[1]) // 2,
|
123 |
+
(sample[key].size[0] + base_size[0]) // 2,
|
124 |
+
(sample[key].size[1] + base_size[1]) // 2))
|
125 |
+
|
126 |
+
return sample
|
127 |
+
|
128 |
+
class random_image_enhance:
|
129 |
+
def __init__(self, methods=['contrast', 'brightness', 'sharpness']):
|
130 |
+
self.enhance_method = []
|
131 |
+
if 'contrast' in methods:
|
132 |
+
self.enhance_method.append(ImageEnhance.Contrast)
|
133 |
+
if 'brightness' in methods:
|
134 |
+
self.enhance_method.append(ImageEnhance.Brightness)
|
135 |
+
if 'sharpness' in methods:
|
136 |
+
self.enhance_method.append(ImageEnhance.Sharpness)
|
137 |
+
|
138 |
+
def __call__(self, sample):
|
139 |
+
if 'image' in sample.keys():
|
140 |
+
np.random.shuffle(self.enhance_method)
|
141 |
+
|
142 |
+
for method in self.enhance_method:
|
143 |
+
if np.random.random() > 0.5:
|
144 |
+
enhancer = method(sample['image'])
|
145 |
+
factor = float(1 + np.random.random() / 10)
|
146 |
+
sample['image'] = enhancer.enhance(factor)
|
147 |
+
|
148 |
+
return sample
|
149 |
+
|
150 |
+
class tonumpy:
|
151 |
+
def __init__(self):
|
152 |
+
pass
|
153 |
+
|
154 |
+
def __call__(self, sample):
|
155 |
+
for key in sample.keys():
|
156 |
+
if key in ['image', 'image_resized', 'gt', 'gt_resized']:
|
157 |
+
sample[key] = np.array(sample[key], dtype=np.float32)
|
158 |
+
|
159 |
+
return sample
|
160 |
+
|
161 |
+
class normalize:
|
162 |
+
def __init__(self, mean: Optional[list]=None, std: Optional[list]=None, div=255):
|
163 |
+
self.mean = mean if mean is not None else 0.0
|
164 |
+
self.std = std if std is not None else 1.0
|
165 |
+
self.div = div
|
166 |
+
|
167 |
+
def __call__(self, sample):
|
168 |
+
if 'image' in sample.keys():
|
169 |
+
sample['image'] /= self.div
|
170 |
+
sample['image'] -= self.mean
|
171 |
+
sample['image'] /= self.std
|
172 |
+
|
173 |
+
if 'image_resized' in sample.keys():
|
174 |
+
sample['image_resized'] /= self.div
|
175 |
+
sample['image_resized'] -= self.mean
|
176 |
+
sample['image_resized'] /= self.std
|
177 |
+
|
178 |
+
if 'gt' in sample.keys():
|
179 |
+
sample['gt'] /= self.div
|
180 |
+
|
181 |
+
if 'gt_resized' in sample.keys():
|
182 |
+
sample['gt_resized'] /= self.div
|
183 |
+
|
184 |
+
return sample
|
185 |
+
class totensor:
|
186 |
+
def __init__(self):
|
187 |
+
pass
|
188 |
+
|
189 |
+
def __call__(self, sample):
|
190 |
+
if 'image' in sample.keys():
|
191 |
+
sample['image'] = sample['image'].transpose((2, 0, 1))
|
192 |
+
sample['image'] = torch.from_numpy(sample['image']).float()
|
193 |
+
|
194 |
+
if 'image_resized' in sample.keys():
|
195 |
+
sample['image_resized'] = sample['image_resized'].transpose((2, 0, 1))
|
196 |
+
sample['image_resized'] = torch.from_numpy(sample['image_resized']).float()
|
197 |
+
|
198 |
+
if 'gt' in sample.keys():
|
199 |
+
sample['gt'] = torch.from_numpy(sample['gt'])
|
200 |
+
sample['gt'] = sample['gt'].unsqueeze(dim=0)
|
201 |
+
|
202 |
+
if 'gt_resized' in sample.keys():
|
203 |
+
sample['gt_resized'] = torch.from_numpy(sample['gt_resized'])
|
204 |
+
sample['gt_resized'] = sample['gt_resized'].unsqueeze(dim=0)
|
205 |
+
|
206 |
+
return sample
|
data/dataloader.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
|
8 |
+
from torch.utils.data.dataset import Dataset
|
9 |
+
from PIL import Image
|
10 |
+
from threading import Thread
|
11 |
+
|
12 |
+
filepath = os.path.split(__file__)[0]
|
13 |
+
repopath = os.path.split(filepath)[0]
|
14 |
+
sys.path.append(repopath)
|
15 |
+
|
16 |
+
from data.custom_transforms import *
|
17 |
+
from utils.misc import *
|
18 |
+
|
19 |
+
Image.MAX_IMAGE_PIXELS = None
|
20 |
+
|
21 |
+
def get_transform(tfs):
|
22 |
+
comp = []
|
23 |
+
for key, value in zip(tfs.keys(), tfs.values()):
|
24 |
+
if value is not None:
|
25 |
+
tf = eval(key)(**value)
|
26 |
+
else:
|
27 |
+
tf = eval(key)()
|
28 |
+
comp.append(tf)
|
29 |
+
return transforms.Compose(comp)
|
30 |
+
|
31 |
+
class RGB_Dataset(Dataset):
|
32 |
+
def __init__(self, root, sets, tfs):
|
33 |
+
self.images, self.gts = [], []
|
34 |
+
|
35 |
+
for set in sets:
|
36 |
+
image_root, gt_root = os.path.join(root, set, 'images'), os.path.join(root, set, 'masks')
|
37 |
+
|
38 |
+
images = [os.path.join(image_root, f) for f in os.listdir(image_root) if f.lower().endswith(('.jpg', '.png'))]
|
39 |
+
images = sort(images)
|
40 |
+
|
41 |
+
gts = [os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.lower().endswith(('.jpg', '.png'))]
|
42 |
+
gts = sort(gts)
|
43 |
+
|
44 |
+
self.images.extend(images)
|
45 |
+
self.gts.extend(gts)
|
46 |
+
|
47 |
+
self.filter_files()
|
48 |
+
|
49 |
+
self.size = len(self.images)
|
50 |
+
self.transform = get_transform(tfs)
|
51 |
+
|
52 |
+
def __getitem__(self, index):
|
53 |
+
image = Image.open(self.images[index]).convert('RGB')
|
54 |
+
gt = Image.open(self.gts[index]).convert('L')
|
55 |
+
shape = gt.size[::-1]
|
56 |
+
name = self.images[index].split(os.sep)[-1]
|
57 |
+
name = os.path.splitext(name)[0]
|
58 |
+
|
59 |
+
sample = {'image': image, 'gt': gt, 'name': name, 'shape': shape}
|
60 |
+
|
61 |
+
sample = self.transform(sample)
|
62 |
+
return sample
|
63 |
+
|
64 |
+
def filter_files(self):
|
65 |
+
assert len(self.images) == len(self.gts)
|
66 |
+
images, gts = [], []
|
67 |
+
for img_path, gt_path in zip(self.images, self.gts):
|
68 |
+
img, gt = Image.open(img_path), Image.open(gt_path)
|
69 |
+
if img.size == gt.size:
|
70 |
+
images.append(img_path)
|
71 |
+
gts.append(gt_path)
|
72 |
+
self.images, self.gts = images, gts
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return self.size
|
76 |
+
|
77 |
+
class ImageLoader:
|
78 |
+
def __init__(self, root, tfs):
|
79 |
+
if os.path.isdir(root):
|
80 |
+
self.images = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
|
81 |
+
self.images = sort(self.images)
|
82 |
+
elif os.path.isfile(root):
|
83 |
+
self.images = [root]
|
84 |
+
self.size = len(self.images)
|
85 |
+
self.transform = get_transform(tfs)
|
86 |
+
|
87 |
+
def __iter__(self):
|
88 |
+
self.index = 0
|
89 |
+
return self
|
90 |
+
|
91 |
+
def __next__(self):
|
92 |
+
if self.index == self.size:
|
93 |
+
raise StopIteration
|
94 |
+
image = Image.open(self.images[self.index]).convert('RGB')
|
95 |
+
shape = image.size[::-1]
|
96 |
+
name = self.images[self.index].split(os.sep)[-1]
|
97 |
+
name = os.path.splitext(name)[0]
|
98 |
+
|
99 |
+
sample = {'image': image, 'name': name, 'shape': shape, 'original': image}
|
100 |
+
sample = self.transform(sample)
|
101 |
+
sample['image'] = sample['image'].unsqueeze(0)
|
102 |
+
if 'image_resized' in sample.keys():
|
103 |
+
sample['image_resized'] = sample['image_resized'].unsqueeze(0)
|
104 |
+
|
105 |
+
self.index += 1
|
106 |
+
return sample
|
107 |
+
|
108 |
+
def __len__(self):
|
109 |
+
return self.size
|
110 |
+
|
111 |
+
class VideoLoader:
|
112 |
+
def __init__(self, root, tfs):
|
113 |
+
if os.path.isdir(root):
|
114 |
+
self.videos = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.mp4', '.avi', 'mov'))]
|
115 |
+
elif os.path.isfile(root):
|
116 |
+
self.videos = [root]
|
117 |
+
self.size = len(self.videos)
|
118 |
+
self.transform = get_transform(tfs)
|
119 |
+
|
120 |
+
def __iter__(self):
|
121 |
+
self.index = 0
|
122 |
+
self.cap = None
|
123 |
+
self.fps = None
|
124 |
+
return self
|
125 |
+
|
126 |
+
def __next__(self):
|
127 |
+
if self.index == self.size:
|
128 |
+
raise StopIteration
|
129 |
+
|
130 |
+
if self.cap is None:
|
131 |
+
self.cap = cv2.VideoCapture(self.videos[self.index])
|
132 |
+
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
133 |
+
ret, frame = self.cap.read()
|
134 |
+
name = self.videos[self.index].split(os.sep)[-1]
|
135 |
+
name = os.path.splitext(name)[0]
|
136 |
+
if ret is False:
|
137 |
+
self.cap.release()
|
138 |
+
self.cap = None
|
139 |
+
sample = {'image': None, 'shape': None, 'name': name, 'original': None}
|
140 |
+
self.index += 1
|
141 |
+
|
142 |
+
else:
|
143 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
144 |
+
image = Image.fromarray(frame).convert('RGB')
|
145 |
+
shape = image.size[::-1]
|
146 |
+
sample = {'image': image, 'shape': shape, 'name': name, 'original': image}
|
147 |
+
sample = self.transform(sample)
|
148 |
+
sample['image'] = sample['image'].unsqueeze(0)
|
149 |
+
if 'image_resized' in sample.keys():
|
150 |
+
sample['image_resized'] = sample['image_resized'].unsqueeze(0)
|
151 |
+
|
152 |
+
return sample
|
153 |
+
|
154 |
+
def __len__(self):
|
155 |
+
return self.size
|
156 |
+
|
157 |
+
|
158 |
+
class WebcamLoader:
|
159 |
+
def __init__(self, ID, tfs):
|
160 |
+
self.ID = int(ID)
|
161 |
+
self.cap = cv2.VideoCapture(self.ID)
|
162 |
+
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
163 |
+
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
164 |
+
self.transform = get_transform(tfs)
|
165 |
+
self.imgs = []
|
166 |
+
self.imgs.append(self.cap.read()[1])
|
167 |
+
self.thread = Thread(target=self.update, daemon=True)
|
168 |
+
self.thread.start()
|
169 |
+
|
170 |
+
def update(self):
|
171 |
+
while self.cap.isOpened():
|
172 |
+
ret, frame = self.cap.read()
|
173 |
+
if ret is True:
|
174 |
+
self.imgs.append(frame)
|
175 |
+
else:
|
176 |
+
break
|
177 |
+
|
178 |
+
def __iter__(self):
|
179 |
+
return self
|
180 |
+
|
181 |
+
def __next__(self):
|
182 |
+
if len(self.imgs) > 0:
|
183 |
+
frame = self.imgs[-1]
|
184 |
+
else:
|
185 |
+
frame = np.zeros((480, 640, 3)).astype(np.uint8)
|
186 |
+
if self.thread.is_alive() is False or cv2.waitKey(1) == ord('q'):
|
187 |
+
cv2.destroyAllWindows()
|
188 |
+
raise StopIteration
|
189 |
+
|
190 |
+
else:
|
191 |
+
image = Image.fromarray(frame).convert('RGB')
|
192 |
+
shape = image.size[::-1]
|
193 |
+
sample = {'image': image, 'shape': shape, 'name': 'webcam', 'original': image}
|
194 |
+
sample = self.transform(sample)
|
195 |
+
sample['image'] = sample['image'].unsqueeze(0)
|
196 |
+
if 'image_resized' in sample.keys():
|
197 |
+
sample['image_resized'] = sample['image_resized'].unsqueeze(0)
|
198 |
+
|
199 |
+
del self.imgs[:-1]
|
200 |
+
return sample
|
201 |
+
|
202 |
+
|
203 |
+
def __len__(self):
|
204 |
+
return 0
|
205 |
+
|
206 |
+
class RefinementLoader:
|
207 |
+
def __init__(self, image_dir, seg_dir, tfs):
|
208 |
+
self.images = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
|
209 |
+
self.images = sort(self.images)
|
210 |
+
|
211 |
+
self.segs = [os.path.join(seg_dir, f) for f in os.listdir(seg_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
|
212 |
+
self.segs = sort(self.segs)
|
213 |
+
|
214 |
+
self.size = len(self.images)
|
215 |
+
self.transform = get_transform(tfs)
|
216 |
+
|
217 |
+
def __iter__(self):
|
218 |
+
self.index = 0
|
219 |
+
return self
|
220 |
+
|
221 |
+
def __next__(self):
|
222 |
+
if self.index == self.size:
|
223 |
+
raise StopIteration
|
224 |
+
image = Image.open(self.images[self.index]).convert('RGB')
|
225 |
+
seg = Image.open(self.segs[self.index]).convert('L')
|
226 |
+
shape = image.size[::-1]
|
227 |
+
name = self.images[self.index].split(os.sep)[-1]
|
228 |
+
name = os.path.splitext(name)[0]
|
229 |
+
|
230 |
+
sample = {'image': image, 'gt': seg, 'name': name, 'shape': shape, 'original': image}
|
231 |
+
sample = self.transform(sample)
|
232 |
+
sample['image'] = sample['image'].unsqueeze(0)
|
233 |
+
sample['mask'] = sample['gt'].unsqueeze(0)
|
234 |
+
if 'image_resized' in sample.keys():
|
235 |
+
sample['image_resized'] = sample['image_resized'].unsqueeze(0)
|
236 |
+
del sample['gt']
|
237 |
+
|
238 |
+
self.index += 1
|
239 |
+
return sample
|
240 |
+
|
241 |
+
def __len__(self):
|
242 |
+
return self.size
|
243 |
+
|
docs/getting_started.md
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Getting Started :flight_departure:
|
2 |
+
|
3 |
+
## Create environment
|
4 |
+
* Create conda environment with following command `conda create -y -n inspyrenet python`
|
5 |
+
* Activate environment with following command `conda activate inspyrenet`
|
6 |
+
* Install latest PyTorch from [Official Website](https://pytorch.org/get-started/locally/)
|
7 |
+
* Linux [2022.11.09]
|
8 |
+
```
|
9 |
+
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
|
10 |
+
```
|
11 |
+
* Install requirements with following command `pip install -r requirements.txt`
|
12 |
+
|
13 |
+
## Preparation
|
14 |
+
|
15 |
+
* For training, you may need training datasets and ImageNet pre-trained checkpoints for the backbone. For testing (inference), you may need test datasets (sample images).
|
16 |
+
* Training datasets are expected to be located under [Train.Dataset.root](https://github.com/plemeri/InSPyReNet/blob/main/configs/InSPyReNet_SwinB.yaml#L11). Likewise, testing datasets should be under [Test.Dataset.root](https://github.com/plemeri/InSPyReNet/blob/main/configs/InSPyReNet_SwinB.yaml#L55).
|
17 |
+
* Each dataset folder should contain `images` folder and `masks` folder for images and ground truth masks respectively.
|
18 |
+
* You may use multiple training datasets by listing dataset folders for [Train.Dataset.sets](https://github.com/plemeri/InSPyReNet/blob/main/configs/InSPyReNet_SwinB.yaml#L12), such as `[DUTS-TR] -> [DUTS-TR, HRSOD-TR, UHRSD-TR]`.
|
19 |
+
|
20 |
+
### Backbone Checkpoints
|
21 |
+
Item | Destination Folder | OneDrive | GDrive
|
22 |
+
:-|:-|:-|:-
|
23 |
+
Res2Net50 checkpoint | `data/backbone_ckpt/*.pth` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUO7GDBwoC9CulTPdnq_yhQBlc0SIyyELMy3OmrNhOjcGg?e=T3PVyG&download=1) | [Link](https://drive.google.com/file/d/1MMhioAsZ-oYa5FpnTi22XBGh5HkjLX3y/view?usp=sharing)
|
24 |
+
SwinB checkpoint | `data/backbone_ckpt/*.pth` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESlYCLy0endMhcZm9eC2A4ABatxupp4UPh03EcqFjbtSRw?e=7y6lLt&download=1) | [Link](https://drive.google.com/file/d/1fBJFMupe5pV-Vtou-k8LTvHclWs0y1bI/view?usp=sharing)
|
25 |
+
|
26 |
+
* We changed Res2Net50 checkpoint to resolve an error while training with DDP. Please refer to [issue #9](https://github.com/plemeri/InSPyReNet/issues/9).
|
27 |
+
|
28 |
+
### Train Datasets
|
29 |
+
Item | Destination Folder | OneDrive | GDrive
|
30 |
+
:-|:-|:-|:-
|
31 |
+
DUTS-TR | `data/Train_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EQ7L2XS-5YFMuJGee7o7HQ8BdRSLO8utbC_zRrv-KtqQ3Q?e=bCSIeo&download=1) | [Link](https://drive.google.com/file/d/1hy5UTq65uQWFO5yzhEn9KFIbdvhduThP/view?usp=share_link)
|
32 |
+
|
33 |
+
### Extra Train Datasets (High Resolution, Optional)
|
34 |
+
Item | Destination Folder | OneDrive | GDrive
|
35 |
+
:-|:-|:-|:-
|
36 |
+
HRSOD-TR | `data/Train_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfUx92hUgZJNrWPj46PC0yEBXcorQskXOCSz8SnGH5AcLQ?e=WA5pc6&download=1) | N/A
|
37 |
+
UHRSD-TR | `data/Train_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ea4_UCbsKmhKnMCccAJOTLgBmQFsQ4KhJSf2jx8WQqj3Wg?e=18kYZS&download=1) | N/A
|
38 |
+
DIS-TR | `data/Train_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZtZJ493tVNJjBIpNLdus68B3u906PdWtHsf87pulj78jw?e=bUg2UQ&download=1) | N/A
|
39 |
+
|
40 |
+
### Test Datasets
|
41 |
+
Item | Destination Folder | OneDrive | GDrive
|
42 |
+
:-|:-|:-|:-
|
43 |
+
DUTS-TE | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfuCxjveXphPpIska9BxHDMBHpYroEKdVlq9HsonZ4wLDw?e=Mz5giA&download=1) | [Link](https://drive.google.com/file/d/1w4pigcQe9zplMulp1rAwmp6yYXmEbmvy/view?usp=share_link)
|
44 |
+
DUT-OMRON | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ERvApm9rHH5LiR4NJoWHqDoBCneUQNextk8EjQ_Hy0bUHg?e=wTRZQb&download=1) | [Link](https://drive.google.com/file/d/1qIm_GQLLQkP6s-xDZhmp_FEAalavJDXf/view?usp=sharing)
|
45 |
+
ECSSD | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ES_GCdS0yblBmnRaDZ8xmKQBPU_qeECTVB9vlPUups8bnA?e=POVAlG&download=1) | [Link](https://drive.google.com/file/d/1qk_12KLGX6FPr1P_S9dQ7vXKaMqyIRoA/view?usp=sharing)
|
46 |
+
HKU-IS | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EYBRVvC1MJRAgSfzt0zaG94BU_UWaVrvpv4tjogu4vSV6w?e=TKN7hQ&download=1) | [Link](https://drive.google.com/file/d/1H3szJYbr5_CRCzrYfhPHThTgszkKd1EU/view?usp=share_link)
|
47 |
+
PASCAL-S | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfUDGDckMnZHhEPy8YQGwBQB5MN3qInBkEygpIr7ccJdTQ?e=YarZaQ&download=1) | [Link](https://drive.google.com/file/d/1h0IE2DlUt0HHZcvzMV5FCxtZqQqh9Ztf/view?usp=sharing)
|
48 |
+
DAVIS-S | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ebam8I2o-tRJgADcq-r9YOkBCDyaAdWBVWyfN-xCYyAfDQ?e=Mqz8cK&download=1) | [Link](https://drive.google.com/file/d/15F0dy9o02LPTlpUbnD9NJlGeKyKU3zOz/view?usp=sharing)
|
49 |
+
HRSOD-TE | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EbHOQZKC59xIpIdrM11ulWsBHRYY1wZY2njjWCDFXvT6IA?e=wls17m&download=1) | [Link](https://drive.google.com/file/d/1KnUCsvluS4kP2HwUFVRbKU8RK_v6rv2N/view?usp=sharing)
|
50 |
+
UHRSD-TE | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUpc8QJffNpNpESv-vpBi40BppucqOoXm_IaK7HYJkuOog?e=JTjGmS&download=1) | [Link](https://drive.google.com/file/d/1niiHBo9LX6-I3KsEWYOi_s6cul80IYvK/view?usp=sharing)
|
51 |
+
|
52 |
+
### Extra Test Datasets (Optional)
|
53 |
+
Item | Destination Folder | OneDrive | GDrive
|
54 |
+
:-|:-|:-|:-
|
55 |
+
FSS-1000 | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EaP6DogjMAVCtTKcC_Bx-YoBoBSWBo90lesVcMyuCN35NA?e=0DDohA&download=1) | N/A
|
56 |
+
MSRA-10K | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EauXsIkxBzhDjio6fW0TubUB4L7YJc0GMTaq7VfjI2nPsg?e=c5DIxg&download=1) | N/A
|
57 |
+
DIS-VD | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EYJm3BqheaxNhdVoMt6X41gBgVnE4dulBwkp6pbOQtcIrQ?e=T6dtXm&download=1) | [Link](https://drive.google.com/file/d/1jhlZb3QyNPkc0o8nL3RWF0MLuVsVtJju/view?usp=sharing)
|
58 |
+
DIS-TE1 | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EcGYE_Gc0cVHoHi_qUtmsawB_5v9RSpJS5PIAPlLu6xo9A?e=Nu5mkQ&download=1) | [Link](https://drive.google.com/file/d/1iz8Y4uaX3ZBy42N2MIOkmNb0D5jroFPJ/view?usp=sharing)
|
59 |
+
DIS-TE2 | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EdhgMdbZ049GvMv7tNrjbbQB1wL9Ok85YshiXIkgLyTfkQ?e=mPA6Po&download=1) | [Link](https://drive.google.com/file/d/1DWSoWogTWDuS2PFbD1Qx9P8_SnSv2zTe/view?usp=sharing)
|
60 |
+
DIS-TE3 | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EcxXYC_3rXxKsQBrp6BdNb4BOKxBK3_vsR9RL76n7YVG-g?e=2M0cse&download=1) | [Link](https://drive.google.com/file/d/1bIVSjsxCjMrcmV1fsGplkKl9ORiiiJTZ/view?usp=sharing)
|
61 |
+
DIS-TE4 | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EdkG2SUi8flJvoYbHHOmvMABsGhkCJCsLLZlaV2E_SZimA?e=zlM2kC&download=1) | [Link](https://drive.google.com/file/d/1VuPNqkGTP1H4BFEHe807dTIkv8Kfzk5_/view?usp=sharing)
|
62 |
+
|
63 |
+
## Train & Evaluate
|
64 |
+
|
65 |
+
* Train InSPyReNet
|
66 |
+
```
|
67 |
+
# Single GPU
|
68 |
+
python run/Train.py --config configs/InSPyReNet_SwinB.yaml --verbose
|
69 |
+
|
70 |
+
# Multi GPUs with DDP (e.g., 4 GPUs)
|
71 |
+
torchrun --standalone --nproc_per_node=4 run/Train.py --config configs/InSPyReNet_SwinB.yaml --verbose
|
72 |
+
|
73 |
+
# Multi GPUs with DDP with designated devices (e.g., 2 GPUs - 0 and 1)
|
74 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nproc_per_node=2 run/Train.py --config configs/InSPyReNet_SwinB.yaml --verbose
|
75 |
+
```
|
76 |
+
|
77 |
+
* `--config [CONFIG_FILE], -c [CONFIG_FILE]`: config file path for training.
|
78 |
+
* `--resume, -r`: use this argument to resume from last checkpoint.
|
79 |
+
* `--verbose, -v`: use this argument to output progress info.
|
80 |
+
* `--debug, -d`: use this argument to save debug images every epoch.
|
81 |
+
|
82 |
+
* Train with extra training datasets can be done by just changing [Train.Dataset.sets](https://github.com/plemeri/InSPyReNet/blob/main/configs/InSPyReNet_SwinB.yaml#L12) in the `yaml` config file, which is just simply adding more directories (e.g., HRSOD-TR, HRSOD-TR, UHRSD-TR, ...):
|
83 |
+
```
|
84 |
+
Train:
|
85 |
+
Dataset:
|
86 |
+
type: "RGB_Dataset"
|
87 |
+
root: "data/RGB_Dataset/Train_Dataset"
|
88 |
+
sets: ['DUTS-TR'] --> ['DUTS-TR', 'HRSOD-TR', 'UHRSD-TR']
|
89 |
+
```
|
90 |
+
* Inference for test benchmarks
|
91 |
+
```
|
92 |
+
python run/Test.py --config configs/InSPyReNet_SwinB.yaml --verbose
|
93 |
+
```
|
94 |
+
* Evaluate metrics
|
95 |
+
```
|
96 |
+
python run/Eval.py --config configs/InSPyReNet_SwinB.yaml --verbose
|
97 |
+
```
|
98 |
+
|
99 |
+
* All-in-One command (Train, Test, Eval in single command)
|
100 |
+
```
|
101 |
+
# Single GPU
|
102 |
+
python Expr.py --config configs/InSPyReNet_SwinB.yaml --verbose
|
103 |
+
|
104 |
+
# Multi GPUs with DDP (e.g., 4 GPUs)
|
105 |
+
torchrun --standalone --nproc_per_node=4 Expr.py --config configs/InSPyReNet_SwinB.yaml --verbose
|
106 |
+
|
107 |
+
# Multi GPUs with DDP with designated devices (e.g., 2 GPUs - 0 and 1)
|
108 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nproc_per_node=2 Expr.py --config configs/InSPyReNet_SwinB.yaml --verbose
|
109 |
+
```
|
110 |
+
## Inference on your own data
|
111 |
+
* You can inference your own single image or images (.jpg, .jpeg, and .png are supported), single video or videos (.mp4, .mov, and .avi are supported), and webcam input (ubuntu and macos are tested so far).
|
112 |
+
```
|
113 |
+
python run/Inference.py --config configs/InSPyReNet_SwinB.yaml --source [SOURCE] --dest [DEST] --type [TYPE] --gpu --jit --verbose
|
114 |
+
```
|
115 |
+
|
116 |
+
* `--source [SOURCE]`: Specify your data in this argument.
|
117 |
+
* Single image - `image.png`
|
118 |
+
* Folder containing images - `path/to/img/folder`
|
119 |
+
* Single video - `video.mp4`
|
120 |
+
* Folder containing videos - `path/to/vid/folder`
|
121 |
+
* Webcam input: `0` (may vary depends on your device.)
|
122 |
+
* `--dest [DEST]` (optional): Specify your destination folder. If not specified, it will be saved in `results` folder.
|
123 |
+
* `--type [TYPE]`: Choose between `map` `green`, `rgba`, `blur`, `overlay`, and another image file.
|
124 |
+
* `map` will output saliency map only.
|
125 |
+
* `green` will change the background with green screen.
|
126 |
+
* `rgba` will generate RGBA output regarding saliency score as an alpha map. Note that this will not work for video and webcam input.
|
127 |
+
* `blur` will blur the background.
|
128 |
+
* `overlay` will cover the salient object with translucent green color, and highlight the edges.
|
129 |
+
* Another image file (e.g., `backgroud.png`) will be used as a background, and the object will be overlapped on it.
|
130 |
+
<details><summary>Examples of TYPE argument</summary>
|
131 |
+
<p>
|
132 |
+
<img src=../figures/demo_type.png >
|
133 |
+
</p>
|
134 |
+
</details>
|
135 |
+
* `--gpu`: Use this argument if you want to use GPU.
|
136 |
+
* `--jit`: Slightly improves inference speed when used.
|
137 |
+
* `--verbose`: Use when you want to visualize progress.
|
138 |
+
|
docs/model_zoo.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Zoo :giraffe:
|
2 |
+
|
3 |
+
## Pre-trained Checkpoints
|
4 |
+
|
5 |
+
:label: Note: If you want to try our trained checkpoints below, please make sure to locate `latest.pth` file to the [Test.Checkpoint.checkpoint_dir](https://github.com/plemeri/InSPyReNet/blob/main/configs/InSPyReNet_SwinB.yaml#L69).
|
6 |
+
|
7 |
+
### Trained with LR dataset only (DUTS-TR, 384 X 384)
|
8 |
+
|
9 |
+
Backbone | Train DB | Config | OneDrive | GDrive
|
10 |
+
:-|:-|:-|:-|:-
|
11 |
+
Res2Net50 | DUTS-TR | [InSPyReNet_Res2Net50](../configs/InSPyReNet_Res2Net50.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ERqm7RPeNBFPvVxkA5P5G2AB-mtFsiYkCNHnBf0DcwpFzw?e=nayVno&download=1) | [Link](https://drive.google.com/file/d/12moRuU8F0-xRvE16bVg6mkGWDuqYHJor/view?usp=sharing)
|
12 |
+
SwinB | DUTS-TR | [InSPyReNet_SwinB](../configs/InSPyReNet_SwinB.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EV0ow4E8LddCgu5tAuAkMbcBpBYoEDmJgQg5wkiuvLoQUA?e=cOZspv&download=1) | [Link](https://drive.google.com/file/d/1k5hNJImgEgSmz-ZeJEEb_dVkrOnswVMq/view?usp=sharing)
|
13 |
+
|
14 |
+
### Trained with LR+HR dataset (with LR scale 384 X 384)
|
15 |
+
|
16 |
+
Backbone | Train DB | Config | OneDrive | GDrive
|
17 |
+
:-|:-|:-|:-|:-
|
18 |
+
SwinB | DUTS-TR, HRSOD-TR | [InSPyReNet_SwinB_DH_LR](../configs/extra_dataset/InSPyReNet_SwinB_DH_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EWxPZoIKALlGsfrNgUFNvxwBC8IE8jzzhPNtzcbHmTNFcg?e=e22wmy&download=1) | [Link](https://drive.google.com/file/d/1nbs6Xa7NMtcikeHFtkQRVrsHbBRHtIqC/view?usp=sharing)
|
19 |
+
SwinB | HRSOD-TR, UHRSD-TR | [InSPyReNet_SwinB_HU_LR](../configs/extra_dataset/InSPyReNet_SwinB_HU_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EQe-iy0AZctIkgl3o-BmVYUBn795wvii3tsnBq1fNUbc9g?e=gMZ4PV&download=1) | [Link](https://drive.google.com/file/d/1uLSIYXlRsZv4Ho0C-c87xKPhmF_b-Ll4/view?usp=sharing)
|
20 |
+
SwinB | DUTS-TR, HRSOD-TR, UHRSD-TR | [InSPyReNet_SwinB_DHU_LR](../configs/extra_dataset/InSPyReNet_SwinB_DHU_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfsCbnfAU1RAqCJIkj1ewRgBhFetStsGB6SMSq_UJZimjA?e=Ghuacy&download=1) | [Link](https://drive.google.com/file/d/14gRNwR7XwJ5oEcR4RWIVbYH3HEV6uBUq/view?usp=sharing)
|
21 |
+
|
22 |
+
### Trained with LR+HR dataset (with HR scale 1024 X 1024)
|
23 |
+
|
24 |
+
Backbone | Train DB | Config | OneDrive | GDrive
|
25 |
+
:-|:-|:-|:-|:-
|
26 |
+
SwinB | DUTS-TR, HRSOD-TR | [InSPyReNet_SwinB_DH](../configs/extra_dataset/InSPyReNet_SwinB_DH.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EW2Qg-tMBBxNkygMj-8QgMUBiqHox5ExTOJl0LGLsn6AtA?e=Mam8Ur&download=1) | [Link](https://drive.google.com/file/d/1UBGFDUYZ9SysZr96dhsscZg7nDXt6IUD/view?usp=sharing)
|
27 |
+
SwinB | HRSOD-TR, UHRSD-TR | [InSPyReNet_SwinB_HU](../configs/extra_dataset/InSPyReNet_SwinB_HU.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EeE8nnCt_AdFvxxu0JsxwDgBCtGchuUka6DW9za_epX-Qw?e=U7wZu9&download=1) | [Link](https://drive.google.com/file/d/1HB02tiInEgo-pNzwqyvyV6eSN1Y2xPRJ/view?usp=sharing)
|
28 |
+
|
29 |
+
|
30 |
+
### Trained with Dichotomous Image Segmentation dataset (DIS5K-TR) with LR (384 X 384) and HR (1024 X 1024) scale
|
31 |
+
|
32 |
+
Backbone | Train DB | Config | OneDrive | GDrive
|
33 |
+
:-|:-|:-|:-|:-
|
34 |
+
SwinB | DIS5K-TR | [InSPyReNet_SwinB_DIS5K_LR](../configs/extra_dataset/InSPyReNet_SwinB_DIS5K_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ERKrQ_YeoJRHl_3HcH8ZJLoBedsa6hZlmIIf66wobZRGuw?e=EywJmS&download=1) | [Link](https://drive.google.com/file/d/1Sj7GZoocGMHyKNhFnQQc1FTs76ysJIX3/view?usp=sharing)
|
35 |
+
SwinB | DIS5K-TR | [InSPyReNet_SwinB_DIS5K](../configs/extra_dataset/InSPyReNet_SwinB_DIS5K.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EShRbD-jZuRJiWv6DS2Us34BwgazGZvK1t4uTKvgE5379Q?e=8oVpS8) | [Link](https://drive.google.com/file/d/1aCxHMbhvj8ah77jXVgqvqImQA_Y0G-Yg/view?usp=sharing)
|
36 |
+
|
37 |
+
### <img src="https://www.kindpng.com/picc/b/124-1249525_all-might-png.png" width=20px> Trained with Massive SOD Datasets (with LR (384 X 384) and HR (1024 X 1024), Not in the paper, for real-world scenarios)
|
38 |
+
|
39 |
+
Backbone | Train DB | Config | OneDrive | GDrive
|
40 |
+
:-|:-|:-|:-|:-
|
41 |
+
SwinB | DUTS-TR, HRSOD-TR, UHRSD-TR, DIS-TR, DUTS-TE, DUT-OMRON, ECSSD,HKU-IS, PASCAL-S, DAVIS-S, HRSOD-TE,UHRSD-TE, FSS-1000, MSRA-10K, DIS-VD, DIS-TE1, DIS-TE2, DIS-TE3, DIS-TE4 | [Plus_Ultra_LR](../configs/extra_dataset/Plus_Ultra_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESKuh1zhToVFsIxhUUsgkbgBnu2kFXCFLRuSz1xxsKzjhA?e=02HDrm&download=1) | [Link](https://drive.google.com/file/d/1iRX-0MVbUjvAVns5MtVdng6CQlGOIo3m/view?usp=sharing)
|
42 |
+
SwinB | DUTS-TR, HRSOD-TR, UHRSD-TR, DIS-TR, DUTS-TE, DUT-OMRON, ECSSD,HKU-IS, PASCAL-S, DAVIS-S, HRSOD-TE,UHRSD-TE, FSS-1000, MSRA-10K, DIS-VD, DIS-TE1, DIS-TE2, DIS-TE3, DIS-TE4 | [Plus_Ultra](../configs/extra_dataset/Plus_Ultra.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ET0R-yM8MfVHqI4g94AlL6AB-D6LxNajaWeDV4xbVQyh7w?e=l4JkZn) | [Link](https://drive.google.com/file/d/13oBl5MTVcWER3YU4fSxW3ATlVfueFQPY/view?usp=sharing)
|
43 |
+
|
44 |
+
* We used above checkpoints for [`transparent-background`](https://github.com/plemeri/transparent-background) which is a cool background removing command-line tool / python API!
|
45 |
+
|
46 |
+
## Pre-Computed Saliency Maps
|
47 |
+
|
48 |
+
### Salient Object Detection benchmarks
|
49 |
+
|
50 |
+
Config | DUTS-TE | DUT-OMRON | ECSSD | HKU-IS | PASCAL-S | DAVIS-S | HRSOD-TE | UHRSD-TE
|
51 |
+
:-|:-|:-|:-|:-|:-|:-|:-|:-
|
52 |
+
[InSPyReNet_Res2Net50](../configs/InSPyReNet_Res2Net50.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Eb0iKXGX1vxEjPhe9KGBKr0Bv7v2vv6Ua5NFybwc6aIi1w?e=oHnGyJ&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ef1HaYMvgh1EuuOL8bw3JGYB41-yo6KdTD8FGXcFZX3-Bg?e=TkW2m8&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EdEQQ8o-yI9BtTpROcuB_iIBFSIk0uBJAkNyob0WI04-kw?e=cwEj2V&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ec6LyrumVZ9PoB2Af0OW4dcBrDht0OznnwOBYiu8pdyJ4A?e=Y04Fmn&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETPijMHlTRZIjqO5H4LBknUBmy8TGDwOyUQ1H4EnIpHVOw?e=k1afrh&download=1) | N/A | N/A | N/A |
|
53 |
+
[InSPyReNet_SwinB](../configs/InSPyReNet_SwinB.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETumLjuBantLim4kRqj4e_MBpK_X5XrTwjGQUToN8TKVjw?e=ZT8AWy&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZbwxhwT6dtHkBJrIMMjTnkBK_HaDTXgHcDSjxuswZKTZw?e=9XeE4b&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESfQK-557uZOmUwG5W49j0EBK42_7dMOaQcPsc_U1zsYlA?e=IvjkKX&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EURH96JUp55EgUHI0A8RzKoBBqvQc1nVb_a67RgwOY7f-w?e=IP9xKa&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EakMpwONph9EmnCM2rS3hn4B_TL42T6tuLjBEeEa5ndkIw?e=XksfA5&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETUCKFX0k8lAvpsDj5sT23QB2ohuE_ST7oQnWdaW7AoCIw?e=MbSmM2&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ea6kf6Kk8fpIs15WWDfJMoYBeQUeo9WXvYx9oM5yWFE1Jg?e=RNN0Ns&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EVJLvAP3HwtHksZMUolIfCABHqP7GgAWcG_1V5T_Xrnr2g?e=ct3pzo&download=1) |
|
54 |
+
[InSPyReNet_SwinB_DH_LR](../configs/extra_dataset/InSPyReNet_SwinB_DH_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EbaXjDWFb6lGp9x5ae9mJPgBt9dkmgclq9XrXDjj4B5qSw?e=57wnhE) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EU7uNIpKPXZHrDNt3gapGfsBaUrkCj67-Paj4w_7E8xs1g?e=n0QBR1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EeBq5uU02FRLiW-0g-mQ_CsB3iHlMgAyOSY2Deu5suo9pQ?e=zjhB33) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EY5AppwMB4hFqbISFL_u5QMBeux7dQWtDeXaMMcAyLqLqQ?e=N71XVt) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfhejUtLallLroU6jjgDl-oBSHITWnkdiU6NVV95DO5YqQ?e=T6rrRW) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EesN4fRAE4JJk1aZ_tH3EBQBgirysALcgfw1Ipsa9dLe9Q?e=b5oWsg) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETpVYDijnZhInN9zBHRuwhUBSPslqe9FP0m3Eo3TWS0d5A?e=QTzfx6) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUkKVH3LSyFLq6UXuvWEVuYBEH8W8uAKgyolLRVuIUILag?e=y1SceD)
|
55 |
+
[InSPyReNet_SwinB_HU_LR](../configs/extra_dataset/InSPyReNet_SwinB_HU_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUSblgWAwg9Plc4LCGj4TLwB-7HLEdZGJqF1jHOU55g3OA?e=2YT3zM) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESNaVX4Fh5JHn5jOfgnSWi4Bx1bc9t6pg79IoG3mrpZpAw?e=M8D0CM) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETKjTH1vcVZDu5ahRaw4cb8B7JKaPMR-0Uae1DbwarobIA?e=Qw67IZ) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EcC9-lPgAXZHs7th9DiVjygB-zPIq_1Ii6i1GpbLGc1iPQ?e=EVXKp9) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EXjyAS4XyzlPqrpaizgElioBdmgd4E81qQzj11Qm4xo5sA?e=hoOzc2) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EcSCMR033GVNotOilIzYhIsBikzb8ZzGlkuW6aSNMlUpqQ?e=TFcgvE) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUHRSEANcmVFjcS_K-PeYr0B6VPXPgb2AHFUnlJYrf3dOQ?e=unwcqV) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EU2gb0hS5kZBgKNXWqQAomsBXU-zjGCXKAzYNNk4d6EAiQ?e=pjhiN2)
|
56 |
+
[InSPyReNet_SwinB_DHU_LR](../configs/extra_dataset/InSPyReNet_SwinB_DHU_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ee7Y647ERzxMgBoFceEEO6kBIUkIlmYHoizMj71gT37sxw?e=xDt83C) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESFGN3FCdzRAvlsW6bEaGdoBYNoJgK4DAjaS6WkVVyI_QQ?e=nYHklV) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EeONq5kOirRCkErrb6fFqd8B4w4SMZXBY1Q2mJvZcRsGdQ?e=K7fwQt) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ea5P4CBBQatPiUzsH53lckoB0k23haePzuERBfyJXaCbBg?e=AZ96mc) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ea2pezB7eo1BraeBpIA8YZoBkVY38rRa3KrwSIzY1cn2dQ?e=o121S6) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfNDzH8O54pGtnAxit_hUjUBK9poVq4sxxnJjSG7PUQCkw?e=OWt7k8) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESfZaII0pO9IqqL1FkpjIuAB8SGxLcslJeWTuKQxPNFIVA?e=Ce1CWg) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EYsQqf1GKShJhBnkZ6gD5PABPOcjRcUGSfvTbe-wYh2O2Q?e=t4Xlxv)
|
57 |
+
[InSPyReNet_SwinB_DH](../configs/extra_dataset/InSPyReNet_SwinB_DH.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EWNaRqtzhtNFhMVfLcoyfqQBw35M8q8bxME3yZyhkTtc7Q?e=jrJe3v) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZF5_s8JfR9HqGBUZHSM_j4BVVMONp38_gJ1ekEdvlM-qQ?e=0chMdl) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZuaFObyNOtKg0W5cM7bqPYBZYGg7Z3V3i4sClI6bU_ntA?e=BxxQI7) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EY-yEnNNNT5KpiURhiDAkDEBMkiA1QwQ_T0wB1UC75GXVg?e=Lle02B) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EYfbsgWhm7lAlX_wj_WZZowBV_-l-UvvThC4LJEKpV0BQQ?e=zTiKpI) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ef0GUP7c0bBBonHlqgB988YB0rgxCFq3oo0u8xCN8wfyyQ?e=LCb8UV) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EVUZBnRpa35AmrvdUybsQDMBzMZvuJWe5tT7635lh9MHDQ?e=FlpQW1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETfZ_zdrDvhOh21u2mqVhigBSxn3vlfKVIwXhRfzzSSFzA?e=kXBBi9)
|
58 |
+
[InSPyReNet_SwinB_HU](../configs/extra_dataset/InSPyReNet_SwinB_HU.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZq4JUACKCBMk2bn4yoWz6sBOKrSFTPfL7d5xopc1uDw_A?e=RtVHSl) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETJaqoSPaYtNkc8eSGDeKzMBbjbuOAWgJwG4q52bW87aew?e=Pguh4b) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZAeCI6BPMdNsicnQ-m1pVEBwAhOiIcbelhOMoRGXGEvVA?e=BQKd7Q) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EVmGvZGz54JOvrIymLsSwq4Bpos3vWSXZm3oV7-qmGZgHA?e=4UhDgv) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ERHDUybOh4ZKkqWZpcu7MiMBFuTK6wACkKUZaNeEQGbCNQ?e=GCQnoe) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESPmZXTnfO5CrCoo_0OADxgBt_3FoU5mSFoSE4QWbWxumQ?e=HAsAYz) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EdTnUwEeMZNBrPSBBbGZKQcBmVshSTfca9qz_BqNpAUpOg?e=HsJ4Gx) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ET48owfVQEdImrh0V4gx8_ABsYXgbIJqtpq77aK_U28VwQ?e=h8er3H)
|
59 |
+
|
60 |
+
### DIS5K Results
|
61 |
+
|
62 |
+
Config | DIS-VD | DIS-TE1 | DIS-TE2 | DIS-TE3 | DIS-TE4
|
63 |
+
:-|:-|:-|:-|:-|:-|
|
64 |
+
[InSPyReNet_SwinB_DIS5K_LR](../configs/extra_dataset/InSPyReNet_SwinB_DIS5K_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUbzddb_QRRCtnXC8Xl6vZoBC6IqOfom52BWbzOYk-b2Ow?e=aqJYi1&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESeW_SOD26tHjBLymmgFaXwBIJlljzNycaGWXLpOp_d_kA?e=2EyMai&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EYWT5fZDjI5Bn-lr-iQM1TsB1num0-UqfJC1TIv-LuOXoA?e=jCcnty&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EQXm1DEBfaNJmH0B-A3o23kBn4v5j53kP2nF9CpG9SQkyw?e=lEUiZh&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZeH2ufGsFZIoUh6D8Rtv88BBF_ddQXav4xYXXRP_ayEAg?e=AMzIp8&download=1)
|
65 |
+
[InSPyReNet_SwinB_DIS5K](../configs/extra_dataset/InSPyReNet_SwinB_DIS5K.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUisiphB_W5BjgfpIYu9oNgB_fY4XxL-MhR2gR-ZZUt49Q?e=gqorYs) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EW2y54ROYIZFlEq5ilRFwOsBSrIm2-HGsUHPHykaJvUBfA?e=797fxr) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ER6yEGZkgWVOsL-mauYgPyoBDIoU0Mck-twEBgQi5g3Mxw?e=0yJZTT) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Edvzj0iZ8hdDthm4Q-p2YHgBhP1X5z4AAccAoUasr2nihA?e=1dognG) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EWwMORYg8DlCgGGqOQFThZ8BgIU9wV9-0DwLrKldQl7N8w?e=nCwuqy)
|
figures/demo_image.gif
ADDED
![]() |
Git LFS Details
|
figures/demo_type.png
ADDED
![]() |
Git LFS Details
|
figures/demo_video.gif
ADDED
![]() |
Git LFS Details
|
figures/demo_webapp.gif
ADDED
![]() |
Git LFS Details
|
figures/fig_architecture.png
ADDED
![]() |
Git LFS Details
|
figures/fig_pyramid_blending.png
ADDED
![]() |
Git LFS Details
|
figures/fig_qualitative.png
ADDED
![]() |
Git LFS Details
|
figures/fig_qualitative2.png
ADDED
![]() |
Git LFS Details
|
figures/fig_qualitative3.jpg
ADDED
![]() |
Git LFS Details
|
figures/fig_qualitative_dis.png
ADDED
![]() |
Git LFS Details
|
figures/fig_quantitative.png
ADDED
![]() |
Git LFS Details
|
figures/fig_quantitative2.png
ADDED
![]() |
Git LFS Details
|
figures/fig_quantitative3.png
ADDED
![]() |
Git LFS Details
|
figures/fig_quantitative4.png
ADDED
![]() |
Git LFS Details
|
lib/InSPyReNet.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from lib.optim import *
|
7 |
+
from lib.modules.layers import *
|
8 |
+
from lib.modules.context_module import *
|
9 |
+
from lib.modules.attention_module import *
|
10 |
+
from lib.modules.decoder_module import *
|
11 |
+
|
12 |
+
from lib.backbones.Res2Net_v1b import res2net50_v1b_26w_4s
|
13 |
+
from lib.backbones.SwinTransformer import SwinB
|
14 |
+
|
15 |
+
class InSPyReNet(nn.Module):
|
16 |
+
def __init__(self, backbone, in_channels, depth=64, base_size=[384, 384], threshold=512, **kwargs):
|
17 |
+
super(InSPyReNet, self).__init__()
|
18 |
+
self.backbone = backbone
|
19 |
+
self.in_channels = in_channels
|
20 |
+
self.depth = depth
|
21 |
+
self.base_size = base_size
|
22 |
+
self.threshold = threshold
|
23 |
+
|
24 |
+
self.context1 = PAA_e(self.in_channels[0], self.depth, base_size=self.base_size, stage=0)
|
25 |
+
self.context2 = PAA_e(self.in_channels[1], self.depth, base_size=self.base_size, stage=1)
|
26 |
+
self.context3 = PAA_e(self.in_channels[2], self.depth, base_size=self.base_size, stage=2)
|
27 |
+
self.context4 = PAA_e(self.in_channels[3], self.depth, base_size=self.base_size, stage=3)
|
28 |
+
self.context5 = PAA_e(self.in_channels[4], self.depth, base_size=self.base_size, stage=4)
|
29 |
+
|
30 |
+
self.decoder = PAA_d(self.depth * 3, depth=self.depth, base_size=base_size, stage=2)
|
31 |
+
|
32 |
+
self.attention0 = SICA(self.depth , depth=self.depth, base_size=self.base_size, stage=0, lmap_in=True)
|
33 |
+
self.attention1 = SICA(self.depth * 2, depth=self.depth, base_size=self.base_size, stage=1, lmap_in=True)
|
34 |
+
self.attention2 = SICA(self.depth * 2, depth=self.depth, base_size=self.base_size, stage=2 )
|
35 |
+
|
36 |
+
self.sod_loss_fn = lambda x, y: weighted_bce_loss_with_logits(x, y, reduction='mean') + iou_loss_with_logits(x, y, reduction='mean')
|
37 |
+
self.pc_loss_fn = nn.L1Loss()
|
38 |
+
|
39 |
+
self.ret = lambda x, target: F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False)
|
40 |
+
self.res = lambda x, size: F.interpolate(x, size=size, mode='bilinear', align_corners=False)
|
41 |
+
self.des = lambda x, size: F.interpolate(x, size=size, mode='nearest')
|
42 |
+
|
43 |
+
self.image_pyramid = ImagePyramid(7, 1)
|
44 |
+
|
45 |
+
self.transition0 = Transition(17)
|
46 |
+
self.transition1 = Transition(9)
|
47 |
+
self.transition2 = Transition(5)
|
48 |
+
|
49 |
+
self.forward = self.forward_inference
|
50 |
+
|
51 |
+
def to(self, device):
|
52 |
+
self.image_pyramid.to(device)
|
53 |
+
self.transition0.to(device)
|
54 |
+
self.transition1.to(device)
|
55 |
+
self.transition2.to(device)
|
56 |
+
super(InSPyReNet, self).to(device)
|
57 |
+
return self
|
58 |
+
|
59 |
+
def cuda(self, idx=None):
|
60 |
+
if idx is None:
|
61 |
+
idx = torch.cuda.current_device()
|
62 |
+
|
63 |
+
self.to(device="cuda:{}".format(idx))
|
64 |
+
return self
|
65 |
+
|
66 |
+
def train(self, mode=True):
|
67 |
+
super(InSPyReNet, self).train(mode)
|
68 |
+
self.forward = self.forward_train
|
69 |
+
return self
|
70 |
+
|
71 |
+
def eval(self):
|
72 |
+
super(InSPyReNet, self).train(False)
|
73 |
+
self.forward = self.forward_inference
|
74 |
+
return self
|
75 |
+
|
76 |
+
def forward_inspyre(self, x):
|
77 |
+
B, _, H, W = x.shape
|
78 |
+
|
79 |
+
x1, x2, x3, x4, x5 = self.backbone(x)
|
80 |
+
|
81 |
+
x1 = self.context1(x1) #4
|
82 |
+
x2 = self.context2(x2) #4
|
83 |
+
x3 = self.context3(x3) #8
|
84 |
+
x4 = self.context4(x4) #16
|
85 |
+
x5 = self.context5(x5) #32
|
86 |
+
|
87 |
+
f3, d3 = self.decoder([x3, x4, x5]) #16
|
88 |
+
|
89 |
+
f3 = self.res(f3, (H // 4, W // 4 ))
|
90 |
+
f2, p2 = self.attention2(torch.cat([x2, f3], dim=1), d3.detach())
|
91 |
+
d2 = self.image_pyramid.reconstruct(d3.detach(), p2) #4
|
92 |
+
|
93 |
+
x1 = self.res(x1, (H // 2, W // 2))
|
94 |
+
f2 = self.res(f2, (H // 2, W // 2))
|
95 |
+
f1, p1 = self.attention1(torch.cat([x1, f2], dim=1), d2.detach(), p2.detach()) #2
|
96 |
+
d1 = self.image_pyramid.reconstruct(d2.detach(), p1) #2
|
97 |
+
|
98 |
+
f1 = self.res(f1, (H, W))
|
99 |
+
_, p0 = self.attention0(f1, d1.detach(), p1.detach()) #2
|
100 |
+
d0 = self.image_pyramid.reconstruct(d1.detach(), p0) #2
|
101 |
+
|
102 |
+
out = dict()
|
103 |
+
out['saliency'] = [d3, d2, d1, d0]
|
104 |
+
out['laplacian'] = [p2, p1, p0]
|
105 |
+
|
106 |
+
return out
|
107 |
+
|
108 |
+
def forward_train(self, sample):
|
109 |
+
x = sample['image']
|
110 |
+
B, _, H, W = x.shape
|
111 |
+
out = self.forward_inspyre(x)
|
112 |
+
|
113 |
+
d3, d2, d1, d0 = out['saliency']
|
114 |
+
p2, p1, p0 = out['laplacian']
|
115 |
+
|
116 |
+
if type(sample) == dict and 'gt' in sample.keys() and sample['gt'] is not None:
|
117 |
+
y = sample['gt']
|
118 |
+
|
119 |
+
y1 = self.image_pyramid.reduce(y)
|
120 |
+
y2 = self.image_pyramid.reduce(y1)
|
121 |
+
y3 = self.image_pyramid.reduce(y2)
|
122 |
+
|
123 |
+
loss = self.pc_loss_fn(self.des(d3, (H, W)), self.des(self.image_pyramid.reduce(d2), (H, W)).detach()) * 0.0001
|
124 |
+
loss += self.pc_loss_fn(self.des(d2, (H, W)), self.des(self.image_pyramid.reduce(d1), (H, W)).detach()) * 0.0001
|
125 |
+
loss += self.pc_loss_fn(self.des(d1, (H, W)), self.des(self.image_pyramid.reduce(d0), (H, W)).detach()) * 0.0001
|
126 |
+
|
127 |
+
loss += self.sod_loss_fn(self.des(d3, (H, W)), self.des(y3, (H, W)))
|
128 |
+
loss += self.sod_loss_fn(self.des(d2, (H, W)), self.des(y2, (H, W)))
|
129 |
+
loss += self.sod_loss_fn(self.des(d1, (H, W)), self.des(y1, (H, W)))
|
130 |
+
loss += self.sod_loss_fn(self.des(d0, (H, W)), self.des(y, (H, W)))
|
131 |
+
|
132 |
+
else:
|
133 |
+
loss = 0
|
134 |
+
|
135 |
+
pred = torch.sigmoid(d0)
|
136 |
+
pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
|
137 |
+
|
138 |
+
sample['pred'] = pred
|
139 |
+
sample['loss'] = loss
|
140 |
+
sample['saliency'] = [d3, d2, d1, d0]
|
141 |
+
sample['laplacian'] = [p2, p1, p0]
|
142 |
+
return sample
|
143 |
+
|
144 |
+
def forward_inference(self, sample):
|
145 |
+
B, _, H, W = sample['image'].shape
|
146 |
+
|
147 |
+
if self.threshold is None:
|
148 |
+
out = self.forward_inspyre(sample['image'])
|
149 |
+
d3, d2, d1, d0 = out['saliency']
|
150 |
+
p2, p1, p0 = out['laplacian']
|
151 |
+
|
152 |
+
elif (H <= self.threshold or W <= self.threshold):
|
153 |
+
if 'image_resized' in sample.keys():
|
154 |
+
out = self.forward_inspyre(sample['image_resized'])
|
155 |
+
else:
|
156 |
+
out = self.forward_inspyre(sample['image'])
|
157 |
+
d3, d2, d1, d0 = out['saliency']
|
158 |
+
p2, p1, p0 = out['laplacian']
|
159 |
+
|
160 |
+
else:
|
161 |
+
# LR Saliency Pyramid
|
162 |
+
lr_out = self.forward_inspyre(sample['image_resized'])
|
163 |
+
lr_d3, lr_d2, lr_d1, lr_d0 = lr_out['saliency']
|
164 |
+
lr_p2, lr_p1, lr_p0 = lr_out['laplacian']
|
165 |
+
|
166 |
+
# HR Saliency Pyramid
|
167 |
+
hr_out = self.forward_inspyre(sample['image'])
|
168 |
+
hr_d3, hr_d2, hr_d1, hr_d0 = hr_out['saliency']
|
169 |
+
hr_p2, hr_p1, hr_p0 = hr_out['laplacian']
|
170 |
+
|
171 |
+
# Pyramid Blending
|
172 |
+
d3 = self.ret(lr_d0, hr_d3)
|
173 |
+
|
174 |
+
t2 = self.ret(self.transition2(d3), hr_p2)
|
175 |
+
p2 = t2 * hr_p2
|
176 |
+
d2 = self.image_pyramid.reconstruct(d3, p2)
|
177 |
+
|
178 |
+
t1 = self.ret(self.transition1(d2), hr_p1)
|
179 |
+
p1 = t1 * hr_p1
|
180 |
+
d1 = self.image_pyramid.reconstruct(d2, p1)
|
181 |
+
|
182 |
+
t0 = self.ret(self.transition0(d1), hr_p0)
|
183 |
+
p0 = t0 * hr_p0
|
184 |
+
d0 = self.image_pyramid.reconstruct(d1, p0)
|
185 |
+
|
186 |
+
pred = torch.sigmoid(d0)
|
187 |
+
pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
|
188 |
+
|
189 |
+
sample['pred'] = pred
|
190 |
+
sample['loss'] = 0
|
191 |
+
sample['saliency'] = [d3, d2, d1, d0]
|
192 |
+
sample['laplacian'] = [p2, p1, p0]
|
193 |
+
return sample
|
194 |
+
|
195 |
+
def InSPyReNet_Res2Net50(depth, pretrained, base_size, **kwargs):
|
196 |
+
return InSPyReNet(res2net50_v1b_26w_4s(pretrained=pretrained), [64, 256, 512, 1024, 2048], depth, base_size, **kwargs)
|
197 |
+
|
198 |
+
def InSPyReNet_SwinB(depth, pretrained, base_size, **kwargs):
|
199 |
+
return InSPyReNet(SwinB(pretrained=pretrained), [128, 128, 256, 512, 1024], depth, base_size, **kwargs)
|
lib/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from lib.InSPyReNet import InSPyReNet_Res2Net50, InSPyReNet_SwinB
|
lib/backbones/Res2Net_v1b.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import math
|
3 |
+
import torch.utils.model_zoo as model_zoo
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
__all__ = ['Res2Net', 'res2net50_v1b',
|
8 |
+
'res2net101_v1b', 'res2net50_v1b_26w_4s']
|
9 |
+
|
10 |
+
model_urls = {
|
11 |
+
'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth',
|
12 |
+
'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth',
|
13 |
+
}
|
14 |
+
class Bottle2neck(nn.Module):
|
15 |
+
expansion = 4
|
16 |
+
|
17 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, baseWidth=26, scale=4, stype='normal'):
|
18 |
+
super(Bottle2neck, self).__init__()
|
19 |
+
|
20 |
+
width = int(math.floor(planes * (baseWidth / 64.0)))
|
21 |
+
self.conv1 = nn.Conv2d(inplanes, width * scale,
|
22 |
+
kernel_size=1, bias=False)
|
23 |
+
self.bn1 = nn.BatchNorm2d(width * scale)
|
24 |
+
|
25 |
+
if scale == 1:
|
26 |
+
self.nums = 1
|
27 |
+
else:
|
28 |
+
self.nums = scale - 1
|
29 |
+
if stype == 'stage':
|
30 |
+
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
|
31 |
+
convs = []
|
32 |
+
bns = []
|
33 |
+
for i in range(self.nums):
|
34 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride,
|
35 |
+
dilation=dilation, padding=dilation, bias=False))
|
36 |
+
bns.append(nn.BatchNorm2d(width))
|
37 |
+
self.convs = nn.ModuleList(convs)
|
38 |
+
self.bns = nn.ModuleList(bns)
|
39 |
+
|
40 |
+
self.conv3 = nn.Conv2d(width * scale, planes *
|
41 |
+
self.expansion, kernel_size=1, bias=False)
|
42 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
43 |
+
|
44 |
+
self.relu = nn.ReLU(inplace=True)
|
45 |
+
self.downsample = downsample
|
46 |
+
self.stype = stype
|
47 |
+
self.scale = scale
|
48 |
+
self.width = width
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
residual = x
|
52 |
+
|
53 |
+
out = self.conv1(x)
|
54 |
+
out = self.bn1(out)
|
55 |
+
out = self.relu(out)
|
56 |
+
|
57 |
+
spx = torch.split(out, self.width, 1)
|
58 |
+
for i in range(self.nums):
|
59 |
+
if i == 0 or self.stype == 'stage':
|
60 |
+
sp = spx[i]
|
61 |
+
else:
|
62 |
+
sp = sp + spx[i]
|
63 |
+
sp = self.convs[i](sp)
|
64 |
+
sp = self.relu(self.bns[i](sp))
|
65 |
+
if i == 0:
|
66 |
+
out = sp
|
67 |
+
else:
|
68 |
+
out = torch.cat((out, sp), 1)
|
69 |
+
if self.scale != 1 and self.stype == 'normal':
|
70 |
+
out = torch.cat((out, spx[self.nums]), 1)
|
71 |
+
elif self.scale != 1 and self.stype == 'stage':
|
72 |
+
out = torch.cat((out, self.pool(spx[self.nums])), 1)
|
73 |
+
|
74 |
+
out = self.conv3(out)
|
75 |
+
out = self.bn3(out)
|
76 |
+
|
77 |
+
if self.downsample is not None:
|
78 |
+
residual = self.downsample(x)
|
79 |
+
|
80 |
+
out += residual
|
81 |
+
out = self.relu(out)
|
82 |
+
|
83 |
+
return out
|
84 |
+
|
85 |
+
|
86 |
+
class Res2Net(nn.Module):
|
87 |
+
|
88 |
+
def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000, output_stride=32):
|
89 |
+
self.inplanes = 64
|
90 |
+
super(Res2Net, self).__init__()
|
91 |
+
self.baseWidth = baseWidth
|
92 |
+
self.scale = scale
|
93 |
+
self.output_stride = output_stride
|
94 |
+
if self.output_stride == 8:
|
95 |
+
self.grid = [1, 2, 1]
|
96 |
+
self.stride = [1, 2, 1, 1]
|
97 |
+
self.dilation = [1, 1, 2, 4]
|
98 |
+
elif self.output_stride == 16:
|
99 |
+
self.grid = [1, 2, 4]
|
100 |
+
self.stride = [1, 2, 2, 1]
|
101 |
+
self.dilation = [1, 1, 1, 2]
|
102 |
+
elif self.output_stride == 32:
|
103 |
+
self.grid = [1, 2, 4]
|
104 |
+
self.stride = [1, 2, 2, 2]
|
105 |
+
self.dilation = [1, 1, 2, 4]
|
106 |
+
self.conv1 = nn.Sequential(
|
107 |
+
nn.Conv2d(3, 32, 3, 2, 1, bias=False),
|
108 |
+
nn.BatchNorm2d(32),
|
109 |
+
nn.ReLU(inplace=True),
|
110 |
+
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
|
111 |
+
nn.BatchNorm2d(32),
|
112 |
+
nn.ReLU(inplace=True),
|
113 |
+
nn.Conv2d(32, 64, 3, 1, 1, bias=False)
|
114 |
+
)
|
115 |
+
self.bn1 = nn.BatchNorm2d(64)
|
116 |
+
self.relu = nn.ReLU()
|
117 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
118 |
+
self.layer1 = self._make_layer(
|
119 |
+
block, 64, layers[0], stride=self.stride[0], dilation=self.dilation[0])
|
120 |
+
self.layer2 = self._make_layer(
|
121 |
+
block, 128, layers[1], stride=self.stride[1], dilation=self.dilation[1])
|
122 |
+
self.layer3 = self._make_layer(
|
123 |
+
block, 256, layers[2], stride=self.stride[2], dilation=self.dilation[2])
|
124 |
+
self.layer4 = self._make_layer(
|
125 |
+
block, 512, layers[3], stride=self.stride[3], dilation=self.dilation[3], grid=self.grid)
|
126 |
+
|
127 |
+
for m in self.modules():
|
128 |
+
if isinstance(m, nn.Conv2d):
|
129 |
+
nn.init.kaiming_normal_(
|
130 |
+
m.weight, mode='fan_out', nonlinearity='relu')
|
131 |
+
elif isinstance(m, nn.BatchNorm2d):
|
132 |
+
nn.init.constant_(m.weight, 1)
|
133 |
+
nn.init.constant_(m.bias, 0)
|
134 |
+
|
135 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, grid=None):
|
136 |
+
downsample = None
|
137 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
138 |
+
downsample = nn.Sequential(
|
139 |
+
nn.AvgPool2d(kernel_size=stride, stride=stride,
|
140 |
+
ceil_mode=True, count_include_pad=False),
|
141 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
142 |
+
kernel_size=1, stride=1, bias=False),
|
143 |
+
nn.BatchNorm2d(planes * block.expansion),
|
144 |
+
)
|
145 |
+
|
146 |
+
layers = []
|
147 |
+
layers.append(block(self.inplanes, planes, stride, dilation, downsample=downsample,
|
148 |
+
stype='stage', baseWidth=self.baseWidth, scale=self.scale))
|
149 |
+
self.inplanes = planes * block.expansion
|
150 |
+
|
151 |
+
if grid is not None:
|
152 |
+
assert len(grid) == blocks
|
153 |
+
else:
|
154 |
+
grid = [1] * blocks
|
155 |
+
|
156 |
+
for i in range(1, blocks):
|
157 |
+
layers.append(block(self.inplanes, planes, dilation=dilation *
|
158 |
+
grid[i], baseWidth=self.baseWidth, scale=self.scale))
|
159 |
+
|
160 |
+
return nn.Sequential(*layers)
|
161 |
+
|
162 |
+
def change_stride(self, output_stride=16):
|
163 |
+
if output_stride == self.output_stride:
|
164 |
+
return
|
165 |
+
else:
|
166 |
+
self.output_stride = output_stride
|
167 |
+
if self.output_stride == 8:
|
168 |
+
self.grid = [1, 2, 1]
|
169 |
+
self.stride = [1, 2, 1, 1]
|
170 |
+
self.dilation = [1, 1, 2, 4]
|
171 |
+
elif self.output_stride == 16:
|
172 |
+
self.grid = [1, 2, 4]
|
173 |
+
self.stride = [1, 2, 2, 1]
|
174 |
+
self.dilation = [1, 1, 1, 2]
|
175 |
+
elif self.output_stride == 32:
|
176 |
+
self.grid = [1, 2, 4]
|
177 |
+
self.stride = [1, 2, 2, 2]
|
178 |
+
self.dilation = [1, 1, 2, 4]
|
179 |
+
|
180 |
+
for i, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]):
|
181 |
+
for j, block in enumerate(layer):
|
182 |
+
if block.downsample is not None:
|
183 |
+
block.downsample[0].kernel_size = (
|
184 |
+
self.stride[i], self.stride[i])
|
185 |
+
block.downsample[0].stride = (
|
186 |
+
self.stride[i], self.stride[i])
|
187 |
+
if hasattr(block, 'pool'):
|
188 |
+
block.pool.stride = (
|
189 |
+
self.stride[i], self.stride[i])
|
190 |
+
for conv in block.convs:
|
191 |
+
conv.stride = (self.stride[i], self.stride[i])
|
192 |
+
for conv in block.convs:
|
193 |
+
d = self.dilation[i] if i != 3 else self.dilation[i] * \
|
194 |
+
self.grid[j]
|
195 |
+
conv.dilation = (d, d)
|
196 |
+
conv.padding = (d, d)
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
x = self.conv1(x)
|
200 |
+
x = self.bn1(x)
|
201 |
+
x = self.relu(x)
|
202 |
+
x = self.maxpool(x)
|
203 |
+
|
204 |
+
out = [x]
|
205 |
+
|
206 |
+
x = self.layer1(x)
|
207 |
+
out.append(x)
|
208 |
+
x = self.layer2(x)
|
209 |
+
out.append(x)
|
210 |
+
x = self.layer3(x)
|
211 |
+
out.append(x)
|
212 |
+
x = self.layer4(x)
|
213 |
+
out.append(x)
|
214 |
+
|
215 |
+
return out
|
216 |
+
|
217 |
+
|
218 |
+
def res2net50_v1b(pretrained=False, **kwargs):
|
219 |
+
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
|
220 |
+
if pretrained:
|
221 |
+
model.load_state_dict(model_zoo.load_url(
|
222 |
+
model_urls['res2net50_v1b_26w_4s']))
|
223 |
+
|
224 |
+
return model
|
225 |
+
|
226 |
+
|
227 |
+
def res2net101_v1b(pretrained=False, **kwargs):
|
228 |
+
model = Res2Net(Bottle2neck, [3, 4, 23, 3],
|
229 |
+
baseWidth=26, scale=4, **kwargs)
|
230 |
+
if pretrained:
|
231 |
+
model.load_state_dict(model_zoo.load_url(
|
232 |
+
model_urls['res2net101_v1b_26w_4s']))
|
233 |
+
|
234 |
+
return model
|
235 |
+
|
236 |
+
|
237 |
+
def res2net50_v1b_26w_4s(pretrained=True, **kwargs):
|
238 |
+
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
|
239 |
+
if pretrained is True:
|
240 |
+
model.load_state_dict(torch.load('data/backbone_ckpt/res2net50_v1b_26w_4s-3cf99910.pth', map_location='cpu'))
|
241 |
+
|
242 |
+
return model
|
243 |
+
|
244 |
+
|
245 |
+
def res2net101_v1b_26w_4s(pretrained=True, **kwargs):
|
246 |
+
model = Res2Net(Bottle2neck, [3, 4, 23, 3],
|
247 |
+
baseWidth=26, scale=4, **kwargs)
|
248 |
+
if pretrained is True:
|
249 |
+
model.load_state_dict(torch.load('data/backbone_ckpt/res2net101_v1b_26w_4s-0812c246.pth', map_location='cpu'))
|
250 |
+
|
251 |
+
return model
|
252 |
+
|
253 |
+
|
254 |
+
def res2net152_v1b_26w_4s(pretrained=False, **kwargs):
|
255 |
+
model = Res2Net(Bottle2neck, [3, 8, 36, 3],
|
256 |
+
baseWidth=26, scale=4, **kwargs)
|
257 |
+
if pretrained:
|
258 |
+
model.load_state_dict(model_zoo.load_url(
|
259 |
+
model_urls['res2net152_v1b_26w_4s']))
|
260 |
+
|
261 |
+
return model
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
images = torch.rand(1, 3, 224, 224).cuda(0)
|
266 |
+
model = res2net50_v1b_26w_4s(pretrained=True)
|
267 |
+
model = model.cuda(0)
|
268 |
+
print(model(images).size())
|
lib/backbones/SwinTransformer.py
ADDED
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu, Yutong Lin, Yixuan Wei
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint as checkpoint
|
12 |
+
import numpy as np
|
13 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
14 |
+
|
15 |
+
class Mlp(nn.Module):
|
16 |
+
""" Multilayer perceptron."""
|
17 |
+
|
18 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
19 |
+
super().__init__()
|
20 |
+
out_features = out_features or in_features
|
21 |
+
hidden_features = hidden_features or in_features
|
22 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
23 |
+
self.act = act_layer()
|
24 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
25 |
+
self.drop = nn.Dropout(drop)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
x = self.fc1(x)
|
29 |
+
x = self.act(x)
|
30 |
+
x = self.drop(x)
|
31 |
+
x = self.fc2(x)
|
32 |
+
x = self.drop(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
def window_partition(x, window_size):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
x: (B, H, W, C)
|
40 |
+
window_size (int): window size
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
windows: (num_windows*B, window_size, window_size, C)
|
44 |
+
"""
|
45 |
+
B, H, W, C = x.shape
|
46 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
47 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
48 |
+
return windows
|
49 |
+
|
50 |
+
|
51 |
+
def window_reverse(windows, window_size, H, W):
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
windows: (num_windows*B, window_size, window_size, C)
|
55 |
+
window_size (int): Window size
|
56 |
+
H (int): Height of image
|
57 |
+
W (int): Width of image
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
x: (B, H, W, C)
|
61 |
+
"""
|
62 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
63 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
64 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class WindowAttention(nn.Module):
|
69 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
70 |
+
It supports both of shifted and non-shifted window.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
dim (int): Number of input channels.
|
74 |
+
window_size (tuple[int]): The height and width of the window.
|
75 |
+
num_heads (int): Number of attention heads.
|
76 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
77 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
78 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
79 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
83 |
+
|
84 |
+
super().__init__()
|
85 |
+
self.dim = dim
|
86 |
+
self.window_size = window_size # Wh, Ww
|
87 |
+
self.num_heads = num_heads
|
88 |
+
head_dim = dim // num_heads
|
89 |
+
self.scale = qk_scale or head_dim ** -0.5
|
90 |
+
|
91 |
+
# define a parameter table of relative position bias
|
92 |
+
self.relative_position_bias_table = nn.Parameter(
|
93 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
94 |
+
|
95 |
+
# get pair-wise relative position index for each token inside the window
|
96 |
+
coords_h = torch.arange(self.window_size[0])
|
97 |
+
coords_w = torch.arange(self.window_size[1])
|
98 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
99 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
100 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
101 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
102 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
103 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
104 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
105 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
106 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
107 |
+
|
108 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
109 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
110 |
+
self.proj = nn.Linear(dim, dim)
|
111 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
112 |
+
|
113 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
114 |
+
self.softmax = nn.Softmax(dim=-1)
|
115 |
+
|
116 |
+
def forward(self, x, mask=None):
|
117 |
+
""" Forward function.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
x: input features with shape of (num_windows*B, N, C)
|
121 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
122 |
+
"""
|
123 |
+
B_, N, C = x.shape
|
124 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
125 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
126 |
+
|
127 |
+
q = q * self.scale
|
128 |
+
attn = (q @ k.transpose(-2, -1))
|
129 |
+
|
130 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
131 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
132 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
133 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
134 |
+
|
135 |
+
if mask is not None:
|
136 |
+
nW = mask.shape[0]
|
137 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
138 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
139 |
+
attn = self.softmax(attn)
|
140 |
+
else:
|
141 |
+
attn = self.softmax(attn)
|
142 |
+
|
143 |
+
attn = self.attn_drop(attn)
|
144 |
+
|
145 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
146 |
+
x = self.proj(x)
|
147 |
+
x = self.proj_drop(x)
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class SwinTransformerBlock(nn.Module):
|
152 |
+
""" Swin Transformer Block.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
dim (int): Number of input channels.
|
156 |
+
num_heads (int): Number of attention heads.
|
157 |
+
window_size (int): Window size.
|
158 |
+
shift_size (int): Shift size for SW-MSA.
|
159 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
160 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
161 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
162 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
163 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
164 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
165 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
166 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
170 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
171 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
172 |
+
super().__init__()
|
173 |
+
self.dim = dim
|
174 |
+
self.num_heads = num_heads
|
175 |
+
self.window_size = window_size
|
176 |
+
self.shift_size = shift_size
|
177 |
+
self.mlp_ratio = mlp_ratio
|
178 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
179 |
+
|
180 |
+
self.norm1 = norm_layer(dim)
|
181 |
+
self.attn = WindowAttention(
|
182 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
183 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
184 |
+
|
185 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
186 |
+
self.norm2 = norm_layer(dim)
|
187 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
188 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
189 |
+
|
190 |
+
self.H = None
|
191 |
+
self.W = None
|
192 |
+
|
193 |
+
def forward(self, x, mask_matrix):
|
194 |
+
""" Forward function.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
x: Input feature, tensor size (B, H*W, C).
|
198 |
+
H, W: Spatial resolution of the input feature.
|
199 |
+
mask_matrix: Attention mask for cyclic shift.
|
200 |
+
"""
|
201 |
+
B, L, C = x.shape
|
202 |
+
H, W = self.H, self.W
|
203 |
+
assert L == H * W, "input feature has wrong size"
|
204 |
+
|
205 |
+
shortcut = x
|
206 |
+
x = self.norm1(x)
|
207 |
+
x = x.view(B, H, W, C)
|
208 |
+
|
209 |
+
# pad feature maps to multiples of window size
|
210 |
+
pad_l = pad_t = 0
|
211 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
212 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
213 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
214 |
+
_, Hp, Wp, _ = x.shape
|
215 |
+
|
216 |
+
# cyclic shift
|
217 |
+
if self.shift_size > 0:
|
218 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
219 |
+
attn_mask = mask_matrix
|
220 |
+
else:
|
221 |
+
shifted_x = x
|
222 |
+
attn_mask = None
|
223 |
+
|
224 |
+
# partition windows
|
225 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
226 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
227 |
+
|
228 |
+
# W-MSA/SW-MSA
|
229 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
230 |
+
|
231 |
+
# merge windows
|
232 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
233 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
234 |
+
|
235 |
+
# reverse cyclic shift
|
236 |
+
if self.shift_size > 0:
|
237 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
238 |
+
else:
|
239 |
+
x = shifted_x
|
240 |
+
|
241 |
+
if pad_r > 0 or pad_b > 0:
|
242 |
+
x = x[:, :H, :W, :].contiguous()
|
243 |
+
|
244 |
+
x = x.view(B, H * W, C)
|
245 |
+
|
246 |
+
# FFN
|
247 |
+
x = shortcut + self.drop_path(x)
|
248 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
249 |
+
|
250 |
+
return x
|
251 |
+
|
252 |
+
|
253 |
+
class PatchMerging(nn.Module):
|
254 |
+
""" Patch Merging Layer
|
255 |
+
|
256 |
+
Args:
|
257 |
+
dim (int): Number of input channels.
|
258 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
259 |
+
"""
|
260 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
261 |
+
super().__init__()
|
262 |
+
self.dim = dim
|
263 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
264 |
+
self.norm = norm_layer(4 * dim)
|
265 |
+
|
266 |
+
def forward(self, x, H, W):
|
267 |
+
""" Forward function.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
x: Input feature, tensor size (B, H*W, C).
|
271 |
+
H, W: Spatial resolution of the input feature.
|
272 |
+
"""
|
273 |
+
B, L, C = x.shape
|
274 |
+
assert L == H * W, "input feature has wrong size"
|
275 |
+
|
276 |
+
x = x.view(B, H, W, C)
|
277 |
+
|
278 |
+
# padding
|
279 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
280 |
+
if pad_input:
|
281 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
282 |
+
|
283 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
284 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
285 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
286 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
287 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
288 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
289 |
+
|
290 |
+
x = self.norm(x)
|
291 |
+
x = self.reduction(x)
|
292 |
+
|
293 |
+
return x
|
294 |
+
|
295 |
+
|
296 |
+
class BasicLayer(nn.Module):
|
297 |
+
""" A basic Swin Transformer layer for one stage.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
dim (int): Number of feature channels
|
301 |
+
depth (int): Depths of this stage.
|
302 |
+
num_heads (int): Number of attention head.
|
303 |
+
window_size (int): Local window size. Default: 7.
|
304 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
305 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
306 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
307 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
308 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
309 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
310 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
311 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
312 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
313 |
+
"""
|
314 |
+
|
315 |
+
def __init__(self,
|
316 |
+
dim,
|
317 |
+
depth,
|
318 |
+
num_heads,
|
319 |
+
window_size=7,
|
320 |
+
mlp_ratio=4.,
|
321 |
+
qkv_bias=True,
|
322 |
+
qk_scale=None,
|
323 |
+
drop=0.,
|
324 |
+
attn_drop=0.,
|
325 |
+
drop_path=0.,
|
326 |
+
norm_layer=nn.LayerNorm,
|
327 |
+
downsample=None,
|
328 |
+
use_checkpoint=False):
|
329 |
+
super().__init__()
|
330 |
+
self.window_size = window_size
|
331 |
+
self.shift_size = window_size // 2
|
332 |
+
self.depth = depth
|
333 |
+
self.use_checkpoint = use_checkpoint
|
334 |
+
|
335 |
+
# build blocks
|
336 |
+
self.blocks = nn.ModuleList([
|
337 |
+
SwinTransformerBlock(
|
338 |
+
dim=dim,
|
339 |
+
num_heads=num_heads,
|
340 |
+
window_size=window_size,
|
341 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
342 |
+
mlp_ratio=mlp_ratio,
|
343 |
+
qkv_bias=qkv_bias,
|
344 |
+
qk_scale=qk_scale,
|
345 |
+
drop=drop,
|
346 |
+
attn_drop=attn_drop,
|
347 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
348 |
+
norm_layer=norm_layer)
|
349 |
+
for i in range(depth)])
|
350 |
+
|
351 |
+
# patch merging layer
|
352 |
+
if downsample is not None:
|
353 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
354 |
+
else:
|
355 |
+
self.downsample = None
|
356 |
+
|
357 |
+
def forward(self, x, H, W):
|
358 |
+
""" Forward function.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
x: Input feature, tensor size (B, H*W, C).
|
362 |
+
H, W: Spatial resolution of the input feature.
|
363 |
+
"""
|
364 |
+
|
365 |
+
# calculate attention mask for SW-MSA
|
366 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
367 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
368 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
369 |
+
h_slices = (slice(0, -self.window_size),
|
370 |
+
slice(-self.window_size, -self.shift_size),
|
371 |
+
slice(-self.shift_size, None))
|
372 |
+
w_slices = (slice(0, -self.window_size),
|
373 |
+
slice(-self.window_size, -self.shift_size),
|
374 |
+
slice(-self.shift_size, None))
|
375 |
+
cnt = 0
|
376 |
+
for h in h_slices:
|
377 |
+
for w in w_slices:
|
378 |
+
img_mask[:, h, w, :] = cnt
|
379 |
+
cnt += 1
|
380 |
+
|
381 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
382 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
383 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
384 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
385 |
+
|
386 |
+
for blk in self.blocks:
|
387 |
+
blk.H, blk.W = H, W
|
388 |
+
if self.use_checkpoint:
|
389 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
390 |
+
else:
|
391 |
+
x = blk(x, attn_mask)
|
392 |
+
if self.downsample is not None:
|
393 |
+
x_down = self.downsample(x, H, W)
|
394 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
395 |
+
return x, H, W, x_down, Wh, Ww
|
396 |
+
else:
|
397 |
+
return x, H, W, x, H, W
|
398 |
+
|
399 |
+
|
400 |
+
class PatchEmbed(nn.Module):
|
401 |
+
""" Image to Patch Embedding
|
402 |
+
|
403 |
+
Args:
|
404 |
+
patch_size (int): Patch token size. Default: 4.
|
405 |
+
in_chans (int): Number of input image channels. Default: 3.
|
406 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
407 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
408 |
+
"""
|
409 |
+
|
410 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
411 |
+
super().__init__()
|
412 |
+
patch_size = to_2tuple(patch_size)
|
413 |
+
self.patch_size = patch_size
|
414 |
+
|
415 |
+
self.in_chans = in_chans
|
416 |
+
self.embed_dim = embed_dim
|
417 |
+
|
418 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
419 |
+
if norm_layer is not None:
|
420 |
+
self.norm = norm_layer(embed_dim)
|
421 |
+
else:
|
422 |
+
self.norm = None
|
423 |
+
|
424 |
+
def forward(self, x):
|
425 |
+
"""Forward function."""
|
426 |
+
# padding
|
427 |
+
_, _, H, W = x.size()
|
428 |
+
if W % self.patch_size[1] != 0:
|
429 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
430 |
+
if H % self.patch_size[0] != 0:
|
431 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
432 |
+
|
433 |
+
x = self.proj(x) # B C Wh Ww
|
434 |
+
if self.norm is not None:
|
435 |
+
Wh, Ww = x.size(2), x.size(3)
|
436 |
+
x = x.flatten(2).transpose(1, 2)
|
437 |
+
x = self.norm(x)
|
438 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
439 |
+
|
440 |
+
return x
|
441 |
+
|
442 |
+
|
443 |
+
class SwinTransformer(nn.Module):
|
444 |
+
""" Swin Transformer backbone.
|
445 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
446 |
+
https://arxiv.org/pdf/2103.14030
|
447 |
+
|
448 |
+
Args:
|
449 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
450 |
+
used in absolute postion embedding. Default 224.
|
451 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
452 |
+
in_chans (int): Number of input image channels. Default: 3.
|
453 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
454 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
455 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
456 |
+
window_size (int): Window size. Default: 7.
|
457 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
458 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
459 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
460 |
+
drop_rate (float): Dropout rate.
|
461 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
462 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
463 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
464 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
465 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
466 |
+
out_indices (Sequence[int]): Output from which stages.
|
467 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
468 |
+
-1 means not freezing any parameters.
|
469 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
470 |
+
"""
|
471 |
+
|
472 |
+
def __init__(self,
|
473 |
+
pretrain_img_size=224,
|
474 |
+
patch_size=4,
|
475 |
+
in_chans=3,
|
476 |
+
embed_dim=96,
|
477 |
+
depths=[2, 2, 6, 2],
|
478 |
+
num_heads=[3, 6, 12, 24],
|
479 |
+
window_size=7,
|
480 |
+
mlp_ratio=4.,
|
481 |
+
qkv_bias=True,
|
482 |
+
qk_scale=None,
|
483 |
+
drop_rate=0.,
|
484 |
+
attn_drop_rate=0.,
|
485 |
+
drop_path_rate=0.2,
|
486 |
+
norm_layer=nn.LayerNorm,
|
487 |
+
ape=False,
|
488 |
+
patch_norm=True,
|
489 |
+
out_indices=(0, 1, 2, 3),
|
490 |
+
frozen_stages=-1,
|
491 |
+
use_checkpoint=False):
|
492 |
+
super().__init__()
|
493 |
+
|
494 |
+
self.pretrain_img_size = pretrain_img_size
|
495 |
+
self.num_layers = len(depths)
|
496 |
+
self.embed_dim = embed_dim
|
497 |
+
self.ape = ape
|
498 |
+
self.patch_norm = patch_norm
|
499 |
+
self.out_indices = out_indices
|
500 |
+
self.frozen_stages = frozen_stages
|
501 |
+
|
502 |
+
# split image into non-overlapping patches
|
503 |
+
self.patch_embed = PatchEmbed(
|
504 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
505 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
506 |
+
|
507 |
+
# absolute position embedding
|
508 |
+
if self.ape:
|
509 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
510 |
+
patch_size = to_2tuple(patch_size)
|
511 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
512 |
+
|
513 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
514 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
515 |
+
|
516 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
517 |
+
|
518 |
+
# stochastic depth
|
519 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
520 |
+
|
521 |
+
# build layers
|
522 |
+
self.layers = nn.ModuleList()
|
523 |
+
for i_layer in range(self.num_layers):
|
524 |
+
layer = BasicLayer(
|
525 |
+
dim=int(embed_dim * 2 ** i_layer),
|
526 |
+
depth=depths[i_layer],
|
527 |
+
num_heads=num_heads[i_layer],
|
528 |
+
window_size=window_size,
|
529 |
+
mlp_ratio=mlp_ratio,
|
530 |
+
qkv_bias=qkv_bias,
|
531 |
+
qk_scale=qk_scale,
|
532 |
+
drop=drop_rate,
|
533 |
+
attn_drop=attn_drop_rate,
|
534 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
535 |
+
norm_layer=norm_layer,
|
536 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
537 |
+
use_checkpoint=use_checkpoint)
|
538 |
+
self.layers.append(layer)
|
539 |
+
|
540 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
541 |
+
self.num_features = num_features
|
542 |
+
|
543 |
+
# add a norm layer for each output
|
544 |
+
for i_layer in out_indices:
|
545 |
+
layer = norm_layer(num_features[i_layer])
|
546 |
+
layer_name = f'norm{i_layer}'
|
547 |
+
self.add_module(layer_name, layer)
|
548 |
+
|
549 |
+
self._freeze_stages()
|
550 |
+
|
551 |
+
def _freeze_stages(self):
|
552 |
+
if self.frozen_stages >= 0:
|
553 |
+
self.patch_embed.eval()
|
554 |
+
for param in self.patch_embed.parameters():
|
555 |
+
param.requires_grad = False
|
556 |
+
|
557 |
+
if self.frozen_stages >= 1 and self.ape:
|
558 |
+
self.absolute_pos_embed.requires_grad = False
|
559 |
+
|
560 |
+
if self.frozen_stages >= 2:
|
561 |
+
self.pos_drop.eval()
|
562 |
+
for i in range(0, self.frozen_stages - 1):
|
563 |
+
m = self.layers[i]
|
564 |
+
m.eval()
|
565 |
+
for param in m.parameters():
|
566 |
+
param.requires_grad = False
|
567 |
+
|
568 |
+
def init_weights(self, pretrained=None):
|
569 |
+
"""Initialize the weights in backbone.
|
570 |
+
|
571 |
+
Args:
|
572 |
+
pretrained (str, optional): Path to pre-trained weights.
|
573 |
+
Defaults to None.
|
574 |
+
"""
|
575 |
+
|
576 |
+
def _init_weights(m):
|
577 |
+
if isinstance(m, nn.Linear):
|
578 |
+
trunc_normal_(m.weight, std=.02)
|
579 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
580 |
+
nn.init.constant_(m.bias, 0)
|
581 |
+
elif isinstance(m, nn.LayerNorm):
|
582 |
+
nn.init.constant_(m.bias, 0)
|
583 |
+
nn.init.constant_(m.weight, 1.0)
|
584 |
+
|
585 |
+
if isinstance(pretrained, str):
|
586 |
+
self.apply(_init_weights)
|
587 |
+
logger = get_root_logger()
|
588 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
589 |
+
elif pretrained is None:
|
590 |
+
self.apply(_init_weights)
|
591 |
+
else:
|
592 |
+
raise TypeError('pretrained must be a str or None')
|
593 |
+
|
594 |
+
def forward(self, x):
|
595 |
+
"""Forward function."""
|
596 |
+
x = self.patch_embed(x)
|
597 |
+
|
598 |
+
Wh, Ww = x.size(2), x.size(3)
|
599 |
+
if self.ape:
|
600 |
+
# interpolate the position embedding to the corresponding size
|
601 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
602 |
+
x = (x + absolute_pos_embed) # B Wh*Ww C
|
603 |
+
|
604 |
+
outs = [x.contiguous()]
|
605 |
+
x = x.flatten(2).transpose(1, 2)
|
606 |
+
x = self.pos_drop(x)
|
607 |
+
for i in range(self.num_layers):
|
608 |
+
layer = self.layers[i]
|
609 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
610 |
+
|
611 |
+
if i in self.out_indices:
|
612 |
+
norm_layer = getattr(self, f'norm{i}')
|
613 |
+
x_out = norm_layer(x_out)
|
614 |
+
|
615 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
616 |
+
outs.append(out)
|
617 |
+
|
618 |
+
return tuple(outs)
|
619 |
+
|
620 |
+
def train(self, mode=True):
|
621 |
+
"""Convert the model into training mode while keep layers freezed."""
|
622 |
+
super(SwinTransformer, self).train(mode)
|
623 |
+
self._freeze_stages()
|
624 |
+
|
625 |
+
def SwinT(pretrained=True):
|
626 |
+
model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7)
|
627 |
+
if pretrained is True:
|
628 |
+
model.load_state_dict(torch.load('data/backbone_ckpt/swin_tiny_patch4_window7_224.pth', map_location='cpu')['model'], strict=False)
|
629 |
+
|
630 |
+
return model
|
631 |
+
|
632 |
+
def SwinS(pretrained=True):
|
633 |
+
model = SwinTransformer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7)
|
634 |
+
if pretrained is True:
|
635 |
+
model.load_state_dict(torch.load('data/backbone_ckpt/swin_small_patch4_window7_224.pth', map_location='cpu')['model'], strict=False)
|
636 |
+
|
637 |
+
return model
|
638 |
+
|
639 |
+
def SwinB(pretrained=True):
|
640 |
+
model = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
|
641 |
+
if pretrained is True:
|
642 |
+
model.load_state_dict(torch.load('data/backbone_ckpt/swin_base_patch4_window12_384_22kto1k.pth', map_location='cpu')['model'], strict=False)
|
643 |
+
|
644 |
+
return model
|
645 |
+
|
646 |
+
def SwinL(pretrained=True):
|
647 |
+
model = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12)
|
648 |
+
if pretrained is True:
|
649 |
+
model.load_state_dict(torch.load('data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth', map_location='cpu')['model'], strict=False)
|
650 |
+
|
651 |
+
return model
|
652 |
+
|
lib/modules/attention_module.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from torch.nn.parameter import Parameter
|
6 |
+
from operator import xor
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
from lib.modules.layers import *
|
10 |
+
from utils.misc import *
|
11 |
+
class SICA(nn.Module):
|
12 |
+
def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None, lmap_in=False):
|
13 |
+
super(SICA, self).__init__()
|
14 |
+
self.in_channel = in_channel
|
15 |
+
self.depth = depth
|
16 |
+
self.lmap_in = lmap_in
|
17 |
+
if base_size is not None and stage is not None:
|
18 |
+
self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage))
|
19 |
+
else:
|
20 |
+
self.stage_size = None
|
21 |
+
|
22 |
+
self.conv_query = nn.Sequential(Conv2d(in_channel, depth, 3, relu=True),
|
23 |
+
Conv2d(depth, depth, 3, relu=True))
|
24 |
+
self.conv_key = nn.Sequential(Conv2d(in_channel, depth, 1, relu=True),
|
25 |
+
Conv2d(depth, depth, 1, relu=True))
|
26 |
+
self.conv_value = nn.Sequential(Conv2d(in_channel, depth, 1, relu=True),
|
27 |
+
Conv2d(depth, depth, 1, relu=True))
|
28 |
+
|
29 |
+
if self.lmap_in is True:
|
30 |
+
self.ctx = 5
|
31 |
+
else:
|
32 |
+
self.ctx = 3
|
33 |
+
|
34 |
+
self.conv_out1 = Conv2d(depth, depth, 3, relu=True)
|
35 |
+
self.conv_out2 = Conv2d(in_channel + depth, depth, 3, relu=True)
|
36 |
+
self.conv_out3 = Conv2d(depth, depth, 3, relu=True)
|
37 |
+
self.conv_out4 = Conv2d(depth, out_channel, 1)
|
38 |
+
|
39 |
+
self.threshold = Parameter(torch.tensor([0.5]))
|
40 |
+
|
41 |
+
if self.lmap_in is True:
|
42 |
+
self.lthreshold = Parameter(torch.tensor([0.5]))
|
43 |
+
|
44 |
+
def forward(self, x, smap, lmap: Optional[torch.Tensor]=None):
|
45 |
+
assert not xor(self.lmap_in is True, lmap is not None)
|
46 |
+
b, c, h, w = x.shape
|
47 |
+
|
48 |
+
# compute class probability
|
49 |
+
smap = F.interpolate(smap, size=x.shape[-2:], mode='bilinear', align_corners=False)
|
50 |
+
smap = torch.sigmoid(smap)
|
51 |
+
p = smap - self.threshold
|
52 |
+
|
53 |
+
fg = torch.clip(p, 0, 1) # foreground
|
54 |
+
bg = torch.clip(-p, 0, 1) # background
|
55 |
+
cg = self.threshold - torch.abs(p) # confusion area
|
56 |
+
|
57 |
+
if self.lmap_in is True and lmap is not None:
|
58 |
+
lmap = F.interpolate(lmap, size=x.shape[-2:], mode='bilinear', align_corners=False)
|
59 |
+
lmap = torch.sigmoid(lmap)
|
60 |
+
lp = lmap - self.lthreshold
|
61 |
+
fp = torch.clip(lp, 0, 1) # foreground
|
62 |
+
bp = torch.clip(-lp, 0, 1) # background
|
63 |
+
|
64 |
+
prob = [fg, bg, cg, fp, bp]
|
65 |
+
else:
|
66 |
+
prob = [fg, bg, cg]
|
67 |
+
|
68 |
+
prob = torch.cat(prob, dim=1)
|
69 |
+
|
70 |
+
# reshape feature & prob
|
71 |
+
if self.stage_size is not None:
|
72 |
+
shape = self.stage_size
|
73 |
+
shape_mul = self.stage_size[0] * self.stage_size[1]
|
74 |
+
else:
|
75 |
+
shape = (h, w)
|
76 |
+
shape_mul = h * w
|
77 |
+
|
78 |
+
f = F.interpolate(x, size=shape, mode='bilinear', align_corners=False).view(b, shape_mul, -1)
|
79 |
+
prob = F.interpolate(prob, size=shape, mode='bilinear', align_corners=False).view(b, self.ctx, shape_mul)
|
80 |
+
|
81 |
+
# compute context vector
|
82 |
+
context = torch.bmm(prob, f).permute(0, 2, 1).unsqueeze(3) # b, 3, c
|
83 |
+
|
84 |
+
# k q v compute
|
85 |
+
query = self.conv_query(x).view(b, self.depth, -1).permute(0, 2, 1)
|
86 |
+
key = self.conv_key(context).view(b, self.depth, -1)
|
87 |
+
value = self.conv_value(context).view(b, self.depth, -1).permute(0, 2, 1)
|
88 |
+
|
89 |
+
# compute similarity map
|
90 |
+
sim = torch.bmm(query, key) # b, hw, c x b, c, 2
|
91 |
+
sim = (self.depth ** -.5) * sim
|
92 |
+
sim = F.softmax(sim, dim=-1)
|
93 |
+
|
94 |
+
# compute refined feature
|
95 |
+
context = torch.bmm(sim, value).permute(0, 2, 1).contiguous().view(b, -1, h, w)
|
96 |
+
context = self.conv_out1(context)
|
97 |
+
|
98 |
+
x = torch.cat([x, context], dim=1)
|
99 |
+
x = self.conv_out2(x)
|
100 |
+
x = self.conv_out3(x)
|
101 |
+
out = self.conv_out4(x)
|
102 |
+
|
103 |
+
return x, out
|
lib/modules/context_module.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .layers import *
|
6 |
+
|
7 |
+
class PAA_kernel(nn.Module):
|
8 |
+
def __init__(self, in_channel, out_channel, receptive_size, stage_size=None):
|
9 |
+
super(PAA_kernel, self).__init__()
|
10 |
+
self.conv0 = Conv2d(in_channel, out_channel, 1)
|
11 |
+
self.conv1 = Conv2d(out_channel, out_channel, kernel_size=(1, receptive_size))
|
12 |
+
self.conv2 = Conv2d(out_channel, out_channel, kernel_size=(receptive_size, 1))
|
13 |
+
self.conv3 = Conv2d(out_channel, out_channel, 3, dilation=receptive_size)
|
14 |
+
self.Hattn = SelfAttention(out_channel, 'h', stage_size[0] if stage_size is not None else None)
|
15 |
+
self.Wattn = SelfAttention(out_channel, 'w', stage_size[1] if stage_size is not None else None)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = self.conv0(x)
|
19 |
+
x = self.conv1(x)
|
20 |
+
x = self.conv2(x)
|
21 |
+
|
22 |
+
Hx = self.Hattn(x)
|
23 |
+
Wx = self.Wattn(x)
|
24 |
+
|
25 |
+
x = self.conv3(Hx + Wx)
|
26 |
+
return x
|
27 |
+
|
28 |
+
class PAA_e(nn.Module):
|
29 |
+
def __init__(self, in_channel, out_channel, base_size=None, stage=None):
|
30 |
+
super(PAA_e, self).__init__()
|
31 |
+
self.relu = nn.ReLU(True)
|
32 |
+
if base_size is not None and stage is not None:
|
33 |
+
self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage))
|
34 |
+
else:
|
35 |
+
self.stage_size = None
|
36 |
+
|
37 |
+
self.branch0 = Conv2d(in_channel, out_channel, 1)
|
38 |
+
self.branch1 = PAA_kernel(in_channel, out_channel, 3, self.stage_size)
|
39 |
+
self.branch2 = PAA_kernel(in_channel, out_channel, 5, self.stage_size)
|
40 |
+
self.branch3 = PAA_kernel(in_channel, out_channel, 7, self.stage_size)
|
41 |
+
|
42 |
+
self.conv_cat = Conv2d(4 * out_channel, out_channel, 3)
|
43 |
+
self.conv_res = Conv2d(in_channel, out_channel, 1)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
x0 = self.branch0(x)
|
47 |
+
x1 = self.branch1(x)
|
48 |
+
x2 = self.branch2(x)
|
49 |
+
x3 = self.branch3(x)
|
50 |
+
|
51 |
+
x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
|
52 |
+
x = self.relu(x_cat + self.conv_res(x))
|
53 |
+
|
54 |
+
return x
|
lib/modules/decoder_module.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .layers import *
|
6 |
+
class PAA_d(nn.Module):
|
7 |
+
def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None):
|
8 |
+
super(PAA_d, self).__init__()
|
9 |
+
self.conv1 = Conv2d(in_channel ,depth, 3)
|
10 |
+
self.conv2 = Conv2d(depth, depth, 3)
|
11 |
+
self.conv3 = Conv2d(depth, depth, 3)
|
12 |
+
self.conv4 = Conv2d(depth, depth, 3)
|
13 |
+
self.conv5 = Conv2d(depth, out_channel, 3, bn=False)
|
14 |
+
|
15 |
+
self.base_size = base_size
|
16 |
+
self.stage = stage
|
17 |
+
|
18 |
+
if base_size is not None and stage is not None:
|
19 |
+
self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage))
|
20 |
+
else:
|
21 |
+
self.stage_size = [None, None]
|
22 |
+
|
23 |
+
self.Hattn = SelfAttention(depth, 'h', self.stage_size[0])
|
24 |
+
self.Wattn = SelfAttention(depth, 'w', self.stage_size[1])
|
25 |
+
|
26 |
+
self.upsample = lambda img, size: F.interpolate(img, size=size, mode='bilinear', align_corners=True)
|
27 |
+
|
28 |
+
def forward(self, fs): #f3 f4 f5 -> f3 f2 f1
|
29 |
+
fx = fs[0]
|
30 |
+
for i in range(1, len(fs)):
|
31 |
+
fs[i] = self.upsample(fs[i], fx.shape[-2:])
|
32 |
+
fx = torch.cat(fs[::-1], dim=1)
|
33 |
+
|
34 |
+
fx = self.conv1(fx)
|
35 |
+
|
36 |
+
Hfx = self.Hattn(fx)
|
37 |
+
Wfx = self.Wattn(fx)
|
38 |
+
|
39 |
+
fx = self.conv2(Hfx + Wfx)
|
40 |
+
fx = self.conv3(fx)
|
41 |
+
fx = self.conv4(fx)
|
42 |
+
out = self.conv5(fx)
|
43 |
+
|
44 |
+
return fx, out
|
lib/modules/layers.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from optparse import Option
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from kornia.morphology import dilation, erosion
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
from typing import Optional
|
12 |
+
class ImagePyramid:
|
13 |
+
def __init__(self, ksize=7, sigma=1, channels=1):
|
14 |
+
self.ksize = ksize
|
15 |
+
self.sigma = sigma
|
16 |
+
self.channels = channels
|
17 |
+
|
18 |
+
k = cv2.getGaussianKernel(ksize, sigma)
|
19 |
+
k = np.outer(k, k)
|
20 |
+
k = torch.tensor(k).float()
|
21 |
+
self.kernel = k.repeat(channels, 1, 1, 1)
|
22 |
+
|
23 |
+
def to(self, device):
|
24 |
+
self.kernel = self.kernel.to(device)
|
25 |
+
return self
|
26 |
+
|
27 |
+
def cuda(self, idx=None):
|
28 |
+
if idx is None:
|
29 |
+
idx = torch.cuda.current_device()
|
30 |
+
|
31 |
+
self.to(device="cuda:{}".format(idx))
|
32 |
+
return self
|
33 |
+
|
34 |
+
def expand(self, x):
|
35 |
+
z = torch.zeros_like(x)
|
36 |
+
x = torch.cat([x, z, z, z], dim=1)
|
37 |
+
x = F.pixel_shuffle(x, 2)
|
38 |
+
x = F.pad(x, (self.ksize // 2, ) * 4, mode='reflect')
|
39 |
+
x = F.conv2d(x, self.kernel * 4, groups=self.channels)
|
40 |
+
return x
|
41 |
+
|
42 |
+
def reduce(self, x):
|
43 |
+
x = F.pad(x, (self.ksize // 2, ) * 4, mode='reflect')
|
44 |
+
x = F.conv2d(x, self.kernel, groups=self.channels)
|
45 |
+
x = x[:, :, ::2, ::2]
|
46 |
+
return x
|
47 |
+
|
48 |
+
def deconstruct(self, x):
|
49 |
+
reduced_x = self.reduce(x)
|
50 |
+
expanded_reduced_x = self.expand(reduced_x)
|
51 |
+
|
52 |
+
if x.shape != expanded_reduced_x.shape:
|
53 |
+
expanded_reduced_x = F.interpolate(expanded_reduced_x, x.shape[-2:])
|
54 |
+
|
55 |
+
laplacian_x = x - expanded_reduced_x
|
56 |
+
return reduced_x, laplacian_x
|
57 |
+
|
58 |
+
def reconstruct(self, x, laplacian_x):
|
59 |
+
expanded_x = self.expand(x)
|
60 |
+
if laplacian_x.shape != expanded_x:
|
61 |
+
laplacian_x = F.interpolate(laplacian_x, expanded_x.shape[-2:], mode='bilinear', align_corners=True)
|
62 |
+
return expanded_x + laplacian_x
|
63 |
+
|
64 |
+
class Transition:
|
65 |
+
def __init__(self, k=3):
|
66 |
+
self.kernel = torch.tensor(cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))).float()
|
67 |
+
|
68 |
+
def to(self, device):
|
69 |
+
self.kernel = self.kernel.to(device)
|
70 |
+
return self
|
71 |
+
|
72 |
+
def cuda(self, idx=None):
|
73 |
+
if idx is None:
|
74 |
+
idx = torch.cuda.current_device()
|
75 |
+
|
76 |
+
self.to(device="cuda:{}".format(idx))
|
77 |
+
return self
|
78 |
+
|
79 |
+
def __call__(self, x):
|
80 |
+
x = torch.sigmoid(x)
|
81 |
+
dx = dilation(x, self.kernel)
|
82 |
+
ex = erosion(x, self.kernel)
|
83 |
+
|
84 |
+
return ((dx - ex) > .5).float()
|
85 |
+
|
86 |
+
class Conv2d(nn.Module):
|
87 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, padding='same', bias=False, bn=True, relu=False):
|
88 |
+
super(Conv2d, self).__init__()
|
89 |
+
if '__iter__' not in dir(kernel_size):
|
90 |
+
kernel_size = (kernel_size, kernel_size)
|
91 |
+
if '__iter__' not in dir(stride):
|
92 |
+
stride = (stride, stride)
|
93 |
+
if '__iter__' not in dir(dilation):
|
94 |
+
dilation = (dilation, dilation)
|
95 |
+
|
96 |
+
if padding == 'same':
|
97 |
+
width_pad_size = kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1)
|
98 |
+
height_pad_size = kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1)
|
99 |
+
elif padding == 'valid':
|
100 |
+
width_pad_size = 0
|
101 |
+
height_pad_size = 0
|
102 |
+
else:
|
103 |
+
if '__iter__' in dir(padding):
|
104 |
+
width_pad_size = padding[0] * 2
|
105 |
+
height_pad_size = padding[1] * 2
|
106 |
+
else:
|
107 |
+
width_pad_size = padding * 2
|
108 |
+
height_pad_size = padding * 2
|
109 |
+
|
110 |
+
width_pad_size = width_pad_size // 2 + (width_pad_size % 2 - 1)
|
111 |
+
height_pad_size = height_pad_size // 2 + (height_pad_size % 2 - 1)
|
112 |
+
pad_size = (width_pad_size, height_pad_size)
|
113 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_size, dilation, groups, bias=bias)
|
114 |
+
self.reset_parameters()
|
115 |
+
|
116 |
+
if bn is True:
|
117 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
118 |
+
else:
|
119 |
+
self.bn = None
|
120 |
+
|
121 |
+
if relu is True:
|
122 |
+
self.relu = nn.ReLU(inplace=True)
|
123 |
+
else:
|
124 |
+
self.relu = None
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = self.conv(x)
|
128 |
+
if self.bn is not None:
|
129 |
+
x = self.bn(x)
|
130 |
+
if self.relu is not None:
|
131 |
+
x = self.relu(x)
|
132 |
+
return x
|
133 |
+
|
134 |
+
def reset_parameters(self):
|
135 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
136 |
+
|
137 |
+
|
138 |
+
class SelfAttention(nn.Module):
|
139 |
+
def __init__(self, in_channels, mode='hw', stage_size=None):
|
140 |
+
super(SelfAttention, self).__init__()
|
141 |
+
|
142 |
+
self.mode = mode
|
143 |
+
|
144 |
+
self.query_conv = Conv2d(in_channels, in_channels // 8, kernel_size=(1, 1))
|
145 |
+
self.key_conv = Conv2d(in_channels, in_channels // 8, kernel_size=(1, 1))
|
146 |
+
self.value_conv = Conv2d(in_channels, in_channels, kernel_size=(1, 1))
|
147 |
+
|
148 |
+
self.gamma = Parameter(torch.zeros(1))
|
149 |
+
self.softmax = nn.Softmax(dim=-1)
|
150 |
+
|
151 |
+
self.stage_size = stage_size
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
batch_size, channel, height, width = x.size()
|
155 |
+
|
156 |
+
axis = 1
|
157 |
+
if 'h' in self.mode:
|
158 |
+
axis *= height
|
159 |
+
if 'w' in self.mode:
|
160 |
+
axis *= width
|
161 |
+
|
162 |
+
view = (batch_size, -1, axis)
|
163 |
+
|
164 |
+
projected_query = self.query_conv(x).view(*view).permute(0, 2, 1)
|
165 |
+
projected_key = self.key_conv(x).view(*view)
|
166 |
+
|
167 |
+
attention_map = torch.bmm(projected_query, projected_key)
|
168 |
+
attention = self.softmax(attention_map)
|
169 |
+
projected_value = self.value_conv(x).view(*view)
|
170 |
+
|
171 |
+
out = torch.bmm(projected_value, attention.permute(0, 2, 1))
|
172 |
+
out = out.view(batch_size, channel, height, width)
|
173 |
+
|
174 |
+
out = self.gamma * out + x
|
175 |
+
return out
|
lib/optim/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .losses import *
|
2 |
+
from .scheduler import *
|
lib/optim/losses.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
def bce_loss(pred, mask, reduction='none'):
|
5 |
+
bce = F.binary_cross_entropy(pred, mask, reduction=reduction)
|
6 |
+
return bce
|
7 |
+
|
8 |
+
def weighted_bce_loss(pred, mask, reduction='none'):
|
9 |
+
weight = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
|
10 |
+
weight = weight.flatten()
|
11 |
+
|
12 |
+
bce = weight * bce_loss(pred, mask, reduction='none').flatten()
|
13 |
+
|
14 |
+
if reduction == 'mean':
|
15 |
+
bce = bce.mean()
|
16 |
+
|
17 |
+
return bce
|
18 |
+
|
19 |
+
def iou_loss(pred, mask, reduction='none'):
|
20 |
+
inter = pred * mask
|
21 |
+
union = pred + mask
|
22 |
+
iou = 1 - (inter + 1) / (union - inter + 1)
|
23 |
+
|
24 |
+
if reduction == 'mean':
|
25 |
+
iou = iou.mean()
|
26 |
+
|
27 |
+
return iou
|
28 |
+
|
29 |
+
def bce_loss_with_logits(pred, mask, reduction='none'):
|
30 |
+
return bce_loss(torch.sigmoid(pred), mask, reduction=reduction)
|
31 |
+
|
32 |
+
def weighted_bce_loss_with_logits(pred, mask, reduction='none'):
|
33 |
+
return weighted_bce_loss(torch.sigmoid(pred), mask, reduction=reduction)
|
34 |
+
|
35 |
+
def iou_loss_with_logits(pred, mask, reduction='none'):
|
36 |
+
return iou_loss(torch.sigmoid(pred), mask, reduction=reduction)
|
lib/optim/scheduler.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
2 |
+
|
3 |
+
|
4 |
+
class PolyLr(_LRScheduler):
|
5 |
+
def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1):
|
6 |
+
self.gamma = gamma
|
7 |
+
self.max_iteration = max_iteration
|
8 |
+
self.minimum_lr = minimum_lr
|
9 |
+
self.warmup_iteration = warmup_iteration
|
10 |
+
|
11 |
+
self.last_epoch = None
|
12 |
+
self.base_lrs = []
|
13 |
+
|
14 |
+
super(PolyLr, self).__init__(optimizer, last_epoch)
|
15 |
+
|
16 |
+
def poly_lr(self, base_lr, step):
|
17 |
+
return (base_lr - self.minimum_lr) * ((1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr
|
18 |
+
|
19 |
+
def warmup_lr(self, base_lr, alpha):
|
20 |
+
return base_lr * (1 / 10.0 * (1 - alpha) + alpha)
|
21 |
+
|
22 |
+
def get_lr(self):
|
23 |
+
if self.last_epoch < self.warmup_iteration:
|
24 |
+
alpha = self.last_epoch / self.warmup_iteration
|
25 |
+
lrs = [min(self.warmup_lr(base_lr, alpha), self.poly_lr(base_lr, self.last_epoch)) for base_lr in
|
26 |
+
self.base_lrs]
|
27 |
+
else:
|
28 |
+
lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
|
29 |
+
|
30 |
+
return lrs
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
easydict
|
2 |
+
timm
|
3 |
+
tqdm
|
4 |
+
kornia
|
5 |
+
numpy
|
6 |
+
opencv-python
|
7 |
+
pyyaml
|
8 |
+
pandas
|
9 |
+
scipy
|
run/Eval.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import tqdm
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
filepath = os.path.split(os.path.abspath(__file__))[0]
|
11 |
+
repopath = os.path.split(filepath)[0]
|
12 |
+
sys.path.append(repopath)
|
13 |
+
|
14 |
+
from utils.eval_functions import *
|
15 |
+
from utils.misc import *
|
16 |
+
|
17 |
+
BETA = 1.0
|
18 |
+
|
19 |
+
def evaluate(opt, args):
|
20 |
+
if os.path.isdir(opt.Eval.result_path) is False:
|
21 |
+
os.makedirs(opt.Eval.result_path)
|
22 |
+
|
23 |
+
method = os.path.split(opt.Eval.pred_root)[-1]
|
24 |
+
|
25 |
+
if args.verbose is True:
|
26 |
+
print('#' * 20, 'Start Evaluation', '#' * 20)
|
27 |
+
datasets = tqdm.tqdm(opt.Eval.datasets, desc='Expr - ' + method, total=len(
|
28 |
+
opt.Eval.datasets), position=0, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
|
29 |
+
else:
|
30 |
+
datasets = opt.Eval.datasets
|
31 |
+
|
32 |
+
results = []
|
33 |
+
|
34 |
+
for dataset in datasets:
|
35 |
+
pred_root = os.path.join(opt.Eval.pred_root, dataset)
|
36 |
+
gt_root = os.path.join(opt.Eval.gt_root, dataset, 'masks')
|
37 |
+
|
38 |
+
preds = os.listdir(pred_root)
|
39 |
+
gts = os.listdir(gt_root)
|
40 |
+
|
41 |
+
preds = sort(preds)
|
42 |
+
gts = sort(gts)
|
43 |
+
|
44 |
+
preds = [i for i in preds if i in gts]
|
45 |
+
gts = [i for i in gts if i in preds]
|
46 |
+
|
47 |
+
FM = Fmeasure()
|
48 |
+
WFM = WeightedFmeasure()
|
49 |
+
SM = Smeasure()
|
50 |
+
EM = Emeasure()
|
51 |
+
MAE = Mae()
|
52 |
+
MSE = Mse()
|
53 |
+
MBA = BoundaryAccuracy()
|
54 |
+
IOU = IoU()
|
55 |
+
BIOU = BIoU()
|
56 |
+
TIOU = TIoU()
|
57 |
+
|
58 |
+
if args.verbose is True:
|
59 |
+
samples = tqdm.tqdm(enumerate(zip(preds, gts)), desc=dataset + ' - Evaluation', total=len(
|
60 |
+
preds), position=1, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
|
61 |
+
else:
|
62 |
+
samples = enumerate(zip(preds, gts))
|
63 |
+
|
64 |
+
for i, sample in samples:
|
65 |
+
pred, gt = sample
|
66 |
+
|
67 |
+
pred_mask = np.array(Image.open(os.path.join(pred_root, pred)).convert('L'))
|
68 |
+
gt_mask = np.array(Image.open(os.path.join(gt_root, gt)).convert('L'))
|
69 |
+
|
70 |
+
if len(pred_mask.shape) != 2:
|
71 |
+
pred_mask = pred_mask[:, :, 0]
|
72 |
+
if len(gt_mask.shape) != 2:
|
73 |
+
gt_mask = gt_mask[:, :, 0]
|
74 |
+
|
75 |
+
assert pred_mask.shape == gt_mask.shape, print(pred, 'does not match the size of', gt)
|
76 |
+
# print(gt_mask.max())
|
77 |
+
|
78 |
+
FM.step( pred=pred_mask, gt=gt_mask)
|
79 |
+
WFM.step(pred=pred_mask, gt=gt_mask)
|
80 |
+
SM.step( pred=pred_mask, gt=gt_mask)
|
81 |
+
EM.step( pred=pred_mask, gt=gt_mask)
|
82 |
+
MAE.step(pred=pred_mask, gt=gt_mask)
|
83 |
+
MSE.step(pred=pred_mask, gt=gt_mask)
|
84 |
+
MBA.step(pred=pred_mask, gt=gt_mask)
|
85 |
+
IOU.step(pred=pred_mask, gt=gt_mask)
|
86 |
+
BIOU.step(pred=pred_mask, gt=gt_mask)
|
87 |
+
TIOU.step(pred=pred_mask, gt=gt_mask)
|
88 |
+
|
89 |
+
result = []
|
90 |
+
|
91 |
+
Sm = SM.get_results()["sm"]
|
92 |
+
wFm = WFM.get_results()["wfm"]
|
93 |
+
mae = MAE.get_results()["mae"]
|
94 |
+
mse = MSE.get_results()["mse"]
|
95 |
+
mBA = MBA.get_results()["mba"]
|
96 |
+
|
97 |
+
Fm = FM.get_results()["fm"]
|
98 |
+
Em = EM.get_results()["em"]
|
99 |
+
Iou = IOU.get_results()["iou"]
|
100 |
+
BIou = BIOU.get_results()["biou"]
|
101 |
+
TIou = TIOU.get_results()["tiou"]
|
102 |
+
|
103 |
+
adpEm = Em["adp"]
|
104 |
+
avgEm = Em["curve"].mean()
|
105 |
+
maxEm = Em["curve"].max()
|
106 |
+
adpFm = Fm["adp"]
|
107 |
+
avgFm = Fm["curve"].mean()
|
108 |
+
maxFm = Fm["curve"].max()
|
109 |
+
avgIou = Iou["curve"].mean()
|
110 |
+
maxIou = Iou["curve"].max()
|
111 |
+
avgBIou = BIou["curve"].mean()
|
112 |
+
maxBIou = BIou["curve"].max()
|
113 |
+
avgTIou = TIou["curve"].mean()
|
114 |
+
maxTIou = TIou["curve"].max()
|
115 |
+
|
116 |
+
out = dict()
|
117 |
+
for metric in opt.Eval.metrics:
|
118 |
+
out[metric] = eval(metric)
|
119 |
+
|
120 |
+
pkl = os.path.join(opt.Eval.result_path, 'result_' + dataset + '.pkl')
|
121 |
+
if os.path.isfile(pkl) is True:
|
122 |
+
result = pd.read_pickle(pkl)
|
123 |
+
result.loc[method] = out
|
124 |
+
result.to_pickle(pkl)
|
125 |
+
else:
|
126 |
+
result = pd.DataFrame(data=out, index=[method])
|
127 |
+
result.to_pickle(pkl)
|
128 |
+
result.to_csv(os.path.join(opt.Eval.result_path, 'result_' + dataset + '.csv'))
|
129 |
+
results.append(result)
|
130 |
+
|
131 |
+
if args.verbose is True:
|
132 |
+
for dataset, result in zip(datasets, results):
|
133 |
+
print('###', dataset, '###', '\n', result.sort_index(), '\n')
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
args = parse_args()
|
137 |
+
opt = load_config(args.config)
|
138 |
+
evaluate(opt, args)
|
run/Inference.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
import tqdm
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
filepath = os.path.split(os.path.abspath(__file__))[0]
|
13 |
+
repopath = os.path.split(filepath)[0]
|
14 |
+
sys.path.append(repopath)
|
15 |
+
|
16 |
+
from lib import *
|
17 |
+
from utils.misc import *
|
18 |
+
from data.dataloader import *
|
19 |
+
from data.custom_transforms import *
|
20 |
+
|
21 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
22 |
+
torch.backends.cudnn.allow_tf32 = False
|
23 |
+
|
24 |
+
def _args():
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml')
|
27 |
+
parser.add_argument('--source', '-s', type=str)
|
28 |
+
parser.add_argument('--dest', '-d', type=str, default=None)
|
29 |
+
parser.add_argument('--type', '-t', type=str, default='map')
|
30 |
+
parser.add_argument('--gpu', '-g', action='store_true', default=False)
|
31 |
+
parser.add_argument('--jit', '-j', action='store_true', default=False)
|
32 |
+
parser.add_argument('--verbose', '-v', action='store_true', default=False)
|
33 |
+
return parser.parse_args()
|
34 |
+
|
35 |
+
def get_format(source):
|
36 |
+
img_count = len([i for i in source if i.lower().endswith(('.jpg', '.png', '.jpeg'))])
|
37 |
+
vid_count = len([i for i in source if i.lower().endswith(('.mp4', '.avi', '.mov' ))])
|
38 |
+
|
39 |
+
if img_count * vid_count != 0:
|
40 |
+
return ''
|
41 |
+
elif img_count != 0:
|
42 |
+
return 'Image'
|
43 |
+
elif vid_count != 0:
|
44 |
+
return 'Video'
|
45 |
+
else:
|
46 |
+
return ''
|
47 |
+
|
48 |
+
def inference(opt, args):
|
49 |
+
model = eval(opt.Model.name)(**opt.Model)
|
50 |
+
model.load_state_dict(torch.load(os.path.join(
|
51 |
+
opt.Test.Checkpoint.checkpoint_dir, 'latest.pth'), map_location=torch.device('cpu')), strict=True)
|
52 |
+
|
53 |
+
if args.gpu is True:
|
54 |
+
model = model.cuda()
|
55 |
+
model.eval()
|
56 |
+
|
57 |
+
if args.jit is True:
|
58 |
+
if os.path.isfile(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) is False:
|
59 |
+
model = Simplify(model)
|
60 |
+
model = torch.jit.trace(model, torch.rand(1, 3, *opt.Test.Dataset.transforms.static_resize.size).cuda(), strict=False)
|
61 |
+
torch.jit.save(model, os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt'))
|
62 |
+
|
63 |
+
else:
|
64 |
+
del model
|
65 |
+
model = torch.jit.load(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt'))
|
66 |
+
|
67 |
+
save_dir = None
|
68 |
+
_format = None
|
69 |
+
|
70 |
+
if args.source.isnumeric() is True:
|
71 |
+
_format = 'Webcam'
|
72 |
+
|
73 |
+
elif os.path.isdir(args.source):
|
74 |
+
save_dir = os.path.join('results', args.source.split(os.sep)[-1])
|
75 |
+
_format = get_format(os.listdir(args.source))
|
76 |
+
|
77 |
+
elif os.path.isfile(args.source):
|
78 |
+
save_dir = 'results'
|
79 |
+
_format = get_format([args.source])
|
80 |
+
|
81 |
+
if args.dest is not None:
|
82 |
+
save_dir = args.dest
|
83 |
+
|
84 |
+
if save_dir is not None:
|
85 |
+
os.makedirs(save_dir, exist_ok=True)
|
86 |
+
|
87 |
+
sample_list = eval(_format + 'Loader')(args.source, opt.Test.Dataset.transforms)
|
88 |
+
|
89 |
+
if args.verbose is True:
|
90 |
+
samples = tqdm.tqdm(sample_list, desc='Inference', total=len(
|
91 |
+
sample_list), position=0, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
|
92 |
+
else:
|
93 |
+
samples = sample_list
|
94 |
+
|
95 |
+
writer = None
|
96 |
+
background = None
|
97 |
+
|
98 |
+
for sample in samples:
|
99 |
+
if _format == 'Video' and writer is None:
|
100 |
+
writer = cv2.VideoWriter(os.path.join(save_dir, sample['name'] + '.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), sample_list.fps, sample['shape'][::-1])
|
101 |
+
samples.total += int(sample_list.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
102 |
+
if _format == 'Video' and sample['image'] is None:
|
103 |
+
if writer is not None:
|
104 |
+
writer.release()
|
105 |
+
writer = None
|
106 |
+
continue
|
107 |
+
|
108 |
+
if args.gpu is True:
|
109 |
+
sample = to_cuda(sample)
|
110 |
+
|
111 |
+
with torch.no_grad():
|
112 |
+
if args.jit is True:
|
113 |
+
out = model(sample['image'])
|
114 |
+
else:
|
115 |
+
out = model(sample)
|
116 |
+
|
117 |
+
|
118 |
+
pred = to_numpy(out['pred'], sample['shape'])
|
119 |
+
img = np.array(sample['original'])
|
120 |
+
|
121 |
+
if args.type == 'map':
|
122 |
+
img = (np.stack([pred] * 3, axis=-1) * 255).astype(np.uint8)
|
123 |
+
elif args.type == 'rgba':
|
124 |
+
r, g, b = cv2.split(img)
|
125 |
+
pred = (pred * 255).astype(np.uint8)
|
126 |
+
img = cv2.merge([r, g, b, pred])
|
127 |
+
elif args.type == 'green':
|
128 |
+
bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155]
|
129 |
+
img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis])
|
130 |
+
elif args.type == 'blur':
|
131 |
+
img = img * pred[..., np.newaxis] + cv2.GaussianBlur(img, (0, 0), 15) * (1 - pred[..., np.newaxis])
|
132 |
+
elif args.type == 'overlay':
|
133 |
+
bg = (np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] + img) // 2
|
134 |
+
img = bg * pred[..., np.newaxis] + img * (1 - pred[..., np.newaxis])
|
135 |
+
border = cv2.Canny(((pred > .5) * 255).astype(np.uint8), 50, 100)
|
136 |
+
img[border != 0] = [120, 255, 155]
|
137 |
+
elif args.type.lower().endswith(('.jpg', '.jpeg', '.png')):
|
138 |
+
if background is None:
|
139 |
+
background = cv2.cvtColor(cv2.imread(args.type), cv2.COLOR_BGR2RGB)
|
140 |
+
background = cv2.resize(background, img.shape[:2][::-1])
|
141 |
+
img = img * pred[..., np.newaxis] + background * (1 - pred[..., np.newaxis])
|
142 |
+
elif args.type == 'debug':
|
143 |
+
debs = []
|
144 |
+
for k in opt.Train.Debug.keys:
|
145 |
+
debs.extend(out[k])
|
146 |
+
for i, j in enumerate(debs):
|
147 |
+
log = torch.sigmoid(j).cpu().detach().numpy().squeeze()
|
148 |
+
log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8)
|
149 |
+
log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB)
|
150 |
+
log = cv2.resize(log, img.shape[:2][::-1])
|
151 |
+
Image.fromarray(log).save(os.path.join(save_dir, sample['name'] + '_' + str(i) + '.png'))
|
152 |
+
# size=img.shape[:2][::-1]
|
153 |
+
|
154 |
+
|
155 |
+
img = img.astype(np.uint8)
|
156 |
+
|
157 |
+
if _format == 'Image':
|
158 |
+
Image.fromarray(img).save(os.path.join(save_dir, sample['name'] + '.png'))
|
159 |
+
elif _format == 'Video' and writer is not None:
|
160 |
+
writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
161 |
+
elif _format == 'Webcam':
|
162 |
+
cv2.imshow('InSPyReNet', img)
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
args = _args()
|
166 |
+
opt = load_config(args.config)
|
167 |
+
inference(opt, args)
|
run/Test.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import tqdm
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from torch.utils.data.dataloader import DataLoader
|
10 |
+
|
11 |
+
filepath = os.path.split(os.path.abspath(__file__))[0]
|
12 |
+
repopath = os.path.split(filepath)[0]
|
13 |
+
sys.path.append(repopath)
|
14 |
+
|
15 |
+
from lib import *
|
16 |
+
from utils.misc import *
|
17 |
+
from data.dataloader import *
|
18 |
+
|
19 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
20 |
+
torch.backends.cudnn.allow_tf32 = False
|
21 |
+
|
22 |
+
def test(opt, args):
|
23 |
+
model = eval(opt.Model.name)(**opt.Model)
|
24 |
+
model.load_state_dict(torch.load(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'latest.pth')), strict=True)
|
25 |
+
|
26 |
+
model.cuda()
|
27 |
+
model.eval()
|
28 |
+
|
29 |
+
if args.verbose is True:
|
30 |
+
sets = tqdm.tqdm(opt.Test.Dataset.sets, desc='Total TestSet', total=len(
|
31 |
+
opt.Test.Dataset.sets), position=0, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
|
32 |
+
else:
|
33 |
+
sets = opt.Test.Dataset.sets
|
34 |
+
|
35 |
+
for set in sets:
|
36 |
+
save_path = os.path.join(opt.Test.Checkpoint.checkpoint_dir, set)
|
37 |
+
|
38 |
+
os.makedirs(save_path, exist_ok=True)
|
39 |
+
test_dataset = eval(opt.Test.Dataset.type)(opt.Test.Dataset.root, [set], opt.Test.Dataset.transforms)
|
40 |
+
test_loader = DataLoader(dataset=test_dataset, batch_size=1, num_workers=opt.Test.Dataloader.num_workers, pin_memory=opt.Test.Dataloader.pin_memory)
|
41 |
+
|
42 |
+
if args.verbose is True:
|
43 |
+
samples = tqdm.tqdm(test_loader, desc=set + ' - Test', total=len(test_loader),
|
44 |
+
position=1, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
|
45 |
+
else:
|
46 |
+
samples = test_loader
|
47 |
+
|
48 |
+
for sample in samples:
|
49 |
+
sample = to_cuda(sample)
|
50 |
+
with torch.no_grad():
|
51 |
+
out = model(sample)
|
52 |
+
|
53 |
+
pred = to_numpy(out['pred'], sample['shape'])
|
54 |
+
Image.fromarray((pred * 255).astype(np.uint8)).save(os.path.join(save_path, sample['name'][0] + '.png'))
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
args = parse_args()
|
58 |
+
opt = load_config(args.config)
|
59 |
+
test(opt, args)
|