DRgaddam commited on
Commit
362a3b9
·
verified ·
1 Parent(s): 9f0a56a

adding files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +14 -0
  2. .gitignore +139 -0
  3. CODE_OF_CONDUCT.md +128 -0
  4. Expr.py +14 -0
  5. LICENSE +21 -0
  6. README.md +186 -3
  7. configs/InSPyReNet_Res2Net50.yaml +76 -0
  8. configs/InSPyReNet_SwinB.yaml +76 -0
  9. configs/extra_dataset/InSPyReNet_SwinB_DH.yaml +76 -0
  10. configs/extra_dataset/InSPyReNet_SwinB_DHU_LR.yaml +76 -0
  11. configs/extra_dataset/InSPyReNet_SwinB_DH_LR.yaml +76 -0
  12. configs/extra_dataset/InSPyReNet_SwinB_DIS5K.yaml +76 -0
  13. configs/extra_dataset/InSPyReNet_SwinB_DIS5K_LR.yaml +76 -0
  14. configs/extra_dataset/InSPyReNet_SwinB_HU.yaml +76 -0
  15. configs/extra_dataset/InSPyReNet_SwinB_HU_LR.yaml +76 -0
  16. configs/extra_dataset/Plus_Ultra.yaml +98 -0
  17. configs/extra_dataset/Plus_Ultra_LR.yaml +96 -0
  18. data/custom_transforms.py +206 -0
  19. data/dataloader.py +243 -0
  20. docs/getting_started.md +138 -0
  21. docs/model_zoo.md +65 -0
  22. figures/demo_image.gif +3 -0
  23. figures/demo_type.png +3 -0
  24. figures/demo_video.gif +3 -0
  25. figures/demo_webapp.gif +3 -0
  26. figures/fig_architecture.png +3 -0
  27. figures/fig_pyramid_blending.png +3 -0
  28. figures/fig_qualitative.png +3 -0
  29. figures/fig_qualitative2.png +3 -0
  30. figures/fig_qualitative3.jpg +3 -0
  31. figures/fig_qualitative_dis.png +3 -0
  32. figures/fig_quantitative.png +3 -0
  33. figures/fig_quantitative2.png +3 -0
  34. figures/fig_quantitative3.png +3 -0
  35. figures/fig_quantitative4.png +3 -0
  36. lib/InSPyReNet.py +199 -0
  37. lib/__init__.py +1 -0
  38. lib/backbones/Res2Net_v1b.py +268 -0
  39. lib/backbones/SwinTransformer.py +652 -0
  40. lib/modules/attention_module.py +103 -0
  41. lib/modules/context_module.py +54 -0
  42. lib/modules/decoder_module.py +44 -0
  43. lib/modules/layers.py +175 -0
  44. lib/optim/__init__.py +2 -0
  45. lib/optim/losses.py +36 -0
  46. lib/optim/scheduler.py +30 -0
  47. requirements.txt +9 -0
  48. run/Eval.py +138 -0
  49. run/Inference.py +167 -0
  50. run/Test.py +59 -0
.gitattributes CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/demo_image.gif filter=lfs diff=lfs merge=lfs -text
37
+ figures/demo_type.png filter=lfs diff=lfs merge=lfs -text
38
+ figures/demo_video.gif filter=lfs diff=lfs merge=lfs -text
39
+ figures/demo_webapp.gif filter=lfs diff=lfs merge=lfs -text
40
+ figures/fig_architecture.png filter=lfs diff=lfs merge=lfs -text
41
+ figures/fig_pyramid_blending.png filter=lfs diff=lfs merge=lfs -text
42
+ figures/fig_qualitative_dis.png filter=lfs diff=lfs merge=lfs -text
43
+ figures/fig_qualitative.png filter=lfs diff=lfs merge=lfs -text
44
+ figures/fig_qualitative2.png filter=lfs diff=lfs merge=lfs -text
45
+ figures/fig_qualitative3.jpg filter=lfs diff=lfs merge=lfs -text
46
+ figures/fig_quantitative.png filter=lfs diff=lfs merge=lfs -text
47
+ figures/fig_quantitative2.png filter=lfs diff=lfs merge=lfs -text
48
+ figures/fig_quantitative3.png filter=lfs diff=lfs merge=lfs -text
49
+ figures/fig_quantitative4.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ pip-wheel-metadata/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+
53
+ # Translations
54
+ *.mo
55
+ *.pot
56
+
57
+ # Django stuff:
58
+ *.log
59
+ local_settings.py
60
+ db.sqlite3
61
+ db.sqlite3-journal
62
+
63
+ # Flask stuff:
64
+ instance/
65
+ .webassets-cache
66
+
67
+ # Scrapy stuff:
68
+ .scrapy
69
+
70
+ # Sphinx documentation
71
+ docs/_build/
72
+
73
+ # PyBuilder
74
+ target/
75
+
76
+ # Jupyter Notebook
77
+ .ipynb_checkpoints
78
+
79
+ # IPython
80
+ profile_default/
81
+ ipython_config.py
82
+
83
+ # pyenv
84
+ .python-version
85
+
86
+ # pipenv
87
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
89
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
90
+ # install all needed dependencies.
91
+ #Pipfile.lock
92
+
93
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
94
+ __pypackages__/
95
+
96
+ # Celery stuff
97
+ celerybeat-schedule
98
+ celerybeat.pid
99
+
100
+ # SageMath parsed files
101
+ *.sage.py
102
+
103
+ # Environments
104
+ .env
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ data/backbone_ckpt
131
+ data/Train_Dataset
132
+ data/Test_Dataset
133
+ results/*
134
+ snapshots
135
+ snapshots/*
136
+ Samples/*
137
+ temp/*
138
+ .idea
139
+ .vscode
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our
6
+ community a harassment-free experience for everyone, regardless of age, body
7
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
8
+ identity and expression, level of experience, education, socio-economic status,
9
+ nationality, personal appearance, race, religion, or sexual identity
10
+ and orientation.
11
+
12
+ We pledge to act and interact in ways that contribute to an open, welcoming,
13
+ diverse, inclusive, and healthy community.
14
+
15
+ ## Our Standards
16
+
17
+ Examples of behavior that contributes to a positive environment for our
18
+ community include:
19
+
20
+ * Demonstrating empathy and kindness toward other people
21
+ * Being respectful of differing opinions, viewpoints, and experiences
22
+ * Giving and gracefully accepting constructive feedback
23
+ * Accepting responsibility and apologizing to those affected by our mistakes,
24
+ and learning from the experience
25
+ * Focusing on what is best not just for us as individuals, but for the
26
+ overall community
27
+
28
+ Examples of unacceptable behavior include:
29
+
30
+ * The use of sexualized language or imagery, and sexual attention or
31
+ advances of any kind
32
+ * Trolling, insulting or derogatory comments, and personal or political attacks
33
+ * Public or private harassment
34
+ * Publishing others' private information, such as a physical or email
35
+ address, without their explicit permission
36
+ * Other conduct which could reasonably be considered inappropriate in a
37
+ professional setting
38
+
39
+ ## Enforcement Responsibilities
40
+
41
+ Community leaders are responsible for clarifying and enforcing our standards of
42
+ acceptable behavior and will take appropriate and fair corrective action in
43
+ response to any behavior that they deem inappropriate, threatening, offensive,
44
+ or harmful.
45
+
46
+ Community leaders have the right and responsibility to remove, edit, or reject
47
+ comments, commits, code, wiki edits, issues, and other contributions that are
48
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
49
+ decisions when appropriate.
50
+
51
+ ## Scope
52
+
53
+ This Code of Conduct applies within all community spaces, and also applies when
54
+ an individual is officially representing the community in public spaces.
55
+ Examples of representing our community include using an official e-mail address,
56
+ posting via an official social media account, or acting as an appointed
57
+ representative at an online or offline event.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported to the community leaders responsible for enforcement at
63
+ .
64
+ All complaints will be reviewed and investigated promptly and fairly.
65
+
66
+ All community leaders are obligated to respect the privacy and security of the
67
+ reporter of any incident.
68
+
69
+ ## Enforcement Guidelines
70
+
71
+ Community leaders will follow these Community Impact Guidelines in determining
72
+ the consequences for any action they deem in violation of this Code of Conduct:
73
+
74
+ ### 1. Correction
75
+
76
+ **Community Impact**: Use of inappropriate language or other behavior deemed
77
+ unprofessional or unwelcome in the community.
78
+
79
+ **Consequence**: A private, written warning from community leaders, providing
80
+ clarity around the nature of the violation and an explanation of why the
81
+ behavior was inappropriate. A public apology may be requested.
82
+
83
+ ### 2. Warning
84
+
85
+ **Community Impact**: A violation through a single incident or series
86
+ of actions.
87
+
88
+ **Consequence**: A warning with consequences for continued behavior. No
89
+ interaction with the people involved, including unsolicited interaction with
90
+ those enforcing the Code of Conduct, for a specified period of time. This
91
+ includes avoiding interactions in community spaces as well as external channels
92
+ like social media. Violating these terms may lead to a temporary or
93
+ permanent ban.
94
+
95
+ ### 3. Temporary Ban
96
+
97
+ **Community Impact**: A serious violation of community standards, including
98
+ sustained inappropriate behavior.
99
+
100
+ **Consequence**: A temporary ban from any sort of interaction or public
101
+ communication with the community for a specified period of time. No public or
102
+ private interaction with the people involved, including unsolicited interaction
103
+ with those enforcing the Code of Conduct, is allowed during this period.
104
+ Violating these terms may lead to a permanent ban.
105
+
106
+ ### 4. Permanent Ban
107
+
108
+ **Community Impact**: Demonstrating a pattern of violation of community
109
+ standards, including sustained inappropriate behavior, harassment of an
110
+ individual, or aggression toward or disparagement of classes of individuals.
111
+
112
+ **Consequence**: A permanent ban from any sort of public interaction within
113
+ the community.
114
+
115
+ ## Attribution
116
+
117
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118
+ version 2.0, available at
119
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120
+
121
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
122
+ enforcement ladder](https://github.com/mozilla/diversity).
123
+
124
+ [homepage]: https://www.contributor-covenant.org
125
+
126
+ For answers to common questions about this code of conduct, see the FAQ at
127
+ https://www.contributor-covenant.org/faq. Translations are available at
128
+ https://www.contributor-covenant.org/translations.
Expr.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.cuda as cuda
2
+
3
+ from utils.misc import *
4
+ from run import *
5
+
6
+ if __name__ == "__main__":
7
+ args = parse_args()
8
+ opt = load_config(args.config)
9
+
10
+ train(opt, args)
11
+ cuda.empty_cache()
12
+ if args.local_rank <= 0:
13
+ test(opt, args)
14
+ evaluate(opt, args)
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Taehun Kim
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,186 @@
1
- ---
2
- license: afl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Revisiting Image Pyramid Structure for High Resolution Salient Object Detection (InSPyReNet)
2
+
3
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/salient-object-detection-on-hku-is)](https://paperswithcode.com/sota/salient-object-detection-on-hku-is?p=revisiting-image-pyramid-structure-for-high)
4
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/salient-object-detection-on-duts-te)](https://paperswithcode.com/sota/salient-object-detection-on-duts-te?p=revisiting-image-pyramid-structure-for-high)
5
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/salient-object-detection-on-pascal-s)](https://paperswithcode.com/sota/salient-object-detection-on-pascal-s?p=revisiting-image-pyramid-structure-for-high)
6
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/salient-object-detection-on-ecssd)](https://paperswithcode.com/sota/salient-object-detection-on-ecssd?p=revisiting-image-pyramid-structure-for-high)
7
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/salient-object-detection-on-dut-omron)](https://paperswithcode.com/sota/salient-object-detection-on-dut-omron?p=revisiting-image-pyramid-structure-for-high)
8
+
9
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/rgb-salient-object-detection-on-davis-s)](https://paperswithcode.com/sota/rgb-salient-object-detection-on-davis-s?p=revisiting-image-pyramid-structure-for-high)
10
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/rgb-salient-object-detection-on-hrsod)](https://paperswithcode.com/sota/rgb-salient-object-detection-on-hrsod?p=revisiting-image-pyramid-structure-for-high)
11
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/rgb-salient-object-detection-on-uhrsd)](https://paperswithcode.com/sota/rgb-salient-object-detection-on-uhrsd?p=revisiting-image-pyramid-structure-for-high)
12
+
13
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/dichotomous-image-segmentation-on-dis-vd)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=revisiting-image-pyramid-structure-for-high)
14
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/dichotomous-image-segmentation-on-dis-te1)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te1?p=revisiting-image-pyramid-structure-for-high)
15
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/dichotomous-image-segmentation-on-dis-te2)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te2?p=revisiting-image-pyramid-structure-for-high)
16
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/dichotomous-image-segmentation-on-dis-te3)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te3?p=revisiting-image-pyramid-structure-for-high)
17
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-image-pyramid-structure-for-high/dichotomous-image-segmentation-on-dis-te4)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te4?p=revisiting-image-pyramid-structure-for-high)
18
+
19
+ Official PyTorch implementation of PyTorch implementation of Revisiting Image Pyramid Structure for High Resolution Salient Object Detection (InSPyReNet)
20
+
21
+ To appear in the 16th Asian Conference on Computer Vision (ACCV2022)
22
+
23
+ <p align="center">
24
+ <a href="https://github.com/plemeri/InSPyReNet/blob/main/LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue"></a>
25
+ <a href="https://arxiv.org/abs/2209.09475"><img src="https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg" ></a>
26
+ <a href="https://openaccess.thecvf.com/content/ACCV2022/html/Kim_Revisiting_Image_Pyramid_Structure_for_High_Resolution_Salient_Object_Detection_ACCV_2022_paper.html"><img src="https://img.shields.io/static/v1?label=inproceedings&message=Paper&color=orange"></a>
27
+ <a href="https://huggingface.co/spaces/taskswithcode/salient-object-detection"><img src="https://img.shields.io/static/v1?label=HuggingFace&message=Demo&color=yellow"></a>
28
+ <a href="https://www.taskswithcode.com/salient_object_detection/"><img src="https://img.shields.io/static/v1?label=TasksWithCode&message=Demo&color=blue"></a>
29
+ <a href="https://colab.research.google.com/github/taskswithcode/InSPyReNet/blob/main/TWCSOD.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg"></a>
30
+ </p>
31
+
32
+ > [Taehun Kim](https://scholar.google.co.kr/citations?user=f12-9yQAAAAJ&hl=en), [Kunhee Kim](https://scholar.google.co.kr/citations?user=6sU5r7MAAAAJ&hl=en), [Joonyeong Lee](https://scholar.google.co.kr/citations?hl=en&user=pOM4zSYAAAAJ), Dongmin Cha, [Jiho Lee](https://scholar.google.co.kr/citations?user=1Q1awj8AAAAJ&hl=en), [Daijin Kim](https://scholar.google.co.kr/citations?user=Mw6anjAAAAAJ&hl=en)
33
+
34
+ > **Abstract:**
35
+ Salient object detection (SOD) has been in the spotlight recently, yet has been studied less for high-resolution (HR) images.
36
+ Unfortunately, HR images and their pixel-level annotations are certainly more labor-intensive and time-consuming compared to low-resolution (LR) images.
37
+ Therefore, we propose an image pyramid-based SOD framework, Inverse Saliency Pyramid Reconstruction Network (InSPyReNet), for HR prediction without any of HR datasets.
38
+ We design InSPyReNet to produce a strict image pyramid structure of saliency map, which enables to ensemble multiple results with pyramid-based image blending.
39
+ For HR prediction, we design a pyramid blending method which synthesizes two different image pyramids from a pair of LR and HR scale from the same image to overcome effective receptive field (ERF) discrepancy. Our extensive evaluation on public LR and HR SOD benchmarks demonstrates that InSPyReNet surpasses the State-of-the-Art (SotA) methods on various SOD metrics and boundary accuracy.
40
+
41
+ ## Contents
42
+
43
+ 1. [News](#news-newspaper)
44
+ 2. [Demo](#demo-rocket)
45
+ 3. [Applications](#applications-video_game)
46
+ 4. [Easy Download](#easy-download-cake)
47
+ 5. [Getting Started](#getting-started-flight_departure)
48
+ 6. [Model Zoo](#model-zoo-giraffe)
49
+ 7. [Results](#results-100)
50
+ * [Quantitative Results](#quantitative-results)
51
+ * [Qualitative Results](#qualitative-results)
52
+ 8. [Citation](#citation)
53
+ 9. [Acknowledgement](#acknowledgement)
54
+ * [Special Thanks to](#special-thanks-to-tada)
55
+ 10. [References](#references)
56
+
57
+ ## News :newspaper:
58
+
59
+ [2022.10.04] [TasksWithCode](https://github.com/taskswithcode) mentioned our work in [Blog](https://medium.com/@taskswithcode/twc-9-7c960c921f69) and reproducing our work on [Colab](https://github.com/taskswithcode/InSPyReNet). Thank you for your attention!
60
+
61
+ [2022.10.20] We trained our model on [Dichotomous Image Segmentation dataset (DIS5K)](https://xuebinqin.github.io/dis/index.html) and showed competitive results! Trained checkpoint and pre-computed segmentation masks are available in [Model Zoo](./docs/model_zoo.md)). Also, you can check our qualitative and quantitative results in [Results](#results-100) section.
62
+
63
+ [2022.10.28] Multi GPU training for latest pytorch is now available.
64
+
65
+ [2022.10.31] [TasksWithCode](https://github.com/taskswithcode) provided an amazing web demo with [HuggingFace](https://huggingface.co). Visit the [WepApp](https://huggingface.co/spaces/taskswithcode/salient-object-detection) and try with your image!
66
+
67
+ [2022.11.09] :car: Lane segmentation for driving scene built based on InSPyReNet is available in [`LaneSOD`](https://github.com/plemeri/LaneSOD) repository.
68
+
69
+ [2022.11.18] I am speaking at The 16th Asian Conference on Computer Vision (ACCV2022). Please check out my talk if you're attending the event! #ACCV2022 #Macau - via #Whova event app
70
+
71
+ [2022.11.23] We made our work available on pypi package. Please visit [`transparent-background`](https://github.com/plemeri/transparent-background) to download our tool and try on your machine. It works as command-line tool and python API.
72
+
73
+ [2023.01.18] [rsreetech](https://github.com/rsreetech) shared a tutorial for our pypi package [`transparent-background`](https://github.com/plemeri/transparent-background) using colab. :tv: [[Youtube](https://www.youtube.com/watch?v=jKuQEnKmv4A)]
74
+
75
+ ## Demo :rocket:
76
+
77
+ [Image Sample](./figures/demo_image.gif) | [Video Sample](./figures/demo_video.gif)
78
+ :-:|:-:
79
+ <img src=./figures/demo_image.gif height=200px> | <img src=./figures/demo_video.gif height=200px>
80
+
81
+ ## Applications :video_game:
82
+ Here are some applications/extensions of our work.
83
+ ### Web Application <img src=https://huggingface.co/front/assets/huggingface_logo-noborder.svg height="20px" width="20px">
84
+ [TasksWithCode](https://github.com/taskswithcode) provided [WepApp](https://huggingface.co/spaces/taskswithcode/salient-object-detection) on HuggingFace to generate your own results!
85
+
86
+ [Web Demo](https://huggingface.co/spaces/taskswithcode/salient-object-detection) |
87
+ |:-:
88
+ <img src=./figures/demo_webapp.gif height=200px> |
89
+
90
+ ### Command-line Tool / Python API :pager:
91
+ Try using our model as command-line tool or python API. More details about how to use is available on [`transparent-background`](https://github.com/plemeri/transparent-background).
92
+ ```bash
93
+ pip install transparent-background
94
+ ```
95
+ ### Lane Segmentation :car:
96
+ We extend our model to detect lane markers in a driving scene in [`LaneSOD`](https://github.com/plemeri/LaneSOD)
97
+
98
+ [Lane Segmentation](https://github.com/plemeri/LaneSOD) |
99
+ |:-:
100
+ <img src=https://github.com/plemeri/LaneSOD/blob/main/figures/Teaser.gif height=200px> |
101
+
102
+ ## Easy Download :cake:
103
+
104
+ <details><summary>How to use easy download</summary>
105
+ <p>
106
+
107
+ Downloading each dataset, checkpoint is quite bothering, even for me :zzz:. Instead, you can download data we provide including `ImageNet pre-trained backbone checkpoints`, `Training Datasets`, `Testing Datasets for benchmark`, `Pre-trained model checkpoints`, `Pre-computed saliency maps` with single command below.
108
+ ```bash
109
+ python utils/download.py --extra --dest [DEST]
110
+ ```
111
+
112
+ * `--extra, -e`: Without this argument, only the datasets, checkpoint, and results from our main paper will be downloaded. Otherwise, all data will be downloaded including results from supplementary material and DIS5K results.
113
+ * `--dest [DEST], -d [DEST]`: If you want to specify the destination, use this argument. It will automatically create a symbolic links of the destination folders inside `data` and `snapshots`. Use this argument if you want to download data on other physical disk. Otherwise, it will download inside this repository folder.
114
+
115
+ If you want to download a certain checkpoint or pre-computed map, please refer to [Getting Started](#getting-started-flight_departure) and [Model Zoo](#model-zoo-giraffe).
116
+
117
+ </p>
118
+ </details>
119
+
120
+ ## Getting Started :flight_departure:
121
+
122
+ Please refer to [getting_started.md](./docs/getting_started.md) for training, testing, and evaluating on benchmarks, and inferencing on your own images.
123
+
124
+ ## Model Zoo :giraffe:
125
+
126
+ Please refer to [model_zoo.md](./docs/model_zoo.md) for downloading pre-trained models and pre-computed saliency maps.
127
+
128
+ ## Results :100:
129
+
130
+ ### Quantitative Results
131
+
132
+ [LR Benchmark](./figures/fig_quantitative.png) | [HR Benchmark](./figures/fig_quantitative2.png) | [HR Benchmark (Trained with extra DB)](./figures/fig_quantitative3.png) | [DIS](./figures/fig_quantitative4.png)
133
+ :-:|:-:|:-:|:-:
134
+ <img src=./figures/fig_quantitative.png height=200px> | <img src=./figures/fig_quantitative2.png height=200px> | <img src=./figures/fig_quantitative3.png height=200px> | <img src=./figures/fig_quantitative4.png height=200px>
135
+
136
+ </p>
137
+ </details>
138
+
139
+ ### Qualitative Results
140
+
141
+ [DAVIS-S & HRSOD](./figures/fig_qualitative.png) | [UHRSD](./figures/fig_qualitative2.png) | [UHRSD (w/ HR scale)](./figures/fig_qualitative3.jpg) | [DIS](./figures/fig_qualitative_dis.png)
142
+ :-:|:-:|:-:|:-:
143
+ <img src=./figures/fig_qualitative.png height=200px> | <img src=./figures/fig_qualitative2.png height=200px> | <img src=./figures/fig_qualitative3.jpg height=200px> | <img src=./figures/fig_qualitative_dis.png height=200px>
144
+
145
+ ## Citation
146
+
147
+ ```
148
+ @inproceedings{kim2022revisiting,
149
+ title={Revisiting Image Pyramid Structure for High Resolution Salient Object Detection},
150
+ author={Kim, Taehun and Kim, Kunhee and Lee, Joonyeong and Cha, Dongmin and Lee, Jiho and Kim, Daijin},
151
+ booktitle={Proceedings of the Asian Conference on Computer Vision},
152
+ pages={108--124},
153
+ year={2022}
154
+ }
155
+ ```
156
+
157
+ ## Acknowledgement
158
+
159
+ This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT)
160
+ (No.2017-0-00897, Development of Object Detection and Recognition for Intelligent Vehicles) and
161
+ (No.B0101-15-0266, Development of High Performance Visual BigData Discovery Platform for Large-Scale Realtime Data Analysis)
162
+
163
+ ### Special Thanks to :tada:
164
+ * [TasksWithCode](https://github.com/taskswithcode) team for sharing our work and making the most amazing web app demo.
165
+
166
+
167
+ ## References
168
+
169
+ ### Related Works
170
+
171
+ * Towards High-Resolution Salient Object Detection ([paper](https://drive.google.com/open?id=15o-Fel0BSyNulGoptrxfHR0t22qMHlTr) | [github](https://github.com/yi94code/HRSOD))
172
+ * Disentangled High Quality Salient Object Detection ([paper](https://openaccess.thecvf.com/content/ICCV2021/papers/Tang_Disentangled_High_Quality_Salient_Object_Detection_ICCV_2021_paper.pdf) | [github](https://github.com/luckybird1994/HQSOD))
173
+ * Pyramid Grafting Network for One-Stage High Resolution Saliency Detection ([paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Xie_Pyramid_Grafting_Network_for_One-Stage_High_Resolution_Saliency_Detection_CVPR_2022_paper.pdf) | [github](https://github.com/iCVTEAM/PGNet))
174
+
175
+ ### Resources
176
+
177
+ * Backbones: [Res2Net](https://github.com/Res2Net/Res2Net-PretrainedModels), [Swin Transformer](https://github.com/microsoft/Swin-Transformer)
178
+
179
+ * Datasets
180
+ * LR Benchmarks: [DUTS](http://saliencydetection.net/duts/), [DUT-OMRON](http://saliencydetection.net/dut-omron/), [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html), [HKU-IS](https://i.cs.hku.hk/~gbli/deep_saliency.html), [PASCAL-S](http://cbi.gatech.edu/salobj/)
181
+ * HR Benchmarks: [DAVIS-S, HRSOD](https://github.com/yi94code/HRSOD), [UHRSD](https://github.com/iCVTEAM/PGNet)
182
+ * Dichotomous Image Segmentation: [DIS5K](https://xuebinqin.github.io/dis/index.html)
183
+
184
+ * Evaluation Toolkit
185
+ * SOD Metrics (e.g., S-measure): [PySOD Metrics](https://github.com/lartpang/PySODMetrics)
186
+ * Boundary Metric (mBA): [CascadePSP](https://github.com/hkchengrex/CascadePSP)
configs/InSPyReNet_Res2Net50.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_Res2Net50"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [384, 384]
6
+ threshold: NULL
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data/Train_Dataset"
12
+ sets: ['DUTS-TR']
13
+ transforms:
14
+ static_resize:
15
+ size: [384, 384]
16
+ random_scale_crop:
17
+ range: [0.75, 1.25]
18
+ random_flip:
19
+ lr: True
20
+ ud: False
21
+ random_rotate:
22
+ range: [-10, 10]
23
+ random_image_enhance:
24
+ methods: ['contrast', 'sharpness', 'brightness']
25
+ tonumpy: NULL
26
+ normalize:
27
+ mean: [0.485, 0.456, 0.406]
28
+ std: [0.229, 0.224, 0.225]
29
+ totensor: NULL
30
+ Dataloader:
31
+ batch_size: 6
32
+ shuffle: True
33
+ num_workers: 8
34
+ pin_memory: True
35
+ Optimizer:
36
+ type: "Adam"
37
+ lr: 1.0e-05
38
+ weight_decay: 0.0
39
+ mixed_precision: False
40
+ Scheduler:
41
+ type: "PolyLr"
42
+ epoch: 60
43
+ gamma: 0.9
44
+ minimum_lr: 1.0e-07
45
+ warmup_iteration: 12000
46
+ Checkpoint:
47
+ checkpoint_epoch: 1
48
+ checkpoint_dir: "snapshots/InSPyReNet_Res2Net50"
49
+ Debug:
50
+ keys: ['saliency', 'laplacian']
51
+
52
+ Test:
53
+ Dataset:
54
+ type: "RGB_Dataset"
55
+ root: "data/Test_Dataset"
56
+ sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD', 'UHRSD-TE']
57
+ transforms:
58
+ static_resize:
59
+ size: [384, 384]
60
+ tonumpy: NULL
61
+ normalize:
62
+ mean: [0.485, 0.456, 0.406]
63
+ std: [0.229, 0.224, 0.225]
64
+ totensor: NULL
65
+ Dataloader:
66
+ num_workers: 8
67
+ pin_memory: True
68
+ Checkpoint:
69
+ checkpoint_dir: "snapshots/InSPyReNet_Res2Net50"
70
+
71
+ Eval:
72
+ gt_root: "data/Test_Dataset"
73
+ pred_root: "snapshots/InSPyReNet_Res2Net50"
74
+ result_path: "results"
75
+ datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD', 'UHRSD-TE']
76
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
configs/InSPyReNet_SwinB.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_SwinB"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [384, 384]
6
+ threshold: 512
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data/Train_Dataset"
12
+ sets: ['DUTS-TR']
13
+ transforms:
14
+ static_resize:
15
+ size: [384, 384]
16
+ random_scale_crop:
17
+ range: [0.75, 1.25]
18
+ random_flip:
19
+ lr: True
20
+ ud: False
21
+ random_rotate:
22
+ range: [-10, 10]
23
+ random_image_enhance:
24
+ methods: ['contrast', 'sharpness', 'brightness']
25
+ tonumpy: NULL
26
+ normalize:
27
+ mean: [0.485, 0.456, 0.406]
28
+ std: [0.229, 0.224, 0.225]
29
+ totensor: NULL
30
+ Dataloader:
31
+ batch_size: 6
32
+ shuffle: True
33
+ num_workers: 8
34
+ pin_memory: False
35
+ Optimizer:
36
+ type: "Adam"
37
+ lr: 1.0e-05
38
+ weight_decay: 0.0
39
+ mixed_precision: False
40
+ Scheduler:
41
+ type: "PolyLr"
42
+ epoch: 60
43
+ gamma: 0.9
44
+ minimum_lr: 1.0e-07
45
+ warmup_iteration: 12000
46
+ Checkpoint:
47
+ checkpoint_epoch: 1
48
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB"
49
+ Debug:
50
+ keys: ['saliency', 'laplacian']
51
+
52
+ Test:
53
+ Dataset:
54
+ type: "RGB_Dataset"
55
+ root: "data/Test_Dataset"
56
+ sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
57
+ transforms:
58
+ dynamic_resize:
59
+ L: 1280
60
+ tonumpy: NULL
61
+ normalize:
62
+ mean: [0.485, 0.456, 0.406]
63
+ std: [0.229, 0.224, 0.225]
64
+ totensor: NULL
65
+ Dataloader:
66
+ num_workers: 8
67
+ pin_memory: True
68
+ Checkpoint:
69
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB"
70
+
71
+ Eval:
72
+ gt_root: "data/Test_Dataset"
73
+ pred_root: "snapshots/InSPyReNet_SwinB"
74
+ result_path: "results"
75
+ datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
76
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
configs/extra_dataset/InSPyReNet_SwinB_DH.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_SwinB"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [1024, 1024]
6
+ threshold: NULL
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data/Train_Dataset"
12
+ sets: ['DUTS-TR', 'HRSOD-TR']
13
+ transforms:
14
+ static_resize:
15
+ size: [1024, 1024]
16
+ random_scale_crop:
17
+ range: [0.75, 1.25]
18
+ random_flip:
19
+ lr: True
20
+ ud: False
21
+ random_rotate:
22
+ range: [-10, 10]
23
+ random_image_enhance:
24
+ methods: ['contrast', 'sharpness', 'brightness']
25
+ tonumpy: NULL
26
+ normalize:
27
+ mean: [0.485, 0.456, 0.406]
28
+ std: [0.229, 0.224, 0.225]
29
+ totensor: NULL
30
+ Dataloader:
31
+ batch_size: 1
32
+ shuffle: True
33
+ num_workers: 8
34
+ pin_memory: False
35
+ Optimizer:
36
+ type: "Adam"
37
+ lr: 1.0e-05
38
+ weight_decay: 0.0
39
+ mixed_precision: False
40
+ Scheduler:
41
+ type: "PolyLr"
42
+ epoch: 60
43
+ gamma: 0.9
44
+ minimum_lr: 1.0e-07
45
+ warmup_iteration: 12000
46
+ Checkpoint:
47
+ checkpoint_epoch: 1
48
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_DH"
49
+ Debug:
50
+ keys: ['saliency', 'laplacian']
51
+
52
+ Test:
53
+ Dataset:
54
+ type: "RGB_Dataset"
55
+ root: "data/Test_Dataset"
56
+ sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
57
+ transforms:
58
+ static_resize:
59
+ size: [1024, 1024]
60
+ tonumpy: NULL
61
+ normalize:
62
+ mean: [0.485, 0.456, 0.406]
63
+ std: [0.229, 0.224, 0.225]
64
+ totensor: NULL
65
+ Dataloader:
66
+ num_workers: 8
67
+ pin_memory: True
68
+ Checkpoint:
69
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_DH"
70
+
71
+ Eval:
72
+ gt_root: "data/Test_Dataset"
73
+ pred_root: "snapshots/InSPyReNet_SwinB_DH"
74
+ result_path: "results"
75
+ datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
76
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
configs/extra_dataset/InSPyReNet_SwinB_DHU_LR.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_SwinB"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [384, 384]
6
+ threshold: 512
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data/Train_Dataset"
12
+ sets: ['DUTS-TR', 'HRSOD-TR', 'UHRSD-TR']
13
+ transforms:
14
+ static_resize:
15
+ size: [384, 384]
16
+ random_scale_crop:
17
+ range: [0.75, 1.25]
18
+ random_flip:
19
+ lr: True
20
+ ud: False
21
+ random_rotate:
22
+ range: [-10, 10]
23
+ random_image_enhance:
24
+ methods: ['contrast', 'sharpness', 'brightness']
25
+ tonumpy: NULL
26
+ normalize:
27
+ mean: [0.485, 0.456, 0.406]
28
+ std: [0.229, 0.224, 0.225]
29
+ totensor: NULL
30
+ Dataloader:
31
+ batch_size: 6
32
+ shuffle: True
33
+ num_workers: 8
34
+ pin_memory: False
35
+ Optimizer:
36
+ type: "Adam"
37
+ lr: 1.0e-05
38
+ weight_decay: 0.0
39
+ mixed_precision: False
40
+ Scheduler:
41
+ type: "PolyLr"
42
+ epoch: 60
43
+ gamma: 0.9
44
+ minimum_lr: 1.0e-07
45
+ warmup_iteration: 12000
46
+ Checkpoint:
47
+ checkpoint_epoch: 1
48
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_DHU_LR"
49
+ Debug:
50
+ keys: ['saliency', 'laplacian']
51
+
52
+ Test:
53
+ Dataset:
54
+ type: "RGB_Dataset"
55
+ root: "data/Test_Dataset"
56
+ sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
57
+ transforms:
58
+ dynamic_resize:
59
+ L: 1280
60
+ tonumpy: NULL
61
+ normalize:
62
+ mean: [0.485, 0.456, 0.406]
63
+ std: [0.229, 0.224, 0.225]
64
+ totensor: NULL
65
+ Dataloader:
66
+ num_workers: 8
67
+ pin_memory: True
68
+ Checkpoint:
69
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_DHU_LR"
70
+
71
+ Eval:
72
+ gt_root: "data/Test_Dataset"
73
+ pred_root: "snapshots/InSPyReNet_SwinB_DHU_LR"
74
+ result_path: "results"
75
+ datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
76
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
configs/extra_dataset/InSPyReNet_SwinB_DH_LR.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_SwinB"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [384, 384]
6
+ threshold: 512
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data/Train_Dataset"
12
+ sets: ['DUTS-TR', 'HRSOD-TR']
13
+ transforms:
14
+ static_resize:
15
+ size: [384, 384]
16
+ random_scale_crop:
17
+ range: [0.75, 1.25]
18
+ random_flip:
19
+ lr: True
20
+ ud: False
21
+ random_rotate:
22
+ range: [-10, 10]
23
+ random_image_enhance:
24
+ methods: ['contrast', 'sharpness', 'brightness']
25
+ tonumpy: NULL
26
+ normalize:
27
+ mean: [0.485, 0.456, 0.406]
28
+ std: [0.229, 0.224, 0.225]
29
+ totensor: NULL
30
+ Dataloader:
31
+ batch_size: 6
32
+ shuffle: True
33
+ num_workers: 8
34
+ pin_memory: False
35
+ Optimizer:
36
+ type: "Adam"
37
+ lr: 1.0e-05
38
+ weight_decay: 0.0
39
+ mixed_precision: False
40
+ Scheduler:
41
+ type: "PolyLr"
42
+ epoch: 60
43
+ gamma: 0.9
44
+ minimum_lr: 1.0e-07
45
+ warmup_iteration: 12000
46
+ Checkpoint:
47
+ checkpoint_epoch: 1
48
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_DH_LR"
49
+ Debug:
50
+ keys: ['saliency', 'laplacian']
51
+
52
+ Test:
53
+ Dataset:
54
+ type: "RGB_Dataset"
55
+ root: "data/Test_Dataset"
56
+ sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
57
+ transforms:
58
+ dynamic_resize:
59
+ L: 1280
60
+ tonumpy: NULL
61
+ normalize:
62
+ mean: [0.485, 0.456, 0.406]
63
+ std: [0.229, 0.224, 0.225]
64
+ totensor: NULL
65
+ Dataloader:
66
+ num_workers: 8
67
+ pin_memory: True
68
+ Checkpoint:
69
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_DH_LR"
70
+
71
+ Eval:
72
+ gt_root: "data/Test_Dataset"
73
+ pred_root: "snapshots/InSPyReNet_SwinB_DH_LR"
74
+ result_path: "results"
75
+ datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
76
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
configs/extra_dataset/InSPyReNet_SwinB_DIS5K.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_SwinB"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [1024, 1024]
6
+ threshold: NULL
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data/Train_Dataset"
12
+ sets: ['DIS-TR']
13
+ transforms:
14
+ static_resize:
15
+ size: [1024, 1024]
16
+ random_scale_crop:
17
+ range: [0.75, 1.25]
18
+ random_flip:
19
+ lr: True
20
+ ud: False
21
+ random_rotate:
22
+ range: [-10, 10]
23
+ random_image_enhance:
24
+ methods: ['contrast', 'sharpness', 'brightness']
25
+ tonumpy: NULL
26
+ normalize:
27
+ mean: [0.485, 0.456, 0.406]
28
+ std: [0.229, 0.224, 0.225]
29
+ totensor: NULL
30
+ Dataloader:
31
+ batch_size: 1
32
+ shuffle: True
33
+ num_workers: 8
34
+ pin_memory: False
35
+ Optimizer:
36
+ type: "Adam"
37
+ lr: 1.0e-05
38
+ weight_decay: 0.0
39
+ mixed_precision: False
40
+ Scheduler:
41
+ type: "PolyLr"
42
+ epoch: 180
43
+ gamma: 0.9
44
+ minimum_lr: 1.0e-07
45
+ warmup_iteration: 12000
46
+ Checkpoint:
47
+ checkpoint_epoch: 1
48
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_DIS5K"
49
+ Debug:
50
+ keys: ['saliency', 'laplacian']
51
+
52
+ Test:
53
+ Dataset:
54
+ type: "RGB_Dataset"
55
+ root: "data/Test_Dataset"
56
+ sets: ['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']
57
+ transforms:
58
+ static_resize:
59
+ size: [1024, 1024]
60
+ tonumpy: NULL
61
+ normalize:
62
+ mean: [0.485, 0.456, 0.406]
63
+ std: [0.229, 0.224, 0.225]
64
+ totensor: NULL
65
+ Dataloader:
66
+ num_workers: 8
67
+ pin_memory: True
68
+ Checkpoint:
69
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_DIS5K"
70
+
71
+ Eval:
72
+ gt_root: "data/Test_Dataset"
73
+ pred_root: "snapshots/InSPyReNet_SwinB_DIS5K"
74
+ result_path: "results"
75
+ datasets: ['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']
76
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
configs/extra_dataset/InSPyReNet_SwinB_DIS5K_LR.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_SwinB"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [384, 384]
6
+ threshold: 512
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data/Train_Dataset"
12
+ sets: ['DIS-TR']
13
+ transforms:
14
+ static_resize:
15
+ size: [384, 384]
16
+ random_scale_crop:
17
+ range: [0.75, 1.25]
18
+ random_flip:
19
+ lr: True
20
+ ud: False
21
+ random_rotate:
22
+ range: [-10, 10]
23
+ random_image_enhance:
24
+ methods: ['contrast', 'sharpness', 'brightness']
25
+ tonumpy: NULL
26
+ normalize:
27
+ mean: [0.485, 0.456, 0.406]
28
+ std: [0.229, 0.224, 0.225]
29
+ totensor: NULL
30
+ Dataloader:
31
+ batch_size: 6
32
+ shuffle: True
33
+ num_workers: 8
34
+ pin_memory: False
35
+ Optimizer:
36
+ type: "Adam"
37
+ lr: 1.0e-05
38
+ weight_decay: 0.0
39
+ mixed_precision: False
40
+ Scheduler:
41
+ type: "PolyLr"
42
+ epoch: 180
43
+ gamma: 0.9
44
+ minimum_lr: 1.0e-07
45
+ warmup_iteration: 12000
46
+ Checkpoint:
47
+ checkpoint_epoch: 1
48
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_DIS5K_LR"
49
+ Debug:
50
+ keys: ['saliency', 'laplacian']
51
+
52
+ Test:
53
+ Dataset:
54
+ type: "RGB_Dataset"
55
+ root: "data/Test_Dataset"
56
+ sets: ['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']
57
+ transforms:
58
+ dynamic_resize:
59
+ L: 1280
60
+ tonumpy: NULL
61
+ normalize:
62
+ mean: [0.485, 0.456, 0.406]
63
+ std: [0.229, 0.224, 0.225]
64
+ totensor: NULL
65
+ Dataloader:
66
+ num_workers: 8
67
+ pin_memory: True
68
+ Checkpoint:
69
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_DIS5K_LR"
70
+
71
+ Eval:
72
+ gt_root: "data/Test_Dataset"
73
+ pred_root: "snapshots/InSPyReNet_SwinB_DIS5K_LR"
74
+ result_path: "results"
75
+ datasets: ['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']
76
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
configs/extra_dataset/InSPyReNet_SwinB_HU.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_SwinB"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [1024, 1024]
6
+ threshold: NULL
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data/Train_Dataset"
12
+ sets: ['HRSOD-TR', 'UHRSD-TR']
13
+ transforms:
14
+ static_resize:
15
+ size: [1024, 1024]
16
+ random_scale_crop:
17
+ range: [0.75, 1.25]
18
+ random_flip:
19
+ lr: True
20
+ ud: False
21
+ random_rotate:
22
+ range: [-10, 10]
23
+ random_image_enhance:
24
+ methods: ['contrast', 'sharpness', 'brightness']
25
+ tonumpy: NULL
26
+ normalize:
27
+ mean: [0.485, 0.456, 0.406]
28
+ std: [0.229, 0.224, 0.225]
29
+ totensor: NULL
30
+ Dataloader:
31
+ batch_size: 1
32
+ shuffle: True
33
+ num_workers: 8
34
+ pin_memory: False
35
+ Optimizer:
36
+ type: "Adam"
37
+ lr: 1.0e-05
38
+ weight_decay: 0.0
39
+ mixed_precision: False
40
+ Scheduler:
41
+ type: "PolyLr"
42
+ epoch: 60
43
+ gamma: 0.9
44
+ minimum_lr: 1.0e-07
45
+ warmup_iteration: 12000
46
+ Checkpoint:
47
+ checkpoint_epoch: 1
48
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_HU"
49
+ Debug:
50
+ keys: ['saliency', 'laplacian']
51
+
52
+ Test:
53
+ Dataset:
54
+ type: "RGB_Dataset"
55
+ root: "data/Test_Dataset"
56
+ sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
57
+ transforms:
58
+ static_resize:
59
+ size: [1024, 1024]
60
+ tonumpy: NULL
61
+ normalize:
62
+ mean: [0.485, 0.456, 0.406]
63
+ std: [0.229, 0.224, 0.225]
64
+ totensor: NULL
65
+ Dataloader:
66
+ num_workers: 8
67
+ pin_memory: True
68
+ Checkpoint:
69
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_HU"
70
+
71
+ Eval:
72
+ gt_root: "data/Test_Dataset"
73
+ pred_root: "snapshots/InSPyReNet_SwinB_HU"
74
+ result_path: "results"
75
+ datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
76
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
configs/extra_dataset/InSPyReNet_SwinB_HU_LR.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_SwinB"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [384, 384]
6
+ threshold: 512
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data/Train_Dataset"
12
+ sets: ['HRSOD-TR', 'UHRSD-TR']
13
+ transforms:
14
+ static_resize:
15
+ size: [384, 384]
16
+ random_scale_crop:
17
+ range: [0.75, 1.25]
18
+ random_flip:
19
+ lr: True
20
+ ud: False
21
+ random_rotate:
22
+ range: [-10, 10]
23
+ random_image_enhance:
24
+ methods: ['contrast', 'sharpness', 'brightness']
25
+ tonumpy: NULL
26
+ normalize:
27
+ mean: [0.485, 0.456, 0.406]
28
+ std: [0.229, 0.224, 0.225]
29
+ totensor: NULL
30
+ Dataloader:
31
+ batch_size: 6
32
+ shuffle: True
33
+ num_workers: 8
34
+ pin_memory: False
35
+ Optimizer:
36
+ type: "Adam"
37
+ lr: 1.0e-05
38
+ weight_decay: 0.0
39
+ mixed_precision: False
40
+ Scheduler:
41
+ type: "PolyLr"
42
+ epoch: 60
43
+ gamma: 0.9
44
+ minimum_lr: 1.0e-07
45
+ warmup_iteration: 12000
46
+ Checkpoint:
47
+ checkpoint_epoch: 1
48
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_HU_LR"
49
+ Debug:
50
+ keys: ['saliency', 'laplacian']
51
+
52
+ Test:
53
+ Dataset:
54
+ type: "RGB_Dataset"
55
+ root: "data/Test_Dataset"
56
+ sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
57
+ transforms:
58
+ dynamic_resize:
59
+ L: 1280
60
+ tonumpy: NULL
61
+ normalize:
62
+ mean: [0.485, 0.456, 0.406]
63
+ std: [0.229, 0.224, 0.225]
64
+ totensor: NULL
65
+ Dataloader:
66
+ num_workers: 8
67
+ pin_memory: True
68
+ Checkpoint:
69
+ checkpoint_dir: "snapshots/InSPyReNet_SwinB_HU_LR"
70
+
71
+ Eval:
72
+ gt_root: "data/Test_Dataset"
73
+ pred_root: "snapshots/InSPyReNet_SwinB_HU_LR"
74
+ result_path: "results"
75
+ datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
76
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
configs/extra_dataset/Plus_Ultra.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_SwinB"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [1024, 1024]
6
+ threshold: NULL
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data"
12
+ sets: [
13
+ 'Train_Dataset/DUTS-TR',
14
+ 'Train_Dataset/HRSOD-TR',
15
+ 'Train_Dataset/UHRSD-TR',
16
+ 'Train_Dataset/DIS-TR',
17
+ 'Test_Dataset/DUTS-TE',
18
+ 'Test_Dataset/DUT-OMRON',
19
+ 'Test_Dataset/ECSSD',
20
+ 'Test_Dataset/HKU-IS',
21
+ 'Test_Dataset/PASCAL-S',
22
+ 'Test_Dataset/DAVIS-S',
23
+ 'Test_Dataset/HRSOD-TE',
24
+ 'Test_Dataset/UHRSD-TE',
25
+ 'Test_Dataset/FSS-1000',
26
+ 'Test_Dataset/MSRA-10K',
27
+ 'Test_Dataset/DIS-VD',
28
+ 'Test_Dataset/DIS-TE1',
29
+ 'Test_Dataset/DIS-TE2',
30
+ 'Test_Dataset/DIS-TE3',
31
+ 'Test_Dataset/DIS-TE4'
32
+ ]
33
+ transforms:
34
+ static_resize:
35
+ size: [1024, 1024]
36
+ random_scale_crop:
37
+ range: [0.75, 1.25]
38
+ random_flip:
39
+ lr: True
40
+ ud: False
41
+ random_rotate:
42
+ range: [-10, 10]
43
+ random_image_enhance:
44
+ methods: ['contrast', 'sharpness', 'brightness']
45
+ tonumpy: NULL
46
+ normalize:
47
+ mean: [0.485, 0.456, 0.406]
48
+ std: [0.229, 0.224, 0.225]
49
+ totensor: NULL
50
+ Dataloader:
51
+ batch_size: 1
52
+ shuffle: True
53
+ num_workers: 8
54
+ pin_memory: False
55
+ Optimizer:
56
+ type: "Adam"
57
+ lr: 1.0e-05
58
+ weight_decay: 0.0
59
+ mixed_precision: False
60
+ Scheduler:
61
+ type: "PolyLr"
62
+ epoch: 60
63
+ gamma: 0.9
64
+ minimum_lr: 1.0e-07
65
+ warmup_iteration: 12000
66
+ Checkpoint:
67
+ checkpoint_epoch: 1
68
+ checkpoint_dir: "snapshots/Plus_Ultra"
69
+ Debug:
70
+ keys: ['saliency', 'laplacian']
71
+
72
+ Test:
73
+ Dataset:
74
+ type: "RGB_Dataset"
75
+ root: "data/Test_Dataset"
76
+ sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
77
+ transforms:
78
+ # static_resize:
79
+ # size: [1024, 1024]
80
+ dynamic_resize:
81
+ L: 1280
82
+ tonumpy: NULL
83
+ normalize:
84
+ mean: [0.485, 0.456, 0.406]
85
+ std: [0.229, 0.224, 0.225]
86
+ totensor: NULL
87
+ Dataloader:
88
+ num_workers: 8
89
+ pin_memory: True
90
+ Checkpoint:
91
+ checkpoint_dir: "snapshots/Plus_Ultra"
92
+
93
+ Eval:
94
+ gt_root: "data/Test_Dataset"
95
+ pred_root: "snapshots/Plus_Ultra"
96
+ result_path: "results"
97
+ datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
98
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
configs/extra_dataset/Plus_Ultra_LR.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ name: "InSPyReNet_SwinB"
3
+ depth: 64
4
+ pretrained: True
5
+ base_size: [384, 384]
6
+ threshold: 512
7
+
8
+ Train:
9
+ Dataset:
10
+ type: "RGB_Dataset"
11
+ root: "data"
12
+ sets: [
13
+ 'Train_Dataset/DUTS-TR',
14
+ 'Train_Dataset/HRSOD-TR',
15
+ 'Train_Dataset/UHRSD-TR',
16
+ 'Train_Dataset/DIS-TR',
17
+ 'Test_Dataset/DUTS-TE',
18
+ 'Test_Dataset/DUT-OMRON',
19
+ 'Test_Dataset/ECSSD',
20
+ 'Test_Dataset/HKU-IS',
21
+ 'Test_Dataset/PASCAL-S',
22
+ 'Test_Dataset/DAVIS-S',
23
+ 'Test_Dataset/HRSOD-TE',
24
+ 'Test_Dataset/UHRSD-TE',
25
+ 'Test_Dataset/FSS-1000',
26
+ 'Test_Dataset/MSRA-10K',
27
+ 'Test_Dataset/DIS-VD',
28
+ 'Test_Dataset/DIS-TE1',
29
+ 'Test_Dataset/DIS-TE2',
30
+ 'Test_Dataset/DIS-TE3',
31
+ 'Test_Dataset/DIS-TE4'
32
+ ]
33
+ transforms:
34
+ static_resize:
35
+ size: [384, 384]
36
+ random_scale_crop:
37
+ range: [0.75, 1.25]
38
+ random_flip:
39
+ lr: True
40
+ ud: False
41
+ random_rotate:
42
+ range: [-10, 10]
43
+ random_image_enhance:
44
+ methods: ['contrast', 'sharpness', 'brightness']
45
+ tonumpy: NULL
46
+ normalize:
47
+ mean: [0.485, 0.456, 0.406]
48
+ std: [0.229, 0.224, 0.225]
49
+ totensor: NULL
50
+ Dataloader:
51
+ batch_size: 6
52
+ shuffle: True
53
+ num_workers: 8
54
+ pin_memory: False
55
+ Optimizer:
56
+ type: "Adam"
57
+ lr: 1.0e-05
58
+ weight_decay: 0.0
59
+ mixed_precision: False
60
+ Scheduler:
61
+ type: "PolyLr"
62
+ epoch: 60
63
+ gamma: 0.9
64
+ minimum_lr: 1.0e-07
65
+ warmup_iteration: 12000
66
+ Checkpoint:
67
+ checkpoint_epoch: 1
68
+ checkpoint_dir: "snapshots/Plus_Ultra_LR"
69
+ Debug:
70
+ keys: ['saliency', 'laplacian']
71
+
72
+ Test:
73
+ Dataset:
74
+ type: "RGB_Dataset"
75
+ root: "data/Test_Dataset"
76
+ sets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
77
+ transforms:
78
+ dynamic_resize:
79
+ L: 1280
80
+ tonumpy: NULL
81
+ normalize:
82
+ mean: [0.485, 0.456, 0.406]
83
+ std: [0.229, 0.224, 0.225]
84
+ totensor: NULL
85
+ Dataloader:
86
+ num_workers: 8
87
+ pin_memory: True
88
+ Checkpoint:
89
+ checkpoint_dir: "snapshots/Plus_Ultra_LR"
90
+
91
+ Eval:
92
+ gt_root: "data/Test_Dataset"
93
+ pred_root: "snapshots/Plus_Ultra_LR"
94
+ result_path: "results"
95
+ datasets: ['DUTS-TE', 'DUT-OMRON', 'ECSSD', 'HKU-IS', 'PASCAL-S', 'DAVIS-S', 'HRSOD-TE', 'UHRSD-TE']
96
+ metrics: ['Sm', 'mae', 'adpEm', 'maxEm', 'avgEm', 'adpFm', 'maxFm', 'avgFm', 'wFm', 'mBA']
data/custom_transforms.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from email.mime import base
2
+ import numpy as np
3
+ from PIL import Image
4
+ import os
5
+ import sys
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image, ImageOps, ImageFilter, ImageEnhance
9
+ from typing import Optional
10
+
11
+ filepath = os.path.split(__file__)[0]
12
+ repopath = os.path.split(filepath)[0]
13
+ sys.path.append(repopath)
14
+
15
+ from utils.misc import *
16
+
17
+ class static_resize:
18
+ # Resize for training
19
+ # size: h x w
20
+ def __init__(self, size=[384, 384], base_size=None):
21
+ self.size = size[::-1]
22
+ self.base_size = base_size[::-1] if base_size is not None else None
23
+
24
+ def __call__(self, sample):
25
+ sample['image'] = sample['image'].resize(self.size, Image.BILINEAR)
26
+ if 'gt' in sample.keys():
27
+ sample['gt'] = sample['gt'].resize(self.size, Image.NEAREST)
28
+
29
+ if self.base_size is not None:
30
+ sample['image_resized'] = sample['image'].resize(self.size, Image.BILINEAR)
31
+ if 'gt' in sample.keys():
32
+ sample['gt_resized'] = sample['gt'].resize(self.size, Image.NEAREST)
33
+
34
+ return sample
35
+
36
+ class dynamic_resize:
37
+ # base_size: h x w
38
+ def __init__(self, L=1280, base_size=[384, 384]):
39
+ self.L = L
40
+ self.base_size = base_size[::-1]
41
+
42
+ def __call__(self, sample):
43
+ size = list(sample['image'].size)
44
+ if (size[0] >= size[1]) and size[1] > self.L:
45
+ size[0] = size[0] / (size[1] / self.L)
46
+ size[1] = self.L
47
+ elif (size[1] > size[0]) and size[0] > self.L:
48
+ size[1] = size[1] / (size[0] / self.L)
49
+ size[0] = self.L
50
+ size = (int(round(size[0] / 32)) * 32, int(round(size[1] / 32)) * 32)
51
+
52
+ if 'image' in sample.keys():
53
+ sample['image_resized'] = sample['image'].resize(self.base_size, Image.BILINEAR)
54
+ sample['image'] = sample['image'].resize(size, Image.BILINEAR)
55
+
56
+ if 'gt' in sample.keys():
57
+ sample['gt_resized'] = sample['gt'].resize(self.base_size, Image.NEAREST)
58
+ sample['gt'] = sample['gt'].resize(size, Image.NEAREST)
59
+
60
+ return sample
61
+
62
+ class random_scale_crop:
63
+ def __init__(self, range=[0.75, 1.25]):
64
+ self.range = range
65
+
66
+ def __call__(self, sample):
67
+ scale = np.random.random() * (self.range[1] - self.range[0]) + self.range[0]
68
+ if np.random.random() < 0.5:
69
+ for key in sample.keys():
70
+ if key in ['image', 'gt']:
71
+ base_size = sample[key].size
72
+
73
+ scale_size = tuple((np.array(base_size) * scale).round().astype(int))
74
+ sample[key] = sample[key].resize(scale_size)
75
+
76
+ lf = (sample[key].size[0] - base_size[0]) // 2
77
+ up = (sample[key].size[1] - base_size[1]) // 2
78
+ rg = (sample[key].size[0] + base_size[0]) // 2
79
+ lw = (sample[key].size[1] + base_size[1]) // 2
80
+
81
+ border = -min(0, min(lf, up))
82
+ sample[key] = ImageOps.expand(sample[key], border=border)
83
+ sample[key] = sample[key].crop((lf + border, up + border, rg + border, lw + border))
84
+ return sample
85
+
86
+ class random_flip:
87
+ def __init__(self, lr=True, ud=True):
88
+ self.lr = lr
89
+ self.ud = ud
90
+
91
+ def __call__(self, sample):
92
+ lr = np.random.random() < 0.5 and self.lr is True
93
+ ud = np.random.random() < 0.5 and self.ud is True
94
+
95
+ for key in sample.keys():
96
+ if key in ['image', 'gt']:
97
+ sample[key] = np.array(sample[key])
98
+ if lr:
99
+ sample[key] = np.fliplr(sample[key])
100
+ if ud:
101
+ sample[key] = np.flipud(sample[key])
102
+ sample[key] = Image.fromarray(sample[key])
103
+
104
+ return sample
105
+
106
+ class random_rotate:
107
+ def __init__(self, range=[0, 360], interval=1):
108
+ self.range = range
109
+ self.interval = interval
110
+
111
+ def __call__(self, sample):
112
+ rot = (np.random.randint(*self.range) // self.interval) * self.interval
113
+ rot = rot + 360 if rot < 0 else rot
114
+
115
+ if np.random.random() < 0.5:
116
+ for key in sample.keys():
117
+ if key in ['image', 'gt']:
118
+ base_size = sample[key].size
119
+ sample[key] = sample[key].rotate(rot, expand=True, fillcolor=255 if key == 'depth' else None)
120
+
121
+ sample[key] = sample[key].crop(((sample[key].size[0] - base_size[0]) // 2,
122
+ (sample[key].size[1] - base_size[1]) // 2,
123
+ (sample[key].size[0] + base_size[0]) // 2,
124
+ (sample[key].size[1] + base_size[1]) // 2))
125
+
126
+ return sample
127
+
128
+ class random_image_enhance:
129
+ def __init__(self, methods=['contrast', 'brightness', 'sharpness']):
130
+ self.enhance_method = []
131
+ if 'contrast' in methods:
132
+ self.enhance_method.append(ImageEnhance.Contrast)
133
+ if 'brightness' in methods:
134
+ self.enhance_method.append(ImageEnhance.Brightness)
135
+ if 'sharpness' in methods:
136
+ self.enhance_method.append(ImageEnhance.Sharpness)
137
+
138
+ def __call__(self, sample):
139
+ if 'image' in sample.keys():
140
+ np.random.shuffle(self.enhance_method)
141
+
142
+ for method in self.enhance_method:
143
+ if np.random.random() > 0.5:
144
+ enhancer = method(sample['image'])
145
+ factor = float(1 + np.random.random() / 10)
146
+ sample['image'] = enhancer.enhance(factor)
147
+
148
+ return sample
149
+
150
+ class tonumpy:
151
+ def __init__(self):
152
+ pass
153
+
154
+ def __call__(self, sample):
155
+ for key in sample.keys():
156
+ if key in ['image', 'image_resized', 'gt', 'gt_resized']:
157
+ sample[key] = np.array(sample[key], dtype=np.float32)
158
+
159
+ return sample
160
+
161
+ class normalize:
162
+ def __init__(self, mean: Optional[list]=None, std: Optional[list]=None, div=255):
163
+ self.mean = mean if mean is not None else 0.0
164
+ self.std = std if std is not None else 1.0
165
+ self.div = div
166
+
167
+ def __call__(self, sample):
168
+ if 'image' in sample.keys():
169
+ sample['image'] /= self.div
170
+ sample['image'] -= self.mean
171
+ sample['image'] /= self.std
172
+
173
+ if 'image_resized' in sample.keys():
174
+ sample['image_resized'] /= self.div
175
+ sample['image_resized'] -= self.mean
176
+ sample['image_resized'] /= self.std
177
+
178
+ if 'gt' in sample.keys():
179
+ sample['gt'] /= self.div
180
+
181
+ if 'gt_resized' in sample.keys():
182
+ sample['gt_resized'] /= self.div
183
+
184
+ return sample
185
+ class totensor:
186
+ def __init__(self):
187
+ pass
188
+
189
+ def __call__(self, sample):
190
+ if 'image' in sample.keys():
191
+ sample['image'] = sample['image'].transpose((2, 0, 1))
192
+ sample['image'] = torch.from_numpy(sample['image']).float()
193
+
194
+ if 'image_resized' in sample.keys():
195
+ sample['image_resized'] = sample['image_resized'].transpose((2, 0, 1))
196
+ sample['image_resized'] = torch.from_numpy(sample['image_resized']).float()
197
+
198
+ if 'gt' in sample.keys():
199
+ sample['gt'] = torch.from_numpy(sample['gt'])
200
+ sample['gt'] = sample['gt'].unsqueeze(dim=0)
201
+
202
+ if 'gt_resized' in sample.keys():
203
+ sample['gt_resized'] = torch.from_numpy(sample['gt_resized'])
204
+ sample['gt_resized'] = sample['gt_resized'].unsqueeze(dim=0)
205
+
206
+ return sample
data/dataloader.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+
5
+ import numpy as np
6
+ import torchvision.transforms as transforms
7
+
8
+ from torch.utils.data.dataset import Dataset
9
+ from PIL import Image
10
+ from threading import Thread
11
+
12
+ filepath = os.path.split(__file__)[0]
13
+ repopath = os.path.split(filepath)[0]
14
+ sys.path.append(repopath)
15
+
16
+ from data.custom_transforms import *
17
+ from utils.misc import *
18
+
19
+ Image.MAX_IMAGE_PIXELS = None
20
+
21
+ def get_transform(tfs):
22
+ comp = []
23
+ for key, value in zip(tfs.keys(), tfs.values()):
24
+ if value is not None:
25
+ tf = eval(key)(**value)
26
+ else:
27
+ tf = eval(key)()
28
+ comp.append(tf)
29
+ return transforms.Compose(comp)
30
+
31
+ class RGB_Dataset(Dataset):
32
+ def __init__(self, root, sets, tfs):
33
+ self.images, self.gts = [], []
34
+
35
+ for set in sets:
36
+ image_root, gt_root = os.path.join(root, set, 'images'), os.path.join(root, set, 'masks')
37
+
38
+ images = [os.path.join(image_root, f) for f in os.listdir(image_root) if f.lower().endswith(('.jpg', '.png'))]
39
+ images = sort(images)
40
+
41
+ gts = [os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.lower().endswith(('.jpg', '.png'))]
42
+ gts = sort(gts)
43
+
44
+ self.images.extend(images)
45
+ self.gts.extend(gts)
46
+
47
+ self.filter_files()
48
+
49
+ self.size = len(self.images)
50
+ self.transform = get_transform(tfs)
51
+
52
+ def __getitem__(self, index):
53
+ image = Image.open(self.images[index]).convert('RGB')
54
+ gt = Image.open(self.gts[index]).convert('L')
55
+ shape = gt.size[::-1]
56
+ name = self.images[index].split(os.sep)[-1]
57
+ name = os.path.splitext(name)[0]
58
+
59
+ sample = {'image': image, 'gt': gt, 'name': name, 'shape': shape}
60
+
61
+ sample = self.transform(sample)
62
+ return sample
63
+
64
+ def filter_files(self):
65
+ assert len(self.images) == len(self.gts)
66
+ images, gts = [], []
67
+ for img_path, gt_path in zip(self.images, self.gts):
68
+ img, gt = Image.open(img_path), Image.open(gt_path)
69
+ if img.size == gt.size:
70
+ images.append(img_path)
71
+ gts.append(gt_path)
72
+ self.images, self.gts = images, gts
73
+
74
+ def __len__(self):
75
+ return self.size
76
+
77
+ class ImageLoader:
78
+ def __init__(self, root, tfs):
79
+ if os.path.isdir(root):
80
+ self.images = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
81
+ self.images = sort(self.images)
82
+ elif os.path.isfile(root):
83
+ self.images = [root]
84
+ self.size = len(self.images)
85
+ self.transform = get_transform(tfs)
86
+
87
+ def __iter__(self):
88
+ self.index = 0
89
+ return self
90
+
91
+ def __next__(self):
92
+ if self.index == self.size:
93
+ raise StopIteration
94
+ image = Image.open(self.images[self.index]).convert('RGB')
95
+ shape = image.size[::-1]
96
+ name = self.images[self.index].split(os.sep)[-1]
97
+ name = os.path.splitext(name)[0]
98
+
99
+ sample = {'image': image, 'name': name, 'shape': shape, 'original': image}
100
+ sample = self.transform(sample)
101
+ sample['image'] = sample['image'].unsqueeze(0)
102
+ if 'image_resized' in sample.keys():
103
+ sample['image_resized'] = sample['image_resized'].unsqueeze(0)
104
+
105
+ self.index += 1
106
+ return sample
107
+
108
+ def __len__(self):
109
+ return self.size
110
+
111
+ class VideoLoader:
112
+ def __init__(self, root, tfs):
113
+ if os.path.isdir(root):
114
+ self.videos = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.mp4', '.avi', 'mov'))]
115
+ elif os.path.isfile(root):
116
+ self.videos = [root]
117
+ self.size = len(self.videos)
118
+ self.transform = get_transform(tfs)
119
+
120
+ def __iter__(self):
121
+ self.index = 0
122
+ self.cap = None
123
+ self.fps = None
124
+ return self
125
+
126
+ def __next__(self):
127
+ if self.index == self.size:
128
+ raise StopIteration
129
+
130
+ if self.cap is None:
131
+ self.cap = cv2.VideoCapture(self.videos[self.index])
132
+ self.fps = self.cap.get(cv2.CAP_PROP_FPS)
133
+ ret, frame = self.cap.read()
134
+ name = self.videos[self.index].split(os.sep)[-1]
135
+ name = os.path.splitext(name)[0]
136
+ if ret is False:
137
+ self.cap.release()
138
+ self.cap = None
139
+ sample = {'image': None, 'shape': None, 'name': name, 'original': None}
140
+ self.index += 1
141
+
142
+ else:
143
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
144
+ image = Image.fromarray(frame).convert('RGB')
145
+ shape = image.size[::-1]
146
+ sample = {'image': image, 'shape': shape, 'name': name, 'original': image}
147
+ sample = self.transform(sample)
148
+ sample['image'] = sample['image'].unsqueeze(0)
149
+ if 'image_resized' in sample.keys():
150
+ sample['image_resized'] = sample['image_resized'].unsqueeze(0)
151
+
152
+ return sample
153
+
154
+ def __len__(self):
155
+ return self.size
156
+
157
+
158
+ class WebcamLoader:
159
+ def __init__(self, ID, tfs):
160
+ self.ID = int(ID)
161
+ self.cap = cv2.VideoCapture(self.ID)
162
+ self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
163
+ self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
164
+ self.transform = get_transform(tfs)
165
+ self.imgs = []
166
+ self.imgs.append(self.cap.read()[1])
167
+ self.thread = Thread(target=self.update, daemon=True)
168
+ self.thread.start()
169
+
170
+ def update(self):
171
+ while self.cap.isOpened():
172
+ ret, frame = self.cap.read()
173
+ if ret is True:
174
+ self.imgs.append(frame)
175
+ else:
176
+ break
177
+
178
+ def __iter__(self):
179
+ return self
180
+
181
+ def __next__(self):
182
+ if len(self.imgs) > 0:
183
+ frame = self.imgs[-1]
184
+ else:
185
+ frame = np.zeros((480, 640, 3)).astype(np.uint8)
186
+ if self.thread.is_alive() is False or cv2.waitKey(1) == ord('q'):
187
+ cv2.destroyAllWindows()
188
+ raise StopIteration
189
+
190
+ else:
191
+ image = Image.fromarray(frame).convert('RGB')
192
+ shape = image.size[::-1]
193
+ sample = {'image': image, 'shape': shape, 'name': 'webcam', 'original': image}
194
+ sample = self.transform(sample)
195
+ sample['image'] = sample['image'].unsqueeze(0)
196
+ if 'image_resized' in sample.keys():
197
+ sample['image_resized'] = sample['image_resized'].unsqueeze(0)
198
+
199
+ del self.imgs[:-1]
200
+ return sample
201
+
202
+
203
+ def __len__(self):
204
+ return 0
205
+
206
+ class RefinementLoader:
207
+ def __init__(self, image_dir, seg_dir, tfs):
208
+ self.images = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
209
+ self.images = sort(self.images)
210
+
211
+ self.segs = [os.path.join(seg_dir, f) for f in os.listdir(seg_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
212
+ self.segs = sort(self.segs)
213
+
214
+ self.size = len(self.images)
215
+ self.transform = get_transform(tfs)
216
+
217
+ def __iter__(self):
218
+ self.index = 0
219
+ return self
220
+
221
+ def __next__(self):
222
+ if self.index == self.size:
223
+ raise StopIteration
224
+ image = Image.open(self.images[self.index]).convert('RGB')
225
+ seg = Image.open(self.segs[self.index]).convert('L')
226
+ shape = image.size[::-1]
227
+ name = self.images[self.index].split(os.sep)[-1]
228
+ name = os.path.splitext(name)[0]
229
+
230
+ sample = {'image': image, 'gt': seg, 'name': name, 'shape': shape, 'original': image}
231
+ sample = self.transform(sample)
232
+ sample['image'] = sample['image'].unsqueeze(0)
233
+ sample['mask'] = sample['gt'].unsqueeze(0)
234
+ if 'image_resized' in sample.keys():
235
+ sample['image_resized'] = sample['image_resized'].unsqueeze(0)
236
+ del sample['gt']
237
+
238
+ self.index += 1
239
+ return sample
240
+
241
+ def __len__(self):
242
+ return self.size
243
+
docs/getting_started.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Getting Started :flight_departure:
2
+
3
+ ## Create environment
4
+ * Create conda environment with following command `conda create -y -n inspyrenet python`
5
+ * Activate environment with following command `conda activate inspyrenet`
6
+ * Install latest PyTorch from [Official Website](https://pytorch.org/get-started/locally/)
7
+ * Linux [2022.11.09]
8
+ ```
9
+ pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
10
+ ```
11
+ * Install requirements with following command `pip install -r requirements.txt`
12
+
13
+ ## Preparation
14
+
15
+ * For training, you may need training datasets and ImageNet pre-trained checkpoints for the backbone. For testing (inference), you may need test datasets (sample images).
16
+ * Training datasets are expected to be located under [Train.Dataset.root](https://github.com/plemeri/InSPyReNet/blob/main/configs/InSPyReNet_SwinB.yaml#L11). Likewise, testing datasets should be under [Test.Dataset.root](https://github.com/plemeri/InSPyReNet/blob/main/configs/InSPyReNet_SwinB.yaml#L55).
17
+ * Each dataset folder should contain `images` folder and `masks` folder for images and ground truth masks respectively.
18
+ * You may use multiple training datasets by listing dataset folders for [Train.Dataset.sets](https://github.com/plemeri/InSPyReNet/blob/main/configs/InSPyReNet_SwinB.yaml#L12), such as `[DUTS-TR] -> [DUTS-TR, HRSOD-TR, UHRSD-TR]`.
19
+
20
+ ### Backbone Checkpoints
21
+ Item | Destination Folder | OneDrive | GDrive
22
+ :-|:-|:-|:-
23
+ Res2Net50 checkpoint | `data/backbone_ckpt/*.pth` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUO7GDBwoC9CulTPdnq_yhQBlc0SIyyELMy3OmrNhOjcGg?e=T3PVyG&download=1) | [Link](https://drive.google.com/file/d/1MMhioAsZ-oYa5FpnTi22XBGh5HkjLX3y/view?usp=sharing)
24
+ SwinB checkpoint | `data/backbone_ckpt/*.pth` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESlYCLy0endMhcZm9eC2A4ABatxupp4UPh03EcqFjbtSRw?e=7y6lLt&download=1) | [Link](https://drive.google.com/file/d/1fBJFMupe5pV-Vtou-k8LTvHclWs0y1bI/view?usp=sharing)
25
+
26
+ * We changed Res2Net50 checkpoint to resolve an error while training with DDP. Please refer to [issue #9](https://github.com/plemeri/InSPyReNet/issues/9).
27
+
28
+ ### Train Datasets
29
+ Item | Destination Folder | OneDrive | GDrive
30
+ :-|:-|:-|:-
31
+ DUTS-TR | `data/Train_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EQ7L2XS-5YFMuJGee7o7HQ8BdRSLO8utbC_zRrv-KtqQ3Q?e=bCSIeo&download=1) | [Link](https://drive.google.com/file/d/1hy5UTq65uQWFO5yzhEn9KFIbdvhduThP/view?usp=share_link)
32
+
33
+ ### Extra Train Datasets (High Resolution, Optional)
34
+ Item | Destination Folder | OneDrive | GDrive
35
+ :-|:-|:-|:-
36
+ HRSOD-TR | `data/Train_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfUx92hUgZJNrWPj46PC0yEBXcorQskXOCSz8SnGH5AcLQ?e=WA5pc6&download=1) | N/A
37
+ UHRSD-TR | `data/Train_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ea4_UCbsKmhKnMCccAJOTLgBmQFsQ4KhJSf2jx8WQqj3Wg?e=18kYZS&download=1) | N/A
38
+ DIS-TR | `data/Train_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZtZJ493tVNJjBIpNLdus68B3u906PdWtHsf87pulj78jw?e=bUg2UQ&download=1) | N/A
39
+
40
+ ### Test Datasets
41
+ Item | Destination Folder | OneDrive | GDrive
42
+ :-|:-|:-|:-
43
+ DUTS-TE | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfuCxjveXphPpIska9BxHDMBHpYroEKdVlq9HsonZ4wLDw?e=Mz5giA&download=1) | [Link](https://drive.google.com/file/d/1w4pigcQe9zplMulp1rAwmp6yYXmEbmvy/view?usp=share_link)
44
+ DUT-OMRON | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ERvApm9rHH5LiR4NJoWHqDoBCneUQNextk8EjQ_Hy0bUHg?e=wTRZQb&download=1) | [Link](https://drive.google.com/file/d/1qIm_GQLLQkP6s-xDZhmp_FEAalavJDXf/view?usp=sharing)
45
+ ECSSD | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ES_GCdS0yblBmnRaDZ8xmKQBPU_qeECTVB9vlPUups8bnA?e=POVAlG&download=1) | [Link](https://drive.google.com/file/d/1qk_12KLGX6FPr1P_S9dQ7vXKaMqyIRoA/view?usp=sharing)
46
+ HKU-IS | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EYBRVvC1MJRAgSfzt0zaG94BU_UWaVrvpv4tjogu4vSV6w?e=TKN7hQ&download=1) | [Link](https://drive.google.com/file/d/1H3szJYbr5_CRCzrYfhPHThTgszkKd1EU/view?usp=share_link)
47
+ PASCAL-S | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfUDGDckMnZHhEPy8YQGwBQB5MN3qInBkEygpIr7ccJdTQ?e=YarZaQ&download=1) | [Link](https://drive.google.com/file/d/1h0IE2DlUt0HHZcvzMV5FCxtZqQqh9Ztf/view?usp=sharing)
48
+ DAVIS-S | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ebam8I2o-tRJgADcq-r9YOkBCDyaAdWBVWyfN-xCYyAfDQ?e=Mqz8cK&download=1) | [Link](https://drive.google.com/file/d/15F0dy9o02LPTlpUbnD9NJlGeKyKU3zOz/view?usp=sharing)
49
+ HRSOD-TE | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EbHOQZKC59xIpIdrM11ulWsBHRYY1wZY2njjWCDFXvT6IA?e=wls17m&download=1) | [Link](https://drive.google.com/file/d/1KnUCsvluS4kP2HwUFVRbKU8RK_v6rv2N/view?usp=sharing)
50
+ UHRSD-TE | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUpc8QJffNpNpESv-vpBi40BppucqOoXm_IaK7HYJkuOog?e=JTjGmS&download=1) | [Link](https://drive.google.com/file/d/1niiHBo9LX6-I3KsEWYOi_s6cul80IYvK/view?usp=sharing)
51
+
52
+ ### Extra Test Datasets (Optional)
53
+ Item | Destination Folder | OneDrive | GDrive
54
+ :-|:-|:-|:-
55
+ FSS-1000 | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EaP6DogjMAVCtTKcC_Bx-YoBoBSWBo90lesVcMyuCN35NA?e=0DDohA&download=1) | N/A
56
+ MSRA-10K | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EauXsIkxBzhDjio6fW0TubUB4L7YJc0GMTaq7VfjI2nPsg?e=c5DIxg&download=1) | N/A
57
+ DIS-VD | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EYJm3BqheaxNhdVoMt6X41gBgVnE4dulBwkp6pbOQtcIrQ?e=T6dtXm&download=1) | [Link](https://drive.google.com/file/d/1jhlZb3QyNPkc0o8nL3RWF0MLuVsVtJju/view?usp=sharing)
58
+ DIS-TE1 | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EcGYE_Gc0cVHoHi_qUtmsawB_5v9RSpJS5PIAPlLu6xo9A?e=Nu5mkQ&download=1) | [Link](https://drive.google.com/file/d/1iz8Y4uaX3ZBy42N2MIOkmNb0D5jroFPJ/view?usp=sharing)
59
+ DIS-TE2 | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EdhgMdbZ049GvMv7tNrjbbQB1wL9Ok85YshiXIkgLyTfkQ?e=mPA6Po&download=1) | [Link](https://drive.google.com/file/d/1DWSoWogTWDuS2PFbD1Qx9P8_SnSv2zTe/view?usp=sharing)
60
+ DIS-TE3 | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EcxXYC_3rXxKsQBrp6BdNb4BOKxBK3_vsR9RL76n7YVG-g?e=2M0cse&download=1) | [Link](https://drive.google.com/file/d/1bIVSjsxCjMrcmV1fsGplkKl9ORiiiJTZ/view?usp=sharing)
61
+ DIS-TE4 | `data/Test_Dataset/...` | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EdkG2SUi8flJvoYbHHOmvMABsGhkCJCsLLZlaV2E_SZimA?e=zlM2kC&download=1) | [Link](https://drive.google.com/file/d/1VuPNqkGTP1H4BFEHe807dTIkv8Kfzk5_/view?usp=sharing)
62
+
63
+ ## Train & Evaluate
64
+
65
+ * Train InSPyReNet
66
+ ```
67
+ # Single GPU
68
+ python run/Train.py --config configs/InSPyReNet_SwinB.yaml --verbose
69
+
70
+ # Multi GPUs with DDP (e.g., 4 GPUs)
71
+ torchrun --standalone --nproc_per_node=4 run/Train.py --config configs/InSPyReNet_SwinB.yaml --verbose
72
+
73
+ # Multi GPUs with DDP with designated devices (e.g., 2 GPUs - 0 and 1)
74
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nproc_per_node=2 run/Train.py --config configs/InSPyReNet_SwinB.yaml --verbose
75
+ ```
76
+
77
+ * `--config [CONFIG_FILE], -c [CONFIG_FILE]`: config file path for training.
78
+ * `--resume, -r`: use this argument to resume from last checkpoint.
79
+ * `--verbose, -v`: use this argument to output progress info.
80
+ * `--debug, -d`: use this argument to save debug images every epoch.
81
+
82
+ * Train with extra training datasets can be done by just changing [Train.Dataset.sets](https://github.com/plemeri/InSPyReNet/blob/main/configs/InSPyReNet_SwinB.yaml#L12) in the `yaml` config file, which is just simply adding more directories (e.g., HRSOD-TR, HRSOD-TR, UHRSD-TR, ...):
83
+ ```
84
+ Train:
85
+ Dataset:
86
+ type: "RGB_Dataset"
87
+ root: "data/RGB_Dataset/Train_Dataset"
88
+ sets: ['DUTS-TR'] --> ['DUTS-TR', 'HRSOD-TR', 'UHRSD-TR']
89
+ ```
90
+ * Inference for test benchmarks
91
+ ```
92
+ python run/Test.py --config configs/InSPyReNet_SwinB.yaml --verbose
93
+ ```
94
+ * Evaluate metrics
95
+ ```
96
+ python run/Eval.py --config configs/InSPyReNet_SwinB.yaml --verbose
97
+ ```
98
+
99
+ * All-in-One command (Train, Test, Eval in single command)
100
+ ```
101
+ # Single GPU
102
+ python Expr.py --config configs/InSPyReNet_SwinB.yaml --verbose
103
+
104
+ # Multi GPUs with DDP (e.g., 4 GPUs)
105
+ torchrun --standalone --nproc_per_node=4 Expr.py --config configs/InSPyReNet_SwinB.yaml --verbose
106
+
107
+ # Multi GPUs with DDP with designated devices (e.g., 2 GPUs - 0 and 1)
108
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nproc_per_node=2 Expr.py --config configs/InSPyReNet_SwinB.yaml --verbose
109
+ ```
110
+ ## Inference on your own data
111
+ * You can inference your own single image or images (.jpg, .jpeg, and .png are supported), single video or videos (.mp4, .mov, and .avi are supported), and webcam input (ubuntu and macos are tested so far).
112
+ ```
113
+ python run/Inference.py --config configs/InSPyReNet_SwinB.yaml --source [SOURCE] --dest [DEST] --type [TYPE] --gpu --jit --verbose
114
+ ```
115
+
116
+ * `--source [SOURCE]`: Specify your data in this argument.
117
+ * Single image - `image.png`
118
+ * Folder containing images - `path/to/img/folder`
119
+ * Single video - `video.mp4`
120
+ * Folder containing videos - `path/to/vid/folder`
121
+ * Webcam input: `0` (may vary depends on your device.)
122
+ * `--dest [DEST]` (optional): Specify your destination folder. If not specified, it will be saved in `results` folder.
123
+ * `--type [TYPE]`: Choose between `map` `green`, `rgba`, `blur`, `overlay`, and another image file.
124
+ * `map` will output saliency map only.
125
+ * `green` will change the background with green screen.
126
+ * `rgba` will generate RGBA output regarding saliency score as an alpha map. Note that this will not work for video and webcam input.
127
+ * `blur` will blur the background.
128
+ * `overlay` will cover the salient object with translucent green color, and highlight the edges.
129
+ * Another image file (e.g., `backgroud.png`) will be used as a background, and the object will be overlapped on it.
130
+ <details><summary>Examples of TYPE argument</summary>
131
+ <p>
132
+ <img src=../figures/demo_type.png >
133
+ </p>
134
+ </details>
135
+ * `--gpu`: Use this argument if you want to use GPU.
136
+ * `--jit`: Slightly improves inference speed when used.
137
+ * `--verbose`: Use when you want to visualize progress.
138
+
docs/model_zoo.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Zoo :giraffe:
2
+
3
+ ## Pre-trained Checkpoints
4
+
5
+ :label: Note: If you want to try our trained checkpoints below, please make sure to locate `latest.pth` file to the [Test.Checkpoint.checkpoint_dir](https://github.com/plemeri/InSPyReNet/blob/main/configs/InSPyReNet_SwinB.yaml#L69).
6
+
7
+ ### Trained with LR dataset only (DUTS-TR, 384 X 384)
8
+
9
+ Backbone | Train DB | Config | OneDrive | GDrive
10
+ :-|:-|:-|:-|:-
11
+ Res2Net50 | DUTS-TR | [InSPyReNet_Res2Net50](../configs/InSPyReNet_Res2Net50.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ERqm7RPeNBFPvVxkA5P5G2AB-mtFsiYkCNHnBf0DcwpFzw?e=nayVno&download=1) | [Link](https://drive.google.com/file/d/12moRuU8F0-xRvE16bVg6mkGWDuqYHJor/view?usp=sharing)
12
+ SwinB | DUTS-TR | [InSPyReNet_SwinB](../configs/InSPyReNet_SwinB.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EV0ow4E8LddCgu5tAuAkMbcBpBYoEDmJgQg5wkiuvLoQUA?e=cOZspv&download=1) | [Link](https://drive.google.com/file/d/1k5hNJImgEgSmz-ZeJEEb_dVkrOnswVMq/view?usp=sharing)
13
+
14
+ ### Trained with LR+HR dataset (with LR scale 384 X 384)
15
+
16
+ Backbone | Train DB | Config | OneDrive | GDrive
17
+ :-|:-|:-|:-|:-
18
+ SwinB | DUTS-TR, HRSOD-TR | [InSPyReNet_SwinB_DH_LR](../configs/extra_dataset/InSPyReNet_SwinB_DH_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EWxPZoIKALlGsfrNgUFNvxwBC8IE8jzzhPNtzcbHmTNFcg?e=e22wmy&download=1) | [Link](https://drive.google.com/file/d/1nbs6Xa7NMtcikeHFtkQRVrsHbBRHtIqC/view?usp=sharing)
19
+ SwinB | HRSOD-TR, UHRSD-TR | [InSPyReNet_SwinB_HU_LR](../configs/extra_dataset/InSPyReNet_SwinB_HU_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EQe-iy0AZctIkgl3o-BmVYUBn795wvii3tsnBq1fNUbc9g?e=gMZ4PV&download=1) | [Link](https://drive.google.com/file/d/1uLSIYXlRsZv4Ho0C-c87xKPhmF_b-Ll4/view?usp=sharing)
20
+ SwinB | DUTS-TR, HRSOD-TR, UHRSD-TR | [InSPyReNet_SwinB_DHU_LR](../configs/extra_dataset/InSPyReNet_SwinB_DHU_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfsCbnfAU1RAqCJIkj1ewRgBhFetStsGB6SMSq_UJZimjA?e=Ghuacy&download=1) | [Link](https://drive.google.com/file/d/14gRNwR7XwJ5oEcR4RWIVbYH3HEV6uBUq/view?usp=sharing)
21
+
22
+ ### Trained with LR+HR dataset (with HR scale 1024 X 1024)
23
+
24
+ Backbone | Train DB | Config | OneDrive | GDrive
25
+ :-|:-|:-|:-|:-
26
+ SwinB | DUTS-TR, HRSOD-TR | [InSPyReNet_SwinB_DH](../configs/extra_dataset/InSPyReNet_SwinB_DH.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EW2Qg-tMBBxNkygMj-8QgMUBiqHox5ExTOJl0LGLsn6AtA?e=Mam8Ur&download=1) | [Link](https://drive.google.com/file/d/1UBGFDUYZ9SysZr96dhsscZg7nDXt6IUD/view?usp=sharing)
27
+ SwinB | HRSOD-TR, UHRSD-TR | [InSPyReNet_SwinB_HU](../configs/extra_dataset/InSPyReNet_SwinB_HU.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EeE8nnCt_AdFvxxu0JsxwDgBCtGchuUka6DW9za_epX-Qw?e=U7wZu9&download=1) | [Link](https://drive.google.com/file/d/1HB02tiInEgo-pNzwqyvyV6eSN1Y2xPRJ/view?usp=sharing)
28
+
29
+
30
+ ### Trained with Dichotomous Image Segmentation dataset (DIS5K-TR) with LR (384 X 384) and HR (1024 X 1024) scale
31
+
32
+ Backbone | Train DB | Config | OneDrive | GDrive
33
+ :-|:-|:-|:-|:-
34
+ SwinB | DIS5K-TR | [InSPyReNet_SwinB_DIS5K_LR](../configs/extra_dataset/InSPyReNet_SwinB_DIS5K_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ERKrQ_YeoJRHl_3HcH8ZJLoBedsa6hZlmIIf66wobZRGuw?e=EywJmS&download=1) | [Link](https://drive.google.com/file/d/1Sj7GZoocGMHyKNhFnQQc1FTs76ysJIX3/view?usp=sharing)
35
+ SwinB | DIS5K-TR | [InSPyReNet_SwinB_DIS5K](../configs/extra_dataset/InSPyReNet_SwinB_DIS5K.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EShRbD-jZuRJiWv6DS2Us34BwgazGZvK1t4uTKvgE5379Q?e=8oVpS8) | [Link](https://drive.google.com/file/d/1aCxHMbhvj8ah77jXVgqvqImQA_Y0G-Yg/view?usp=sharing)
36
+
37
+ ### <img src="https://www.kindpng.com/picc/b/124-1249525_all-might-png.png" width=20px> Trained with Massive SOD Datasets (with LR (384 X 384) and HR (1024 X 1024), Not in the paper, for real-world scenarios)
38
+
39
+ Backbone | Train DB | Config | OneDrive | GDrive
40
+ :-|:-|:-|:-|:-
41
+ SwinB | DUTS-TR, HRSOD-TR, UHRSD-TR, DIS-TR, DUTS-TE, DUT-OMRON, ECSSD,HKU-IS, PASCAL-S, DAVIS-S, HRSOD-TE,UHRSD-TE, FSS-1000, MSRA-10K, DIS-VD, DIS-TE1, DIS-TE2, DIS-TE3, DIS-TE4 | [Plus_Ultra_LR](../configs/extra_dataset/Plus_Ultra_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESKuh1zhToVFsIxhUUsgkbgBnu2kFXCFLRuSz1xxsKzjhA?e=02HDrm&download=1) | [Link](https://drive.google.com/file/d/1iRX-0MVbUjvAVns5MtVdng6CQlGOIo3m/view?usp=sharing)
42
+ SwinB | DUTS-TR, HRSOD-TR, UHRSD-TR, DIS-TR, DUTS-TE, DUT-OMRON, ECSSD,HKU-IS, PASCAL-S, DAVIS-S, HRSOD-TE,UHRSD-TE, FSS-1000, MSRA-10K, DIS-VD, DIS-TE1, DIS-TE2, DIS-TE3, DIS-TE4 | [Plus_Ultra](../configs/extra_dataset/Plus_Ultra.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ET0R-yM8MfVHqI4g94AlL6AB-D6LxNajaWeDV4xbVQyh7w?e=l4JkZn) | [Link](https://drive.google.com/file/d/13oBl5MTVcWER3YU4fSxW3ATlVfueFQPY/view?usp=sharing)
43
+
44
+ * We used above checkpoints for [`transparent-background`](https://github.com/plemeri/transparent-background) which is a cool background removing command-line tool / python API!
45
+
46
+ ## Pre-Computed Saliency Maps
47
+
48
+ ### Salient Object Detection benchmarks
49
+
50
+ Config | DUTS-TE | DUT-OMRON | ECSSD | HKU-IS | PASCAL-S | DAVIS-S | HRSOD-TE | UHRSD-TE
51
+ :-|:-|:-|:-|:-|:-|:-|:-|:-
52
+ [InSPyReNet_Res2Net50](../configs/InSPyReNet_Res2Net50.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Eb0iKXGX1vxEjPhe9KGBKr0Bv7v2vv6Ua5NFybwc6aIi1w?e=oHnGyJ&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ef1HaYMvgh1EuuOL8bw3JGYB41-yo6KdTD8FGXcFZX3-Bg?e=TkW2m8&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EdEQQ8o-yI9BtTpROcuB_iIBFSIk0uBJAkNyob0WI04-kw?e=cwEj2V&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ec6LyrumVZ9PoB2Af0OW4dcBrDht0OznnwOBYiu8pdyJ4A?e=Y04Fmn&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETPijMHlTRZIjqO5H4LBknUBmy8TGDwOyUQ1H4EnIpHVOw?e=k1afrh&download=1) | N/A | N/A | N/A |
53
+ [InSPyReNet_SwinB](../configs/InSPyReNet_SwinB.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETumLjuBantLim4kRqj4e_MBpK_X5XrTwjGQUToN8TKVjw?e=ZT8AWy&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZbwxhwT6dtHkBJrIMMjTnkBK_HaDTXgHcDSjxuswZKTZw?e=9XeE4b&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESfQK-557uZOmUwG5W49j0EBK42_7dMOaQcPsc_U1zsYlA?e=IvjkKX&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EURH96JUp55EgUHI0A8RzKoBBqvQc1nVb_a67RgwOY7f-w?e=IP9xKa&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EakMpwONph9EmnCM2rS3hn4B_TL42T6tuLjBEeEa5ndkIw?e=XksfA5&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETUCKFX0k8lAvpsDj5sT23QB2ohuE_ST7oQnWdaW7AoCIw?e=MbSmM2&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ea6kf6Kk8fpIs15WWDfJMoYBeQUeo9WXvYx9oM5yWFE1Jg?e=RNN0Ns&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EVJLvAP3HwtHksZMUolIfCABHqP7GgAWcG_1V5T_Xrnr2g?e=ct3pzo&download=1) |
54
+ [InSPyReNet_SwinB_DH_LR](../configs/extra_dataset/InSPyReNet_SwinB_DH_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EbaXjDWFb6lGp9x5ae9mJPgBt9dkmgclq9XrXDjj4B5qSw?e=57wnhE) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EU7uNIpKPXZHrDNt3gapGfsBaUrkCj67-Paj4w_7E8xs1g?e=n0QBR1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EeBq5uU02FRLiW-0g-mQ_CsB3iHlMgAyOSY2Deu5suo9pQ?e=zjhB33) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EY5AppwMB4hFqbISFL_u5QMBeux7dQWtDeXaMMcAyLqLqQ?e=N71XVt) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfhejUtLallLroU6jjgDl-oBSHITWnkdiU6NVV95DO5YqQ?e=T6rrRW) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EesN4fRAE4JJk1aZ_tH3EBQBgirysALcgfw1Ipsa9dLe9Q?e=b5oWsg) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETpVYDijnZhInN9zBHRuwhUBSPslqe9FP0m3Eo3TWS0d5A?e=QTzfx6) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUkKVH3LSyFLq6UXuvWEVuYBEH8W8uAKgyolLRVuIUILag?e=y1SceD)
55
+ [InSPyReNet_SwinB_HU_LR](../configs/extra_dataset/InSPyReNet_SwinB_HU_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUSblgWAwg9Plc4LCGj4TLwB-7HLEdZGJqF1jHOU55g3OA?e=2YT3zM) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESNaVX4Fh5JHn5jOfgnSWi4Bx1bc9t6pg79IoG3mrpZpAw?e=M8D0CM) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETKjTH1vcVZDu5ahRaw4cb8B7JKaPMR-0Uae1DbwarobIA?e=Qw67IZ) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EcC9-lPgAXZHs7th9DiVjygB-zPIq_1Ii6i1GpbLGc1iPQ?e=EVXKp9) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EXjyAS4XyzlPqrpaizgElioBdmgd4E81qQzj11Qm4xo5sA?e=hoOzc2) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EcSCMR033GVNotOilIzYhIsBikzb8ZzGlkuW6aSNMlUpqQ?e=TFcgvE) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUHRSEANcmVFjcS_K-PeYr0B6VPXPgb2AHFUnlJYrf3dOQ?e=unwcqV) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EU2gb0hS5kZBgKNXWqQAomsBXU-zjGCXKAzYNNk4d6EAiQ?e=pjhiN2)
56
+ [InSPyReNet_SwinB_DHU_LR](../configs/extra_dataset/InSPyReNet_SwinB_DHU_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ee7Y647ERzxMgBoFceEEO6kBIUkIlmYHoizMj71gT37sxw?e=xDt83C) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESFGN3FCdzRAvlsW6bEaGdoBYNoJgK4DAjaS6WkVVyI_QQ?e=nYHklV) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EeONq5kOirRCkErrb6fFqd8B4w4SMZXBY1Q2mJvZcRsGdQ?e=K7fwQt) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ea5P4CBBQatPiUzsH53lckoB0k23haePzuERBfyJXaCbBg?e=AZ96mc) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ea2pezB7eo1BraeBpIA8YZoBkVY38rRa3KrwSIzY1cn2dQ?e=o121S6) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EfNDzH8O54pGtnAxit_hUjUBK9poVq4sxxnJjSG7PUQCkw?e=OWt7k8) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESfZaII0pO9IqqL1FkpjIuAB8SGxLcslJeWTuKQxPNFIVA?e=Ce1CWg) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EYsQqf1GKShJhBnkZ6gD5PABPOcjRcUGSfvTbe-wYh2O2Q?e=t4Xlxv)
57
+ [InSPyReNet_SwinB_DH](../configs/extra_dataset/InSPyReNet_SwinB_DH.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EWNaRqtzhtNFhMVfLcoyfqQBw35M8q8bxME3yZyhkTtc7Q?e=jrJe3v) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZF5_s8JfR9HqGBUZHSM_j4BVVMONp38_gJ1ekEdvlM-qQ?e=0chMdl) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZuaFObyNOtKg0W5cM7bqPYBZYGg7Z3V3i4sClI6bU_ntA?e=BxxQI7) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EY-yEnNNNT5KpiURhiDAkDEBMkiA1QwQ_T0wB1UC75GXVg?e=Lle02B) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EYfbsgWhm7lAlX_wj_WZZowBV_-l-UvvThC4LJEKpV0BQQ?e=zTiKpI) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Ef0GUP7c0bBBonHlqgB988YB0rgxCFq3oo0u8xCN8wfyyQ?e=LCb8UV) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EVUZBnRpa35AmrvdUybsQDMBzMZvuJWe5tT7635lh9MHDQ?e=FlpQW1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETfZ_zdrDvhOh21u2mqVhigBSxn3vlfKVIwXhRfzzSSFzA?e=kXBBi9)
58
+ [InSPyReNet_SwinB_HU](../configs/extra_dataset/InSPyReNet_SwinB_HU.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZq4JUACKCBMk2bn4yoWz6sBOKrSFTPfL7d5xopc1uDw_A?e=RtVHSl) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ETJaqoSPaYtNkc8eSGDeKzMBbjbuOAWgJwG4q52bW87aew?e=Pguh4b) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZAeCI6BPMdNsicnQ-m1pVEBwAhOiIcbelhOMoRGXGEvVA?e=BQKd7Q) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EVmGvZGz54JOvrIymLsSwq4Bpos3vWSXZm3oV7-qmGZgHA?e=4UhDgv) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ERHDUybOh4ZKkqWZpcu7MiMBFuTK6wACkKUZaNeEQGbCNQ?e=GCQnoe) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESPmZXTnfO5CrCoo_0OADxgBt_3FoU5mSFoSE4QWbWxumQ?e=HAsAYz) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EdTnUwEeMZNBrPSBBbGZKQcBmVshSTfca9qz_BqNpAUpOg?e=HsJ4Gx) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ET48owfVQEdImrh0V4gx8_ABsYXgbIJqtpq77aK_U28VwQ?e=h8er3H)
59
+
60
+ ### DIS5K Results
61
+
62
+ Config | DIS-VD | DIS-TE1 | DIS-TE2 | DIS-TE3 | DIS-TE4
63
+ :-|:-|:-|:-|:-|:-|
64
+ [InSPyReNet_SwinB_DIS5K_LR](../configs/extra_dataset/InSPyReNet_SwinB_DIS5K_LR.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUbzddb_QRRCtnXC8Xl6vZoBC6IqOfom52BWbzOYk-b2Ow?e=aqJYi1&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ESeW_SOD26tHjBLymmgFaXwBIJlljzNycaGWXLpOp_d_kA?e=2EyMai&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EYWT5fZDjI5Bn-lr-iQM1TsB1num0-UqfJC1TIv-LuOXoA?e=jCcnty&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EQXm1DEBfaNJmH0B-A3o23kBn4v5j53kP2nF9CpG9SQkyw?e=lEUiZh&download=1) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EZeH2ufGsFZIoUh6D8Rtv88BBF_ddQXav4xYXXRP_ayEAg?e=AMzIp8&download=1)
65
+ [InSPyReNet_SwinB_DIS5K](../configs/extra_dataset/InSPyReNet_SwinB_DIS5K.yaml) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EUisiphB_W5BjgfpIYu9oNgB_fY4XxL-MhR2gR-ZZUt49Q?e=gqorYs) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EW2y54ROYIZFlEq5ilRFwOsBSrIm2-HGsUHPHykaJvUBfA?e=797fxr) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/ER6yEGZkgWVOsL-mauYgPyoBDIoU0Mck-twEBgQi5g3Mxw?e=0yJZTT) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/Edvzj0iZ8hdDthm4Q-p2YHgBhP1X5z4AAccAoUasr2nihA?e=1dognG) | [Link](https://postechackr-my.sharepoint.com/:u:/g/personal/taehoon1018_postech_ac_kr/EWwMORYg8DlCgGGqOQFThZ8BgIU9wV9-0DwLrKldQl7N8w?e=nCwuqy)
figures/demo_image.gif ADDED

Git LFS Details

  • SHA256: 73b3c5f31f8a9e3768b80a858924bacaa6416e5de90b4ec67dce070fd55366f1
  • Pointer size: 132 Bytes
  • Size of remote file: 2.85 MB
figures/demo_type.png ADDED

Git LFS Details

  • SHA256: aa49c0e1875fdc8f9d3fdca4f3a92a5cc395ae890ccb7db94fc74addc99f937a
  • Pointer size: 132 Bytes
  • Size of remote file: 4.32 MB
figures/demo_video.gif ADDED

Git LFS Details

  • SHA256: a91422e301420d8c77ea2a2aa511f2a9377d0fb6ae1c3445283bc618943b427b
  • Pointer size: 133 Bytes
  • Size of remote file: 10.2 MB
figures/demo_webapp.gif ADDED

Git LFS Details

  • SHA256: 500cc53d85d451aa1c2833bf874dccadace67670cdd36a07807adc89b95c9731
  • Pointer size: 132 Bytes
  • Size of remote file: 5.51 MB
figures/fig_architecture.png ADDED

Git LFS Details

  • SHA256: 2824a63e92a0ca8921306c82bf098ed24926093840559cb552022f7b8fe1554f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
figures/fig_pyramid_blending.png ADDED

Git LFS Details

  • SHA256: 83550b1317798948a2dc08257720a03a23d57f7eb111d9e4d39e51ab115f4e0a
  • Pointer size: 131 Bytes
  • Size of remote file: 931 kB
figures/fig_qualitative.png ADDED

Git LFS Details

  • SHA256: 985439b05e48afeae809f1cc3127bd899e49453311147498d9d5d2a13069d343
  • Pointer size: 132 Bytes
  • Size of remote file: 4.72 MB
figures/fig_qualitative2.png ADDED

Git LFS Details

  • SHA256: e27f33cfebf9ad3dfb6474d1f0ff388bf4870d0eaf68ce34da3dc2bb904b6d26
  • Pointer size: 132 Bytes
  • Size of remote file: 5.12 MB
figures/fig_qualitative3.jpg ADDED

Git LFS Details

  • SHA256: 849e31679e43c683e454afa970f5526620eeefd4060a9a5048a8b0db51e84fa3
  • Pointer size: 133 Bytes
  • Size of remote file: 10.6 MB
figures/fig_qualitative_dis.png ADDED

Git LFS Details

  • SHA256: 976f51d9d116610260f41031f1a9a70443b6cc22bb8386bf158e871e3a9bd14c
  • Pointer size: 133 Bytes
  • Size of remote file: 25.7 MB
figures/fig_quantitative.png ADDED

Git LFS Details

  • SHA256: 688ddc0180065aad54629d4deba8db5087b5c123a16ce62e8e1c64c69a9ca5b1
  • Pointer size: 131 Bytes
  • Size of remote file: 733 kB
figures/fig_quantitative2.png ADDED

Git LFS Details

  • SHA256: 6ce8afbbcd6fa35f2b56837267205a1f5b19825b249e2f96955d336b2288b754
  • Pointer size: 131 Bytes
  • Size of remote file: 484 kB
figures/fig_quantitative3.png ADDED

Git LFS Details

  • SHA256: 1634e88a50c92c3bd8b3fcf98c9d4c03f1996d4fcf7756a1623c2bff0cdd0ca3
  • Pointer size: 131 Bytes
  • Size of remote file: 373 kB
figures/fig_quantitative4.png ADDED

Git LFS Details

  • SHA256: 8a3c8bd9b79f56abf3d2dac6b118bed521b8c0274f94c15f9a2b972659312b1c
  • Pointer size: 131 Bytes
  • Size of remote file: 368 kB
lib/InSPyReNet.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from lib.optim import *
7
+ from lib.modules.layers import *
8
+ from lib.modules.context_module import *
9
+ from lib.modules.attention_module import *
10
+ from lib.modules.decoder_module import *
11
+
12
+ from lib.backbones.Res2Net_v1b import res2net50_v1b_26w_4s
13
+ from lib.backbones.SwinTransformer import SwinB
14
+
15
+ class InSPyReNet(nn.Module):
16
+ def __init__(self, backbone, in_channels, depth=64, base_size=[384, 384], threshold=512, **kwargs):
17
+ super(InSPyReNet, self).__init__()
18
+ self.backbone = backbone
19
+ self.in_channels = in_channels
20
+ self.depth = depth
21
+ self.base_size = base_size
22
+ self.threshold = threshold
23
+
24
+ self.context1 = PAA_e(self.in_channels[0], self.depth, base_size=self.base_size, stage=0)
25
+ self.context2 = PAA_e(self.in_channels[1], self.depth, base_size=self.base_size, stage=1)
26
+ self.context3 = PAA_e(self.in_channels[2], self.depth, base_size=self.base_size, stage=2)
27
+ self.context4 = PAA_e(self.in_channels[3], self.depth, base_size=self.base_size, stage=3)
28
+ self.context5 = PAA_e(self.in_channels[4], self.depth, base_size=self.base_size, stage=4)
29
+
30
+ self.decoder = PAA_d(self.depth * 3, depth=self.depth, base_size=base_size, stage=2)
31
+
32
+ self.attention0 = SICA(self.depth , depth=self.depth, base_size=self.base_size, stage=0, lmap_in=True)
33
+ self.attention1 = SICA(self.depth * 2, depth=self.depth, base_size=self.base_size, stage=1, lmap_in=True)
34
+ self.attention2 = SICA(self.depth * 2, depth=self.depth, base_size=self.base_size, stage=2 )
35
+
36
+ self.sod_loss_fn = lambda x, y: weighted_bce_loss_with_logits(x, y, reduction='mean') + iou_loss_with_logits(x, y, reduction='mean')
37
+ self.pc_loss_fn = nn.L1Loss()
38
+
39
+ self.ret = lambda x, target: F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False)
40
+ self.res = lambda x, size: F.interpolate(x, size=size, mode='bilinear', align_corners=False)
41
+ self.des = lambda x, size: F.interpolate(x, size=size, mode='nearest')
42
+
43
+ self.image_pyramid = ImagePyramid(7, 1)
44
+
45
+ self.transition0 = Transition(17)
46
+ self.transition1 = Transition(9)
47
+ self.transition2 = Transition(5)
48
+
49
+ self.forward = self.forward_inference
50
+
51
+ def to(self, device):
52
+ self.image_pyramid.to(device)
53
+ self.transition0.to(device)
54
+ self.transition1.to(device)
55
+ self.transition2.to(device)
56
+ super(InSPyReNet, self).to(device)
57
+ return self
58
+
59
+ def cuda(self, idx=None):
60
+ if idx is None:
61
+ idx = torch.cuda.current_device()
62
+
63
+ self.to(device="cuda:{}".format(idx))
64
+ return self
65
+
66
+ def train(self, mode=True):
67
+ super(InSPyReNet, self).train(mode)
68
+ self.forward = self.forward_train
69
+ return self
70
+
71
+ def eval(self):
72
+ super(InSPyReNet, self).train(False)
73
+ self.forward = self.forward_inference
74
+ return self
75
+
76
+ def forward_inspyre(self, x):
77
+ B, _, H, W = x.shape
78
+
79
+ x1, x2, x3, x4, x5 = self.backbone(x)
80
+
81
+ x1 = self.context1(x1) #4
82
+ x2 = self.context2(x2) #4
83
+ x3 = self.context3(x3) #8
84
+ x4 = self.context4(x4) #16
85
+ x5 = self.context5(x5) #32
86
+
87
+ f3, d3 = self.decoder([x3, x4, x5]) #16
88
+
89
+ f3 = self.res(f3, (H // 4, W // 4 ))
90
+ f2, p2 = self.attention2(torch.cat([x2, f3], dim=1), d3.detach())
91
+ d2 = self.image_pyramid.reconstruct(d3.detach(), p2) #4
92
+
93
+ x1 = self.res(x1, (H // 2, W // 2))
94
+ f2 = self.res(f2, (H // 2, W // 2))
95
+ f1, p1 = self.attention1(torch.cat([x1, f2], dim=1), d2.detach(), p2.detach()) #2
96
+ d1 = self.image_pyramid.reconstruct(d2.detach(), p1) #2
97
+
98
+ f1 = self.res(f1, (H, W))
99
+ _, p0 = self.attention0(f1, d1.detach(), p1.detach()) #2
100
+ d0 = self.image_pyramid.reconstruct(d1.detach(), p0) #2
101
+
102
+ out = dict()
103
+ out['saliency'] = [d3, d2, d1, d0]
104
+ out['laplacian'] = [p2, p1, p0]
105
+
106
+ return out
107
+
108
+ def forward_train(self, sample):
109
+ x = sample['image']
110
+ B, _, H, W = x.shape
111
+ out = self.forward_inspyre(x)
112
+
113
+ d3, d2, d1, d0 = out['saliency']
114
+ p2, p1, p0 = out['laplacian']
115
+
116
+ if type(sample) == dict and 'gt' in sample.keys() and sample['gt'] is not None:
117
+ y = sample['gt']
118
+
119
+ y1 = self.image_pyramid.reduce(y)
120
+ y2 = self.image_pyramid.reduce(y1)
121
+ y3 = self.image_pyramid.reduce(y2)
122
+
123
+ loss = self.pc_loss_fn(self.des(d3, (H, W)), self.des(self.image_pyramid.reduce(d2), (H, W)).detach()) * 0.0001
124
+ loss += self.pc_loss_fn(self.des(d2, (H, W)), self.des(self.image_pyramid.reduce(d1), (H, W)).detach()) * 0.0001
125
+ loss += self.pc_loss_fn(self.des(d1, (H, W)), self.des(self.image_pyramid.reduce(d0), (H, W)).detach()) * 0.0001
126
+
127
+ loss += self.sod_loss_fn(self.des(d3, (H, W)), self.des(y3, (H, W)))
128
+ loss += self.sod_loss_fn(self.des(d2, (H, W)), self.des(y2, (H, W)))
129
+ loss += self.sod_loss_fn(self.des(d1, (H, W)), self.des(y1, (H, W)))
130
+ loss += self.sod_loss_fn(self.des(d0, (H, W)), self.des(y, (H, W)))
131
+
132
+ else:
133
+ loss = 0
134
+
135
+ pred = torch.sigmoid(d0)
136
+ pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
137
+
138
+ sample['pred'] = pred
139
+ sample['loss'] = loss
140
+ sample['saliency'] = [d3, d2, d1, d0]
141
+ sample['laplacian'] = [p2, p1, p0]
142
+ return sample
143
+
144
+ def forward_inference(self, sample):
145
+ B, _, H, W = sample['image'].shape
146
+
147
+ if self.threshold is None:
148
+ out = self.forward_inspyre(sample['image'])
149
+ d3, d2, d1, d0 = out['saliency']
150
+ p2, p1, p0 = out['laplacian']
151
+
152
+ elif (H <= self.threshold or W <= self.threshold):
153
+ if 'image_resized' in sample.keys():
154
+ out = self.forward_inspyre(sample['image_resized'])
155
+ else:
156
+ out = self.forward_inspyre(sample['image'])
157
+ d3, d2, d1, d0 = out['saliency']
158
+ p2, p1, p0 = out['laplacian']
159
+
160
+ else:
161
+ # LR Saliency Pyramid
162
+ lr_out = self.forward_inspyre(sample['image_resized'])
163
+ lr_d3, lr_d2, lr_d1, lr_d0 = lr_out['saliency']
164
+ lr_p2, lr_p1, lr_p0 = lr_out['laplacian']
165
+
166
+ # HR Saliency Pyramid
167
+ hr_out = self.forward_inspyre(sample['image'])
168
+ hr_d3, hr_d2, hr_d1, hr_d0 = hr_out['saliency']
169
+ hr_p2, hr_p1, hr_p0 = hr_out['laplacian']
170
+
171
+ # Pyramid Blending
172
+ d3 = self.ret(lr_d0, hr_d3)
173
+
174
+ t2 = self.ret(self.transition2(d3), hr_p2)
175
+ p2 = t2 * hr_p2
176
+ d2 = self.image_pyramid.reconstruct(d3, p2)
177
+
178
+ t1 = self.ret(self.transition1(d2), hr_p1)
179
+ p1 = t1 * hr_p1
180
+ d1 = self.image_pyramid.reconstruct(d2, p1)
181
+
182
+ t0 = self.ret(self.transition0(d1), hr_p0)
183
+ p0 = t0 * hr_p0
184
+ d0 = self.image_pyramid.reconstruct(d1, p0)
185
+
186
+ pred = torch.sigmoid(d0)
187
+ pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
188
+
189
+ sample['pred'] = pred
190
+ sample['loss'] = 0
191
+ sample['saliency'] = [d3, d2, d1, d0]
192
+ sample['laplacian'] = [p2, p1, p0]
193
+ return sample
194
+
195
+ def InSPyReNet_Res2Net50(depth, pretrained, base_size, **kwargs):
196
+ return InSPyReNet(res2net50_v1b_26w_4s(pretrained=pretrained), [64, 256, 512, 1024, 2048], depth, base_size, **kwargs)
197
+
198
+ def InSPyReNet_SwinB(depth, pretrained, base_size, **kwargs):
199
+ return InSPyReNet(SwinB(pretrained=pretrained), [128, 128, 256, 512, 1024], depth, base_size, **kwargs)
lib/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from lib.InSPyReNet import InSPyReNet_Res2Net50, InSPyReNet_SwinB
lib/backbones/Res2Net_v1b.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+ import torch.utils.model_zoo as model_zoo
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ['Res2Net', 'res2net50_v1b',
8
+ 'res2net101_v1b', 'res2net50_v1b_26w_4s']
9
+
10
+ model_urls = {
11
+ 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth',
12
+ 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth',
13
+ }
14
+ class Bottle2neck(nn.Module):
15
+ expansion = 4
16
+
17
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, baseWidth=26, scale=4, stype='normal'):
18
+ super(Bottle2neck, self).__init__()
19
+
20
+ width = int(math.floor(planes * (baseWidth / 64.0)))
21
+ self.conv1 = nn.Conv2d(inplanes, width * scale,
22
+ kernel_size=1, bias=False)
23
+ self.bn1 = nn.BatchNorm2d(width * scale)
24
+
25
+ if scale == 1:
26
+ self.nums = 1
27
+ else:
28
+ self.nums = scale - 1
29
+ if stype == 'stage':
30
+ self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
31
+ convs = []
32
+ bns = []
33
+ for i in range(self.nums):
34
+ convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride,
35
+ dilation=dilation, padding=dilation, bias=False))
36
+ bns.append(nn.BatchNorm2d(width))
37
+ self.convs = nn.ModuleList(convs)
38
+ self.bns = nn.ModuleList(bns)
39
+
40
+ self.conv3 = nn.Conv2d(width * scale, planes *
41
+ self.expansion, kernel_size=1, bias=False)
42
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
43
+
44
+ self.relu = nn.ReLU(inplace=True)
45
+ self.downsample = downsample
46
+ self.stype = stype
47
+ self.scale = scale
48
+ self.width = width
49
+
50
+ def forward(self, x):
51
+ residual = x
52
+
53
+ out = self.conv1(x)
54
+ out = self.bn1(out)
55
+ out = self.relu(out)
56
+
57
+ spx = torch.split(out, self.width, 1)
58
+ for i in range(self.nums):
59
+ if i == 0 or self.stype == 'stage':
60
+ sp = spx[i]
61
+ else:
62
+ sp = sp + spx[i]
63
+ sp = self.convs[i](sp)
64
+ sp = self.relu(self.bns[i](sp))
65
+ if i == 0:
66
+ out = sp
67
+ else:
68
+ out = torch.cat((out, sp), 1)
69
+ if self.scale != 1 and self.stype == 'normal':
70
+ out = torch.cat((out, spx[self.nums]), 1)
71
+ elif self.scale != 1 and self.stype == 'stage':
72
+ out = torch.cat((out, self.pool(spx[self.nums])), 1)
73
+
74
+ out = self.conv3(out)
75
+ out = self.bn3(out)
76
+
77
+ if self.downsample is not None:
78
+ residual = self.downsample(x)
79
+
80
+ out += residual
81
+ out = self.relu(out)
82
+
83
+ return out
84
+
85
+
86
+ class Res2Net(nn.Module):
87
+
88
+ def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000, output_stride=32):
89
+ self.inplanes = 64
90
+ super(Res2Net, self).__init__()
91
+ self.baseWidth = baseWidth
92
+ self.scale = scale
93
+ self.output_stride = output_stride
94
+ if self.output_stride == 8:
95
+ self.grid = [1, 2, 1]
96
+ self.stride = [1, 2, 1, 1]
97
+ self.dilation = [1, 1, 2, 4]
98
+ elif self.output_stride == 16:
99
+ self.grid = [1, 2, 4]
100
+ self.stride = [1, 2, 2, 1]
101
+ self.dilation = [1, 1, 1, 2]
102
+ elif self.output_stride == 32:
103
+ self.grid = [1, 2, 4]
104
+ self.stride = [1, 2, 2, 2]
105
+ self.dilation = [1, 1, 2, 4]
106
+ self.conv1 = nn.Sequential(
107
+ nn.Conv2d(3, 32, 3, 2, 1, bias=False),
108
+ nn.BatchNorm2d(32),
109
+ nn.ReLU(inplace=True),
110
+ nn.Conv2d(32, 32, 3, 1, 1, bias=False),
111
+ nn.BatchNorm2d(32),
112
+ nn.ReLU(inplace=True),
113
+ nn.Conv2d(32, 64, 3, 1, 1, bias=False)
114
+ )
115
+ self.bn1 = nn.BatchNorm2d(64)
116
+ self.relu = nn.ReLU()
117
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
118
+ self.layer1 = self._make_layer(
119
+ block, 64, layers[0], stride=self.stride[0], dilation=self.dilation[0])
120
+ self.layer2 = self._make_layer(
121
+ block, 128, layers[1], stride=self.stride[1], dilation=self.dilation[1])
122
+ self.layer3 = self._make_layer(
123
+ block, 256, layers[2], stride=self.stride[2], dilation=self.dilation[2])
124
+ self.layer4 = self._make_layer(
125
+ block, 512, layers[3], stride=self.stride[3], dilation=self.dilation[3], grid=self.grid)
126
+
127
+ for m in self.modules():
128
+ if isinstance(m, nn.Conv2d):
129
+ nn.init.kaiming_normal_(
130
+ m.weight, mode='fan_out', nonlinearity='relu')
131
+ elif isinstance(m, nn.BatchNorm2d):
132
+ nn.init.constant_(m.weight, 1)
133
+ nn.init.constant_(m.bias, 0)
134
+
135
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, grid=None):
136
+ downsample = None
137
+ if stride != 1 or self.inplanes != planes * block.expansion:
138
+ downsample = nn.Sequential(
139
+ nn.AvgPool2d(kernel_size=stride, stride=stride,
140
+ ceil_mode=True, count_include_pad=False),
141
+ nn.Conv2d(self.inplanes, planes * block.expansion,
142
+ kernel_size=1, stride=1, bias=False),
143
+ nn.BatchNorm2d(planes * block.expansion),
144
+ )
145
+
146
+ layers = []
147
+ layers.append(block(self.inplanes, planes, stride, dilation, downsample=downsample,
148
+ stype='stage', baseWidth=self.baseWidth, scale=self.scale))
149
+ self.inplanes = planes * block.expansion
150
+
151
+ if grid is not None:
152
+ assert len(grid) == blocks
153
+ else:
154
+ grid = [1] * blocks
155
+
156
+ for i in range(1, blocks):
157
+ layers.append(block(self.inplanes, planes, dilation=dilation *
158
+ grid[i], baseWidth=self.baseWidth, scale=self.scale))
159
+
160
+ return nn.Sequential(*layers)
161
+
162
+ def change_stride(self, output_stride=16):
163
+ if output_stride == self.output_stride:
164
+ return
165
+ else:
166
+ self.output_stride = output_stride
167
+ if self.output_stride == 8:
168
+ self.grid = [1, 2, 1]
169
+ self.stride = [1, 2, 1, 1]
170
+ self.dilation = [1, 1, 2, 4]
171
+ elif self.output_stride == 16:
172
+ self.grid = [1, 2, 4]
173
+ self.stride = [1, 2, 2, 1]
174
+ self.dilation = [1, 1, 1, 2]
175
+ elif self.output_stride == 32:
176
+ self.grid = [1, 2, 4]
177
+ self.stride = [1, 2, 2, 2]
178
+ self.dilation = [1, 1, 2, 4]
179
+
180
+ for i, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]):
181
+ for j, block in enumerate(layer):
182
+ if block.downsample is not None:
183
+ block.downsample[0].kernel_size = (
184
+ self.stride[i], self.stride[i])
185
+ block.downsample[0].stride = (
186
+ self.stride[i], self.stride[i])
187
+ if hasattr(block, 'pool'):
188
+ block.pool.stride = (
189
+ self.stride[i], self.stride[i])
190
+ for conv in block.convs:
191
+ conv.stride = (self.stride[i], self.stride[i])
192
+ for conv in block.convs:
193
+ d = self.dilation[i] if i != 3 else self.dilation[i] * \
194
+ self.grid[j]
195
+ conv.dilation = (d, d)
196
+ conv.padding = (d, d)
197
+
198
+ def forward(self, x):
199
+ x = self.conv1(x)
200
+ x = self.bn1(x)
201
+ x = self.relu(x)
202
+ x = self.maxpool(x)
203
+
204
+ out = [x]
205
+
206
+ x = self.layer1(x)
207
+ out.append(x)
208
+ x = self.layer2(x)
209
+ out.append(x)
210
+ x = self.layer3(x)
211
+ out.append(x)
212
+ x = self.layer4(x)
213
+ out.append(x)
214
+
215
+ return out
216
+
217
+
218
+ def res2net50_v1b(pretrained=False, **kwargs):
219
+ model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
220
+ if pretrained:
221
+ model.load_state_dict(model_zoo.load_url(
222
+ model_urls['res2net50_v1b_26w_4s']))
223
+
224
+ return model
225
+
226
+
227
+ def res2net101_v1b(pretrained=False, **kwargs):
228
+ model = Res2Net(Bottle2neck, [3, 4, 23, 3],
229
+ baseWidth=26, scale=4, **kwargs)
230
+ if pretrained:
231
+ model.load_state_dict(model_zoo.load_url(
232
+ model_urls['res2net101_v1b_26w_4s']))
233
+
234
+ return model
235
+
236
+
237
+ def res2net50_v1b_26w_4s(pretrained=True, **kwargs):
238
+ model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
239
+ if pretrained is True:
240
+ model.load_state_dict(torch.load('data/backbone_ckpt/res2net50_v1b_26w_4s-3cf99910.pth', map_location='cpu'))
241
+
242
+ return model
243
+
244
+
245
+ def res2net101_v1b_26w_4s(pretrained=True, **kwargs):
246
+ model = Res2Net(Bottle2neck, [3, 4, 23, 3],
247
+ baseWidth=26, scale=4, **kwargs)
248
+ if pretrained is True:
249
+ model.load_state_dict(torch.load('data/backbone_ckpt/res2net101_v1b_26w_4s-0812c246.pth', map_location='cpu'))
250
+
251
+ return model
252
+
253
+
254
+ def res2net152_v1b_26w_4s(pretrained=False, **kwargs):
255
+ model = Res2Net(Bottle2neck, [3, 8, 36, 3],
256
+ baseWidth=26, scale=4, **kwargs)
257
+ if pretrained:
258
+ model.load_state_dict(model_zoo.load_url(
259
+ model_urls['res2net152_v1b_26w_4s']))
260
+
261
+ return model
262
+
263
+
264
+ if __name__ == '__main__':
265
+ images = torch.rand(1, 3, 224, 224).cuda(0)
266
+ model = res2net50_v1b_26w_4s(pretrained=True)
267
+ model = model.cuda(0)
268
+ print(model(images).size())
lib/backbones/SwinTransformer.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as checkpoint
12
+ import numpy as np
13
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
14
+
15
+ class Mlp(nn.Module):
16
+ """ Multilayer perceptron."""
17
+
18
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19
+ super().__init__()
20
+ out_features = out_features or in_features
21
+ hidden_features = hidden_features or in_features
22
+ self.fc1 = nn.Linear(in_features, hidden_features)
23
+ self.act = act_layer()
24
+ self.fc2 = nn.Linear(hidden_features, out_features)
25
+ self.drop = nn.Dropout(drop)
26
+
27
+ def forward(self, x):
28
+ x = self.fc1(x)
29
+ x = self.act(x)
30
+ x = self.drop(x)
31
+ x = self.fc2(x)
32
+ x = self.drop(x)
33
+ return x
34
+
35
+
36
+ def window_partition(x, window_size):
37
+ """
38
+ Args:
39
+ x: (B, H, W, C)
40
+ window_size (int): window size
41
+
42
+ Returns:
43
+ windows: (num_windows*B, window_size, window_size, C)
44
+ """
45
+ B, H, W, C = x.shape
46
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
47
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
48
+ return windows
49
+
50
+
51
+ def window_reverse(windows, window_size, H, W):
52
+ """
53
+ Args:
54
+ windows: (num_windows*B, window_size, window_size, C)
55
+ window_size (int): Window size
56
+ H (int): Height of image
57
+ W (int): Width of image
58
+
59
+ Returns:
60
+ x: (B, H, W, C)
61
+ """
62
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
63
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
64
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
65
+ return x
66
+
67
+
68
+ class WindowAttention(nn.Module):
69
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
70
+ It supports both of shifted and non-shifted window.
71
+
72
+ Args:
73
+ dim (int): Number of input channels.
74
+ window_size (tuple[int]): The height and width of the window.
75
+ num_heads (int): Number of attention heads.
76
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
77
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
78
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
79
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
80
+ """
81
+
82
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
83
+
84
+ super().__init__()
85
+ self.dim = dim
86
+ self.window_size = window_size # Wh, Ww
87
+ self.num_heads = num_heads
88
+ head_dim = dim // num_heads
89
+ self.scale = qk_scale or head_dim ** -0.5
90
+
91
+ # define a parameter table of relative position bias
92
+ self.relative_position_bias_table = nn.Parameter(
93
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
94
+
95
+ # get pair-wise relative position index for each token inside the window
96
+ coords_h = torch.arange(self.window_size[0])
97
+ coords_w = torch.arange(self.window_size[1])
98
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
99
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
100
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
101
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
102
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
103
+ relative_coords[:, :, 1] += self.window_size[1] - 1
104
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
105
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
106
+ self.register_buffer("relative_position_index", relative_position_index)
107
+
108
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
109
+ self.attn_drop = nn.Dropout(attn_drop)
110
+ self.proj = nn.Linear(dim, dim)
111
+ self.proj_drop = nn.Dropout(proj_drop)
112
+
113
+ trunc_normal_(self.relative_position_bias_table, std=.02)
114
+ self.softmax = nn.Softmax(dim=-1)
115
+
116
+ def forward(self, x, mask=None):
117
+ """ Forward function.
118
+
119
+ Args:
120
+ x: input features with shape of (num_windows*B, N, C)
121
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
122
+ """
123
+ B_, N, C = x.shape
124
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
125
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
126
+
127
+ q = q * self.scale
128
+ attn = (q @ k.transpose(-2, -1))
129
+
130
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
131
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
132
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
133
+ attn = attn + relative_position_bias.unsqueeze(0)
134
+
135
+ if mask is not None:
136
+ nW = mask.shape[0]
137
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
138
+ attn = attn.view(-1, self.num_heads, N, N)
139
+ attn = self.softmax(attn)
140
+ else:
141
+ attn = self.softmax(attn)
142
+
143
+ attn = self.attn_drop(attn)
144
+
145
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
146
+ x = self.proj(x)
147
+ x = self.proj_drop(x)
148
+ return x
149
+
150
+
151
+ class SwinTransformerBlock(nn.Module):
152
+ """ Swin Transformer Block.
153
+
154
+ Args:
155
+ dim (int): Number of input channels.
156
+ num_heads (int): Number of attention heads.
157
+ window_size (int): Window size.
158
+ shift_size (int): Shift size for SW-MSA.
159
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
160
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
161
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
162
+ drop (float, optional): Dropout rate. Default: 0.0
163
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
164
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
165
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
166
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
167
+ """
168
+
169
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
170
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
171
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
172
+ super().__init__()
173
+ self.dim = dim
174
+ self.num_heads = num_heads
175
+ self.window_size = window_size
176
+ self.shift_size = shift_size
177
+ self.mlp_ratio = mlp_ratio
178
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
179
+
180
+ self.norm1 = norm_layer(dim)
181
+ self.attn = WindowAttention(
182
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
183
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
184
+
185
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
186
+ self.norm2 = norm_layer(dim)
187
+ mlp_hidden_dim = int(dim * mlp_ratio)
188
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
189
+
190
+ self.H = None
191
+ self.W = None
192
+
193
+ def forward(self, x, mask_matrix):
194
+ """ Forward function.
195
+
196
+ Args:
197
+ x: Input feature, tensor size (B, H*W, C).
198
+ H, W: Spatial resolution of the input feature.
199
+ mask_matrix: Attention mask for cyclic shift.
200
+ """
201
+ B, L, C = x.shape
202
+ H, W = self.H, self.W
203
+ assert L == H * W, "input feature has wrong size"
204
+
205
+ shortcut = x
206
+ x = self.norm1(x)
207
+ x = x.view(B, H, W, C)
208
+
209
+ # pad feature maps to multiples of window size
210
+ pad_l = pad_t = 0
211
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
212
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
213
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
214
+ _, Hp, Wp, _ = x.shape
215
+
216
+ # cyclic shift
217
+ if self.shift_size > 0:
218
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
219
+ attn_mask = mask_matrix
220
+ else:
221
+ shifted_x = x
222
+ attn_mask = None
223
+
224
+ # partition windows
225
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
226
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
227
+
228
+ # W-MSA/SW-MSA
229
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
230
+
231
+ # merge windows
232
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
233
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
234
+
235
+ # reverse cyclic shift
236
+ if self.shift_size > 0:
237
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
238
+ else:
239
+ x = shifted_x
240
+
241
+ if pad_r > 0 or pad_b > 0:
242
+ x = x[:, :H, :W, :].contiguous()
243
+
244
+ x = x.view(B, H * W, C)
245
+
246
+ # FFN
247
+ x = shortcut + self.drop_path(x)
248
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
249
+
250
+ return x
251
+
252
+
253
+ class PatchMerging(nn.Module):
254
+ """ Patch Merging Layer
255
+
256
+ Args:
257
+ dim (int): Number of input channels.
258
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
259
+ """
260
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
261
+ super().__init__()
262
+ self.dim = dim
263
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
264
+ self.norm = norm_layer(4 * dim)
265
+
266
+ def forward(self, x, H, W):
267
+ """ Forward function.
268
+
269
+ Args:
270
+ x: Input feature, tensor size (B, H*W, C).
271
+ H, W: Spatial resolution of the input feature.
272
+ """
273
+ B, L, C = x.shape
274
+ assert L == H * W, "input feature has wrong size"
275
+
276
+ x = x.view(B, H, W, C)
277
+
278
+ # padding
279
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
280
+ if pad_input:
281
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
282
+
283
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
284
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
285
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
286
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
287
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
288
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
289
+
290
+ x = self.norm(x)
291
+ x = self.reduction(x)
292
+
293
+ return x
294
+
295
+
296
+ class BasicLayer(nn.Module):
297
+ """ A basic Swin Transformer layer for one stage.
298
+
299
+ Args:
300
+ dim (int): Number of feature channels
301
+ depth (int): Depths of this stage.
302
+ num_heads (int): Number of attention head.
303
+ window_size (int): Local window size. Default: 7.
304
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
305
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
306
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
307
+ drop (float, optional): Dropout rate. Default: 0.0
308
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
309
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
310
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
311
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
312
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
313
+ """
314
+
315
+ def __init__(self,
316
+ dim,
317
+ depth,
318
+ num_heads,
319
+ window_size=7,
320
+ mlp_ratio=4.,
321
+ qkv_bias=True,
322
+ qk_scale=None,
323
+ drop=0.,
324
+ attn_drop=0.,
325
+ drop_path=0.,
326
+ norm_layer=nn.LayerNorm,
327
+ downsample=None,
328
+ use_checkpoint=False):
329
+ super().__init__()
330
+ self.window_size = window_size
331
+ self.shift_size = window_size // 2
332
+ self.depth = depth
333
+ self.use_checkpoint = use_checkpoint
334
+
335
+ # build blocks
336
+ self.blocks = nn.ModuleList([
337
+ SwinTransformerBlock(
338
+ dim=dim,
339
+ num_heads=num_heads,
340
+ window_size=window_size,
341
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
342
+ mlp_ratio=mlp_ratio,
343
+ qkv_bias=qkv_bias,
344
+ qk_scale=qk_scale,
345
+ drop=drop,
346
+ attn_drop=attn_drop,
347
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
348
+ norm_layer=norm_layer)
349
+ for i in range(depth)])
350
+
351
+ # patch merging layer
352
+ if downsample is not None:
353
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
354
+ else:
355
+ self.downsample = None
356
+
357
+ def forward(self, x, H, W):
358
+ """ Forward function.
359
+
360
+ Args:
361
+ x: Input feature, tensor size (B, H*W, C).
362
+ H, W: Spatial resolution of the input feature.
363
+ """
364
+
365
+ # calculate attention mask for SW-MSA
366
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
367
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
368
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
369
+ h_slices = (slice(0, -self.window_size),
370
+ slice(-self.window_size, -self.shift_size),
371
+ slice(-self.shift_size, None))
372
+ w_slices = (slice(0, -self.window_size),
373
+ slice(-self.window_size, -self.shift_size),
374
+ slice(-self.shift_size, None))
375
+ cnt = 0
376
+ for h in h_slices:
377
+ for w in w_slices:
378
+ img_mask[:, h, w, :] = cnt
379
+ cnt += 1
380
+
381
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
382
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
383
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
384
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
385
+
386
+ for blk in self.blocks:
387
+ blk.H, blk.W = H, W
388
+ if self.use_checkpoint:
389
+ x = checkpoint.checkpoint(blk, x, attn_mask)
390
+ else:
391
+ x = blk(x, attn_mask)
392
+ if self.downsample is not None:
393
+ x_down = self.downsample(x, H, W)
394
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
395
+ return x, H, W, x_down, Wh, Ww
396
+ else:
397
+ return x, H, W, x, H, W
398
+
399
+
400
+ class PatchEmbed(nn.Module):
401
+ """ Image to Patch Embedding
402
+
403
+ Args:
404
+ patch_size (int): Patch token size. Default: 4.
405
+ in_chans (int): Number of input image channels. Default: 3.
406
+ embed_dim (int): Number of linear projection output channels. Default: 96.
407
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
408
+ """
409
+
410
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
411
+ super().__init__()
412
+ patch_size = to_2tuple(patch_size)
413
+ self.patch_size = patch_size
414
+
415
+ self.in_chans = in_chans
416
+ self.embed_dim = embed_dim
417
+
418
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
419
+ if norm_layer is not None:
420
+ self.norm = norm_layer(embed_dim)
421
+ else:
422
+ self.norm = None
423
+
424
+ def forward(self, x):
425
+ """Forward function."""
426
+ # padding
427
+ _, _, H, W = x.size()
428
+ if W % self.patch_size[1] != 0:
429
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
430
+ if H % self.patch_size[0] != 0:
431
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
432
+
433
+ x = self.proj(x) # B C Wh Ww
434
+ if self.norm is not None:
435
+ Wh, Ww = x.size(2), x.size(3)
436
+ x = x.flatten(2).transpose(1, 2)
437
+ x = self.norm(x)
438
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
439
+
440
+ return x
441
+
442
+
443
+ class SwinTransformer(nn.Module):
444
+ """ Swin Transformer backbone.
445
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
446
+ https://arxiv.org/pdf/2103.14030
447
+
448
+ Args:
449
+ pretrain_img_size (int): Input image size for training the pretrained model,
450
+ used in absolute postion embedding. Default 224.
451
+ patch_size (int | tuple(int)): Patch size. Default: 4.
452
+ in_chans (int): Number of input image channels. Default: 3.
453
+ embed_dim (int): Number of linear projection output channels. Default: 96.
454
+ depths (tuple[int]): Depths of each Swin Transformer stage.
455
+ num_heads (tuple[int]): Number of attention head of each stage.
456
+ window_size (int): Window size. Default: 7.
457
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
458
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
459
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
460
+ drop_rate (float): Dropout rate.
461
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
462
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
463
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
464
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
465
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
466
+ out_indices (Sequence[int]): Output from which stages.
467
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
468
+ -1 means not freezing any parameters.
469
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
470
+ """
471
+
472
+ def __init__(self,
473
+ pretrain_img_size=224,
474
+ patch_size=4,
475
+ in_chans=3,
476
+ embed_dim=96,
477
+ depths=[2, 2, 6, 2],
478
+ num_heads=[3, 6, 12, 24],
479
+ window_size=7,
480
+ mlp_ratio=4.,
481
+ qkv_bias=True,
482
+ qk_scale=None,
483
+ drop_rate=0.,
484
+ attn_drop_rate=0.,
485
+ drop_path_rate=0.2,
486
+ norm_layer=nn.LayerNorm,
487
+ ape=False,
488
+ patch_norm=True,
489
+ out_indices=(0, 1, 2, 3),
490
+ frozen_stages=-1,
491
+ use_checkpoint=False):
492
+ super().__init__()
493
+
494
+ self.pretrain_img_size = pretrain_img_size
495
+ self.num_layers = len(depths)
496
+ self.embed_dim = embed_dim
497
+ self.ape = ape
498
+ self.patch_norm = patch_norm
499
+ self.out_indices = out_indices
500
+ self.frozen_stages = frozen_stages
501
+
502
+ # split image into non-overlapping patches
503
+ self.patch_embed = PatchEmbed(
504
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
505
+ norm_layer=norm_layer if self.patch_norm else None)
506
+
507
+ # absolute position embedding
508
+ if self.ape:
509
+ pretrain_img_size = to_2tuple(pretrain_img_size)
510
+ patch_size = to_2tuple(patch_size)
511
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
512
+
513
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
514
+ trunc_normal_(self.absolute_pos_embed, std=.02)
515
+
516
+ self.pos_drop = nn.Dropout(p=drop_rate)
517
+
518
+ # stochastic depth
519
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
520
+
521
+ # build layers
522
+ self.layers = nn.ModuleList()
523
+ for i_layer in range(self.num_layers):
524
+ layer = BasicLayer(
525
+ dim=int(embed_dim * 2 ** i_layer),
526
+ depth=depths[i_layer],
527
+ num_heads=num_heads[i_layer],
528
+ window_size=window_size,
529
+ mlp_ratio=mlp_ratio,
530
+ qkv_bias=qkv_bias,
531
+ qk_scale=qk_scale,
532
+ drop=drop_rate,
533
+ attn_drop=attn_drop_rate,
534
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
535
+ norm_layer=norm_layer,
536
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
537
+ use_checkpoint=use_checkpoint)
538
+ self.layers.append(layer)
539
+
540
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
541
+ self.num_features = num_features
542
+
543
+ # add a norm layer for each output
544
+ for i_layer in out_indices:
545
+ layer = norm_layer(num_features[i_layer])
546
+ layer_name = f'norm{i_layer}'
547
+ self.add_module(layer_name, layer)
548
+
549
+ self._freeze_stages()
550
+
551
+ def _freeze_stages(self):
552
+ if self.frozen_stages >= 0:
553
+ self.patch_embed.eval()
554
+ for param in self.patch_embed.parameters():
555
+ param.requires_grad = False
556
+
557
+ if self.frozen_stages >= 1 and self.ape:
558
+ self.absolute_pos_embed.requires_grad = False
559
+
560
+ if self.frozen_stages >= 2:
561
+ self.pos_drop.eval()
562
+ for i in range(0, self.frozen_stages - 1):
563
+ m = self.layers[i]
564
+ m.eval()
565
+ for param in m.parameters():
566
+ param.requires_grad = False
567
+
568
+ def init_weights(self, pretrained=None):
569
+ """Initialize the weights in backbone.
570
+
571
+ Args:
572
+ pretrained (str, optional): Path to pre-trained weights.
573
+ Defaults to None.
574
+ """
575
+
576
+ def _init_weights(m):
577
+ if isinstance(m, nn.Linear):
578
+ trunc_normal_(m.weight, std=.02)
579
+ if isinstance(m, nn.Linear) and m.bias is not None:
580
+ nn.init.constant_(m.bias, 0)
581
+ elif isinstance(m, nn.LayerNorm):
582
+ nn.init.constant_(m.bias, 0)
583
+ nn.init.constant_(m.weight, 1.0)
584
+
585
+ if isinstance(pretrained, str):
586
+ self.apply(_init_weights)
587
+ logger = get_root_logger()
588
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
589
+ elif pretrained is None:
590
+ self.apply(_init_weights)
591
+ else:
592
+ raise TypeError('pretrained must be a str or None')
593
+
594
+ def forward(self, x):
595
+ """Forward function."""
596
+ x = self.patch_embed(x)
597
+
598
+ Wh, Ww = x.size(2), x.size(3)
599
+ if self.ape:
600
+ # interpolate the position embedding to the corresponding size
601
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
602
+ x = (x + absolute_pos_embed) # B Wh*Ww C
603
+
604
+ outs = [x.contiguous()]
605
+ x = x.flatten(2).transpose(1, 2)
606
+ x = self.pos_drop(x)
607
+ for i in range(self.num_layers):
608
+ layer = self.layers[i]
609
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
610
+
611
+ if i in self.out_indices:
612
+ norm_layer = getattr(self, f'norm{i}')
613
+ x_out = norm_layer(x_out)
614
+
615
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
616
+ outs.append(out)
617
+
618
+ return tuple(outs)
619
+
620
+ def train(self, mode=True):
621
+ """Convert the model into training mode while keep layers freezed."""
622
+ super(SwinTransformer, self).train(mode)
623
+ self._freeze_stages()
624
+
625
+ def SwinT(pretrained=True):
626
+ model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7)
627
+ if pretrained is True:
628
+ model.load_state_dict(torch.load('data/backbone_ckpt/swin_tiny_patch4_window7_224.pth', map_location='cpu')['model'], strict=False)
629
+
630
+ return model
631
+
632
+ def SwinS(pretrained=True):
633
+ model = SwinTransformer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7)
634
+ if pretrained is True:
635
+ model.load_state_dict(torch.load('data/backbone_ckpt/swin_small_patch4_window7_224.pth', map_location='cpu')['model'], strict=False)
636
+
637
+ return model
638
+
639
+ def SwinB(pretrained=True):
640
+ model = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
641
+ if pretrained is True:
642
+ model.load_state_dict(torch.load('data/backbone_ckpt/swin_base_patch4_window12_384_22kto1k.pth', map_location='cpu')['model'], strict=False)
643
+
644
+ return model
645
+
646
+ def SwinL(pretrained=True):
647
+ model = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12)
648
+ if pretrained is True:
649
+ model.load_state_dict(torch.load('data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth', map_location='cpu')['model'], strict=False)
650
+
651
+ return model
652
+
lib/modules/attention_module.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from torch.nn.parameter import Parameter
6
+ from operator import xor
7
+ from typing import Optional
8
+
9
+ from lib.modules.layers import *
10
+ from utils.misc import *
11
+ class SICA(nn.Module):
12
+ def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None, lmap_in=False):
13
+ super(SICA, self).__init__()
14
+ self.in_channel = in_channel
15
+ self.depth = depth
16
+ self.lmap_in = lmap_in
17
+ if base_size is not None and stage is not None:
18
+ self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage))
19
+ else:
20
+ self.stage_size = None
21
+
22
+ self.conv_query = nn.Sequential(Conv2d(in_channel, depth, 3, relu=True),
23
+ Conv2d(depth, depth, 3, relu=True))
24
+ self.conv_key = nn.Sequential(Conv2d(in_channel, depth, 1, relu=True),
25
+ Conv2d(depth, depth, 1, relu=True))
26
+ self.conv_value = nn.Sequential(Conv2d(in_channel, depth, 1, relu=True),
27
+ Conv2d(depth, depth, 1, relu=True))
28
+
29
+ if self.lmap_in is True:
30
+ self.ctx = 5
31
+ else:
32
+ self.ctx = 3
33
+
34
+ self.conv_out1 = Conv2d(depth, depth, 3, relu=True)
35
+ self.conv_out2 = Conv2d(in_channel + depth, depth, 3, relu=True)
36
+ self.conv_out3 = Conv2d(depth, depth, 3, relu=True)
37
+ self.conv_out4 = Conv2d(depth, out_channel, 1)
38
+
39
+ self.threshold = Parameter(torch.tensor([0.5]))
40
+
41
+ if self.lmap_in is True:
42
+ self.lthreshold = Parameter(torch.tensor([0.5]))
43
+
44
+ def forward(self, x, smap, lmap: Optional[torch.Tensor]=None):
45
+ assert not xor(self.lmap_in is True, lmap is not None)
46
+ b, c, h, w = x.shape
47
+
48
+ # compute class probability
49
+ smap = F.interpolate(smap, size=x.shape[-2:], mode='bilinear', align_corners=False)
50
+ smap = torch.sigmoid(smap)
51
+ p = smap - self.threshold
52
+
53
+ fg = torch.clip(p, 0, 1) # foreground
54
+ bg = torch.clip(-p, 0, 1) # background
55
+ cg = self.threshold - torch.abs(p) # confusion area
56
+
57
+ if self.lmap_in is True and lmap is not None:
58
+ lmap = F.interpolate(lmap, size=x.shape[-2:], mode='bilinear', align_corners=False)
59
+ lmap = torch.sigmoid(lmap)
60
+ lp = lmap - self.lthreshold
61
+ fp = torch.clip(lp, 0, 1) # foreground
62
+ bp = torch.clip(-lp, 0, 1) # background
63
+
64
+ prob = [fg, bg, cg, fp, bp]
65
+ else:
66
+ prob = [fg, bg, cg]
67
+
68
+ prob = torch.cat(prob, dim=1)
69
+
70
+ # reshape feature & prob
71
+ if self.stage_size is not None:
72
+ shape = self.stage_size
73
+ shape_mul = self.stage_size[0] * self.stage_size[1]
74
+ else:
75
+ shape = (h, w)
76
+ shape_mul = h * w
77
+
78
+ f = F.interpolate(x, size=shape, mode='bilinear', align_corners=False).view(b, shape_mul, -1)
79
+ prob = F.interpolate(prob, size=shape, mode='bilinear', align_corners=False).view(b, self.ctx, shape_mul)
80
+
81
+ # compute context vector
82
+ context = torch.bmm(prob, f).permute(0, 2, 1).unsqueeze(3) # b, 3, c
83
+
84
+ # k q v compute
85
+ query = self.conv_query(x).view(b, self.depth, -1).permute(0, 2, 1)
86
+ key = self.conv_key(context).view(b, self.depth, -1)
87
+ value = self.conv_value(context).view(b, self.depth, -1).permute(0, 2, 1)
88
+
89
+ # compute similarity map
90
+ sim = torch.bmm(query, key) # b, hw, c x b, c, 2
91
+ sim = (self.depth ** -.5) * sim
92
+ sim = F.softmax(sim, dim=-1)
93
+
94
+ # compute refined feature
95
+ context = torch.bmm(sim, value).permute(0, 2, 1).contiguous().view(b, -1, h, w)
96
+ context = self.conv_out1(context)
97
+
98
+ x = torch.cat([x, context], dim=1)
99
+ x = self.conv_out2(x)
100
+ x = self.conv_out3(x)
101
+ out = self.conv_out4(x)
102
+
103
+ return x, out
lib/modules/context_module.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .layers import *
6
+
7
+ class PAA_kernel(nn.Module):
8
+ def __init__(self, in_channel, out_channel, receptive_size, stage_size=None):
9
+ super(PAA_kernel, self).__init__()
10
+ self.conv0 = Conv2d(in_channel, out_channel, 1)
11
+ self.conv1 = Conv2d(out_channel, out_channel, kernel_size=(1, receptive_size))
12
+ self.conv2 = Conv2d(out_channel, out_channel, kernel_size=(receptive_size, 1))
13
+ self.conv3 = Conv2d(out_channel, out_channel, 3, dilation=receptive_size)
14
+ self.Hattn = SelfAttention(out_channel, 'h', stage_size[0] if stage_size is not None else None)
15
+ self.Wattn = SelfAttention(out_channel, 'w', stage_size[1] if stage_size is not None else None)
16
+
17
+ def forward(self, x):
18
+ x = self.conv0(x)
19
+ x = self.conv1(x)
20
+ x = self.conv2(x)
21
+
22
+ Hx = self.Hattn(x)
23
+ Wx = self.Wattn(x)
24
+
25
+ x = self.conv3(Hx + Wx)
26
+ return x
27
+
28
+ class PAA_e(nn.Module):
29
+ def __init__(self, in_channel, out_channel, base_size=None, stage=None):
30
+ super(PAA_e, self).__init__()
31
+ self.relu = nn.ReLU(True)
32
+ if base_size is not None and stage is not None:
33
+ self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage))
34
+ else:
35
+ self.stage_size = None
36
+
37
+ self.branch0 = Conv2d(in_channel, out_channel, 1)
38
+ self.branch1 = PAA_kernel(in_channel, out_channel, 3, self.stage_size)
39
+ self.branch2 = PAA_kernel(in_channel, out_channel, 5, self.stage_size)
40
+ self.branch3 = PAA_kernel(in_channel, out_channel, 7, self.stage_size)
41
+
42
+ self.conv_cat = Conv2d(4 * out_channel, out_channel, 3)
43
+ self.conv_res = Conv2d(in_channel, out_channel, 1)
44
+
45
+ def forward(self, x):
46
+ x0 = self.branch0(x)
47
+ x1 = self.branch1(x)
48
+ x2 = self.branch2(x)
49
+ x3 = self.branch3(x)
50
+
51
+ x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
52
+ x = self.relu(x_cat + self.conv_res(x))
53
+
54
+ return x
lib/modules/decoder_module.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .layers import *
6
+ class PAA_d(nn.Module):
7
+ def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None):
8
+ super(PAA_d, self).__init__()
9
+ self.conv1 = Conv2d(in_channel ,depth, 3)
10
+ self.conv2 = Conv2d(depth, depth, 3)
11
+ self.conv3 = Conv2d(depth, depth, 3)
12
+ self.conv4 = Conv2d(depth, depth, 3)
13
+ self.conv5 = Conv2d(depth, out_channel, 3, bn=False)
14
+
15
+ self.base_size = base_size
16
+ self.stage = stage
17
+
18
+ if base_size is not None and stage is not None:
19
+ self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage))
20
+ else:
21
+ self.stage_size = [None, None]
22
+
23
+ self.Hattn = SelfAttention(depth, 'h', self.stage_size[0])
24
+ self.Wattn = SelfAttention(depth, 'w', self.stage_size[1])
25
+
26
+ self.upsample = lambda img, size: F.interpolate(img, size=size, mode='bilinear', align_corners=True)
27
+
28
+ def forward(self, fs): #f3 f4 f5 -> f3 f2 f1
29
+ fx = fs[0]
30
+ for i in range(1, len(fs)):
31
+ fs[i] = self.upsample(fs[i], fx.shape[-2:])
32
+ fx = torch.cat(fs[::-1], dim=1)
33
+
34
+ fx = self.conv1(fx)
35
+
36
+ Hfx = self.Hattn(fx)
37
+ Wfx = self.Wattn(fx)
38
+
39
+ fx = self.conv2(Hfx + Wfx)
40
+ fx = self.conv3(fx)
41
+ fx = self.conv4(fx)
42
+ out = self.conv5(fx)
43
+
44
+ return fx, out
lib/modules/layers.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from optparse import Option
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+ from kornia.morphology import dilation, erosion
10
+ from torch.nn.parameter import Parameter
11
+ from typing import Optional
12
+ class ImagePyramid:
13
+ def __init__(self, ksize=7, sigma=1, channels=1):
14
+ self.ksize = ksize
15
+ self.sigma = sigma
16
+ self.channels = channels
17
+
18
+ k = cv2.getGaussianKernel(ksize, sigma)
19
+ k = np.outer(k, k)
20
+ k = torch.tensor(k).float()
21
+ self.kernel = k.repeat(channels, 1, 1, 1)
22
+
23
+ def to(self, device):
24
+ self.kernel = self.kernel.to(device)
25
+ return self
26
+
27
+ def cuda(self, idx=None):
28
+ if idx is None:
29
+ idx = torch.cuda.current_device()
30
+
31
+ self.to(device="cuda:{}".format(idx))
32
+ return self
33
+
34
+ def expand(self, x):
35
+ z = torch.zeros_like(x)
36
+ x = torch.cat([x, z, z, z], dim=1)
37
+ x = F.pixel_shuffle(x, 2)
38
+ x = F.pad(x, (self.ksize // 2, ) * 4, mode='reflect')
39
+ x = F.conv2d(x, self.kernel * 4, groups=self.channels)
40
+ return x
41
+
42
+ def reduce(self, x):
43
+ x = F.pad(x, (self.ksize // 2, ) * 4, mode='reflect')
44
+ x = F.conv2d(x, self.kernel, groups=self.channels)
45
+ x = x[:, :, ::2, ::2]
46
+ return x
47
+
48
+ def deconstruct(self, x):
49
+ reduced_x = self.reduce(x)
50
+ expanded_reduced_x = self.expand(reduced_x)
51
+
52
+ if x.shape != expanded_reduced_x.shape:
53
+ expanded_reduced_x = F.interpolate(expanded_reduced_x, x.shape[-2:])
54
+
55
+ laplacian_x = x - expanded_reduced_x
56
+ return reduced_x, laplacian_x
57
+
58
+ def reconstruct(self, x, laplacian_x):
59
+ expanded_x = self.expand(x)
60
+ if laplacian_x.shape != expanded_x:
61
+ laplacian_x = F.interpolate(laplacian_x, expanded_x.shape[-2:], mode='bilinear', align_corners=True)
62
+ return expanded_x + laplacian_x
63
+
64
+ class Transition:
65
+ def __init__(self, k=3):
66
+ self.kernel = torch.tensor(cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))).float()
67
+
68
+ def to(self, device):
69
+ self.kernel = self.kernel.to(device)
70
+ return self
71
+
72
+ def cuda(self, idx=None):
73
+ if idx is None:
74
+ idx = torch.cuda.current_device()
75
+
76
+ self.to(device="cuda:{}".format(idx))
77
+ return self
78
+
79
+ def __call__(self, x):
80
+ x = torch.sigmoid(x)
81
+ dx = dilation(x, self.kernel)
82
+ ex = erosion(x, self.kernel)
83
+
84
+ return ((dx - ex) > .5).float()
85
+
86
+ class Conv2d(nn.Module):
87
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, padding='same', bias=False, bn=True, relu=False):
88
+ super(Conv2d, self).__init__()
89
+ if '__iter__' not in dir(kernel_size):
90
+ kernel_size = (kernel_size, kernel_size)
91
+ if '__iter__' not in dir(stride):
92
+ stride = (stride, stride)
93
+ if '__iter__' not in dir(dilation):
94
+ dilation = (dilation, dilation)
95
+
96
+ if padding == 'same':
97
+ width_pad_size = kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1)
98
+ height_pad_size = kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1)
99
+ elif padding == 'valid':
100
+ width_pad_size = 0
101
+ height_pad_size = 0
102
+ else:
103
+ if '__iter__' in dir(padding):
104
+ width_pad_size = padding[0] * 2
105
+ height_pad_size = padding[1] * 2
106
+ else:
107
+ width_pad_size = padding * 2
108
+ height_pad_size = padding * 2
109
+
110
+ width_pad_size = width_pad_size // 2 + (width_pad_size % 2 - 1)
111
+ height_pad_size = height_pad_size // 2 + (height_pad_size % 2 - 1)
112
+ pad_size = (width_pad_size, height_pad_size)
113
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_size, dilation, groups, bias=bias)
114
+ self.reset_parameters()
115
+
116
+ if bn is True:
117
+ self.bn = nn.BatchNorm2d(out_channels)
118
+ else:
119
+ self.bn = None
120
+
121
+ if relu is True:
122
+ self.relu = nn.ReLU(inplace=True)
123
+ else:
124
+ self.relu = None
125
+
126
+ def forward(self, x):
127
+ x = self.conv(x)
128
+ if self.bn is not None:
129
+ x = self.bn(x)
130
+ if self.relu is not None:
131
+ x = self.relu(x)
132
+ return x
133
+
134
+ def reset_parameters(self):
135
+ nn.init.kaiming_normal_(self.conv.weight)
136
+
137
+
138
+ class SelfAttention(nn.Module):
139
+ def __init__(self, in_channels, mode='hw', stage_size=None):
140
+ super(SelfAttention, self).__init__()
141
+
142
+ self.mode = mode
143
+
144
+ self.query_conv = Conv2d(in_channels, in_channels // 8, kernel_size=(1, 1))
145
+ self.key_conv = Conv2d(in_channels, in_channels // 8, kernel_size=(1, 1))
146
+ self.value_conv = Conv2d(in_channels, in_channels, kernel_size=(1, 1))
147
+
148
+ self.gamma = Parameter(torch.zeros(1))
149
+ self.softmax = nn.Softmax(dim=-1)
150
+
151
+ self.stage_size = stage_size
152
+
153
+ def forward(self, x):
154
+ batch_size, channel, height, width = x.size()
155
+
156
+ axis = 1
157
+ if 'h' in self.mode:
158
+ axis *= height
159
+ if 'w' in self.mode:
160
+ axis *= width
161
+
162
+ view = (batch_size, -1, axis)
163
+
164
+ projected_query = self.query_conv(x).view(*view).permute(0, 2, 1)
165
+ projected_key = self.key_conv(x).view(*view)
166
+
167
+ attention_map = torch.bmm(projected_query, projected_key)
168
+ attention = self.softmax(attention_map)
169
+ projected_value = self.value_conv(x).view(*view)
170
+
171
+ out = torch.bmm(projected_value, attention.permute(0, 2, 1))
172
+ out = out.view(batch_size, channel, height, width)
173
+
174
+ out = self.gamma * out + x
175
+ return out
lib/optim/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .losses import *
2
+ from .scheduler import *
lib/optim/losses.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def bce_loss(pred, mask, reduction='none'):
5
+ bce = F.binary_cross_entropy(pred, mask, reduction=reduction)
6
+ return bce
7
+
8
+ def weighted_bce_loss(pred, mask, reduction='none'):
9
+ weight = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
10
+ weight = weight.flatten()
11
+
12
+ bce = weight * bce_loss(pred, mask, reduction='none').flatten()
13
+
14
+ if reduction == 'mean':
15
+ bce = bce.mean()
16
+
17
+ return bce
18
+
19
+ def iou_loss(pred, mask, reduction='none'):
20
+ inter = pred * mask
21
+ union = pred + mask
22
+ iou = 1 - (inter + 1) / (union - inter + 1)
23
+
24
+ if reduction == 'mean':
25
+ iou = iou.mean()
26
+
27
+ return iou
28
+
29
+ def bce_loss_with_logits(pred, mask, reduction='none'):
30
+ return bce_loss(torch.sigmoid(pred), mask, reduction=reduction)
31
+
32
+ def weighted_bce_loss_with_logits(pred, mask, reduction='none'):
33
+ return weighted_bce_loss(torch.sigmoid(pred), mask, reduction=reduction)
34
+
35
+ def iou_loss_with_logits(pred, mask, reduction='none'):
36
+ return iou_loss(torch.sigmoid(pred), mask, reduction=reduction)
lib/optim/scheduler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim.lr_scheduler import _LRScheduler
2
+
3
+
4
+ class PolyLr(_LRScheduler):
5
+ def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1):
6
+ self.gamma = gamma
7
+ self.max_iteration = max_iteration
8
+ self.minimum_lr = minimum_lr
9
+ self.warmup_iteration = warmup_iteration
10
+
11
+ self.last_epoch = None
12
+ self.base_lrs = []
13
+
14
+ super(PolyLr, self).__init__(optimizer, last_epoch)
15
+
16
+ def poly_lr(self, base_lr, step):
17
+ return (base_lr - self.minimum_lr) * ((1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr
18
+
19
+ def warmup_lr(self, base_lr, alpha):
20
+ return base_lr * (1 / 10.0 * (1 - alpha) + alpha)
21
+
22
+ def get_lr(self):
23
+ if self.last_epoch < self.warmup_iteration:
24
+ alpha = self.last_epoch / self.warmup_iteration
25
+ lrs = [min(self.warmup_lr(base_lr, alpha), self.poly_lr(base_lr, self.last_epoch)) for base_lr in
26
+ self.base_lrs]
27
+ else:
28
+ lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
29
+
30
+ return lrs
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ easydict
2
+ timm
3
+ tqdm
4
+ kornia
5
+ numpy
6
+ opencv-python
7
+ pyyaml
8
+ pandas
9
+ scipy
run/Eval.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tqdm
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+
10
+ filepath = os.path.split(os.path.abspath(__file__))[0]
11
+ repopath = os.path.split(filepath)[0]
12
+ sys.path.append(repopath)
13
+
14
+ from utils.eval_functions import *
15
+ from utils.misc import *
16
+
17
+ BETA = 1.0
18
+
19
+ def evaluate(opt, args):
20
+ if os.path.isdir(opt.Eval.result_path) is False:
21
+ os.makedirs(opt.Eval.result_path)
22
+
23
+ method = os.path.split(opt.Eval.pred_root)[-1]
24
+
25
+ if args.verbose is True:
26
+ print('#' * 20, 'Start Evaluation', '#' * 20)
27
+ datasets = tqdm.tqdm(opt.Eval.datasets, desc='Expr - ' + method, total=len(
28
+ opt.Eval.datasets), position=0, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
29
+ else:
30
+ datasets = opt.Eval.datasets
31
+
32
+ results = []
33
+
34
+ for dataset in datasets:
35
+ pred_root = os.path.join(opt.Eval.pred_root, dataset)
36
+ gt_root = os.path.join(opt.Eval.gt_root, dataset, 'masks')
37
+
38
+ preds = os.listdir(pred_root)
39
+ gts = os.listdir(gt_root)
40
+
41
+ preds = sort(preds)
42
+ gts = sort(gts)
43
+
44
+ preds = [i for i in preds if i in gts]
45
+ gts = [i for i in gts if i in preds]
46
+
47
+ FM = Fmeasure()
48
+ WFM = WeightedFmeasure()
49
+ SM = Smeasure()
50
+ EM = Emeasure()
51
+ MAE = Mae()
52
+ MSE = Mse()
53
+ MBA = BoundaryAccuracy()
54
+ IOU = IoU()
55
+ BIOU = BIoU()
56
+ TIOU = TIoU()
57
+
58
+ if args.verbose is True:
59
+ samples = tqdm.tqdm(enumerate(zip(preds, gts)), desc=dataset + ' - Evaluation', total=len(
60
+ preds), position=1, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
61
+ else:
62
+ samples = enumerate(zip(preds, gts))
63
+
64
+ for i, sample in samples:
65
+ pred, gt = sample
66
+
67
+ pred_mask = np.array(Image.open(os.path.join(pred_root, pred)).convert('L'))
68
+ gt_mask = np.array(Image.open(os.path.join(gt_root, gt)).convert('L'))
69
+
70
+ if len(pred_mask.shape) != 2:
71
+ pred_mask = pred_mask[:, :, 0]
72
+ if len(gt_mask.shape) != 2:
73
+ gt_mask = gt_mask[:, :, 0]
74
+
75
+ assert pred_mask.shape == gt_mask.shape, print(pred, 'does not match the size of', gt)
76
+ # print(gt_mask.max())
77
+
78
+ FM.step( pred=pred_mask, gt=gt_mask)
79
+ WFM.step(pred=pred_mask, gt=gt_mask)
80
+ SM.step( pred=pred_mask, gt=gt_mask)
81
+ EM.step( pred=pred_mask, gt=gt_mask)
82
+ MAE.step(pred=pred_mask, gt=gt_mask)
83
+ MSE.step(pred=pred_mask, gt=gt_mask)
84
+ MBA.step(pred=pred_mask, gt=gt_mask)
85
+ IOU.step(pred=pred_mask, gt=gt_mask)
86
+ BIOU.step(pred=pred_mask, gt=gt_mask)
87
+ TIOU.step(pred=pred_mask, gt=gt_mask)
88
+
89
+ result = []
90
+
91
+ Sm = SM.get_results()["sm"]
92
+ wFm = WFM.get_results()["wfm"]
93
+ mae = MAE.get_results()["mae"]
94
+ mse = MSE.get_results()["mse"]
95
+ mBA = MBA.get_results()["mba"]
96
+
97
+ Fm = FM.get_results()["fm"]
98
+ Em = EM.get_results()["em"]
99
+ Iou = IOU.get_results()["iou"]
100
+ BIou = BIOU.get_results()["biou"]
101
+ TIou = TIOU.get_results()["tiou"]
102
+
103
+ adpEm = Em["adp"]
104
+ avgEm = Em["curve"].mean()
105
+ maxEm = Em["curve"].max()
106
+ adpFm = Fm["adp"]
107
+ avgFm = Fm["curve"].mean()
108
+ maxFm = Fm["curve"].max()
109
+ avgIou = Iou["curve"].mean()
110
+ maxIou = Iou["curve"].max()
111
+ avgBIou = BIou["curve"].mean()
112
+ maxBIou = BIou["curve"].max()
113
+ avgTIou = TIou["curve"].mean()
114
+ maxTIou = TIou["curve"].max()
115
+
116
+ out = dict()
117
+ for metric in opt.Eval.metrics:
118
+ out[metric] = eval(metric)
119
+
120
+ pkl = os.path.join(opt.Eval.result_path, 'result_' + dataset + '.pkl')
121
+ if os.path.isfile(pkl) is True:
122
+ result = pd.read_pickle(pkl)
123
+ result.loc[method] = out
124
+ result.to_pickle(pkl)
125
+ else:
126
+ result = pd.DataFrame(data=out, index=[method])
127
+ result.to_pickle(pkl)
128
+ result.to_csv(os.path.join(opt.Eval.result_path, 'result_' + dataset + '.csv'))
129
+ results.append(result)
130
+
131
+ if args.verbose is True:
132
+ for dataset, result in zip(datasets, results):
133
+ print('###', dataset, '###', '\n', result.sort_index(), '\n')
134
+
135
+ if __name__ == "__main__":
136
+ args = parse_args()
137
+ opt = load_config(args.config)
138
+ evaluate(opt, args)
run/Inference.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+ import tqdm
5
+ import torch
6
+ import argparse
7
+
8
+ import numpy as np
9
+
10
+ from PIL import Image
11
+
12
+ filepath = os.path.split(os.path.abspath(__file__))[0]
13
+ repopath = os.path.split(filepath)[0]
14
+ sys.path.append(repopath)
15
+
16
+ from lib import *
17
+ from utils.misc import *
18
+ from data.dataloader import *
19
+ from data.custom_transforms import *
20
+
21
+ torch.backends.cuda.matmul.allow_tf32 = False
22
+ torch.backends.cudnn.allow_tf32 = False
23
+
24
+ def _args():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml')
27
+ parser.add_argument('--source', '-s', type=str)
28
+ parser.add_argument('--dest', '-d', type=str, default=None)
29
+ parser.add_argument('--type', '-t', type=str, default='map')
30
+ parser.add_argument('--gpu', '-g', action='store_true', default=False)
31
+ parser.add_argument('--jit', '-j', action='store_true', default=False)
32
+ parser.add_argument('--verbose', '-v', action='store_true', default=False)
33
+ return parser.parse_args()
34
+
35
+ def get_format(source):
36
+ img_count = len([i for i in source if i.lower().endswith(('.jpg', '.png', '.jpeg'))])
37
+ vid_count = len([i for i in source if i.lower().endswith(('.mp4', '.avi', '.mov' ))])
38
+
39
+ if img_count * vid_count != 0:
40
+ return ''
41
+ elif img_count != 0:
42
+ return 'Image'
43
+ elif vid_count != 0:
44
+ return 'Video'
45
+ else:
46
+ return ''
47
+
48
+ def inference(opt, args):
49
+ model = eval(opt.Model.name)(**opt.Model)
50
+ model.load_state_dict(torch.load(os.path.join(
51
+ opt.Test.Checkpoint.checkpoint_dir, 'latest.pth'), map_location=torch.device('cpu')), strict=True)
52
+
53
+ if args.gpu is True:
54
+ model = model.cuda()
55
+ model.eval()
56
+
57
+ if args.jit is True:
58
+ if os.path.isfile(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) is False:
59
+ model = Simplify(model)
60
+ model = torch.jit.trace(model, torch.rand(1, 3, *opt.Test.Dataset.transforms.static_resize.size).cuda(), strict=False)
61
+ torch.jit.save(model, os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt'))
62
+
63
+ else:
64
+ del model
65
+ model = torch.jit.load(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt'))
66
+
67
+ save_dir = None
68
+ _format = None
69
+
70
+ if args.source.isnumeric() is True:
71
+ _format = 'Webcam'
72
+
73
+ elif os.path.isdir(args.source):
74
+ save_dir = os.path.join('results', args.source.split(os.sep)[-1])
75
+ _format = get_format(os.listdir(args.source))
76
+
77
+ elif os.path.isfile(args.source):
78
+ save_dir = 'results'
79
+ _format = get_format([args.source])
80
+
81
+ if args.dest is not None:
82
+ save_dir = args.dest
83
+
84
+ if save_dir is not None:
85
+ os.makedirs(save_dir, exist_ok=True)
86
+
87
+ sample_list = eval(_format + 'Loader')(args.source, opt.Test.Dataset.transforms)
88
+
89
+ if args.verbose is True:
90
+ samples = tqdm.tqdm(sample_list, desc='Inference', total=len(
91
+ sample_list), position=0, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
92
+ else:
93
+ samples = sample_list
94
+
95
+ writer = None
96
+ background = None
97
+
98
+ for sample in samples:
99
+ if _format == 'Video' and writer is None:
100
+ writer = cv2.VideoWriter(os.path.join(save_dir, sample['name'] + '.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), sample_list.fps, sample['shape'][::-1])
101
+ samples.total += int(sample_list.cap.get(cv2.CAP_PROP_FRAME_COUNT))
102
+ if _format == 'Video' and sample['image'] is None:
103
+ if writer is not None:
104
+ writer.release()
105
+ writer = None
106
+ continue
107
+
108
+ if args.gpu is True:
109
+ sample = to_cuda(sample)
110
+
111
+ with torch.no_grad():
112
+ if args.jit is True:
113
+ out = model(sample['image'])
114
+ else:
115
+ out = model(sample)
116
+
117
+
118
+ pred = to_numpy(out['pred'], sample['shape'])
119
+ img = np.array(sample['original'])
120
+
121
+ if args.type == 'map':
122
+ img = (np.stack([pred] * 3, axis=-1) * 255).astype(np.uint8)
123
+ elif args.type == 'rgba':
124
+ r, g, b = cv2.split(img)
125
+ pred = (pred * 255).astype(np.uint8)
126
+ img = cv2.merge([r, g, b, pred])
127
+ elif args.type == 'green':
128
+ bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155]
129
+ img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis])
130
+ elif args.type == 'blur':
131
+ img = img * pred[..., np.newaxis] + cv2.GaussianBlur(img, (0, 0), 15) * (1 - pred[..., np.newaxis])
132
+ elif args.type == 'overlay':
133
+ bg = (np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] + img) // 2
134
+ img = bg * pred[..., np.newaxis] + img * (1 - pred[..., np.newaxis])
135
+ border = cv2.Canny(((pred > .5) * 255).astype(np.uint8), 50, 100)
136
+ img[border != 0] = [120, 255, 155]
137
+ elif args.type.lower().endswith(('.jpg', '.jpeg', '.png')):
138
+ if background is None:
139
+ background = cv2.cvtColor(cv2.imread(args.type), cv2.COLOR_BGR2RGB)
140
+ background = cv2.resize(background, img.shape[:2][::-1])
141
+ img = img * pred[..., np.newaxis] + background * (1 - pred[..., np.newaxis])
142
+ elif args.type == 'debug':
143
+ debs = []
144
+ for k in opt.Train.Debug.keys:
145
+ debs.extend(out[k])
146
+ for i, j in enumerate(debs):
147
+ log = torch.sigmoid(j).cpu().detach().numpy().squeeze()
148
+ log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8)
149
+ log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB)
150
+ log = cv2.resize(log, img.shape[:2][::-1])
151
+ Image.fromarray(log).save(os.path.join(save_dir, sample['name'] + '_' + str(i) + '.png'))
152
+ # size=img.shape[:2][::-1]
153
+
154
+
155
+ img = img.astype(np.uint8)
156
+
157
+ if _format == 'Image':
158
+ Image.fromarray(img).save(os.path.join(save_dir, sample['name'] + '.png'))
159
+ elif _format == 'Video' and writer is not None:
160
+ writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
161
+ elif _format == 'Webcam':
162
+ cv2.imshow('InSPyReNet', img)
163
+
164
+ if __name__ == "__main__":
165
+ args = _args()
166
+ opt = load_config(args.config)
167
+ inference(opt, args)
run/Test.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tqdm
4
+ import torch
5
+
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+ from torch.utils.data.dataloader import DataLoader
10
+
11
+ filepath = os.path.split(os.path.abspath(__file__))[0]
12
+ repopath = os.path.split(filepath)[0]
13
+ sys.path.append(repopath)
14
+
15
+ from lib import *
16
+ from utils.misc import *
17
+ from data.dataloader import *
18
+
19
+ torch.backends.cuda.matmul.allow_tf32 = False
20
+ torch.backends.cudnn.allow_tf32 = False
21
+
22
+ def test(opt, args):
23
+ model = eval(opt.Model.name)(**opt.Model)
24
+ model.load_state_dict(torch.load(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'latest.pth')), strict=True)
25
+
26
+ model.cuda()
27
+ model.eval()
28
+
29
+ if args.verbose is True:
30
+ sets = tqdm.tqdm(opt.Test.Dataset.sets, desc='Total TestSet', total=len(
31
+ opt.Test.Dataset.sets), position=0, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
32
+ else:
33
+ sets = opt.Test.Dataset.sets
34
+
35
+ for set in sets:
36
+ save_path = os.path.join(opt.Test.Checkpoint.checkpoint_dir, set)
37
+
38
+ os.makedirs(save_path, exist_ok=True)
39
+ test_dataset = eval(opt.Test.Dataset.type)(opt.Test.Dataset.root, [set], opt.Test.Dataset.transforms)
40
+ test_loader = DataLoader(dataset=test_dataset, batch_size=1, num_workers=opt.Test.Dataloader.num_workers, pin_memory=opt.Test.Dataloader.pin_memory)
41
+
42
+ if args.verbose is True:
43
+ samples = tqdm.tqdm(test_loader, desc=set + ' - Test', total=len(test_loader),
44
+ position=1, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
45
+ else:
46
+ samples = test_loader
47
+
48
+ for sample in samples:
49
+ sample = to_cuda(sample)
50
+ with torch.no_grad():
51
+ out = model(sample)
52
+
53
+ pred = to_numpy(out['pred'], sample['shape'])
54
+ Image.fromarray((pred * 255).astype(np.uint8)).save(os.path.join(save_path, sample['name'][0] + '.png'))
55
+
56
+ if __name__ == "__main__":
57
+ args = parse_args()
58
+ opt = load_config(args.config)
59
+ test(opt, args)