mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-31 22:55:40 +08:00 
			
		
		
		
	ultralytics 8.0.239 Ultralytics Actions and hub-sdk adoption (#7431)
				
					
				
			Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									e795277391
								
							
						
					
					
						commit
						fe27db2f6e
					
				
							
								
								
									
										2
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							| @ -95,7 +95,7 @@ jobs: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         os: [ubuntu-latest] | ||||
|         python-version: ['3.10'] | ||||
|         python-version: ['3.11'] | ||||
|         model: [yolov8n] | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|  | ||||
							
								
								
									
										25
									
								
								.github/workflows/format.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								.github/workflows/format.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,25 @@ | ||||
| # Ultralytics 🚀 - AGPL-3.0 license | ||||
| # Ultralytics Actions https://github.com/ultralytics/actions | ||||
| # This workflow automatically formats code and documentation in PRs to official Ultralytics standards | ||||
| 
 | ||||
| name: Ultralytics Actions | ||||
| 
 | ||||
| on: | ||||
|   push: | ||||
|     branches: [main] | ||||
|   pull_request: | ||||
|     branches: [main] | ||||
| 
 | ||||
| jobs: | ||||
|   format: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - name: Run Ultralytics Formatting | ||||
|         uses: ultralytics/actions@main | ||||
|         with: | ||||
|           token: ${{ secrets.GITHUB_TOKEN }}  # automatically generated | ||||
|           python: true | ||||
|           docstrings: true | ||||
|           markdown: true | ||||
|           spelling: true | ||||
|           links: false | ||||
| @ -22,7 +22,6 @@ repos: | ||||
|       - id: check-case-conflict | ||||
|       # - id: check-yaml | ||||
|       - id: check-docstring-first | ||||
|       - id: double-quote-string-fixer | ||||
|       - id: detect-private-key | ||||
| 
 | ||||
|   - repo: https://github.com/asottile/pyupgrade | ||||
| @ -64,7 +63,7 @@ repos: | ||||
|       - id: codespell | ||||
|         exclude: 'docs/de|docs/fr|docs/pt|docs/es|docs/mkdocs_de.yml' | ||||
|         args: | ||||
|           - --ignore-words-list=crate,nd,strack,dota,ane,segway,fo,gool,winn | ||||
|           - --ignore-words-list=crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall | ||||
| 
 | ||||
|   - repo: https://github.com/PyCQA/docformatter | ||||
|     rev: v1.7.5 | ||||
|  | ||||
| @ -30,45 +30,47 @@ import subprocess | ||||
| from pathlib import Path | ||||
| 
 | ||||
| DOCS = Path(__file__).parent.resolve() | ||||
| SITE = DOCS.parent / 'site' | ||||
| SITE = DOCS.parent / "site" | ||||
| 
 | ||||
| 
 | ||||
| def build_docs(): | ||||
|     """Build docs using mkdocs.""" | ||||
|     if SITE.exists(): | ||||
|         print(f'Removing existing {SITE}') | ||||
|         print(f"Removing existing {SITE}") | ||||
|         shutil.rmtree(SITE) | ||||
| 
 | ||||
|     # Build the main documentation | ||||
|     print(f'Building docs from {DOCS}') | ||||
|     subprocess.run(f'mkdocs build -f {DOCS}/mkdocs.yml', check=True, shell=True) | ||||
|     print(f"Building docs from {DOCS}") | ||||
|     subprocess.run(f"mkdocs build -f {DOCS}/mkdocs.yml", check=True, shell=True) | ||||
| 
 | ||||
|     # Build other localized documentations | ||||
|     for file in DOCS.glob('mkdocs_*.yml'): | ||||
|         print(f'Building MkDocs site with configuration file: {file}') | ||||
|         subprocess.run(f'mkdocs build -f {file}', check=True, shell=True) | ||||
|     print(f'Site built at {SITE}') | ||||
|     for file in DOCS.glob("mkdocs_*.yml"): | ||||
|         print(f"Building MkDocs site with configuration file: {file}") | ||||
|         subprocess.run(f"mkdocs build -f {file}", check=True, shell=True) | ||||
|     print(f"Site built at {SITE}") | ||||
| 
 | ||||
| 
 | ||||
| def update_html_links(): | ||||
|     """Update href links in HTML files to remove '.md' and '/index.md', excluding links starting with 'https://'.""" | ||||
|     html_files = Path(SITE).rglob('*.html') | ||||
|     html_files = Path(SITE).rglob("*.html") | ||||
|     total_updated_links = 0 | ||||
| 
 | ||||
|     for html_file in html_files: | ||||
|         with open(html_file, 'r+', encoding='utf-8') as file: | ||||
|         with open(html_file, "r+", encoding="utf-8") as file: | ||||
|             content = file.read() | ||||
|             # Find all links to be updated, excluding those starting with 'https://' | ||||
|             links_to_update = re.findall(r'href="(?!https://)([^"]+?)(/index)?\.md"', content) | ||||
| 
 | ||||
|             # Update the content and count the number of links updated | ||||
|             updated_content, number_of_links_updated = re.subn(r'href="(?!https://)([^"]+?)(/index)?\.md"', | ||||
|                                                                r'href="\1"', content) | ||||
|             updated_content, number_of_links_updated = re.subn( | ||||
|                 r'href="(?!https://)([^"]+?)(/index)?\.md"', r'href="\1"', content | ||||
|             ) | ||||
|             total_updated_links += number_of_links_updated | ||||
| 
 | ||||
|             # Special handling for '/index' links | ||||
|             updated_content, number_of_index_links_updated = re.subn(r'href="([^"]+)/index"', r'href="\1/"', | ||||
|                                                                      updated_content) | ||||
|             updated_content, number_of_index_links_updated = re.subn( | ||||
|                 r'href="([^"]+)/index"', r'href="\1/"', updated_content | ||||
|             ) | ||||
|             total_updated_links += number_of_index_links_updated | ||||
| 
 | ||||
|             # Write the updated content back to the file | ||||
| @ -78,23 +80,23 @@ def update_html_links(): | ||||
| 
 | ||||
|             # Print updated links for this file | ||||
|             for link in links_to_update: | ||||
|                 print(f'Updated link in {html_file}: {link[0]}') | ||||
|                 print(f"Updated link in {html_file}: {link[0]}") | ||||
| 
 | ||||
|     print(f'Total number of links updated: {total_updated_links}') | ||||
|     print(f"Total number of links updated: {total_updated_links}") | ||||
| 
 | ||||
| 
 | ||||
| def update_page_title(file_path: Path, new_title: str): | ||||
|     """Update the title of an HTML file.""" | ||||
| 
 | ||||
|     # Read the content of the file | ||||
|     with open(file_path, encoding='utf-8') as file: | ||||
|     with open(file_path, encoding="utf-8") as file: | ||||
|         content = file.read() | ||||
| 
 | ||||
|     # Replace the existing title with the new title | ||||
|     updated_content = re.sub(r'<title>.*?</title>', f'<title>{new_title}</title>', content) | ||||
|     updated_content = re.sub(r"<title>.*?</title>", f"<title>{new_title}</title>", content) | ||||
| 
 | ||||
|     # Write the updated content back to the file | ||||
|     with open(file_path, 'w', encoding='utf-8') as file: | ||||
|     with open(file_path, "w", encoding="utf-8") as file: | ||||
|         file.write(updated_content) | ||||
| 
 | ||||
| 
 | ||||
| @ -109,8 +111,8 @@ def main(): | ||||
|     print('Serve site at http://localhost:8000 with "python -m http.server --directory site"') | ||||
| 
 | ||||
|     # Update titles | ||||
|     update_page_title(SITE / '404.html', new_title='Ultralytics Docs - Not Found') | ||||
|     update_page_title(SITE / "404.html", new_title="Ultralytics Docs - Not Found") | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  | ||||
| @ -14,14 +14,14 @@ from ultralytics.utils import ROOT | ||||
| 
 | ||||
| NEW_YAML_DIR = ROOT.parent | ||||
| CODE_DIR = ROOT | ||||
| REFERENCE_DIR = ROOT.parent / 'docs/en/reference' | ||||
| REFERENCE_DIR = ROOT.parent / "docs/en/reference" | ||||
| 
 | ||||
| 
 | ||||
| def extract_classes_and_functions(filepath: Path) -> tuple: | ||||
|     """Extracts class and function names from a given Python file.""" | ||||
|     content = filepath.read_text() | ||||
|     class_pattern = r'(?:^|\n)class\s(\w+)(?:\(|:)' | ||||
|     func_pattern = r'(?:^|\n)def\s(\w+)\(' | ||||
|     class_pattern = r"(?:^|\n)class\s(\w+)(?:\(|:)" | ||||
|     func_pattern = r"(?:^|\n)def\s(\w+)\(" | ||||
| 
 | ||||
|     classes = re.findall(class_pattern, content) | ||||
|     functions = re.findall(func_pattern, content) | ||||
| @ -31,31 +31,31 @@ def extract_classes_and_functions(filepath: Path) -> tuple: | ||||
| 
 | ||||
| def create_markdown(py_filepath: Path, module_path: str, classes: list, functions: list): | ||||
|     """Creates a Markdown file containing the API reference for the given Python module.""" | ||||
|     md_filepath = py_filepath.with_suffix('.md') | ||||
|     md_filepath = py_filepath.with_suffix(".md") | ||||
| 
 | ||||
|     # Read existing content and keep header content between first two --- | ||||
|     header_content = '' | ||||
|     header_content = "" | ||||
|     if md_filepath.exists(): | ||||
|         existing_content = md_filepath.read_text() | ||||
|         header_parts = existing_content.split('---') | ||||
|         header_parts = existing_content.split("---") | ||||
|         for part in header_parts: | ||||
|             if 'description:' in part or 'comments:' in part: | ||||
|                 header_content += f'---{part}---\n\n' | ||||
|             if "description:" in part or "comments:" in part: | ||||
|                 header_content += f"---{part}---\n\n" | ||||
| 
 | ||||
|     module_name = module_path.replace('.__init__', '') | ||||
|     module_path = module_path.replace('.', '/') | ||||
|     url = f'https://github.com/ultralytics/ultralytics/blob/main/{module_path}.py' | ||||
|     edit = f'https://github.com/ultralytics/ultralytics/edit/main/{module_path}.py' | ||||
|     module_name = module_path.replace(".__init__", "") | ||||
|     module_path = module_path.replace(".", "/") | ||||
|     url = f"https://github.com/ultralytics/ultralytics/blob/main/{module_path}.py" | ||||
|     edit = f"https://github.com/ultralytics/ultralytics/edit/main/{module_path}.py" | ||||
|     title_content = ( | ||||
|         f'# Reference for `{module_path}.py`\n\n' | ||||
|         f'!!! Note\n\n' | ||||
|         f'    This file is available at [{url}]({url}). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠️. Thank you 🙏!\n\n' | ||||
|         f"# Reference for `{module_path}.py`\n\n" | ||||
|         f"!!! Note\n\n" | ||||
|         f"    This file is available at [{url}]({url}). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠️. Thank you 🙏!\n\n" | ||||
|     ) | ||||
|     md_content = ['<br><br>\n'] + [f'## ::: {module_name}.{class_name}\n\n<br><br>\n' for class_name in classes] | ||||
|     md_content.extend(f'## ::: {module_name}.{func_name}\n\n<br><br>\n' for func_name in functions) | ||||
|     md_content = header_content + title_content + '\n'.join(md_content) | ||||
|     if not md_content.endswith('\n'): | ||||
|         md_content += '\n' | ||||
|     md_content = ["<br><br>\n"] + [f"## ::: {module_name}.{class_name}\n\n<br><br>\n" for class_name in classes] | ||||
|     md_content.extend(f"## ::: {module_name}.{func_name}\n\n<br><br>\n" for func_name in functions) | ||||
|     md_content = header_content + title_content + "\n".join(md_content) | ||||
|     if not md_content.endswith("\n"): | ||||
|         md_content += "\n" | ||||
| 
 | ||||
|     md_filepath.parent.mkdir(parents=True, exist_ok=True) | ||||
|     md_filepath.write_text(md_content) | ||||
| @ -80,28 +80,28 @@ def create_nav_menu_yaml(nav_items: list): | ||||
|     for item_str in nav_items: | ||||
|         item = Path(item_str) | ||||
|         parts = item.parts | ||||
|         current_level = nav_tree['reference'] | ||||
|         current_level = nav_tree["reference"] | ||||
|         for part in parts[2:-1]:  # skip the first two parts (docs and reference) and the last part (filename) | ||||
|             current_level = current_level[part] | ||||
| 
 | ||||
|         md_file_name = parts[-1].replace('.md', '') | ||||
|         md_file_name = parts[-1].replace(".md", "") | ||||
|         current_level[md_file_name] = item | ||||
| 
 | ||||
|     nav_tree_sorted = sort_nested_dict(nav_tree) | ||||
| 
 | ||||
|     def _dict_to_yaml(d, level=0): | ||||
|         """Converts a nested dictionary to a YAML-formatted string with indentation.""" | ||||
|         yaml_str = '' | ||||
|         indent = '  ' * level | ||||
|         yaml_str = "" | ||||
|         indent = "  " * level | ||||
|         for k, v in d.items(): | ||||
|             if isinstance(v, dict): | ||||
|                 yaml_str += f'{indent}- {k}:\n{_dict_to_yaml(v, level + 1)}' | ||||
|                 yaml_str += f"{indent}- {k}:\n{_dict_to_yaml(v, level + 1)}" | ||||
|             else: | ||||
|                 yaml_str += f"{indent}- {k}: {str(v).replace('docs/en/', '')}\n" | ||||
|         return yaml_str | ||||
| 
 | ||||
|     # Print updated YAML reference section | ||||
|     print('Scan complete, new mkdocs.yaml reference section is:\n\n', _dict_to_yaml(nav_tree_sorted)) | ||||
|     print("Scan complete, new mkdocs.yaml reference section is:\n\n", _dict_to_yaml(nav_tree_sorted)) | ||||
| 
 | ||||
|     # Save new YAML reference section | ||||
|     # (NEW_YAML_DIR / 'nav_menu_updated.yml').write_text(_dict_to_yaml(nav_tree_sorted)) | ||||
| @ -111,7 +111,7 @@ def main(): | ||||
|     """Main function to extract class and function names, create Markdown files, and generate a YAML navigation menu.""" | ||||
|     nav_items = [] | ||||
| 
 | ||||
|     for py_filepath in CODE_DIR.rglob('*.py'): | ||||
|     for py_filepath in CODE_DIR.rglob("*.py"): | ||||
|         classes, functions = extract_classes_and_functions(py_filepath) | ||||
| 
 | ||||
|         if classes or functions: | ||||
| @ -124,5 +124,5 @@ def main(): | ||||
|     create_nav_menu_yaml(nav_items) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  | ||||
| @ -22,69 +22,232 @@ class MarkdownLinkFixer: | ||||
|         self.base_dir = Path(base_dir) | ||||
|         self.update_links = update_links | ||||
|         self.update_text = update_text | ||||
|         self.md_link_regex = re.compile(r'\[([^]]+)]\(([^:)]+)\.md\)') | ||||
|         self.md_link_regex = re.compile(r"\[([^]]+)]\(([^:)]+)\.md\)") | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def replace_front_matter(content, lang_dir): | ||||
|         """Ensure front matter keywords remain in English.""" | ||||
|         english = ['comments', 'description', 'keywords'] | ||||
|         english = ["comments", "description", "keywords"] | ||||
|         translations = { | ||||
|             'zh': ['评论', '描述', '关键词'],  # Mandarin Chinese (Simplified) warning, sometimes translates as 关键字 | ||||
|             'es': ['comentarios', 'descripción', 'palabras clave'],  # Spanish | ||||
|             'ru': ['комментарии', 'описание', 'ключевые слова'],  # Russian | ||||
|             'pt': ['comentários', 'descrição', 'palavras-chave'],  # Portuguese | ||||
|             'fr': ['commentaires', 'description', 'mots-clés'],  # French | ||||
|             'de': ['kommentare', 'beschreibung', 'schlüsselwörter'],  # German | ||||
|             'ja': ['コメント', '説明', 'キーワード'],  # Japanese | ||||
|             'ko': ['댓글', '설명', '키워드'],  # Korean | ||||
|             'hi': ['टिप्पणियाँ', 'विवरण', 'कीवर्ड'],  # Hindi | ||||
|             'ar': ['التعليقات', 'الوصف', 'الكلمات الرئيسية']  # Arabic | ||||
|             "zh": ["评论", "描述", "关键词"],  # Mandarin Chinese (Simplified) warning, sometimes translates as 关键字 | ||||
|             "es": ["comentarios", "descripción", "palabras clave"],  # Spanish | ||||
|             "ru": ["комментарии", "описание", "ключевые слова"],  # Russian | ||||
|             "pt": ["comentários", "descrição", "palavras-chave"],  # Portuguese | ||||
|             "fr": ["commentaires", "description", "mots-clés"],  # French | ||||
|             "de": ["kommentare", "beschreibung", "schlüsselwörter"],  # German | ||||
|             "ja": ["コメント", "説明", "キーワード"],  # Japanese | ||||
|             "ko": ["댓글", "설명", "키워드"],  # Korean | ||||
|             "hi": ["टिप्पणियाँ", "विवरण", "कीवर्ड"],  # Hindi | ||||
|             "ar": ["التعليقات", "الوصف", "الكلمات الرئيسية"],  # Arabic | ||||
|         }  # front matter translations for comments, description, keyword | ||||
| 
 | ||||
|         for term, eng_key in zip(translations.get(lang_dir.stem, []), english): | ||||
|             content = re.sub(rf'{term} *[::].*', f'{eng_key}: true', content, flags=re.IGNORECASE) if \ | ||||
|                 eng_key == 'comments' else re.sub(rf'{term} *[::] *', f'{eng_key}: ', content, flags=re.IGNORECASE) | ||||
|             content = ( | ||||
|                 re.sub(rf"{term} *[::].*", f"{eng_key}: true", content, flags=re.IGNORECASE) | ||||
|                 if eng_key == "comments" | ||||
|                 else re.sub(rf"{term} *[::] *", f"{eng_key}: ", content, flags=re.IGNORECASE) | ||||
|             ) | ||||
|         return content | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def replace_admonitions(content, lang_dir): | ||||
|         """Ensure front matter keywords remain in English.""" | ||||
|         english = [ | ||||
|             'Note', 'Summary', 'Tip', 'Info', 'Success', 'Question', 'Warning', 'Failure', 'Danger', 'Bug', 'Example', | ||||
|             'Quote', 'Abstract', 'Seealso', 'Admonition'] | ||||
|             "Note", | ||||
|             "Summary", | ||||
|             "Tip", | ||||
|             "Info", | ||||
|             "Success", | ||||
|             "Question", | ||||
|             "Warning", | ||||
|             "Failure", | ||||
|             "Danger", | ||||
|             "Bug", | ||||
|             "Example", | ||||
|             "Quote", | ||||
|             "Abstract", | ||||
|             "Seealso", | ||||
|             "Admonition", | ||||
|         ] | ||||
|         translations = { | ||||
|             'en': | ||||
|             english, | ||||
|             'zh': ['笔记', '摘要', '提示', '信息', '成功', '问题', '警告', '失败', '危险', '故障', '示例', '引用', '摘要', '另见', '警告'], | ||||
|             'es': [ | ||||
|                 'Nota', 'Resumen', 'Consejo', 'Información', 'Éxito', 'Pregunta', 'Advertencia', 'Fracaso', 'Peligro', | ||||
|                 'Error', 'Ejemplo', 'Cita', 'Abstracto', 'Véase También', 'Amonestación'], | ||||
|             'ru': [ | ||||
|                 'Заметка', 'Сводка', 'Совет', 'Информация', 'Успех', 'Вопрос', 'Предупреждение', 'Неудача', 'Опасность', | ||||
|                 'Ошибка', 'Пример', 'Цитата', 'Абстракт', 'См. Также', 'Предостережение'], | ||||
|             'pt': [ | ||||
|                 'Nota', 'Resumo', 'Dica', 'Informação', 'Sucesso', 'Questão', 'Aviso', 'Falha', 'Perigo', 'Bug', | ||||
|                 'Exemplo', 'Citação', 'Abstrato', 'Veja Também', 'Advertência'], | ||||
|             'fr': [ | ||||
|                 'Note', 'Résumé', 'Conseil', 'Info', 'Succès', 'Question', 'Avertissement', 'Échec', 'Danger', 'Bug', | ||||
|                 'Exemple', 'Citation', 'Abstrait', 'Voir Aussi', 'Admonestation'], | ||||
|             'de': [ | ||||
|                 'Hinweis', 'Zusammenfassung', 'Tipp', 'Info', 'Erfolg', 'Frage', 'Warnung', 'Ausfall', 'Gefahr', | ||||
|                 'Fehler', 'Beispiel', 'Zitat', 'Abstrakt', 'Siehe Auch', 'Ermahnung'], | ||||
|             'ja': ['ノート', '要約', 'ヒント', '情報', '成功', '質問', '警告', '失敗', '危険', 'バグ', '例', '引用', '抄録', '参照', '訓告'], | ||||
|             'ko': ['노트', '요약', '팁', '정보', '성공', '질문', '경고', '실패', '위험', '버그', '예제', '인용', '추상', '참조', '경고'], | ||||
|             'hi': [ | ||||
|                 'नोट', 'सारांश', 'सुझाव', 'जानकारी', 'सफलता', 'प्रश्न', 'चेतावनी', 'विफलता', 'खतरा', 'बग', 'उदाहरण', | ||||
|                 'उद्धरण', 'सार', 'देखें भी', 'आगाही'], | ||||
|             'ar': [ | ||||
|                 'ملاحظة', 'ملخص', 'نصيحة', 'معلومات', 'نجاح', 'سؤال', 'تحذير', 'فشل', 'خطر', 'عطل', 'مثال', 'اقتباس', | ||||
|                 'ملخص', 'انظر أيضاً', 'تحذير']} | ||||
|             "en": english, | ||||
|             "zh": [ | ||||
|                 "笔记", | ||||
|                 "摘要", | ||||
|                 "提示", | ||||
|                 "信息", | ||||
|                 "成功", | ||||
|                 "问题", | ||||
|                 "警告", | ||||
|                 "失败", | ||||
|                 "危险", | ||||
|                 "故障", | ||||
|                 "示例", | ||||
|                 "引用", | ||||
|                 "摘要", | ||||
|                 "另见", | ||||
|                 "警告", | ||||
|             ], | ||||
|             "es": [ | ||||
|                 "Nota", | ||||
|                 "Resumen", | ||||
|                 "Consejo", | ||||
|                 "Información", | ||||
|                 "Éxito", | ||||
|                 "Pregunta", | ||||
|                 "Advertencia", | ||||
|                 "Fracaso", | ||||
|                 "Peligro", | ||||
|                 "Error", | ||||
|                 "Ejemplo", | ||||
|                 "Cita", | ||||
|                 "Abstracto", | ||||
|                 "Véase También", | ||||
|                 "Amonestación", | ||||
|             ], | ||||
|             "ru": [ | ||||
|                 "Заметка", | ||||
|                 "Сводка", | ||||
|                 "Совет", | ||||
|                 "Информация", | ||||
|                 "Успех", | ||||
|                 "Вопрос", | ||||
|                 "Предупреждение", | ||||
|                 "Неудача", | ||||
|                 "Опасность", | ||||
|                 "Ошибка", | ||||
|                 "Пример", | ||||
|                 "Цитата", | ||||
|                 "Абстракт", | ||||
|                 "См. Также", | ||||
|                 "Предостережение", | ||||
|             ], | ||||
|             "pt": [ | ||||
|                 "Nota", | ||||
|                 "Resumo", | ||||
|                 "Dica", | ||||
|                 "Informação", | ||||
|                 "Sucesso", | ||||
|                 "Questão", | ||||
|                 "Aviso", | ||||
|                 "Falha", | ||||
|                 "Perigo", | ||||
|                 "Bug", | ||||
|                 "Exemplo", | ||||
|                 "Citação", | ||||
|                 "Abstrato", | ||||
|                 "Veja Também", | ||||
|                 "Advertência", | ||||
|             ], | ||||
|             "fr": [ | ||||
|                 "Note", | ||||
|                 "Résumé", | ||||
|                 "Conseil", | ||||
|                 "Info", | ||||
|                 "Succès", | ||||
|                 "Question", | ||||
|                 "Avertissement", | ||||
|                 "Échec", | ||||
|                 "Danger", | ||||
|                 "Bug", | ||||
|                 "Exemple", | ||||
|                 "Citation", | ||||
|                 "Abstrait", | ||||
|                 "Voir Aussi", | ||||
|                 "Admonestation", | ||||
|             ], | ||||
|             "de": [ | ||||
|                 "Hinweis", | ||||
|                 "Zusammenfassung", | ||||
|                 "Tipp", | ||||
|                 "Info", | ||||
|                 "Erfolg", | ||||
|                 "Frage", | ||||
|                 "Warnung", | ||||
|                 "Ausfall", | ||||
|                 "Gefahr", | ||||
|                 "Fehler", | ||||
|                 "Beispiel", | ||||
|                 "Zitat", | ||||
|                 "Abstrakt", | ||||
|                 "Siehe Auch", | ||||
|                 "Ermahnung", | ||||
|             ], | ||||
|             "ja": [ | ||||
|                 "ノート", | ||||
|                 "要約", | ||||
|                 "ヒント", | ||||
|                 "情報", | ||||
|                 "成功", | ||||
|                 "質問", | ||||
|                 "警告", | ||||
|                 "失敗", | ||||
|                 "危険", | ||||
|                 "バグ", | ||||
|                 "例", | ||||
|                 "引用", | ||||
|                 "抄録", | ||||
|                 "参照", | ||||
|                 "訓告", | ||||
|             ], | ||||
|             "ko": [ | ||||
|                 "노트", | ||||
|                 "요약", | ||||
|                 "팁", | ||||
|                 "정보", | ||||
|                 "성공", | ||||
|                 "질문", | ||||
|                 "경고", | ||||
|                 "실패", | ||||
|                 "위험", | ||||
|                 "버그", | ||||
|                 "예제", | ||||
|                 "인용", | ||||
|                 "추상", | ||||
|                 "참조", | ||||
|                 "경고", | ||||
|             ], | ||||
|             "hi": [ | ||||
|                 "नोट", | ||||
|                 "सारांश", | ||||
|                 "सुझाव", | ||||
|                 "जानकारी", | ||||
|                 "सफलता", | ||||
|                 "प्रश्न", | ||||
|                 "चेतावनी", | ||||
|                 "विफलता", | ||||
|                 "खतरा", | ||||
|                 "बग", | ||||
|                 "उदाहरण", | ||||
|                 "उद्धरण", | ||||
|                 "सार", | ||||
|                 "देखें भी", | ||||
|                 "आगाही", | ||||
|             ], | ||||
|             "ar": [ | ||||
|                 "ملاحظة", | ||||
|                 "ملخص", | ||||
|                 "نصيحة", | ||||
|                 "معلومات", | ||||
|                 "نجاح", | ||||
|                 "سؤال", | ||||
|                 "تحذير", | ||||
|                 "فشل", | ||||
|                 "خطر", | ||||
|                 "عطل", | ||||
|                 "مثال", | ||||
|                 "اقتباس", | ||||
|                 "ملخص", | ||||
|                 "انظر أيضاً", | ||||
|                 "تحذير", | ||||
|             ], | ||||
|         } | ||||
| 
 | ||||
|         for term, eng_key in zip(translations.get(lang_dir.stem, []), english): | ||||
|             if lang_dir.stem != 'en': | ||||
|                 content = re.sub(rf'!!! *{eng_key} *\n', f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE) | ||||
|                 content = re.sub(rf'!!! *{term} *\n', f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE) | ||||
|             content = re.sub(rf'!!! *{term}', f'!!! {eng_key}', content, flags=re.IGNORECASE) | ||||
|             if lang_dir.stem != "en": | ||||
|                 content = re.sub(rf"!!! *{eng_key} *\n", f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE) | ||||
|                 content = re.sub(rf"!!! *{term} *\n", f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE) | ||||
|             content = re.sub(rf"!!! *{term}", f"!!! {eng_key}", content, flags=re.IGNORECASE) | ||||
|             content = re.sub(r'!!! *"', '!!! Example "', content, flags=re.IGNORECASE) | ||||
| 
 | ||||
|         return content | ||||
| @ -92,30 +255,30 @@ class MarkdownLinkFixer: | ||||
|     @staticmethod | ||||
|     def update_iframe(content): | ||||
|         """Update the 'allow' attribute of iframe if it does not contain the specific English permissions.""" | ||||
|         english = 'accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share' | ||||
|         english = "accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" | ||||
|         pattern = re.compile(f'allow="(?!{re.escape(english)}).+?"') | ||||
|         return pattern.sub(f'allow="{english}"', content) | ||||
| 
 | ||||
|     def link_replacer(self, match, parent_dir, lang_dir, use_abs_link=False): | ||||
|         """Replace broken links with corresponding links in the /en/ directory.""" | ||||
|         text, path = match.groups() | ||||
|         linked_path = (parent_dir / path).resolve().with_suffix('.md') | ||||
|         linked_path = (parent_dir / path).resolve().with_suffix(".md") | ||||
| 
 | ||||
|         if not linked_path.exists(): | ||||
|             en_linked_path = Path(str(linked_path).replace(str(lang_dir), str(lang_dir.parent / 'en'))) | ||||
|             en_linked_path = Path(str(linked_path).replace(str(lang_dir), str(lang_dir.parent / "en"))) | ||||
|             if en_linked_path.exists(): | ||||
|                 if use_abs_link: | ||||
|                     # Use absolute links WARNING: BUGS, DO NOT USE | ||||
|                     docs_root_relative_path = en_linked_path.relative_to(lang_dir.parent) | ||||
|                     updated_path = str(docs_root_relative_path).replace('en/', '/../') | ||||
|                     updated_path = str(docs_root_relative_path).replace("en/", "/../") | ||||
|                 else: | ||||
|                     # Use relative links | ||||
|                     steps_up = len(parent_dir.relative_to(self.base_dir).parts) | ||||
|                     updated_path = Path('../' * steps_up) / en_linked_path.relative_to(self.base_dir) | ||||
|                     updated_path = str(updated_path).replace('/en/', '/') | ||||
|                     updated_path = Path("../" * steps_up) / en_linked_path.relative_to(self.base_dir) | ||||
|                     updated_path = str(updated_path).replace("/en/", "/") | ||||
| 
 | ||||
|                 print(f"Redirecting link '[{text}]({path})' from {parent_dir} to {updated_path}") | ||||
|                 return f'[{text}]({updated_path})' | ||||
|                 return f"[{text}]({updated_path})" | ||||
|             else: | ||||
|                 print(f"Warning: Broken link '[{text}]({path})' found in {parent_dir} does not exist in /docs/en/.") | ||||
| 
 | ||||
| @ -124,28 +287,30 @@ class MarkdownLinkFixer: | ||||
|     @staticmethod | ||||
|     def update_html_tags(content): | ||||
|         """Updates HTML tags in docs.""" | ||||
|         alt_tag = 'MISSING' | ||||
|         alt_tag = "MISSING" | ||||
| 
 | ||||
|         # Remove closing slashes from self-closing HTML tags | ||||
|         pattern = re.compile(r'<([^>]+?)\s*/>') | ||||
|         content = re.sub(pattern, r'<\1>', content) | ||||
|         pattern = re.compile(r"<([^>]+?)\s*/>") | ||||
|         content = re.sub(pattern, r"<\1>", content) | ||||
| 
 | ||||
|         # Find all images without alt tags and add placeholder alt text | ||||
|         pattern = re.compile(r'!\[(.*?)\]\((.*?)\)') | ||||
|         content, num_replacements = re.subn(pattern, lambda match: f'})', | ||||
|                                             content) | ||||
|         pattern = re.compile(r"!\[(.*?)\]\((.*?)\)") | ||||
|         content, num_replacements = re.subn( | ||||
|             pattern, lambda match: f"})", content | ||||
|         ) | ||||
| 
 | ||||
|         # Add missing alt tags to HTML images | ||||
|         pattern = re.compile(r'<img\s+(?!.*?\balt\b)[^>]*src=["\'](.*?)["\'][^>]*>') | ||||
|         content, num_replacements = re.subn(pattern, lambda match: match.group(0).replace('>', f' alt="{alt_tag}">', 1), | ||||
|                                             content) | ||||
|         content, num_replacements = re.subn( | ||||
|             pattern, lambda match: match.group(0).replace(">", f' alt="{alt_tag}">', 1), content | ||||
|         ) | ||||
| 
 | ||||
|         return content | ||||
| 
 | ||||
|     def process_markdown_file(self, md_file_path, lang_dir): | ||||
|         """Process each markdown file in the language directory.""" | ||||
|         print(f'Processing file: {md_file_path}') | ||||
|         with open(md_file_path, encoding='utf-8') as file: | ||||
|         print(f"Processing file: {md_file_path}") | ||||
|         with open(md_file_path, encoding="utf-8") as file: | ||||
|             content = file.read() | ||||
| 
 | ||||
|         if self.update_links: | ||||
| @ -157,23 +322,23 @@ class MarkdownLinkFixer: | ||||
|             content = self.update_iframe(content) | ||||
|             content = self.update_html_tags(content) | ||||
| 
 | ||||
|         with open(md_file_path, 'w', encoding='utf-8') as file: | ||||
|         with open(md_file_path, "w", encoding="utf-8") as file: | ||||
|             file.write(content) | ||||
| 
 | ||||
|     def process_language_directory(self, lang_dir): | ||||
|         """Process each language-specific directory.""" | ||||
|         print(f'Processing language directory: {lang_dir}') | ||||
|         for md_file in lang_dir.rglob('*.md'): | ||||
|         print(f"Processing language directory: {lang_dir}") | ||||
|         for md_file in lang_dir.rglob("*.md"): | ||||
|             self.process_markdown_file(md_file, lang_dir) | ||||
| 
 | ||||
|     def run(self): | ||||
|         """Run the link fixing and front matter updating process for each language-specific directory.""" | ||||
|         for subdir in self.base_dir.iterdir(): | ||||
|             if subdir.is_dir() and re.match(r'^\w\w$', subdir.name): | ||||
|             if subdir.is_dir() and re.match(r"^\w\w$", subdir.name): | ||||
|                 self.process_language_directory(subdir) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     # Set the path to your MkDocs 'docs' directory here | ||||
|     docs_dir = str(Path(__file__).parent.resolve()) | ||||
|     fixer = MarkdownLinkFixer(docs_dir, update_links=True, update_text=True) | ||||
|  | ||||
| @ -28,7 +28,7 @@ class YOLOv8: | ||||
|         self.iou_thres = iou_thres | ||||
| 
 | ||||
|         # Load the class names from the COCO dataset | ||||
|         self.classes = yaml_load(check_yaml('coco128.yaml'))['names'] | ||||
|         self.classes = yaml_load(check_yaml("coco128.yaml"))["names"] | ||||
| 
 | ||||
|         # Generate a color palette for the classes | ||||
|         self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3)) | ||||
| @ -57,7 +57,7 @@ class YOLOv8: | ||||
|         cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2) | ||||
| 
 | ||||
|         # Create the label text with class name and score | ||||
|         label = f'{self.classes[class_id]}: {score:.2f}' | ||||
|         label = f"{self.classes[class_id]}: {score:.2f}" | ||||
| 
 | ||||
|         # Calculate the dimensions of the label text | ||||
|         (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | ||||
| @ -67,8 +67,9 @@ class YOLOv8: | ||||
|         label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 | ||||
| 
 | ||||
|         # Draw a filled rectangle as the background for the label text | ||||
|         cv2.rectangle(img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, | ||||
|                       cv2.FILLED) | ||||
|         cv2.rectangle( | ||||
|             img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED | ||||
|         ) | ||||
| 
 | ||||
|         # Draw the label text on the image | ||||
|         cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) | ||||
| @ -182,7 +183,7 @@ class YOLOv8: | ||||
|             output_img: The output image with drawn detections. | ||||
|         """ | ||||
|         # Create an inference session using the ONNX model and specify execution providers | ||||
|         session = ort.InferenceSession(self.onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | ||||
|         session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) | ||||
| 
 | ||||
|         # Get the model inputs | ||||
|         model_inputs = session.get_inputs() | ||||
| @ -202,17 +203,17 @@ class YOLOv8: | ||||
|         return self.postprocess(self.img, outputs)  # output image | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     # Create an argument parser to handle command-line arguments | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--model', type=str, default='yolov8n.onnx', help='Input your ONNX model.') | ||||
|     parser.add_argument('--img', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image.') | ||||
|     parser.add_argument('--conf-thres', type=float, default=0.5, help='Confidence threshold') | ||||
|     parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold') | ||||
|     parser.add_argument("--model", type=str, default="yolov8n.onnx", help="Input your ONNX model.") | ||||
|     parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.") | ||||
|     parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold") | ||||
|     parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold") | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
|     # Check the requirements and select the appropriate backend (CPU or GPU) | ||||
|     check_requirements('onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime') | ||||
|     check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime") | ||||
| 
 | ||||
|     # Create an instance of the YOLOv8 class with the specified arguments | ||||
|     detection = YOLOv8(args.model, args.img, args.conf_thres, args.iou_thres) | ||||
| @ -221,8 +222,8 @@ if __name__ == '__main__': | ||||
|     output_image = detection.main() | ||||
| 
 | ||||
|     # Display the output image in a window | ||||
|     cv2.namedWindow('Output', cv2.WINDOW_NORMAL) | ||||
|     cv2.imshow('Output', output_image) | ||||
|     cv2.namedWindow("Output", cv2.WINDOW_NORMAL) | ||||
|     cv2.imshow("Output", output_image) | ||||
| 
 | ||||
|     # Wait for a key press to exit | ||||
|     cv2.waitKey(0) | ||||
|  | ||||
| @ -6,7 +6,7 @@ import numpy as np | ||||
| from ultralytics.utils import ASSETS, yaml_load | ||||
| from ultralytics.utils.checks import check_yaml | ||||
| 
 | ||||
| CLASSES = yaml_load(check_yaml('coco128.yaml'))['names'] | ||||
| CLASSES = yaml_load(check_yaml("coco128.yaml"))["names"] | ||||
| colors = np.random.uniform(0, 255, size=(len(CLASSES), 3)) | ||||
| 
 | ||||
| 
 | ||||
| @ -23,7 +23,7 @@ def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h): | ||||
|         x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box. | ||||
|         y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box. | ||||
|     """ | ||||
|     label = f'{CLASSES[class_id]} ({confidence:.2f})' | ||||
|     label = f"{CLASSES[class_id]} ({confidence:.2f})" | ||||
|     color = colors[class_id] | ||||
|     cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2) | ||||
|     cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | ||||
| @ -76,8 +76,11 @@ def main(onnx_model, input_image): | ||||
|         (minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores) | ||||
|         if maxScore >= 0.25: | ||||
|             box = [ | ||||
|                 outputs[0][i][0] - (0.5 * outputs[0][i][2]), outputs[0][i][1] - (0.5 * outputs[0][i][3]), | ||||
|                 outputs[0][i][2], outputs[0][i][3]] | ||||
|                 outputs[0][i][0] - (0.5 * outputs[0][i][2]), | ||||
|                 outputs[0][i][1] - (0.5 * outputs[0][i][3]), | ||||
|                 outputs[0][i][2], | ||||
|                 outputs[0][i][3], | ||||
|             ] | ||||
|             boxes.append(box) | ||||
|             scores.append(maxScore) | ||||
|             class_ids.append(maxClassIndex) | ||||
| @ -92,26 +95,34 @@ def main(onnx_model, input_image): | ||||
|         index = result_boxes[i] | ||||
|         box = boxes[index] | ||||
|         detection = { | ||||
|             'class_id': class_ids[index], | ||||
|             'class_name': CLASSES[class_ids[index]], | ||||
|             'confidence': scores[index], | ||||
|             'box': box, | ||||
|             'scale': scale} | ||||
|             "class_id": class_ids[index], | ||||
|             "class_name": CLASSES[class_ids[index]], | ||||
|             "confidence": scores[index], | ||||
|             "box": box, | ||||
|             "scale": scale, | ||||
|         } | ||||
|         detections.append(detection) | ||||
|         draw_bounding_box(original_image, class_ids[index], scores[index], round(box[0] * scale), round(box[1] * scale), | ||||
|                           round((box[0] + box[2]) * scale), round((box[1] + box[3]) * scale)) | ||||
|         draw_bounding_box( | ||||
|             original_image, | ||||
|             class_ids[index], | ||||
|             scores[index], | ||||
|             round(box[0] * scale), | ||||
|             round(box[1] * scale), | ||||
|             round((box[0] + box[2]) * scale), | ||||
|             round((box[1] + box[3]) * scale), | ||||
|         ) | ||||
| 
 | ||||
|     # Display the image with bounding boxes | ||||
|     cv2.imshow('image', original_image) | ||||
|     cv2.imshow("image", original_image) | ||||
|     cv2.waitKey(0) | ||||
|     cv2.destroyAllWindows() | ||||
| 
 | ||||
|     return detections | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--model', default='yolov8n.onnx', help='Input your ONNX model.') | ||||
|     parser.add_argument('--img', default=str(ASSETS / 'bus.jpg'), help='Path to input image.') | ||||
|     parser.add_argument("--model", default="yolov8n.onnx", help="Input your ONNX model.") | ||||
|     parser.add_argument("--img", default=str(ASSETS / "bus.jpg"), help="Path to input image.") | ||||
|     args = parser.parse_args() | ||||
|     main(args.model, args.img) | ||||
|  | ||||
| @ -13,14 +13,9 @@ img_height = 640 | ||||
| 
 | ||||
| 
 | ||||
| class LetterBox: | ||||
| 
 | ||||
|     def __init__(self, | ||||
|                  new_shape=(img_width, img_height), | ||||
|                  auto=False, | ||||
|                  scaleFill=False, | ||||
|                  scaleup=True, | ||||
|                  center=True, | ||||
|                  stride=32): | ||||
|     def __init__( | ||||
|         self, new_shape=(img_width, img_height), auto=False, scaleFill=False, scaleup=True, center=True, stride=32 | ||||
|     ): | ||||
|         self.new_shape = new_shape | ||||
|         self.auto = auto | ||||
|         self.scaleFill = scaleFill | ||||
| @ -33,9 +28,9 @@ class LetterBox: | ||||
| 
 | ||||
|         if labels is None: | ||||
|             labels = {} | ||||
|         img = labels.get('img') if image is None else image | ||||
|         img = labels.get("img") if image is None else image | ||||
|         shape = img.shape[:2]  # current shape [height, width] | ||||
|         new_shape = labels.pop('rect_shape', self.new_shape) | ||||
|         new_shape = labels.pop("rect_shape", self.new_shape) | ||||
|         if isinstance(new_shape, int): | ||||
|             new_shape = (new_shape, new_shape) | ||||
| 
 | ||||
| @ -63,15 +58,16 @@ class LetterBox: | ||||
|             img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) | ||||
|         top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1)) | ||||
|         left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1)) | ||||
|         img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, | ||||
|                                  value=(114, 114, 114))  # add border | ||||
|         if labels.get('ratio_pad'): | ||||
|             labels['ratio_pad'] = (labels['ratio_pad'], (left, top))  # for evaluation | ||||
|         img = cv2.copyMakeBorder( | ||||
|             img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) | ||||
|         )  # add border | ||||
|         if labels.get("ratio_pad"): | ||||
|             labels["ratio_pad"] = (labels["ratio_pad"], (left, top))  # for evaluation | ||||
| 
 | ||||
|         if len(labels): | ||||
|             labels = self._update_labels(labels, ratio, dw, dh) | ||||
|             labels['img'] = img | ||||
|             labels['resized_shape'] = new_shape | ||||
|             labels["img"] = img | ||||
|             labels["resized_shape"] = new_shape | ||||
|             return labels | ||||
|         else: | ||||
|             return img | ||||
| @ -79,15 +75,14 @@ class LetterBox: | ||||
|     def _update_labels(self, labels, ratio, padw, padh): | ||||
|         """Update labels.""" | ||||
| 
 | ||||
|         labels['instances'].convert_bbox(format='xyxy') | ||||
|         labels['instances'].denormalize(*labels['img'].shape[:2][::-1]) | ||||
|         labels['instances'].scale(*ratio) | ||||
|         labels['instances'].add_padding(padw, padh) | ||||
|         labels["instances"].convert_bbox(format="xyxy") | ||||
|         labels["instances"].denormalize(*labels["img"].shape[:2][::-1]) | ||||
|         labels["instances"].scale(*ratio) | ||||
|         labels["instances"].add_padding(padw, padh) | ||||
|         return labels | ||||
| 
 | ||||
| 
 | ||||
| class Yolov8TFLite: | ||||
| 
 | ||||
|     def __init__(self, tflite_model, input_image, confidence_thres, iou_thres): | ||||
|         """ | ||||
|         Initializes an instance of the Yolov8TFLite class. | ||||
| @ -105,7 +100,7 @@ class Yolov8TFLite: | ||||
|         self.iou_thres = iou_thres | ||||
| 
 | ||||
|         # Load the class names from the COCO dataset | ||||
|         self.classes = yaml_load(check_yaml('coco128.yaml'))['names'] | ||||
|         self.classes = yaml_load(check_yaml("coco128.yaml"))["names"] | ||||
| 
 | ||||
|         # Generate a color palette for the classes | ||||
|         self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3)) | ||||
| @ -134,7 +129,7 @@ class Yolov8TFLite: | ||||
|         cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2) | ||||
| 
 | ||||
|         # Create the label text with class name and score | ||||
|         label = f'{self.classes[class_id]}: {score:.2f}' | ||||
|         label = f"{self.classes[class_id]}: {score:.2f}" | ||||
| 
 | ||||
|         # Calculate the dimensions of the label text | ||||
|         (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | ||||
| @ -144,8 +139,13 @@ class Yolov8TFLite: | ||||
|         label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 | ||||
| 
 | ||||
|         # Draw a filled rectangle as the background for the label text | ||||
|         cv2.rectangle(img, (int(label_x), int(label_y - label_height)), | ||||
|                       (int(label_x + label_width), int(label_y + label_height)), color, cv2.FILLED) | ||||
|         cv2.rectangle( | ||||
|             img, | ||||
|             (int(label_x), int(label_y - label_height)), | ||||
|             (int(label_x + label_width), int(label_y + label_height)), | ||||
|             color, | ||||
|             cv2.FILLED, | ||||
|         ) | ||||
| 
 | ||||
|         # Draw the label text on the image | ||||
|         cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) | ||||
| @ -161,7 +161,7 @@ class Yolov8TFLite: | ||||
|         # Read the input image using OpenCV | ||||
|         self.img = cv2.imread(self.input_image) | ||||
| 
 | ||||
|         print('image befor', self.img) | ||||
|         print("image before", self.img) | ||||
|         # Get the height and width of the input image | ||||
|         self.img_height, self.img_width = self.img.shape[:2] | ||||
| 
 | ||||
| @ -209,8 +209,10 @@ class Yolov8TFLite: | ||||
|             # Get the box, score, and class ID corresponding to the index | ||||
|             box = boxes[i] | ||||
|             gain = min(img_width / self.img_width, img_height / self.img_height) | ||||
|             pad = round((img_width - self.img_width * gain) / 2 - | ||||
|                         0.1), round((img_height - self.img_height * gain) / 2 - 0.1) | ||||
|             pad = ( | ||||
|                 round((img_width - self.img_width * gain) / 2 - 0.1), | ||||
|                 round((img_height - self.img_height * gain) / 2 - 0.1), | ||||
|             ) | ||||
|             box[0] = (box[0] - pad[0]) / gain | ||||
|             box[1] = (box[1] - pad[1]) / gain | ||||
|             box[2] = box[2] / gain | ||||
| @ -242,7 +244,7 @@ class Yolov8TFLite: | ||||
|         output_details = interpreter.get_output_details() | ||||
| 
 | ||||
|         # Store the shape of the input for later use | ||||
|         input_shape = input_details[0]['shape'] | ||||
|         input_shape = input_details[0]["shape"] | ||||
|         self.input_width = input_shape[1] | ||||
|         self.input_height = input_shape[2] | ||||
| 
 | ||||
| @ -251,19 +253,19 @@ class Yolov8TFLite: | ||||
|         img_data = img_data | ||||
|         # img_data = img_data.cpu().numpy() | ||||
|         # Set the input tensor to the interpreter | ||||
|         print(input_details[0]['index']) | ||||
|         print(input_details[0]["index"]) | ||||
|         print(img_data.shape) | ||||
|         img_data = img_data.transpose((0, 2, 3, 1)) | ||||
| 
 | ||||
|         scale, zero_point = input_details[0]['quantization'] | ||||
|         interpreter.set_tensor(input_details[0]['index'], img_data) | ||||
|         scale, zero_point = input_details[0]["quantization"] | ||||
|         interpreter.set_tensor(input_details[0]["index"], img_data) | ||||
| 
 | ||||
|         # Run inference | ||||
|         interpreter.invoke() | ||||
| 
 | ||||
|         # Get the output tensor from the interpreter | ||||
|         output = interpreter.get_tensor(output_details[0]['index']) | ||||
|         scale, zero_point = output_details[0]['quantization'] | ||||
|         output = interpreter.get_tensor(output_details[0]["index"]) | ||||
|         scale, zero_point = output_details[0]["quantization"] | ||||
|         output = (output.astype(np.float32) - zero_point) * scale | ||||
| 
 | ||||
|         output[:, [0, 2]] *= img_width | ||||
| @ -273,16 +275,15 @@ class Yolov8TFLite: | ||||
|         return self.postprocess(self.img, output) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     # Create an argument parser to handle command-line arguments | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--model', | ||||
|                         type=str, | ||||
|                         default='yolov8n_full_integer_quant.tflite', | ||||
|                         help='Input your TFLite model.') | ||||
|     parser.add_argument('--img', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image.') | ||||
|     parser.add_argument('--conf-thres', type=float, default=0.5, help='Confidence threshold') | ||||
|     parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold') | ||||
|     parser.add_argument( | ||||
|         "--model", type=str, default="yolov8n_full_integer_quant.tflite", help="Input your TFLite model." | ||||
|     ) | ||||
|     parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.") | ||||
|     parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold") | ||||
|     parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold") | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
|     # Create an instance of the Yolov8TFLite class with the specified arguments | ||||
| @ -292,7 +293,7 @@ if __name__ == '__main__': | ||||
|     output_image = detection.main() | ||||
| 
 | ||||
|     # Display the output image in a window | ||||
|     cv2.imshow('Output', output_image) | ||||
|     cv2.imshow("Output", output_image) | ||||
| 
 | ||||
|     # Wait for a key press to exit | ||||
|     cv2.waitKey(0) | ||||
|  | ||||
| @ -16,21 +16,22 @@ track_history = defaultdict(list) | ||||
| current_region = None | ||||
| counting_regions = [ | ||||
|     { | ||||
|         'name': 'YOLOv8 Polygon Region', | ||||
|         'polygon': Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]),  # Polygon points | ||||
|         'counts': 0, | ||||
|         'dragging': False, | ||||
|         'region_color': (255, 42, 4),  # BGR Value | ||||
|         'text_color': (255, 255, 255)  # Region Text Color | ||||
|         "name": "YOLOv8 Polygon Region", | ||||
|         "polygon": Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]),  # Polygon points | ||||
|         "counts": 0, | ||||
|         "dragging": False, | ||||
|         "region_color": (255, 42, 4),  # BGR Value | ||||
|         "text_color": (255, 255, 255),  # Region Text Color | ||||
|     }, | ||||
|     { | ||||
|         'name': 'YOLOv8 Rectangle Region', | ||||
|         'polygon': Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]),  # Polygon points | ||||
|         'counts': 0, | ||||
|         'dragging': False, | ||||
|         'region_color': (37, 255, 225),  # BGR Value | ||||
|         'text_color': (0, 0, 0),  # Region Text Color | ||||
|     }, ] | ||||
|         "name": "YOLOv8 Rectangle Region", | ||||
|         "polygon": Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]),  # Polygon points | ||||
|         "counts": 0, | ||||
|         "dragging": False, | ||||
|         "region_color": (37, 255, 225),  # BGR Value | ||||
|         "text_color": (0, 0, 0),  # Region Text Color | ||||
|     }, | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
| def mouse_callback(event, x, y, flags, param): | ||||
| @ -40,32 +41,33 @@ def mouse_callback(event, x, y, flags, param): | ||||
|     # Mouse left button down event | ||||
|     if event == cv2.EVENT_LBUTTONDOWN: | ||||
|         for region in counting_regions: | ||||
|             if region['polygon'].contains(Point((x, y))): | ||||
|             if region["polygon"].contains(Point((x, y))): | ||||
|                 current_region = region | ||||
|                 current_region['dragging'] = True | ||||
|                 current_region['offset_x'] = x | ||||
|                 current_region['offset_y'] = y | ||||
|                 current_region["dragging"] = True | ||||
|                 current_region["offset_x"] = x | ||||
|                 current_region["offset_y"] = y | ||||
| 
 | ||||
|     # Mouse move event | ||||
|     elif event == cv2.EVENT_MOUSEMOVE: | ||||
|         if current_region is not None and current_region['dragging']: | ||||
|             dx = x - current_region['offset_x'] | ||||
|             dy = y - current_region['offset_y'] | ||||
|             current_region['polygon'] = Polygon([ | ||||
|                 (p[0] + dx, p[1] + dy) for p in current_region['polygon'].exterior.coords]) | ||||
|             current_region['offset_x'] = x | ||||
|             current_region['offset_y'] = y | ||||
|         if current_region is not None and current_region["dragging"]: | ||||
|             dx = x - current_region["offset_x"] | ||||
|             dy = y - current_region["offset_y"] | ||||
|             current_region["polygon"] = Polygon( | ||||
|                 [(p[0] + dx, p[1] + dy) for p in current_region["polygon"].exterior.coords] | ||||
|             ) | ||||
|             current_region["offset_x"] = x | ||||
|             current_region["offset_y"] = y | ||||
| 
 | ||||
|     # Mouse left button up event | ||||
|     elif event == cv2.EVENT_LBUTTONUP: | ||||
|         if current_region is not None and current_region['dragging']: | ||||
|             current_region['dragging'] = False | ||||
|         if current_region is not None and current_region["dragging"]: | ||||
|             current_region["dragging"] = False | ||||
| 
 | ||||
| 
 | ||||
| def run( | ||||
|     weights='yolov8n.pt', | ||||
|     weights="yolov8n.pt", | ||||
|     source=None, | ||||
|     device='cpu', | ||||
|     device="cpu", | ||||
|     view_img=False, | ||||
|     save_img=False, | ||||
|     exist_ok=False, | ||||
| @ -100,8 +102,8 @@ def run( | ||||
|         raise FileNotFoundError(f"Source path '{source}' does not exist.") | ||||
| 
 | ||||
|     # Setup Model | ||||
|     model = YOLO(f'{weights}') | ||||
|     model.to('cuda') if device == '0' else model.to('cpu') | ||||
|     model = YOLO(f"{weights}") | ||||
|     model.to("cuda") if device == "0" else model.to("cpu") | ||||
| 
 | ||||
|     # Extract classes names | ||||
|     names = model.model.names | ||||
| @ -109,12 +111,12 @@ def run( | ||||
|     # Video setup | ||||
|     videocapture = cv2.VideoCapture(source) | ||||
|     frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4)) | ||||
|     fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v') | ||||
|     fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v") | ||||
| 
 | ||||
|     # Output setup | ||||
|     save_dir = increment_path(Path('ultralytics_rc_output') / 'exp', exist_ok) | ||||
|     save_dir = increment_path(Path("ultralytics_rc_output") / "exp", exist_ok) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height)) | ||||
|     video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height)) | ||||
| 
 | ||||
|     # Iterate over video frames | ||||
|     while videocapture.isOpened(): | ||||
| @ -146,43 +148,48 @@ def run( | ||||
| 
 | ||||
|                 # Check if detection inside region | ||||
|                 for region in counting_regions: | ||||
|                     if region['polygon'].contains(Point((bbox_center[0], bbox_center[1]))): | ||||
|                         region['counts'] += 1 | ||||
|                     if region["polygon"].contains(Point((bbox_center[0], bbox_center[1]))): | ||||
|                         region["counts"] += 1 | ||||
| 
 | ||||
|         # Draw regions (Polygons/Rectangles) | ||||
|         for region in counting_regions: | ||||
|             region_label = str(region['counts']) | ||||
|             region_color = region['region_color'] | ||||
|             region_text_color = region['text_color'] | ||||
|             region_label = str(region["counts"]) | ||||
|             region_color = region["region_color"] | ||||
|             region_text_color = region["text_color"] | ||||
| 
 | ||||
|             polygon_coords = np.array(region['polygon'].exterior.coords, dtype=np.int32) | ||||
|             centroid_x, centroid_y = int(region['polygon'].centroid.x), int(region['polygon'].centroid.y) | ||||
|             polygon_coords = np.array(region["polygon"].exterior.coords, dtype=np.int32) | ||||
|             centroid_x, centroid_y = int(region["polygon"].centroid.x), int(region["polygon"].centroid.y) | ||||
| 
 | ||||
|             text_size, _ = cv2.getTextSize(region_label, | ||||
|                                            cv2.FONT_HERSHEY_SIMPLEX, | ||||
|                                            fontScale=0.7, | ||||
|                                            thickness=line_thickness) | ||||
|             text_size, _ = cv2.getTextSize( | ||||
|                 region_label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, thickness=line_thickness | ||||
|             ) | ||||
|             text_x = centroid_x - text_size[0] // 2 | ||||
|             text_y = centroid_y + text_size[1] // 2 | ||||
|             cv2.rectangle(frame, (text_x - 5, text_y - text_size[1] - 5), (text_x + text_size[0] + 5, text_y + 5), | ||||
|                           region_color, -1) | ||||
|             cv2.putText(frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color, | ||||
|                         line_thickness) | ||||
|             cv2.rectangle( | ||||
|                 frame, | ||||
|                 (text_x - 5, text_y - text_size[1] - 5), | ||||
|                 (text_x + text_size[0] + 5, text_y + 5), | ||||
|                 region_color, | ||||
|                 -1, | ||||
|             ) | ||||
|             cv2.putText( | ||||
|                 frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color, line_thickness | ||||
|             ) | ||||
|             cv2.polylines(frame, [polygon_coords], isClosed=True, color=region_color, thickness=region_thickness) | ||||
| 
 | ||||
|         if view_img: | ||||
|             if vid_frame_count == 1: | ||||
|                 cv2.namedWindow('Ultralytics YOLOv8 Region Counter Movable') | ||||
|                 cv2.setMouseCallback('Ultralytics YOLOv8 Region Counter Movable', mouse_callback) | ||||
|             cv2.imshow('Ultralytics YOLOv8 Region Counter Movable', frame) | ||||
|                 cv2.namedWindow("Ultralytics YOLOv8 Region Counter Movable") | ||||
|                 cv2.setMouseCallback("Ultralytics YOLOv8 Region Counter Movable", mouse_callback) | ||||
|             cv2.imshow("Ultralytics YOLOv8 Region Counter Movable", frame) | ||||
| 
 | ||||
|         if save_img: | ||||
|             video_writer.write(frame) | ||||
| 
 | ||||
|         for region in counting_regions:  # Reinitialize count for each region | ||||
|             region['counts'] = 0 | ||||
|             region["counts"] = 0 | ||||
| 
 | ||||
|         if cv2.waitKey(1) & 0xFF == ord('q'): | ||||
|         if cv2.waitKey(1) & 0xFF == ord("q"): | ||||
|             break | ||||
| 
 | ||||
|     del vid_frame_count | ||||
| @ -194,16 +201,16 @@ def run( | ||||
| def parse_opt(): | ||||
|     """Parse command line arguments.""" | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path') | ||||
|     parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | ||||
|     parser.add_argument('--source', type=str, required=True, help='video file path') | ||||
|     parser.add_argument('--view-img', action='store_true', help='show results') | ||||
|     parser.add_argument('--save-img', action='store_true', help='save results') | ||||
|     parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') | ||||
|     parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3') | ||||
|     parser.add_argument('--line-thickness', type=int, default=2, help='bounding box thickness') | ||||
|     parser.add_argument('--track-thickness', type=int, default=2, help='Tracking line thickness') | ||||
|     parser.add_argument('--region-thickness', type=int, default=4, help='Region thickness') | ||||
|     parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path") | ||||
|     parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu") | ||||
|     parser.add_argument("--source", type=str, required=True, help="video file path") | ||||
|     parser.add_argument("--view-img", action="store_true", help="show results") | ||||
|     parser.add_argument("--save-img", action="store_true", help="save results") | ||||
|     parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") | ||||
|     parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3") | ||||
|     parser.add_argument("--line-thickness", type=int, default=2, help="bounding box thickness") | ||||
|     parser.add_argument("--track-thickness", type=int, default=2, help="Tracking line thickness") | ||||
|     parser.add_argument("--region-thickness", type=int, default=4, help="Region thickness") | ||||
| 
 | ||||
|     return parser.parse_args() | ||||
| 
 | ||||
| @ -213,6 +220,6 @@ def main(opt): | ||||
|     run(**vars(opt)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     opt = parse_opt() | ||||
|     main(opt) | ||||
|  | ||||
| @ -9,7 +9,7 @@ from sahi.utils.yolov8 import download_yolov8s_model | ||||
| from ultralytics.utils.files import increment_path | ||||
| 
 | ||||
| 
 | ||||
| def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, exist_ok=False): | ||||
| def run(weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False): | ||||
|     """ | ||||
|     Run object detection on a video using YOLOv8 and SAHI. | ||||
| 
 | ||||
| @ -25,41 +25,41 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, | ||||
|     if not Path(source).exists(): | ||||
|         raise FileNotFoundError(f"Source path '{source}' does not exist.") | ||||
| 
 | ||||
|     yolov8_model_path = f'models/{weights}' | ||||
|     yolov8_model_path = f"models/{weights}" | ||||
|     download_yolov8s_model(yolov8_model_path) | ||||
|     detection_model = AutoDetectionModel.from_pretrained(model_type='yolov8', | ||||
|                                                          model_path=yolov8_model_path, | ||||
|                                                          confidence_threshold=0.3, | ||||
|                                                          device='cpu') | ||||
|     detection_model = AutoDetectionModel.from_pretrained( | ||||
|         model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu" | ||||
|     ) | ||||
| 
 | ||||
|     # Video setup | ||||
|     videocapture = cv2.VideoCapture(source) | ||||
|     frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4)) | ||||
|     fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v') | ||||
|     fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v") | ||||
| 
 | ||||
|     # Output setup | ||||
|     save_dir = increment_path(Path('ultralytics_results_with_sahi') / 'exp', exist_ok) | ||||
|     save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height)) | ||||
|     video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height)) | ||||
| 
 | ||||
|     while videocapture.isOpened(): | ||||
|         success, frame = videocapture.read() | ||||
|         if not success: | ||||
|             break | ||||
| 
 | ||||
|         results = get_sliced_prediction(frame, | ||||
|                                         detection_model, | ||||
|                                         slice_height=512, | ||||
|                                         slice_width=512, | ||||
|                                         overlap_height_ratio=0.2, | ||||
|                                         overlap_width_ratio=0.2) | ||||
|         results = get_sliced_prediction( | ||||
|             frame, detection_model, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2 | ||||
|         ) | ||||
|         object_prediction_list = results.object_prediction_list | ||||
| 
 | ||||
|         boxes_list = [] | ||||
|         clss_list = [] | ||||
|         for ind, _ in enumerate(object_prediction_list): | ||||
|             boxes = object_prediction_list[ind].bbox.minx, object_prediction_list[ind].bbox.miny, \ | ||||
|                 object_prediction_list[ind].bbox.maxx, object_prediction_list[ind].bbox.maxy | ||||
|             boxes = ( | ||||
|                 object_prediction_list[ind].bbox.minx, | ||||
|                 object_prediction_list[ind].bbox.miny, | ||||
|                 object_prediction_list[ind].bbox.maxx, | ||||
|                 object_prediction_list[ind].bbox.maxy, | ||||
|             ) | ||||
|             clss = object_prediction_list[ind].category.name | ||||
|             boxes_list.append(boxes) | ||||
|             clss_list.append(clss) | ||||
| @ -69,21 +69,19 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, | ||||
|             cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2) | ||||
|             label = str(cls) | ||||
|             t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0] | ||||
|             cv2.rectangle(frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), | ||||
|                           -1) | ||||
|             cv2.putText(frame, | ||||
|                         label, (int(x1), int(y1) - 2), | ||||
|                         0, | ||||
|                         0.6, [255, 255, 255], | ||||
|                         thickness=1, | ||||
|                         lineType=cv2.LINE_AA) | ||||
|             cv2.rectangle( | ||||
|                 frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), -1 | ||||
|             ) | ||||
|             cv2.putText( | ||||
|                 frame, label, (int(x1), int(y1) - 2), 0, 0.6, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA | ||||
|             ) | ||||
| 
 | ||||
|         if view_img: | ||||
|             cv2.imshow(Path(source).stem, frame) | ||||
|         if save_img: | ||||
|             video_writer.write(frame) | ||||
| 
 | ||||
|         if cv2.waitKey(1) & 0xFF == ord('q'): | ||||
|         if cv2.waitKey(1) & 0xFF == ord("q"): | ||||
|             break | ||||
|     video_writer.release() | ||||
|     videocapture.release() | ||||
| @ -93,11 +91,11 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, | ||||
| def parse_opt(): | ||||
|     """Parse command line arguments.""" | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path') | ||||
|     parser.add_argument('--source', type=str, required=True, help='video file path') | ||||
|     parser.add_argument('--view-img', action='store_true', help='show results') | ||||
|     parser.add_argument('--save-img', action='store_true', help='save results') | ||||
|     parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') | ||||
|     parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path") | ||||
|     parser.add_argument("--source", type=str, required=True, help="video file path") | ||||
|     parser.add_argument("--view-img", action="store_true", help="show results") | ||||
|     parser.add_argument("--save-img", action="store_true", help="save results") | ||||
|     parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") | ||||
|     return parser.parse_args() | ||||
| 
 | ||||
| 
 | ||||
| @ -106,6 +104,6 @@ def main(opt): | ||||
|     run(**vars(opt)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     opt = parse_opt() | ||||
|     main(opt) | ||||
|  | ||||
| @ -21,18 +21,21 @@ class YOLOv8Seg: | ||||
|         """ | ||||
| 
 | ||||
|         # Build Ort session | ||||
|         self.session = ort.InferenceSession(onnx_model, | ||||
|                                             providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | ||||
|                                             if ort.get_device() == 'GPU' else ['CPUExecutionProvider']) | ||||
|         self.session = ort.InferenceSession( | ||||
|             onnx_model, | ||||
|             providers=["CUDAExecutionProvider", "CPUExecutionProvider"] | ||||
|             if ort.get_device() == "GPU" | ||||
|             else ["CPUExecutionProvider"], | ||||
|         ) | ||||
| 
 | ||||
|         # Numpy dtype: support both FP32 and FP16 onnx model | ||||
|         self.ndtype = np.half if self.session.get_inputs()[0].type == 'tensor(float16)' else np.single | ||||
|         self.ndtype = np.half if self.session.get_inputs()[0].type == "tensor(float16)" else np.single | ||||
| 
 | ||||
|         # Get model width and height(YOLOv8-seg only has one input) | ||||
|         self.model_height, self.model_width = [x.shape for x in self.session.get_inputs()][0][-2:] | ||||
| 
 | ||||
|         # Load COCO class names | ||||
|         self.classes = yaml_load(check_yaml('coco128.yaml'))['names'] | ||||
|         self.classes = yaml_load(check_yaml("coco128.yaml"))["names"] | ||||
| 
 | ||||
|         # Create color palette | ||||
|         self.color_palette = Colors() | ||||
| @ -60,14 +63,16 @@ class YOLOv8Seg: | ||||
|         preds = self.session.run(None, {self.session.get_inputs()[0].name: im}) | ||||
| 
 | ||||
|         # Post-process | ||||
|         boxes, segments, masks = self.postprocess(preds, | ||||
|         boxes, segments, masks = self.postprocess( | ||||
|             preds, | ||||
|             im0=im0, | ||||
|             ratio=ratio, | ||||
|             pad_w=pad_w, | ||||
|             pad_h=pad_h, | ||||
|             conf_threshold=conf_threshold, | ||||
|             iou_threshold=iou_threshold, | ||||
|                                                   nm=nm) | ||||
|             nm=nm, | ||||
|         ) | ||||
|         return boxes, segments, masks | ||||
| 
 | ||||
|     def preprocess(self, img): | ||||
| @ -98,7 +103,7 @@ class YOLOv8Seg: | ||||
|         img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) | ||||
| 
 | ||||
|         # Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional) | ||||
|         img = np.ascontiguousarray(np.einsum('HWC->CHW', img)[::-1], dtype=self.ndtype) / 255.0 | ||||
|         img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype) / 255.0 | ||||
|         img_process = img[None] if len(img.shape) == 3 else img | ||||
|         return img_process, ratio, (pad_w, pad_h) | ||||
| 
 | ||||
| @ -124,7 +129,7 @@ class YOLOv8Seg: | ||||
|         x, protos = preds[0], preds[1]  # Two outputs: predictions and protos | ||||
| 
 | ||||
|         # Transpose the first output: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm) | ||||
|         x = np.einsum('bcn->bnc', x) | ||||
|         x = np.einsum("bcn->bnc", x) | ||||
| 
 | ||||
|         # Predictions filtering by conf-threshold | ||||
|         x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold] | ||||
| @ -138,7 +143,6 @@ class YOLOv8Seg: | ||||
| 
 | ||||
|         # Decode and return | ||||
|         if len(x) > 0: | ||||
| 
 | ||||
|             # Bounding boxes format change: cxcywh -> xyxy | ||||
|             x[..., [0, 1]] -= x[..., [2, 3]] / 2 | ||||
|             x[..., [2, 3]] += x[..., [0, 1]] | ||||
| @ -173,13 +177,13 @@ class YOLOv8Seg: | ||||
|             segments (List): list of segment masks. | ||||
|         """ | ||||
|         segments = [] | ||||
|         for x in masks.astype('uint8'): | ||||
|         for x in masks.astype("uint8"): | ||||
|             c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0]  # CHAIN_APPROX_SIMPLE | ||||
|             if c: | ||||
|                 c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) | ||||
|             else: | ||||
|                 c = np.zeros((0, 2))  # no segments found | ||||
|             segments.append(c.astype('float32')) | ||||
|             segments.append(c.astype("float32")) | ||||
|         return segments | ||||
| 
 | ||||
|     @staticmethod | ||||
| @ -219,7 +223,7 @@ class YOLOv8Seg: | ||||
|         masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw)).transpose(1, 2, 0)  # HWN | ||||
|         masks = np.ascontiguousarray(masks) | ||||
|         masks = self.scale_mask(masks, im0_shape)  # re-scale mask from P3 shape to original input image shape | ||||
|         masks = np.einsum('HWN -> NHW', masks)  # HWN -> NHW | ||||
|         masks = np.einsum("HWN -> NHW", masks)  # HWN -> NHW | ||||
|         masks = self.crop_mask(masks, bboxes) | ||||
|         return np.greater(masks, 0.5) | ||||
| 
 | ||||
| @ -250,8 +254,9 @@ class YOLOv8Seg: | ||||
|         if len(masks.shape) < 2: | ||||
|             raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') | ||||
|         masks = masks[top:bottom, left:right] | ||||
|         masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]), | ||||
|                            interpolation=cv2.INTER_LINEAR)  # INTER_CUBIC would be better | ||||
|         masks = cv2.resize( | ||||
|             masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR | ||||
|         )  # INTER_CUBIC would be better | ||||
|         if len(masks.shape) == 2: | ||||
|             masks = masks[:, :, None] | ||||
|         return masks | ||||
| @ -279,32 +284,46 @@ class YOLOv8Seg: | ||||
|             cv2.fillPoly(im_canvas, np.int32([segment]), self.color_palette(int(cls_), bgr=True)) | ||||
| 
 | ||||
|             # draw bbox rectangle | ||||
|             cv2.rectangle(im, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), | ||||
|                           self.color_palette(int(cls_), bgr=True), 1, cv2.LINE_AA) | ||||
|             cv2.putText(im, f'{self.classes[cls_]}: {conf:.3f}', (int(box[0]), int(box[1] - 9)), | ||||
|                         cv2.FONT_HERSHEY_SIMPLEX, 0.7, self.color_palette(int(cls_), bgr=True), 2, cv2.LINE_AA) | ||||
|             cv2.rectangle( | ||||
|                 im, | ||||
|                 (int(box[0]), int(box[1])), | ||||
|                 (int(box[2]), int(box[3])), | ||||
|                 self.color_palette(int(cls_), bgr=True), | ||||
|                 1, | ||||
|                 cv2.LINE_AA, | ||||
|             ) | ||||
|             cv2.putText( | ||||
|                 im, | ||||
|                 f"{self.classes[cls_]}: {conf:.3f}", | ||||
|                 (int(box[0]), int(box[1] - 9)), | ||||
|                 cv2.FONT_HERSHEY_SIMPLEX, | ||||
|                 0.7, | ||||
|                 self.color_palette(int(cls_), bgr=True), | ||||
|                 2, | ||||
|                 cv2.LINE_AA, | ||||
|             ) | ||||
| 
 | ||||
|         # Mix image | ||||
|         im = cv2.addWeighted(im_canvas, 0.3, im, 0.7, 0) | ||||
| 
 | ||||
|         # Show image | ||||
|         if vis: | ||||
|             cv2.imshow('demo', im) | ||||
|             cv2.imshow("demo", im) | ||||
|             cv2.waitKey(0) | ||||
|             cv2.destroyAllWindows() | ||||
| 
 | ||||
|         # Save image | ||||
|         if save: | ||||
|             cv2.imwrite('demo.jpg', im) | ||||
|             cv2.imwrite("demo.jpg", im) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     # Create an argument parser to handle command-line arguments | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--model', type=str, required=True, help='Path to ONNX model') | ||||
|     parser.add_argument('--source', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image') | ||||
|     parser.add_argument('--conf', type=float, default=0.25, help='Confidence threshold') | ||||
|     parser.add_argument('--iou', type=float, default=0.45, help='NMS IoU threshold') | ||||
|     parser.add_argument("--model", type=str, required=True, help="Path to ONNX model") | ||||
|     parser.add_argument("--source", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image") | ||||
|     parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold") | ||||
|     parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold") | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
|     # Build model | ||||
|  | ||||
| @ -179,5 +179,5 @@ pre-summary-newline = true | ||||
| close-quotes-on-newline = true | ||||
| 
 | ||||
| [tool.codespell] | ||||
| ignore-words-list = "crate,nd,strack,dota,ane,segway,fo,gool,winn,commend" | ||||
| skip = '*.csv,*venv*,docs/??/,docs/mkdocs_??.yml' | ||||
| ignore-words-list = "crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall" | ||||
| skip = "*.pt,*.pth,*.torchscript,*.onnx,*.tflite,*.pb,*.bin,*.param,*.mlmodel,*.engine,*.npy,*.data*,*.csv,*pnnx*,*venv*,__pycache__*,*.ico,*.jpg,*.png,*.mp4,*.mov,/runs,/.git,./docs/??/*.md,./docs/mkdocs_??.yml" | ||||
|  | ||||
| @ -3,6 +3,8 @@ import PIL | ||||
| from ultralytics import Explorer | ||||
| from ultralytics.utils import ASSETS | ||||
| 
 | ||||
| import PIL | ||||
| 
 | ||||
| 
 | ||||
| def test_similarity(): | ||||
|     """Test similarity calculations and SQL queries for correctness and response length.""" | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||
| 
 | ||||
| __version__ = '8.0.238' | ||||
| __version__ = "8.0.239" | ||||
| 
 | ||||
| from ultralytics.data.explorer.explorer import Explorer | ||||
| from ultralytics.models import RTDETR, SAM, YOLO | ||||
| @ -10,4 +10,4 @@ from ultralytics.utils import SETTINGS as settings | ||||
| from ultralytics.utils.checks import check_yolo as checks | ||||
| from ultralytics.utils.downloads import download | ||||
| 
 | ||||
| __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings', 'Explorer' | ||||
| __all__ = "__version__", "YOLO", "NAS", "SAM", "FastSAM", "RTDETR", "checks", "download", "settings", "Explorer" | ||||
|  | ||||
| @ -8,34 +8,53 @@ from pathlib import Path | ||||
| from types import SimpleNamespace | ||||
| from typing import Dict, List, Union | ||||
| 
 | ||||
| from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, ROOT, RUNS_DIR, | ||||
|                                SETTINGS, SETTINGS_YAML, TESTS_RUNNING, IterableSimpleNamespace, __version__, checks, | ||||
|                                colorstr, deprecation_warn, yaml_load, yaml_print) | ||||
| from ultralytics.utils import ( | ||||
|     ASSETS, | ||||
|     DEFAULT_CFG, | ||||
|     DEFAULT_CFG_DICT, | ||||
|     DEFAULT_CFG_PATH, | ||||
|     LOGGER, | ||||
|     RANK, | ||||
|     ROOT, | ||||
|     RUNS_DIR, | ||||
|     SETTINGS, | ||||
|     SETTINGS_YAML, | ||||
|     TESTS_RUNNING, | ||||
|     IterableSimpleNamespace, | ||||
|     __version__, | ||||
|     checks, | ||||
|     colorstr, | ||||
|     deprecation_warn, | ||||
|     yaml_load, | ||||
|     yaml_print, | ||||
| ) | ||||
| 
 | ||||
| # Define valid tasks and modes | ||||
| MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' | ||||
| TASKS = 'detect', 'segment', 'classify', 'pose', 'obb' | ||||
| MODES = "train", "val", "predict", "export", "track", "benchmark" | ||||
| TASKS = "detect", "segment", "classify", "pose", "obb" | ||||
| TASK2DATA = { | ||||
|     'detect': 'coco8.yaml', | ||||
|     'segment': 'coco8-seg.yaml', | ||||
|     'classify': 'imagenet10', | ||||
|     'pose': 'coco8-pose.yaml', | ||||
|     'obb': 'dota8-obb.yaml'}  # not implemented yet | ||||
|     "detect": "coco8.yaml", | ||||
|     "segment": "coco8-seg.yaml", | ||||
|     "classify": "imagenet10", | ||||
|     "pose": "coco8-pose.yaml", | ||||
|     "obb": "dota8-obb.yaml", | ||||
| } | ||||
| TASK2MODEL = { | ||||
|     'detect': 'yolov8n.pt', | ||||
|     'segment': 'yolov8n-seg.pt', | ||||
|     'classify': 'yolov8n-cls.pt', | ||||
|     'pose': 'yolov8n-pose.pt', | ||||
|     'obb': 'yolov8n-obb.pt'} | ||||
|     "detect": "yolov8n.pt", | ||||
|     "segment": "yolov8n-seg.pt", | ||||
|     "classify": "yolov8n-cls.pt", | ||||
|     "pose": "yolov8n-pose.pt", | ||||
|     "obb": "yolov8n-obb.pt", | ||||
| } | ||||
| TASK2METRIC = { | ||||
|     'detect': 'metrics/mAP50-95(B)', | ||||
|     'segment': 'metrics/mAP50-95(M)', | ||||
|     'classify': 'metrics/accuracy_top1', | ||||
|     'pose': 'metrics/mAP50-95(P)', | ||||
|     'obb': 'metrics/mAP50-95(OBB)'} | ||||
|     "detect": "metrics/mAP50-95(B)", | ||||
|     "segment": "metrics/mAP50-95(M)", | ||||
|     "classify": "metrics/accuracy_top1", | ||||
|     "pose": "metrics/mAP50-95(P)", | ||||
|     "obb": "metrics/mAP50-95(OBB)", | ||||
| } | ||||
| 
 | ||||
| CLI_HELP_MSG = \ | ||||
|     f""" | ||||
| CLI_HELP_MSG = f""" | ||||
|     Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax: | ||||
| 
 | ||||
|         yolo TASK MODE ARGS | ||||
| @ -74,16 +93,83 @@ CLI_HELP_MSG = \ | ||||
|     """ | ||||
| 
 | ||||
| # Define keys for arg type checks | ||||
| CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'time' | ||||
| CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', | ||||
|                      'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', | ||||
|                      'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction')  # fraction floats 0.0 - 1.0 | ||||
| CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride', | ||||
|                 'line_width', 'workspace', 'nbs', 'save_period') | ||||
| CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val', | ||||
|                  'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop', | ||||
|                  'save_frames', 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', | ||||
|                  'show_boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile', 'multi_scale') | ||||
| CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time" | ||||
| CFG_FRACTION_KEYS = ( | ||||
|     "dropout", | ||||
|     "iou", | ||||
|     "lr0", | ||||
|     "lrf", | ||||
|     "momentum", | ||||
|     "weight_decay", | ||||
|     "warmup_momentum", | ||||
|     "warmup_bias_lr", | ||||
|     "label_smoothing", | ||||
|     "hsv_h", | ||||
|     "hsv_s", | ||||
|     "hsv_v", | ||||
|     "translate", | ||||
|     "scale", | ||||
|     "perspective", | ||||
|     "flipud", | ||||
|     "fliplr", | ||||
|     "mosaic", | ||||
|     "mixup", | ||||
|     "copy_paste", | ||||
|     "conf", | ||||
|     "iou", | ||||
|     "fraction", | ||||
| )  # fraction floats 0.0 - 1.0 | ||||
| CFG_INT_KEYS = ( | ||||
|     "epochs", | ||||
|     "patience", | ||||
|     "batch", | ||||
|     "workers", | ||||
|     "seed", | ||||
|     "close_mosaic", | ||||
|     "mask_ratio", | ||||
|     "max_det", | ||||
|     "vid_stride", | ||||
|     "line_width", | ||||
|     "workspace", | ||||
|     "nbs", | ||||
|     "save_period", | ||||
| ) | ||||
| CFG_BOOL_KEYS = ( | ||||
|     "save", | ||||
|     "exist_ok", | ||||
|     "verbose", | ||||
|     "deterministic", | ||||
|     "single_cls", | ||||
|     "rect", | ||||
|     "cos_lr", | ||||
|     "overlap_mask", | ||||
|     "val", | ||||
|     "save_json", | ||||
|     "save_hybrid", | ||||
|     "half", | ||||
|     "dnn", | ||||
|     "plots", | ||||
|     "show", | ||||
|     "save_txt", | ||||
|     "save_conf", | ||||
|     "save_crop", | ||||
|     "save_frames", | ||||
|     "show_labels", | ||||
|     "show_conf", | ||||
|     "visualize", | ||||
|     "augment", | ||||
|     "agnostic_nms", | ||||
|     "retina_masks", | ||||
|     "show_boxes", | ||||
|     "keras", | ||||
|     "optimize", | ||||
|     "int8", | ||||
|     "dynamic", | ||||
|     "simplify", | ||||
|     "nms", | ||||
|     "profile", | ||||
|     "multi_scale", | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| def cfg2dict(cfg): | ||||
| @ -119,38 +205,44 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove | ||||
|     # Merge overrides | ||||
|     if overrides: | ||||
|         overrides = cfg2dict(overrides) | ||||
|         if 'save_dir' not in cfg: | ||||
|             overrides.pop('save_dir', None)  # special override keys to ignore | ||||
|         if "save_dir" not in cfg: | ||||
|             overrides.pop("save_dir", None)  # special override keys to ignore | ||||
|         check_dict_alignment(cfg, overrides) | ||||
|         cfg = {**cfg, **overrides}  # merge cfg and overrides dicts (prefer overrides) | ||||
| 
 | ||||
|     # Special handling for numeric project/name | ||||
|     for k in 'project', 'name': | ||||
|     for k in "project", "name": | ||||
|         if k in cfg and isinstance(cfg[k], (int, float)): | ||||
|             cfg[k] = str(cfg[k]) | ||||
|     if cfg.get('name') == 'model':  # assign model to 'name' arg | ||||
|         cfg['name'] = cfg.get('model', '').split('.')[0] | ||||
|     if cfg.get("name") == "model":  # assign model to 'name' arg | ||||
|         cfg["name"] = cfg.get("model", "").split(".")[0] | ||||
|         LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") | ||||
| 
 | ||||
|     # Type and Value checks | ||||
|     for k, v in cfg.items(): | ||||
|         if v is not None:  # None values may be from optional args | ||||
|             if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): | ||||
|                 raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " | ||||
|                                 f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')") | ||||
|                 raise TypeError( | ||||
|                     f"'{k}={v}' is of invalid type {type(v).__name__}. " | ||||
|                     f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" | ||||
|                 ) | ||||
|             elif k in CFG_FRACTION_KEYS: | ||||
|                 if not isinstance(v, (int, float)): | ||||
|                     raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " | ||||
|                                     f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')") | ||||
|                     raise TypeError( | ||||
|                         f"'{k}={v}' is of invalid type {type(v).__name__}. " | ||||
|                         f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" | ||||
|                     ) | ||||
|                 if not (0.0 <= v <= 1.0): | ||||
|                     raise ValueError(f"'{k}={v}' is an invalid value. " | ||||
|                                      f"Valid '{k}' values are between 0.0 and 1.0.") | ||||
|                     raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.") | ||||
|             elif k in CFG_INT_KEYS and not isinstance(v, int): | ||||
|                 raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " | ||||
|                                 f"'{k}' must be an int (i.e. '{k}=8')") | ||||
|                 raise TypeError( | ||||
|                     f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" | ||||
|                 ) | ||||
|             elif k in CFG_BOOL_KEYS and not isinstance(v, bool): | ||||
|                 raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " | ||||
|                                 f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')") | ||||
|                 raise TypeError( | ||||
|                     f"'{k}={v}' is of invalid type {type(v).__name__}. " | ||||
|                     f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" | ||||
|                 ) | ||||
| 
 | ||||
|     # Return instance | ||||
|     return IterableSimpleNamespace(**cfg) | ||||
| @ -159,13 +251,13 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove | ||||
| def get_save_dir(args, name=None): | ||||
|     """Return save_dir as created from train/val/predict arguments.""" | ||||
| 
 | ||||
|     if getattr(args, 'save_dir', None): | ||||
|     if getattr(args, "save_dir", None): | ||||
|         save_dir = args.save_dir | ||||
|     else: | ||||
|         from ultralytics.utils.files import increment_path | ||||
| 
 | ||||
|         project = args.project or (ROOT.parent / 'tests/tmp/runs' if TESTS_RUNNING else RUNS_DIR) / args.task | ||||
|         name = name or args.name or f'{args.mode}' | ||||
|         project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task | ||||
|         name = name or args.name or f"{args.mode}" | ||||
|         save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True) | ||||
| 
 | ||||
|     return Path(save_dir) | ||||
| @ -175,18 +267,18 @@ def _handle_deprecation(custom): | ||||
|     """Hardcoded function to handle deprecated config keys.""" | ||||
| 
 | ||||
|     for key in custom.copy().keys(): | ||||
|         if key == 'boxes': | ||||
|             deprecation_warn(key, 'show_boxes') | ||||
|             custom['show_boxes'] = custom.pop('boxes') | ||||
|         if key == 'hide_labels': | ||||
|             deprecation_warn(key, 'show_labels') | ||||
|             custom['show_labels'] = custom.pop('hide_labels') == 'False' | ||||
|         if key == 'hide_conf': | ||||
|             deprecation_warn(key, 'show_conf') | ||||
|             custom['show_conf'] = custom.pop('hide_conf') == 'False' | ||||
|         if key == 'line_thickness': | ||||
|             deprecation_warn(key, 'line_width') | ||||
|             custom['line_width'] = custom.pop('line_thickness') | ||||
|         if key == "boxes": | ||||
|             deprecation_warn(key, "show_boxes") | ||||
|             custom["show_boxes"] = custom.pop("boxes") | ||||
|         if key == "hide_labels": | ||||
|             deprecation_warn(key, "show_labels") | ||||
|             custom["show_labels"] = custom.pop("hide_labels") == "False" | ||||
|         if key == "hide_conf": | ||||
|             deprecation_warn(key, "show_conf") | ||||
|             custom["show_conf"] = custom.pop("hide_conf") == "False" | ||||
|         if key == "line_thickness": | ||||
|             deprecation_warn(key, "line_width") | ||||
|             custom["line_width"] = custom.pop("line_thickness") | ||||
| 
 | ||||
|     return custom | ||||
| 
 | ||||
| @ -207,11 +299,11 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None): | ||||
|     if mismatched: | ||||
|         from difflib import get_close_matches | ||||
| 
 | ||||
|         string = '' | ||||
|         string = "" | ||||
|         for x in mismatched: | ||||
|             matches = get_close_matches(x, base_keys)  # key list | ||||
|             matches = [f'{k}={base[k]}' if base.get(k) is not None else k for k in matches] | ||||
|             match_str = f'Similar arguments are i.e. {matches}.' if matches else '' | ||||
|             matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches] | ||||
|             match_str = f"Similar arguments are i.e. {matches}." if matches else "" | ||||
|             string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n" | ||||
|         raise SyntaxError(string + CLI_HELP_MSG) from e | ||||
| 
 | ||||
| @ -229,13 +321,13 @@ def merge_equals_args(args: List[str]) -> List[str]: | ||||
|     """ | ||||
|     new_args = [] | ||||
|     for i, arg in enumerate(args): | ||||
|         if arg == '=' and 0 < i < len(args) - 1:  # merge ['arg', '=', 'val'] | ||||
|             new_args[-1] += f'={args[i + 1]}' | ||||
|         if arg == "=" and 0 < i < len(args) - 1:  # merge ['arg', '=', 'val'] | ||||
|             new_args[-1] += f"={args[i + 1]}" | ||||
|             del args[i + 1] | ||||
|         elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]:  # merge ['arg=', 'val'] | ||||
|             new_args.append(f'{arg}{args[i + 1]}') | ||||
|         elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]:  # merge ['arg=', 'val'] | ||||
|             new_args.append(f"{arg}{args[i + 1]}") | ||||
|             del args[i + 1] | ||||
|         elif arg.startswith('=') and i > 0:  # merge ['arg', '=val'] | ||||
|         elif arg.startswith("=") and i > 0:  # merge ['arg', '=val'] | ||||
|             new_args[-1] += arg | ||||
|         else: | ||||
|             new_args.append(arg) | ||||
| @ -259,11 +351,11 @@ def handle_yolo_hub(args: List[str]) -> None: | ||||
|     """ | ||||
|     from ultralytics import hub | ||||
| 
 | ||||
|     if args[0] == 'login': | ||||
|         key = args[1] if len(args) > 1 else '' | ||||
|     if args[0] == "login": | ||||
|         key = args[1] if len(args) > 1 else "" | ||||
|         # Log in to Ultralytics HUB using the provided API key | ||||
|         hub.login(key) | ||||
|     elif args[0] == 'logout': | ||||
|     elif args[0] == "logout": | ||||
|         # Log out from Ultralytics HUB | ||||
|         hub.logout() | ||||
| 
 | ||||
| @ -283,19 +375,19 @@ def handle_yolo_settings(args: List[str]) -> None: | ||||
|         python my_script.py yolo settings reset | ||||
|         ``` | ||||
|     """ | ||||
|     url = 'https://docs.ultralytics.com/quickstart/#ultralytics-settings'  # help URL | ||||
|     url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings"  # help URL | ||||
|     try: | ||||
|         if any(args): | ||||
|             if args[0] == 'reset': | ||||
|             if args[0] == "reset": | ||||
|                 SETTINGS_YAML.unlink()  # delete the settings file | ||||
|                 SETTINGS.reset()  # create new settings | ||||
|                 LOGGER.info('Settings reset successfully')  # inform the user that settings have been reset | ||||
|                 LOGGER.info("Settings reset successfully")  # inform the user that settings have been reset | ||||
|             else:  # save a new setting | ||||
|                 new = dict(parse_key_value_pair(a) for a in args) | ||||
|                 check_dict_alignment(SETTINGS, new) | ||||
|                 SETTINGS.update(new) | ||||
| 
 | ||||
|         LOGGER.info(f'💡 Learn about settings at {url}') | ||||
|         LOGGER.info(f"💡 Learn about settings at {url}") | ||||
|         yaml_print(SETTINGS_YAML)  # print the current settings | ||||
|     except Exception as e: | ||||
|         LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.") | ||||
| @ -303,13 +395,13 @@ def handle_yolo_settings(args: List[str]) -> None: | ||||
| 
 | ||||
| def handle_explorer(): | ||||
|     """Open the Ultralytics Explorer GUI.""" | ||||
|     checks.check_requirements('streamlit') | ||||
|     subprocess.run(['streamlit', 'run', ROOT / 'data/explorer/gui/dash.py', '--server.maxMessageSize', '2048']) | ||||
|     checks.check_requirements("streamlit") | ||||
|     subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"]) | ||||
| 
 | ||||
| 
 | ||||
| def parse_key_value_pair(pair): | ||||
|     """Parse one 'key=value' pair and return key and value.""" | ||||
|     k, v = pair.split('=', 1)  # split on first '=' sign | ||||
|     k, v = pair.split("=", 1)  # split on first '=' sign | ||||
|     k, v = k.strip(), v.strip()  # remove spaces | ||||
|     assert v, f"missing '{k}' value" | ||||
|     return k, smart_value(v) | ||||
| @ -318,11 +410,11 @@ def parse_key_value_pair(pair): | ||||
| def smart_value(v): | ||||
|     """Convert a string to an underlying type such as int, float, bool, etc.""" | ||||
|     v_lower = v.lower() | ||||
|     if v_lower == 'none': | ||||
|     if v_lower == "none": | ||||
|         return None | ||||
|     elif v_lower == 'true': | ||||
|     elif v_lower == "true": | ||||
|         return True | ||||
|     elif v_lower == 'false': | ||||
|     elif v_lower == "false": | ||||
|         return False | ||||
|     else: | ||||
|         with contextlib.suppress(Exception): | ||||
| @ -330,7 +422,7 @@ def smart_value(v): | ||||
|         return v | ||||
| 
 | ||||
| 
 | ||||
| def entrypoint(debug=''): | ||||
| def entrypoint(debug=""): | ||||
|     """ | ||||
|     This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed | ||||
|     to the package. | ||||
| @ -345,139 +437,150 @@ def entrypoint(debug=''): | ||||
|     It uses the package's default cfg and initializes it using the passed overrides. | ||||
|     Then it calls the CLI function with the composed cfg | ||||
|     """ | ||||
|     args = (debug.split(' ') if debug else sys.argv)[1:] | ||||
|     args = (debug.split(" ") if debug else sys.argv)[1:] | ||||
|     if not args:  # no arguments passed | ||||
|         LOGGER.info(CLI_HELP_MSG) | ||||
|         return | ||||
| 
 | ||||
|     special = { | ||||
|         'help': lambda: LOGGER.info(CLI_HELP_MSG), | ||||
|         'checks': checks.collect_system_info, | ||||
|         'version': lambda: LOGGER.info(__version__), | ||||
|         'settings': lambda: handle_yolo_settings(args[1:]), | ||||
|         'cfg': lambda: yaml_print(DEFAULT_CFG_PATH), | ||||
|         'hub': lambda: handle_yolo_hub(args[1:]), | ||||
|         'login': lambda: handle_yolo_hub(args), | ||||
|         'copy-cfg': copy_default_cfg, | ||||
|         'explorer': lambda: handle_explorer()} | ||||
|         "help": lambda: LOGGER.info(CLI_HELP_MSG), | ||||
|         "checks": checks.collect_system_info, | ||||
|         "version": lambda: LOGGER.info(__version__), | ||||
|         "settings": lambda: handle_yolo_settings(args[1:]), | ||||
|         "cfg": lambda: yaml_print(DEFAULT_CFG_PATH), | ||||
|         "hub": lambda: handle_yolo_hub(args[1:]), | ||||
|         "login": lambda: handle_yolo_hub(args), | ||||
|         "copy-cfg": copy_default_cfg, | ||||
|         "explorer": lambda: handle_explorer(), | ||||
|     } | ||||
|     full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} | ||||
| 
 | ||||
|     # Define common misuses of special commands, i.e. -h, -help, --help | ||||
|     special.update({k[0]: v for k, v in special.items()})  # singular | ||||
|     special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')})  # singular | ||||
|     special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}} | ||||
|     special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")})  # singular | ||||
|     special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}} | ||||
| 
 | ||||
|     overrides = {}  # basic overrides, i.e. imgsz=320 | ||||
|     for a in merge_equals_args(args):  # merge spaces around '=' sign | ||||
|         if a.startswith('--'): | ||||
|         if a.startswith("--"): | ||||
|             LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") | ||||
|             a = a[2:] | ||||
|         if a.endswith(','): | ||||
|         if a.endswith(","): | ||||
|             LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") | ||||
|             a = a[:-1] | ||||
|         if '=' in a: | ||||
|         if "=" in a: | ||||
|             try: | ||||
|                 k, v = parse_key_value_pair(a) | ||||
|                 if k == 'cfg' and v is not None:  # custom.yaml passed | ||||
|                     LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}') | ||||
|                     overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'} | ||||
|                 if k == "cfg" and v is not None:  # custom.yaml passed | ||||
|                     LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}") | ||||
|                     overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"} | ||||
|                 else: | ||||
|                     overrides[k] = v | ||||
|             except (NameError, SyntaxError, ValueError, AssertionError) as e: | ||||
|                 check_dict_alignment(full_args_dict, {a: ''}, e) | ||||
|                 check_dict_alignment(full_args_dict, {a: ""}, e) | ||||
| 
 | ||||
|         elif a in TASKS: | ||||
|             overrides['task'] = a | ||||
|             overrides["task"] = a | ||||
|         elif a in MODES: | ||||
|             overrides['mode'] = a | ||||
|             overrides["mode"] = a | ||||
|         elif a.lower() in special: | ||||
|             special[a.lower()]() | ||||
|             return | ||||
|         elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): | ||||
|             overrides[a] = True  # auto-True for default bool args, i.e. 'yolo show' sets show=True | ||||
|         elif a in DEFAULT_CFG_DICT: | ||||
|             raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " | ||||
|                               f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}") | ||||
|             raise SyntaxError( | ||||
|                 f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " | ||||
|                 f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}" | ||||
|             ) | ||||
|         else: | ||||
|             check_dict_alignment(full_args_dict, {a: ''}) | ||||
|             check_dict_alignment(full_args_dict, {a: ""}) | ||||
| 
 | ||||
|     # Check keys | ||||
|     check_dict_alignment(full_args_dict, overrides) | ||||
| 
 | ||||
|     # Mode | ||||
|     mode = overrides.get('mode') | ||||
|     mode = overrides.get("mode") | ||||
|     if mode is None: | ||||
|         mode = DEFAULT_CFG.mode or 'predict' | ||||
|         mode = DEFAULT_CFG.mode or "predict" | ||||
|         LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") | ||||
|     elif mode not in MODES: | ||||
|         raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") | ||||
| 
 | ||||
|     # Task | ||||
|     task = overrides.pop('task', None) | ||||
|     task = overrides.pop("task", None) | ||||
|     if task: | ||||
|         if task not in TASKS: | ||||
|             raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") | ||||
|         if 'model' not in overrides: | ||||
|             overrides['model'] = TASK2MODEL[task] | ||||
|         if "model" not in overrides: | ||||
|             overrides["model"] = TASK2MODEL[task] | ||||
| 
 | ||||
|     # Model | ||||
|     model = overrides.pop('model', DEFAULT_CFG.model) | ||||
|     model = overrides.pop("model", DEFAULT_CFG.model) | ||||
|     if model is None: | ||||
|         model = 'yolov8n.pt' | ||||
|         model = "yolov8n.pt" | ||||
|         LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.") | ||||
|     overrides['model'] = model | ||||
|     overrides["model"] = model | ||||
|     stem = Path(model).stem.lower() | ||||
|     if 'rtdetr' in stem:  # guess architecture | ||||
|     if "rtdetr" in stem:  # guess architecture | ||||
|         from ultralytics import RTDETR | ||||
| 
 | ||||
|         model = RTDETR(model)  # no task argument | ||||
|     elif 'fastsam' in stem: | ||||
|     elif "fastsam" in stem: | ||||
|         from ultralytics import FastSAM | ||||
| 
 | ||||
|         model = FastSAM(model) | ||||
|     elif 'sam' in stem: | ||||
|     elif "sam" in stem: | ||||
|         from ultralytics import SAM | ||||
| 
 | ||||
|         model = SAM(model) | ||||
|     else: | ||||
|         from ultralytics import YOLO | ||||
| 
 | ||||
|         model = YOLO(model, task=task) | ||||
|     if isinstance(overrides.get('pretrained'), str): | ||||
|         model.load(overrides['pretrained']) | ||||
|     if isinstance(overrides.get("pretrained"), str): | ||||
|         model.load(overrides["pretrained"]) | ||||
| 
 | ||||
|     # Task Update | ||||
|     if task != model.task: | ||||
|         if task: | ||||
|             LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " | ||||
|                            f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.") | ||||
|             LOGGER.warning( | ||||
|                 f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " | ||||
|                 f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model." | ||||
|             ) | ||||
|         task = model.task | ||||
| 
 | ||||
|     # Mode | ||||
|     if mode in ('predict', 'track') and 'source' not in overrides: | ||||
|         overrides['source'] = DEFAULT_CFG.source or ASSETS | ||||
|     if mode in ("predict", "track") and "source" not in overrides: | ||||
|         overrides["source"] = DEFAULT_CFG.source or ASSETS | ||||
|         LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") | ||||
|     elif mode in ('train', 'val'): | ||||
|         if 'data' not in overrides and 'resume' not in overrides: | ||||
|             overrides['data'] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) | ||||
|     elif mode in ("train", "val"): | ||||
|         if "data" not in overrides and "resume" not in overrides: | ||||
|             overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) | ||||
|             LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.") | ||||
|     elif mode == 'export': | ||||
|         if 'format' not in overrides: | ||||
|             overrides['format'] = DEFAULT_CFG.format or 'torchscript' | ||||
|     elif mode == "export": | ||||
|         if "format" not in overrides: | ||||
|             overrides["format"] = DEFAULT_CFG.format or "torchscript" | ||||
|             LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.") | ||||
| 
 | ||||
|     # Run command in python | ||||
|     getattr(model, mode)(**overrides)  # default args from model | ||||
| 
 | ||||
|     # Show help | ||||
|     LOGGER.info(f'💡 Learn more at https://docs.ultralytics.com/modes/{mode}') | ||||
|     LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}") | ||||
| 
 | ||||
| 
 | ||||
| # Special modes -------------------------------------------------------------------------------------------------------- | ||||
| def copy_default_cfg(): | ||||
|     """Copy and create a new default configuration file with '_copy' appended to its name.""" | ||||
|     new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml') | ||||
|     new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml") | ||||
|     shutil.copy2(DEFAULT_CFG_PATH, new_file) | ||||
|     LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n' | ||||
|                 f"Example YOLO command with this new custom cfg:\n    yolo cfg='{new_file}' imgsz=320 batch=8") | ||||
|     LOGGER.info( | ||||
|         f"{DEFAULT_CFG_PATH} copied to {new_file}\n" | ||||
|         f"Example YOLO command with this new custom cfg:\n    yolo cfg='{new_file}' imgsz=320 batch=8" | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     # Example: entrypoint(debug='yolo predict model=yolov8n.pt') | ||||
|     entrypoint(debug='') | ||||
|     entrypoint(debug="") | ||||
|  | ||||
| @ -4,5 +4,12 @@ from .base import BaseDataset | ||||
| from .build import build_dataloader, build_yolo_dataset, load_inference_source | ||||
| from .dataset import ClassificationDataset, SemanticDataset, YOLODataset | ||||
| 
 | ||||
| __all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset', | ||||
|            'build_dataloader', 'load_inference_source') | ||||
| __all__ = ( | ||||
|     "BaseDataset", | ||||
|     "ClassificationDataset", | ||||
|     "SemanticDataset", | ||||
|     "YOLODataset", | ||||
|     "build_yolo_dataset", | ||||
|     "build_dataloader", | ||||
|     "load_inference_source", | ||||
| ) | ||||
|  | ||||
| @ -5,7 +5,7 @@ from pathlib import Path | ||||
| from ultralytics import SAM, YOLO | ||||
| 
 | ||||
| 
 | ||||
| def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None): | ||||
| def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="", output_dir=None): | ||||
|     """ | ||||
|     Automatically annotates images using a YOLO object detection model and a SAM segmentation model. | ||||
| 
 | ||||
| @ -29,7 +29,7 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', | ||||
| 
 | ||||
|     data = Path(data) | ||||
|     if not output_dir: | ||||
|         output_dir = data.parent / f'{data.stem}_auto_annotate_labels' | ||||
|         output_dir = data.parent / f"{data.stem}_auto_annotate_labels" | ||||
|     Path(output_dir).mkdir(exist_ok=True, parents=True) | ||||
| 
 | ||||
|     det_results = det_model(data, stream=True, device=device) | ||||
| @ -41,10 +41,10 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', | ||||
|             sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device) | ||||
|             segments = sam_results[0].masks.xyn  # noqa | ||||
| 
 | ||||
|             with open(f'{str(Path(output_dir) / Path(result.path).stem)}.txt', 'w') as f: | ||||
|             with open(f"{str(Path(output_dir) / Path(result.path).stem)}.txt", "w") as f: | ||||
|                 for i in range(len(segments)): | ||||
|                     s = segments[i] | ||||
|                     if len(s) == 0: | ||||
|                         continue | ||||
|                     segment = map(str, segments[i].reshape(-1).tolist()) | ||||
|                     f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n') | ||||
|                     f.write(f"{class_ids[i]} " + " ".join(segment) + "\n") | ||||
|  | ||||
| @ -117,11 +117,11 @@ class BaseMixTransform: | ||||
|         if self.pre_transform is not None: | ||||
|             for i, data in enumerate(mix_labels): | ||||
|                 mix_labels[i] = self.pre_transform(data) | ||||
|         labels['mix_labels'] = mix_labels | ||||
|         labels["mix_labels"] = mix_labels | ||||
| 
 | ||||
|         # Mosaic or MixUp | ||||
|         labels = self._mix_transform(labels) | ||||
|         labels.pop('mix_labels', None) | ||||
|         labels.pop("mix_labels", None) | ||||
|         return labels | ||||
| 
 | ||||
|     def _mix_transform(self, labels): | ||||
| @ -149,8 +149,8 @@ class Mosaic(BaseMixTransform): | ||||
| 
 | ||||
|     def __init__(self, dataset, imgsz=640, p=1.0, n=4): | ||||
|         """Initializes the object with a dataset, image size, probability, and border.""" | ||||
|         assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.' | ||||
|         assert n in (4, 9), 'grid must be equal to 4 or 9.' | ||||
|         assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}." | ||||
|         assert n in (4, 9), "grid must be equal to 4 or 9." | ||||
|         super().__init__(dataset=dataset, p=p) | ||||
|         self.dataset = dataset | ||||
|         self.imgsz = imgsz | ||||
| @ -166,20 +166,21 @@ class Mosaic(BaseMixTransform): | ||||
| 
 | ||||
|     def _mix_transform(self, labels): | ||||
|         """Apply mixup transformation to the input image and labels.""" | ||||
|         assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.' | ||||
|         assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.' | ||||
|         return self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9( | ||||
|             labels)  # This code is modified for mosaic3 method. | ||||
|         assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive." | ||||
|         assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment." | ||||
|         return ( | ||||
|             self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels) | ||||
|         )  # This code is modified for mosaic3 method. | ||||
| 
 | ||||
|     def _mosaic3(self, labels): | ||||
|         """Create a 1x3 image mosaic.""" | ||||
|         mosaic_labels = [] | ||||
|         s = self.imgsz | ||||
|         for i in range(3): | ||||
|             labels_patch = labels if i == 0 else labels['mix_labels'][i - 1] | ||||
|             labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] | ||||
|             # Load image | ||||
|             img = labels_patch['img'] | ||||
|             h, w = labels_patch.pop('resized_shape') | ||||
|             img = labels_patch["img"] | ||||
|             h, w = labels_patch.pop("resized_shape") | ||||
| 
 | ||||
|             # Place img in img3 | ||||
|             if i == 0:  # center | ||||
| @ -194,7 +195,7 @@ class Mosaic(BaseMixTransform): | ||||
|             padw, padh = c[:2] | ||||
|             x1, y1, x2, y2 = (max(x, 0) for x in c)  # allocate coords | ||||
| 
 | ||||
|             img3[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:]  # img3[ymin:ymax, xmin:xmax] | ||||
|             img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :]  # img3[ymin:ymax, xmin:xmax] | ||||
|             # hp, wp = h, w  # height, width previous for next iteration | ||||
| 
 | ||||
|             # Labels assuming imgsz*2 mosaic size | ||||
| @ -202,7 +203,7 @@ class Mosaic(BaseMixTransform): | ||||
|             mosaic_labels.append(labels_patch) | ||||
|         final_labels = self._cat_labels(mosaic_labels) | ||||
| 
 | ||||
|         final_labels['img'] = img3[-self.border[0]:self.border[0], -self.border[1]:self.border[1]] | ||||
|         final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] | ||||
|         return final_labels | ||||
| 
 | ||||
|     def _mosaic4(self, labels): | ||||
| @ -211,10 +212,10 @@ class Mosaic(BaseMixTransform): | ||||
|         s = self.imgsz | ||||
|         yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border)  # mosaic center x, y | ||||
|         for i in range(4): | ||||
|             labels_patch = labels if i == 0 else labels['mix_labels'][i - 1] | ||||
|             labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] | ||||
|             # Load image | ||||
|             img = labels_patch['img'] | ||||
|             h, w = labels_patch.pop('resized_shape') | ||||
|             img = labels_patch["img"] | ||||
|             h, w = labels_patch.pop("resized_shape") | ||||
| 
 | ||||
|             # Place img in img4 | ||||
|             if i == 0:  # top left | ||||
| @ -238,7 +239,7 @@ class Mosaic(BaseMixTransform): | ||||
|             labels_patch = self._update_labels(labels_patch, padw, padh) | ||||
|             mosaic_labels.append(labels_patch) | ||||
|         final_labels = self._cat_labels(mosaic_labels) | ||||
|         final_labels['img'] = img4 | ||||
|         final_labels["img"] = img4 | ||||
|         return final_labels | ||||
| 
 | ||||
|     def _mosaic9(self, labels): | ||||
| @ -247,10 +248,10 @@ class Mosaic(BaseMixTransform): | ||||
|         s = self.imgsz | ||||
|         hp, wp = -1, -1  # height, width previous | ||||
|         for i in range(9): | ||||
|             labels_patch = labels if i == 0 else labels['mix_labels'][i - 1] | ||||
|             labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] | ||||
|             # Load image | ||||
|             img = labels_patch['img'] | ||||
|             h, w = labels_patch.pop('resized_shape') | ||||
|             img = labels_patch["img"] | ||||
|             h, w = labels_patch.pop("resized_shape") | ||||
| 
 | ||||
|             # Place img in img9 | ||||
|             if i == 0:  # center | ||||
| @ -278,7 +279,7 @@ class Mosaic(BaseMixTransform): | ||||
|             x1, y1, x2, y2 = (max(x, 0) for x in c)  # allocate coords | ||||
| 
 | ||||
|             # Image | ||||
|             img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:]  # img9[ymin:ymax, xmin:xmax] | ||||
|             img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :]  # img9[ymin:ymax, xmin:xmax] | ||||
|             hp, wp = h, w  # height, width previous for next iteration | ||||
| 
 | ||||
|             # Labels assuming imgsz*2 mosaic size | ||||
| @ -286,16 +287,16 @@ class Mosaic(BaseMixTransform): | ||||
|             mosaic_labels.append(labels_patch) | ||||
|         final_labels = self._cat_labels(mosaic_labels) | ||||
| 
 | ||||
|         final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]] | ||||
|         final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] | ||||
|         return final_labels | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _update_labels(labels, padw, padh): | ||||
|         """Update labels.""" | ||||
|         nh, nw = labels['img'].shape[:2] | ||||
|         labels['instances'].convert_bbox(format='xyxy') | ||||
|         labels['instances'].denormalize(nw, nh) | ||||
|         labels['instances'].add_padding(padw, padh) | ||||
|         nh, nw = labels["img"].shape[:2] | ||||
|         labels["instances"].convert_bbox(format="xyxy") | ||||
|         labels["instances"].denormalize(nw, nh) | ||||
|         labels["instances"].add_padding(padw, padh) | ||||
|         return labels | ||||
| 
 | ||||
|     def _cat_labels(self, mosaic_labels): | ||||
| @ -306,18 +307,20 @@ class Mosaic(BaseMixTransform): | ||||
|         instances = [] | ||||
|         imgsz = self.imgsz * 2  # mosaic imgsz | ||||
|         for labels in mosaic_labels: | ||||
|             cls.append(labels['cls']) | ||||
|             instances.append(labels['instances']) | ||||
|             cls.append(labels["cls"]) | ||||
|             instances.append(labels["instances"]) | ||||
|         # Final labels | ||||
|         final_labels = { | ||||
|             'im_file': mosaic_labels[0]['im_file'], | ||||
|             'ori_shape': mosaic_labels[0]['ori_shape'], | ||||
|             'resized_shape': (imgsz, imgsz), | ||||
|             'cls': np.concatenate(cls, 0), | ||||
|             'instances': Instances.concatenate(instances, axis=0), | ||||
|             'mosaic_border': self.border}  # final_labels | ||||
|         final_labels['instances'].clip(imgsz, imgsz) | ||||
|         good = final_labels['instances'].remove_zero_area_boxes() | ||||
|         final_labels['cls'] = final_labels['cls'][good] | ||||
|             "im_file": mosaic_labels[0]["im_file"], | ||||
|             "ori_shape": mosaic_labels[0]["ori_shape"], | ||||
|             "resized_shape": (imgsz, imgsz), | ||||
|             "cls": np.concatenate(cls, 0), | ||||
|             "instances": Instances.concatenate(instances, axis=0), | ||||
|             "mosaic_border": self.border, | ||||
|         } | ||||
|         final_labels["instances"].clip(imgsz, imgsz) | ||||
|         good = final_labels["instances"].remove_zero_area_boxes() | ||||
|         final_labels["cls"] = final_labels["cls"][good] | ||||
|         return final_labels | ||||
| 
 | ||||
| 
 | ||||
| @ -335,10 +338,10 @@ class MixUp(BaseMixTransform): | ||||
|     def _mix_transform(self, labels): | ||||
|         """Applies MixUp augmentation as per https://arxiv.org/pdf/1710.09412.pdf.""" | ||||
|         r = np.random.beta(32.0, 32.0)  # mixup ratio, alpha=beta=32.0 | ||||
|         labels2 = labels['mix_labels'][0] | ||||
|         labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8) | ||||
|         labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0) | ||||
|         labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0) | ||||
|         labels2 = labels["mix_labels"][0] | ||||
|         labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8) | ||||
|         labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0) | ||||
|         labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0) | ||||
|         return labels | ||||
| 
 | ||||
| 
 | ||||
| @ -366,14 +369,9 @@ class RandomPerspective: | ||||
|         box_candidates(box1, box2): Filters out bounding boxes that don't meet certain criteria post-transformation. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, | ||||
|                  degrees=0.0, | ||||
|                  translate=0.1, | ||||
|                  scale=0.5, | ||||
|                  shear=0.0, | ||||
|                  perspective=0.0, | ||||
|                  border=(0, 0), | ||||
|                  pre_transform=None): | ||||
|     def __init__( | ||||
|         self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None | ||||
|     ): | ||||
|         """Initializes RandomPerspective object with transformation parameters.""" | ||||
| 
 | ||||
|         self.degrees = degrees | ||||
| @ -519,18 +517,18 @@ class RandomPerspective: | ||||
|         Args: | ||||
|             labels (dict): a dict of `bboxes`, `segments`, `keypoints`. | ||||
|         """ | ||||
|         if self.pre_transform and 'mosaic_border' not in labels: | ||||
|         if self.pre_transform and "mosaic_border" not in labels: | ||||
|             labels = self.pre_transform(labels) | ||||
|         labels.pop('ratio_pad', None)  # do not need ratio pad | ||||
|         labels.pop("ratio_pad", None)  # do not need ratio pad | ||||
| 
 | ||||
|         img = labels['img'] | ||||
|         cls = labels['cls'] | ||||
|         instances = labels.pop('instances') | ||||
|         img = labels["img"] | ||||
|         cls = labels["cls"] | ||||
|         instances = labels.pop("instances") | ||||
|         # Make sure the coord formats are right | ||||
|         instances.convert_bbox(format='xyxy') | ||||
|         instances.convert_bbox(format="xyxy") | ||||
|         instances.denormalize(*img.shape[:2][::-1]) | ||||
| 
 | ||||
|         border = labels.pop('mosaic_border', self.border) | ||||
|         border = labels.pop("mosaic_border", self.border) | ||||
|         self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2  # w, h | ||||
|         # M is affine matrix | ||||
|         # Scale for func:`box_candidates` | ||||
| @ -546,20 +544,20 @@ class RandomPerspective: | ||||
| 
 | ||||
|         if keypoints is not None: | ||||
|             keypoints = self.apply_keypoints(keypoints, M) | ||||
|         new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False) | ||||
|         new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False) | ||||
|         # Clip | ||||
|         new_instances.clip(*self.size) | ||||
| 
 | ||||
|         # Filter instances | ||||
|         instances.scale(scale_w=scale, scale_h=scale, bbox_only=True) | ||||
|         # Make the bboxes have the same scale with new_bboxes | ||||
|         i = self.box_candidates(box1=instances.bboxes.T, | ||||
|                                 box2=new_instances.bboxes.T, | ||||
|                                 area_thr=0.01 if len(segments) else 0.10) | ||||
|         labels['instances'] = new_instances[i] | ||||
|         labels['cls'] = cls[i] | ||||
|         labels['img'] = img | ||||
|         labels['resized_shape'] = img.shape[:2] | ||||
|         i = self.box_candidates( | ||||
|             box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10 | ||||
|         ) | ||||
|         labels["instances"] = new_instances[i] | ||||
|         labels["cls"] = cls[i] | ||||
|         labels["img"] = img | ||||
|         labels["resized_shape"] = img.shape[:2] | ||||
|         return labels | ||||
| 
 | ||||
|     def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): | ||||
| @ -611,7 +609,7 @@ class RandomHSV: | ||||
| 
 | ||||
|         The modified image replaces the original image in the input 'labels' dict. | ||||
|         """ | ||||
|         img = labels['img'] | ||||
|         img = labels["img"] | ||||
|         if self.hgain or self.sgain or self.vgain: | ||||
|             r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1  # random gains | ||||
|             hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) | ||||
| @ -634,7 +632,7 @@ class RandomFlip: | ||||
|     Also updates any instances (bounding boxes, keypoints, etc.) accordingly. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None: | ||||
|     def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None: | ||||
|         """ | ||||
|         Initializes the RandomFlip class with probability and direction. | ||||
| 
 | ||||
| @ -644,7 +642,7 @@ class RandomFlip: | ||||
|                 Default is 'horizontal'. | ||||
|             flip_idx (array-like, optional): Index mapping for flipping keypoints, if any. | ||||
|         """ | ||||
|         assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}' | ||||
|         assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}" | ||||
|         assert 0 <= p <= 1.0 | ||||
| 
 | ||||
|         self.p = p | ||||
| @ -662,25 +660,25 @@ class RandomFlip: | ||||
|         Returns: | ||||
|             (dict): The same dict with the flipped image and updated instances under the 'img' and 'instances' keys. | ||||
|         """ | ||||
|         img = labels['img'] | ||||
|         instances = labels.pop('instances') | ||||
|         instances.convert_bbox(format='xywh') | ||||
|         img = labels["img"] | ||||
|         instances = labels.pop("instances") | ||||
|         instances.convert_bbox(format="xywh") | ||||
|         h, w = img.shape[:2] | ||||
|         h = 1 if instances.normalized else h | ||||
|         w = 1 if instances.normalized else w | ||||
| 
 | ||||
|         # Flip up-down | ||||
|         if self.direction == 'vertical' and random.random() < self.p: | ||||
|         if self.direction == "vertical" and random.random() < self.p: | ||||
|             img = np.flipud(img) | ||||
|             instances.flipud(h) | ||||
|         if self.direction == 'horizontal' and random.random() < self.p: | ||||
|         if self.direction == "horizontal" and random.random() < self.p: | ||||
|             img = np.fliplr(img) | ||||
|             instances.fliplr(w) | ||||
|             # For keypoints | ||||
|             if self.flip_idx is not None and instances.keypoints is not None: | ||||
|                 instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :]) | ||||
|         labels['img'] = np.ascontiguousarray(img) | ||||
|         labels['instances'] = instances | ||||
|         labels["img"] = np.ascontiguousarray(img) | ||||
|         labels["instances"] = instances | ||||
|         return labels | ||||
| 
 | ||||
| 
 | ||||
| @ -700,9 +698,9 @@ class LetterBox: | ||||
|         """Return updated labels and image with added border.""" | ||||
|         if labels is None: | ||||
|             labels = {} | ||||
|         img = labels.get('img') if image is None else image | ||||
|         img = labels.get("img") if image is None else image | ||||
|         shape = img.shape[:2]  # current shape [height, width] | ||||
|         new_shape = labels.pop('rect_shape', self.new_shape) | ||||
|         new_shape = labels.pop("rect_shape", self.new_shape) | ||||
|         if isinstance(new_shape, int): | ||||
|             new_shape = (new_shape, new_shape) | ||||
| 
 | ||||
| @ -730,25 +728,26 @@ class LetterBox: | ||||
|             img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) | ||||
|         top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1)) | ||||
|         left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1)) | ||||
|         img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, | ||||
|                                  value=(114, 114, 114))  # add border | ||||
|         if labels.get('ratio_pad'): | ||||
|             labels['ratio_pad'] = (labels['ratio_pad'], (left, top))  # for evaluation | ||||
|         img = cv2.copyMakeBorder( | ||||
|             img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) | ||||
|         )  # add border | ||||
|         if labels.get("ratio_pad"): | ||||
|             labels["ratio_pad"] = (labels["ratio_pad"], (left, top))  # for evaluation | ||||
| 
 | ||||
|         if len(labels): | ||||
|             labels = self._update_labels(labels, ratio, dw, dh) | ||||
|             labels['img'] = img | ||||
|             labels['resized_shape'] = new_shape | ||||
|             labels["img"] = img | ||||
|             labels["resized_shape"] = new_shape | ||||
|             return labels | ||||
|         else: | ||||
|             return img | ||||
| 
 | ||||
|     def _update_labels(self, labels, ratio, padw, padh): | ||||
|         """Update labels.""" | ||||
|         labels['instances'].convert_bbox(format='xyxy') | ||||
|         labels['instances'].denormalize(*labels['img'].shape[:2][::-1]) | ||||
|         labels['instances'].scale(*ratio) | ||||
|         labels['instances'].add_padding(padw, padh) | ||||
|         labels["instances"].convert_bbox(format="xyxy") | ||||
|         labels["instances"].denormalize(*labels["img"].shape[:2][::-1]) | ||||
|         labels["instances"].scale(*ratio) | ||||
|         labels["instances"].add_padding(padw, padh) | ||||
|         return labels | ||||
| 
 | ||||
| 
 | ||||
| @ -785,11 +784,11 @@ class CopyPaste: | ||||
|             1. Instances are expected to have 'segments' as one of their attributes for this augmentation to work. | ||||
|             2. This method modifies the input dictionary 'labels' in place. | ||||
|         """ | ||||
|         im = labels['img'] | ||||
|         cls = labels['cls'] | ||||
|         im = labels["img"] | ||||
|         cls = labels["cls"] | ||||
|         h, w = im.shape[:2] | ||||
|         instances = labels.pop('instances') | ||||
|         instances.convert_bbox(format='xyxy') | ||||
|         instances = labels.pop("instances") | ||||
|         instances.convert_bbox(format="xyxy") | ||||
|         instances.denormalize(w, h) | ||||
|         if self.p and len(instances.segments): | ||||
|             n = len(instances) | ||||
| @ -812,9 +811,9 @@ class CopyPaste: | ||||
|             i = cv2.flip(im_new, 1).astype(bool) | ||||
|             im[i] = result[i] | ||||
| 
 | ||||
|         labels['img'] = im | ||||
|         labels['cls'] = cls | ||||
|         labels['instances'] = instances | ||||
|         labels["img"] = im | ||||
|         labels["cls"] = cls | ||||
|         labels["instances"] = instances | ||||
|         return labels | ||||
| 
 | ||||
| 
 | ||||
| @ -831,12 +830,13 @@ class Albumentations: | ||||
|         """Initialize the transform object for YOLO bbox formatted params.""" | ||||
|         self.p = p | ||||
|         self.transform = None | ||||
|         prefix = colorstr('albumentations: ') | ||||
|         prefix = colorstr("albumentations: ") | ||||
|         try: | ||||
|             import albumentations as A | ||||
| 
 | ||||
|             check_version(A.__version__, '1.0.3', hard=True)  # version requirement | ||||
|             check_version(A.__version__, "1.0.3", hard=True)  # version requirement | ||||
| 
 | ||||
|             # Transforms | ||||
|             T = [ | ||||
|                 A.Blur(p=0.01), | ||||
|                 A.MedianBlur(p=0.01), | ||||
| @ -844,31 +844,32 @@ class Albumentations: | ||||
|                 A.CLAHE(p=0.01), | ||||
|                 A.RandomBrightnessContrast(p=0.0), | ||||
|                 A.RandomGamma(p=0.0), | ||||
|                 A.ImageCompression(quality_lower=75, p=0.0)]  # transforms | ||||
|             self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])) | ||||
|                 A.ImageCompression(quality_lower=75, p=0.0), | ||||
|             ] | ||||
|             self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"])) | ||||
| 
 | ||||
|             LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p)) | ||||
|             LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p)) | ||||
|         except ImportError:  # package not installed, skip | ||||
|             pass | ||||
|         except Exception as e: | ||||
|             LOGGER.info(f'{prefix}{e}') | ||||
|             LOGGER.info(f"{prefix}{e}") | ||||
| 
 | ||||
|     def __call__(self, labels): | ||||
|         """Generates object detections and returns a dictionary with detection results.""" | ||||
|         im = labels['img'] | ||||
|         cls = labels['cls'] | ||||
|         im = labels["img"] | ||||
|         cls = labels["cls"] | ||||
|         if len(cls): | ||||
|             labels['instances'].convert_bbox('xywh') | ||||
|             labels['instances'].normalize(*im.shape[:2][::-1]) | ||||
|             bboxes = labels['instances'].bboxes | ||||
|             labels["instances"].convert_bbox("xywh") | ||||
|             labels["instances"].normalize(*im.shape[:2][::-1]) | ||||
|             bboxes = labels["instances"].bboxes | ||||
|             # TODO: add supports of segments and keypoints | ||||
|             if self.transform and random.random() < self.p: | ||||
|                 new = self.transform(image=im, bboxes=bboxes, class_labels=cls)  # transformed | ||||
|                 if len(new['class_labels']) > 0:  # skip update if no bbox in new im | ||||
|                     labels['img'] = new['image'] | ||||
|                     labels['cls'] = np.array(new['class_labels']) | ||||
|                     bboxes = np.array(new['bboxes'], dtype=np.float32) | ||||
|             labels['instances'].update(bboxes=bboxes) | ||||
|                 if len(new["class_labels"]) > 0:  # skip update if no bbox in new im | ||||
|                     labels["img"] = new["image"] | ||||
|                     labels["cls"] = np.array(new["class_labels"]) | ||||
|                     bboxes = np.array(new["bboxes"], dtype=np.float32) | ||||
|             labels["instances"].update(bboxes=bboxes) | ||||
|         return labels | ||||
| 
 | ||||
| 
 | ||||
| @ -888,15 +889,17 @@ class Format: | ||||
|         batch_idx (bool): Keep batch indexes. Default is True. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, | ||||
|                  bbox_format='xywh', | ||||
|     def __init__( | ||||
|         self, | ||||
|         bbox_format="xywh", | ||||
|         normalize=True, | ||||
|         return_mask=False, | ||||
|         return_keypoint=False, | ||||
|         return_obb=False, | ||||
|         mask_ratio=4, | ||||
|         mask_overlap=True, | ||||
|                  batch_idx=True): | ||||
|         batch_idx=True, | ||||
|     ): | ||||
|         """Initializes the Format class with given parameters.""" | ||||
|         self.bbox_format = bbox_format | ||||
|         self.normalize = normalize | ||||
| @ -909,10 +912,10 @@ class Format: | ||||
| 
 | ||||
|     def __call__(self, labels): | ||||
|         """Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'.""" | ||||
|         img = labels.pop('img') | ||||
|         img = labels.pop("img") | ||||
|         h, w = img.shape[:2] | ||||
|         cls = labels.pop('cls') | ||||
|         instances = labels.pop('instances') | ||||
|         cls = labels.pop("cls") | ||||
|         instances = labels.pop("instances") | ||||
|         instances.convert_bbox(format=self.bbox_format) | ||||
|         instances.denormalize(w, h) | ||||
|         nl = len(instances) | ||||
| @ -922,22 +925,24 @@ class Format: | ||||
|                 masks, instances, cls = self._format_segments(instances, cls, w, h) | ||||
|                 masks = torch.from_numpy(masks) | ||||
|             else: | ||||
|                 masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, | ||||
|                                     img.shape[1] // self.mask_ratio) | ||||
|             labels['masks'] = masks | ||||
|                 masks = torch.zeros( | ||||
|                     1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio | ||||
|                 ) | ||||
|             labels["masks"] = masks | ||||
|         if self.normalize: | ||||
|             instances.normalize(w, h) | ||||
|         labels['img'] = self._format_img(img) | ||||
|         labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl) | ||||
|         labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) | ||||
|         labels["img"] = self._format_img(img) | ||||
|         labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl) | ||||
|         labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) | ||||
|         if self.return_keypoint: | ||||
|             labels['keypoints'] = torch.from_numpy(instances.keypoints) | ||||
|             labels["keypoints"] = torch.from_numpy(instances.keypoints) | ||||
|         if self.return_obb: | ||||
|             labels['bboxes'] = xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len( | ||||
|                 instances.segments) else torch.zeros((0, 5)) | ||||
|             labels["bboxes"] = ( | ||||
|                 xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5)) | ||||
|             ) | ||||
|         # Then we can use collate_fn | ||||
|         if self.batch_idx: | ||||
|             labels['batch_idx'] = torch.zeros(nl) | ||||
|             labels["batch_idx"] = torch.zeros(nl) | ||||
|         return labels | ||||
| 
 | ||||
|     def _format_img(self, img): | ||||
| @ -964,7 +969,8 @@ class Format: | ||||
| 
 | ||||
| def v8_transforms(dataset, imgsz, hyp, stretch=False): | ||||
|     """Convert images to a size suitable for YOLOv8 training.""" | ||||
|     pre_transform = Compose([ | ||||
|     pre_transform = Compose( | ||||
|         [ | ||||
|             Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), | ||||
|             CopyPaste(p=hyp.copy_paste), | ||||
|             RandomPerspective( | ||||
| @ -974,23 +980,28 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False): | ||||
|                 shear=hyp.shear, | ||||
|                 perspective=hyp.perspective, | ||||
|                 pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)), | ||||
|         )]) | ||||
|     flip_idx = dataset.data.get('flip_idx', [])  # for keypoints augmentation | ||||
|             ), | ||||
|         ] | ||||
|     ) | ||||
|     flip_idx = dataset.data.get("flip_idx", [])  # for keypoints augmentation | ||||
|     if dataset.use_keypoints: | ||||
|         kpt_shape = dataset.data.get('kpt_shape', None) | ||||
|         kpt_shape = dataset.data.get("kpt_shape", None) | ||||
|         if len(flip_idx) == 0 and hyp.fliplr > 0.0: | ||||
|             hyp.fliplr = 0.0 | ||||
|             LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'") | ||||
|         elif flip_idx and (len(flip_idx) != kpt_shape[0]): | ||||
|             raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}') | ||||
|             raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}") | ||||
| 
 | ||||
|     return Compose([ | ||||
|     return Compose( | ||||
|         [ | ||||
|             pre_transform, | ||||
|             MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), | ||||
|             Albumentations(p=1.0), | ||||
|             RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), | ||||
|         RandomFlip(direction='vertical', p=hyp.flipud), | ||||
|         RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)])  # transforms | ||||
|             RandomFlip(direction="vertical", p=hyp.flipud), | ||||
|             RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx), | ||||
|         ] | ||||
|     )  # transforms | ||||
| 
 | ||||
| 
 | ||||
| # Classification augmentations ----------------------------------------------------------------------------------------- | ||||
| @ -1031,10 +1042,13 @@ def classify_transforms( | ||||
|         tfl = [T.Resize(scale_size)] | ||||
|     tfl += [T.CenterCrop(size)] | ||||
| 
 | ||||
|     tfl += [T.ToTensor(), T.Normalize( | ||||
|     tfl += [ | ||||
|         T.ToTensor(), | ||||
|         T.Normalize( | ||||
|             mean=torch.tensor(mean), | ||||
|             std=torch.tensor(std), | ||||
|     )] | ||||
|         ), | ||||
|     ] | ||||
| 
 | ||||
|     return T.Compose(tfl) | ||||
| 
 | ||||
| @ -1053,7 +1067,7 @@ def classify_augmentations( | ||||
|     hsv_s=0.4,  # image HSV-Saturation augmentation (fraction) | ||||
|     hsv_v=0.4,  # image HSV-Value augmentation (fraction) | ||||
|     force_color_jitter=False, | ||||
|     erasing=0., | ||||
|     erasing=0.0, | ||||
|     interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, | ||||
| ): | ||||
|     """ | ||||
| @ -1080,13 +1094,13 @@ def classify_augmentations( | ||||
|     """ | ||||
|     # Transforms to apply if albumentations not installed | ||||
|     if not isinstance(size, int): | ||||
|         raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)') | ||||
|         raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") | ||||
|     scale = tuple(scale or (0.08, 1.0))  # default imagenet scale range | ||||
|     ratio = tuple(ratio or (3. / 4., 4. / 3.))  # default imagenet ratio range | ||||
|     ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0))  # default imagenet ratio range | ||||
|     primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)] | ||||
|     if hflip > 0.: | ||||
|     if hflip > 0.0: | ||||
|         primary_tfl += [T.RandomHorizontalFlip(p=hflip)] | ||||
|     if vflip > 0.: | ||||
|     if vflip > 0.0: | ||||
|         primary_tfl += [T.RandomVerticalFlip(p=vflip)] | ||||
| 
 | ||||
|     secondary_tfl = [] | ||||
| @ -1097,27 +1111,29 @@ def classify_augmentations( | ||||
|         # this allows override without breaking old hparm cfgs | ||||
|         disable_color_jitter = not force_color_jitter | ||||
| 
 | ||||
|         if auto_augment == 'randaugment': | ||||
|         if auto_augment == "randaugment": | ||||
|             if TORCHVISION_0_11: | ||||
|                 secondary_tfl += [T.RandAugment(interpolation=interpolation)] | ||||
|             else: | ||||
|                 LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.') | ||||
| 
 | ||||
|         elif auto_augment == 'augmix': | ||||
|         elif auto_augment == "augmix": | ||||
|             if TORCHVISION_0_13: | ||||
|                 secondary_tfl += [T.AugMix(interpolation=interpolation)] | ||||
|             else: | ||||
|                 LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.') | ||||
| 
 | ||||
|         elif auto_augment == 'autoaugment': | ||||
|         elif auto_augment == "autoaugment": | ||||
|             if TORCHVISION_0_10: | ||||
|                 secondary_tfl += [T.AutoAugment(interpolation=interpolation)] | ||||
|             else: | ||||
|                 LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.') | ||||
| 
 | ||||
|         else: | ||||
|             raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", ' | ||||
|                              f'"augmix", "autoaugment" or None') | ||||
|             raise ValueError( | ||||
|                 f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", ' | ||||
|                 f'"augmix", "autoaugment" or None' | ||||
|             ) | ||||
| 
 | ||||
|     if not disable_color_jitter: | ||||
|         secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)] | ||||
| @ -1125,7 +1141,8 @@ def classify_augmentations( | ||||
|     final_tfl = [ | ||||
|         T.ToTensor(), | ||||
|         T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), | ||||
|         T.RandomErasing(p=erasing, inplace=True)] | ||||
|         T.RandomErasing(p=erasing, inplace=True), | ||||
|     ] | ||||
| 
 | ||||
|     return T.Compose(primary_tfl + secondary_tfl + final_tfl) | ||||
| 
 | ||||
| @ -1177,7 +1194,7 @@ class ClassifyLetterBox: | ||||
| 
 | ||||
|         # Create padded image | ||||
|         im_out = np.full((hs, ws, 3), 114, dtype=im.dtype) | ||||
|         im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) | ||||
|         im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) | ||||
|         return im_out | ||||
| 
 | ||||
| 
 | ||||
| @ -1205,7 +1222,7 @@ class CenterCrop: | ||||
|         imh, imw = im.shape[:2] | ||||
|         m = min(imh, imw)  # min dimension | ||||
|         top, left = (imh - m) // 2, (imw - m) // 2 | ||||
|         return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) | ||||
|         return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) | ||||
| 
 | ||||
| 
 | ||||
| # NOTE: keep this class for backward compatibility | ||||
|  | ||||
| @ -47,20 +47,22 @@ class BaseDataset(Dataset): | ||||
|         transforms (callable): Image transformation function. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, | ||||
|     def __init__( | ||||
|         self, | ||||
|         img_path, | ||||
|         imgsz=640, | ||||
|         cache=False, | ||||
|         augment=True, | ||||
|         hyp=DEFAULT_CFG, | ||||
|                  prefix='', | ||||
|         prefix="", | ||||
|         rect=False, | ||||
|         batch_size=16, | ||||
|         stride=32, | ||||
|         pad=0.5, | ||||
|         single_cls=False, | ||||
|         classes=None, | ||||
|                  fraction=1.0): | ||||
|         fraction=1.0, | ||||
|     ): | ||||
|         """Initialize BaseDataset with given configuration and options.""" | ||||
|         super().__init__() | ||||
|         self.img_path = img_path | ||||
| @ -86,10 +88,10 @@ class BaseDataset(Dataset): | ||||
|         self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 | ||||
| 
 | ||||
|         # Cache images | ||||
|         if cache == 'ram' and not self.check_cache_ram(): | ||||
|         if cache == "ram" and not self.check_cache_ram(): | ||||
|             cache = False | ||||
|         self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni | ||||
|         self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] | ||||
|         self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files] | ||||
|         if cache: | ||||
|             self.cache_images(cache) | ||||
| 
 | ||||
| @ -103,23 +105,23 @@ class BaseDataset(Dataset): | ||||
|             for p in img_path if isinstance(img_path, list) else [img_path]: | ||||
|                 p = Path(p)  # os-agnostic | ||||
|                 if p.is_dir():  # dir | ||||
|                     f += glob.glob(str(p / '**' / '*.*'), recursive=True) | ||||
|                     f += glob.glob(str(p / "**" / "*.*"), recursive=True) | ||||
|                     # F = list(p.rglob('*.*'))  # pathlib | ||||
|                 elif p.is_file():  # file | ||||
|                     with open(p) as t: | ||||
|                         t = t.read().strip().splitlines() | ||||
|                         parent = str(p.parent) + os.sep | ||||
|                         f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path | ||||
|                         f += [x.replace("./", parent) if x.startswith("./") else x for x in t]  # local to global path | ||||
|                         # F += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib) | ||||
|                 else: | ||||
|                     raise FileNotFoundError(f'{self.prefix}{p} does not exist') | ||||
|             im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS) | ||||
|                     raise FileNotFoundError(f"{self.prefix}{p} does not exist") | ||||
|             im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS) | ||||
|             # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlib | ||||
|             assert im_files, f'{self.prefix}No images found in {img_path}' | ||||
|             assert im_files, f"{self.prefix}No images found in {img_path}" | ||||
|         except Exception as e: | ||||
|             raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e | ||||
|             raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e | ||||
|         if self.fraction < 1: | ||||
|             im_files = im_files[:round(len(im_files) * self.fraction)] | ||||
|             im_files = im_files[: round(len(im_files) * self.fraction)] | ||||
|         return im_files | ||||
| 
 | ||||
|     def update_labels(self, include_class: Optional[list]): | ||||
| @ -127,19 +129,19 @@ class BaseDataset(Dataset): | ||||
|         include_class_array = np.array(include_class).reshape(1, -1) | ||||
|         for i in range(len(self.labels)): | ||||
|             if include_class is not None: | ||||
|                 cls = self.labels[i]['cls'] | ||||
|                 bboxes = self.labels[i]['bboxes'] | ||||
|                 segments = self.labels[i]['segments'] | ||||
|                 keypoints = self.labels[i]['keypoints'] | ||||
|                 cls = self.labels[i]["cls"] | ||||
|                 bboxes = self.labels[i]["bboxes"] | ||||
|                 segments = self.labels[i]["segments"] | ||||
|                 keypoints = self.labels[i]["keypoints"] | ||||
|                 j = (cls == include_class_array).any(1) | ||||
|                 self.labels[i]['cls'] = cls[j] | ||||
|                 self.labels[i]['bboxes'] = bboxes[j] | ||||
|                 self.labels[i]["cls"] = cls[j] | ||||
|                 self.labels[i]["bboxes"] = bboxes[j] | ||||
|                 if segments: | ||||
|                     self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx] | ||||
|                     self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx] | ||||
|                 if keypoints is not None: | ||||
|                     self.labels[i]['keypoints'] = keypoints[j] | ||||
|                     self.labels[i]["keypoints"] = keypoints[j] | ||||
|             if self.single_cls: | ||||
|                 self.labels[i]['cls'][:, 0] = 0 | ||||
|                 self.labels[i]["cls"][:, 0] = 0 | ||||
| 
 | ||||
|     def load_image(self, i, rect_mode=True): | ||||
|         """Loads 1 image from dataset index 'i', returns (im, resized hw).""" | ||||
| @ -149,13 +151,13 @@ class BaseDataset(Dataset): | ||||
|                 try: | ||||
|                     im = np.load(fn) | ||||
|                 except Exception as e: | ||||
|                     LOGGER.warning(f'{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}') | ||||
|                     LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}") | ||||
|                     Path(fn).unlink(missing_ok=True) | ||||
|                     im = cv2.imread(f)  # BGR | ||||
|             else:  # read image | ||||
|                 im = cv2.imread(f)  # BGR | ||||
|             if im is None: | ||||
|                 raise FileNotFoundError(f'Image Not Found {f}') | ||||
|                 raise FileNotFoundError(f"Image Not Found {f}") | ||||
| 
 | ||||
|             h0, w0 = im.shape[:2]  # orig hw | ||||
|             if rect_mode:  # resize long side to imgsz while maintaining aspect ratio | ||||
| @ -181,17 +183,17 @@ class BaseDataset(Dataset): | ||||
|     def cache_images(self, cache): | ||||
|         """Cache images to memory or disk.""" | ||||
|         b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes | ||||
|         fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image | ||||
|         fcn = self.cache_images_to_disk if cache == "disk" else self.load_image | ||||
|         with ThreadPool(NUM_THREADS) as pool: | ||||
|             results = pool.imap(fcn, range(self.ni)) | ||||
|             pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0) | ||||
|             for i, x in pbar: | ||||
|                 if cache == 'disk': | ||||
|                 if cache == "disk": | ||||
|                     b += self.npy_files[i].stat().st_size | ||||
|                 else:  # 'ram' | ||||
|                     self.ims[i], self.im_hw0[i], self.im_hw[i] = x  # im, hw_orig, hw_resized = load_image(self, i) | ||||
|                     b += self.ims[i].nbytes | ||||
|                 pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})' | ||||
|                 pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})" | ||||
|             pbar.close() | ||||
| 
 | ||||
|     def cache_images_to_disk(self, i): | ||||
| @ -207,15 +209,17 @@ class BaseDataset(Dataset): | ||||
|         for _ in range(n): | ||||
|             im = cv2.imread(random.choice(self.im_files))  # sample image | ||||
|             ratio = self.imgsz / max(im.shape[0], im.shape[1])  # max(h, w)  # ratio | ||||
|             b += im.nbytes * ratio ** 2 | ||||
|             b += im.nbytes * ratio**2 | ||||
|         mem_required = b * self.ni / n * (1 + safety_margin)  # GB required to cache dataset into RAM | ||||
|         mem = psutil.virtual_memory() | ||||
|         cache = mem_required < mem.available  # to cache or not to cache, that is the question | ||||
|         if not cache: | ||||
|             LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images ' | ||||
|             LOGGER.info( | ||||
|                 f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images ' | ||||
|                 f'with {int(safety_margin * 100)}% safety margin but only ' | ||||
|                 f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, ' | ||||
|                         f"{'caching images ✅' if cache else 'not caching images ⚠️'}") | ||||
|                 f"{'caching images ✅' if cache else 'not caching images ⚠️'}" | ||||
|             ) | ||||
|         return cache | ||||
| 
 | ||||
|     def set_rectangle(self): | ||||
| @ -223,7 +227,7 @@ class BaseDataset(Dataset): | ||||
|         bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int)  # batch index | ||||
|         nb = bi[-1] + 1  # number of batches | ||||
| 
 | ||||
|         s = np.array([x.pop('shape') for x in self.labels])  # hw | ||||
|         s = np.array([x.pop("shape") for x in self.labels])  # hw | ||||
|         ar = s[:, 0] / s[:, 1]  # aspect ratio | ||||
|         irect = ar.argsort() | ||||
|         self.im_files = [self.im_files[i] for i in irect] | ||||
| @ -250,12 +254,14 @@ class BaseDataset(Dataset): | ||||
|     def get_image_and_label(self, index): | ||||
|         """Get and return label information from the dataset.""" | ||||
|         label = deepcopy(self.labels[index])  # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948 | ||||
|         label.pop('shape', None)  # shape is for rect, remove it | ||||
|         label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index) | ||||
|         label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0], | ||||
|                               label['resized_shape'][1] / label['ori_shape'][1])  # for evaluation | ||||
|         label.pop("shape", None)  # shape is for rect, remove it | ||||
|         label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index) | ||||
|         label["ratio_pad"] = ( | ||||
|             label["resized_shape"][0] / label["ori_shape"][0], | ||||
|             label["resized_shape"][1] / label["ori_shape"][1], | ||||
|         )  # for evaluation | ||||
|         if self.rect: | ||||
|             label['rect_shape'] = self.batch_shapes[self.batch[index]] | ||||
|             label["rect_shape"] = self.batch_shapes[self.batch[index]] | ||||
|         return self.update_labels_info(label) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|  | ||||
| @ -9,8 +9,16 @@ import torch | ||||
| from PIL import Image | ||||
| from torch.utils.data import dataloader, distributed | ||||
| 
 | ||||
| from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor, | ||||
|                                       SourceTypes, autocast_list) | ||||
| from ultralytics.data.loaders import ( | ||||
|     LOADERS, | ||||
|     LoadImages, | ||||
|     LoadPilAndNumpy, | ||||
|     LoadScreenshots, | ||||
|     LoadStreams, | ||||
|     LoadTensor, | ||||
|     SourceTypes, | ||||
|     autocast_list, | ||||
| ) | ||||
| from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS | ||||
| from ultralytics.utils import RANK, colorstr | ||||
| from ultralytics.utils.checks import check_file | ||||
| @ -29,7 +37,7 @@ class InfiniteDataLoader(dataloader.DataLoader): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         """Dataloader that infinitely recycles workers, inherits from DataLoader.""" | ||||
|         super().__init__(*args, **kwargs) | ||||
|         object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) | ||||
|         object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler)) | ||||
|         self.iterator = super().__iter__() | ||||
| 
 | ||||
|     def __len__(self): | ||||
| @ -70,29 +78,30 @@ class _RepeatSampler: | ||||
| 
 | ||||
| def seed_worker(worker_id):  # noqa | ||||
|     """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader.""" | ||||
|     worker_seed = torch.initial_seed() % 2 ** 32 | ||||
|     worker_seed = torch.initial_seed() % 2**32 | ||||
|     np.random.seed(worker_seed) | ||||
|     random.seed(worker_seed) | ||||
| 
 | ||||
| 
 | ||||
| def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32): | ||||
| def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32): | ||||
|     """Build YOLO Dataset.""" | ||||
|     return YOLODataset( | ||||
|         img_path=img_path, | ||||
|         imgsz=cfg.imgsz, | ||||
|         batch_size=batch, | ||||
|         augment=mode == 'train',  # augmentation | ||||
|         augment=mode == "train",  # augmentation | ||||
|         hyp=cfg,  # TODO: probably add a get_hyps_from_cfg function | ||||
|         rect=cfg.rect or rect,  # rectangular batches | ||||
|         cache=cfg.cache or None, | ||||
|         single_cls=cfg.single_cls or False, | ||||
|         stride=int(stride), | ||||
|         pad=0.0 if mode == 'train' else 0.5, | ||||
|         prefix=colorstr(f'{mode}: '), | ||||
|         pad=0.0 if mode == "train" else 0.5, | ||||
|         prefix=colorstr(f"{mode}: "), | ||||
|         task=cfg.task, | ||||
|         classes=cfg.classes, | ||||
|         data=data, | ||||
|         fraction=cfg.fraction if mode == 'train' else 1.0) | ||||
|         fraction=cfg.fraction if mode == "train" else 1.0, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): | ||||
| @ -103,15 +112,17 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): | ||||
|     sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) | ||||
|     generator = torch.Generator() | ||||
|     generator.manual_seed(6148914691236517205 + RANK) | ||||
|     return InfiniteDataLoader(dataset=dataset, | ||||
|     return InfiniteDataLoader( | ||||
|         dataset=dataset, | ||||
|         batch_size=batch, | ||||
|         shuffle=shuffle and sampler is None, | ||||
|         num_workers=nw, | ||||
|         sampler=sampler, | ||||
|         pin_memory=PIN_MEMORY, | ||||
|                               collate_fn=getattr(dataset, 'collate_fn', None), | ||||
|         collate_fn=getattr(dataset, "collate_fn", None), | ||||
|         worker_init_fn=seed_worker, | ||||
|                               generator=generator) | ||||
|         generator=generator, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def check_source(source): | ||||
| @ -120,9 +131,9 @@ def check_source(source): | ||||
|     if isinstance(source, (str, int, Path)):  # int for local usb camera | ||||
|         source = str(source) | ||||
|         is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) | ||||
|         is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://')) | ||||
|         webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) | ||||
|         screenshot = source.lower() == 'screen' | ||||
|         is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")) | ||||
|         webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file) | ||||
|         screenshot = source.lower() == "screen" | ||||
|         if is_url and is_file: | ||||
|             source = check_file(source)  # download | ||||
|     elif isinstance(source, LOADERS): | ||||
| @ -135,7 +146,7 @@ def check_source(source): | ||||
|     elif isinstance(source, torch.Tensor): | ||||
|         tensor = True | ||||
|     else: | ||||
|         raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict') | ||||
|         raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict") | ||||
| 
 | ||||
|     return source, webcam, screenshot, from_img, in_memory, tensor | ||||
| 
 | ||||
| @ -171,6 +182,6 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False): | ||||
|         dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride) | ||||
| 
 | ||||
|     # Attach source types to the dataset | ||||
|     setattr(dataset, 'source_type', source_type) | ||||
|     setattr(dataset, "source_type", source_type) | ||||
| 
 | ||||
|     return dataset | ||||
|  | ||||
| @ -20,10 +20,98 @@ def coco91_to_coco80_class(): | ||||
|             corresponding 91-index class ID. | ||||
|     """ | ||||
|     return [ | ||||
|         0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None, | ||||
|         None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, | ||||
|         51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, | ||||
|         None, 73, 74, 75, 76, 77, 78, 79, None] | ||||
|         0, | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
|         4, | ||||
|         5, | ||||
|         6, | ||||
|         7, | ||||
|         8, | ||||
|         9, | ||||
|         10, | ||||
|         None, | ||||
|         11, | ||||
|         12, | ||||
|         13, | ||||
|         14, | ||||
|         15, | ||||
|         16, | ||||
|         17, | ||||
|         18, | ||||
|         19, | ||||
|         20, | ||||
|         21, | ||||
|         22, | ||||
|         23, | ||||
|         None, | ||||
|         24, | ||||
|         25, | ||||
|         None, | ||||
|         None, | ||||
|         26, | ||||
|         27, | ||||
|         28, | ||||
|         29, | ||||
|         30, | ||||
|         31, | ||||
|         32, | ||||
|         33, | ||||
|         34, | ||||
|         35, | ||||
|         36, | ||||
|         37, | ||||
|         38, | ||||
|         39, | ||||
|         None, | ||||
|         40, | ||||
|         41, | ||||
|         42, | ||||
|         43, | ||||
|         44, | ||||
|         45, | ||||
|         46, | ||||
|         47, | ||||
|         48, | ||||
|         49, | ||||
|         50, | ||||
|         51, | ||||
|         52, | ||||
|         53, | ||||
|         54, | ||||
|         55, | ||||
|         56, | ||||
|         57, | ||||
|         58, | ||||
|         59, | ||||
|         None, | ||||
|         60, | ||||
|         None, | ||||
|         None, | ||||
|         61, | ||||
|         None, | ||||
|         62, | ||||
|         63, | ||||
|         64, | ||||
|         65, | ||||
|         66, | ||||
|         67, | ||||
|         68, | ||||
|         69, | ||||
|         70, | ||||
|         71, | ||||
|         72, | ||||
|         None, | ||||
|         73, | ||||
|         74, | ||||
|         75, | ||||
|         76, | ||||
|         77, | ||||
|         78, | ||||
|         79, | ||||
|         None, | ||||
|     ] | ||||
| 
 | ||||
| 
 | ||||
| def coco80_to_coco91_class(): | ||||
| @ -42,16 +130,96 @@ def coco80_to_coco91_class(): | ||||
|         ``` | ||||
|     """ | ||||
|     return [ | ||||
|         1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, | ||||
|         35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, | ||||
|         64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
|         4, | ||||
|         5, | ||||
|         6, | ||||
|         7, | ||||
|         8, | ||||
|         9, | ||||
|         10, | ||||
|         11, | ||||
|         13, | ||||
|         14, | ||||
|         15, | ||||
|         16, | ||||
|         17, | ||||
|         18, | ||||
|         19, | ||||
|         20, | ||||
|         21, | ||||
|         22, | ||||
|         23, | ||||
|         24, | ||||
|         25, | ||||
|         27, | ||||
|         28, | ||||
|         31, | ||||
|         32, | ||||
|         33, | ||||
|         34, | ||||
|         35, | ||||
|         36, | ||||
|         37, | ||||
|         38, | ||||
|         39, | ||||
|         40, | ||||
|         41, | ||||
|         42, | ||||
|         43, | ||||
|         44, | ||||
|         46, | ||||
|         47, | ||||
|         48, | ||||
|         49, | ||||
|         50, | ||||
|         51, | ||||
|         52, | ||||
|         53, | ||||
|         54, | ||||
|         55, | ||||
|         56, | ||||
|         57, | ||||
|         58, | ||||
|         59, | ||||
|         60, | ||||
|         61, | ||||
|         62, | ||||
|         63, | ||||
|         64, | ||||
|         65, | ||||
|         67, | ||||
|         70, | ||||
|         72, | ||||
|         73, | ||||
|         74, | ||||
|         75, | ||||
|         76, | ||||
|         77, | ||||
|         78, | ||||
|         79, | ||||
|         80, | ||||
|         81, | ||||
|         82, | ||||
|         84, | ||||
|         85, | ||||
|         86, | ||||
|         87, | ||||
|         88, | ||||
|         89, | ||||
|         90, | ||||
|     ] | ||||
| 
 | ||||
| 
 | ||||
| def convert_coco(labels_dir='../coco/annotations/', | ||||
|                  save_dir='coco_converted/', | ||||
| def convert_coco( | ||||
|     labels_dir="../coco/annotations/", | ||||
|     save_dir="coco_converted/", | ||||
|     use_segments=False, | ||||
|     use_keypoints=False, | ||||
|                  cls91to80=True): | ||||
|     cls91to80=True, | ||||
| ): | ||||
|     """ | ||||
|     Converts COCO dataset annotations to a YOLO annotation format  suitable for training YOLO models. | ||||
| 
 | ||||
| @ -75,76 +243,78 @@ def convert_coco(labels_dir='../coco/annotations/', | ||||
| 
 | ||||
|     # Create dataset directory | ||||
|     save_dir = increment_path(save_dir)  # increment if save directory already exists | ||||
|     for p in save_dir / 'labels', save_dir / 'images': | ||||
|     for p in save_dir / "labels", save_dir / "images": | ||||
|         p.mkdir(parents=True, exist_ok=True)  # make dir | ||||
| 
 | ||||
|     # Convert classes | ||||
|     coco80 = coco91_to_coco80_class() | ||||
| 
 | ||||
|     # Import json | ||||
|     for json_file in sorted(Path(labels_dir).resolve().glob('*.json')): | ||||
|         fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '')  # folder name | ||||
|     for json_file in sorted(Path(labels_dir).resolve().glob("*.json")): | ||||
|         fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "")  # folder name | ||||
|         fn.mkdir(parents=True, exist_ok=True) | ||||
|         with open(json_file) as f: | ||||
|             data = json.load(f) | ||||
| 
 | ||||
|         # Create image dict | ||||
|         images = {f'{x["id"]:d}': x for x in data['images']} | ||||
|         images = {f'{x["id"]:d}': x for x in data["images"]} | ||||
|         # Create image-annotations dict | ||||
|         imgToAnns = defaultdict(list) | ||||
|         for ann in data['annotations']: | ||||
|             imgToAnns[ann['image_id']].append(ann) | ||||
|         for ann in data["annotations"]: | ||||
|             imgToAnns[ann["image_id"]].append(ann) | ||||
| 
 | ||||
|         # Write labels file | ||||
|         for img_id, anns in TQDM(imgToAnns.items(), desc=f'Annotations {json_file}'): | ||||
|             img = images[f'{img_id:d}'] | ||||
|             h, w, f = img['height'], img['width'], img['file_name'] | ||||
|         for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"): | ||||
|             img = images[f"{img_id:d}"] | ||||
|             h, w, f = img["height"], img["width"], img["file_name"] | ||||
| 
 | ||||
|             bboxes = [] | ||||
|             segments = [] | ||||
|             keypoints = [] | ||||
|             for ann in anns: | ||||
|                 if ann['iscrowd']: | ||||
|                 if ann["iscrowd"]: | ||||
|                     continue | ||||
|                 # The COCO box format is [top left x, top left y, width, height] | ||||
|                 box = np.array(ann['bbox'], dtype=np.float64) | ||||
|                 box = np.array(ann["bbox"], dtype=np.float64) | ||||
|                 box[:2] += box[2:] / 2  # xy top-left corner to center | ||||
|                 box[[0, 2]] /= w  # normalize x | ||||
|                 box[[1, 3]] /= h  # normalize y | ||||
|                 if box[2] <= 0 or box[3] <= 0:  # if w <= 0 and h <= 0 | ||||
|                     continue | ||||
| 
 | ||||
|                 cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1  # class | ||||
|                 cls = coco80[ann["category_id"] - 1] if cls91to80 else ann["category_id"] - 1  # class | ||||
|                 box = [cls] + box.tolist() | ||||
|                 if box not in bboxes: | ||||
|                     bboxes.append(box) | ||||
|                     if use_segments and ann.get('segmentation') is not None: | ||||
|                         if len(ann['segmentation']) == 0: | ||||
|                     if use_segments and ann.get("segmentation") is not None: | ||||
|                         if len(ann["segmentation"]) == 0: | ||||
|                             segments.append([]) | ||||
|                             continue | ||||
|                         elif len(ann['segmentation']) > 1: | ||||
|                             s = merge_multi_segment(ann['segmentation']) | ||||
|                         elif len(ann["segmentation"]) > 1: | ||||
|                             s = merge_multi_segment(ann["segmentation"]) | ||||
|                             s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist() | ||||
|                         else: | ||||
|                             s = [j for i in ann['segmentation'] for j in i]  # all segments concatenated | ||||
|                             s = [j for i in ann["segmentation"] for j in i]  # all segments concatenated | ||||
|                             s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist() | ||||
|                         s = [cls] + s | ||||
|                         segments.append(s) | ||||
|                     if use_keypoints and ann.get('keypoints') is not None: | ||||
|                         keypoints.append(box + (np.array(ann['keypoints']).reshape(-1, 3) / | ||||
|                                                 np.array([w, h, 1])).reshape(-1).tolist()) | ||||
|                     if use_keypoints and ann.get("keypoints") is not None: | ||||
|                         keypoints.append( | ||||
|                             box + (np.array(ann["keypoints"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist() | ||||
|                         ) | ||||
| 
 | ||||
|             # Write | ||||
|             with open((fn / f).with_suffix('.txt'), 'a') as file: | ||||
|             with open((fn / f).with_suffix(".txt"), "a") as file: | ||||
|                 for i in range(len(bboxes)): | ||||
|                     if use_keypoints: | ||||
|                         line = *(keypoints[i]),  # cls, box, keypoints | ||||
|                         line = (*(keypoints[i]),)  # cls, box, keypoints | ||||
|                     else: | ||||
|                         line = *(segments[i] | ||||
|                                  if use_segments and len(segments[i]) > 0 else bboxes[i]),  # cls, box or segments | ||||
|                     file.write(('%g ' * len(line)).rstrip() % line + '\n') | ||||
|                         line = ( | ||||
|                             *(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]), | ||||
|                         )  # cls, box or segments | ||||
|                     file.write(("%g " * len(line)).rstrip() % line + "\n") | ||||
| 
 | ||||
|     LOGGER.info(f'COCO data converted successfully.\nResults saved to {save_dir.resolve()}') | ||||
|     LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}") | ||||
| 
 | ||||
| 
 | ||||
| def convert_dota_to_yolo_obb(dota_root_path: str): | ||||
| @ -184,31 +354,32 @@ def convert_dota_to_yolo_obb(dota_root_path: str): | ||||
| 
 | ||||
|     # Class names to indices mapping | ||||
|     class_mapping = { | ||||
|         'plane': 0, | ||||
|         'ship': 1, | ||||
|         'storage-tank': 2, | ||||
|         'baseball-diamond': 3, | ||||
|         'tennis-court': 4, | ||||
|         'basketball-court': 5, | ||||
|         'ground-track-field': 6, | ||||
|         'harbor': 7, | ||||
|         'bridge': 8, | ||||
|         'large-vehicle': 9, | ||||
|         'small-vehicle': 10, | ||||
|         'helicopter': 11, | ||||
|         'roundabout': 12, | ||||
|         'soccer-ball-field': 13, | ||||
|         'swimming-pool': 14, | ||||
|         'container-crane': 15, | ||||
|         'airport': 16, | ||||
|         'helipad': 17} | ||||
|         "plane": 0, | ||||
|         "ship": 1, | ||||
|         "storage-tank": 2, | ||||
|         "baseball-diamond": 3, | ||||
|         "tennis-court": 4, | ||||
|         "basketball-court": 5, | ||||
|         "ground-track-field": 6, | ||||
|         "harbor": 7, | ||||
|         "bridge": 8, | ||||
|         "large-vehicle": 9, | ||||
|         "small-vehicle": 10, | ||||
|         "helicopter": 11, | ||||
|         "roundabout": 12, | ||||
|         "soccer-ball-field": 13, | ||||
|         "swimming-pool": 14, | ||||
|         "container-crane": 15, | ||||
|         "airport": 16, | ||||
|         "helipad": 17, | ||||
|     } | ||||
| 
 | ||||
|     def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir): | ||||
|         """Converts a single image's DOTA annotation to YOLO OBB format and saves it to a specified directory.""" | ||||
|         orig_label_path = orig_label_dir / f'{image_name}.txt' | ||||
|         save_path = save_dir / f'{image_name}.txt' | ||||
|         orig_label_path = orig_label_dir / f"{image_name}.txt" | ||||
|         save_path = save_dir / f"{image_name}.txt" | ||||
| 
 | ||||
|         with orig_label_path.open('r') as f, save_path.open('w') as g: | ||||
|         with orig_label_path.open("r") as f, save_path.open("w") as g: | ||||
|             lines = f.readlines() | ||||
|             for line in lines: | ||||
|                 parts = line.strip().split() | ||||
| @ -218,20 +389,21 @@ def convert_dota_to_yolo_obb(dota_root_path: str): | ||||
|                 class_idx = class_mapping[class_name] | ||||
|                 coords = [float(p) for p in parts[:8]] | ||||
|                 normalized_coords = [ | ||||
|                     coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)] | ||||
|                 formatted_coords = ['{:.6g}'.format(coord) for coord in normalized_coords] | ||||
|                     coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8) | ||||
|                 ] | ||||
|                 formatted_coords = ["{:.6g}".format(coord) for coord in normalized_coords] | ||||
|                 g.write(f"{class_idx} {' '.join(formatted_coords)}\n") | ||||
| 
 | ||||
|     for phase in ['train', 'val']: | ||||
|         image_dir = dota_root_path / 'images' / phase | ||||
|         orig_label_dir = dota_root_path / 'labels' / f'{phase}_original' | ||||
|         save_dir = dota_root_path / 'labels' / phase | ||||
|     for phase in ["train", "val"]: | ||||
|         image_dir = dota_root_path / "images" / phase | ||||
|         orig_label_dir = dota_root_path / "labels" / f"{phase}_original" | ||||
|         save_dir = dota_root_path / "labels" / phase | ||||
| 
 | ||||
|         save_dir.mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
|         image_paths = list(image_dir.iterdir()) | ||||
|         for image_path in TQDM(image_paths, desc=f'Processing {phase} images'): | ||||
|             if image_path.suffix != '.png': | ||||
|         for image_path in TQDM(image_paths, desc=f"Processing {phase} images"): | ||||
|             if image_path.suffix != ".png": | ||||
|                 continue | ||||
|             image_name_without_ext = image_path.stem | ||||
|             img = cv2.imread(str(image_path)) | ||||
| @ -293,7 +465,7 @@ def merge_multi_segment(segments): | ||||
|                     s.append(segments[i]) | ||||
|                 else: | ||||
|                     idx = [0, idx[1] - idx[0]] | ||||
|                     s.append(segments[i][idx[0]:idx[1] + 1]) | ||||
|                     s.append(segments[i][idx[0] : idx[1] + 1]) | ||||
| 
 | ||||
|         else: | ||||
|             for i in range(len(idx_list) - 1, -1, -1): | ||||
|  | ||||
| @ -18,7 +18,7 @@ from .base import BaseDataset | ||||
| from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label | ||||
| 
 | ||||
| # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 | ||||
| DATASET_CACHE_VERSION = '1.0.3' | ||||
| DATASET_CACHE_VERSION = "1.0.3" | ||||
| 
 | ||||
| 
 | ||||
| class YOLODataset(BaseDataset): | ||||
| @ -33,16 +33,16 @@ class YOLODataset(BaseDataset): | ||||
|         (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, *args, data=None, task='detect', **kwargs): | ||||
|     def __init__(self, *args, data=None, task="detect", **kwargs): | ||||
|         """Initializes the YOLODataset with optional configurations for segments and keypoints.""" | ||||
|         self.use_segments = task == 'segment' | ||||
|         self.use_keypoints = task == 'pose' | ||||
|         self.use_obb = task == 'obb' | ||||
|         self.use_segments = task == "segment" | ||||
|         self.use_keypoints = task == "pose" | ||||
|         self.use_obb = task == "obb" | ||||
|         self.data = data | ||||
|         assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.' | ||||
|         assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." | ||||
|         super().__init__(*args, **kwargs) | ||||
| 
 | ||||
|     def cache_labels(self, path=Path('./labels.cache')): | ||||
|     def cache_labels(self, path=Path("./labels.cache")): | ||||
|         """ | ||||
|         Cache dataset labels, check images and read shapes. | ||||
| 
 | ||||
| @ -51,19 +51,29 @@ class YOLODataset(BaseDataset): | ||||
|         Returns: | ||||
|             (dict): labels. | ||||
|         """ | ||||
|         x = {'labels': []} | ||||
|         x = {"labels": []} | ||||
|         nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages | ||||
|         desc = f'{self.prefix}Scanning {path.parent / path.stem}...' | ||||
|         desc = f"{self.prefix}Scanning {path.parent / path.stem}..." | ||||
|         total = len(self.im_files) | ||||
|         nkpt, ndim = self.data.get('kpt_shape', (0, 0)) | ||||
|         nkpt, ndim = self.data.get("kpt_shape", (0, 0)) | ||||
|         if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)): | ||||
|             raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " | ||||
|                              "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'") | ||||
|             raise ValueError( | ||||
|                 "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " | ||||
|                 "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" | ||||
|             ) | ||||
|         with ThreadPool(NUM_THREADS) as pool: | ||||
|             results = pool.imap(func=verify_image_label, | ||||
|                                 iterable=zip(self.im_files, self.label_files, repeat(self.prefix), | ||||
|                                              repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt), | ||||
|                                              repeat(ndim))) | ||||
|             results = pool.imap( | ||||
|                 func=verify_image_label, | ||||
|                 iterable=zip( | ||||
|                     self.im_files, | ||||
|                     self.label_files, | ||||
|                     repeat(self.prefix), | ||||
|                     repeat(self.use_keypoints), | ||||
|                     repeat(len(self.data["names"])), | ||||
|                     repeat(nkpt), | ||||
|                     repeat(ndim), | ||||
|                 ), | ||||
|             ) | ||||
|             pbar = TQDM(results, desc=desc, total=total) | ||||
|             for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: | ||||
|                 nm += nm_f | ||||
| @ -71,7 +81,7 @@ class YOLODataset(BaseDataset): | ||||
|                 ne += ne_f | ||||
|                 nc += nc_f | ||||
|                 if im_file: | ||||
|                     x['labels'].append( | ||||
|                     x["labels"].append( | ||||
|                         dict( | ||||
|                             im_file=im_file, | ||||
|                             shape=shape, | ||||
| @ -80,60 +90,63 @@ class YOLODataset(BaseDataset): | ||||
|                             segments=segments, | ||||
|                             keypoints=keypoint, | ||||
|                             normalized=True, | ||||
|                             bbox_format='xywh')) | ||||
|                             bbox_format="xywh", | ||||
|                         ) | ||||
|                     ) | ||||
|                 if msg: | ||||
|                     msgs.append(msg) | ||||
|                 pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt' | ||||
|                 pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" | ||||
|             pbar.close() | ||||
| 
 | ||||
|         if msgs: | ||||
|             LOGGER.info('\n'.join(msgs)) | ||||
|             LOGGER.info("\n".join(msgs)) | ||||
|         if nf == 0: | ||||
|             LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}') | ||||
|         x['hash'] = get_hash(self.label_files + self.im_files) | ||||
|         x['results'] = nf, nm, ne, nc, len(self.im_files) | ||||
|         x['msgs'] = msgs  # warnings | ||||
|             LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}") | ||||
|         x["hash"] = get_hash(self.label_files + self.im_files) | ||||
|         x["results"] = nf, nm, ne, nc, len(self.im_files) | ||||
|         x["msgs"] = msgs  # warnings | ||||
|         save_dataset_cache_file(self.prefix, path, x) | ||||
|         return x | ||||
| 
 | ||||
|     def get_labels(self): | ||||
|         """Returns dictionary of labels for YOLO training.""" | ||||
|         self.label_files = img2label_paths(self.im_files) | ||||
|         cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') | ||||
|         cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") | ||||
|         try: | ||||
|             cache, exists = load_dataset_cache_file(cache_path), True  # attempt to load a *.cache file | ||||
|             assert cache['version'] == DATASET_CACHE_VERSION  # matches current version | ||||
|             assert cache['hash'] == get_hash(self.label_files + self.im_files)  # identical hash | ||||
|             assert cache["version"] == DATASET_CACHE_VERSION  # matches current version | ||||
|             assert cache["hash"] == get_hash(self.label_files + self.im_files)  # identical hash | ||||
|         except (FileNotFoundError, AssertionError, AttributeError): | ||||
|             cache, exists = self.cache_labels(cache_path), False  # run cache ops | ||||
| 
 | ||||
|         # Display cache | ||||
|         nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupt, total | ||||
|         nf, nm, ne, nc, n = cache.pop("results")  # found, missing, empty, corrupt, total | ||||
|         if exists and LOCAL_RANK in (-1, 0): | ||||
|             d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt' | ||||
|             d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" | ||||
|             TQDM(None, desc=self.prefix + d, total=n, initial=n)  # display results | ||||
|             if cache['msgs']: | ||||
|                 LOGGER.info('\n'.join(cache['msgs']))  # display warnings | ||||
|             if cache["msgs"]: | ||||
|                 LOGGER.info("\n".join(cache["msgs"]))  # display warnings | ||||
| 
 | ||||
|         # Read cache | ||||
|         [cache.pop(k) for k in ('hash', 'version', 'msgs')]  # remove items | ||||
|         labels = cache['labels'] | ||||
|         [cache.pop(k) for k in ("hash", "version", "msgs")]  # remove items | ||||
|         labels = cache["labels"] | ||||
|         if not labels: | ||||
|             LOGGER.warning(f'WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}') | ||||
|         self.im_files = [lb['im_file'] for lb in labels]  # update im_files | ||||
|             LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}") | ||||
|         self.im_files = [lb["im_file"] for lb in labels]  # update im_files | ||||
| 
 | ||||
|         # Check if the dataset is all boxes or all segments | ||||
|         lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels) | ||||
|         lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels) | ||||
|         len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) | ||||
|         if len_segments and len_boxes != len_segments: | ||||
|             LOGGER.warning( | ||||
|                 f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, ' | ||||
|                 f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. ' | ||||
|                 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.') | ||||
|                 f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " | ||||
|                 f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " | ||||
|                 "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset." | ||||
|             ) | ||||
|             for lb in labels: | ||||
|                 lb['segments'] = [] | ||||
|                 lb["segments"] = [] | ||||
|         if len_cls == 0: | ||||
|             LOGGER.warning(f'WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}') | ||||
|             LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}") | ||||
|         return labels | ||||
| 
 | ||||
|     def build_transforms(self, hyp=None): | ||||
| @ -145,14 +158,17 @@ class YOLODataset(BaseDataset): | ||||
|         else: | ||||
|             transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) | ||||
|         transforms.append( | ||||
|             Format(bbox_format='xywh', | ||||
|             Format( | ||||
|                 bbox_format="xywh", | ||||
|                 normalize=True, | ||||
|                 return_mask=self.use_segments, | ||||
|                 return_keypoint=self.use_keypoints, | ||||
|                 return_obb=self.use_obb, | ||||
|                 batch_idx=True, | ||||
|                 mask_ratio=hyp.mask_ratio, | ||||
|                    mask_overlap=hyp.overlap_mask)) | ||||
|                 mask_overlap=hyp.overlap_mask, | ||||
|             ) | ||||
|         ) | ||||
|         return transforms | ||||
| 
 | ||||
|     def close_mosaic(self, hyp): | ||||
| @ -166,11 +182,11 @@ class YOLODataset(BaseDataset): | ||||
|         """Custom your label format here.""" | ||||
|         # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label | ||||
|         # We can make it also support classification and semantic segmentation by add or remove some dict keys there. | ||||
|         bboxes = label.pop('bboxes') | ||||
|         segments = label.pop('segments', []) | ||||
|         keypoints = label.pop('keypoints', None) | ||||
|         bbox_format = label.pop('bbox_format') | ||||
|         normalized = label.pop('normalized') | ||||
|         bboxes = label.pop("bboxes") | ||||
|         segments = label.pop("segments", []) | ||||
|         keypoints = label.pop("keypoints", None) | ||||
|         bbox_format = label.pop("bbox_format") | ||||
|         normalized = label.pop("normalized") | ||||
| 
 | ||||
|         # NOTE: do NOT resample oriented boxes | ||||
|         segment_resamples = 100 if self.use_obb else 1000 | ||||
| @ -180,7 +196,7 @@ class YOLODataset(BaseDataset): | ||||
|             segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0) | ||||
|         else: | ||||
|             segments = np.zeros((0, segment_resamples, 2), dtype=np.float32) | ||||
|         label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) | ||||
|         label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) | ||||
|         return label | ||||
| 
 | ||||
|     @staticmethod | ||||
| @ -191,15 +207,15 @@ class YOLODataset(BaseDataset): | ||||
|         values = list(zip(*[list(b.values()) for b in batch])) | ||||
|         for i, k in enumerate(keys): | ||||
|             value = values[i] | ||||
|             if k == 'img': | ||||
|             if k == "img": | ||||
|                 value = torch.stack(value, 0) | ||||
|             if k in ['masks', 'keypoints', 'bboxes', 'cls', 'segments', 'obb']: | ||||
|             if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]: | ||||
|                 value = torch.cat(value, 0) | ||||
|             new_batch[k] = value | ||||
|         new_batch['batch_idx'] = list(new_batch['batch_idx']) | ||||
|         for i in range(len(new_batch['batch_idx'])): | ||||
|             new_batch['batch_idx'][i] += i  # add target image index for build_targets() | ||||
|         new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0) | ||||
|         new_batch["batch_idx"] = list(new_batch["batch_idx"]) | ||||
|         for i in range(len(new_batch["batch_idx"])): | ||||
|             new_batch["batch_idx"][i] += i  # add target image index for build_targets() | ||||
|         new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0) | ||||
|         return new_batch | ||||
| 
 | ||||
| 
 | ||||
| @ -219,7 +235,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): | ||||
|         album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, root, args, augment=False, cache=False, prefix=''): | ||||
|     def __init__(self, root, args, augment=False, cache=False, prefix=""): | ||||
|         """ | ||||
|         Initialize YOLO object with root, image size, augmentations, and cache settings. | ||||
| 
 | ||||
| @ -231,14 +247,16 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): | ||||
|         """ | ||||
|         super().__init__(root=root) | ||||
|         if augment and args.fraction < 1.0:  # reduce training fraction | ||||
|             self.samples = self.samples[:round(len(self.samples) * args.fraction)] | ||||
|         self.prefix = colorstr(f'{prefix}: ') if prefix else '' | ||||
|         self.cache_ram = cache is True or cache == 'ram' | ||||
|         self.cache_disk = cache == 'disk' | ||||
|             self.samples = self.samples[: round(len(self.samples) * args.fraction)] | ||||
|         self.prefix = colorstr(f"{prefix}: ") if prefix else "" | ||||
|         self.cache_ram = cache is True or cache == "ram" | ||||
|         self.cache_disk = cache == "disk" | ||||
|         self.samples = self.verify_images()  # filter out bad images | ||||
|         self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples]  # file, index, npy, im | ||||
|         self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples]  # file, index, npy, im | ||||
|         scale = (1.0 - args.scale, 1.0)  # (0.08, 1.0) | ||||
|         self.torch_transforms = classify_augmentations(size=args.imgsz, | ||||
|         self.torch_transforms = ( | ||||
|             classify_augmentations( | ||||
|                 size=args.imgsz, | ||||
|                 scale=scale, | ||||
|                 hflip=args.fliplr, | ||||
|                 vflip=args.flipud, | ||||
| @ -246,8 +264,11 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): | ||||
|                 auto_augment=args.auto_augment, | ||||
|                 hsv_h=args.hsv_h, | ||||
|                 hsv_s=args.hsv_s, | ||||
|                                                        hsv_v=args.hsv_v) if augment else classify_transforms( | ||||
|                                                            size=args.imgsz, crop_fraction=args.crop_fraction) | ||||
|                 hsv_v=args.hsv_v, | ||||
|             ) | ||||
|             if augment | ||||
|             else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction) | ||||
|         ) | ||||
| 
 | ||||
|     def __getitem__(self, i): | ||||
|         """Returns subset of data and targets corresponding to given indices.""" | ||||
| @ -263,7 +284,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): | ||||
|         # Convert NumPy array to PIL image | ||||
|         im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) | ||||
|         sample = self.torch_transforms(im) | ||||
|         return {'img': sample, 'cls': j} | ||||
|         return {"img": sample, "cls": j} | ||||
| 
 | ||||
|     def __len__(self) -> int: | ||||
|         """Return the total number of samples in the dataset.""" | ||||
| @ -271,19 +292,19 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): | ||||
| 
 | ||||
|     def verify_images(self): | ||||
|         """Verify all images in dataset.""" | ||||
|         desc = f'{self.prefix}Scanning {self.root}...' | ||||
|         path = Path(self.root).with_suffix('.cache')  # *.cache file path | ||||
|         desc = f"{self.prefix}Scanning {self.root}..." | ||||
|         path = Path(self.root).with_suffix(".cache")  # *.cache file path | ||||
| 
 | ||||
|         with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError): | ||||
|             cache = load_dataset_cache_file(path)  # attempt to load a *.cache file | ||||
|             assert cache['version'] == DATASET_CACHE_VERSION  # matches current version | ||||
|             assert cache['hash'] == get_hash([x[0] for x in self.samples])  # identical hash | ||||
|             nf, nc, n, samples = cache.pop('results')  # found, missing, empty, corrupt, total | ||||
|             assert cache["version"] == DATASET_CACHE_VERSION  # matches current version | ||||
|             assert cache["hash"] == get_hash([x[0] for x in self.samples])  # identical hash | ||||
|             nf, nc, n, samples = cache.pop("results")  # found, missing, empty, corrupt, total | ||||
|             if LOCAL_RANK in (-1, 0): | ||||
|                 d = f'{desc} {nf} images, {nc} corrupt' | ||||
|                 d = f"{desc} {nf} images, {nc} corrupt" | ||||
|                 TQDM(None, desc=d, total=n, initial=n) | ||||
|                 if cache['msgs']: | ||||
|                     LOGGER.info('\n'.join(cache['msgs']))  # display warnings | ||||
|                 if cache["msgs"]: | ||||
|                     LOGGER.info("\n".join(cache["msgs"]))  # display warnings | ||||
|             return samples | ||||
| 
 | ||||
|         # Run scan if *.cache retrieval failed | ||||
| @ -298,13 +319,13 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): | ||||
|                     msgs.append(msg) | ||||
|                 nf += nf_f | ||||
|                 nc += nc_f | ||||
|                 pbar.desc = f'{desc} {nf} images, {nc} corrupt' | ||||
|                 pbar.desc = f"{desc} {nf} images, {nc} corrupt" | ||||
|             pbar.close() | ||||
|         if msgs: | ||||
|             LOGGER.info('\n'.join(msgs)) | ||||
|         x['hash'] = get_hash([x[0] for x in self.samples]) | ||||
|         x['results'] = nf, nc, len(samples), samples | ||||
|         x['msgs'] = msgs  # warnings | ||||
|             LOGGER.info("\n".join(msgs)) | ||||
|         x["hash"] = get_hash([x[0] for x in self.samples]) | ||||
|         x["results"] = nf, nc, len(samples), samples | ||||
|         x["msgs"] = msgs  # warnings | ||||
|         save_dataset_cache_file(self.prefix, path, x) | ||||
|         return samples | ||||
| 
 | ||||
| @ -312,6 +333,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): | ||||
| def load_dataset_cache_file(path): | ||||
|     """Load an Ultralytics *.cache dictionary from path.""" | ||||
|     import gc | ||||
| 
 | ||||
|     gc.disable()  # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 | ||||
|     cache = np.load(str(path), allow_pickle=True).item()  # load dict | ||||
|     gc.enable() | ||||
| @ -320,15 +342,15 @@ def load_dataset_cache_file(path): | ||||
| 
 | ||||
| def save_dataset_cache_file(prefix, path, x): | ||||
|     """Save an Ultralytics dataset *.cache dictionary x to path.""" | ||||
|     x['version'] = DATASET_CACHE_VERSION  # add cache version | ||||
|     x["version"] = DATASET_CACHE_VERSION  # add cache version | ||||
|     if is_dir_writeable(path.parent): | ||||
|         if path.exists(): | ||||
|             path.unlink()  # remove *.cache file if exists | ||||
|         np.save(str(path), x)  # save cache for next time | ||||
|         path.with_suffix('.cache.npy').rename(path)  # remove .npy suffix | ||||
|         LOGGER.info(f'{prefix}New cache created: {path}') | ||||
|         path.with_suffix(".cache.npy").rename(path)  # remove .npy suffix | ||||
|         LOGGER.info(f"{prefix}New cache created: {path}") | ||||
|     else: | ||||
|         LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.') | ||||
|         LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.") | ||||
| 
 | ||||
| 
 | ||||
| # TODO: support semantic segmentation | ||||
|  | ||||
| @ -2,4 +2,4 @@ | ||||
| 
 | ||||
| from .utils import plot_query_result | ||||
| 
 | ||||
| __all__ = ['plot_query_result'] | ||||
| __all__ = ["plot_query_result"] | ||||
|  | ||||
| @ -22,7 +22,6 @@ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, pr | ||||
| 
 | ||||
| 
 | ||||
| class ExplorerDataset(YOLODataset): | ||||
| 
 | ||||
|     def __init__(self, *args, data: dict = None, **kwargs) -> None: | ||||
|         super().__init__(*args, data=data, **kwargs) | ||||
| 
 | ||||
| @ -35,7 +34,7 @@ class ExplorerDataset(YOLODataset): | ||||
|             else:  # read image | ||||
|                 im = cv2.imread(f)  # BGR | ||||
|                 if im is None: | ||||
|                     raise FileNotFoundError(f'Image Not Found {f}') | ||||
|                     raise FileNotFoundError(f"Image Not Found {f}") | ||||
|             h0, w0 = im.shape[:2]  # orig hw | ||||
|             return im, (h0, w0), im.shape[:2] | ||||
| 
 | ||||
| @ -44,7 +43,7 @@ class ExplorerDataset(YOLODataset): | ||||
|     def build_transforms(self, hyp: IterableSimpleNamespace = None): | ||||
|         """Creates transforms for dataset images without resizing.""" | ||||
|         return Format( | ||||
|             bbox_format='xyxy', | ||||
|             bbox_format="xyxy", | ||||
|             normalize=False, | ||||
|             return_mask=self.use_segments, | ||||
|             return_keypoint=self.use_keypoints, | ||||
| @ -55,17 +54,16 @@ class ExplorerDataset(YOLODataset): | ||||
| 
 | ||||
| 
 | ||||
| class Explorer: | ||||
| 
 | ||||
|     def __init__(self, | ||||
|                  data: Union[str, Path] = 'coco128.yaml', | ||||
|                  model: str = 'yolov8n.pt', | ||||
|                  uri: str = '~/ultralytics/explorer') -> None: | ||||
|         checks.check_requirements(['lancedb>=0.4.3', 'duckdb']) | ||||
|     def __init__( | ||||
|         self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer" | ||||
|     ) -> None: | ||||
|         checks.check_requirements(["lancedb>=0.4.3", "duckdb"]) | ||||
|         import lancedb | ||||
| 
 | ||||
|         self.connection = lancedb.connect(uri) | ||||
|         self.table_name = Path(data).name.lower() + '_' + model.lower() | ||||
|         self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower( | ||||
|         self.table_name = Path(data).name.lower() + "_" + model.lower() | ||||
|         self.sim_idx_base_name = ( | ||||
|             f"{self.table_name}_sim_idx".lower() | ||||
|         )  # Use this name and append thres and top_k to reuse the table | ||||
|         self.model = YOLO(model) | ||||
|         self.data = data  # None | ||||
| @ -74,7 +72,7 @@ class Explorer: | ||||
|         self.table = None | ||||
|         self.progress = 0 | ||||
| 
 | ||||
|     def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None: | ||||
|     def create_embeddings_table(self, force: bool = False, split: str = "train") -> None: | ||||
|         """ | ||||
|         Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it | ||||
|         already exists. Pass force=True to overwrite the existing table. | ||||
| @ -90,20 +88,20 @@ class Explorer: | ||||
|             ``` | ||||
|         """ | ||||
|         if self.table is not None and not force: | ||||
|             LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.') | ||||
|             LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.") | ||||
|             return | ||||
|         if self.table_name in self.connection.table_names() and not force: | ||||
|             LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.') | ||||
|             LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.") | ||||
|             self.table = self.connection.open_table(self.table_name) | ||||
|             self.progress = 1 | ||||
|             return | ||||
|         if self.data is None: | ||||
|             raise ValueError('Data must be provided to create embeddings table') | ||||
|             raise ValueError("Data must be provided to create embeddings table") | ||||
| 
 | ||||
|         data_info = check_det_dataset(self.data) | ||||
|         if split not in data_info: | ||||
|             raise ValueError( | ||||
|                 f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}' | ||||
|                 f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}" | ||||
|             ) | ||||
| 
 | ||||
|         choice_set = data_info[split] | ||||
| @ -113,13 +111,16 @@ class Explorer: | ||||
| 
 | ||||
|         # Create the table schema | ||||
|         batch = dataset[0] | ||||
|         vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0] | ||||
|         table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite') | ||||
|         vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0] | ||||
|         table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite") | ||||
|         table.add( | ||||
|             self._yield_batches(dataset, | ||||
|             self._yield_batches( | ||||
|                 dataset, | ||||
|                 data_info, | ||||
|                 self.model, | ||||
|                                 exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx'])) | ||||
|                 exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"], | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         self.table = table | ||||
| 
 | ||||
| @ -131,12 +132,12 @@ class Explorer: | ||||
|             for k in exclude_keys: | ||||
|                 batch.pop(k, None) | ||||
|             batch = sanitize_batch(batch, data_info) | ||||
|             batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist() | ||||
|             batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist() | ||||
|             yield [batch] | ||||
| 
 | ||||
|     def query(self, | ||||
|               imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, | ||||
|               limit: int = 25) -> Any:  # pyarrow.Table | ||||
|     def query( | ||||
|         self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25 | ||||
|     ) -> Any:  # pyarrow.Table | ||||
|         """ | ||||
|         Query the table for similar images. Accepts a single image or a list of images. | ||||
| 
 | ||||
| @ -157,18 +158,18 @@ class Explorer: | ||||
|             ``` | ||||
|         """ | ||||
|         if self.table is None: | ||||
|             raise ValueError('Table is not created. Please create the table first.') | ||||
|             raise ValueError("Table is not created. Please create the table first.") | ||||
|         if isinstance(imgs, str): | ||||
|             imgs = [imgs] | ||||
|         assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}' | ||||
|         assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}" | ||||
|         embeds = self.model.embed(imgs) | ||||
|         # Get avg if multiple images are passed (len > 1) | ||||
|         embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() | ||||
|         return self.table.search(embeds).limit(limit).to_arrow() | ||||
| 
 | ||||
|     def sql_query(self, | ||||
|                   query: str, | ||||
|                   return_type: str = 'pandas') -> Union[DataFrame, Any, None]:  # pandas.dataframe or pyarrow.Table | ||||
|     def sql_query( | ||||
|         self, query: str, return_type: str = "pandas" | ||||
|     ) -> Union[DataFrame, Any, None]:  # pandas.dataframe or pyarrow.Table | ||||
|         """ | ||||
|         Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. | ||||
| 
 | ||||
| @ -187,27 +188,29 @@ class Explorer: | ||||
|             result = exp.sql_query(query) | ||||
|             ``` | ||||
|         """ | ||||
|         assert return_type in ['pandas', | ||||
|                                'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}' | ||||
|         assert return_type in [ | ||||
|             "pandas", | ||||
|             "arrow", | ||||
|         ], f"Return type should be either `pandas` or `arrow`, but got {return_type}" | ||||
|         import duckdb | ||||
| 
 | ||||
|         if self.table is None: | ||||
|             raise ValueError('Table is not created. Please create the table first.') | ||||
|             raise ValueError("Table is not created. Please create the table first.") | ||||
| 
 | ||||
|         # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this. | ||||
|         table = self.table.to_arrow()  # noqa NOTE: Don't comment this. This line is used by DuckDB | ||||
|         if not query.startswith('SELECT') and not query.startswith('WHERE'): | ||||
|         if not query.startswith("SELECT") and not query.startswith("WHERE"): | ||||
|             raise ValueError( | ||||
|                 f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}' | ||||
|                 f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}" | ||||
|             ) | ||||
|         if query.startswith('WHERE'): | ||||
|         if query.startswith("WHERE"): | ||||
|             query = f"SELECT * FROM 'table' {query}" | ||||
|         LOGGER.info(f'Running query: {query}') | ||||
|         LOGGER.info(f"Running query: {query}") | ||||
| 
 | ||||
|         rs = duckdb.sql(query) | ||||
|         if return_type == 'pandas': | ||||
|         if return_type == "pandas": | ||||
|             return rs.df() | ||||
|         elif return_type == 'arrow': | ||||
|         elif return_type == "arrow": | ||||
|             return rs.arrow() | ||||
| 
 | ||||
|     def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image: | ||||
| @ -228,18 +231,20 @@ class Explorer: | ||||
|             result = exp.plot_sql_query(query) | ||||
|             ``` | ||||
|         """ | ||||
|         result = self.sql_query(query, return_type='arrow') | ||||
|         result = self.sql_query(query, return_type="arrow") | ||||
|         if len(result) == 0: | ||||
|             LOGGER.info('No results found.') | ||||
|             LOGGER.info("No results found.") | ||||
|             return None | ||||
|         img = plot_query_result(result, plot_labels=labels) | ||||
|         return Image.fromarray(img) | ||||
| 
 | ||||
|     def get_similar(self, | ||||
|     def get_similar( | ||||
|         self, | ||||
|         img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, | ||||
|         idx: Union[int, List[int]] = None, | ||||
|         limit: int = 25, | ||||
|                     return_type: str = 'pandas') -> Union[DataFrame, Any]:  # pandas.dataframe or pyarrow.Table | ||||
|         return_type: str = "pandas", | ||||
|     ) -> Union[DataFrame, Any]:  # pandas.dataframe or pyarrow.Table | ||||
|         """ | ||||
|         Query the table for similar images. Accepts a single image or a list of images. | ||||
| 
 | ||||
| @ -259,21 +264,25 @@ class Explorer: | ||||
|             similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg') | ||||
|             ``` | ||||
|         """ | ||||
|         assert return_type in ['pandas', | ||||
|                                'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}' | ||||
|         assert return_type in [ | ||||
|             "pandas", | ||||
|             "arrow", | ||||
|         ], f"Return type should be either `pandas` or `arrow`, but got {return_type}" | ||||
|         img = self._check_imgs_or_idxs(img, idx) | ||||
|         similar = self.query(img, limit=limit) | ||||
| 
 | ||||
|         if return_type == 'pandas': | ||||
|         if return_type == "pandas": | ||||
|             return similar.to_pandas() | ||||
|         elif return_type == 'arrow': | ||||
|         elif return_type == "arrow": | ||||
|             return similar | ||||
| 
 | ||||
|     def plot_similar(self, | ||||
|     def plot_similar( | ||||
|         self, | ||||
|         img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, | ||||
|         idx: Union[int, List[int]] = None, | ||||
|         limit: int = 25, | ||||
|                      labels: bool = True) -> Image.Image: | ||||
|         labels: bool = True, | ||||
|     ) -> Image.Image: | ||||
|         """ | ||||
|         Plot the similar images. Accepts images or indexes. | ||||
| 
 | ||||
| @ -293,9 +302,9 @@ class Explorer: | ||||
|             similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg') | ||||
|             ``` | ||||
|         """ | ||||
|         similar = self.get_similar(img, idx, limit, return_type='arrow') | ||||
|         similar = self.get_similar(img, idx, limit, return_type="arrow") | ||||
|         if len(similar) == 0: | ||||
|             LOGGER.info('No results found.') | ||||
|             LOGGER.info("No results found.") | ||||
|             return None | ||||
|         img = plot_query_result(similar, plot_labels=labels) | ||||
|         return Image.fromarray(img) | ||||
| @ -323,34 +332,37 @@ class Explorer: | ||||
|             ``` | ||||
|         """ | ||||
|         if self.table is None: | ||||
|             raise ValueError('Table is not created. Please create the table first.') | ||||
|         sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower() | ||||
|             raise ValueError("Table is not created. Please create the table first.") | ||||
|         sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower() | ||||
|         if sim_idx_table_name in self.connection.table_names() and not force: | ||||
|             LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.') | ||||
|             LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.") | ||||
|             return self.connection.open_table(sim_idx_table_name).to_pandas() | ||||
| 
 | ||||
|         if top_k and not (1.0 >= top_k >= 0.0): | ||||
|             raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}') | ||||
|             raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}") | ||||
|         if max_dist < 0.0: | ||||
|             raise ValueError(f'max_dist must be greater than 0. Got {max_dist}') | ||||
|             raise ValueError(f"max_dist must be greater than 0. Got {max_dist}") | ||||
| 
 | ||||
|         top_k = int(top_k * len(self.table)) if top_k else len(self.table) | ||||
|         top_k = max(top_k, 1) | ||||
|         features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict() | ||||
|         im_files = features['im_file'] | ||||
|         embeddings = features['vector'] | ||||
|         features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict() | ||||
|         im_files = features["im_file"] | ||||
|         embeddings = features["vector"] | ||||
| 
 | ||||
|         sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite') | ||||
|         sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite") | ||||
| 
 | ||||
|         def _yield_sim_idx(): | ||||
|             """Generates a dataframe with similarity indices and distances for images.""" | ||||
|             for i in tqdm(range(len(embeddings))): | ||||
|                 sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}') | ||||
|                 yield [{ | ||||
|                     'idx': i, | ||||
|                     'im_file': im_files[i], | ||||
|                     'count': len(sim_idx), | ||||
|                     'sim_im_files': sim_idx['im_file'].tolist()}] | ||||
|                 sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}") | ||||
|                 yield [ | ||||
|                     { | ||||
|                         "idx": i, | ||||
|                         "im_file": im_files[i], | ||||
|                         "count": len(sim_idx), | ||||
|                         "sim_im_files": sim_idx["im_file"].tolist(), | ||||
|                     } | ||||
|                 ] | ||||
| 
 | ||||
|         sim_table.add(_yield_sim_idx()) | ||||
|         self.sim_index = sim_table | ||||
| @ -381,7 +393,7 @@ class Explorer: | ||||
|             ``` | ||||
|         """ | ||||
|         sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) | ||||
|         sim_count = sim_idx['count'].tolist() | ||||
|         sim_count = sim_idx["count"].tolist() | ||||
|         sim_count = np.array(sim_count) | ||||
| 
 | ||||
|         indices = np.arange(len(sim_count)) | ||||
| @ -390,25 +402,26 @@ class Explorer: | ||||
|         plt.bar(indices, sim_count) | ||||
| 
 | ||||
|         # Customize the plot (optional) | ||||
|         plt.xlabel('data idx') | ||||
|         plt.ylabel('Count') | ||||
|         plt.title('Similarity Count') | ||||
|         plt.xlabel("data idx") | ||||
|         plt.ylabel("Count") | ||||
|         plt.title("Similarity Count") | ||||
|         buffer = BytesIO() | ||||
|         plt.savefig(buffer, format='png') | ||||
|         plt.savefig(buffer, format="png") | ||||
|         buffer.seek(0) | ||||
| 
 | ||||
|         # Use Pillow to open the image from the buffer | ||||
|         return Image.fromarray(np.array(Image.open(buffer))) | ||||
| 
 | ||||
|     def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], | ||||
|                             idx: Union[None, int, List[int]]) -> List[np.ndarray]: | ||||
|     def _check_imgs_or_idxs( | ||||
|         self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]] | ||||
|     ) -> List[np.ndarray]: | ||||
|         if img is None and idx is None: | ||||
|             raise ValueError('Either img or idx must be provided.') | ||||
|             raise ValueError("Either img or idx must be provided.") | ||||
|         if img is not None and idx is not None: | ||||
|             raise ValueError('Only one of img or idx must be provided.') | ||||
|             raise ValueError("Only one of img or idx must be provided.") | ||||
|         if idx is not None: | ||||
|             idx = idx if isinstance(idx, list) else [idx] | ||||
|             img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file'] | ||||
|             img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"] | ||||
| 
 | ||||
|         return img if isinstance(img, list) else [img] | ||||
| 
 | ||||
| @ -433,7 +446,7 @@ class Explorer: | ||||
|         try: | ||||
|             df = self.sql_query(result) | ||||
|         except Exception as e: | ||||
|             LOGGER.error('AI generated query is not valid. Please try again with a different prompt') | ||||
|             LOGGER.error("AI generated query is not valid. Please try again with a different prompt") | ||||
|             LOGGER.error(e) | ||||
|             return None | ||||
|         return df | ||||
|  | ||||
| @ -9,100 +9,114 @@ from ultralytics import Explorer | ||||
| from ultralytics.utils import ROOT, SETTINGS | ||||
| from ultralytics.utils.checks import check_requirements | ||||
| 
 | ||||
| check_requirements(('streamlit>=1.29.0', 'streamlit-select>=0.2')) | ||||
| check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2")) | ||||
| 
 | ||||
| import streamlit as st | ||||
| from streamlit_select import image_select | ||||
| 
 | ||||
| 
 | ||||
| def _get_explorer(): | ||||
|     """Initializes and returns an instance of the Explorer class.""" | ||||
|     exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model')) | ||||
|     thread = Thread(target=exp.create_embeddings_table, | ||||
|                     kwargs={'force': st.session_state.get('force_recreate_embeddings')}) | ||||
|     exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model")) | ||||
|     thread = Thread( | ||||
|         target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")} | ||||
|     ) | ||||
|     thread.start() | ||||
|     progress_bar = st.progress(0, text='Creating embeddings table...') | ||||
|     progress_bar = st.progress(0, text="Creating embeddings table...") | ||||
|     while exp.progress < 1: | ||||
|         time.sleep(0.1) | ||||
|         progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%') | ||||
|         progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%") | ||||
|     thread.join() | ||||
|     st.session_state['explorer'] = exp | ||||
|     st.session_state["explorer"] = exp | ||||
|     progress_bar.empty() | ||||
| 
 | ||||
| 
 | ||||
| def init_explorer_form(): | ||||
|     """Initializes an Explorer instance and creates embeddings table with progress tracking.""" | ||||
|     datasets = ROOT / 'cfg' / 'datasets' | ||||
|     ds = [d.name for d in datasets.glob('*.yaml')] | ||||
|     datasets = ROOT / "cfg" / "datasets" | ||||
|     ds = [d.name for d in datasets.glob("*.yaml")] | ||||
|     models = [ | ||||
|         'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt', | ||||
|         'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt', | ||||
|         'yolov8l-pose.pt', 'yolov8x-pose.pt'] | ||||
|     with st.form(key='explorer_init_form'): | ||||
|         "yolov8n.pt", | ||||
|         "yolov8s.pt", | ||||
|         "yolov8m.pt", | ||||
|         "yolov8l.pt", | ||||
|         "yolov8x.pt", | ||||
|         "yolov8n-seg.pt", | ||||
|         "yolov8s-seg.pt", | ||||
|         "yolov8m-seg.pt", | ||||
|         "yolov8l-seg.pt", | ||||
|         "yolov8x-seg.pt", | ||||
|         "yolov8n-pose.pt", | ||||
|         "yolov8s-pose.pt", | ||||
|         "yolov8m-pose.pt", | ||||
|         "yolov8l-pose.pt", | ||||
|         "yolov8x-pose.pt", | ||||
|     ] | ||||
|     with st.form(key="explorer_init_form"): | ||||
|         col1, col2 = st.columns(2) | ||||
|         with col1: | ||||
|             st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml')) | ||||
|             st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml")) | ||||
|         with col2: | ||||
|             st.selectbox('Select model', models, key='model') | ||||
|         st.checkbox('Force recreate embeddings', key='force_recreate_embeddings') | ||||
|             st.selectbox("Select model", models, key="model") | ||||
|         st.checkbox("Force recreate embeddings", key="force_recreate_embeddings") | ||||
| 
 | ||||
|         st.form_submit_button('Explore', on_click=_get_explorer) | ||||
|         st.form_submit_button("Explore", on_click=_get_explorer) | ||||
| 
 | ||||
| 
 | ||||
| def query_form(): | ||||
|     """Sets up a form in Streamlit to initialize Explorer with dataset and model selection.""" | ||||
|     with st.form('query_form'): | ||||
|     with st.form("query_form"): | ||||
|         col1, col2 = st.columns([0.8, 0.2]) | ||||
|         with col1: | ||||
|             st.text_input('Query', | ||||
|             st.text_input( | ||||
|                 "Query", | ||||
|                 "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'", | ||||
|                           label_visibility='collapsed', | ||||
|                           key='query') | ||||
|                 label_visibility="collapsed", | ||||
|                 key="query", | ||||
|             ) | ||||
|         with col2: | ||||
|             st.form_submit_button('Query', on_click=run_sql_query) | ||||
|             st.form_submit_button("Query", on_click=run_sql_query) | ||||
| 
 | ||||
| 
 | ||||
| def ai_query_form(): | ||||
|     """Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection.""" | ||||
|     with st.form('ai_query_form'): | ||||
|     with st.form("ai_query_form"): | ||||
|         col1, col2 = st.columns([0.8, 0.2]) | ||||
|         with col1: | ||||
|             st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query') | ||||
|             st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query") | ||||
|         with col2: | ||||
|             st.form_submit_button('Ask AI', on_click=run_ai_query) | ||||
|             st.form_submit_button("Ask AI", on_click=run_ai_query) | ||||
| 
 | ||||
| 
 | ||||
| def find_similar_imgs(imgs): | ||||
|     """Initializes a Streamlit form for AI-based image querying with custom input.""" | ||||
|     exp = st.session_state['explorer'] | ||||
|     similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow') | ||||
|     paths = similar.to_pydict()['im_file'] | ||||
|     st.session_state['imgs'] = paths | ||||
|     exp = st.session_state["explorer"] | ||||
|     similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow") | ||||
|     paths = similar.to_pydict()["im_file"] | ||||
|     st.session_state["imgs"] = paths | ||||
| 
 | ||||
| 
 | ||||
| def similarity_form(selected_imgs): | ||||
|     """Initializes a form for AI-based image querying with custom input in Streamlit.""" | ||||
|     st.write('Similarity Search') | ||||
|     with st.form('similarity_form'): | ||||
|     st.write("Similarity Search") | ||||
|     with st.form("similarity_form"): | ||||
|         subcol1, subcol2 = st.columns([1, 1]) | ||||
|         with subcol1: | ||||
|             st.number_input('limit', | ||||
|                             min_value=None, | ||||
|                             max_value=None, | ||||
|                             value=25, | ||||
|                             label_visibility='collapsed', | ||||
|                             key='limit') | ||||
|             st.number_input( | ||||
|                 "limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit" | ||||
|             ) | ||||
| 
 | ||||
|         with subcol2: | ||||
|             disabled = not len(selected_imgs) | ||||
|             st.write('Selected: ', len(selected_imgs)) | ||||
|             st.write("Selected: ", len(selected_imgs)) | ||||
|             st.form_submit_button( | ||||
|                 'Search', | ||||
|                 "Search", | ||||
|                 disabled=disabled, | ||||
|                 on_click=find_similar_imgs, | ||||
|                 args=(selected_imgs, ), | ||||
|                 args=(selected_imgs,), | ||||
|             ) | ||||
|         if disabled: | ||||
|             st.error('Select at least one image to search.') | ||||
|             st.error("Select at least one image to search.") | ||||
| 
 | ||||
| 
 | ||||
| # def persist_reset_form(): | ||||
| @ -117,100 +131,108 @@ def similarity_form(selected_imgs): | ||||
| 
 | ||||
| def run_sql_query(): | ||||
|     """Executes an SQL query and returns the results.""" | ||||
|     st.session_state['error'] = None | ||||
|     query = st.session_state.get('query') | ||||
|     st.session_state["error"] = None | ||||
|     query = st.session_state.get("query") | ||||
|     if query.rstrip().lstrip(): | ||||
|         exp = st.session_state['explorer'] | ||||
|         res = exp.sql_query(query, return_type='arrow') | ||||
|         st.session_state['imgs'] = res.to_pydict()['im_file'] | ||||
|         exp = st.session_state["explorer"] | ||||
|         res = exp.sql_query(query, return_type="arrow") | ||||
|         st.session_state["imgs"] = res.to_pydict()["im_file"] | ||||
| 
 | ||||
| 
 | ||||
| def run_ai_query(): | ||||
|     """Execute SQL query and update session state with query results.""" | ||||
|     if not SETTINGS['openai_api_key']: | ||||
|     if not SETTINGS["openai_api_key"]: | ||||
|         st.session_state[ | ||||
|             'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."' | ||||
|             "error" | ||||
|         ] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."' | ||||
|         return | ||||
|     st.session_state['error'] = None | ||||
|     query = st.session_state.get('ai_query') | ||||
|     st.session_state["error"] = None | ||||
|     query = st.session_state.get("ai_query") | ||||
|     if query.rstrip().lstrip(): | ||||
|         exp = st.session_state['explorer'] | ||||
|         exp = st.session_state["explorer"] | ||||
|         res = exp.ask_ai(query) | ||||
|         if not isinstance(res, pd.DataFrame) or res.empty: | ||||
|             st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.' | ||||
|             st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it." | ||||
|             return | ||||
|         st.session_state['imgs'] = res['im_file'].to_list() | ||||
|         st.session_state["imgs"] = res["im_file"].to_list() | ||||
| 
 | ||||
| 
 | ||||
| def reset_explorer(): | ||||
|     """Resets the explorer to its initial state by clearing session variables.""" | ||||
|     st.session_state['explorer'] = None | ||||
|     st.session_state['imgs'] = None | ||||
|     st.session_state['error'] = None | ||||
|     st.session_state["explorer"] = None | ||||
|     st.session_state["imgs"] = None | ||||
|     st.session_state["error"] = None | ||||
| 
 | ||||
| 
 | ||||
| def utralytics_explorer_docs_callback(): | ||||
|     """Resets the explorer to its initial state by clearing session variables.""" | ||||
|     with st.container(border=True): | ||||
|         st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg', | ||||
|                  width=100) | ||||
|         st.image( | ||||
|             "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg", | ||||
|             width=100, | ||||
|         ) | ||||
|         st.markdown( | ||||
|             "<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>", | ||||
|             unsafe_allow_html=True, | ||||
|             help=None) | ||||
|         st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/') | ||||
|             help=None, | ||||
|         ) | ||||
|         st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/") | ||||
| 
 | ||||
| 
 | ||||
| def layout(): | ||||
|     """Resets explorer session variables and provides documentation with a link to API docs.""" | ||||
|     st.set_page_config(layout='wide', initial_sidebar_state='collapsed') | ||||
|     st.set_page_config(layout="wide", initial_sidebar_state="collapsed") | ||||
|     st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True) | ||||
| 
 | ||||
|     if st.session_state.get('explorer') is None: | ||||
|     if st.session_state.get("explorer") is None: | ||||
|         init_explorer_form() | ||||
|         return | ||||
| 
 | ||||
|     st.button(':arrow_backward: Select Dataset', on_click=reset_explorer) | ||||
|     exp = st.session_state.get('explorer') | ||||
|     col1, col2 = st.columns([0.75, 0.25], gap='small') | ||||
|     st.button(":arrow_backward: Select Dataset", on_click=reset_explorer) | ||||
|     exp = st.session_state.get("explorer") | ||||
|     col1, col2 = st.columns([0.75, 0.25], gap="small") | ||||
|     imgs = [] | ||||
|     if st.session_state.get('error'): | ||||
|         st.error(st.session_state['error']) | ||||
|     if st.session_state.get("error"): | ||||
|         st.error(st.session_state["error"]) | ||||
|     else: | ||||
|         imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file'] | ||||
|         imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"] | ||||
|     total_imgs, selected_imgs = len(imgs), [] | ||||
|     with col1: | ||||
|         subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5) | ||||
|         with subcol1: | ||||
|             st.write('Max Images Displayed:') | ||||
|             st.write("Max Images Displayed:") | ||||
|         with subcol2: | ||||
|             num = st.number_input('Max Images Displayed', | ||||
|             num = st.number_input( | ||||
|                 "Max Images Displayed", | ||||
|                 min_value=0, | ||||
|                 max_value=total_imgs, | ||||
|                 value=min(500, total_imgs), | ||||
|                                   key='num_imgs_displayed', | ||||
|                                   label_visibility='collapsed') | ||||
|                 key="num_imgs_displayed", | ||||
|                 label_visibility="collapsed", | ||||
|             ) | ||||
|         with subcol3: | ||||
|             st.write('Start Index:') | ||||
|             st.write("Start Index:") | ||||
|         with subcol4: | ||||
|             start_idx = st.number_input('Start Index', | ||||
|             start_idx = st.number_input( | ||||
|                 "Start Index", | ||||
|                 min_value=0, | ||||
|                 max_value=total_imgs, | ||||
|                 value=0, | ||||
|                                         key='start_index', | ||||
|                                         label_visibility='collapsed') | ||||
|                 key="start_index", | ||||
|                 label_visibility="collapsed", | ||||
|             ) | ||||
|         with subcol5: | ||||
|             reset = st.button('Reset', use_container_width=False, key='reset') | ||||
|             reset = st.button("Reset", use_container_width=False, key="reset") | ||||
|             if reset: | ||||
|                 st.session_state['imgs'] = None | ||||
|                 st.session_state["imgs"] = None | ||||
|                 st.experimental_rerun() | ||||
| 
 | ||||
|         query_form() | ||||
|         ai_query_form() | ||||
|         if total_imgs: | ||||
|             imgs_displayed = imgs[start_idx:start_idx + num] | ||||
|             imgs_displayed = imgs[start_idx : start_idx + num] | ||||
|             selected_imgs = image_select( | ||||
|                 f'Total samples: {total_imgs}', | ||||
|                 f"Total samples: {total_imgs}", | ||||
|                 images=imgs_displayed, | ||||
|                 use_container_width=False, | ||||
|                 # indices=[i for i in range(num)] if select_all else None, | ||||
| @ -222,5 +244,5 @@ def layout(): | ||||
|         utralytics_explorer_docs_callback() | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     layout() | ||||
|  | ||||
| @ -46,14 +46,13 @@ def get_sim_index_schema(): | ||||
| 
 | ||||
| def sanitize_batch(batch, dataset_info): | ||||
|     """Sanitizes input batch for inference, ensuring correct format and dimensions.""" | ||||
|     batch['cls'] = batch['cls'].flatten().int().tolist() | ||||
|     box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1]) | ||||
|     batch['bboxes'] = [box for box, _ in box_cls_pair] | ||||
|     batch['cls'] = [cls for _, cls in box_cls_pair] | ||||
|     batch['labels'] = [dataset_info['names'][i] for i in batch['cls']] | ||||
|     batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]] | ||||
|     batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]] | ||||
| 
 | ||||
|     batch["cls"] = batch["cls"].flatten().int().tolist() | ||||
|     box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1]) | ||||
|     batch["bboxes"] = [box for box, _ in box_cls_pair] | ||||
|     batch["cls"] = [cls for _, cls in box_cls_pair] | ||||
|     batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]] | ||||
|     batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]] | ||||
|     batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]] | ||||
|     return batch | ||||
| 
 | ||||
| 
 | ||||
| @ -65,15 +64,16 @@ def plot_query_result(similar_set, plot_labels=True): | ||||
|         similar_set (list): Pyarrow or pandas object containing the similar data points | ||||
|         plot_labels (bool): Whether to plot labels or not | ||||
|     """ | ||||
|     similar_set = similar_set.to_dict( | ||||
|         orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict() | ||||
|     similar_set = ( | ||||
|         similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict() | ||||
|     ) | ||||
|     empty_masks = [[[]]] | ||||
|     empty_boxes = [[]] | ||||
|     images = similar_set.get('im_file', []) | ||||
|     bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else [] | ||||
|     masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else [] | ||||
|     kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else [] | ||||
|     cls = similar_set.get('cls', []) | ||||
|     images = similar_set.get("im_file", []) | ||||
|     bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else [] | ||||
|     masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else [] | ||||
|     kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else [] | ||||
|     cls = similar_set.get("cls", []) | ||||
| 
 | ||||
|     plot_size = 640 | ||||
|     imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], [] | ||||
| @ -104,34 +104,26 @@ def plot_query_result(similar_set, plot_labels=True): | ||||
|     batch_idx = np.concatenate(batch_idx, axis=0) | ||||
|     cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0) | ||||
| 
 | ||||
|     return plot_images(imgs, | ||||
|                        batch_idx, | ||||
|                        cls, | ||||
|                        bboxes=boxes, | ||||
|                        masks=masks, | ||||
|                        kpts=kpts, | ||||
|                        max_subplots=len(images), | ||||
|                        save=False, | ||||
|                        threaded=False) | ||||
|     return plot_images( | ||||
|         imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def prompt_sql_query(query): | ||||
|     """Plots images with optional labels from a similar data set.""" | ||||
|     check_requirements('openai>=1.6.1') | ||||
|     check_requirements("openai>=1.6.1") | ||||
|     from openai import OpenAI | ||||
| 
 | ||||
|     if not SETTINGS['openai_api_key']: | ||||
|         logger.warning('OpenAI API key not found in settings. Please enter your API key below.') | ||||
|         openai_api_key = getpass.getpass('OpenAI API key: ') | ||||
|         SETTINGS.update({'openai_api_key': openai_api_key}) | ||||
|     openai = OpenAI(api_key=SETTINGS['openai_api_key']) | ||||
|     if not SETTINGS["openai_api_key"]: | ||||
|         logger.warning("OpenAI API key not found in settings. Please enter your API key below.") | ||||
|         openai_api_key = getpass.getpass("OpenAI API key: ") | ||||
|         SETTINGS.update({"openai_api_key": openai_api_key}) | ||||
|     openai = OpenAI(api_key=SETTINGS["openai_api_key"]) | ||||
| 
 | ||||
|     messages = [ | ||||
|         { | ||||
|             'role': | ||||
|             'system', | ||||
|             'content': | ||||
|             ''' | ||||
|             "role": "system", | ||||
|             "content": """ | ||||
|                 You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on | ||||
|                 the following schema and a user request. You only need to output the format with fixed selection | ||||
|                 statement that selects everything from "'table'", like `SELECT * from 'table'` | ||||
| @ -165,10 +157,10 @@ def prompt_sql_query(query): | ||||
|                 request - Get all data points that contain 2 or more people and at least one dog | ||||
|                 correct query- | ||||
|                 SELECT * FROM 'table' WHERE  ARRAY_LENGTH(cls) >= 2  AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2  AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1; | ||||
|              '''}, | ||||
|         { | ||||
|             'role': 'user', | ||||
|             'content': f'{query}'}, ] | ||||
|              """, | ||||
|         }, | ||||
|         {"role": "user", "content": f"{query}"}, | ||||
|     ] | ||||
| 
 | ||||
|     response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages) | ||||
|     response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages) | ||||
|     return response.choices[0].message.content | ||||
|  | ||||
| @ -23,6 +23,7 @@ from ultralytics.utils.checks import check_requirements | ||||
| @dataclass | ||||
| class SourceTypes: | ||||
|     """Class to represent various types of input sources for predictions.""" | ||||
| 
 | ||||
|     webcam: bool = False | ||||
|     screenshot: bool = False | ||||
|     from_img: bool = False | ||||
| @ -59,12 +60,12 @@ class LoadStreams: | ||||
|         __len__: Return the length of the sources object. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, buffer=False): | ||||
|     def __init__(self, sources="file.streams", imgsz=640, vid_stride=1, buffer=False): | ||||
|         """Initialize instance variables and check for consistent input stream shapes.""" | ||||
|         torch.backends.cudnn.benchmark = True  # faster for fixed-size inference | ||||
|         self.buffer = buffer  # buffer input streams | ||||
|         self.running = True  # running flag for Thread | ||||
|         self.mode = 'stream' | ||||
|         self.mode = "stream" | ||||
|         self.imgsz = imgsz | ||||
|         self.vid_stride = vid_stride  # video frame-rate stride | ||||
| 
 | ||||
| @ -79,33 +80,36 @@ class LoadStreams: | ||||
|         self.sources = [ops.clean_str(x) for x in sources]  # clean source names for later | ||||
|         for i, s in enumerate(sources):  # index, source | ||||
|             # Start thread to read frames from video stream | ||||
|             st = f'{i + 1}/{n}: {s}... ' | ||||
|             if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'):  # if source is YouTube video | ||||
|             st = f"{i + 1}/{n}: {s}... " | ||||
|             if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"):  # if source is YouTube video | ||||
|                 # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4' | ||||
|                 s = get_best_youtube_url(s) | ||||
|             s = eval(s) if s.isnumeric() else s  # i.e. s = '0' local webcam | ||||
|             if s == 0 and (is_colab() or is_kaggle()): | ||||
|                 raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. " | ||||
|                                           "Try running 'source=0' in a local environment.") | ||||
|                 raise NotImplementedError( | ||||
|                     "'source=0' webcam not supported in Colab and Kaggle notebooks. " | ||||
|                     "Try running 'source=0' in a local environment." | ||||
|                 ) | ||||
|             self.caps[i] = cv2.VideoCapture(s)  # store video capture object | ||||
|             if not self.caps[i].isOpened(): | ||||
|                 raise ConnectionError(f'{st}Failed to open {s}') | ||||
|                 raise ConnectionError(f"{st}Failed to open {s}") | ||||
|             w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH)) | ||||
|             h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT)) | ||||
|             fps = self.caps[i].get(cv2.CAP_PROP_FPS)  # warning: may return 0 or nan | ||||
|             self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float( | ||||
|                 'inf')  # infinite stream fallback | ||||
|                 "inf" | ||||
|             )  # infinite stream fallback | ||||
|             self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30  # 30 FPS fallback | ||||
| 
 | ||||
|             success, im = self.caps[i].read()  # guarantee first frame | ||||
|             if not success or im is None: | ||||
|                 raise ConnectionError(f'{st}Failed to read images from {s}') | ||||
|                 raise ConnectionError(f"{st}Failed to read images from {s}") | ||||
|             self.imgs[i].append(im) | ||||
|             self.shape[i] = im.shape | ||||
|             self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True) | ||||
|             LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)') | ||||
|             LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)") | ||||
|             self.threads[i].start() | ||||
|         LOGGER.info('')  # newline | ||||
|         LOGGER.info("")  # newline | ||||
| 
 | ||||
|         # Check for common shapes | ||||
|         self.bs = self.__len__() | ||||
| @ -121,7 +125,7 @@ class LoadStreams: | ||||
|                     success, im = cap.retrieve() | ||||
|                     if not success: | ||||
|                         im = np.zeros(self.shape[i], dtype=np.uint8) | ||||
|                         LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.') | ||||
|                         LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.") | ||||
|                         cap.open(stream)  # re-open stream if signal was lost | ||||
|                     if self.buffer: | ||||
|                         self.imgs[i].append(im) | ||||
| @ -140,7 +144,7 @@ class LoadStreams: | ||||
|             try: | ||||
|                 cap.release()  # release video capture | ||||
|             except Exception as e: | ||||
|                 LOGGER.warning(f'WARNING ⚠️ Could not release VideoCapture object: {e}') | ||||
|                 LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}") | ||||
|         cv2.destroyAllWindows() | ||||
| 
 | ||||
|     def __iter__(self): | ||||
| @ -154,16 +158,15 @@ class LoadStreams: | ||||
| 
 | ||||
|         images = [] | ||||
|         for i, x in enumerate(self.imgs): | ||||
| 
 | ||||
|             # Wait until a frame is available in each buffer | ||||
|             while not x: | ||||
|                 if not self.threads[i].is_alive() or cv2.waitKey(1) == ord('q'):  # q to quit | ||||
|                 if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"):  # q to quit | ||||
|                     self.close() | ||||
|                     raise StopIteration | ||||
|                 time.sleep(1 / min(self.fps)) | ||||
|                 x = self.imgs[i] | ||||
|                 if not x: | ||||
|                     LOGGER.warning(f'WARNING ⚠️ Waiting for stream {i}') | ||||
|                     LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}") | ||||
| 
 | ||||
|             # Get and remove the first frame from imgs buffer | ||||
|             if self.buffer: | ||||
| @ -174,7 +177,7 @@ class LoadStreams: | ||||
|                 images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) | ||||
|                 x.clear() | ||||
| 
 | ||||
|         return self.sources, images, None, '' | ||||
|         return self.sources, images, None, "" | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         """Return the length of the sources object.""" | ||||
| @ -209,7 +212,7 @@ class LoadScreenshots: | ||||
| 
 | ||||
|     def __init__(self, source, imgsz=640): | ||||
|         """Source = [screen_number left top width height] (pixels).""" | ||||
|         check_requirements('mss') | ||||
|         check_requirements("mss") | ||||
|         import mss  # noqa | ||||
| 
 | ||||
|         source, *params = source.split() | ||||
| @ -221,18 +224,18 @@ class LoadScreenshots: | ||||
|         elif len(params) == 5: | ||||
|             self.screen, left, top, width, height = (int(x) for x in params) | ||||
|         self.imgsz = imgsz | ||||
|         self.mode = 'stream' | ||||
|         self.mode = "stream" | ||||
|         self.frame = 0 | ||||
|         self.sct = mss.mss() | ||||
|         self.bs = 1 | ||||
| 
 | ||||
|         # Parse monitor shape | ||||
|         monitor = self.sct.monitors[self.screen] | ||||
|         self.top = monitor['top'] if top is None else (monitor['top'] + top) | ||||
|         self.left = monitor['left'] if left is None else (monitor['left'] + left) | ||||
|         self.width = width or monitor['width'] | ||||
|         self.height = height or monitor['height'] | ||||
|         self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height} | ||||
|         self.top = monitor["top"] if top is None else (monitor["top"] + top) | ||||
|         self.left = monitor["left"] if left is None else (monitor["left"] + left) | ||||
|         self.width = width or monitor["width"] | ||||
|         self.height = height or monitor["height"] | ||||
|         self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height} | ||||
| 
 | ||||
|     def __iter__(self): | ||||
|         """Returns an iterator of the object.""" | ||||
| @ -241,7 +244,7 @@ class LoadScreenshots: | ||||
|     def __next__(self): | ||||
|         """mss screen capture: get raw pixels from the screen as np array.""" | ||||
|         im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3]  # BGRA to BGR | ||||
|         s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: ' | ||||
|         s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " | ||||
| 
 | ||||
|         self.frame += 1 | ||||
|         return [str(self.screen)], [im0], None, s  # screen, img, vid_cap, string | ||||
| @ -274,32 +277,32 @@ class LoadImages: | ||||
|     def __init__(self, path, imgsz=640, vid_stride=1): | ||||
|         """Initialize the Dataloader and raise FileNotFoundError if file not found.""" | ||||
|         parent = None | ||||
|         if isinstance(path, str) and Path(path).suffix == '.txt':  # *.txt file with img/vid/dir on each line | ||||
|         if isinstance(path, str) and Path(path).suffix == ".txt":  # *.txt file with img/vid/dir on each line | ||||
|             parent = Path(path).parent | ||||
|             path = Path(path).read_text().splitlines()  # list of sources | ||||
|         files = [] | ||||
|         for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: | ||||
|             a = str(Path(p).absolute())  # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912 | ||||
|             if '*' in a: | ||||
|             if "*" in a: | ||||
|                 files.extend(sorted(glob.glob(a, recursive=True)))  # glob | ||||
|             elif os.path.isdir(a): | ||||
|                 files.extend(sorted(glob.glob(os.path.join(a, '*.*'))))  # dir | ||||
|                 files.extend(sorted(glob.glob(os.path.join(a, "*.*"))))  # dir | ||||
|             elif os.path.isfile(a): | ||||
|                 files.append(a)  # files (absolute or relative to CWD) | ||||
|             elif parent and (parent / p).is_file(): | ||||
|                 files.append(str((parent / p).absolute()))  # files (relative to *.txt file parent) | ||||
|             else: | ||||
|                 raise FileNotFoundError(f'{p} does not exist') | ||||
|                 raise FileNotFoundError(f"{p} does not exist") | ||||
| 
 | ||||
|         images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS] | ||||
|         videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS] | ||||
|         images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS] | ||||
|         videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS] | ||||
|         ni, nv = len(images), len(videos) | ||||
| 
 | ||||
|         self.imgsz = imgsz | ||||
|         self.files = images + videos | ||||
|         self.nf = ni + nv  # number of files | ||||
|         self.video_flag = [False] * ni + [True] * nv | ||||
|         self.mode = 'image' | ||||
|         self.mode = "image" | ||||
|         self.vid_stride = vid_stride  # video frame-rate stride | ||||
|         self.bs = 1 | ||||
|         if any(videos): | ||||
| @ -307,8 +310,10 @@ class LoadImages: | ||||
|         else: | ||||
|             self.cap = None | ||||
|         if self.nf == 0: | ||||
|             raise FileNotFoundError(f'No images or videos found in {p}. ' | ||||
|                                     f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}') | ||||
|             raise FileNotFoundError( | ||||
|                 f"No images or videos found in {p}. " | ||||
|                 f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}" | ||||
|             ) | ||||
| 
 | ||||
|     def __iter__(self): | ||||
|         """Returns an iterator object for VideoStream or ImageFolder.""" | ||||
| @ -323,7 +328,7 @@ class LoadImages: | ||||
| 
 | ||||
|         if self.video_flag[self.count]: | ||||
|             # Read video | ||||
|             self.mode = 'video' | ||||
|             self.mode = "video" | ||||
|             for _ in range(self.vid_stride): | ||||
|                 self.cap.grab() | ||||
|             success, im0 = self.cap.retrieve() | ||||
| @ -338,15 +343,15 @@ class LoadImages: | ||||
| 
 | ||||
|             self.frame += 1 | ||||
|             # im0 = self._cv2_rotate(im0)  # for use if cv2 autorotation is False | ||||
|             s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ' | ||||
|             s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: " | ||||
| 
 | ||||
|         else: | ||||
|             # Read image | ||||
|             self.count += 1 | ||||
|             im0 = cv2.imread(path)  # BGR | ||||
|             if im0 is None: | ||||
|                 raise FileNotFoundError(f'Image Not Found {path}') | ||||
|             s = f'image {self.count}/{self.nf} {path}: ' | ||||
|                 raise FileNotFoundError(f"Image Not Found {path}") | ||||
|             s = f"image {self.count}/{self.nf} {path}: " | ||||
| 
 | ||||
|         return [path], [im0], self.cap, s | ||||
| 
 | ||||
| @ -385,20 +390,20 @@ class LoadPilAndNumpy: | ||||
|         """Initialize PIL and Numpy Dataloader.""" | ||||
|         if not isinstance(im0, list): | ||||
|             im0 = [im0] | ||||
|         self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)] | ||||
|         self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] | ||||
|         self.im0 = [self._single_check(im) for im in im0] | ||||
|         self.imgsz = imgsz | ||||
|         self.mode = 'image' | ||||
|         self.mode = "image" | ||||
|         # Generate fake paths | ||||
|         self.bs = len(self.im0) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _single_check(im): | ||||
|         """Validate and format an image to numpy array.""" | ||||
|         assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}' | ||||
|         assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" | ||||
|         if isinstance(im, Image.Image): | ||||
|             if im.mode != 'RGB': | ||||
|                 im = im.convert('RGB') | ||||
|             if im.mode != "RGB": | ||||
|                 im = im.convert("RGB") | ||||
|             im = np.asarray(im)[:, :, ::-1] | ||||
|             im = np.ascontiguousarray(im)  # contiguous | ||||
|         return im | ||||
| @ -412,7 +417,7 @@ class LoadPilAndNumpy: | ||||
|         if self.count == 1:  # loop only once as it's batch inference | ||||
|             raise StopIteration | ||||
|         self.count += 1 | ||||
|         return self.paths, self.im0, None, '' | ||||
|         return self.paths, self.im0, None, "" | ||||
| 
 | ||||
|     def __iter__(self): | ||||
|         """Enables iteration for class LoadPilAndNumpy.""" | ||||
| @ -441,14 +446,16 @@ class LoadTensor: | ||||
|         """Initialize Tensor Dataloader.""" | ||||
|         self.im0 = self._single_check(im0) | ||||
|         self.bs = self.im0.shape[0] | ||||
|         self.mode = 'image' | ||||
|         self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)] | ||||
|         self.mode = "image" | ||||
|         self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _single_check(im, stride=32): | ||||
|         """Validate and format an image to torch.Tensor.""" | ||||
|         s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \ | ||||
|             f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.' | ||||
|         s = ( | ||||
|             f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) " | ||||
|             f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible." | ||||
|         ) | ||||
|         if len(im.shape) != 4: | ||||
|             if len(im.shape) != 3: | ||||
|                 raise ValueError(s) | ||||
| @ -457,8 +464,10 @@ class LoadTensor: | ||||
|         if im.shape[2] % stride or im.shape[3] % stride: | ||||
|             raise ValueError(s) | ||||
|         if im.max() > 1.0 + torch.finfo(im.dtype).eps:  # torch.float32 eps is 1.2e-07 | ||||
|             LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. ' | ||||
|                            f'Dividing input by 255.') | ||||
|             LOGGER.warning( | ||||
|                 f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. " | ||||
|                 f"Dividing input by 255." | ||||
|             ) | ||||
|             im = im.float() / 255.0 | ||||
| 
 | ||||
|         return im | ||||
| @ -473,7 +482,7 @@ class LoadTensor: | ||||
|         if self.count == 1: | ||||
|             raise StopIteration | ||||
|         self.count += 1 | ||||
|         return self.paths, self.im0, None, '' | ||||
|         return self.paths, self.im0, None, "" | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         """Returns the batch size.""" | ||||
| @ -485,12 +494,14 @@ def autocast_list(source): | ||||
|     files = [] | ||||
|     for im in source: | ||||
|         if isinstance(im, (str, Path)):  # filename or uri | ||||
|             files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im)) | ||||
|             files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im)) | ||||
|         elif isinstance(im, (Image.Image, np.ndarray)):  # PIL or np Image | ||||
|             files.append(im) | ||||
|         else: | ||||
|             raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n' | ||||
|                             f'See https://docs.ultralytics.com/modes/predict for supported source types.') | ||||
|             raise TypeError( | ||||
|                 f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n" | ||||
|                 f"See https://docs.ultralytics.com/modes/predict for supported source types." | ||||
|             ) | ||||
| 
 | ||||
|     return files | ||||
| 
 | ||||
| @ -513,16 +524,18 @@ def get_best_youtube_url(url, use_pafy=True): | ||||
|         (str): The URL of the best quality MP4 video stream, or None if no suitable stream is found. | ||||
|     """ | ||||
|     if use_pafy: | ||||
|         check_requirements(('pafy', 'youtube_dl==2020.12.2')) | ||||
|         check_requirements(("pafy", "youtube_dl==2020.12.2")) | ||||
|         import pafy  # noqa | ||||
|         return pafy.new(url).getbestvideo(preftype='mp4').url | ||||
| 
 | ||||
|         return pafy.new(url).getbestvideo(preftype="mp4").url | ||||
|     else: | ||||
|         check_requirements('yt-dlp') | ||||
|         check_requirements("yt-dlp") | ||||
|         import yt_dlp | ||||
|         with yt_dlp.YoutubeDL({'quiet': True}) as ydl: | ||||
| 
 | ||||
|         with yt_dlp.YoutubeDL({"quiet": True}) as ydl: | ||||
|             info_dict = ydl.extract_info(url, download=False)  # extract info | ||||
|         for f in reversed(info_dict.get('formats', [])):  # reversed because best is usually last | ||||
|         for f in reversed(info_dict.get("formats", [])):  # reversed because best is usually last | ||||
|             # Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size | ||||
|             good_size = (f.get('width') or 0) >= 1920 or (f.get('height') or 0) >= 1080 | ||||
|             if good_size and f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4': | ||||
|                 return f.get('url') | ||||
|             good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080 | ||||
|             if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4": | ||||
|                 return f.get("url") | ||||
|  | ||||
| @ -14,7 +14,7 @@ from tqdm import tqdm | ||||
| from ultralytics.data.utils import exif_size, img2label_paths | ||||
| from ultralytics.utils.checks import check_requirements | ||||
| 
 | ||||
| check_requirements('shapely') | ||||
| check_requirements("shapely") | ||||
| from shapely.geometry import Polygon | ||||
| 
 | ||||
| 
 | ||||
| @ -54,7 +54,7 @@ def bbox_iof(polygon1, bbox2, eps=1e-6): | ||||
|     return outputs | ||||
| 
 | ||||
| 
 | ||||
| def load_yolo_dota(data_root, split='train'): | ||||
| def load_yolo_dota(data_root, split="train"): | ||||
|     """ | ||||
|     Load DOTA dataset. | ||||
| 
 | ||||
| @ -72,10 +72,10 @@ def load_yolo_dota(data_root, split='train'): | ||||
|                     - train | ||||
|                     - val | ||||
|     """ | ||||
|     assert split in ['train', 'val'] | ||||
|     im_dir = os.path.join(data_root, f'images/{split}') | ||||
|     assert split in ["train", "val"] | ||||
|     im_dir = os.path.join(data_root, f"images/{split}") | ||||
|     assert Path(im_dir).exists(), f"Can't find {im_dir}, please check your data root." | ||||
|     im_files = glob(os.path.join(data_root, f'images/{split}/*')) | ||||
|     im_files = glob(os.path.join(data_root, f"images/{split}/*")) | ||||
|     lb_files = img2label_paths(im_files) | ||||
|     annos = [] | ||||
|     for im_file, lb_file in zip(im_files, lb_files): | ||||
| @ -100,7 +100,7 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0 | ||||
|     h, w = im_size | ||||
|     windows = [] | ||||
|     for crop_size, gap in zip(crop_sizes, gaps): | ||||
|         assert crop_size > gap, f'invaild crop_size gap pair [{crop_size} {gap}]' | ||||
|         assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]" | ||||
|         step = crop_size - gap | ||||
| 
 | ||||
|         xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1) | ||||
| @ -132,8 +132,8 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0 | ||||
| 
 | ||||
| def get_window_obj(anno, windows, iof_thr=0.7): | ||||
|     """Get objects for each window.""" | ||||
|     h, w = anno['ori_size'] | ||||
|     label = anno['label'] | ||||
|     h, w = anno["ori_size"] | ||||
|     label = anno["label"] | ||||
|     if len(label): | ||||
|         label[:, 1::2] *= w | ||||
|         label[:, 2::2] *= h | ||||
| @ -166,15 +166,15 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir): | ||||
|                     - train | ||||
|                     - val | ||||
|     """ | ||||
|     im = cv2.imread(anno['filepath']) | ||||
|     name = Path(anno['filepath']).stem | ||||
|     im = cv2.imread(anno["filepath"]) | ||||
|     name = Path(anno["filepath"]).stem | ||||
|     for i, window in enumerate(windows): | ||||
|         x_start, y_start, x_stop, y_stop = window.tolist() | ||||
|         new_name = name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start) | ||||
|         new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start) | ||||
|         patch_im = im[y_start:y_stop, x_start:x_stop] | ||||
|         ph, pw = patch_im.shape[:2] | ||||
| 
 | ||||
|         cv2.imwrite(os.path.join(im_dir, f'{new_name}.jpg'), patch_im) | ||||
|         cv2.imwrite(os.path.join(im_dir, f"{new_name}.jpg"), patch_im) | ||||
|         label = window_objs[i] | ||||
|         if len(label) == 0: | ||||
|             continue | ||||
| @ -183,13 +183,13 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir): | ||||
|         label[:, 1::2] /= pw | ||||
|         label[:, 2::2] /= ph | ||||
| 
 | ||||
|         with open(os.path.join(lb_dir, f'{new_name}.txt'), 'w') as f: | ||||
|         with open(os.path.join(lb_dir, f"{new_name}.txt"), "w") as f: | ||||
|             for lb in label: | ||||
|                 formatted_coords = ['{:.6g}'.format(coord) for coord in lb[1:]] | ||||
|                 formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]] | ||||
|                 f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n") | ||||
| 
 | ||||
| 
 | ||||
| def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024], gaps=[200]): | ||||
| def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=[1024], gaps=[200]): | ||||
|     """ | ||||
|     Split both images and labels. | ||||
| 
 | ||||
| @ -207,14 +207,14 @@ def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024 | ||||
|                 - labels | ||||
|                     - split | ||||
|     """ | ||||
|     im_dir = Path(save_dir) / 'images' / split | ||||
|     im_dir = Path(save_dir) / "images" / split | ||||
|     im_dir.mkdir(parents=True, exist_ok=True) | ||||
|     lb_dir = Path(save_dir) / 'labels' / split | ||||
|     lb_dir = Path(save_dir) / "labels" / split | ||||
|     lb_dir.mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
|     annos = load_yolo_dota(data_root, split=split) | ||||
|     for anno in tqdm(annos, total=len(annos), desc=split): | ||||
|         windows = get_windows(anno['ori_size'], crop_sizes, gaps) | ||||
|         windows = get_windows(anno["ori_size"], crop_sizes, gaps) | ||||
|         window_objs = get_window_obj(anno, windows) | ||||
|         crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir)) | ||||
| 
 | ||||
| @ -245,7 +245,7 @@ def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]): | ||||
|     for r in rates: | ||||
|         crop_sizes.append(int(crop_size / r)) | ||||
|         gaps.append(int(gap / r)) | ||||
|     for split in ['train', 'val']: | ||||
|     for split in ["train", "val"]: | ||||
|         split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps) | ||||
| 
 | ||||
| 
 | ||||
| @ -267,30 +267,30 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]): | ||||
|     for r in rates: | ||||
|         crop_sizes.append(int(crop_size / r)) | ||||
|         gaps.append(int(gap / r)) | ||||
|     save_dir = Path(save_dir) / 'images' / 'test' | ||||
|     save_dir = Path(save_dir) / "images" / "test" | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
|     im_dir = Path(os.path.join(data_root, 'images/test')) | ||||
|     im_dir = Path(os.path.join(data_root, "images/test")) | ||||
|     assert im_dir.exists(), f"Can't find {str(im_dir)}, please check your data root." | ||||
|     im_files = glob(str(im_dir / '*')) | ||||
|     for im_file in tqdm(im_files, total=len(im_files), desc='test'): | ||||
|     im_files = glob(str(im_dir / "*")) | ||||
|     for im_file in tqdm(im_files, total=len(im_files), desc="test"): | ||||
|         w, h = exif_size(Image.open(im_file)) | ||||
|         windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps) | ||||
|         im = cv2.imread(im_file) | ||||
|         name = Path(im_file).stem | ||||
|         for window in windows: | ||||
|             x_start, y_start, x_stop, y_stop = window.tolist() | ||||
|             new_name = (name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start)) | ||||
|             new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start) | ||||
|             patch_im = im[y_start:y_stop, x_start:x_stop] | ||||
|             cv2.imwrite(os.path.join(str(save_dir), f'{new_name}.jpg'), patch_im) | ||||
|             cv2.imwrite(os.path.join(str(save_dir), f"{new_name}.jpg"), patch_im) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     split_trainval( | ||||
|         data_root='DOTAv2', | ||||
|         save_dir='DOTAv2-split', | ||||
|         data_root="DOTAv2", | ||||
|         save_dir="DOTAv2-split", | ||||
|     ) | ||||
|     split_test( | ||||
|         data_root='DOTAv2', | ||||
|         save_dir='DOTAv2-split', | ||||
|         data_root="DOTAv2", | ||||
|         save_dir="DOTAv2-split", | ||||
|     ) | ||||
|  | ||||
| @ -17,36 +17,47 @@ import numpy as np | ||||
| from PIL import Image, ImageOps | ||||
| 
 | ||||
| from ultralytics.nn.autobackend import check_class_names | ||||
| from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, TQDM, clean_url, colorstr, | ||||
|                                emojis, yaml_load, yaml_save) | ||||
| from ultralytics.utils import ( | ||||
|     DATASETS_DIR, | ||||
|     LOGGER, | ||||
|     NUM_THREADS, | ||||
|     ROOT, | ||||
|     SETTINGS_YAML, | ||||
|     TQDM, | ||||
|     clean_url, | ||||
|     colorstr, | ||||
|     emojis, | ||||
|     yaml_load, | ||||
|     yaml_save, | ||||
| ) | ||||
| from ultralytics.utils.checks import check_file, check_font, is_ascii | ||||
| from ultralytics.utils.downloads import download, safe_download, unzip_file | ||||
| from ultralytics.utils.ops import segments2boxes | ||||
| 
 | ||||
| HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.' | ||||
| IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'  # image suffixes | ||||
| VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm'  # video suffixes | ||||
| PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true'  # global pin_memory for dataloaders | ||||
| HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance." | ||||
| IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"  # image suffixes | ||||
| VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"  # video suffixes | ||||
| PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true"  # global pin_memory for dataloaders | ||||
| 
 | ||||
| 
 | ||||
| def img2label_paths(img_paths): | ||||
|     """Define label paths as a function of image paths.""" | ||||
|     sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}'  # /images/, /labels/ substrings | ||||
|     return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths] | ||||
|     sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}"  # /images/, /labels/ substrings | ||||
|     return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths] | ||||
| 
 | ||||
| 
 | ||||
| def get_hash(paths): | ||||
|     """Returns a single hash value of a list of paths (files or dirs).""" | ||||
|     size = sum(os.path.getsize(p) for p in paths if os.path.exists(p))  # sizes | ||||
|     h = hashlib.sha256(str(size).encode())  # hash sizes | ||||
|     h.update(''.join(paths).encode())  # hash paths | ||||
|     h.update("".join(paths).encode())  # hash paths | ||||
|     return h.hexdigest()  # return hash | ||||
| 
 | ||||
| 
 | ||||
| def exif_size(img: Image.Image): | ||||
|     """Returns exif-corrected PIL size.""" | ||||
|     s = img.size  # (width, height) | ||||
|     if img.format == 'JPEG':  # only support JPEG images | ||||
|     if img.format == "JPEG":  # only support JPEG images | ||||
|         with contextlib.suppress(Exception): | ||||
|             exif = img.getexif() | ||||
|             if exif: | ||||
| @ -60,24 +71,24 @@ def verify_image(args): | ||||
|     """Verify one image.""" | ||||
|     (im_file, cls), prefix = args | ||||
|     # Number (found, corrupt), message | ||||
|     nf, nc, msg = 0, 0, '' | ||||
|     nf, nc, msg = 0, 0, "" | ||||
|     try: | ||||
|         im = Image.open(im_file) | ||||
|         im.verify()  # PIL verify | ||||
|         shape = exif_size(im)  # image size | ||||
|         shape = (shape[1], shape[0])  # hw | ||||
|         assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' | ||||
|         assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}' | ||||
|         if im.format.lower() in ('jpg', 'jpeg'): | ||||
|             with open(im_file, 'rb') as f: | ||||
|         assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" | ||||
|         assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" | ||||
|         if im.format.lower() in ("jpg", "jpeg"): | ||||
|             with open(im_file, "rb") as f: | ||||
|                 f.seek(-2, 2) | ||||
|                 if f.read() != b'\xff\xd9':  # corrupt JPEG | ||||
|                     ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100) | ||||
|                     msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved' | ||||
|                 if f.read() != b"\xff\xd9":  # corrupt JPEG | ||||
|                     ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) | ||||
|                     msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" | ||||
|         nf = 1 | ||||
|     except Exception as e: | ||||
|         nc = 1 | ||||
|         msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}' | ||||
|         msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" | ||||
|     return (im_file, cls), nf, nc, msg | ||||
| 
 | ||||
| 
 | ||||
| @ -85,21 +96,21 @@ def verify_image_label(args): | ||||
|     """Verify one image-label pair.""" | ||||
|     im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args | ||||
|     # Number (missing, found, empty, corrupt), message, segments, keypoints | ||||
|     nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None | ||||
|     nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None | ||||
|     try: | ||||
|         # Verify images | ||||
|         im = Image.open(im_file) | ||||
|         im.verify()  # PIL verify | ||||
|         shape = exif_size(im)  # image size | ||||
|         shape = (shape[1], shape[0])  # hw | ||||
|         assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' | ||||
|         assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}' | ||||
|         if im.format.lower() in ('jpg', 'jpeg'): | ||||
|             with open(im_file, 'rb') as f: | ||||
|         assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" | ||||
|         assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" | ||||
|         if im.format.lower() in ("jpg", "jpeg"): | ||||
|             with open(im_file, "rb") as f: | ||||
|                 f.seek(-2, 2) | ||||
|                 if f.read() != b'\xff\xd9':  # corrupt JPEG | ||||
|                     ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100) | ||||
|                     msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved' | ||||
|                 if f.read() != b"\xff\xd9":  # corrupt JPEG | ||||
|                     ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) | ||||
|                     msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" | ||||
| 
 | ||||
|         # Verify labels | ||||
|         if os.path.isfile(lb_file): | ||||
| @ -114,25 +125,26 @@ def verify_image_label(args): | ||||
|             nl = len(lb) | ||||
|             if nl: | ||||
|                 if keypoint: | ||||
|                     assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each' | ||||
|                     assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each" | ||||
|                     points = lb[:, 5:].reshape(-1, ndim)[:, :2] | ||||
|                 else: | ||||
|                     assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected' | ||||
|                     assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" | ||||
|                     points = lb[:, 1:] | ||||
|                 assert points.max() <= 1, f'non-normalized or out of bounds coordinates {points[points > 1]}' | ||||
|                 assert lb.min() >= 0, f'negative label values {lb[lb < 0]}' | ||||
|                 assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}" | ||||
|                 assert lb.min() >= 0, f"negative label values {lb[lb < 0]}" | ||||
| 
 | ||||
|                 # All labels | ||||
|                 max_cls = lb[:, 0].max()  # max label count | ||||
|                 assert max_cls <= num_cls, \ | ||||
|                     f'Label class {int(max_cls)} exceeds dataset class count {num_cls}. ' \ | ||||
|                     f'Possible class labels are 0-{num_cls - 1}' | ||||
|                 assert max_cls <= num_cls, ( | ||||
|                     f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. " | ||||
|                     f"Possible class labels are 0-{num_cls - 1}" | ||||
|                 ) | ||||
|                 _, i = np.unique(lb, axis=0, return_index=True) | ||||
|                 if len(i) < nl:  # duplicate row check | ||||
|                     lb = lb[i]  # remove duplicates | ||||
|                     if segments: | ||||
|                         segments = [segments[x] for x in i] | ||||
|                     msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed' | ||||
|                     msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed" | ||||
|             else: | ||||
|                 ne = 1  # label empty | ||||
|                 lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32) | ||||
| @ -148,7 +160,7 @@ def verify_image_label(args): | ||||
|         return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg | ||||
|     except Exception as e: | ||||
|         nc = 1 | ||||
|         msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}' | ||||
|         msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" | ||||
|         return [None, None, None, None, None, nm, nf, ne, nc, msg] | ||||
| 
 | ||||
| 
 | ||||
| @ -194,8 +206,10 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1): | ||||
| 
 | ||||
| def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): | ||||
|     """Return a (640, 640) overlap mask.""" | ||||
|     masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), | ||||
|                      dtype=np.int32 if len(segments) > 255 else np.uint8) | ||||
|     masks = np.zeros( | ||||
|         (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), | ||||
|         dtype=np.int32 if len(segments) > 255 else np.uint8, | ||||
|     ) | ||||
|     areas = [] | ||||
|     ms = [] | ||||
|     for si in range(len(segments)): | ||||
| @ -226,7 +240,7 @@ def find_dataset_yaml(path: Path) -> Path: | ||||
|     Returns: | ||||
|         (Path): The path of the found YAML file. | ||||
|     """ | ||||
|     files = list(path.glob('*.yaml')) or list(path.rglob('*.yaml'))  # try root level first and then recursive | ||||
|     files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml"))  # try root level first and then recursive | ||||
|     assert files, f"No YAML file found in '{path.resolve()}'" | ||||
|     if len(files) > 1: | ||||
|         files = [f for f in files if f.stem == path.stem]  # prefer *.yaml files that match | ||||
| @ -253,7 +267,7 @@ def check_det_dataset(dataset, autodownload=True): | ||||
|     file = check_file(dataset) | ||||
| 
 | ||||
|     # Download (optional) | ||||
|     extract_dir = '' | ||||
|     extract_dir = "" | ||||
|     if zipfile.is_zipfile(file) or is_tarfile(file): | ||||
|         new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) | ||||
|         file = find_dataset_yaml(DATASETS_DIR / new_dir) | ||||
| @ -263,43 +277,44 @@ def check_det_dataset(dataset, autodownload=True): | ||||
|     data = yaml_load(file, append_filename=True)  # dictionary | ||||
| 
 | ||||
|     # Checks | ||||
|     for k in 'train', 'val': | ||||
|     for k in "train", "val": | ||||
|         if k not in data: | ||||
|             if k != 'val' or 'validation' not in data: | ||||
|             if k != "val" or "validation" not in data: | ||||
|                 raise SyntaxError( | ||||
|                     emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")) | ||||
|                     emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.") | ||||
|                 ) | ||||
|             LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.") | ||||
|             data['val'] = data.pop('validation')  # replace 'validation' key with 'val' key | ||||
|     if 'names' not in data and 'nc' not in data: | ||||
|             data["val"] = data.pop("validation")  # replace 'validation' key with 'val' key | ||||
|     if "names" not in data and "nc" not in data: | ||||
|         raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs.")) | ||||
|     if 'names' in data and 'nc' in data and len(data['names']) != data['nc']: | ||||
|     if "names" in data and "nc" in data and len(data["names"]) != data["nc"]: | ||||
|         raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match.")) | ||||
|     if 'names' not in data: | ||||
|         data['names'] = [f'class_{i}' for i in range(data['nc'])] | ||||
|     if "names" not in data: | ||||
|         data["names"] = [f"class_{i}" for i in range(data["nc"])] | ||||
|     else: | ||||
|         data['nc'] = len(data['names']) | ||||
|         data["nc"] = len(data["names"]) | ||||
| 
 | ||||
|     data['names'] = check_class_names(data['names']) | ||||
|     data["names"] = check_class_names(data["names"]) | ||||
| 
 | ||||
|     # Resolve paths | ||||
|     path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent)  # dataset root | ||||
|     path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent)  # dataset root | ||||
|     if not path.is_absolute(): | ||||
|         path = (DATASETS_DIR / path).resolve() | ||||
| 
 | ||||
|     # Set paths | ||||
|     data['path'] = path  # download scripts | ||||
|     for k in 'train', 'val', 'test': | ||||
|     data["path"] = path  # download scripts | ||||
|     for k in "train", "val", "test": | ||||
|         if data.get(k):  # prepend path | ||||
|             if isinstance(data[k], str): | ||||
|                 x = (path / data[k]).resolve() | ||||
|                 if not x.exists() and data[k].startswith('../'): | ||||
|                 if not x.exists() and data[k].startswith("../"): | ||||
|                     x = (path / data[k][3:]).resolve() | ||||
|                 data[k] = str(x) | ||||
|             else: | ||||
|                 data[k] = [str((path / x).resolve()) for x in data[k]] | ||||
| 
 | ||||
|     # Parse YAML | ||||
|     val, s = (data.get(x) for x in ('val', 'download')) | ||||
|     val, s = (data.get(x) for x in ("val", "download")) | ||||
|     if val: | ||||
|         val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val path | ||||
|         if not all(x.exists() for x in val): | ||||
| @ -312,22 +327,22 @@ def check_det_dataset(dataset, autodownload=True): | ||||
|                 raise FileNotFoundError(m) | ||||
|             t = time.time() | ||||
|             r = None  # success | ||||
|             if s.startswith('http') and s.endswith('.zip'):  # URL | ||||
|             if s.startswith("http") and s.endswith(".zip"):  # URL | ||||
|                 safe_download(url=s, dir=DATASETS_DIR, delete=True) | ||||
|             elif s.startswith('bash '):  # bash script | ||||
|                 LOGGER.info(f'Running {s} ...') | ||||
|             elif s.startswith("bash "):  # bash script | ||||
|                 LOGGER.info(f"Running {s} ...") | ||||
|                 r = os.system(s) | ||||
|             else:  # python script | ||||
|                 exec(s, {'yaml': data}) | ||||
|             dt = f'({round(time.time() - t, 1)}s)' | ||||
|             s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌' | ||||
|             LOGGER.info(f'Dataset download {s}\n') | ||||
|     check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf')  # download fonts | ||||
|                 exec(s, {"yaml": data}) | ||||
|             dt = f"({round(time.time() - t, 1)}s)" | ||||
|             s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌" | ||||
|             LOGGER.info(f"Dataset download {s}\n") | ||||
|     check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf")  # download fonts | ||||
| 
 | ||||
|     return data  # dictionary | ||||
| 
 | ||||
| 
 | ||||
| def check_cls_dataset(dataset, split=''): | ||||
| def check_cls_dataset(dataset, split=""): | ||||
|     """ | ||||
|     Checks a classification dataset such as Imagenet. | ||||
| 
 | ||||
| @ -348,54 +363,59 @@ def check_cls_dataset(dataset, split=''): | ||||
|     """ | ||||
| 
 | ||||
|     # Download (optional if dataset=https://file.zip is passed directly) | ||||
|     if str(dataset).startswith(('http:/', 'https:/')): | ||||
|     if str(dataset).startswith(("http:/", "https:/")): | ||||
|         dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False) | ||||
| 
 | ||||
|     dataset = Path(dataset) | ||||
|     data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve() | ||||
|     if not data_dir.is_dir(): | ||||
|         LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') | ||||
|         LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...") | ||||
|         t = time.time() | ||||
|         if str(dataset) == 'imagenet': | ||||
|         if str(dataset) == "imagenet": | ||||
|             subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) | ||||
|         else: | ||||
|             url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' | ||||
|             url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip" | ||||
|             download(url, dir=data_dir.parent) | ||||
|         s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" | ||||
|         LOGGER.info(s) | ||||
|     train_set = data_dir / 'train' | ||||
|     val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if \ | ||||
|         (data_dir / 'validation').exists() else None  # data/test or data/val | ||||
|     test_set = data_dir / 'test' if (data_dir / 'test').exists() else None  # data/val or data/test | ||||
|     if split == 'val' and not val_set: | ||||
|     train_set = data_dir / "train" | ||||
|     val_set = ( | ||||
|         data_dir / "val" | ||||
|         if (data_dir / "val").exists() | ||||
|         else data_dir / "validation" | ||||
|         if (data_dir / "validation").exists() | ||||
|         else None | ||||
|     )  # data/test or data/val | ||||
|     test_set = data_dir / "test" if (data_dir / "test").exists() else None  # data/val or data/test | ||||
|     if split == "val" and not val_set: | ||||
|         LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.") | ||||
|     elif split == 'test' and not test_set: | ||||
|     elif split == "test" and not test_set: | ||||
|         LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.") | ||||
| 
 | ||||
|     nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()])  # number of classes | ||||
|     names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()]  # class names list | ||||
|     nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()])  # number of classes | ||||
|     names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()]  # class names list | ||||
|     names = dict(enumerate(sorted(names))) | ||||
| 
 | ||||
|     # Print to console | ||||
|     for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items(): | ||||
|     for k, v in {"train": train_set, "val": val_set, "test": test_set}.items(): | ||||
|         prefix = f'{colorstr(f"{k}:")} {v}...' | ||||
|         if v is None: | ||||
|             LOGGER.info(prefix) | ||||
|         else: | ||||
|             files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS] | ||||
|             files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS] | ||||
|             nf = len(files)  # number of files | ||||
|             nd = len({file.parent for file in files})  # number of directories | ||||
|             if nf == 0: | ||||
|                 if k == 'train': | ||||
|                 if k == "train": | ||||
|                     raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ ")) | ||||
|                 else: | ||||
|                     LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found') | ||||
|                     LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found") | ||||
|             elif nd != nc: | ||||
|                 LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}') | ||||
|                 LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}") | ||||
|             else: | ||||
|                 LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ') | ||||
|                 LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ") | ||||
| 
 | ||||
|     return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names} | ||||
|     return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names} | ||||
| 
 | ||||
| 
 | ||||
| class HUBDatasetStats: | ||||
| @ -423,42 +443,43 @@ class HUBDatasetStats: | ||||
|         ``` | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, path='coco8.yaml', task='detect', autodownload=False): | ||||
|     def __init__(self, path="coco8.yaml", task="detect", autodownload=False): | ||||
|         """Initialize class.""" | ||||
|         path = Path(path).resolve() | ||||
|         LOGGER.info(f'Starting HUB dataset checks for {path}....') | ||||
|         LOGGER.info(f"Starting HUB dataset checks for {path}....") | ||||
| 
 | ||||
|         self.task = task  # detect, segment, pose, classify | ||||
|         if self.task == 'classify': | ||||
|         if self.task == "classify": | ||||
|             unzip_dir = unzip_file(path) | ||||
|             data = check_cls_dataset(unzip_dir) | ||||
|             data['path'] = unzip_dir | ||||
|             data["path"] = unzip_dir | ||||
|         else:  # detect, segment, pose | ||||
|             _, data_dir, yaml_path = self._unzip(Path(path)) | ||||
|             try: | ||||
|                 # Load YAML with checks | ||||
|                 data = yaml_load(yaml_path) | ||||
|                 data['path'] = ''  # strip path since YAML should be in dataset root for all HUB datasets | ||||
|                 data["path"] = ""  # strip path since YAML should be in dataset root for all HUB datasets | ||||
|                 yaml_save(yaml_path, data) | ||||
|                 data = check_det_dataset(yaml_path, autodownload)  # dict | ||||
|                 data['path'] = data_dir  # YAML path should be set to '' (relative) or parent (absolute) | ||||
|                 data["path"] = data_dir  # YAML path should be set to '' (relative) or parent (absolute) | ||||
|             except Exception as e: | ||||
|                 raise Exception('error/HUB/dataset_stats/init') from e | ||||
|                 raise Exception("error/HUB/dataset_stats/init") from e | ||||
| 
 | ||||
|         self.hub_dir = Path(f'{data["path"]}-hub') | ||||
|         self.im_dir = self.hub_dir / 'images' | ||||
|         self.im_dir = self.hub_dir / "images" | ||||
|         self.im_dir.mkdir(parents=True, exist_ok=True)  # makes /images | ||||
|         self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())}  # statistics dictionary | ||||
|         self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())}  # statistics dictionary | ||||
|         self.data = data | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _unzip(path): | ||||
|         """Unzip data.zip.""" | ||||
|         if not str(path).endswith('.zip'):  # path is data.yaml | ||||
|         if not str(path).endswith(".zip"):  # path is data.yaml | ||||
|             return False, None, path | ||||
|         unzip_dir = unzip_file(path, path=path.parent) | ||||
|         assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \ | ||||
|                                    f'path/to/abc.zip MUST unzip to path/to/abc/' | ||||
|         assert unzip_dir.is_dir(), ( | ||||
|             f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/" | ||||
|         ) | ||||
|         return True, str(unzip_dir), find_dataset_yaml(unzip_dir)  # zipped, data_dir, yaml_path | ||||
| 
 | ||||
|     def _hub_ops(self, f): | ||||
| @ -470,31 +491,31 @@ class HUBDatasetStats: | ||||
| 
 | ||||
|         def _round(labels): | ||||
|             """Update labels to integer class and 4 decimal place floats.""" | ||||
|             if self.task == 'detect': | ||||
|                 coordinates = labels['bboxes'] | ||||
|             elif self.task == 'segment': | ||||
|                 coordinates = [x.flatten() for x in labels['segments']] | ||||
|             elif self.task == 'pose': | ||||
|                 n = labels['keypoints'].shape[0] | ||||
|                 coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1) | ||||
|             if self.task == "detect": | ||||
|                 coordinates = labels["bboxes"] | ||||
|             elif self.task == "segment": | ||||
|                 coordinates = [x.flatten() for x in labels["segments"]] | ||||
|             elif self.task == "pose": | ||||
|                 n = labels["keypoints"].shape[0] | ||||
|                 coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1) | ||||
|             else: | ||||
|                 raise ValueError('Undefined dataset task.') | ||||
|             zipped = zip(labels['cls'], coordinates) | ||||
|                 raise ValueError("Undefined dataset task.") | ||||
|             zipped = zip(labels["cls"], coordinates) | ||||
|             return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped] | ||||
| 
 | ||||
|         for split in 'train', 'val', 'test': | ||||
|         for split in "train", "val", "test": | ||||
|             self.stats[split] = None  # predefine | ||||
|             path = self.data.get(split) | ||||
| 
 | ||||
|             # Check split | ||||
|             if path is None:  # no split | ||||
|                 continue | ||||
|             files = [f for f in Path(path).rglob('*.*') if f.suffix[1:].lower() in IMG_FORMATS]  # image files in split | ||||
|             files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS]  # image files in split | ||||
|             if not files:  # no images | ||||
|                 continue | ||||
| 
 | ||||
|             # Get dataset statistics | ||||
|             if self.task == 'classify': | ||||
|             if self.task == "classify": | ||||
|                 from torchvision.datasets import ImageFolder | ||||
| 
 | ||||
|                 dataset = ImageFolder(self.data[split]) | ||||
| @ -504,38 +525,35 @@ class HUBDatasetStats: | ||||
|                     x[im[1]] += 1 | ||||
| 
 | ||||
|                 self.stats[split] = { | ||||
|                     'instance_stats': { | ||||
|                         'total': len(dataset), | ||||
|                         'per_class': x.tolist()}, | ||||
|                     'image_stats': { | ||||
|                         'total': len(dataset), | ||||
|                         'unlabelled': 0, | ||||
|                         'per_class': x.tolist()}, | ||||
|                     'labels': [{ | ||||
|                         Path(k).name: v} for k, v in dataset.imgs]} | ||||
|                     "instance_stats": {"total": len(dataset), "per_class": x.tolist()}, | ||||
|                     "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()}, | ||||
|                     "labels": [{Path(k).name: v} for k, v in dataset.imgs], | ||||
|                 } | ||||
|             else: | ||||
|                 from ultralytics.data import YOLODataset | ||||
| 
 | ||||
|                 dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task) | ||||
|                 x = np.array([ | ||||
|                     np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc']) | ||||
|                     for label in TQDM(dataset.labels, total=len(dataset), desc='Statistics')])  # shape(128x80) | ||||
|                 x = np.array( | ||||
|                     [ | ||||
|                         np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"]) | ||||
|                         for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics") | ||||
|                     ] | ||||
|                 )  # shape(128x80) | ||||
|                 self.stats[split] = { | ||||
|                     'instance_stats': { | ||||
|                         'total': int(x.sum()), | ||||
|                         'per_class': x.sum(0).tolist()}, | ||||
|                     'image_stats': { | ||||
|                         'total': len(dataset), | ||||
|                         'unlabelled': int(np.all(x == 0, 1).sum()), | ||||
|                         'per_class': (x > 0).sum(0).tolist()}, | ||||
|                     'labels': [{ | ||||
|                         Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]} | ||||
|                     "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()}, | ||||
|                     "image_stats": { | ||||
|                         "total": len(dataset), | ||||
|                         "unlabelled": int(np.all(x == 0, 1).sum()), | ||||
|                         "per_class": (x > 0).sum(0).tolist(), | ||||
|                     }, | ||||
|                     "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)], | ||||
|                 } | ||||
| 
 | ||||
|         # Save, print and return | ||||
|         if save: | ||||
|             stats_path = self.hub_dir / 'stats.json' | ||||
|             LOGGER.info(f'Saving {stats_path.resolve()}...') | ||||
|             with open(stats_path, 'w') as f: | ||||
|             stats_path = self.hub_dir / "stats.json" | ||||
|             LOGGER.info(f"Saving {stats_path.resolve()}...") | ||||
|             with open(stats_path, "w") as f: | ||||
|                 json.dump(self.stats, f)  # save stats.json | ||||
|         if verbose: | ||||
|             LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False)) | ||||
| @ -545,14 +563,14 @@ class HUBDatasetStats: | ||||
|         """Compress images for Ultralytics HUB.""" | ||||
|         from ultralytics.data import YOLODataset  # ClassificationDataset | ||||
| 
 | ||||
|         for split in 'train', 'val', 'test': | ||||
|         for split in "train", "val", "test": | ||||
|             if self.data.get(split) is None: | ||||
|                 continue | ||||
|             dataset = YOLODataset(img_path=self.data[split], data=self.data) | ||||
|             with ThreadPool(NUM_THREADS) as pool: | ||||
|                 for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'): | ||||
|                 for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"): | ||||
|                     pass | ||||
|         LOGGER.info(f'Done. All images saved to {self.im_dir}') | ||||
|         LOGGER.info(f"Done. All images saved to {self.im_dir}") | ||||
|         return self.im_dir | ||||
| 
 | ||||
| 
 | ||||
| @ -583,9 +601,9 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50): | ||||
|         r = max_dim / max(im.height, im.width)  # ratio | ||||
|         if r < 1.0:  # image too large | ||||
|             im = im.resize((int(im.width * r), int(im.height * r))) | ||||
|         im.save(f_new or f, 'JPEG', quality=quality, optimize=True)  # save | ||||
|         im.save(f_new or f, "JPEG", quality=quality, optimize=True)  # save | ||||
|     except Exception as e:  # use OpenCV | ||||
|         LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}') | ||||
|         LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}") | ||||
|         im = cv2.imread(f) | ||||
|         im_height, im_width = im.shape[:2] | ||||
|         r = max_dim / max(im_height, im_width)  # ratio | ||||
| @ -594,7 +612,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50): | ||||
|         cv2.imwrite(str(f_new or f), im) | ||||
| 
 | ||||
| 
 | ||||
| def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False): | ||||
| def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False): | ||||
|     """ | ||||
|     Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files. | ||||
| 
 | ||||
| @ -612,18 +630,18 @@ def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annot | ||||
|     """ | ||||
| 
 | ||||
|     path = Path(path)  # images dir | ||||
|     files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS)  # image files only | ||||
|     files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS)  # image files only | ||||
|     n = len(files)  # number of files | ||||
|     random.seed(0)  # for reproducibility | ||||
|     indices = random.choices([0, 1, 2], weights=weights, k=n)  # assign each image to a split | ||||
| 
 | ||||
|     txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt']  # 3 txt files | ||||
|     txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"]  # 3 txt files | ||||
|     for x in txt: | ||||
|         if (path.parent / x).exists(): | ||||
|             (path.parent / x).unlink()  # remove existing | ||||
| 
 | ||||
|     LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only) | ||||
|     LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only) | ||||
|     for i, img in TQDM(zip(indices, files), total=n): | ||||
|         if not annotated_only or Path(img2label_paths([str(img)])[0]).exists():  # check label | ||||
|             with open(path.parent / txt[i], 'a') as f: | ||||
|                 f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n')  # add image to txt file | ||||
|             with open(path.parent / txt[i], "a") as f: | ||||
|                 f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n")  # add image to txt file | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -53,7 +53,7 @@ class Model(nn.Module): | ||||
|         list(ultralytics.engine.results.Results): The prediction results. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None: | ||||
|     def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None) -> None: | ||||
|         """ | ||||
|         Initializes the YOLO model. | ||||
| 
 | ||||
| @ -89,7 +89,7 @@ class Model(nn.Module): | ||||
| 
 | ||||
|         # Load or create new YOLO model | ||||
|         model = checks.check_model_file_from_stem(model)  # add suffix, i.e. yolov8n -> yolov8n.pt | ||||
|         if Path(model).suffix in ('.yaml', '.yml'): | ||||
|         if Path(model).suffix in (".yaml", ".yml"): | ||||
|             self._new(model, task) | ||||
|         else: | ||||
|             self._load(model, task) | ||||
| @ -112,16 +112,20 @@ class Model(nn.Module): | ||||
|     def is_triton_model(model): | ||||
|         """Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>""" | ||||
|         from urllib.parse import urlsplit | ||||
| 
 | ||||
|         url = urlsplit(model) | ||||
|         return url.netloc and url.path and url.scheme in {'http', 'grpc'} | ||||
|         return url.netloc and url.path and url.scheme in {"http", "grpc"} | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def is_hub_model(model): | ||||
|         """Check if the provided model is a HUB model.""" | ||||
|         return any(( | ||||
|             model.startswith(f'{HUB_WEB_ROOT}/models/'),  # i.e. https://hub.ultralytics.com/models/MODEL_ID | ||||
|             [len(x) for x in model.split('_')] == [42, 20],  # APIKEY_MODELID | ||||
|             len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\')))  # MODELID | ||||
|         return any( | ||||
|             ( | ||||
|                 model.startswith(f"{HUB_WEB_ROOT}/models/"),  # i.e. https://hub.ultralytics.com/models/MODEL_ID | ||||
|                 [len(x) for x in model.split("_")] == [42, 20],  # APIKEY_MODELID | ||||
|                 len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), | ||||
|             ) | ||||
|         )  # MODELID | ||||
| 
 | ||||
|     def _new(self, cfg: str, task=None, model=None, verbose=True): | ||||
|         """ | ||||
| @ -136,9 +140,9 @@ class Model(nn.Module): | ||||
|         cfg_dict = yaml_model_load(cfg) | ||||
|         self.cfg = cfg | ||||
|         self.task = task or guess_model_task(cfg_dict) | ||||
|         self.model = (model or self._smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1)  # build model | ||||
|         self.overrides['model'] = self.cfg | ||||
|         self.overrides['task'] = self.task | ||||
|         self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1)  # build model | ||||
|         self.overrides["model"] = self.cfg | ||||
|         self.overrides["task"] = self.task | ||||
| 
 | ||||
|         # Below added to allow export from YAMLs | ||||
|         self.model.args = {**DEFAULT_CFG_DICT, **self.overrides}  # combine default and model args (prefer model args) | ||||
| @ -153,9 +157,9 @@ class Model(nn.Module): | ||||
|             task (str | None): model task | ||||
|         """ | ||||
|         suffix = Path(weights).suffix | ||||
|         if suffix == '.pt': | ||||
|         if suffix == ".pt": | ||||
|             self.model, self.ckpt = attempt_load_one_weight(weights) | ||||
|             self.task = self.model.args['task'] | ||||
|             self.task = self.model.args["task"] | ||||
|             self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) | ||||
|             self.ckpt_path = self.model.pt_path | ||||
|         else: | ||||
| @ -163,12 +167,12 @@ class Model(nn.Module): | ||||
|             self.model, self.ckpt = weights, None | ||||
|             self.task = task or guess_model_task(weights) | ||||
|             self.ckpt_path = weights | ||||
|         self.overrides['model'] = weights | ||||
|         self.overrides['task'] = self.task | ||||
|         self.overrides["model"] = weights | ||||
|         self.overrides["task"] = self.task | ||||
| 
 | ||||
|     def _check_is_pytorch_model(self): | ||||
|         """Raises TypeError is model is not a PyTorch model.""" | ||||
|         pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt' | ||||
|         pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" | ||||
|         pt_module = isinstance(self.model, nn.Module) | ||||
|         if not (pt_module or pt_str): | ||||
|             raise TypeError( | ||||
| @ -176,19 +180,20 @@ class Model(nn.Module): | ||||
|                 f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " | ||||
|                 f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " | ||||
|                 f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device " | ||||
|                 f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'") | ||||
|                 f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" | ||||
|             ) | ||||
| 
 | ||||
|     def reset_weights(self): | ||||
|         """Resets the model modules parameters to randomly initialized values, losing all training information.""" | ||||
|         self._check_is_pytorch_model() | ||||
|         for m in self.model.modules(): | ||||
|             if hasattr(m, 'reset_parameters'): | ||||
|             if hasattr(m, "reset_parameters"): | ||||
|                 m.reset_parameters() | ||||
|         for p in self.model.parameters(): | ||||
|             p.requires_grad = True | ||||
|         return self | ||||
| 
 | ||||
|     def load(self, weights='yolov8n.pt'): | ||||
|     def load(self, weights="yolov8n.pt"): | ||||
|         """Transfers parameters with matching names and shapes from 'weights' to model.""" | ||||
|         self._check_is_pytorch_model() | ||||
|         if isinstance(weights, (str, Path)): | ||||
| @ -226,8 +231,8 @@ class Model(nn.Module): | ||||
|         Returns: | ||||
|             (List[torch.Tensor]): A list of image embeddings. | ||||
|         """ | ||||
|         if not kwargs.get('embed'): | ||||
|             kwargs['embed'] = [len(self.model.model) - 2]  # embed second-to-last layer if no indices passed | ||||
|         if not kwargs.get("embed"): | ||||
|             kwargs["embed"] = [len(self.model.model) - 2]  # embed second-to-last layer if no indices passed | ||||
|         return self.predict(source, stream, **kwargs) | ||||
| 
 | ||||
|     def predict(self, source=None, stream=False, predictor=None, **kwargs): | ||||
| @ -249,21 +254,22 @@ class Model(nn.Module): | ||||
|             source = ASSETS | ||||
|             LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") | ||||
| 
 | ||||
|         is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any( | ||||
|             x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track')) | ||||
|         is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any( | ||||
|             x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track") | ||||
|         ) | ||||
| 
 | ||||
|         custom = {'conf': 0.25, 'save': is_cli}  # method defaults | ||||
|         args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'}  # highest priority args on the right | ||||
|         prompts = args.pop('prompts', None)  # for SAM-type models | ||||
|         custom = {"conf": 0.25, "save": is_cli}  # method defaults | ||||
|         args = {**self.overrides, **custom, **kwargs, "mode": "predict"}  # highest priority args on the right | ||||
|         prompts = args.pop("prompts", None)  # for SAM-type models | ||||
| 
 | ||||
|         if not self.predictor: | ||||
|             self.predictor = predictor or self._smart_load('predictor')(overrides=args, _callbacks=self.callbacks) | ||||
|             self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks) | ||||
|             self.predictor.setup_model(model=self.model, verbose=is_cli) | ||||
|         else:  # only update args if predictor is already setup | ||||
|             self.predictor.args = get_cfg(self.predictor.args, args) | ||||
|             if 'project' in args or 'name' in args: | ||||
|             if "project" in args or "name" in args: | ||||
|                 self.predictor.save_dir = get_save_dir(self.predictor.args) | ||||
|         if prompts and hasattr(self.predictor, 'set_prompts'):  # for SAM-type models | ||||
|         if prompts and hasattr(self.predictor, "set_prompts"):  # for SAM-type models | ||||
|             self.predictor.set_prompts(prompts) | ||||
|         return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) | ||||
| 
 | ||||
| @ -280,11 +286,12 @@ class Model(nn.Module): | ||||
|         Returns: | ||||
|             (List[ultralytics.engine.results.Results]): The tracking results. | ||||
|         """ | ||||
|         if not hasattr(self.predictor, 'trackers'): | ||||
|         if not hasattr(self.predictor, "trackers"): | ||||
|             from ultralytics.trackers import register_tracker | ||||
| 
 | ||||
|             register_tracker(self, persist) | ||||
|         kwargs['conf'] = kwargs.get('conf') or 0.1  # ByteTrack-based method needs low confidence predictions as input | ||||
|         kwargs['mode'] = 'track' | ||||
|         kwargs["conf"] = kwargs.get("conf") or 0.1  # ByteTrack-based method needs low confidence predictions as input | ||||
|         kwargs["mode"] = "track" | ||||
|         return self.predict(source=source, stream=stream, **kwargs) | ||||
| 
 | ||||
|     def val(self, validator=None, **kwargs): | ||||
| @ -295,10 +302,10 @@ class Model(nn.Module): | ||||
|             validator (BaseValidator): Customized validator. | ||||
|             **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs | ||||
|         """ | ||||
|         custom = {'rect': True}  # method defaults | ||||
|         args = {**self.overrides, **custom, **kwargs, 'mode': 'val'}  # highest priority args on the right | ||||
|         custom = {"rect": True}  # method defaults | ||||
|         args = {**self.overrides, **custom, **kwargs, "mode": "val"}  # highest priority args on the right | ||||
| 
 | ||||
|         validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks) | ||||
|         validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks) | ||||
|         validator(model=self.model) | ||||
|         self.metrics = validator.metrics | ||||
|         return validator.metrics | ||||
| @ -313,16 +320,17 @@ class Model(nn.Module): | ||||
|         self._check_is_pytorch_model() | ||||
|         from ultralytics.utils.benchmarks import benchmark | ||||
| 
 | ||||
|         custom = {'verbose': False}  # method defaults | ||||
|         args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, 'mode': 'benchmark'} | ||||
|         custom = {"verbose": False}  # method defaults | ||||
|         args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"} | ||||
|         return benchmark( | ||||
|             model=self, | ||||
|             data=kwargs.get('data'),  # if no 'data' argument passed set data=None for default datasets | ||||
|             imgsz=args['imgsz'], | ||||
|             half=args['half'], | ||||
|             int8=args['int8'], | ||||
|             device=args['device'], | ||||
|             verbose=kwargs.get('verbose')) | ||||
|             data=kwargs.get("data"),  # if no 'data' argument passed set data=None for default datasets | ||||
|             imgsz=args["imgsz"], | ||||
|             half=args["half"], | ||||
|             int8=args["int8"], | ||||
|             device=args["device"], | ||||
|             verbose=kwargs.get("verbose"), | ||||
|         ) | ||||
| 
 | ||||
|     def export(self, **kwargs): | ||||
|         """ | ||||
| @ -334,8 +342,8 @@ class Model(nn.Module): | ||||
|         self._check_is_pytorch_model() | ||||
|         from .exporter import Exporter | ||||
| 
 | ||||
|         custom = {'imgsz': self.model.args['imgsz'], 'batch': 1, 'data': None, 'verbose': False}  # method defaults | ||||
|         args = {**self.overrides, **custom, **kwargs, 'mode': 'export'}  # highest priority args on the right | ||||
|         custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False}  # method defaults | ||||
|         args = {**self.overrides, **custom, **kwargs, "mode": "export"}  # highest priority args on the right | ||||
|         return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) | ||||
| 
 | ||||
|     def train(self, trainer=None, **kwargs): | ||||
| @ -347,32 +355,32 @@ class Model(nn.Module): | ||||
|             **kwargs (Any): Any number of arguments representing the training configuration. | ||||
|         """ | ||||
|         self._check_is_pytorch_model() | ||||
|         if hasattr(self.session, 'model') and self.session.model.id:  # Ultralytics HUB session with loaded model | ||||
|         if hasattr(self.session, "model") and self.session.model.id:  # Ultralytics HUB session with loaded model | ||||
|             if any(kwargs): | ||||
|                 LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.') | ||||
|                 LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") | ||||
|             kwargs = self.session.train_args  # overwrite kwargs | ||||
| 
 | ||||
|         checks.check_pip_update_available() | ||||
| 
 | ||||
|         overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides | ||||
|         custom = {'data': DEFAULT_CFG_DICT['data'] or TASK2DATA[self.task]}  # method defaults | ||||
|         args = {**overrides, **custom, **kwargs, 'mode': 'train'}  # highest priority args on the right | ||||
|         if args.get('resume'): | ||||
|             args['resume'] = self.ckpt_path | ||||
|         overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides | ||||
|         custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]}  # method defaults | ||||
|         args = {**overrides, **custom, **kwargs, "mode": "train"}  # highest priority args on the right | ||||
|         if args.get("resume"): | ||||
|             args["resume"] = self.ckpt_path | ||||
| 
 | ||||
|         self.trainer = (trainer or self._smart_load('trainer'))(overrides=args, _callbacks=self.callbacks) | ||||
|         if not args.get('resume'):  # manually set model only if not resuming | ||||
|         self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks) | ||||
|         if not args.get("resume"):  # manually set model only if not resuming | ||||
|             self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) | ||||
|             self.model = self.trainer.model | ||||
| 
 | ||||
|             if SETTINGS['hub'] is True and not self.session: | ||||
|             if SETTINGS["hub"] is True and not self.session: | ||||
|                 # Create a model in HUB | ||||
|                 try: | ||||
|                     self.session = self._get_hub_session(self.model_name) | ||||
|                     if self.session: | ||||
|                         self.session.create_model(args) | ||||
|                         # Check model was created | ||||
|                         if not getattr(self.session.model, 'id', None): | ||||
|                         if not getattr(self.session.model, "id", None): | ||||
|                             self.session = None | ||||
|                 except PermissionError: | ||||
|                     # Ignore permission error | ||||
| @ -385,7 +393,7 @@ class Model(nn.Module): | ||||
|             ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last | ||||
|             self.model, _ = attempt_load_one_weight(ckpt) | ||||
|             self.overrides = self.model.args | ||||
|             self.metrics = getattr(self.trainer.validator, 'metrics', None)  # TODO: no metrics returned by DDP | ||||
|             self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP | ||||
|         return self.metrics | ||||
| 
 | ||||
|     def tune(self, use_ray=False, iterations=10, *args, **kwargs): | ||||
| @ -398,12 +406,13 @@ class Model(nn.Module): | ||||
|         self._check_is_pytorch_model() | ||||
|         if use_ray: | ||||
|             from ultralytics.utils.tuner import run_ray_tune | ||||
| 
 | ||||
|             return run_ray_tune(self, max_samples=iterations, *args, **kwargs) | ||||
|         else: | ||||
|             from .tuner import Tuner | ||||
| 
 | ||||
|             custom = {}  # method defaults | ||||
|             args = {**self.overrides, **custom, **kwargs, 'mode': 'train'}  # highest priority args on the right | ||||
|             args = {**self.overrides, **custom, **kwargs, "mode": "train"}  # highest priority args on the right | ||||
|             return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) | ||||
| 
 | ||||
|     def _apply(self, fn): | ||||
| @ -411,13 +420,13 @@ class Model(nn.Module): | ||||
|         self._check_is_pytorch_model() | ||||
|         self = super()._apply(fn)  # noqa | ||||
|         self.predictor = None  # reset predictor as device may have changed | ||||
|         self.overrides['device'] = self.device  # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' | ||||
|         self.overrides["device"] = self.device  # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' | ||||
|         return self | ||||
| 
 | ||||
|     @property | ||||
|     def names(self): | ||||
|         """Returns class names of the loaded model.""" | ||||
|         return self.model.names if hasattr(self.model, 'names') else None | ||||
|         return self.model.names if hasattr(self.model, "names") else None | ||||
| 
 | ||||
|     @property | ||||
|     def device(self): | ||||
| @ -427,7 +436,7 @@ class Model(nn.Module): | ||||
|     @property | ||||
|     def transforms(self): | ||||
|         """Returns transform of the loaded model.""" | ||||
|         return self.model.transforms if hasattr(self.model, 'transforms') else None | ||||
|         return self.model.transforms if hasattr(self.model, "transforms") else None | ||||
| 
 | ||||
|     def add_callback(self, event: str, func): | ||||
|         """Add a callback.""" | ||||
| @ -445,7 +454,7 @@ class Model(nn.Module): | ||||
|     @staticmethod | ||||
|     def _reset_ckpt_args(args): | ||||
|         """Reset arguments when loading a PyTorch model.""" | ||||
|         include = {'imgsz', 'data', 'task', 'single_cls'}  # only remember these arguments when loading a PyTorch model | ||||
|         include = {"imgsz", "data", "task", "single_cls"}  # only remember these arguments when loading a PyTorch model | ||||
|         return {k: v for k, v in args.items() if k in include} | ||||
| 
 | ||||
|     # def __getattr__(self, attr): | ||||
| @ -461,7 +470,8 @@ class Model(nn.Module): | ||||
|             name = self.__class__.__name__ | ||||
|             mode = inspect.stack()[1][3]  # get the function name. | ||||
|             raise NotImplementedError( | ||||
|                 emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")) from e | ||||
|                 emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.") | ||||
|             ) from e | ||||
| 
 | ||||
|     @property | ||||
|     def task_map(self): | ||||
| @ -471,4 +481,4 @@ class Model(nn.Module): | ||||
|         Returns: | ||||
|             task_map (dict): The map of model task to mode classes. | ||||
|         """ | ||||
|         raise NotImplementedError('Please provide task map for your model!') | ||||
|         raise NotImplementedError("Please provide task map for your model!") | ||||
|  | ||||
| @ -132,8 +132,11 @@ class BasePredictor: | ||||
| 
 | ||||
|     def inference(self, im, *args, **kwargs): | ||||
|         """Runs inference on a given image using the specified model and arguments.""" | ||||
|         visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem, | ||||
|                                    mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False | ||||
|         visualize = ( | ||||
|             increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True) | ||||
|             if self.args.visualize and (not self.source_type.tensor) | ||||
|             else False | ||||
|         ) | ||||
|         return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) | ||||
| 
 | ||||
|     def pre_transform(self, im): | ||||
| @ -153,35 +156,38 @@ class BasePredictor: | ||||
|     def write_results(self, idx, results, batch): | ||||
|         """Write inference results to a file or directory.""" | ||||
|         p, im, _ = batch | ||||
|         log_string = '' | ||||
|         log_string = "" | ||||
|         if len(im.shape) == 3: | ||||
|             im = im[None]  # expand for batch dim | ||||
|         if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor:  # batch_size >= 1 | ||||
|             log_string += f'{idx}: ' | ||||
|             log_string += f"{idx}: " | ||||
|             frame = self.dataset.count | ||||
|         else: | ||||
|             frame = getattr(self.dataset, 'frame', 0) | ||||
|             frame = getattr(self.dataset, "frame", 0) | ||||
|         self.data_path = p | ||||
|         self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') | ||||
|         log_string += '%gx%g ' % im.shape[2:]  # print string | ||||
|         self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}") | ||||
|         log_string += "%gx%g " % im.shape[2:]  # print string | ||||
|         result = results[idx] | ||||
|         log_string += result.verbose() | ||||
| 
 | ||||
|         if self.args.save or self.args.show:  # Add bbox to image | ||||
|             plot_args = { | ||||
|                 'line_width': self.args.line_width, | ||||
|                 'boxes': self.args.show_boxes, | ||||
|                 'conf': self.args.show_conf, | ||||
|                 'labels': self.args.show_labels} | ||||
|                 "line_width": self.args.line_width, | ||||
|                 "boxes": self.args.show_boxes, | ||||
|                 "conf": self.args.show_conf, | ||||
|                 "labels": self.args.show_labels, | ||||
|             } | ||||
|             if not self.args.retina_masks: | ||||
|                 plot_args['im_gpu'] = im[idx] | ||||
|                 plot_args["im_gpu"] = im[idx] | ||||
|             self.plotted_img = result.plot(**plot_args) | ||||
|         # Write | ||||
|         if self.args.save_txt: | ||||
|             result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf) | ||||
|             result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) | ||||
|         if self.args.save_crop: | ||||
|             result.save_crop(save_dir=self.save_dir / 'crops', | ||||
|                              file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}')) | ||||
|             result.save_crop( | ||||
|                 save_dir=self.save_dir / "crops", | ||||
|                 file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"), | ||||
|             ) | ||||
| 
 | ||||
|         return log_string | ||||
| 
 | ||||
| @ -210,17 +216,24 @@ class BasePredictor: | ||||
|     def setup_source(self, source): | ||||
|         """Sets up source and inference mode.""" | ||||
|         self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2)  # check image size | ||||
|         self.transforms = getattr( | ||||
|             self.model.model, 'transforms', classify_transforms( | ||||
|                 self.imgsz[0], crop_fraction=self.args.crop_fraction)) if self.args.task == 'classify' else None | ||||
|         self.dataset = load_inference_source(source=source, | ||||
|                                              imgsz=self.imgsz, | ||||
|                                              vid_stride=self.args.vid_stride, | ||||
|                                              buffer=self.args.stream_buffer) | ||||
|         self.transforms = ( | ||||
|             getattr( | ||||
|                 self.model.model, | ||||
|                 "transforms", | ||||
|                 classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), | ||||
|             ) | ||||
|             if self.args.task == "classify" | ||||
|             else None | ||||
|         ) | ||||
|         self.dataset = load_inference_source( | ||||
|             source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer | ||||
|         ) | ||||
|         self.source_type = self.dataset.source_type | ||||
|         if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or  # streams | ||||
|                                                   len(self.dataset) > 1000 or  # images | ||||
|                                                   any(getattr(self.dataset, 'video_flag', [False]))):  # videos | ||||
|         if not getattr(self, "stream", True) and ( | ||||
|             self.dataset.mode == "stream"  # streams | ||||
|             or len(self.dataset) > 1000  # images | ||||
|             or any(getattr(self.dataset, "video_flag", [False])) | ||||
|         ):  # videos | ||||
|             LOGGER.warning(STREAM_WARNING) | ||||
|         self.vid_path = [None] * self.dataset.bs | ||||
|         self.vid_writer = [None] * self.dataset.bs | ||||
| @ -230,7 +243,7 @@ class BasePredictor: | ||||
|     def stream_inference(self, source=None, model=None, *args, **kwargs): | ||||
|         """Streams real-time inference on camera feed and saves results to file.""" | ||||
|         if self.args.verbose: | ||||
|             LOGGER.info('') | ||||
|             LOGGER.info("") | ||||
| 
 | ||||
|         # Setup model | ||||
|         if not self.model: | ||||
| @ -242,7 +255,7 @@ class BasePredictor: | ||||
| 
 | ||||
|             # Check if save_dir/ label file exists | ||||
|             if self.args.save or self.args.save_txt: | ||||
|                 (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) | ||||
|                 (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
|             # Warmup model | ||||
|             if not self.done_warmup: | ||||
| @ -250,10 +263,10 @@ class BasePredictor: | ||||
|                 self.done_warmup = True | ||||
| 
 | ||||
|             self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile()) | ||||
|             self.run_callbacks('on_predict_start') | ||||
|             self.run_callbacks("on_predict_start") | ||||
| 
 | ||||
|             for batch in self.dataset: | ||||
|                 self.run_callbacks('on_predict_batch_start') | ||||
|                 self.run_callbacks("on_predict_batch_start") | ||||
|                 self.batch = batch | ||||
|                 path, im0s, vid_cap, s = batch | ||||
| 
 | ||||
| @ -272,15 +285,16 @@ class BasePredictor: | ||||
|                 with profilers[2]: | ||||
|                     self.results = self.postprocess(preds, im, im0s) | ||||
| 
 | ||||
|                 self.run_callbacks('on_predict_postprocess_end') | ||||
|                 self.run_callbacks("on_predict_postprocess_end") | ||||
|                 # Visualize, save, write results | ||||
|                 n = len(im0s) | ||||
|                 for i in range(n): | ||||
|                     self.seen += 1 | ||||
|                     self.results[i].speed = { | ||||
|                         'preprocess': profilers[0].dt * 1E3 / n, | ||||
|                         'inference': profilers[1].dt * 1E3 / n, | ||||
|                         'postprocess': profilers[2].dt * 1E3 / n} | ||||
|                         "preprocess": profilers[0].dt * 1e3 / n, | ||||
|                         "inference": profilers[1].dt * 1e3 / n, | ||||
|                         "postprocess": profilers[2].dt * 1e3 / n, | ||||
|                     } | ||||
|                     p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy() | ||||
|                     p = Path(p) | ||||
| 
 | ||||
| @ -293,12 +307,12 @@ class BasePredictor: | ||||
|                     if self.args.save and self.plotted_img is not None: | ||||
|                         self.save_preds(vid_cap, i, str(self.save_dir / p.name)) | ||||
| 
 | ||||
|                 self.run_callbacks('on_predict_batch_end') | ||||
|                 self.run_callbacks("on_predict_batch_end") | ||||
|                 yield from self.results | ||||
| 
 | ||||
|                 # Print time (inference-only) | ||||
|                 if self.args.verbose: | ||||
|                     LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms') | ||||
|                     LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms") | ||||
| 
 | ||||
|         # Release assets | ||||
|         if isinstance(self.vid_writer[-1], cv2.VideoWriter): | ||||
| @ -306,25 +320,29 @@ class BasePredictor: | ||||
| 
 | ||||
|         # Print results | ||||
|         if self.args.verbose and self.seen: | ||||
|             t = tuple(x.t / self.seen * 1E3 for x in profilers)  # speeds per image | ||||
|             LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape ' | ||||
|                         f'{(1, 3, *im.shape[2:])}' % t) | ||||
|             t = tuple(x.t / self.seen * 1e3 for x in profilers)  # speeds per image | ||||
|             LOGGER.info( | ||||
|                 f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " | ||||
|                 f"{(1, 3, *im.shape[2:])}" % t | ||||
|             ) | ||||
|         if self.args.save or self.args.save_txt or self.args.save_crop: | ||||
|             nl = len(list(self.save_dir.glob('labels/*.txt')))  # number of labels | ||||
|             s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else '' | ||||
|             nl = len(list(self.save_dir.glob("labels/*.txt")))  # number of labels | ||||
|             s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" | ||||
|             LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") | ||||
| 
 | ||||
|         self.run_callbacks('on_predict_end') | ||||
|         self.run_callbacks("on_predict_end") | ||||
| 
 | ||||
|     def setup_model(self, model, verbose=True): | ||||
|         """Initialize YOLO model with given parameters and set it to evaluation mode.""" | ||||
|         self.model = AutoBackend(model or self.args.model, | ||||
|         self.model = AutoBackend( | ||||
|             model or self.args.model, | ||||
|             device=select_device(self.args.device, verbose=verbose), | ||||
|             dnn=self.args.dnn, | ||||
|             data=self.args.data, | ||||
|             fp16=self.args.half, | ||||
|             fuse=True, | ||||
|                                  verbose=verbose) | ||||
|             verbose=verbose, | ||||
|         ) | ||||
| 
 | ||||
|         self.device = self.model.device  # update device | ||||
|         self.args.half = self.model.fp16  # update half | ||||
| @ -333,18 +351,18 @@ class BasePredictor: | ||||
|     def show(self, p): | ||||
|         """Display an image in a window using OpenCV imshow().""" | ||||
|         im0 = self.plotted_img | ||||
|         if platform.system() == 'Linux' and p not in self.windows: | ||||
|         if platform.system() == "Linux" and p not in self.windows: | ||||
|             self.windows.append(p) | ||||
|             cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux) | ||||
|             cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) | ||||
|         cv2.imshow(str(p), im0) | ||||
|         cv2.waitKey(500 if self.batch[3].startswith('image') else 1)  # 1 millisecond | ||||
|         cv2.waitKey(500 if self.batch[3].startswith("image") else 1)  # 1 millisecond | ||||
| 
 | ||||
|     def save_preds(self, vid_cap, idx, save_path): | ||||
|         """Save video predictions as mp4 at specified path.""" | ||||
|         im0 = self.plotted_img | ||||
|         # Save imgs | ||||
|         if self.dataset.mode == 'image': | ||||
|         if self.dataset.mode == "image": | ||||
|             cv2.imwrite(save_path, im0) | ||||
|         else:  # 'video' or 'stream' | ||||
|             frames_path = f'{save_path.split(".", 1)[0]}_frames/' | ||||
| @ -361,15 +379,16 @@ class BasePredictor: | ||||
|                     h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | ||||
|                 else:  # stream | ||||
|                     fps, w, h = 30, im0.shape[1], im0.shape[0] | ||||
|                 suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG') | ||||
|                 self.vid_writer[idx] = cv2.VideoWriter(str(Path(save_path).with_suffix(suffix)), | ||||
|                                                        cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) | ||||
|                 suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") | ||||
|                 self.vid_writer[idx] = cv2.VideoWriter( | ||||
|                     str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h) | ||||
|                 ) | ||||
|             # Write video | ||||
|             self.vid_writer[idx].write(im0) | ||||
| 
 | ||||
|             # Write frame | ||||
|             if self.args.save_frames: | ||||
|                 cv2.imwrite(f'{frames_path}{self.vid_frame[idx]}.jpg', im0) | ||||
|                 cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0) | ||||
|                 self.vid_frame[idx] += 1 | ||||
| 
 | ||||
|     def run_callbacks(self, event: str): | ||||
|  | ||||
| @ -98,15 +98,15 @@ class Results(SimpleClass): | ||||
|         self.probs = Probs(probs) if probs is not None else None | ||||
|         self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None | ||||
|         self.obb = OBB(obb, self.orig_shape) if obb is not None else None | ||||
|         self.speed = {'preprocess': None, 'inference': None, 'postprocess': None}  # milliseconds per image | ||||
|         self.speed = {"preprocess": None, "inference": None, "postprocess": None}  # milliseconds per image | ||||
|         self.names = names | ||||
|         self.path = path | ||||
|         self.save_dir = None | ||||
|         self._keys = 'boxes', 'masks', 'probs', 'keypoints', 'obb' | ||||
|         self._keys = "boxes", "masks", "probs", "keypoints", "obb" | ||||
| 
 | ||||
|     def __getitem__(self, idx): | ||||
|         """Return a Results object for the specified index.""" | ||||
|         return self._apply('__getitem__', idx) | ||||
|         return self._apply("__getitem__", idx) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         """Return the number of detections in the Results object.""" | ||||
| @ -146,19 +146,19 @@ class Results(SimpleClass): | ||||
| 
 | ||||
|     def cpu(self): | ||||
|         """Return a copy of the Results object with all tensors on CPU memory.""" | ||||
|         return self._apply('cpu') | ||||
|         return self._apply("cpu") | ||||
| 
 | ||||
|     def numpy(self): | ||||
|         """Return a copy of the Results object with all tensors as numpy arrays.""" | ||||
|         return self._apply('numpy') | ||||
|         return self._apply("numpy") | ||||
| 
 | ||||
|     def cuda(self): | ||||
|         """Return a copy of the Results object with all tensors on GPU memory.""" | ||||
|         return self._apply('cuda') | ||||
|         return self._apply("cuda") | ||||
| 
 | ||||
|     def to(self, *args, **kwargs): | ||||
|         """Return a copy of the Results object with tensors on the specified device and dtype.""" | ||||
|         return self._apply('to', *args, **kwargs) | ||||
|         return self._apply("to", *args, **kwargs) | ||||
| 
 | ||||
|     def new(self): | ||||
|         """Return a new Results object with the same image, path, and names.""" | ||||
| @ -169,7 +169,7 @@ class Results(SimpleClass): | ||||
|         conf=True, | ||||
|         line_width=None, | ||||
|         font_size=None, | ||||
|         font='Arial.ttf', | ||||
|         font="Arial.ttf", | ||||
|         pil=False, | ||||
|         img=None, | ||||
|         im_gpu=None, | ||||
| @ -229,14 +229,20 @@ class Results(SimpleClass): | ||||
|             font_size, | ||||
|             font, | ||||
|             pil or (pred_probs is not None and show_probs),  # Classify tasks default to pil=True | ||||
|             example=names) | ||||
|             example=names, | ||||
|         ) | ||||
| 
 | ||||
|         # Plot Segment results | ||||
|         if pred_masks and show_masks: | ||||
|             if im_gpu is None: | ||||
|                 img = LetterBox(pred_masks.shape[1:])(image=annotator.result()) | ||||
|                 im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute( | ||||
|                     2, 0, 1).flip(0).contiguous() / 255 | ||||
|                 im_gpu = ( | ||||
|                     torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device) | ||||
|                     .permute(2, 0, 1) | ||||
|                     .flip(0) | ||||
|                     .contiguous() | ||||
|                     / 255 | ||||
|                 ) | ||||
|             idx = pred_boxes.cls if pred_boxes else range(len(pred_masks)) | ||||
|             annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu) | ||||
| 
 | ||||
| @ -244,14 +250,14 @@ class Results(SimpleClass): | ||||
|         if pred_boxes is not None and show_boxes: | ||||
|             for d in reversed(pred_boxes): | ||||
|                 c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) | ||||
|                 name = ('' if id is None else f'id:{id} ') + names[c] | ||||
|                 label = (f'{name} {conf:.2f}' if conf else name) if labels else None | ||||
|                 name = ("" if id is None else f"id:{id} ") + names[c] | ||||
|                 label = (f"{name} {conf:.2f}" if conf else name) if labels else None | ||||
|                 box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze() | ||||
|                 annotator.box_label(box, label, color=colors(c, True), rotated=is_obb) | ||||
| 
 | ||||
|         # Plot Classify results | ||||
|         if pred_probs is not None and show_probs: | ||||
|             text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5) | ||||
|             text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5) | ||||
|             x = round(self.orig_shape[0] * 0.03) | ||||
|             annotator.text([x, x], text, txt_color=(255, 255, 255))  # TODO: allow setting colors | ||||
| 
 | ||||
| @ -264,11 +270,11 @@ class Results(SimpleClass): | ||||
| 
 | ||||
|     def verbose(self): | ||||
|         """Return log string for each task.""" | ||||
|         log_string = '' | ||||
|         log_string = "" | ||||
|         probs = self.probs | ||||
|         boxes = self.boxes | ||||
|         if len(self) == 0: | ||||
|             return log_string if probs is not None else f'{log_string}(no detections), ' | ||||
|             return log_string if probs is not None else f"{log_string}(no detections), " | ||||
|         if probs is not None: | ||||
|             log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, " | ||||
|         if boxes: | ||||
| @ -293,7 +299,7 @@ class Results(SimpleClass): | ||||
|         texts = [] | ||||
|         if probs is not None: | ||||
|             # Classify | ||||
|             [texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5] | ||||
|             [texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5] | ||||
|         elif boxes: | ||||
|             # Detect/segment/pose | ||||
|             for j, d in enumerate(boxes): | ||||
| @ -304,16 +310,16 @@ class Results(SimpleClass): | ||||
|                     line = (c, *seg) | ||||
|                 if kpts is not None: | ||||
|                     kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn | ||||
|                     line += (*kpt.reshape(-1).tolist(), ) | ||||
|                 line += (conf, ) * save_conf + (() if id is None else (id, )) | ||||
|                 texts.append(('%g ' * len(line)).rstrip() % line) | ||||
|                     line += (*kpt.reshape(-1).tolist(),) | ||||
|                 line += (conf,) * save_conf + (() if id is None else (id,)) | ||||
|                 texts.append(("%g " * len(line)).rstrip() % line) | ||||
| 
 | ||||
|         if texts: | ||||
|             Path(txt_file).parent.mkdir(parents=True, exist_ok=True)  # make directory | ||||
|             with open(txt_file, 'a') as f: | ||||
|                 f.writelines(text + '\n' for text in texts) | ||||
|             with open(txt_file, "a") as f: | ||||
|                 f.writelines(text + "\n" for text in texts) | ||||
| 
 | ||||
|     def save_crop(self, save_dir, file_name=Path('im.jpg')): | ||||
|     def save_crop(self, save_dir, file_name=Path("im.jpg")): | ||||
|         """ | ||||
|         Save cropped predictions to `save_dir/cls/file_name.jpg`. | ||||
| 
 | ||||
| @ -322,21 +328,23 @@ class Results(SimpleClass): | ||||
|             file_name (str | pathlib.Path): File name. | ||||
|         """ | ||||
|         if self.probs is not None: | ||||
|             LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.') | ||||
|             LOGGER.warning("WARNING ⚠️ Classify task do not support `save_crop`.") | ||||
|             return | ||||
|         if self.obb is not None: | ||||
|             LOGGER.warning('WARNING ⚠️ OBB task do not support `save_crop`.') | ||||
|             LOGGER.warning("WARNING ⚠️ OBB task do not support `save_crop`.") | ||||
|             return | ||||
|         for d in self.boxes: | ||||
|             save_one_box(d.xyxy, | ||||
|             save_one_box( | ||||
|                 d.xyxy, | ||||
|                 self.orig_img.copy(), | ||||
|                          file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name)}.jpg', | ||||
|                          BGR=True) | ||||
|                 file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg", | ||||
|                 BGR=True, | ||||
|             ) | ||||
| 
 | ||||
|     def tojson(self, normalize=False): | ||||
|         """Convert the object to JSON format.""" | ||||
|         if self.probs is not None: | ||||
|             LOGGER.warning('Warning: Classify task do not support `tojson` yet.') | ||||
|             LOGGER.warning("Warning: Classify task do not support `tojson` yet.") | ||||
|             return | ||||
| 
 | ||||
|         import json | ||||
| @ -346,19 +354,19 @@ class Results(SimpleClass): | ||||
|         data = self.boxes.data.cpu().tolist() | ||||
|         h, w = self.orig_shape if normalize else (1, 1) | ||||
|         for i, row in enumerate(data):  # xyxy, track_id if tracking, conf, class_id | ||||
|             box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h} | ||||
|             box = {"x1": row[0] / w, "y1": row[1] / h, "x2": row[2] / w, "y2": row[3] / h} | ||||
|             conf = row[-2] | ||||
|             class_id = int(row[-1]) | ||||
|             name = self.names[class_id] | ||||
|             result = {'name': name, 'class': class_id, 'confidence': conf, 'box': box} | ||||
|             result = {"name": name, "class": class_id, "confidence": conf, "box": box} | ||||
|             if self.boxes.is_track: | ||||
|                 result['track_id'] = int(row[-3])  # track ID | ||||
|                 result["track_id"] = int(row[-3])  # track ID | ||||
|             if self.masks: | ||||
|                 x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1]  # numpy array | ||||
|                 result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()} | ||||
|                 result["segments"] = {"x": (x / w).tolist(), "y": (y / h).tolist()} | ||||
|             if self.keypoints is not None: | ||||
|                 x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1)  # torch Tensor | ||||
|                 result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()} | ||||
|                 result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()} | ||||
|             results.append(result) | ||||
| 
 | ||||
|         # Convert detections to JSON | ||||
| @ -397,7 +405,7 @@ class Boxes(BaseTensor): | ||||
|         if boxes.ndim == 1: | ||||
|             boxes = boxes[None, :] | ||||
|         n = boxes.shape[-1] | ||||
|         assert n in (6, 7), f'expected 6 or 7 values but got {n}'  # xyxy, track_id, conf, cls | ||||
|         assert n in (6, 7), f"expected 6 or 7 values but got {n}"  # xyxy, track_id, conf, cls | ||||
|         super().__init__(boxes, orig_shape) | ||||
|         self.is_track = n == 7 | ||||
|         self.orig_shape = orig_shape | ||||
| @ -474,7 +482,8 @@ class Masks(BaseTensor): | ||||
|         """Return normalized segments.""" | ||||
|         return [ | ||||
|             ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True) | ||||
|             for x in ops.masks2segments(self.data)] | ||||
|             for x in ops.masks2segments(self.data) | ||||
|         ] | ||||
| 
 | ||||
|     @property | ||||
|     @lru_cache(maxsize=1) | ||||
| @ -482,7 +491,8 @@ class Masks(BaseTensor): | ||||
|         """Return segments in pixel coordinates.""" | ||||
|         return [ | ||||
|             ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False) | ||||
|             for x in ops.masks2segments(self.data)] | ||||
|             for x in ops.masks2segments(self.data) | ||||
|         ] | ||||
| 
 | ||||
| 
 | ||||
| class Keypoints(BaseTensor): | ||||
| @ -610,7 +620,7 @@ class OBB(BaseTensor): | ||||
|         if boxes.ndim == 1: | ||||
|             boxes = boxes[None, :] | ||||
|         n = boxes.shape[-1] | ||||
|         assert n in (7, 8), f'expected 7 or 8 values but got {n}'  # xywh, rotation, track_id, conf, cls | ||||
|         assert n in (7, 8), f"expected 7 or 8 values but got {n}"  # xywh, rotation, track_id, conf, cls | ||||
|         super().__init__(boxes, orig_shape) | ||||
|         self.is_track = n == 8 | ||||
|         self.orig_shape = orig_shape | ||||
|  | ||||
| @ -23,14 +23,31 @@ from torch import nn, optim | ||||
| from ultralytics.cfg import get_cfg, get_save_dir | ||||
| from ultralytics.data.utils import check_cls_dataset, check_det_dataset | ||||
| from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights | ||||
| from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis, | ||||
|                                yaml_save) | ||||
| from ultralytics.utils import ( | ||||
|     DEFAULT_CFG, | ||||
|     LOGGER, | ||||
|     RANK, | ||||
|     TQDM, | ||||
|     __version__, | ||||
|     callbacks, | ||||
|     clean_url, | ||||
|     colorstr, | ||||
|     emojis, | ||||
|     yaml_save, | ||||
| ) | ||||
| from ultralytics.utils.autobatch import check_train_batch_size | ||||
| from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args | ||||
| from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command | ||||
| from ultralytics.utils.files import get_latest_run | ||||
| from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device, | ||||
|                                            strip_optimizer) | ||||
| from ultralytics.utils.torch_utils import ( | ||||
|     EarlyStopping, | ||||
|     ModelEMA, | ||||
|     de_parallel, | ||||
|     init_seeds, | ||||
|     one_cycle, | ||||
|     select_device, | ||||
|     strip_optimizer, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class BaseTrainer: | ||||
| @ -89,12 +106,12 @@ class BaseTrainer: | ||||
|         # Dirs | ||||
|         self.save_dir = get_save_dir(self.args) | ||||
|         self.args.name = self.save_dir.name  # update name for loggers | ||||
|         self.wdir = self.save_dir / 'weights'  # weights dir | ||||
|         self.wdir = self.save_dir / "weights"  # weights dir | ||||
|         if RANK in (-1, 0): | ||||
|             self.wdir.mkdir(parents=True, exist_ok=True)  # make dir | ||||
|             self.args.save_dir = str(self.save_dir) | ||||
|             yaml_save(self.save_dir / 'args.yaml', vars(self.args))  # save run args | ||||
|         self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'  # checkpoint paths | ||||
|             yaml_save(self.save_dir / "args.yaml", vars(self.args))  # save run args | ||||
|         self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths | ||||
|         self.save_period = self.args.save_period | ||||
| 
 | ||||
|         self.batch_size = self.args.batch | ||||
| @ -104,18 +121,18 @@ class BaseTrainer: | ||||
|             print_args(vars(self.args)) | ||||
| 
 | ||||
|         # Device | ||||
|         if self.device.type in ('cpu', 'mps'): | ||||
|         if self.device.type in ("cpu", "mps"): | ||||
|             self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading | ||||
| 
 | ||||
|         # Model and Dataset | ||||
|         self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolov8n -> yolov8n.pt | ||||
|         try: | ||||
|             if self.args.task == 'classify': | ||||
|             if self.args.task == "classify": | ||||
|                 self.data = check_cls_dataset(self.args.data) | ||||
|             elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'): | ||||
|             elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ("detect", "segment", "pose"): | ||||
|                 self.data = check_det_dataset(self.args.data) | ||||
|                 if 'yaml_file' in self.data: | ||||
|                     self.args.data = self.data['yaml_file']  # for validating 'yolo train data=url.zip' usage | ||||
|                 if "yaml_file" in self.data: | ||||
|                     self.args.data = self.data["yaml_file"]  # for validating 'yolo train data=url.zip' usage | ||||
|         except Exception as e: | ||||
|             raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e | ||||
| 
 | ||||
| @ -131,8 +148,8 @@ class BaseTrainer: | ||||
|         self.fitness = None | ||||
|         self.loss = None | ||||
|         self.tloss = None | ||||
|         self.loss_names = ['Loss'] | ||||
|         self.csv = self.save_dir / 'results.csv' | ||||
|         self.loss_names = ["Loss"] | ||||
|         self.csv = self.save_dir / "results.csv" | ||||
|         self.plot_idx = [0, 1, 2] | ||||
| 
 | ||||
|         # Callbacks | ||||
| @ -156,7 +173,7 @@ class BaseTrainer: | ||||
|     def train(self): | ||||
|         """Allow device='', device=None on Multi-GPU systems to default to device=0.""" | ||||
|         if isinstance(self.args.device, str) and len(self.args.device):  # i.e. device='0' or device='0,1,2,3' | ||||
|             world_size = len(self.args.device.split(',')) | ||||
|             world_size = len(self.args.device.split(",")) | ||||
|         elif isinstance(self.args.device, (tuple, list)):  # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list) | ||||
|             world_size = len(self.args.device) | ||||
|         elif torch.cuda.is_available():  # i.e. device=None or device='' or device=number | ||||
| @ -165,14 +182,16 @@ class BaseTrainer: | ||||
|             world_size = 0 | ||||
| 
 | ||||
|         # Run subprocess if DDP training, else train normally | ||||
|         if world_size > 1 and 'LOCAL_RANK' not in os.environ: | ||||
|         if world_size > 1 and "LOCAL_RANK" not in os.environ: | ||||
|             # Argument checks | ||||
|             if self.args.rect: | ||||
|                 LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'") | ||||
|                 self.args.rect = False | ||||
|             if self.args.batch == -1: | ||||
|                 LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting " | ||||
|                                "default 'batch=16'") | ||||
|                 LOGGER.warning( | ||||
|                     "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting " | ||||
|                     "default 'batch=16'" | ||||
|                 ) | ||||
|                 self.args.batch = 16 | ||||
| 
 | ||||
|             # Command | ||||
| @ -199,37 +218,45 @@ class BaseTrainer: | ||||
|     def _setup_ddp(self, world_size): | ||||
|         """Initializes and sets the DistributedDataParallel parameters for training.""" | ||||
|         torch.cuda.set_device(RANK) | ||||
|         self.device = torch.device('cuda', RANK) | ||||
|         self.device = torch.device("cuda", RANK) | ||||
|         # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') | ||||
|         os.environ['NCCL_BLOCKING_WAIT'] = '1'  # set to enforce timeout | ||||
|         os.environ["NCCL_BLOCKING_WAIT"] = "1"  # set to enforce timeout | ||||
|         dist.init_process_group( | ||||
|             'nccl' if dist.is_nccl_available() else 'gloo', | ||||
|             "nccl" if dist.is_nccl_available() else "gloo", | ||||
|             timeout=timedelta(seconds=10800),  # 3 hours | ||||
|             rank=RANK, | ||||
|             world_size=world_size) | ||||
|             world_size=world_size, | ||||
|         ) | ||||
| 
 | ||||
|     def _setup_train(self, world_size): | ||||
|         """Builds dataloaders and optimizer on correct rank process.""" | ||||
| 
 | ||||
|         # Model | ||||
|         self.run_callbacks('on_pretrain_routine_start') | ||||
|         self.run_callbacks("on_pretrain_routine_start") | ||||
|         ckpt = self.setup_model() | ||||
|         self.model = self.model.to(self.device) | ||||
|         self.set_model_attributes() | ||||
| 
 | ||||
|         # Freeze layers | ||||
|         freeze_list = self.args.freeze if isinstance( | ||||
|             self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else [] | ||||
|         always_freeze_names = ['.dfl']  # always freeze these layers | ||||
|         freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names | ||||
|         freeze_list = ( | ||||
|             self.args.freeze | ||||
|             if isinstance(self.args.freeze, list) | ||||
|             else range(self.args.freeze) | ||||
|             if isinstance(self.args.freeze, int) | ||||
|             else [] | ||||
|         ) | ||||
|         always_freeze_names = [".dfl"]  # always freeze these layers | ||||
|         freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names | ||||
|         for k, v in self.model.named_parameters(): | ||||
|             # v.register_hook(lambda x: torch.nan_to_num(x))  # NaN to 0 (commented for erratic training results) | ||||
|             if any(x in k for x in freeze_layer_names): | ||||
|                 LOGGER.info(f"Freezing layer '{k}'") | ||||
|                 v.requires_grad = False | ||||
|             elif not v.requires_grad: | ||||
|                 LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. " | ||||
|                             'See ultralytics.engine.trainer for customization of frozen layers.') | ||||
|                 LOGGER.info( | ||||
|                     f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. " | ||||
|                     "See ultralytics.engine.trainer for customization of frozen layers." | ||||
|                 ) | ||||
|                 v.requires_grad = True | ||||
| 
 | ||||
|         # Check AMP | ||||
| @ -246,7 +273,7 @@ class BaseTrainer: | ||||
|             self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK]) | ||||
| 
 | ||||
|         # Check imgsz | ||||
|         gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32)  # grid size (max stride) | ||||
|         gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32)  # grid size (max stride) | ||||
|         self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1) | ||||
|         self.stride = gs  # for multi-scale training | ||||
| 
 | ||||
| @ -256,15 +283,14 @@ class BaseTrainer: | ||||
| 
 | ||||
|         # Dataloaders | ||||
|         batch_size = self.batch_size // max(world_size, 1) | ||||
|         self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train') | ||||
|         self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train") | ||||
|         if RANK in (-1, 0): | ||||
|             # NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects. | ||||
|             self.test_loader = self.get_dataloader(self.testset, | ||||
|                                                    batch_size=batch_size if self.args.task == 'obb' else batch_size * 2, | ||||
|                                                    rank=-1, | ||||
|                                                    mode='val') | ||||
|             self.test_loader = self.get_dataloader( | ||||
|                 self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val" | ||||
|             ) | ||||
|             self.validator = self.get_validator() | ||||
|             metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val') | ||||
|             metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val") | ||||
|             self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) | ||||
|             self.ema = ModelEMA(self.model) | ||||
|             if self.args.plots: | ||||
| @ -274,18 +300,20 @@ class BaseTrainer: | ||||
|         self.accumulate = max(round(self.args.nbs / self.batch_size), 1)  # accumulate loss before optimizing | ||||
|         weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs  # scale weight_decay | ||||
|         iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs | ||||
|         self.optimizer = self.build_optimizer(model=self.model, | ||||
|         self.optimizer = self.build_optimizer( | ||||
|             model=self.model, | ||||
|             name=self.args.optimizer, | ||||
|             lr=self.args.lr0, | ||||
|             momentum=self.args.momentum, | ||||
|             decay=weight_decay, | ||||
|                                               iterations=iterations) | ||||
|             iterations=iterations, | ||||
|         ) | ||||
|         # Scheduler | ||||
|         self._setup_scheduler() | ||||
|         self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False | ||||
|         self.resume_training(ckpt) | ||||
|         self.scheduler.last_epoch = self.start_epoch - 1  # do not move | ||||
|         self.run_callbacks('on_pretrain_routine_end') | ||||
|         self.run_callbacks("on_pretrain_routine_end") | ||||
| 
 | ||||
|     def _do_train(self, world_size=1): | ||||
|         """Train completed, evaluate and plot if specified by arguments.""" | ||||
| @ -299,19 +327,23 @@ class BaseTrainer: | ||||
|         self.epoch_time = None | ||||
|         self.epoch_time_start = time.time() | ||||
|         self.train_time_start = time.time() | ||||
|         self.run_callbacks('on_train_start') | ||||
|         LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' | ||||
|         self.run_callbacks("on_train_start") | ||||
|         LOGGER.info( | ||||
|             f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' | ||||
|             f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' | ||||
|             f"Logging results to {colorstr('bold', self.save_dir)}\n" | ||||
|             f'Starting training for ' | ||||
|                     f'{self.args.time} hours...' if self.args.time else f'{self.epochs} epochs...') | ||||
|             f'{self.args.time} hours...' | ||||
|             if self.args.time | ||||
|             else f"{self.epochs} epochs..." | ||||
|         ) | ||||
|         if self.args.close_mosaic: | ||||
|             base_idx = (self.epochs - self.args.close_mosaic) * nb | ||||
|             self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) | ||||
|         epoch = self.epochs  # predefine for resume fully trained model edge cases | ||||
|         for epoch in range(self.start_epoch, self.epochs): | ||||
|             self.epoch = epoch | ||||
|             self.run_callbacks('on_train_epoch_start') | ||||
|             self.run_callbacks("on_train_epoch_start") | ||||
|             self.model.train() | ||||
|             if RANK != -1: | ||||
|                 self.train_loader.sampler.set_epoch(epoch) | ||||
| @ -327,7 +359,7 @@ class BaseTrainer: | ||||
|             self.tloss = None | ||||
|             self.optimizer.zero_grad() | ||||
|             for i, batch in pbar: | ||||
|                 self.run_callbacks('on_train_batch_start') | ||||
|                 self.run_callbacks("on_train_batch_start") | ||||
|                 # Warmup | ||||
|                 ni = i + nb * epoch | ||||
|                 if ni <= nw: | ||||
| @ -335,10 +367,11 @@ class BaseTrainer: | ||||
|                     self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())) | ||||
|                     for j, x in enumerate(self.optimizer.param_groups): | ||||
|                         # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 | ||||
|                         x['lr'] = np.interp( | ||||
|                             ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)]) | ||||
|                         if 'momentum' in x: | ||||
|                             x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) | ||||
|                         x["lr"] = np.interp( | ||||
|                             ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)] | ||||
|                         ) | ||||
|                         if "momentum" in x: | ||||
|                             x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) | ||||
| 
 | ||||
|                 # Forward | ||||
|                 with torch.cuda.amp.autocast(self.amp): | ||||
| @ -346,8 +379,9 @@ class BaseTrainer: | ||||
|                     self.loss, self.loss_items = self.model(batch) | ||||
|                     if RANK != -1: | ||||
|                         self.loss *= world_size | ||||
|                     self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ | ||||
|                         else self.loss_items | ||||
|                     self.tloss = ( | ||||
|                         (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items | ||||
|                     ) | ||||
| 
 | ||||
|                 # Backward | ||||
|                 self.scaler.scale(self.loss).backward() | ||||
| @ -368,24 +402,25 @@ class BaseTrainer: | ||||
|                             break | ||||
| 
 | ||||
|                 # Log | ||||
|                 mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'  # (GB) | ||||
|                 mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G"  # (GB) | ||||
|                 loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1 | ||||
|                 losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) | ||||
|                 if RANK in (-1, 0): | ||||
|                     pbar.set_description( | ||||
|                         ('%11s' * 2 + '%11.4g' * (2 + loss_len)) % | ||||
|                         (f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1])) | ||||
|                     self.run_callbacks('on_batch_end') | ||||
|                         ("%11s" * 2 + "%11.4g" * (2 + loss_len)) | ||||
|                         % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]) | ||||
|                     ) | ||||
|                     self.run_callbacks("on_batch_end") | ||||
|                     if self.args.plots and ni in self.plot_idx: | ||||
|                         self.plot_training_samples(batch, ni) | ||||
| 
 | ||||
|                 self.run_callbacks('on_train_batch_end') | ||||
|                 self.run_callbacks("on_train_batch_end") | ||||
| 
 | ||||
|             self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers | ||||
|             self.run_callbacks('on_train_epoch_end') | ||||
|             self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers | ||||
|             self.run_callbacks("on_train_epoch_end") | ||||
|             if RANK in (-1, 0): | ||||
|                 final_epoch = epoch + 1 == self.epochs | ||||
|                 self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) | ||||
|                 self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) | ||||
| 
 | ||||
|                 # Validation | ||||
|                 if self.args.val or final_epoch or self.stopper.possible_stop or self.stop: | ||||
| @ -398,14 +433,14 @@ class BaseTrainer: | ||||
|                 # Save model | ||||
|                 if self.args.save or final_epoch: | ||||
|                     self.save_model() | ||||
|                     self.run_callbacks('on_model_save') | ||||
|                     self.run_callbacks("on_model_save") | ||||
| 
 | ||||
|             # Scheduler | ||||
|             t = time.time() | ||||
|             self.epoch_time = t - self.epoch_time_start | ||||
|             self.epoch_time_start = t | ||||
|             with warnings.catch_warnings(): | ||||
|                 warnings.simplefilter('ignore')  # suppress 'Detected lr_scheduler.step() before optimizer.step()' | ||||
|                 warnings.simplefilter("ignore")  # suppress 'Detected lr_scheduler.step() before optimizer.step()' | ||||
|                 if self.args.time: | ||||
|                     mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1) | ||||
|                     self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time) | ||||
| @ -413,7 +448,7 @@ class BaseTrainer: | ||||
|                     self.scheduler.last_epoch = self.epoch  # do not move | ||||
|                     self.stop |= epoch >= self.epochs  # stop if exceeded epochs | ||||
|                 self.scheduler.step() | ||||
|             self.run_callbacks('on_fit_epoch_end') | ||||
|             self.run_callbacks("on_fit_epoch_end") | ||||
|             torch.cuda.empty_cache()  # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors | ||||
| 
 | ||||
|             # Early Stopping | ||||
| @ -426,39 +461,43 @@ class BaseTrainer: | ||||
| 
 | ||||
|         if RANK in (-1, 0): | ||||
|             # Do final val with best.pt | ||||
|             LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in ' | ||||
|                         f'{(time.time() - self.train_time_start) / 3600:.3f} hours.') | ||||
|             LOGGER.info( | ||||
|                 f"\n{epoch - self.start_epoch + 1} epochs completed in " | ||||
|                 f"{(time.time() - self.train_time_start) / 3600:.3f} hours." | ||||
|             ) | ||||
|             self.final_eval() | ||||
|             if self.args.plots: | ||||
|                 self.plot_metrics() | ||||
|             self.run_callbacks('on_train_end') | ||||
|             self.run_callbacks("on_train_end") | ||||
|         torch.cuda.empty_cache() | ||||
|         self.run_callbacks('teardown') | ||||
|         self.run_callbacks("teardown") | ||||
| 
 | ||||
|     def save_model(self): | ||||
|         """Save model training checkpoints with additional metadata.""" | ||||
|         import pandas as pd  # scope for faster startup | ||||
|         metrics = {**self.metrics, **{'fitness': self.fitness}} | ||||
|         results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient='list').items()} | ||||
| 
 | ||||
|         metrics = {**self.metrics, **{"fitness": self.fitness}} | ||||
|         results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()} | ||||
|         ckpt = { | ||||
|             'epoch': self.epoch, | ||||
|             'best_fitness': self.best_fitness, | ||||
|             'model': deepcopy(de_parallel(self.model)).half(), | ||||
|             'ema': deepcopy(self.ema.ema).half(), | ||||
|             'updates': self.ema.updates, | ||||
|             'optimizer': self.optimizer.state_dict(), | ||||
|             'train_args': vars(self.args),  # save as dict | ||||
|             'train_metrics': metrics, | ||||
|             'train_results': results, | ||||
|             'date': datetime.now().isoformat(), | ||||
|             'version': __version__} | ||||
|             "epoch": self.epoch, | ||||
|             "best_fitness": self.best_fitness, | ||||
|             "model": deepcopy(de_parallel(self.model)).half(), | ||||
|             "ema": deepcopy(self.ema.ema).half(), | ||||
|             "updates": self.ema.updates, | ||||
|             "optimizer": self.optimizer.state_dict(), | ||||
|             "train_args": vars(self.args),  # save as dict | ||||
|             "train_metrics": metrics, | ||||
|             "train_results": results, | ||||
|             "date": datetime.now().isoformat(), | ||||
|             "version": __version__, | ||||
|         } | ||||
| 
 | ||||
|         # Save last and best | ||||
|         torch.save(ckpt, self.last) | ||||
|         if self.best_fitness == self.fitness: | ||||
|             torch.save(ckpt, self.best) | ||||
|         if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): | ||||
|             torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt') | ||||
|             torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt") | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def get_dataset(data): | ||||
| @ -467,7 +506,7 @@ class BaseTrainer: | ||||
| 
 | ||||
|         Returns None if data format is not recognized. | ||||
|         """ | ||||
|         return data['train'], data.get('val') or data.get('test') | ||||
|         return data["train"], data.get("val") or data.get("test") | ||||
| 
 | ||||
|     def setup_model(self): | ||||
|         """Load/create/download model for any task.""" | ||||
| @ -476,9 +515,9 @@ class BaseTrainer: | ||||
| 
 | ||||
|         model, weights = self.model, None | ||||
|         ckpt = None | ||||
|         if str(model).endswith('.pt'): | ||||
|         if str(model).endswith(".pt"): | ||||
|             weights, ckpt = attempt_load_one_weight(model) | ||||
|             cfg = ckpt['model'].yaml | ||||
|             cfg = ckpt["model"].yaml | ||||
|         else: | ||||
|             cfg = model | ||||
|         self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights) | ||||
| @ -505,7 +544,7 @@ class BaseTrainer: | ||||
|         The returned dict is expected to contain "fitness" key. | ||||
|         """ | ||||
|         metrics = self.validator(self) | ||||
|         fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found | ||||
|         fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found | ||||
|         if not self.best_fitness or self.best_fitness < fitness: | ||||
|             self.best_fitness = fitness | ||||
|         return metrics, fitness | ||||
| @ -516,24 +555,24 @@ class BaseTrainer: | ||||
| 
 | ||||
|     def get_validator(self): | ||||
|         """Returns a NotImplementedError when the get_validator function is called.""" | ||||
|         raise NotImplementedError('get_validator function not implemented in trainer') | ||||
|         raise NotImplementedError("get_validator function not implemented in trainer") | ||||
| 
 | ||||
|     def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): | ||||
|     def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): | ||||
|         """Returns dataloader derived from torch.data.Dataloader.""" | ||||
|         raise NotImplementedError('get_dataloader function not implemented in trainer') | ||||
|         raise NotImplementedError("get_dataloader function not implemented in trainer") | ||||
| 
 | ||||
|     def build_dataset(self, img_path, mode='train', batch=None): | ||||
|     def build_dataset(self, img_path, mode="train", batch=None): | ||||
|         """Build dataset.""" | ||||
|         raise NotImplementedError('build_dataset function not implemented in trainer') | ||||
|         raise NotImplementedError("build_dataset function not implemented in trainer") | ||||
| 
 | ||||
|     def label_loss_items(self, loss_items=None, prefix='train'): | ||||
|     def label_loss_items(self, loss_items=None, prefix="train"): | ||||
|         """Returns a loss dict with labelled training loss items tensor.""" | ||||
|         # Not needed for classification but necessary for segmentation & detection | ||||
|         return {'loss': loss_items} if loss_items is not None else ['loss'] | ||||
|         return {"loss": loss_items} if loss_items is not None else ["loss"] | ||||
| 
 | ||||
|     def set_model_attributes(self): | ||||
|         """To set or update model parameters before training.""" | ||||
|         self.model.names = self.data['names'] | ||||
|         self.model.names = self.data["names"] | ||||
| 
 | ||||
|     def build_targets(self, preds, targets): | ||||
|         """Builds target tensors for training YOLO model.""" | ||||
| @ -541,7 +580,7 @@ class BaseTrainer: | ||||
| 
 | ||||
|     def progress_string(self): | ||||
|         """Returns a string describing training progress.""" | ||||
|         return '' | ||||
|         return "" | ||||
| 
 | ||||
|     # TODO: may need to put these following functions into callback | ||||
|     def plot_training_samples(self, batch, ni): | ||||
| @ -556,9 +595,9 @@ class BaseTrainer: | ||||
|         """Saves training metrics to a CSV file.""" | ||||
|         keys, vals = list(metrics.keys()), list(metrics.values()) | ||||
|         n = len(metrics) + 1  # number of cols | ||||
|         s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n')  # header | ||||
|         with open(self.csv, 'a') as f: | ||||
|             f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n') | ||||
|         s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n")  # header | ||||
|         with open(self.csv, "a") as f: | ||||
|             f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n") | ||||
| 
 | ||||
|     def plot_metrics(self): | ||||
|         """Plot and display metrics visually.""" | ||||
| @ -567,7 +606,7 @@ class BaseTrainer: | ||||
|     def on_plot(self, name, data=None): | ||||
|         """Registers plots (e.g. to be consumed in callbacks)""" | ||||
|         path = Path(name) | ||||
|         self.plots[path] = {'data': data, 'timestamp': time.time()} | ||||
|         self.plots[path] = {"data": data, "timestamp": time.time()} | ||||
| 
 | ||||
|     def final_eval(self): | ||||
|         """Performs final evaluation and validation for object detection YOLO model.""" | ||||
| @ -575,11 +614,11 @@ class BaseTrainer: | ||||
|             if f.exists(): | ||||
|                 strip_optimizer(f)  # strip optimizers | ||||
|                 if f is self.best: | ||||
|                     LOGGER.info(f'\nValidating {f}...') | ||||
|                     LOGGER.info(f"\nValidating {f}...") | ||||
|                     self.validator.args.plots = self.args.plots | ||||
|                     self.metrics = self.validator(model=f) | ||||
|                     self.metrics.pop('fitness', None) | ||||
|                     self.run_callbacks('on_fit_epoch_end') | ||||
|                     self.metrics.pop("fitness", None) | ||||
|                     self.run_callbacks("on_fit_epoch_end") | ||||
| 
 | ||||
|     def check_resume(self, overrides): | ||||
|         """Check if resume checkpoint exists and update arguments accordingly.""" | ||||
| @ -591,19 +630,21 @@ class BaseTrainer: | ||||
| 
 | ||||
|                 # Check that resume data YAML exists, otherwise strip to force re-download of dataset | ||||
|                 ckpt_args = attempt_load_weights(last).args | ||||
|                 if not Path(ckpt_args['data']).exists(): | ||||
|                     ckpt_args['data'] = self.args.data | ||||
|                 if not Path(ckpt_args["data"]).exists(): | ||||
|                     ckpt_args["data"] = self.args.data | ||||
| 
 | ||||
|                 resume = True | ||||
|                 self.args = get_cfg(ckpt_args) | ||||
|                 self.args.model = str(last)  # reinstate model | ||||
|                 for k in 'imgsz', 'batch':  # allow arg updates to reduce memory on resume if crashed due to CUDA OOM | ||||
|                 for k in "imgsz", "batch":  # allow arg updates to reduce memory on resume if crashed due to CUDA OOM | ||||
|                     if k in overrides: | ||||
|                         setattr(self.args, k, overrides[k]) | ||||
| 
 | ||||
|             except Exception as e: | ||||
|                 raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, ' | ||||
|                                         "i.e. 'yolo train resume model=path/to/last.pt'") from e | ||||
|                 raise FileNotFoundError( | ||||
|                     "Resume checkpoint not found. Please pass a valid checkpoint to resume from, " | ||||
|                     "i.e. 'yolo train resume model=path/to/last.pt'" | ||||
|                 ) from e | ||||
|         self.resume = resume | ||||
| 
 | ||||
|     def resume_training(self, ckpt): | ||||
| @ -611,23 +652,26 @@ class BaseTrainer: | ||||
|         if ckpt is None: | ||||
|             return | ||||
|         best_fitness = 0.0 | ||||
|         start_epoch = ckpt['epoch'] + 1 | ||||
|         if ckpt['optimizer'] is not None: | ||||
|             self.optimizer.load_state_dict(ckpt['optimizer'])  # optimizer | ||||
|             best_fitness = ckpt['best_fitness'] | ||||
|         if self.ema and ckpt.get('ema'): | ||||
|             self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict())  # EMA | ||||
|             self.ema.updates = ckpt['updates'] | ||||
|         start_epoch = ckpt["epoch"] + 1 | ||||
|         if ckpt["optimizer"] is not None: | ||||
|             self.optimizer.load_state_dict(ckpt["optimizer"])  # optimizer | ||||
|             best_fitness = ckpt["best_fitness"] | ||||
|         if self.ema and ckpt.get("ema"): | ||||
|             self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())  # EMA | ||||
|             self.ema.updates = ckpt["updates"] | ||||
|         if self.resume: | ||||
|             assert start_epoch > 0, \ | ||||
|                 f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ | ||||
|             assert start_epoch > 0, ( | ||||
|                 f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" | ||||
|                 f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" | ||||
|             ) | ||||
|             LOGGER.info( | ||||
|                 f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs') | ||||
|                 f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs" | ||||
|             ) | ||||
|         if self.epochs < start_epoch: | ||||
|             LOGGER.info( | ||||
|                 f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.") | ||||
|             self.epochs += ckpt['epoch']  # finetune additional epochs | ||||
|                 f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." | ||||
|             ) | ||||
|             self.epochs += ckpt["epoch"]  # finetune additional epochs | ||||
|         self.best_fitness = best_fitness | ||||
|         self.start_epoch = start_epoch | ||||
|         if start_epoch > (self.epochs - self.args.close_mosaic): | ||||
| @ -635,13 +679,13 @@ class BaseTrainer: | ||||
| 
 | ||||
|     def _close_dataloader_mosaic(self): | ||||
|         """Update dataloaders to stop using mosaic augmentation.""" | ||||
|         if hasattr(self.train_loader.dataset, 'mosaic'): | ||||
|         if hasattr(self.train_loader.dataset, "mosaic"): | ||||
|             self.train_loader.dataset.mosaic = False | ||||
|         if hasattr(self.train_loader.dataset, 'close_mosaic'): | ||||
|             LOGGER.info('Closing dataloader mosaic') | ||||
|         if hasattr(self.train_loader.dataset, "close_mosaic"): | ||||
|             LOGGER.info("Closing dataloader mosaic") | ||||
|             self.train_loader.dataset.close_mosaic(hyp=self.args) | ||||
| 
 | ||||
|     def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): | ||||
|     def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): | ||||
|         """ | ||||
|         Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum, | ||||
|         weight decay, and number of iterations. | ||||
| @ -661,41 +705,45 @@ class BaseTrainer: | ||||
|         """ | ||||
| 
 | ||||
|         g = [], [], []  # optimizer parameter groups | ||||
|         bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d() | ||||
|         if name == 'auto': | ||||
|             LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, " | ||||
|         bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d() | ||||
|         if name == "auto": | ||||
|             LOGGER.info( | ||||
|                 f"{colorstr('optimizer:')} 'optimizer=auto' found, " | ||||
|                 f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " | ||||
|                         f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ") | ||||
|             nc = getattr(model, 'nc', 10)  # number of classes | ||||
|                 f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " | ||||
|             ) | ||||
|             nc = getattr(model, "nc", 10)  # number of classes | ||||
|             lr_fit = round(0.002 * 5 / (4 + nc), 6)  # lr0 fit equation to 6 decimal places | ||||
|             name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9) | ||||
|             name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9) | ||||
|             self.args.warmup_bias_lr = 0.0  # no higher than 0.01 for Adam | ||||
| 
 | ||||
|         for module_name, module in model.named_modules(): | ||||
|             for param_name, param in module.named_parameters(recurse=False): | ||||
|                 fullname = f'{module_name}.{param_name}' if module_name else param_name | ||||
|                 if 'bias' in fullname:  # bias (no decay) | ||||
|                 fullname = f"{module_name}.{param_name}" if module_name else param_name | ||||
|                 if "bias" in fullname:  # bias (no decay) | ||||
|                     g[2].append(param) | ||||
|                 elif isinstance(module, bn):  # weight (no decay) | ||||
|                     g[1].append(param) | ||||
|                 else:  # weight (with decay) | ||||
|                     g[0].append(param) | ||||
| 
 | ||||
|         if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'): | ||||
|         if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"): | ||||
|             optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) | ||||
|         elif name == 'RMSProp': | ||||
|         elif name == "RMSProp": | ||||
|             optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) | ||||
|         elif name == 'SGD': | ||||
|         elif name == "SGD": | ||||
|             optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) | ||||
|         else: | ||||
|             raise NotImplementedError( | ||||
|                 f"Optimizer '{name}' not found in list of available optimizers " | ||||
|                 f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].' | ||||
|                 'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.') | ||||
|                 f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]." | ||||
|                 "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics." | ||||
|             ) | ||||
| 
 | ||||
|         optimizer.add_param_group({'params': g[0], 'weight_decay': decay})  # add g0 with weight_decay | ||||
|         optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})  # add g1 (BatchNorm2d weights) | ||||
|         optimizer.add_param_group({"params": g[0], "weight_decay": decay})  # add g0 with weight_decay | ||||
|         optimizer.add_param_group({"params": g[1], "weight_decay": 0.0})  # add g1 (BatchNorm2d weights) | ||||
|         LOGGER.info( | ||||
|             f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups " | ||||
|             f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)') | ||||
|             f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)' | ||||
|         ) | ||||
|         return optimizer | ||||
|  | ||||
| @ -73,40 +73,43 @@ class Tuner: | ||||
|         Args: | ||||
|             args (dict, optional): Configuration for hyperparameter evolution. | ||||
|         """ | ||||
|         self.space = args.pop('space', None) or {  # key: (min, max, gain(optional)) | ||||
|         self.space = args.pop("space", None) or {  # key: (min, max, gain(optional)) | ||||
|             # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), | ||||
|             'lr0': (1e-5, 1e-1),  # initial learning rate (i.e. SGD=1E-2, Adam=1E-3) | ||||
|             'lrf': (0.0001, 0.1),  # final OneCycleLR learning rate (lr0 * lrf) | ||||
|             'momentum': (0.7, 0.98, 0.3),  # SGD momentum/Adam beta1 | ||||
|             'weight_decay': (0.0, 0.001),  # optimizer weight decay 5e-4 | ||||
|             'warmup_epochs': (0.0, 5.0),  # warmup epochs (fractions ok) | ||||
|             'warmup_momentum': (0.0, 0.95),  # warmup initial momentum | ||||
|             'box': (1.0, 20.0),  # box loss gain | ||||
|             'cls': (0.2, 4.0),  # cls loss gain (scale with pixels) | ||||
|             'dfl': (0.4, 6.0),  # dfl loss gain | ||||
|             'hsv_h': (0.0, 0.1),  # image HSV-Hue augmentation (fraction) | ||||
|             'hsv_s': (0.0, 0.9),  # image HSV-Saturation augmentation (fraction) | ||||
|             'hsv_v': (0.0, 0.9),  # image HSV-Value augmentation (fraction) | ||||
|             'degrees': (0.0, 45.0),  # image rotation (+/- deg) | ||||
|             'translate': (0.0, 0.9),  # image translation (+/- fraction) | ||||
|             'scale': (0.0, 0.95),  # image scale (+/- gain) | ||||
|             'shear': (0.0, 10.0),  # image shear (+/- deg) | ||||
|             'perspective': (0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001 | ||||
|             'flipud': (0.0, 1.0),  # image flip up-down (probability) | ||||
|             'fliplr': (0.0, 1.0),  # image flip left-right (probability) | ||||
|             'mosaic': (0.0, 1.0),  # image mixup (probability) | ||||
|             'mixup': (0.0, 1.0),  # image mixup (probability) | ||||
|             'copy_paste': (0.0, 1.0)}  # segment copy-paste (probability) | ||||
|             "lr0": (1e-5, 1e-1),  # initial learning rate (i.e. SGD=1E-2, Adam=1E-3) | ||||
|             "lrf": (0.0001, 0.1),  # final OneCycleLR learning rate (lr0 * lrf) | ||||
|             "momentum": (0.7, 0.98, 0.3),  # SGD momentum/Adam beta1 | ||||
|             "weight_decay": (0.0, 0.001),  # optimizer weight decay 5e-4 | ||||
|             "warmup_epochs": (0.0, 5.0),  # warmup epochs (fractions ok) | ||||
|             "warmup_momentum": (0.0, 0.95),  # warmup initial momentum | ||||
|             "box": (1.0, 20.0),  # box loss gain | ||||
|             "cls": (0.2, 4.0),  # cls loss gain (scale with pixels) | ||||
|             "dfl": (0.4, 6.0),  # dfl loss gain | ||||
|             "hsv_h": (0.0, 0.1),  # image HSV-Hue augmentation (fraction) | ||||
|             "hsv_s": (0.0, 0.9),  # image HSV-Saturation augmentation (fraction) | ||||
|             "hsv_v": (0.0, 0.9),  # image HSV-Value augmentation (fraction) | ||||
|             "degrees": (0.0, 45.0),  # image rotation (+/- deg) | ||||
|             "translate": (0.0, 0.9),  # image translation (+/- fraction) | ||||
|             "scale": (0.0, 0.95),  # image scale (+/- gain) | ||||
|             "shear": (0.0, 10.0),  # image shear (+/- deg) | ||||
|             "perspective": (0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001 | ||||
|             "flipud": (0.0, 1.0),  # image flip up-down (probability) | ||||
|             "fliplr": (0.0, 1.0),  # image flip left-right (probability) | ||||
|             "mosaic": (0.0, 1.0),  # image mixup (probability) | ||||
|             "mixup": (0.0, 1.0),  # image mixup (probability) | ||||
|             "copy_paste": (0.0, 1.0),  # segment copy-paste (probability) | ||||
|         } | ||||
|         self.args = get_cfg(overrides=args) | ||||
|         self.tune_dir = get_save_dir(self.args, name='tune') | ||||
|         self.tune_csv = self.tune_dir / 'tune_results.csv' | ||||
|         self.tune_dir = get_save_dir(self.args, name="tune") | ||||
|         self.tune_csv = self.tune_dir / "tune_results.csv" | ||||
|         self.callbacks = _callbacks or callbacks.get_default_callbacks() | ||||
|         self.prefix = colorstr('Tuner: ') | ||||
|         self.prefix = colorstr("Tuner: ") | ||||
|         callbacks.add_integration_callbacks(self) | ||||
|         LOGGER.info(f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n" | ||||
|                     f'{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning') | ||||
|         LOGGER.info( | ||||
|             f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n" | ||||
|             f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning" | ||||
|         ) | ||||
| 
 | ||||
|     def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2): | ||||
|     def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2): | ||||
|         """ | ||||
|         Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`. | ||||
| 
 | ||||
| @ -121,15 +124,15 @@ class Tuner: | ||||
|         """ | ||||
|         if self.tune_csv.exists():  # if CSV file exists: select best hyps and mutate | ||||
|             # Select parent(s) | ||||
|             x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1) | ||||
|             x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1) | ||||
|             fitness = x[:, 0]  # first column | ||||
|             n = min(n, len(x))  # number of previous results to consider | ||||
|             x = x[np.argsort(-fitness)][:n]  # top n mutations | ||||
|             w = x[:, 0] - x[:, 0].min() + 1E-6  # weights (sum > 0) | ||||
|             if parent == 'single' or len(x) == 1: | ||||
|             w = x[:, 0] - x[:, 0].min() + 1e-6  # weights (sum > 0) | ||||
|             if parent == "single" or len(x) == 1: | ||||
|                 # x = x[random.randint(0, n - 1)]  # random selection | ||||
|                 x = x[random.choices(range(n), weights=w)[0]]  # weighted selection | ||||
|             elif parent == 'weighted': | ||||
|             elif parent == "weighted": | ||||
|                 x = (x * w.reshape(n, 1)).sum(0) / w.sum()  # weighted combination | ||||
| 
 | ||||
|             # Mutate | ||||
| @ -174,44 +177,44 @@ class Tuner: | ||||
| 
 | ||||
|         t0 = time.time() | ||||
|         best_save_dir, best_metrics = None, None | ||||
|         (self.tune_dir / 'weights').mkdir(parents=True, exist_ok=True) | ||||
|         (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True) | ||||
|         for i in range(iterations): | ||||
|             # Mutate hyperparameters | ||||
|             mutated_hyp = self._mutate() | ||||
|             LOGGER.info(f'{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}') | ||||
|             LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}") | ||||
| 
 | ||||
|             metrics = {} | ||||
|             train_args = {**vars(self.args), **mutated_hyp} | ||||
|             save_dir = get_save_dir(get_cfg(train_args)) | ||||
|             weights_dir = save_dir / 'weights' | ||||
|             ckpt_file = weights_dir / ('best.pt' if (weights_dir / 'best.pt').exists() else 'last.pt') | ||||
|             weights_dir = save_dir / "weights" | ||||
|             ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt") | ||||
|             try: | ||||
|                 # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang) | ||||
|                 cmd = ['yolo', 'train', *(f'{k}={v}' for k, v in train_args.items())] | ||||
|                 cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())] | ||||
|                 return_code = subprocess.run(cmd, check=True).returncode | ||||
|                 metrics = torch.load(ckpt_file)['train_metrics'] | ||||
|                 assert return_code == 0, 'training failed' | ||||
|                 metrics = torch.load(ckpt_file)["train_metrics"] | ||||
|                 assert return_code == 0, "training failed" | ||||
| 
 | ||||
|             except Exception as e: | ||||
|                 LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}') | ||||
|                 LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}") | ||||
| 
 | ||||
|             # Save results and mutated_hyp to CSV | ||||
|             fitness = metrics.get('fitness', 0.0) | ||||
|             fitness = metrics.get("fitness", 0.0) | ||||
|             log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()] | ||||
|             headers = '' if self.tune_csv.exists() else (','.join(['fitness'] + list(self.space.keys())) + '\n') | ||||
|             with open(self.tune_csv, 'a') as f: | ||||
|                 f.write(headers + ','.join(map(str, log_row)) + '\n') | ||||
|             headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n") | ||||
|             with open(self.tune_csv, "a") as f: | ||||
|                 f.write(headers + ",".join(map(str, log_row)) + "\n") | ||||
| 
 | ||||
|             # Get best results | ||||
|             x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1) | ||||
|             x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1) | ||||
|             fitness = x[:, 0]  # first column | ||||
|             best_idx = fitness.argmax() | ||||
|             best_is_current = best_idx == i | ||||
|             if best_is_current: | ||||
|                 best_save_dir = save_dir | ||||
|                 best_metrics = {k: round(v, 5) for k, v in metrics.items()} | ||||
|                 for ckpt in weights_dir.glob('*.pt'): | ||||
|                     shutil.copy2(ckpt, self.tune_dir / 'weights') | ||||
|                 for ckpt in weights_dir.glob("*.pt"): | ||||
|                     shutil.copy2(ckpt, self.tune_dir / "weights") | ||||
|             elif cleanup: | ||||
|                 shutil.rmtree(ckpt_file.parent)  # remove iteration weights/ dir to reduce storage space | ||||
| 
 | ||||
| @ -219,15 +222,19 @@ class Tuner: | ||||
|             plot_tune_results(self.tune_csv) | ||||
| 
 | ||||
|             # Save and print tune results | ||||
|             header = (f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n' | ||||
|             header = ( | ||||
|                 f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n' | ||||
|                 f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n' | ||||
|                 f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n' | ||||
|                 f'{self.prefix}Best fitness metrics are {best_metrics}\n' | ||||
|                 f'{self.prefix}Best fitness model is {best_save_dir}\n' | ||||
|                       f'{self.prefix}Best fitness hyperparameters are printed below.\n') | ||||
|             LOGGER.info('\n' + header) | ||||
|                 f'{self.prefix}Best fitness hyperparameters are printed below.\n' | ||||
|             ) | ||||
|             LOGGER.info("\n" + header) | ||||
|             data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())} | ||||
|             yaml_save(self.tune_dir / 'best_hyperparameters.yaml', | ||||
|             yaml_save( | ||||
|                 self.tune_dir / "best_hyperparameters.yaml", | ||||
|                 data=data, | ||||
|                       header=remove_colorstr(header.replace(self.prefix, '# ')) + '\n') | ||||
|             yaml_print(self.tune_dir / 'best_hyperparameters.yaml') | ||||
|                 header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n", | ||||
|             ) | ||||
|             yaml_print(self.tune_dir / "best_hyperparameters.yaml") | ||||
|  | ||||
| @ -89,10 +89,10 @@ class BaseValidator: | ||||
|         self.nc = None | ||||
|         self.iouv = None | ||||
|         self.jdict = None | ||||
|         self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} | ||||
|         self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} | ||||
| 
 | ||||
|         self.save_dir = save_dir or get_save_dir(self.args) | ||||
|         (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) | ||||
|         (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) | ||||
|         if self.args.conf is None: | ||||
|             self.args.conf = 0.001  # default conf=0.001 | ||||
|         self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1) | ||||
| @ -110,7 +110,7 @@ class BaseValidator: | ||||
|         if self.training: | ||||
|             self.device = trainer.device | ||||
|             self.data = trainer.data | ||||
|             self.args.half = self.device.type != 'cpu'  # force FP16 val during training | ||||
|             self.args.half = self.device.type != "cpu"  # force FP16 val during training | ||||
|             model = trainer.ema.ema or trainer.model | ||||
|             model = model.half() if self.args.half else model.float() | ||||
|             # self.model = model | ||||
| @ -119,11 +119,13 @@ class BaseValidator: | ||||
|             model.eval() | ||||
|         else: | ||||
|             callbacks.add_integration_callbacks(self) | ||||
|             model = AutoBackend(model or self.args.model, | ||||
|             model = AutoBackend( | ||||
|                 model or self.args.model, | ||||
|                 device=select_device(self.args.device, self.args.batch), | ||||
|                 dnn=self.args.dnn, | ||||
|                 data=self.args.data, | ||||
|                                 fp16=self.args.half) | ||||
|                 fp16=self.args.half, | ||||
|             ) | ||||
|             # self.model = model | ||||
|             self.device = model.device  # update device | ||||
|             self.args.half = model.fp16  # update half | ||||
| @ -133,16 +135,16 @@ class BaseValidator: | ||||
|                 self.args.batch = model.batch_size | ||||
|             elif not pt and not jit: | ||||
|                 self.args.batch = 1  # export.py models default to batch-size 1 | ||||
|                 LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') | ||||
|                 LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models") | ||||
| 
 | ||||
|             if str(self.args.data).split('.')[-1] in ('yaml', 'yml'): | ||||
|             if str(self.args.data).split(".")[-1] in ("yaml", "yml"): | ||||
|                 self.data = check_det_dataset(self.args.data) | ||||
|             elif self.args.task == 'classify': | ||||
|             elif self.args.task == "classify": | ||||
|                 self.data = check_cls_dataset(self.args.data, split=self.args.split) | ||||
|             else: | ||||
|                 raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) | ||||
| 
 | ||||
|             if self.device.type in ('cpu', 'mps'): | ||||
|             if self.device.type in ("cpu", "mps"): | ||||
|                 self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading | ||||
|             if not pt: | ||||
|                 self.args.rect = False | ||||
| @ -152,13 +154,13 @@ class BaseValidator: | ||||
|             model.eval() | ||||
|             model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))  # warmup | ||||
| 
 | ||||
|         self.run_callbacks('on_val_start') | ||||
|         self.run_callbacks("on_val_start") | ||||
|         dt = Profile(), Profile(), Profile(), Profile() | ||||
|         bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader)) | ||||
|         self.init_metrics(de_parallel(model)) | ||||
|         self.jdict = []  # empty before each val | ||||
|         for batch_i, batch in enumerate(bar): | ||||
|             self.run_callbacks('on_val_batch_start') | ||||
|             self.run_callbacks("on_val_batch_start") | ||||
|             self.batch_i = batch_i | ||||
|             # Preprocess | ||||
|             with dt[0]: | ||||
| @ -166,7 +168,7 @@ class BaseValidator: | ||||
| 
 | ||||
|             # Inference | ||||
|             with dt[1]: | ||||
|                 preds = model(batch['img'], augment=augment) | ||||
|                 preds = model(batch["img"], augment=augment) | ||||
| 
 | ||||
|             # Loss | ||||
|             with dt[2]: | ||||
| @ -182,23 +184,25 @@ class BaseValidator: | ||||
|                 self.plot_val_samples(batch, batch_i) | ||||
|                 self.plot_predictions(batch, preds, batch_i) | ||||
| 
 | ||||
|             self.run_callbacks('on_val_batch_end') | ||||
|             self.run_callbacks("on_val_batch_end") | ||||
|         stats = self.get_stats() | ||||
|         self.check_stats(stats) | ||||
|         self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt))) | ||||
|         self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt))) | ||||
|         self.finalize_metrics() | ||||
|         self.print_results() | ||||
|         self.run_callbacks('on_val_end') | ||||
|         self.run_callbacks("on_val_end") | ||||
|         if self.training: | ||||
|             model.float() | ||||
|             results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')} | ||||
|             results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} | ||||
|             return {k: round(float(v), 5) for k, v in results.items()}  # return results as 5 decimal place floats | ||||
|         else: | ||||
|             LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' % | ||||
|                         tuple(self.speed.values())) | ||||
|             LOGGER.info( | ||||
|                 "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image" | ||||
|                 % tuple(self.speed.values()) | ||||
|             ) | ||||
|             if self.args.save_json and self.jdict: | ||||
|                 with open(str(self.save_dir / 'predictions.json'), 'w') as f: | ||||
|                     LOGGER.info(f'Saving {f.name}...') | ||||
|                 with open(str(self.save_dir / "predictions.json"), "w") as f: | ||||
|                     LOGGER.info(f"Saving {f.name}...") | ||||
|                     json.dump(self.jdict, f)  # flatten and save | ||||
|                 stats = self.eval_json(stats)  # update stats | ||||
|             if self.args.plots or self.args.save_json: | ||||
| @ -228,6 +232,7 @@ class BaseValidator: | ||||
|             if use_scipy: | ||||
|                 # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708 | ||||
|                 import scipy  # scope import to avoid importing for all commands | ||||
| 
 | ||||
|                 cost_matrix = iou * (iou >= threshold) | ||||
|                 if cost_matrix.any(): | ||||
|                     labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True) | ||||
| @ -257,11 +262,11 @@ class BaseValidator: | ||||
| 
 | ||||
|     def get_dataloader(self, dataset_path, batch_size): | ||||
|         """Get data loader from dataset path and batch size.""" | ||||
|         raise NotImplementedError('get_dataloader function not implemented for this validator') | ||||
|         raise NotImplementedError("get_dataloader function not implemented for this validator") | ||||
| 
 | ||||
|     def build_dataset(self, img_path): | ||||
|         """Build dataset.""" | ||||
|         raise NotImplementedError('build_dataset function not implemented in validator') | ||||
|         raise NotImplementedError("build_dataset function not implemented in validator") | ||||
| 
 | ||||
|     def preprocess(self, batch): | ||||
|         """Preprocesses an input batch.""" | ||||
| @ -306,7 +311,7 @@ class BaseValidator: | ||||
| 
 | ||||
|     def on_plot(self, name, data=None): | ||||
|         """Registers plots (e.g. to be consumed in callbacks)""" | ||||
|         self.plots[Path(name)] = {'data': data, 'timestamp': time.time()} | ||||
|         self.plots[Path(name)] = {"data": data, "timestamp": time.time()} | ||||
| 
 | ||||
|     # TODO: may need to put these following functions into callback | ||||
|     def plot_val_samples(self, batch, ni): | ||||
|  | ||||
| @ -21,10 +21,10 @@ def login(api_key: str = None, save=True) -> bool: | ||||
|     Returns: | ||||
|         bool: True if authentication is successful, False otherwise. | ||||
|     """ | ||||
|     api_key_url = f'{HUB_WEB_ROOT}/settings?tab=api+keys'  # Set the redirect URL | ||||
|     saved_key = SETTINGS.get('api_key') | ||||
|     api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys"  # set the redirect URL | ||||
|     saved_key = SETTINGS.get("api_key") | ||||
|     active_key = api_key or saved_key | ||||
|     credentials = {'api_key': active_key} if active_key and active_key != '' else None  # Set credentials | ||||
|     credentials = {"api_key": active_key} if active_key and active_key != "" else None  # set credentials | ||||
| 
 | ||||
|     client = HUBClient(credentials)  # initialize HUBClient | ||||
| 
 | ||||
| @ -32,17 +32,18 @@ def login(api_key: str = None, save=True) -> bool: | ||||
|         # Successfully authenticated with HUB | ||||
| 
 | ||||
|         if save and client.api_key != saved_key: | ||||
|             SETTINGS.update({'api_key': client.api_key})  # update settings with valid API key | ||||
|             SETTINGS.update({"api_key": client.api_key})  # update settings with valid API key | ||||
| 
 | ||||
|         # Set message based on whether key was provided or retrieved from settings | ||||
|         log_message = ('New authentication successful ✅' | ||||
|                        if client.api_key == api_key or not credentials else 'Authenticated ✅') | ||||
|         LOGGER.info(f'{PREFIX}{log_message}') | ||||
|         log_message = ( | ||||
|             "New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅" | ||||
|         ) | ||||
|         LOGGER.info(f"{PREFIX}{log_message}") | ||||
| 
 | ||||
|         return True | ||||
|     else: | ||||
|         # Failed to authenticate with HUB | ||||
|         LOGGER.info(f'{PREFIX}Retrieve API key from {api_key_url}') | ||||
|         LOGGER.info(f"{PREFIX}Retrieve API key from {api_key_url}") | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| @ -57,50 +58,50 @@ def logout(): | ||||
|         hub.logout() | ||||
|         ``` | ||||
|     """ | ||||
|     SETTINGS['api_key'] = '' | ||||
|     SETTINGS["api_key"] = "" | ||||
|     SETTINGS.save() | ||||
|     LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.") | ||||
| 
 | ||||
| 
 | ||||
| def reset_model(model_id=''): | ||||
| def reset_model(model_id=""): | ||||
|     """Reset a trained model to an untrained state.""" | ||||
|     r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'modelId': model_id}, headers={'x-api-key': Auth().api_key}) | ||||
|     r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key}) | ||||
|     if r.status_code == 200: | ||||
|         LOGGER.info(f'{PREFIX}Model reset successfully') | ||||
|         LOGGER.info(f"{PREFIX}Model reset successfully") | ||||
|         return | ||||
|     LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}') | ||||
|     LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}") | ||||
| 
 | ||||
| 
 | ||||
| def export_fmts_hub(): | ||||
|     """Returns a list of HUB-supported export formats.""" | ||||
|     from ultralytics.engine.exporter import export_formats | ||||
|     return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml'] | ||||
| 
 | ||||
|     return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"] | ||||
| 
 | ||||
| 
 | ||||
| def export_model(model_id='', format='torchscript'): | ||||
| def export_model(model_id="", format="torchscript"): | ||||
|     """Export a model to all formats.""" | ||||
|     assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" | ||||
|     r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export', | ||||
|                       json={'format': format}, | ||||
|                       headers={'x-api-key': Auth().api_key}) | ||||
|     assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}' | ||||
|     LOGGER.info(f'{PREFIX}{format} export started ✅') | ||||
|     r = requests.post( | ||||
|         f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key} | ||||
|     ) | ||||
|     assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}" | ||||
|     LOGGER.info(f"{PREFIX}{format} export started ✅") | ||||
| 
 | ||||
| 
 | ||||
| def get_export(model_id='', format='torchscript'): | ||||
| def get_export(model_id="", format="torchscript"): | ||||
|     """Get an exported model dictionary with download URL.""" | ||||
|     assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" | ||||
|     r = requests.post(f'{HUB_API_ROOT}/get-export', | ||||
|                       json={ | ||||
|                           'apiKey': Auth().api_key, | ||||
|                           'modelId': model_id, | ||||
|                           'format': format}, | ||||
|                           headers={'x-api-key': Auth().api_key}) | ||||
|     assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}' | ||||
|     r = requests.post( | ||||
|         f"{HUB_API_ROOT}/get-export", | ||||
|         json={"apiKey": Auth().api_key, "modelId": model_id, "format": format}, | ||||
|         headers={"x-api-key": Auth().api_key}, | ||||
|     ) | ||||
|     assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}" | ||||
|     return r.json() | ||||
| 
 | ||||
| 
 | ||||
| def check_dataset(path='', task='detect'): | ||||
| def check_dataset(path="", task="detect"): | ||||
|     """ | ||||
|     Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded | ||||
|     to the HUB. Usage examples are given below. | ||||
| @ -119,4 +120,4 @@ def check_dataset(path='', task='detect'): | ||||
|         ``` | ||||
|     """ | ||||
|     HUBDatasetStats(path=path, task=task).get_json() | ||||
|     LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.') | ||||
|     LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.") | ||||
|  | ||||
| @ -6,7 +6,7 @@ from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT | ||||
| from ultralytics.hub.utils import PREFIX, request_with_credentials | ||||
| from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab | ||||
| 
 | ||||
| API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys' | ||||
| API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys" | ||||
| 
 | ||||
| 
 | ||||
| class Auth: | ||||
| @ -23,9 +23,10 @@ class Auth: | ||||
|         api_key (str or bool): API key for authentication, initialized as False. | ||||
|         model_key (bool): Placeholder for model key, initialized as False. | ||||
|     """ | ||||
| 
 | ||||
|     id_token = api_key = model_key = False | ||||
| 
 | ||||
|     def __init__(self, api_key='', verbose=False): | ||||
|     def __init__(self, api_key="", verbose=False): | ||||
|         """ | ||||
|         Initialize the Auth class with an optional API key. | ||||
| 
 | ||||
| @ -33,18 +34,18 @@ class Auth: | ||||
|             api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id | ||||
|         """ | ||||
|         # Split the input API key in case it contains a combined key_model and keep only the API key part | ||||
|         api_key = api_key.split('_')[0] | ||||
|         api_key = api_key.split("_")[0] | ||||
| 
 | ||||
|         # Set API key attribute as value passed or SETTINGS API key if none passed | ||||
|         self.api_key = api_key or SETTINGS.get('api_key', '') | ||||
|         self.api_key = api_key or SETTINGS.get("api_key", "") | ||||
| 
 | ||||
|         # If an API key is provided | ||||
|         if self.api_key: | ||||
|             # If the provided API key matches the API key in the SETTINGS | ||||
|             if self.api_key == SETTINGS.get('api_key'): | ||||
|             if self.api_key == SETTINGS.get("api_key"): | ||||
|                 # Log that the user is already logged in | ||||
|                 if verbose: | ||||
|                     LOGGER.info(f'{PREFIX}Authenticated ✅') | ||||
|                     LOGGER.info(f"{PREFIX}Authenticated ✅") | ||||
|                 return | ||||
|             else: | ||||
|                 # Attempt to authenticate with the provided API key | ||||
| @ -59,12 +60,12 @@ class Auth: | ||||
| 
 | ||||
|         # Update SETTINGS with the new API key after successful authentication | ||||
|         if success: | ||||
|             SETTINGS.update({'api_key': self.api_key}) | ||||
|             SETTINGS.update({"api_key": self.api_key}) | ||||
|             # Log that the new login was successful | ||||
|             if verbose: | ||||
|                 LOGGER.info(f'{PREFIX}New authentication successful ✅') | ||||
|                 LOGGER.info(f"{PREFIX}New authentication successful ✅") | ||||
|         elif verbose: | ||||
|             LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}') | ||||
|             LOGGER.info(f"{PREFIX}Retrieve API key from {API_KEY_URL}") | ||||
| 
 | ||||
|     def request_api_key(self, max_attempts=3): | ||||
|         """ | ||||
| @ -73,13 +74,14 @@ class Auth: | ||||
|         Returns the model ID. | ||||
|         """ | ||||
|         import getpass | ||||
| 
 | ||||
|         for attempts in range(max_attempts): | ||||
|             LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}') | ||||
|             input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ') | ||||
|             self.api_key = input_key.split('_')[0]  # remove model id if present | ||||
|             LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}") | ||||
|             input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ") | ||||
|             self.api_key = input_key.split("_")[0]  # remove model id if present | ||||
|             if self.authenticate(): | ||||
|                 return True | ||||
|         raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌')) | ||||
|         raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌")) | ||||
| 
 | ||||
|     def authenticate(self) -> bool: | ||||
|         """ | ||||
| @ -90,14 +92,14 @@ class Auth: | ||||
|         """ | ||||
|         try: | ||||
|             if header := self.get_auth_header(): | ||||
|                 r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header) | ||||
|                 if not r.json().get('success', False): | ||||
|                     raise ConnectionError('Unable to authenticate.') | ||||
|                 r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header) | ||||
|                 if not r.json().get("success", False): | ||||
|                     raise ConnectionError("Unable to authenticate.") | ||||
|                 return True | ||||
|             raise ConnectionError('User has not authenticated locally.') | ||||
|             raise ConnectionError("User has not authenticated locally.") | ||||
|         except ConnectionError: | ||||
|             self.id_token = self.api_key = False  # reset invalid | ||||
|             LOGGER.warning(f'{PREFIX}Invalid API key ⚠️') | ||||
|             LOGGER.warning(f"{PREFIX}Invalid API key ⚠️") | ||||
|             return False | ||||
| 
 | ||||
|     def auth_with_cookies(self) -> bool: | ||||
| @ -111,12 +113,12 @@ class Auth: | ||||
|         if not is_colab(): | ||||
|             return False  # Currently only works with Colab | ||||
|         try: | ||||
|             authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto') | ||||
|             if authn.get('success', False): | ||||
|                 self.id_token = authn.get('data', {}).get('idToken', None) | ||||
|             authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto") | ||||
|             if authn.get("success", False): | ||||
|                 self.id_token = authn.get("data", {}).get("idToken", None) | ||||
|                 self.authenticate() | ||||
|                 return True | ||||
|             raise ConnectionError('Unable to fetch browser authentication details.') | ||||
|             raise ConnectionError("Unable to fetch browser authentication details.") | ||||
|         except ConnectionError: | ||||
|             self.id_token = False  # reset invalid | ||||
|             return False | ||||
| @ -129,7 +131,7 @@ class Auth: | ||||
|             (dict): The authentication header if id_token or API key is set, None otherwise. | ||||
|         """ | ||||
|         if self.id_token: | ||||
|             return {'authorization': f'Bearer {self.id_token}'} | ||||
|             return {"authorization": f"Bearer {self.id_token}"} | ||||
|         elif self.api_key: | ||||
|             return {'x-api-key': self.api_key} | ||||
|             return {"x-api-key": self.api_key} | ||||
|         # else returns None | ||||
|  | ||||
| @ -12,16 +12,13 @@ from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM | ||||
| from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab | ||||
| from ultralytics.utils.errors import HUBModelError | ||||
| 
 | ||||
| AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local') | ||||
| AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local" | ||||
| 
 | ||||
| 
 | ||||
| class HUBTrainingSession: | ||||
|     """ | ||||
|     HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing. | ||||
| 
 | ||||
|     Args: | ||||
|         url (str): Model identifier used to initialize the HUB training session. | ||||
| 
 | ||||
|     Attributes: | ||||
|         agent_id (str): Identifier for the instance communicating with the server. | ||||
|         model_id (str): Identifier for the YOLO model being trained. | ||||
| @ -40,7 +37,7 @@ class HUBTrainingSession: | ||||
|         Initialize the HUBTrainingSession with the provided model identifier. | ||||
| 
 | ||||
|         Args: | ||||
|             url (str): Model identifier used to initialize the HUB training session. | ||||
|             identifier (str): Model identifier used to initialize the HUB training session. | ||||
|                 It can be a URL string or a model key with specific format. | ||||
| 
 | ||||
|         Raises: | ||||
| @ -48,9 +45,10 @@ class HUBTrainingSession: | ||||
|             ConnectionError: If connecting with global API key is not supported. | ||||
|         """ | ||||
|         self.rate_limits = { | ||||
|             'metrics': 3.0, | ||||
|             'ckpt': 900.0, | ||||
|             'heartbeat': 300.0, }  # rate limits (seconds) | ||||
|             "metrics": 3.0, | ||||
|             "ckpt": 900.0, | ||||
|             "heartbeat": 300.0, | ||||
|         }  # rate limits (seconds) | ||||
|         self.metrics_queue = {}  # holds metrics for each epoch until upload | ||||
|         self.timers = {}  # holds timers in ultralytics/utils/callbacks/hub.py | ||||
| 
 | ||||
| @ -58,8 +56,8 @@ class HUBTrainingSession: | ||||
|         api_key, model_id, self.filename = self._parse_identifier(identifier) | ||||
| 
 | ||||
|         # Get credentials | ||||
|         active_key = api_key or SETTINGS.get('api_key') | ||||
|         credentials = {'api_key': active_key} if active_key else None  # set credentials | ||||
|         active_key = api_key or SETTINGS.get("api_key") | ||||
|         credentials = {"api_key": active_key} if active_key else None  # set credentials | ||||
| 
 | ||||
|         # Initialize client | ||||
|         self.client = HUBClient(credentials) | ||||
| @ -72,35 +70,37 @@ class HUBTrainingSession: | ||||
|     def load_model(self, model_id): | ||||
|         # Initialize model | ||||
|         self.model = self.client.model(model_id) | ||||
|         self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}' | ||||
|         self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" | ||||
| 
 | ||||
|         self._set_train_args() | ||||
| 
 | ||||
|         # Start heartbeats for HUB to monitor agent | ||||
|         self.model.start_heartbeat(self.rate_limits['heartbeat']) | ||||
|         LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀') | ||||
|         self.model.start_heartbeat(self.rate_limits["heartbeat"]) | ||||
|         LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") | ||||
| 
 | ||||
|     def create_model(self, model_args): | ||||
|         # Initialize model | ||||
|         payload = { | ||||
|             'config': { | ||||
|                 'batchSize': model_args.get('batch', -1), | ||||
|                 'epochs': model_args.get('epochs', 300), | ||||
|                 'imageSize': model_args.get('imgsz', 640), | ||||
|                 'patience': model_args.get('patience', 100), | ||||
|                 'device': model_args.get('device', ''), | ||||
|                 'cache': model_args.get('cache', 'ram'), }, | ||||
|             'dataset': { | ||||
|                 'name': model_args.get('data')}, | ||||
|             'lineage': { | ||||
|                 'architecture': { | ||||
|                     'name': self.filename.replace('.pt', '').replace('.yaml', ''), }, | ||||
|                 'parent': {}, }, | ||||
|             'meta': { | ||||
|                 'name': self.filename}, } | ||||
|             "config": { | ||||
|                 "batchSize": model_args.get("batch", -1), | ||||
|                 "epochs": model_args.get("epochs", 300), | ||||
|                 "imageSize": model_args.get("imgsz", 640), | ||||
|                 "patience": model_args.get("patience", 100), | ||||
|                 "device": model_args.get("device", ""), | ||||
|                 "cache": model_args.get("cache", "ram"), | ||||
|             }, | ||||
|             "dataset": {"name": model_args.get("data")}, | ||||
|             "lineage": { | ||||
|                 "architecture": { | ||||
|                     "name": self.filename.replace(".pt", "").replace(".yaml", ""), | ||||
|                 }, | ||||
|                 "parent": {}, | ||||
|             }, | ||||
|             "meta": {"name": self.filename}, | ||||
|         } | ||||
| 
 | ||||
|         if self.filename.endswith('.pt'): | ||||
|             payload['lineage']['parent']['name'] = self.filename | ||||
|         if self.filename.endswith(".pt"): | ||||
|             payload["lineage"]["parent"]["name"] = self.filename | ||||
| 
 | ||||
|         self.model.create_model(payload) | ||||
| 
 | ||||
| @ -109,12 +109,12 @@ class HUBTrainingSession: | ||||
|         if not self.model.id: | ||||
|             return | ||||
| 
 | ||||
|         self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}' | ||||
|         self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" | ||||
| 
 | ||||
|         # Start heartbeats for HUB to monitor agent | ||||
|         self.model.start_heartbeat(self.rate_limits['heartbeat']) | ||||
|         self.model.start_heartbeat(self.rate_limits["heartbeat"]) | ||||
| 
 | ||||
|         LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀') | ||||
|         LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") | ||||
| 
 | ||||
|     def _parse_identifier(self, identifier): | ||||
|         """ | ||||
| @ -140,12 +140,12 @@ class HUBTrainingSession: | ||||
|         api_key, model_id, filename = None, None, None | ||||
| 
 | ||||
|         # Check if identifier is a HUB URL | ||||
|         if identifier.startswith(f'{HUB_WEB_ROOT}/models/'): | ||||
|         if identifier.startswith(f"{HUB_WEB_ROOT}/models/"): | ||||
|             # Extract the model_id after the HUB_WEB_ROOT URL | ||||
|             model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1] | ||||
|             model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1] | ||||
|         else: | ||||
|             # Split the identifier based on underscores only if it's not a HUB URL | ||||
|             parts = identifier.split('_') | ||||
|             parts = identifier.split("_") | ||||
| 
 | ||||
|             # Check if identifier is in the format of API key and model ID | ||||
|             if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20: | ||||
| @ -154,43 +154,46 @@ class HUBTrainingSession: | ||||
|             elif len(parts) == 1 and len(parts[0]) == 20: | ||||
|                 model_id = parts[0] | ||||
|             # Check if identifier is a local filename | ||||
|             elif identifier.endswith('.pt') or identifier.endswith('.yaml'): | ||||
|             elif identifier.endswith(".pt") or identifier.endswith(".yaml"): | ||||
|                 filename = identifier | ||||
|             else: | ||||
|                 raise HUBModelError( | ||||
|                     f"model='{identifier}' could not be parsed. Check format is correct. " | ||||
|                     f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.') | ||||
|                     f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file." | ||||
|                 ) | ||||
| 
 | ||||
|         return api_key, model_id, filename | ||||
| 
 | ||||
|     def _set_train_args(self, **kwargs): | ||||
|         if self.model.is_trained(): | ||||
|             # Model is already trained | ||||
|             raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀')) | ||||
|             raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀")) | ||||
| 
 | ||||
|         if self.model.is_resumable(): | ||||
|             # Model has saved weights | ||||
|             self.train_args = {'data': self.model.get_dataset_url(), 'resume': True} | ||||
|             self.model_file = self.model.get_weights_url('last') | ||||
|             self.train_args = {"data": self.model.get_dataset_url(), "resume": True} | ||||
|             self.model_file = self.model.get_weights_url("last") | ||||
|         else: | ||||
|             # Model has no saved weights | ||||
|             def get_train_args(config): | ||||
|                 return { | ||||
|                     'batch': config['batchSize'], | ||||
|                     'epochs': config['epochs'], | ||||
|                     'imgsz': config['imageSize'], | ||||
|                     'patience': config['patience'], | ||||
|                     'device': config['device'], | ||||
|                     'cache': config['cache'], | ||||
|                     'data': self.model.get_dataset_url(), } | ||||
|                     "batch": config["batchSize"], | ||||
|                     "epochs": config["epochs"], | ||||
|                     "imgsz": config["imageSize"], | ||||
|                     "patience": config["patience"], | ||||
|                     "device": config["device"], | ||||
|                     "cache": config["cache"], | ||||
|                     "data": self.model.get_dataset_url(), | ||||
|                 } | ||||
| 
 | ||||
|             self.train_args = get_train_args(self.model.data.get('config')) | ||||
|             self.train_args = get_train_args(self.model.data.get("config")) | ||||
|             # Set the model file as either a *.pt or *.yaml file | ||||
|             self.model_file = (self.model.get_weights_url('parent') | ||||
|                                if self.model.is_pretrained() else self.model.get_architecture()) | ||||
|             self.model_file = ( | ||||
|                 self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture() | ||||
|             ) | ||||
| 
 | ||||
|         if not self.train_args.get('data'): | ||||
|             raise ValueError('Dataset may still be processing. Please wait a minute and try again.')  # RF fix | ||||
|         if not self.train_args.get("data"): | ||||
|             raise ValueError("Dataset may still be processing. Please wait a minute and try again.")  # RF fix | ||||
| 
 | ||||
|         self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False)  # YOLOv5->YOLOv5u | ||||
|         self.model_id = self.model.id | ||||
| @ -206,12 +209,11 @@ class HUBTrainingSession: | ||||
|         *args, | ||||
|         **kwargs, | ||||
|     ): | ||||
| 
 | ||||
|         def retry_request(): | ||||
|             t0 = time.time()  # Record the start time for the timeout | ||||
|             for i in range(retry + 1): | ||||
|                 if (time.time() - t0) > timeout: | ||||
|                     LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}') | ||||
|                     LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}") | ||||
|                     break  # Timeout reached, exit loop | ||||
| 
 | ||||
|                 response = request_func(*args, **kwargs) | ||||
| @ -219,8 +221,8 @@ class HUBTrainingSession: | ||||
|                     self._show_upload_progress(progress_total, response) | ||||
| 
 | ||||
|                 if response is None: | ||||
|                     LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}') | ||||
|                     time.sleep(2 ** i)  # Exponential backoff before retrying | ||||
|                     LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}") | ||||
|                     time.sleep(2**i)  # Exponential backoff before retrying | ||||
|                     continue  # Skip further processing and retry | ||||
| 
 | ||||
|                 if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES: | ||||
| @ -231,13 +233,13 @@ class HUBTrainingSession: | ||||
|                     message = self._get_failure_message(response, retry, timeout) | ||||
| 
 | ||||
|                     if verbose: | ||||
|                         LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})') | ||||
|                         LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})") | ||||
| 
 | ||||
|                 if not self._should_retry(response.status_code): | ||||
|                     LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}') | ||||
|                     LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}") | ||||
|                     break  # Not an error that should be retried, exit loop | ||||
| 
 | ||||
|                 time.sleep(2 ** i)  # Exponential backoff for retries | ||||
|                 time.sleep(2**i)  # Exponential backoff for retries | ||||
| 
 | ||||
|             return response | ||||
| 
 | ||||
| @ -253,7 +255,8 @@ class HUBTrainingSession: | ||||
|         retry_codes = { | ||||
|             HTTPStatus.REQUEST_TIMEOUT, | ||||
|             HTTPStatus.BAD_GATEWAY, | ||||
|             HTTPStatus.GATEWAY_TIMEOUT, } | ||||
|             HTTPStatus.GATEWAY_TIMEOUT, | ||||
|         } | ||||
|         return True if status_code in retry_codes else False | ||||
| 
 | ||||
|     def _get_failure_message(self, response: requests.Response, retry: int, timeout: int): | ||||
| @ -269,16 +272,18 @@ class HUBTrainingSession: | ||||
|             str: The retry message. | ||||
|         """ | ||||
|         if self._should_retry(response.status_code): | ||||
|             return f'Retrying {retry}x for {timeout}s.' if retry else '' | ||||
|             return f"Retrying {retry}x for {timeout}s." if retry else "" | ||||
|         elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS:  # rate limit | ||||
|             headers = response.headers | ||||
|             return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). " | ||||
|                     f"Please retry after {headers['Retry-After']}s.") | ||||
|             return ( | ||||
|                 f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). " | ||||
|                 f"Please retry after {headers['Retry-After']}s." | ||||
|             ) | ||||
|         else: | ||||
|             try: | ||||
|                 return response.json().get('message', 'No JSON message.') | ||||
|                 return response.json().get("message", "No JSON message.") | ||||
|             except AttributeError: | ||||
|                 return 'Unable to read JSON.' | ||||
|                 return "Unable to read JSON." | ||||
| 
 | ||||
|     def upload_metrics(self): | ||||
|         """Upload model metrics to Ultralytics HUB.""" | ||||
| @ -303,7 +308,7 @@ class HUBTrainingSession: | ||||
|             final (bool): Indicates if the model is the final model after training. | ||||
|         """ | ||||
|         if Path(weights).is_file(): | ||||
|             progress_total = (Path(weights).stat().st_size if final else None)  # Only show progress if final | ||||
|             progress_total = Path(weights).stat().st_size if final else None  # Only show progress if final | ||||
|             self.request_queue( | ||||
|                 self.model.upload_model, | ||||
|                 epoch=epoch, | ||||
| @ -317,7 +322,7 @@ class HUBTrainingSession: | ||||
|                 progress_total=progress_total, | ||||
|             ) | ||||
|         else: | ||||
|             LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.') | ||||
|             LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.") | ||||
| 
 | ||||
|     def _show_upload_progress(self, content_length: int, response: requests.Response) -> None: | ||||
|         """ | ||||
| @ -330,6 +335,6 @@ class HUBTrainingSession: | ||||
|         Returns: | ||||
|             (None) | ||||
|         """ | ||||
|         with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar: | ||||
|         with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar: | ||||
|             for data in response.iter_content(chunk_size=1024): | ||||
|                 pbar.update(len(data)) | ||||
|  | ||||
| @ -9,12 +9,26 @@ from pathlib import Path | ||||
| 
 | ||||
| import requests | ||||
| 
 | ||||
| from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, __version__, | ||||
|                                colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package) | ||||
| from ultralytics.utils import ( | ||||
|     ENVIRONMENT, | ||||
|     LOGGER, | ||||
|     ONLINE, | ||||
|     RANK, | ||||
|     SETTINGS, | ||||
|     TESTS_RUNNING, | ||||
|     TQDM, | ||||
|     TryExcept, | ||||
|     __version__, | ||||
|     colorstr, | ||||
|     get_git_origin_url, | ||||
|     is_colab, | ||||
|     is_git_dir, | ||||
|     is_pip_package, | ||||
| ) | ||||
| from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES | ||||
| 
 | ||||
| PREFIX = colorstr('Ultralytics HUB: ') | ||||
| HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.' | ||||
| PREFIX = colorstr("Ultralytics HUB: ") | ||||
| HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance." | ||||
| 
 | ||||
| 
 | ||||
| def request_with_credentials(url: str) -> any: | ||||
| @ -31,11 +45,13 @@ def request_with_credentials(url: str) -> any: | ||||
|         OSError: If the function is not run in a Google Colab environment. | ||||
|     """ | ||||
|     if not is_colab(): | ||||
|         raise OSError('request_with_credentials() must run in a Colab environment') | ||||
|         raise OSError("request_with_credentials() must run in a Colab environment") | ||||
|     from google.colab import output  # noqa | ||||
|     from IPython import display  # noqa | ||||
| 
 | ||||
|     display.display( | ||||
|         display.Javascript(""" | ||||
|         display.Javascript( | ||||
|             """ | ||||
|             window._hub_tmp = new Promise((resolve, reject) => { | ||||
|                 const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000) | ||||
|                 fetch("%s", { | ||||
| @ -50,8 +66,11 @@ def request_with_credentials(url: str) -> any: | ||||
|                     reject(err); | ||||
|                 }); | ||||
|             }); | ||||
|             """ % url)) | ||||
|     return output.eval_js('_hub_tmp') | ||||
|             """ | ||||
|             % url | ||||
|         ) | ||||
|     ) | ||||
|     return output.eval_js("_hub_tmp") | ||||
| 
 | ||||
| 
 | ||||
| def requests_with_progress(method, url, **kwargs): | ||||
| @ -71,13 +90,13 @@ def requests_with_progress(method, url, **kwargs): | ||||
|         content length. | ||||
|         - If 'progress' is a number then progress bar will display assuming content length = progress. | ||||
|     """ | ||||
|     progress = kwargs.pop('progress', False) | ||||
|     progress = kwargs.pop("progress", False) | ||||
|     if not progress: | ||||
|         return requests.request(method, url, **kwargs) | ||||
|     response = requests.request(method, url, stream=True, **kwargs) | ||||
|     total = int(response.headers.get('content-length', 0) if isinstance(progress, bool) else progress)  # total size | ||||
|     total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress)  # total size | ||||
|     try: | ||||
|         pbar = TQDM(total=total, unit='B', unit_scale=True, unit_divisor=1024) | ||||
|         pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024) | ||||
|         for data in response.iter_content(chunk_size=1024): | ||||
|             pbar.update(len(data)) | ||||
|         pbar.close() | ||||
| @ -118,25 +137,27 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos | ||||
|             if r.status_code < 300:  # return codes in the 2xx range are generally considered "good" or "successful" | ||||
|                 break | ||||
|             try: | ||||
|                 m = r.json().get('message', 'No JSON message.') | ||||
|                 m = r.json().get("message", "No JSON message.") | ||||
|             except AttributeError: | ||||
|                 m = 'Unable to read JSON.' | ||||
|                 m = "Unable to read JSON." | ||||
|             if i == 0: | ||||
|                 if r.status_code in retry_codes: | ||||
|                     m += f' Retrying {retry}x for {timeout}s.' if retry else '' | ||||
|                     m += f" Retrying {retry}x for {timeout}s." if retry else "" | ||||
|                 elif r.status_code == 429:  # rate limit | ||||
|                     h = r.headers  # response headers | ||||
|                     m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \ | ||||
|                     m = ( | ||||
|                         f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " | ||||
|                         f"Please retry after {h['Retry-After']}s." | ||||
|                     ) | ||||
|                 if verbose: | ||||
|                     LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})') | ||||
|                     LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})") | ||||
|                 if r.status_code not in retry_codes: | ||||
|                     return r | ||||
|             time.sleep(2 ** i)  # exponential standoff | ||||
|             time.sleep(2**i)  # exponential standoff | ||||
|         return r | ||||
| 
 | ||||
|     args = method, url | ||||
|     kwargs['progress'] = progress | ||||
|     kwargs["progress"] = progress | ||||
|     if thread: | ||||
|         threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start() | ||||
|     else: | ||||
| @ -155,7 +176,7 @@ class Events: | ||||
|         enabled (bool): A flag to enable or disable Events based on certain conditions. | ||||
|     """ | ||||
| 
 | ||||
|     url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw' | ||||
|     url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw" | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         """Initializes the Events object with default values for events, rate_limit, and metadata.""" | ||||
| @ -163,19 +184,21 @@ class Events: | ||||
|         self.rate_limit = 60.0  # rate limit (seconds) | ||||
|         self.t = 0.0  # rate limit timer (seconds) | ||||
|         self.metadata = { | ||||
|             'cli': Path(sys.argv[0]).name == 'yolo', | ||||
|             'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other', | ||||
|             'python': '.'.join(platform.python_version_tuple()[:2]),  # i.e. 3.10 | ||||
|             'version': __version__, | ||||
|             'env': ENVIRONMENT, | ||||
|             'session_id': round(random.random() * 1E15), | ||||
|             'engagement_time_msec': 1000} | ||||
|         self.enabled = \ | ||||
|             SETTINGS['sync'] and \ | ||||
|             RANK in (-1, 0) and \ | ||||
|             not TESTS_RUNNING and \ | ||||
|             ONLINE and \ | ||||
|             (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git') | ||||
|             "cli": Path(sys.argv[0]).name == "yolo", | ||||
|             "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other", | ||||
|             "python": ".".join(platform.python_version_tuple()[:2]),  # i.e. 3.10 | ||||
|             "version": __version__, | ||||
|             "env": ENVIRONMENT, | ||||
|             "session_id": round(random.random() * 1e15), | ||||
|             "engagement_time_msec": 1000, | ||||
|         } | ||||
|         self.enabled = ( | ||||
|             SETTINGS["sync"] | ||||
|             and RANK in (-1, 0) | ||||
|             and not TESTS_RUNNING | ||||
|             and ONLINE | ||||
|             and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git") | ||||
|         ) | ||||
| 
 | ||||
|     def __call__(self, cfg): | ||||
|         """ | ||||
| @ -191,11 +214,13 @@ class Events: | ||||
|         # Attempt to add to events | ||||
|         if len(self.events) < 25:  # Events list limited to 25 events (drop any events past this) | ||||
|             params = { | ||||
|                 **self.metadata, 'task': cfg.task, | ||||
|                 'model': cfg.model if cfg.model in GITHUB_ASSETS_NAMES else 'custom'} | ||||
|             if cfg.mode == 'export': | ||||
|                 params['format'] = cfg.format | ||||
|             self.events.append({'name': cfg.mode, 'params': params}) | ||||
|                 **self.metadata, | ||||
|                 "task": cfg.task, | ||||
|                 "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom", | ||||
|             } | ||||
|             if cfg.mode == "export": | ||||
|                 params["format"] = cfg.format | ||||
|             self.events.append({"name": cfg.mode, "params": params}) | ||||
| 
 | ||||
|         # Check rate limit | ||||
|         t = time.time() | ||||
| @ -204,10 +229,10 @@ class Events: | ||||
|             return | ||||
| 
 | ||||
|         # Time is over rate limiter, send now | ||||
|         data = {'client_id': SETTINGS['uuid'], 'events': self.events}  # SHA-256 anonymized UUID hash and events list | ||||
|         data = {"client_id": SETTINGS["uuid"], "events": self.events}  # SHA-256 anonymized UUID hash and events list | ||||
| 
 | ||||
|         # POST equivalent to requests.post(self.url, json=data) | ||||
|         smart_request('post', self.url, json=data, retry=0, verbose=False) | ||||
|         smart_request("post", self.url, json=data, retry=0, verbose=False) | ||||
| 
 | ||||
|         # Reset events and rate limit timer | ||||
|         self.events = [] | ||||
|  | ||||
| @ -4,4 +4,4 @@ from .rtdetr import RTDETR | ||||
| from .sam import SAM | ||||
| from .yolo import YOLO | ||||
| 
 | ||||
| __all__ = 'YOLO', 'RTDETR', 'SAM'  # allow simpler import | ||||
| __all__ = "YOLO", "RTDETR", "SAM"  # allow simpler import | ||||
|  | ||||
| @ -5,4 +5,4 @@ from .predict import FastSAMPredictor | ||||
| from .prompt import FastSAMPrompt | ||||
| from .val import FastSAMValidator | ||||
| 
 | ||||
| __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator' | ||||
| __all__ = "FastSAMPredictor", "FastSAM", "FastSAMPrompt", "FastSAMValidator" | ||||
|  | ||||
| @ -21,14 +21,14 @@ class FastSAM(Model): | ||||
|         ``` | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, model='FastSAM-x.pt'): | ||||
|     def __init__(self, model="FastSAM-x.pt"): | ||||
|         """Call the __init__ method of the parent class (YOLO) with the updated default model.""" | ||||
|         if str(model) == 'FastSAM.pt': | ||||
|             model = 'FastSAM-x.pt' | ||||
|         assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.' | ||||
|         super().__init__(model=model, task='segment') | ||||
|         if str(model) == "FastSAM.pt": | ||||
|             model = "FastSAM-x.pt" | ||||
|         assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models." | ||||
|         super().__init__(model=model, task="segment") | ||||
| 
 | ||||
|     @property | ||||
|     def task_map(self): | ||||
|         """Returns a dictionary mapping segment task to corresponding predictor and validator classes.""" | ||||
|         return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}} | ||||
|         return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}} | ||||
|  | ||||
| @ -33,7 +33,7 @@ class FastSAMPredictor(DetectionPredictor): | ||||
|             _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction. | ||||
|         """ | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
|         self.args.task = 'segment' | ||||
|         self.args.task = "segment" | ||||
| 
 | ||||
|     def postprocess(self, preds, img, orig_imgs): | ||||
|         """ | ||||
| @ -55,7 +55,8 @@ class FastSAMPredictor(DetectionPredictor): | ||||
|             agnostic=self.args.agnostic_nms, | ||||
|             max_det=self.args.max_det, | ||||
|             nc=1,  # set to 1 class since SAM has no class predictions | ||||
|             classes=self.args.classes) | ||||
|             classes=self.args.classes, | ||||
|         ) | ||||
|         full_box = torch.zeros(p[0].shape[1], device=p[0].device) | ||||
|         full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 | ||||
|         full_box = full_box.view(1, -1) | ||||
|  | ||||
| @ -23,7 +23,7 @@ class FastSAMPrompt: | ||||
|         clip: CLIP model for linear assignment. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, source, results, device='cuda') -> None: | ||||
|     def __init__(self, source, results, device="cuda") -> None: | ||||
|         """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment.""" | ||||
|         self.device = device | ||||
|         self.results = results | ||||
| @ -34,7 +34,8 @@ class FastSAMPrompt: | ||||
|             import clip  # for linear_assignment | ||||
|         except ImportError: | ||||
|             from ultralytics.utils.checks import check_requirements | ||||
|             check_requirements('git+https://github.com/openai/CLIP.git') | ||||
| 
 | ||||
|             check_requirements("git+https://github.com/openai/CLIP.git") | ||||
|             import clip | ||||
|         self.clip = clip | ||||
| 
 | ||||
| @ -46,11 +47,11 @@ class FastSAMPrompt: | ||||
|         x1, y1, x2, y2 = bbox | ||||
|         segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] | ||||
|         segmented_image = Image.fromarray(segmented_image_array) | ||||
|         black_image = Image.new('RGB', image.size, (255, 255, 255)) | ||||
|         black_image = Image.new("RGB", image.size, (255, 255, 255)) | ||||
|         # transparency_mask = np.zeros_like((), dtype=np.uint8) | ||||
|         transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8) | ||||
|         transparency_mask[y1:y2, x1:x2] = 255 | ||||
|         transparency_mask_image = Image.fromarray(transparency_mask, mode='L') | ||||
|         transparency_mask_image = Image.fromarray(transparency_mask, mode="L") | ||||
|         black_image.paste(segmented_image, mask=transparency_mask_image) | ||||
|         return black_image | ||||
| 
 | ||||
| @ -65,11 +66,12 @@ class FastSAMPrompt: | ||||
|             mask = result.masks.data[i] == 1.0 | ||||
|             if torch.sum(mask) >= filter: | ||||
|                 annotation = { | ||||
|                     'id': i, | ||||
|                     'segmentation': mask.cpu().numpy(), | ||||
|                     'bbox': result.boxes.data[i], | ||||
|                     'score': result.boxes.conf[i]} | ||||
|                 annotation['area'] = annotation['segmentation'].sum() | ||||
|                     "id": i, | ||||
|                     "segmentation": mask.cpu().numpy(), | ||||
|                     "bbox": result.boxes.data[i], | ||||
|                     "score": result.boxes.conf[i], | ||||
|                 } | ||||
|                 annotation["area"] = annotation["segmentation"].sum() | ||||
|                 annotations.append(annotation) | ||||
|         return annotations | ||||
| 
 | ||||
| @ -91,7 +93,8 @@ class FastSAMPrompt: | ||||
|                 y2 = max(y2, y_t + h_t) | ||||
|         return [x1, y1, x2, y2] | ||||
| 
 | ||||
|     def plot(self, | ||||
|     def plot( | ||||
|         self, | ||||
|         annotations, | ||||
|         output, | ||||
|         bbox=None, | ||||
| @ -100,7 +103,8 @@ class FastSAMPrompt: | ||||
|         mask_random_color=True, | ||||
|         better_quality=True, | ||||
|         retina=False, | ||||
|              with_contours=True): | ||||
|         with_contours=True, | ||||
|     ): | ||||
|         """ | ||||
|         Plots annotations, bounding boxes, and points on images and saves the output. | ||||
| 
 | ||||
| @ -139,7 +143,8 @@ class FastSAMPrompt: | ||||
|                         mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) | ||||
|                         masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) | ||||
| 
 | ||||
|                 self.fast_show_mask(masks, | ||||
|                 self.fast_show_mask( | ||||
|                     masks, | ||||
|                     plt.gca(), | ||||
|                     random_color=mask_random_color, | ||||
|                     bbox=bbox, | ||||
| @ -147,7 +152,8 @@ class FastSAMPrompt: | ||||
|                     pointlabel=point_label, | ||||
|                     retinamask=retina, | ||||
|                     target_height=original_h, | ||||
|                                     target_width=original_w) | ||||
|                     target_width=original_w, | ||||
|                 ) | ||||
| 
 | ||||
|                 if with_contours: | ||||
|                     contour_all = [] | ||||
| @ -166,10 +172,10 @@ class FastSAMPrompt: | ||||
|             # Save the figure | ||||
|             save_path = Path(output) / result_name | ||||
|             save_path.parent.mkdir(exist_ok=True, parents=True) | ||||
|             plt.axis('off') | ||||
|             plt.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True) | ||||
|             plt.axis("off") | ||||
|             plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True) | ||||
|             plt.close() | ||||
|             pbar.set_description(f'Saving {result_name} to {save_path}') | ||||
|             pbar.set_description(f"Saving {result_name} to {save_path}") | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def fast_show_mask( | ||||
| @ -212,26 +218,26 @@ class FastSAMPrompt: | ||||
|         mask_image = np.expand_dims(annotation, -1) * visual | ||||
| 
 | ||||
|         show = np.zeros((h, w, 4)) | ||||
|         h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij') | ||||
|         h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij") | ||||
|         indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) | ||||
| 
 | ||||
|         show[h_indices, w_indices, :] = mask_image[indices] | ||||
|         if bbox is not None: | ||||
|             x1, y1, x2, y2 = bbox | ||||
|             ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) | ||||
|             ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1)) | ||||
|         # Draw point | ||||
|         if points is not None: | ||||
|             plt.scatter( | ||||
|                 [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], | ||||
|                 [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], | ||||
|                 s=20, | ||||
|                 c='y', | ||||
|                 c="y", | ||||
|             ) | ||||
|             plt.scatter( | ||||
|                 [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], | ||||
|                 [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], | ||||
|                 s=20, | ||||
|                 c='m', | ||||
|                 c="m", | ||||
|             ) | ||||
| 
 | ||||
|         if not retinamask: | ||||
| @ -258,7 +264,7 @@ class FastSAMPrompt: | ||||
|         image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB)) | ||||
|         ori_w, ori_h = image.size | ||||
|         annotations = format_results | ||||
|         mask_h, mask_w = annotations[0]['segmentation'].shape | ||||
|         mask_h, mask_w = annotations[0]["segmentation"].shape | ||||
|         if ori_w != mask_w or ori_h != mask_h: | ||||
|             image = image.resize((mask_w, mask_h)) | ||||
|         cropped_boxes = [] | ||||
| @ -266,19 +272,19 @@ class FastSAMPrompt: | ||||
|         not_crop = [] | ||||
|         filter_id = [] | ||||
|         for _, mask in enumerate(annotations): | ||||
|             if np.sum(mask['segmentation']) <= 100: | ||||
|             if np.sum(mask["segmentation"]) <= 100: | ||||
|                 filter_id.append(_) | ||||
|                 continue | ||||
|             bbox = self._get_bbox_from_mask(mask['segmentation'])  # mask 的 bbox | ||||
|             cropped_boxes.append(self._segment_image(image, bbox))  # 保存裁剪的图片 | ||||
|             cropped_images.append(bbox)  # 保存裁剪的图片的bbox | ||||
|             bbox = self._get_bbox_from_mask(mask["segmentation"])  # bbox from mask | ||||
|             cropped_boxes.append(self._segment_image(image, bbox))  # save cropped image | ||||
|             cropped_images.append(bbox)  # save cropped image bbox | ||||
| 
 | ||||
|         return cropped_boxes, cropped_images, not_crop, filter_id, annotations | ||||
| 
 | ||||
|     def box_prompt(self, bbox): | ||||
|         """Modifies the bounding box properties and calculates IoU between masks and bounding box.""" | ||||
|         if self.results[0].masks is not None: | ||||
|             assert (bbox[2] != 0 and bbox[3] != 0) | ||||
|             assert bbox[2] != 0 and bbox[3] != 0 | ||||
|             if os.path.isdir(self.source): | ||||
|                 raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.") | ||||
|             masks = self.results[0].masks.data | ||||
| @ -290,7 +296,8 @@ class FastSAMPrompt: | ||||
|                     int(bbox[0] * w / target_width), | ||||
|                     int(bbox[1] * h / target_height), | ||||
|                     int(bbox[2] * w / target_width), | ||||
|                     int(bbox[3] * h / target_height), ] | ||||
|                     int(bbox[3] * h / target_height), | ||||
|                 ] | ||||
|             bbox[0] = max(round(bbox[0]), 0) | ||||
|             bbox[1] = max(round(bbox[1]), 0) | ||||
|             bbox[2] = min(round(bbox[2]), w) | ||||
| @ -299,7 +306,7 @@ class FastSAMPrompt: | ||||
|             # IoUs = torch.zeros(len(masks), dtype=torch.float32) | ||||
|             bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) | ||||
| 
 | ||||
|             masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) | ||||
|             masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2)) | ||||
|             orig_masks_area = torch.sum(masks, dim=(1, 2)) | ||||
| 
 | ||||
|             union = bbox_area + orig_masks_area - masks_area | ||||
| @ -316,13 +323,13 @@ class FastSAMPrompt: | ||||
|                 raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.") | ||||
|             masks = self._format_results(self.results[0], 0) | ||||
|             target_height, target_width = self.results[0].orig_shape | ||||
|             h = masks[0]['segmentation'].shape[0] | ||||
|             w = masks[0]['segmentation'].shape[1] | ||||
|             h = masks[0]["segmentation"].shape[0] | ||||
|             w = masks[0]["segmentation"].shape[1] | ||||
|             if h != target_height or w != target_width: | ||||
|                 points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] | ||||
|             onemask = np.zeros((h, w)) | ||||
|             for annotation in masks: | ||||
|                 mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation | ||||
|                 mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation | ||||
|                 for i, point in enumerate(points): | ||||
|                     if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: | ||||
|                         onemask += mask | ||||
| @ -337,12 +344,12 @@ class FastSAMPrompt: | ||||
|         if self.results[0].masks is not None: | ||||
|             format_results = self._format_results(self.results[0], 0) | ||||
|             cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) | ||||
|             clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device) | ||||
|             clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device) | ||||
|             scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) | ||||
|             max_idx = scores.argsort() | ||||
|             max_idx = max_idx[-1] | ||||
|             max_idx += sum(np.array(filter_id) <= int(max_idx)) | ||||
|             self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]['segmentation']])) | ||||
|             self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]])) | ||||
|         return self.results | ||||
| 
 | ||||
|     def everything_prompt(self): | ||||
|  | ||||
| @ -35,6 +35,6 @@ class FastSAMValidator(SegmentationValidator): | ||||
|             Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors. | ||||
|         """ | ||||
|         super().__init__(dataloader, save_dir, pbar, args, _callbacks) | ||||
|         self.args.task = 'segment' | ||||
|         self.args.task = "segment" | ||||
|         self.args.plots = False  # disable ConfusionMatrix and other plots to avoid errors | ||||
|         self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) | ||||
|  | ||||
| @ -4,4 +4,4 @@ from .model import NAS | ||||
| from .predict import NASPredictor | ||||
| from .val import NASValidator | ||||
| 
 | ||||
| __all__ = 'NASPredictor', 'NASValidator', 'NAS' | ||||
| __all__ = "NASPredictor", "NASValidator", "NAS" | ||||
|  | ||||
| @ -44,20 +44,21 @@ class NAS(Model): | ||||
|         YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, model='yolo_nas_s.pt') -> None: | ||||
|     def __init__(self, model="yolo_nas_s.pt") -> None: | ||||
|         """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model.""" | ||||
|         assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.' | ||||
|         super().__init__(model, task='detect') | ||||
|         assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models." | ||||
|         super().__init__(model, task="detect") | ||||
| 
 | ||||
|     @smart_inference_mode() | ||||
|     def _load(self, weights: str, task: str): | ||||
|         """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided.""" | ||||
|         import super_gradients | ||||
| 
 | ||||
|         suffix = Path(weights).suffix | ||||
|         if suffix == '.pt': | ||||
|         if suffix == ".pt": | ||||
|             self.model = torch.load(weights) | ||||
|         elif suffix == '': | ||||
|             self.model = super_gradients.training.models.get(weights, pretrained_weights='coco') | ||||
|         elif suffix == "": | ||||
|             self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") | ||||
|         # Standardize model | ||||
|         self.model.fuse = lambda verbose=True: self.model | ||||
|         self.model.stride = torch.tensor([32]) | ||||
| @ -65,7 +66,7 @@ class NAS(Model): | ||||
|         self.model.is_fused = lambda: False  # for info() | ||||
|         self.model.yaml = {}  # for info() | ||||
|         self.model.pt_path = weights  # for export() | ||||
|         self.model.task = 'detect'  # for export() | ||||
|         self.model.task = "detect"  # for export() | ||||
| 
 | ||||
|     def info(self, detailed=False, verbose=True): | ||||
|         """ | ||||
| @ -80,4 +81,4 @@ class NAS(Model): | ||||
|     @property | ||||
|     def task_map(self): | ||||
|         """Returns a dictionary mapping tasks to respective predictor and validator classes.""" | ||||
|         return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}} | ||||
|         return {"detect": {"predictor": NASPredictor, "validator": NASValidator}} | ||||
|  | ||||
| @ -39,12 +39,14 @@ class NASPredictor(BasePredictor): | ||||
|         boxes = ops.xyxy2xywh(preds_in[0][0]) | ||||
|         preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) | ||||
| 
 | ||||
|         preds = ops.non_max_suppression(preds, | ||||
|         preds = ops.non_max_suppression( | ||||
|             preds, | ||||
|             self.args.conf, | ||||
|             self.args.iou, | ||||
|             agnostic=self.args.agnostic_nms, | ||||
|             max_det=self.args.max_det, | ||||
|                                         classes=self.args.classes) | ||||
|             classes=self.args.classes, | ||||
|         ) | ||||
| 
 | ||||
|         if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list | ||||
|             orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) | ||||
|  | ||||
| @ -5,7 +5,7 @@ import torch | ||||
| from ultralytics.models.yolo.detect import DetectionValidator | ||||
| from ultralytics.utils import ops | ||||
| 
 | ||||
| __all__ = ['NASValidator'] | ||||
| __all__ = ["NASValidator"] | ||||
| 
 | ||||
| 
 | ||||
| class NASValidator(DetectionValidator): | ||||
| @ -38,11 +38,13 @@ class NASValidator(DetectionValidator): | ||||
|         """Apply Non-maximum suppression to prediction outputs.""" | ||||
|         boxes = ops.xyxy2xywh(preds_in[0][0]) | ||||
|         preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) | ||||
|         return ops.non_max_suppression(preds, | ||||
|         return ops.non_max_suppression( | ||||
|             preds, | ||||
|             self.args.conf, | ||||
|             self.args.iou, | ||||
|             labels=self.lb, | ||||
|             multi_label=False, | ||||
|             agnostic=self.args.single_cls, | ||||
|             max_det=self.args.max_det, | ||||
|                                        max_time_img=0.5) | ||||
|             max_time_img=0.5, | ||||
|         ) | ||||
|  | ||||
| @ -4,4 +4,4 @@ from .model import RTDETR | ||||
| from .predict import RTDETRPredictor | ||||
| from .val import RTDETRValidator | ||||
| 
 | ||||
| __all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR' | ||||
| __all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR" | ||||
|  | ||||
| @ -24,7 +24,7 @@ class RTDETR(Model): | ||||
|         model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, model='rtdetr-l.pt') -> None: | ||||
|     def __init__(self, model="rtdetr-l.pt") -> None: | ||||
|         """ | ||||
|         Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats. | ||||
| 
 | ||||
| @ -34,9 +34,9 @@ class RTDETR(Model): | ||||
|         Raises: | ||||
|             NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. | ||||
|         """ | ||||
|         if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'): | ||||
|             raise NotImplementedError('RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.') | ||||
|         super().__init__(model=model, task='detect') | ||||
|         if model and model.split(".")[-1] not in ("pt", "yaml", "yml"): | ||||
|             raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.") | ||||
|         super().__init__(model=model, task="detect") | ||||
| 
 | ||||
|     @property | ||||
|     def task_map(self) -> dict: | ||||
| @ -47,8 +47,10 @@ class RTDETR(Model): | ||||
|             dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model. | ||||
|         """ | ||||
|         return { | ||||
|             'detect': { | ||||
|                 'predictor': RTDETRPredictor, | ||||
|                 'validator': RTDETRValidator, | ||||
|                 'trainer': RTDETRTrainer, | ||||
|                 'model': RTDETRDetectionModel}} | ||||
|             "detect": { | ||||
|                 "predictor": RTDETRPredictor, | ||||
|                 "validator": RTDETRValidator, | ||||
|                 "trainer": RTDETRTrainer, | ||||
|                 "model": RTDETRDetectionModel, | ||||
|             } | ||||
|         } | ||||
|  | ||||
| @ -43,12 +43,12 @@ class RTDETRTrainer(DetectionTrainer): | ||||
|         Returns: | ||||
|             (RTDETRDetectionModel): Initialized model. | ||||
|         """ | ||||
|         model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) | ||||
|         model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) | ||||
|         if weights: | ||||
|             model.load(weights) | ||||
|         return model | ||||
| 
 | ||||
|     def build_dataset(self, img_path, mode='val', batch=None): | ||||
|     def build_dataset(self, img_path, mode="val", batch=None): | ||||
|         """ | ||||
|         Build and return an RT-DETR dataset for training or validation. | ||||
| 
 | ||||
| @ -60,15 +60,17 @@ class RTDETRTrainer(DetectionTrainer): | ||||
|         Returns: | ||||
|             (RTDETRDataset): Dataset object for the specific mode. | ||||
|         """ | ||||
|         return RTDETRDataset(img_path=img_path, | ||||
|         return RTDETRDataset( | ||||
|             img_path=img_path, | ||||
|             imgsz=self.args.imgsz, | ||||
|             batch_size=batch, | ||||
|                              augment=mode == 'train', | ||||
|             augment=mode == "train", | ||||
|             hyp=self.args, | ||||
|             rect=False, | ||||
|             cache=self.args.cache or None, | ||||
|                              prefix=colorstr(f'{mode}: '), | ||||
|                              data=self.data) | ||||
|             prefix=colorstr(f"{mode}: "), | ||||
|             data=self.data, | ||||
|         ) | ||||
| 
 | ||||
|     def get_validator(self): | ||||
|         """ | ||||
| @ -77,7 +79,7 @@ class RTDETRTrainer(DetectionTrainer): | ||||
|         Returns: | ||||
|             (RTDETRValidator): Validator object for model validation. | ||||
|         """ | ||||
|         self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss' | ||||
|         self.loss_names = "giou_loss", "cls_loss", "l1_loss" | ||||
|         return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) | ||||
| 
 | ||||
|     def preprocess_batch(self, batch): | ||||
| @ -91,10 +93,10 @@ class RTDETRTrainer(DetectionTrainer): | ||||
|             (dict): Preprocessed batch. | ||||
|         """ | ||||
|         batch = super().preprocess_batch(batch) | ||||
|         bs = len(batch['img']) | ||||
|         batch_idx = batch['batch_idx'] | ||||
|         bs = len(batch["img"]) | ||||
|         batch_idx = batch["batch_idx"] | ||||
|         gt_bbox, gt_class = [], [] | ||||
|         for i in range(bs): | ||||
|             gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device)) | ||||
|             gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) | ||||
|             gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device)) | ||||
|             gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) | ||||
|         return batch | ||||
|  | ||||
| @ -7,7 +7,7 @@ from ultralytics.data.augment import Compose, Format, v8_transforms | ||||
| from ultralytics.models.yolo.detect import DetectionValidator | ||||
| from ultralytics.utils import colorstr, ops | ||||
| 
 | ||||
| __all__ = 'RTDETRValidator',  # tuple or list | ||||
| __all__ = ("RTDETRValidator",)  # tuple or list | ||||
| 
 | ||||
| 
 | ||||
| class RTDETRDataset(YOLODataset): | ||||
| @ -37,13 +37,16 @@ class RTDETRDataset(YOLODataset): | ||||
|             # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)]) | ||||
|             transforms = Compose([]) | ||||
|         transforms.append( | ||||
|             Format(bbox_format='xywh', | ||||
|             Format( | ||||
|                 bbox_format="xywh", | ||||
|                 normalize=True, | ||||
|                 return_mask=self.use_segments, | ||||
|                 return_keypoint=self.use_keypoints, | ||||
|                 batch_idx=True, | ||||
|                 mask_ratio=hyp.mask_ratio, | ||||
|                    mask_overlap=hyp.overlap_mask)) | ||||
|                 mask_overlap=hyp.overlap_mask, | ||||
|             ) | ||||
|         ) | ||||
|         return transforms | ||||
| 
 | ||||
| 
 | ||||
| @ -68,7 +71,7 @@ class RTDETRValidator(DetectionValidator): | ||||
|         For further details on the attributes and methods, refer to the parent DetectionValidator class. | ||||
|     """ | ||||
| 
 | ||||
|     def build_dataset(self, img_path, mode='val', batch=None): | ||||
|     def build_dataset(self, img_path, mode="val", batch=None): | ||||
|         """ | ||||
|         Build an RTDETR Dataset. | ||||
| 
 | ||||
| @ -85,8 +88,9 @@ class RTDETRValidator(DetectionValidator): | ||||
|             hyp=self.args, | ||||
|             rect=False,  # no rect | ||||
|             cache=self.args.cache or None, | ||||
|             prefix=colorstr(f'{mode}: '), | ||||
|             data=self.data) | ||||
|             prefix=colorstr(f"{mode}: "), | ||||
|             data=self.data, | ||||
|         ) | ||||
| 
 | ||||
|     def postprocess(self, preds): | ||||
|         """Apply Non-maximum suppression to prediction outputs.""" | ||||
| @ -108,12 +112,12 @@ class RTDETRValidator(DetectionValidator): | ||||
| 
 | ||||
|     def _prepare_batch(self, si, batch): | ||||
|         """Prepares a batch for training or inference by applying transformations.""" | ||||
|         idx = batch['batch_idx'] == si | ||||
|         cls = batch['cls'][idx].squeeze(-1) | ||||
|         bbox = batch['bboxes'][idx] | ||||
|         ori_shape = batch['ori_shape'][si] | ||||
|         imgsz = batch['img'].shape[2:] | ||||
|         ratio_pad = batch['ratio_pad'][si] | ||||
|         idx = batch["batch_idx"] == si | ||||
|         cls = batch["cls"][idx].squeeze(-1) | ||||
|         bbox = batch["bboxes"][idx] | ||||
|         ori_shape = batch["ori_shape"][si] | ||||
|         imgsz = batch["img"].shape[2:] | ||||
|         ratio_pad = batch["ratio_pad"][si] | ||||
|         if len(cls): | ||||
|             bbox = ops.xywh2xyxy(bbox)  # target boxes | ||||
|             bbox[..., [0, 2]] *= ori_shape[1]  # native-space pred | ||||
| @ -124,6 +128,6 @@ class RTDETRValidator(DetectionValidator): | ||||
|     def _prepare_pred(self, pred, pbatch): | ||||
|         """Prepares and returns a batch with transformed bounding boxes and class labels.""" | ||||
|         predn = pred.clone() | ||||
|         predn[..., [0, 2]] *= pbatch['ori_shape'][1] / self.args.imgsz  # native-space pred | ||||
|         predn[..., [1, 3]] *= pbatch['ori_shape'][0] / self.args.imgsz  # native-space pred | ||||
|         predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz  # native-space pred | ||||
|         predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz  # native-space pred | ||||
|         return predn.float() | ||||
|  | ||||
| @ -3,4 +3,4 @@ | ||||
| from .model import SAM | ||||
| from .predict import Predictor | ||||
| 
 | ||||
| __all__ = 'SAM', 'Predictor'  # tuple or list | ||||
| __all__ = "SAM", "Predictor"  # tuple or list | ||||
|  | ||||
| @ -8,10 +8,9 @@ import numpy as np | ||||
| import torch | ||||
| 
 | ||||
| 
 | ||||
| def is_box_near_crop_edge(boxes: torch.Tensor, | ||||
|                           crop_box: List[int], | ||||
|                           orig_box: List[int], | ||||
|                           atol: float = 20.0) -> torch.Tensor: | ||||
| def is_box_near_crop_edge( | ||||
|     boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 | ||||
| ) -> torch.Tensor: | ||||
|     """Return a boolean tensor indicating if boxes are near the crop edge.""" | ||||
|     crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) | ||||
|     orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) | ||||
| @ -24,10 +23,10 @@ def is_box_near_crop_edge(boxes: torch.Tensor, | ||||
| 
 | ||||
| def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: | ||||
|     """Yield batches of data from the input arguments.""" | ||||
|     assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.' | ||||
|     assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs." | ||||
|     n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) | ||||
|     for b in range(n_batches): | ||||
|         yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args] | ||||
|         yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] | ||||
| 
 | ||||
| 
 | ||||
| def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor: | ||||
| @ -39,9 +38,8 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh | ||||
|     """ | ||||
|     # One mask is always contained inside the other. | ||||
|     # Save memory by preventing unnecessary cast to torch.int64 | ||||
|     intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, | ||||
|                                                                                                   dtype=torch.int32)) | ||||
|     unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)) | ||||
|     intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) | ||||
|     unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) | ||||
|     return intersections / unions | ||||
| 
 | ||||
| 
 | ||||
| @ -56,11 +54,12 @@ def build_point_grid(n_per_side: int) -> np.ndarray: | ||||
| 
 | ||||
| def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]: | ||||
|     """Generate point grids for all crop layers.""" | ||||
|     return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)] | ||||
|     return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)] | ||||
| 
 | ||||
| 
 | ||||
| def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int, | ||||
|                         overlap_ratio: float) -> Tuple[List[List[int]], List[int]]: | ||||
| def generate_crop_boxes( | ||||
|     im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float | ||||
| ) -> Tuple[List[List[int]], List[int]]: | ||||
|     """ | ||||
|     Generates a list of crop boxes of different sizes. | ||||
| 
 | ||||
| @ -132,8 +131,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup | ||||
|     """Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator.""" | ||||
|     import cv2  # type: ignore | ||||
| 
 | ||||
|     assert mode in {'holes', 'islands'} | ||||
|     correct_holes = mode == 'holes' | ||||
|     assert mode in {"holes", "islands"} | ||||
|     correct_holes = mode == "holes" | ||||
|     working_mask = (correct_holes ^ mask).astype(np.uint8) | ||||
|     n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) | ||||
|     sizes = stats[:, -1][1:]  # Row 0 is background label | ||||
|  | ||||
| @ -64,18 +64,16 @@ def build_mobile_sam(checkpoint=None): | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def _build_sam(encoder_embed_dim, | ||||
|                encoder_depth, | ||||
|                encoder_num_heads, | ||||
|                encoder_global_attn_indexes, | ||||
|                checkpoint=None, | ||||
|                mobile_sam=False): | ||||
| def _build_sam( | ||||
|     encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False | ||||
| ): | ||||
|     """Builds the selected SAM model architecture.""" | ||||
|     prompt_embed_dim = 256 | ||||
|     image_size = 1024 | ||||
|     vit_patch_size = 16 | ||||
|     image_embedding_size = image_size // vit_patch_size | ||||
|     image_encoder = (TinyViT( | ||||
|     image_encoder = ( | ||||
|         TinyViT( | ||||
|             img_size=1024, | ||||
|             in_chans=3, | ||||
|             num_classes=1000, | ||||
| @ -90,7 +88,9 @@ def _build_sam(encoder_embed_dim, | ||||
|             mbconv_expand_ratio=4.0, | ||||
|             local_conv_size=3, | ||||
|             layer_lr_decay=0.8, | ||||
|     ) if mobile_sam else ImageEncoderViT( | ||||
|         ) | ||||
|         if mobile_sam | ||||
|         else ImageEncoderViT( | ||||
|             depth=encoder_depth, | ||||
|             embed_dim=encoder_embed_dim, | ||||
|             img_size=image_size, | ||||
| @ -103,7 +103,8 @@ def _build_sam(encoder_embed_dim, | ||||
|             global_attn_indexes=encoder_global_attn_indexes, | ||||
|             window_size=14, | ||||
|             out_chans=prompt_embed_dim, | ||||
|     )) | ||||
|         ) | ||||
|     ) | ||||
|     sam = Sam( | ||||
|         image_encoder=image_encoder, | ||||
|         prompt_encoder=PromptEncoder( | ||||
| @ -129,7 +130,7 @@ def _build_sam(encoder_embed_dim, | ||||
|     ) | ||||
|     if checkpoint is not None: | ||||
|         checkpoint = attempt_download_asset(checkpoint) | ||||
|         with open(checkpoint, 'rb') as f: | ||||
|         with open(checkpoint, "rb") as f: | ||||
|             state_dict = torch.load(f) | ||||
|         sam.load_state_dict(state_dict) | ||||
|     sam.eval() | ||||
| @ -139,13 +140,14 @@ def _build_sam(encoder_embed_dim, | ||||
| 
 | ||||
| 
 | ||||
| sam_model_map = { | ||||
|     'sam_h.pt': build_sam_vit_h, | ||||
|     'sam_l.pt': build_sam_vit_l, | ||||
|     'sam_b.pt': build_sam_vit_b, | ||||
|     'mobile_sam.pt': build_mobile_sam, } | ||||
|     "sam_h.pt": build_sam_vit_h, | ||||
|     "sam_l.pt": build_sam_vit_l, | ||||
|     "sam_b.pt": build_sam_vit_b, | ||||
|     "mobile_sam.pt": build_mobile_sam, | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| def build_sam(ckpt='sam_b.pt'): | ||||
| def build_sam(ckpt="sam_b.pt"): | ||||
|     """Build a SAM model specified by ckpt.""" | ||||
|     model_builder = None | ||||
|     ckpt = str(ckpt)  # to allow Path ckpt types | ||||
| @ -154,6 +156,6 @@ def build_sam(ckpt='sam_b.pt'): | ||||
|             model_builder = sam_model_map.get(k) | ||||
| 
 | ||||
|     if not model_builder: | ||||
|         raise FileNotFoundError(f'{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}') | ||||
|         raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}") | ||||
| 
 | ||||
|     return model_builder(ckpt) | ||||
|  | ||||
| @ -32,7 +32,7 @@ class SAM(Model): | ||||
|     dataset. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, model='sam_b.pt') -> None: | ||||
|     def __init__(self, model="sam_b.pt") -> None: | ||||
|         """ | ||||
|         Initializes the SAM model with a pre-trained model file. | ||||
| 
 | ||||
| @ -42,9 +42,9 @@ class SAM(Model): | ||||
|         Raises: | ||||
|             NotImplementedError: If the model file extension is not .pt or .pth. | ||||
|         """ | ||||
|         if model and Path(model).suffix not in ('.pt', '.pth'): | ||||
|             raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.') | ||||
|         super().__init__(model=model, task='segment') | ||||
|         if model and Path(model).suffix not in (".pt", ".pth"): | ||||
|             raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.") | ||||
|         super().__init__(model=model, task="segment") | ||||
| 
 | ||||
|     def _load(self, weights: str, task=None): | ||||
|         """ | ||||
| @ -70,7 +70,7 @@ class SAM(Model): | ||||
|         Returns: | ||||
|             (list): The model predictions. | ||||
|         """ | ||||
|         overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024) | ||||
|         overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024) | ||||
|         kwargs.update(overrides) | ||||
|         prompts = dict(bboxes=bboxes, points=points, labels=labels) | ||||
|         return super().predict(source, stream, prompts=prompts, **kwargs) | ||||
| @ -112,4 +112,4 @@ class SAM(Model): | ||||
|         Returns: | ||||
|             (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'. | ||||
|         """ | ||||
|         return {'segment': {'predictor': Predictor}} | ||||
|         return {"segment": {"predictor": Predictor}} | ||||
|  | ||||
| @ -64,8 +64,9 @@ class MaskDecoder(nn.Module): | ||||
|             nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), | ||||
|             activation(), | ||||
|         ) | ||||
|         self.output_hypernetworks_mlps = nn.ModuleList([ | ||||
|             MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]) | ||||
|         self.output_hypernetworks_mlps = nn.ModuleList( | ||||
|             [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] | ||||
|         ) | ||||
| 
 | ||||
|         self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth) | ||||
| 
 | ||||
| @ -132,13 +133,14 @@ class MaskDecoder(nn.Module): | ||||
|         # Run the transformer | ||||
|         hs, src = self.transformer(src, pos_src, tokens) | ||||
|         iou_token_out = hs[:, 0, :] | ||||
|         mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :] | ||||
|         mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] | ||||
| 
 | ||||
|         # Upscale mask embeddings and predict masks using the mask tokens | ||||
|         src = src.transpose(1, 2).view(b, c, h, w) | ||||
|         upscaled_embedding = self.output_upscaling(src) | ||||
|         hyper_in_list: List[torch.Tensor] = [ | ||||
|             self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)] | ||||
|             self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens) | ||||
|         ] | ||||
|         hyper_in = torch.stack(hyper_in_list, dim=1) | ||||
|         b, c, h, w = upscaled_embedding.shape | ||||
|         masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) | ||||
|  | ||||
| @ -283,9 +283,9 @@ class PromptEncoder(nn.Module): | ||||
|         if masks is not None: | ||||
|             dense_embeddings = self._embed_masks(masks) | ||||
|         else: | ||||
|             dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, | ||||
|                                                                  1).expand(bs, -1, self.image_embedding_size[0], | ||||
|                                                                            self.image_embedding_size[1]) | ||||
|             dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( | ||||
|                 bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] | ||||
|             ) | ||||
| 
 | ||||
|         return sparse_embeddings, dense_embeddings | ||||
| 
 | ||||
| @ -298,7 +298,7 @@ class PositionEmbeddingRandom(nn.Module): | ||||
|         super().__init__() | ||||
|         if scale is None or scale <= 0.0: | ||||
|             scale = 1.0 | ||||
|         self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats))) | ||||
|         self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats))) | ||||
| 
 | ||||
|         # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation' | ||||
|         torch.use_deterministic_algorithms(False) | ||||
| @ -425,14 +425,14 @@ class Attention(nn.Module): | ||||
|         super().__init__() | ||||
|         self.num_heads = num_heads | ||||
|         head_dim = dim // num_heads | ||||
|         self.scale = head_dim ** -0.5 | ||||
|         self.scale = head_dim**-0.5 | ||||
| 
 | ||||
|         self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||||
|         self.proj = nn.Linear(dim, dim) | ||||
| 
 | ||||
|         self.use_rel_pos = use_rel_pos | ||||
|         if self.use_rel_pos: | ||||
|             assert (input_size is not None), 'Input size must be provided if using relative positional encoding.' | ||||
|             assert input_size is not None, "Input size must be provided if using relative positional encoding." | ||||
|             # Initialize relative positional embeddings | ||||
|             self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) | ||||
|             self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) | ||||
| @ -479,8 +479,9 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T | ||||
|     return windows, (Hp, Wp) | ||||
| 
 | ||||
| 
 | ||||
| def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], | ||||
|                        hw: Tuple[int, int]) -> torch.Tensor: | ||||
| def window_unpartition( | ||||
|     windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] | ||||
| ) -> torch.Tensor: | ||||
|     """ | ||||
|     Window unpartition into original sequences and removing padding. | ||||
| 
 | ||||
| @ -523,7 +524,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor | ||||
|         rel_pos_resized = F.interpolate( | ||||
|             rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), | ||||
|             size=max_rel_dist, | ||||
|             mode='linear', | ||||
|             mode="linear", | ||||
|         ) | ||||
|         rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) | ||||
|     else: | ||||
| @ -567,11 +568,12 @@ def add_decomposed_rel_pos( | ||||
| 
 | ||||
|     B, _, dim = q.shape | ||||
|     r_q = q.reshape(B, q_h, q_w, dim) | ||||
|     rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh) | ||||
|     rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw) | ||||
|     rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) | ||||
|     rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) | ||||
| 
 | ||||
|     attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( | ||||
|         B, q_h * q_w, k_h * k_w) | ||||
|         B, q_h * q_w, k_h * k_w | ||||
|     ) | ||||
| 
 | ||||
|     return attn | ||||
| 
 | ||||
|  | ||||
| @ -30,8 +30,9 @@ class Sam(nn.Module): | ||||
|         pixel_mean (List[float]): Mean pixel values for image normalization. | ||||
|         pixel_std (List[float]): Standard deviation values for image normalization. | ||||
|     """ | ||||
| 
 | ||||
|     mask_threshold: float = 0.0 | ||||
|     image_format: str = 'RGB' | ||||
|     image_format: str = "RGB" | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
| @ -39,7 +40,7 @@ class Sam(nn.Module): | ||||
|         prompt_encoder: PromptEncoder, | ||||
|         mask_decoder: MaskDecoder, | ||||
|         pixel_mean: List[float] = (123.675, 116.28, 103.53), | ||||
|         pixel_std: List[float] = (58.395, 57.12, 57.375) | ||||
|         pixel_std: List[float] = (58.395, 57.12, 57.375), | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Initialize the Sam class to predict object masks from an image and input prompts. | ||||
| @ -60,5 +61,5 @@ class Sam(nn.Module): | ||||
|         self.image_encoder = image_encoder | ||||
|         self.prompt_encoder = prompt_encoder | ||||
|         self.mask_decoder = mask_decoder | ||||
|         self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False) | ||||
|         self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False) | ||||
|         self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) | ||||
|         self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) | ||||
|  | ||||
| @ -28,11 +28,11 @@ class Conv2d_BN(torch.nn.Sequential): | ||||
|         drop path. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) | ||||
|         self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) | ||||
|         bn = torch.nn.BatchNorm2d(b) | ||||
|         torch.nn.init.constant_(bn.weight, bn_weight_init) | ||||
|         torch.nn.init.constant_(bn.bias, 0) | ||||
|         self.add_module('bn', bn) | ||||
|         self.add_module("bn", bn) | ||||
| 
 | ||||
| 
 | ||||
| class PatchEmbed(nn.Module): | ||||
| @ -146,11 +146,11 @@ class ConvLayer(nn.Module): | ||||
|         input_resolution, | ||||
|         depth, | ||||
|         activation, | ||||
|         drop_path=0., | ||||
|         drop_path=0.0, | ||||
|         downsample=None, | ||||
|         use_checkpoint=False, | ||||
|         out_dim=None, | ||||
|         conv_expand_ratio=4., | ||||
|         conv_expand_ratio=4.0, | ||||
|     ): | ||||
|         """ | ||||
|         Initializes the ConvLayer with the given dimensions and settings. | ||||
| @ -173,18 +173,25 @@ class ConvLayer(nn.Module): | ||||
|         self.use_checkpoint = use_checkpoint | ||||
| 
 | ||||
|         # Build blocks | ||||
|         self.blocks = nn.ModuleList([ | ||||
|         self.blocks = nn.ModuleList( | ||||
|             [ | ||||
|                 MBConv( | ||||
|                     dim, | ||||
|                     dim, | ||||
|                     conv_expand_ratio, | ||||
|                     activation, | ||||
|                     drop_path[i] if isinstance(drop_path, list) else drop_path, | ||||
|             ) for i in range(depth)]) | ||||
|                 ) | ||||
|                 for i in range(depth) | ||||
|             ] | ||||
|         ) | ||||
| 
 | ||||
|         # Patch merging layer | ||||
|         self.downsample = None if downsample is None else downsample( | ||||
|             input_resolution, dim=dim, out_dim=out_dim, activation=activation) | ||||
|         self.downsample = ( | ||||
|             None | ||||
|             if downsample is None | ||||
|             else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) | ||||
|         ) | ||||
| 
 | ||||
|     def forward(self, x): | ||||
|         """Processes the input through a series of convolutional layers and returns the activated output.""" | ||||
| @ -200,7 +207,7 @@ class Mlp(nn.Module): | ||||
|     This layer takes an input with in_features, applies layer normalization and two fully-connected layers. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | ||||
|     def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): | ||||
|         """Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc.""" | ||||
|         super().__init__() | ||||
|         out_features = out_features or in_features | ||||
| @ -256,7 +263,7 @@ class Attention(torch.nn.Module): | ||||
| 
 | ||||
|         assert isinstance(resolution, tuple) and len(resolution) == 2 | ||||
|         self.num_heads = num_heads | ||||
|         self.scale = key_dim ** -0.5 | ||||
|         self.scale = key_dim**-0.5 | ||||
|         self.key_dim = key_dim | ||||
|         self.nh_kd = nh_kd = key_dim * num_heads | ||||
|         self.d = int(attn_ratio * key_dim) | ||||
| @ -279,13 +286,13 @@ class Attention(torch.nn.Module): | ||||
|                     attention_offsets[offset] = len(attention_offsets) | ||||
|                 idxs.append(attention_offsets[offset]) | ||||
|         self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) | ||||
|         self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) | ||||
|         self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False) | ||||
| 
 | ||||
|     @torch.no_grad() | ||||
|     def train(self, mode=True): | ||||
|         """Sets the module in training mode and handles attribute 'ab' based on the mode.""" | ||||
|         super().train(mode) | ||||
|         if mode and hasattr(self, 'ab'): | ||||
|         if mode and hasattr(self, "ab"): | ||||
|             del self.ab | ||||
|         else: | ||||
|             self.ab = self.attention_biases[:, self.attention_bias_idxs] | ||||
| @ -306,8 +313,9 @@ class Attention(torch.nn.Module): | ||||
|         v = v.permute(0, 2, 1, 3) | ||||
|         self.ab = self.ab.to(self.attention_biases.device) | ||||
| 
 | ||||
|         attn = ((q @ k.transpose(-2, -1)) * self.scale + | ||||
|                 (self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab)) | ||||
|         attn = (q @ k.transpose(-2, -1)) * self.scale + ( | ||||
|             self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab | ||||
|         ) | ||||
|         attn = attn.softmax(dim=-1) | ||||
|         x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) | ||||
|         return self.proj(x) | ||||
| @ -322,9 +330,9 @@ class TinyViTBlock(nn.Module): | ||||
|         input_resolution, | ||||
|         num_heads, | ||||
|         window_size=7, | ||||
|         mlp_ratio=4., | ||||
|         drop=0., | ||||
|         drop_path=0., | ||||
|         mlp_ratio=4.0, | ||||
|         drop=0.0, | ||||
|         drop_path=0.0, | ||||
|         local_conv_size=3, | ||||
|         activation=nn.GELU, | ||||
|     ): | ||||
| @ -350,7 +358,7 @@ class TinyViTBlock(nn.Module): | ||||
|         self.dim = dim | ||||
|         self.input_resolution = input_resolution | ||||
|         self.num_heads = num_heads | ||||
|         assert window_size > 0, 'window_size must be greater than 0' | ||||
|         assert window_size > 0, "window_size must be greater than 0" | ||||
|         self.window_size = window_size | ||||
|         self.mlp_ratio = mlp_ratio | ||||
| 
 | ||||
| @ -358,7 +366,7 @@ class TinyViTBlock(nn.Module): | ||||
|         # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | ||||
|         self.drop_path = nn.Identity() | ||||
| 
 | ||||
|         assert dim % num_heads == 0, 'dim must be divisible by num_heads' | ||||
|         assert dim % num_heads == 0, "dim must be divisible by num_heads" | ||||
|         head_dim = dim // num_heads | ||||
| 
 | ||||
|         window_resolution = (window_size, window_size) | ||||
| @ -377,7 +385,7 @@ class TinyViTBlock(nn.Module): | ||||
|         """ | ||||
|         H, W = self.input_resolution | ||||
|         B, L, C = x.shape | ||||
|         assert L == H * W, 'input feature has wrong size' | ||||
|         assert L == H * W, "input feature has wrong size" | ||||
|         res_x = x | ||||
|         if H == self.window_size and W == self.window_size: | ||||
|             x = self.attn(x) | ||||
| @ -394,8 +402,11 @@ class TinyViTBlock(nn.Module): | ||||
|             nH = pH // self.window_size | ||||
|             nW = pW // self.window_size | ||||
|             # Window partition | ||||
|             x = x.view(B, nH, self.window_size, nW, self.window_size, | ||||
|                        C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C) | ||||
|             x = ( | ||||
|                 x.view(B, nH, self.window_size, nW, self.window_size, C) | ||||
|                 .transpose(2, 3) | ||||
|                 .reshape(B * nH * nW, self.window_size * self.window_size, C) | ||||
|             ) | ||||
|             x = self.attn(x) | ||||
|             # Window reverse | ||||
|             x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C) | ||||
| @ -417,8 +428,10 @@ class TinyViTBlock(nn.Module): | ||||
|         """Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of | ||||
|         attentions heads, window size, and MLP ratio. | ||||
|         """ | ||||
|         return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \ | ||||
|                f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}' | ||||
|         return ( | ||||
|             f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " | ||||
|             f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class BasicLayer(nn.Module): | ||||
| @ -431,9 +444,9 @@ class BasicLayer(nn.Module): | ||||
|         depth, | ||||
|         num_heads, | ||||
|         window_size, | ||||
|         mlp_ratio=4., | ||||
|         drop=0., | ||||
|         drop_path=0., | ||||
|         mlp_ratio=4.0, | ||||
|         drop=0.0, | ||||
|         drop_path=0.0, | ||||
|         downsample=None, | ||||
|         use_checkpoint=False, | ||||
|         local_conv_size=3, | ||||
| @ -468,7 +481,8 @@ class BasicLayer(nn.Module): | ||||
|         self.use_checkpoint = use_checkpoint | ||||
| 
 | ||||
|         # Build blocks | ||||
|         self.blocks = nn.ModuleList([ | ||||
|         self.blocks = nn.ModuleList( | ||||
|             [ | ||||
|                 TinyViTBlock( | ||||
|                     dim=dim, | ||||
|                     input_resolution=input_resolution, | ||||
| @ -479,11 +493,17 @@ class BasicLayer(nn.Module): | ||||
|                     drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, | ||||
|                     local_conv_size=local_conv_size, | ||||
|                     activation=activation, | ||||
|             ) for i in range(depth)]) | ||||
|                 ) | ||||
|                 for i in range(depth) | ||||
|             ] | ||||
|         ) | ||||
| 
 | ||||
|         # Patch merging layer | ||||
|         self.downsample = None if downsample is None else downsample( | ||||
|             input_resolution, dim=dim, out_dim=out_dim, activation=activation) | ||||
|         self.downsample = ( | ||||
|             None | ||||
|             if downsample is None | ||||
|             else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) | ||||
|         ) | ||||
| 
 | ||||
|     def forward(self, x): | ||||
|         """Performs forward propagation on the input tensor and returns a normalized tensor.""" | ||||
| @ -493,7 +513,7 @@ class BasicLayer(nn.Module): | ||||
| 
 | ||||
|     def extra_repr(self) -> str: | ||||
|         """Returns a string representation of the extra_repr function with the layer's parameters.""" | ||||
|         return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' | ||||
|         return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" | ||||
| 
 | ||||
| 
 | ||||
| class LayerNorm2d(nn.Module): | ||||
| @ -549,8 +569,8 @@ class TinyViT(nn.Module): | ||||
|         depths=[2, 2, 6, 2], | ||||
|         num_heads=[3, 6, 12, 24], | ||||
|         window_sizes=[7, 7, 14, 7], | ||||
|         mlp_ratio=4., | ||||
|         drop_rate=0., | ||||
|         mlp_ratio=4.0, | ||||
|         drop_rate=0.0, | ||||
|         drop_path_rate=0.1, | ||||
|         use_checkpoint=False, | ||||
|         mbconv_expand_ratio=4.0, | ||||
| @ -585,10 +605,9 @@ class TinyViT(nn.Module): | ||||
| 
 | ||||
|         activation = nn.GELU | ||||
| 
 | ||||
|         self.patch_embed = PatchEmbed(in_chans=in_chans, | ||||
|                                       embed_dim=embed_dims[0], | ||||
|                                       resolution=img_size, | ||||
|                                       activation=activation) | ||||
|         self.patch_embed = PatchEmbed( | ||||
|             in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation | ||||
|         ) | ||||
| 
 | ||||
|         patches_resolution = self.patch_embed.patches_resolution | ||||
|         self.patches_resolution = patches_resolution | ||||
| @ -601,27 +620,30 @@ class TinyViT(nn.Module): | ||||
|         for i_layer in range(self.num_layers): | ||||
|             kwargs = dict( | ||||
|                 dim=embed_dims[i_layer], | ||||
|                 input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), | ||||
|                                   patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))), | ||||
|                 input_resolution=( | ||||
|                     patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), | ||||
|                     patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), | ||||
|                 ), | ||||
|                 #   input_resolution=(patches_resolution[0] // (2 ** i_layer), | ||||
|                 #                     patches_resolution[1] // (2 ** i_layer)), | ||||
|                 depth=depths[i_layer], | ||||
|                 drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | ||||
|                 drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], | ||||
|                 downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, | ||||
|                 use_checkpoint=use_checkpoint, | ||||
|                 out_dim=embed_dims[min(i_layer + 1, | ||||
|                                        len(embed_dims) - 1)], | ||||
|                 out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)], | ||||
|                 activation=activation, | ||||
|             ) | ||||
|             if i_layer == 0: | ||||
|                 layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs) | ||||
|             else: | ||||
|                 layer = BasicLayer(num_heads=num_heads[i_layer], | ||||
|                 layer = BasicLayer( | ||||
|                     num_heads=num_heads[i_layer], | ||||
|                     window_size=window_sizes[i_layer], | ||||
|                     mlp_ratio=self.mlp_ratio, | ||||
|                     drop=drop_rate, | ||||
|                     local_conv_size=local_conv_size, | ||||
|                                    **kwargs) | ||||
|                     **kwargs, | ||||
|                 ) | ||||
|             self.layers.append(layer) | ||||
| 
 | ||||
|         # Classifier head | ||||
| @ -680,7 +702,7 @@ class TinyViT(nn.Module): | ||||
|         def _check_lr_scale(m): | ||||
|             """Checks if the learning rate scale attribute is present in module's parameters.""" | ||||
|             for p in m.parameters(): | ||||
|                 assert hasattr(p, 'lr_scale'), p.param_name | ||||
|                 assert hasattr(p, "lr_scale"), p.param_name | ||||
| 
 | ||||
|         self.apply(_check_lr_scale) | ||||
| 
 | ||||
| @ -698,7 +720,7 @@ class TinyViT(nn.Module): | ||||
|     @torch.jit.ignore | ||||
|     def no_weight_decay_keywords(self): | ||||
|         """Returns a dictionary of parameter names where weight decay should not be applied.""" | ||||
|         return {'attention_biases'} | ||||
|         return {"attention_biases"} | ||||
| 
 | ||||
|     def forward_features(self, x): | ||||
|         """Runs the input through the model layers and returns the transformed output.""" | ||||
|  | ||||
| @ -62,7 +62,8 @@ class TwoWayTransformer(nn.Module): | ||||
|                     activation=activation, | ||||
|                     attention_downsample_rate=attention_downsample_rate, | ||||
|                     skip_first_layer_pe=(i == 0), | ||||
|                 )) | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) | ||||
|         self.norm_final_attn = nn.LayerNorm(embedding_dim) | ||||
| @ -227,7 +228,7 @@ class Attention(nn.Module): | ||||
|         self.embedding_dim = embedding_dim | ||||
|         self.internal_dim = embedding_dim // downsample_rate | ||||
|         self.num_heads = num_heads | ||||
|         assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.' | ||||
|         assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." | ||||
| 
 | ||||
|         self.q_proj = nn.Linear(embedding_dim, self.internal_dim) | ||||
|         self.k_proj = nn.Linear(embedding_dim, self.internal_dim) | ||||
|  | ||||
| @ -19,8 +19,17 @@ from ultralytics.engine.results import Results | ||||
| from ultralytics.utils import DEFAULT_CFG, ops | ||||
| from ultralytics.utils.torch_utils import select_device | ||||
| 
 | ||||
| from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score, | ||||
|                   generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks) | ||||
| from .amg import ( | ||||
|     batch_iterator, | ||||
|     batched_mask_to_box, | ||||
|     build_all_layer_point_grids, | ||||
|     calculate_stability_score, | ||||
|     generate_crop_boxes, | ||||
|     is_box_near_crop_edge, | ||||
|     remove_small_regions, | ||||
|     uncrop_boxes_xyxy, | ||||
|     uncrop_masks, | ||||
| ) | ||||
| from .build import build_sam | ||||
| 
 | ||||
| 
 | ||||
| @ -58,7 +67,7 @@ class Predictor(BasePredictor): | ||||
|         """ | ||||
|         if overrides is None: | ||||
|             overrides = {} | ||||
|         overrides.update(dict(task='segment', mode='predict', imgsz=1024)) | ||||
|         overrides.update(dict(task="segment", mode="predict", imgsz=1024)) | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
|         self.args.retina_masks = True | ||||
|         self.im = None | ||||
| @ -107,7 +116,7 @@ class Predictor(BasePredictor): | ||||
|         Returns: | ||||
|             (List[np.ndarray]): List of transformed images. | ||||
|         """ | ||||
|         assert len(im) == 1, 'SAM model does not currently support batched inference' | ||||
|         assert len(im) == 1, "SAM model does not currently support batched inference" | ||||
|         letterbox = LetterBox(self.args.imgsz, auto=False, center=False) | ||||
|         return [letterbox(image=x) for x in im] | ||||
| 
 | ||||
| @ -132,9 +141,9 @@ class Predictor(BasePredictor): | ||||
|                 - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. | ||||
|         """ | ||||
|         # Override prompts if any stored in self.prompts | ||||
|         bboxes = self.prompts.pop('bboxes', bboxes) | ||||
|         points = self.prompts.pop('points', points) | ||||
|         masks = self.prompts.pop('masks', masks) | ||||
|         bboxes = self.prompts.pop("bboxes", bboxes) | ||||
|         points = self.prompts.pop("points", points) | ||||
|         masks = self.prompts.pop("masks", masks) | ||||
| 
 | ||||
|         if all(i is None for i in [bboxes, points, masks]): | ||||
|             return self.generate(im, *args, **kwargs) | ||||
| @ -199,7 +208,8 @@ class Predictor(BasePredictor): | ||||
|         # `d` could be 1 or 3 depends on `multimask_output`. | ||||
|         return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) | ||||
| 
 | ||||
|     def generate(self, | ||||
|     def generate( | ||||
|         self, | ||||
|         im, | ||||
|         crop_n_layers=0, | ||||
|         crop_overlap_ratio=512 / 1500, | ||||
| @ -210,7 +220,8 @@ class Predictor(BasePredictor): | ||||
|         conf_thres=0.88, | ||||
|         stability_score_thresh=0.95, | ||||
|         stability_score_offset=0.95, | ||||
|                  crop_nms_thresh=0.7): | ||||
|         crop_nms_thresh=0.7, | ||||
|     ): | ||||
|         """ | ||||
|         Perform image segmentation using the Segment Anything Model (SAM). | ||||
| 
 | ||||
| @ -248,19 +259,20 @@ class Predictor(BasePredictor): | ||||
|             area = torch.tensor(w * h, device=im.device) | ||||
|             points_scale = np.array([[w, h]])  # w, h | ||||
|             # Crop image and interpolate to input size | ||||
|             crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False) | ||||
|             crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) | ||||
|             # (num_points, 2) | ||||
|             points_for_image = point_grids[layer_idx] * points_scale | ||||
|             crop_masks, crop_scores, crop_bboxes = [], [], [] | ||||
|             for (points, ) in batch_iterator(points_batch_size, points_for_image): | ||||
|             for (points,) in batch_iterator(points_batch_size, points_for_image): | ||||
|                 pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) | ||||
|                 # Interpolate predicted masks to input size | ||||
|                 pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0] | ||||
|                 pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] | ||||
|                 idx = pred_score > conf_thres | ||||
|                 pred_mask, pred_score = pred_mask[idx], pred_score[idx] | ||||
| 
 | ||||
|                 stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold, | ||||
|                                                             stability_score_offset) | ||||
|                 stability_score = calculate_stability_score( | ||||
|                     pred_mask, self.model.mask_threshold, stability_score_offset | ||||
|                 ) | ||||
|                 idx = stability_score > stability_score_thresh | ||||
|                 pred_mask, pred_score = pred_mask[idx], pred_score[idx] | ||||
|                 # Bool type is much more memory-efficient. | ||||
| @ -404,7 +416,7 @@ class Predictor(BasePredictor): | ||||
|             model = build_sam(self.args.model) | ||||
|             self.setup_model(model) | ||||
|         self.setup_source(image) | ||||
|         assert len(self.dataset) == 1, '`set_image` only supports setting one image!' | ||||
|         assert len(self.dataset) == 1, "`set_image` only supports setting one image!" | ||||
|         for batch in self.dataset: | ||||
|             im = self.preprocess(batch[1]) | ||||
|             self.features = self.model.image_encoder(im) | ||||
| @ -446,9 +458,9 @@ class Predictor(BasePredictor): | ||||
|         scores = [] | ||||
|         for mask in masks: | ||||
|             mask = mask.cpu().numpy().astype(np.uint8) | ||||
|             mask, changed = remove_small_regions(mask, min_area, mode='holes') | ||||
|             mask, changed = remove_small_regions(mask, min_area, mode="holes") | ||||
|             unchanged = not changed | ||||
|             mask, changed = remove_small_regions(mask, min_area, mode='islands') | ||||
|             mask, changed = remove_small_regions(mask, min_area, mode="islands") | ||||
|             unchanged = unchanged and not changed | ||||
| 
 | ||||
|             new_masks.append(torch.as_tensor(mask).unsqueeze(0)) | ||||
|  | ||||
| @ -30,14 +30,9 @@ class DETRLoss(nn.Module): | ||||
|         device (torch.device): Device on which tensors are stored. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, | ||||
|                  nc=80, | ||||
|                  loss_gain=None, | ||||
|                  aux_loss=True, | ||||
|                  use_fl=True, | ||||
|                  use_vfl=False, | ||||
|                  use_uni_match=False, | ||||
|                  uni_match_ind=0): | ||||
|     def __init__( | ||||
|         self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0 | ||||
|     ): | ||||
|         """ | ||||
|         DETR loss function. | ||||
| 
 | ||||
| @ -52,9 +47,9 @@ class DETRLoss(nn.Module): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         if loss_gain is None: | ||||
|             loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1} | ||||
|             loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1} | ||||
|         self.nc = nc | ||||
|         self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2}) | ||||
|         self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2}) | ||||
|         self.loss_gain = loss_gain | ||||
|         self.aux_loss = aux_loss | ||||
|         self.fl = FocalLoss() if use_fl else None | ||||
| @ -64,10 +59,10 @@ class DETRLoss(nn.Module): | ||||
|         self.uni_match_ind = uni_match_ind | ||||
|         self.device = None | ||||
| 
 | ||||
|     def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''): | ||||
|     def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""): | ||||
|         """Computes the classification loss based on predictions, target values, and ground truth scores.""" | ||||
|         # Logits: [b, query, num_classes], gt_class: list[[n, 1]] | ||||
|         name_class = f'loss_class{postfix}' | ||||
|         name_class = f"loss_class{postfix}" | ||||
|         bs, nq = pred_scores.shape[:2] | ||||
|         # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1]  # (bs, num_queries, num_classes) | ||||
|         one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device) | ||||
| @ -82,28 +77,28 @@ class DETRLoss(nn.Module): | ||||
|                 loss_cls = self.fl(pred_scores, one_hot.float()) | ||||
|             loss_cls /= max(num_gts, 1) / nq | ||||
|         else: | ||||
|             loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum()  # YOLO CLS loss | ||||
|             loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum()  # YOLO CLS loss | ||||
| 
 | ||||
|         return {name_class: loss_cls.squeeze() * self.loss_gain['class']} | ||||
|         return {name_class: loss_cls.squeeze() * self.loss_gain["class"]} | ||||
| 
 | ||||
|     def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''): | ||||
|     def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""): | ||||
|         """Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding | ||||
|         boxes. | ||||
|         """ | ||||
|         # Boxes: [b, query, 4], gt_bbox: list[[n, 4]] | ||||
|         name_bbox = f'loss_bbox{postfix}' | ||||
|         name_giou = f'loss_giou{postfix}' | ||||
|         name_bbox = f"loss_bbox{postfix}" | ||||
|         name_giou = f"loss_giou{postfix}" | ||||
| 
 | ||||
|         loss = {} | ||||
|         if len(gt_bboxes) == 0: | ||||
|             loss[name_bbox] = torch.tensor(0., device=self.device) | ||||
|             loss[name_giou] = torch.tensor(0., device=self.device) | ||||
|             loss[name_bbox] = torch.tensor(0.0, device=self.device) | ||||
|             loss[name_giou] = torch.tensor(0.0, device=self.device) | ||||
|             return loss | ||||
| 
 | ||||
|         loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes) | ||||
|         loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes) | ||||
|         loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True) | ||||
|         loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes) | ||||
|         loss[name_giou] = self.loss_gain['giou'] * loss[name_giou] | ||||
|         loss[name_giou] = self.loss_gain["giou"] * loss[name_giou] | ||||
|         return {k: v.squeeze() for k, v in loss.items()} | ||||
| 
 | ||||
|     # This function is for future RT-DETR Segment models | ||||
| @ -137,30 +132,35 @@ class DETRLoss(nn.Module): | ||||
|     #     loss = 1 - (numerator + 1) / (denominator + 1) | ||||
|     #     return loss.sum() / num_gts | ||||
| 
 | ||||
|     def _get_loss_aux(self, | ||||
|     def _get_loss_aux( | ||||
|         self, | ||||
|         pred_bboxes, | ||||
|         pred_scores, | ||||
|         gt_bboxes, | ||||
|         gt_cls, | ||||
|         gt_groups, | ||||
|         match_indices=None, | ||||
|                       postfix='', | ||||
|         postfix="", | ||||
|         masks=None, | ||||
|                       gt_mask=None): | ||||
|         gt_mask=None, | ||||
|     ): | ||||
|         """Get auxiliary losses.""" | ||||
|         # NOTE: loss class, bbox, giou, mask, dice | ||||
|         loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device) | ||||
|         if match_indices is None and self.use_uni_match: | ||||
|             match_indices = self.matcher(pred_bboxes[self.uni_match_ind], | ||||
|             match_indices = self.matcher( | ||||
|                 pred_bboxes[self.uni_match_ind], | ||||
|                 pred_scores[self.uni_match_ind], | ||||
|                 gt_bboxes, | ||||
|                 gt_cls, | ||||
|                 gt_groups, | ||||
|                 masks=masks[self.uni_match_ind] if masks is not None else None, | ||||
|                                          gt_mask=gt_mask) | ||||
|                 gt_mask=gt_mask, | ||||
|             ) | ||||
|         for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)): | ||||
|             aux_masks = masks[i] if masks is not None else None | ||||
|             loss_ = self._get_loss(aux_bboxes, | ||||
|             loss_ = self._get_loss( | ||||
|                 aux_bboxes, | ||||
|                 aux_scores, | ||||
|                 gt_bboxes, | ||||
|                 gt_cls, | ||||
| @ -168,19 +168,21 @@ class DETRLoss(nn.Module): | ||||
|                 masks=aux_masks, | ||||
|                 gt_mask=gt_mask, | ||||
|                 postfix=postfix, | ||||
|                                    match_indices=match_indices) | ||||
|             loss[0] += loss_[f'loss_class{postfix}'] | ||||
|             loss[1] += loss_[f'loss_bbox{postfix}'] | ||||
|             loss[2] += loss_[f'loss_giou{postfix}'] | ||||
|                 match_indices=match_indices, | ||||
|             ) | ||||
|             loss[0] += loss_[f"loss_class{postfix}"] | ||||
|             loss[1] += loss_[f"loss_bbox{postfix}"] | ||||
|             loss[2] += loss_[f"loss_giou{postfix}"] | ||||
|             # if masks is not None and gt_mask is not None: | ||||
|             #     loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix) | ||||
|             #     loss[3] += loss_[f'loss_mask{postfix}'] | ||||
|             #     loss[4] += loss_[f'loss_dice{postfix}'] | ||||
| 
 | ||||
|         loss = { | ||||
|             f'loss_class_aux{postfix}': loss[0], | ||||
|             f'loss_bbox_aux{postfix}': loss[1], | ||||
|             f'loss_giou_aux{postfix}': loss[2]} | ||||
|             f"loss_class_aux{postfix}": loss[0], | ||||
|             f"loss_bbox_aux{postfix}": loss[1], | ||||
|             f"loss_giou_aux{postfix}": loss[2], | ||||
|         } | ||||
|         # if masks is not None and gt_mask is not None: | ||||
|         #     loss[f'loss_mask_aux{postfix}'] = loss[3] | ||||
|         #     loss[f'loss_dice_aux{postfix}'] = loss[4] | ||||
| @ -196,15 +198,22 @@ class DETRLoss(nn.Module): | ||||
| 
 | ||||
|     def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices): | ||||
|         """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices.""" | ||||
|         pred_assigned = torch.cat([ | ||||
|         pred_assigned = torch.cat( | ||||
|             [ | ||||
|                 t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device) | ||||
|             for t, (I, _) in zip(pred_bboxes, match_indices)]) | ||||
|         gt_assigned = torch.cat([ | ||||
|                 for t, (I, _) in zip(pred_bboxes, match_indices) | ||||
|             ] | ||||
|         ) | ||||
|         gt_assigned = torch.cat( | ||||
|             [ | ||||
|                 t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device) | ||||
|             for t, (_, J) in zip(gt_bboxes, match_indices)]) | ||||
|                 for t, (_, J) in zip(gt_bboxes, match_indices) | ||||
|             ] | ||||
|         ) | ||||
|         return pred_assigned, gt_assigned | ||||
| 
 | ||||
|     def _get_loss(self, | ||||
|     def _get_loss( | ||||
|         self, | ||||
|         pred_bboxes, | ||||
|         pred_scores, | ||||
|         gt_bboxes, | ||||
| @ -212,17 +221,14 @@ class DETRLoss(nn.Module): | ||||
|         gt_groups, | ||||
|         masks=None, | ||||
|         gt_mask=None, | ||||
|                   postfix='', | ||||
|                   match_indices=None): | ||||
|         postfix="", | ||||
|         match_indices=None, | ||||
|     ): | ||||
|         """Get losses.""" | ||||
|         if match_indices is None: | ||||
|             match_indices = self.matcher(pred_bboxes, | ||||
|                                          pred_scores, | ||||
|                                          gt_bboxes, | ||||
|                                          gt_cls, | ||||
|                                          gt_groups, | ||||
|                                          masks=masks, | ||||
|                                          gt_mask=gt_mask) | ||||
|             match_indices = self.matcher( | ||||
|                 pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask | ||||
|             ) | ||||
| 
 | ||||
|         idx, gt_idx = self._get_index(match_indices) | ||||
|         pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx] | ||||
| @ -242,7 +248,7 @@ class DETRLoss(nn.Module): | ||||
|         #     loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix)) | ||||
|         return loss | ||||
| 
 | ||||
|     def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs): | ||||
|     def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
|             pred_bboxes (torch.Tensor): [l, b, query, 4] | ||||
| @ -254,21 +260,19 @@ class DETRLoss(nn.Module): | ||||
|             postfix (str): postfix of loss name. | ||||
|         """ | ||||
|         self.device = pred_bboxes.device | ||||
|         match_indices = kwargs.get('match_indices', None) | ||||
|         gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups'] | ||||
|         match_indices = kwargs.get("match_indices", None) | ||||
|         gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"] | ||||
| 
 | ||||
|         total_loss = self._get_loss(pred_bboxes[-1], | ||||
|                                     pred_scores[-1], | ||||
|                                     gt_bboxes, | ||||
|                                     gt_cls, | ||||
|                                     gt_groups, | ||||
|                                     postfix=postfix, | ||||
|                                     match_indices=match_indices) | ||||
|         total_loss = self._get_loss( | ||||
|             pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices | ||||
|         ) | ||||
| 
 | ||||
|         if self.aux_loss: | ||||
|             total_loss.update( | ||||
|                 self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, | ||||
|                                    postfix)) | ||||
|                 self._get_loss_aux( | ||||
|                     pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         return total_loss | ||||
| 
 | ||||
| @ -300,18 +304,18 @@ class RTDETRDetectionLoss(DETRLoss): | ||||
| 
 | ||||
|         # Check for denoising metadata to compute denoising training loss | ||||
|         if dn_meta is not None: | ||||
|             dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group'] | ||||
|             assert len(batch['gt_groups']) == len(dn_pos_idx) | ||||
|             dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"] | ||||
|             assert len(batch["gt_groups"]) == len(dn_pos_idx) | ||||
| 
 | ||||
|             # Get the match indices for denoising | ||||
|             match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups']) | ||||
|             match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"]) | ||||
| 
 | ||||
|             # Compute the denoising training loss | ||||
|             dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices) | ||||
|             dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices) | ||||
|             total_loss.update(dn_loss) | ||||
|         else: | ||||
|             # If no denoising metadata is provided, set denoising loss to zero | ||||
|             total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()}) | ||||
|             total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()}) | ||||
| 
 | ||||
|         return total_loss | ||||
| 
 | ||||
| @ -334,8 +338,8 @@ class RTDETRDetectionLoss(DETRLoss): | ||||
|             if num_gt > 0: | ||||
|                 gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i] | ||||
|                 gt_idx = gt_idx.repeat(dn_num_group) | ||||
|                 assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, ' | ||||
|                 f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.' | ||||
|                 assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, " | ||||
|                 f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively." | ||||
|                 dn_match_indices.append((dn_pos_idx[i], gt_idx)) | ||||
|             else: | ||||
|                 dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long))) | ||||
|  | ||||
| @ -37,7 +37,7 @@ class HungarianMatcher(nn.Module): | ||||
|         """ | ||||
|         super().__init__() | ||||
|         if cost_gain is None: | ||||
|             cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1} | ||||
|             cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1} | ||||
|         self.cost_gain = cost_gain | ||||
|         self.use_fl = use_fl | ||||
|         self.with_mask = with_mask | ||||
| @ -86,7 +86,7 @@ class HungarianMatcher(nn.Module): | ||||
|         # Compute the classification cost | ||||
|         pred_scores = pred_scores[:, gt_cls] | ||||
|         if self.use_fl: | ||||
|             neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log()) | ||||
|             neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log()) | ||||
|             pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log()) | ||||
|             cost_class = pos_cost_class - neg_cost_class | ||||
|         else: | ||||
| @ -99,9 +99,11 @@ class HungarianMatcher(nn.Module): | ||||
|         cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1) | ||||
| 
 | ||||
|         # Final cost matrix | ||||
|         C = self.cost_gain['class'] * cost_class + \ | ||||
|             self.cost_gain['bbox'] * cost_bbox + \ | ||||
|             self.cost_gain['giou'] * cost_giou | ||||
|         C = ( | ||||
|             self.cost_gain["class"] * cost_class | ||||
|             + self.cost_gain["bbox"] * cost_bbox | ||||
|             + self.cost_gain["giou"] * cost_giou | ||||
|         ) | ||||
|         # Compute the mask cost and dice cost | ||||
|         if self.with_mask: | ||||
|             C += self._cost_mask(bs, gt_groups, masks, gt_mask) | ||||
| @ -111,10 +113,11 @@ class HungarianMatcher(nn.Module): | ||||
| 
 | ||||
|         C = C.view(bs, nq, -1).cpu() | ||||
|         indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))] | ||||
|         gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) | ||||
|         # (idx for queries, idx for gt) | ||||
|         return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k]) | ||||
|                 for k, (i, j) in enumerate(indices)] | ||||
|         gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)  # (idx for queries, idx for gt) | ||||
|         return [ | ||||
|             (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k]) | ||||
|             for k, (i, j) in enumerate(indices) | ||||
|         ] | ||||
| 
 | ||||
|     # This function is for future RT-DETR Segment models | ||||
|     # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None): | ||||
| @ -147,14 +150,9 @@ class HungarianMatcher(nn.Module): | ||||
|     #     return C | ||||
| 
 | ||||
| 
 | ||||
| def get_cdn_group(batch, | ||||
|                   num_classes, | ||||
|                   num_queries, | ||||
|                   class_embed, | ||||
|                   num_dn=100, | ||||
|                   cls_noise_ratio=0.5, | ||||
|                   box_noise_scale=1.0, | ||||
|                   training=False): | ||||
| def get_cdn_group( | ||||
|     batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False | ||||
| ): | ||||
|     """ | ||||
|     Get contrastive denoising training group. This function creates a contrastive denoising training group with positive | ||||
|     and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates, | ||||
| @ -180,7 +178,7 @@ def get_cdn_group(batch, | ||||
| 
 | ||||
|     if (not training) or num_dn <= 0: | ||||
|         return None, None, None, None | ||||
|     gt_groups = batch['gt_groups'] | ||||
|     gt_groups = batch["gt_groups"] | ||||
|     total_num = sum(gt_groups) | ||||
|     max_nums = max(gt_groups) | ||||
|     if max_nums == 0: | ||||
| @ -190,9 +188,9 @@ def get_cdn_group(batch, | ||||
|     num_group = 1 if num_group == 0 else num_group | ||||
|     # Pad gt to max_num of a batch | ||||
|     bs = len(gt_groups) | ||||
|     gt_cls = batch['cls']  # (bs*num, ) | ||||
|     gt_bbox = batch['bboxes']  # bs*num, 4 | ||||
|     b_idx = batch['batch_idx'] | ||||
|     gt_cls = batch["cls"]  # (bs*num, ) | ||||
|     gt_bbox = batch["bboxes"]  # bs*num, 4 | ||||
|     b_idx = batch["batch_idx"] | ||||
| 
 | ||||
|     # Each group has positive and negative queries. | ||||
|     dn_cls = gt_cls.repeat(2 * num_group)  # (2*num_group*bs*num, ) | ||||
| @ -245,16 +243,21 @@ def get_cdn_group(batch, | ||||
|     # Reconstruct cannot see each other | ||||
|     for i in range(num_group): | ||||
|         if i == 0: | ||||
|             attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True | ||||
|             attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True | ||||
|         if i == num_group - 1: | ||||
|             attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True | ||||
|             attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True | ||||
|         else: | ||||
|             attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True | ||||
|             attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True | ||||
|             attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True | ||||
|             attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True | ||||
|     dn_meta = { | ||||
|         'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)], | ||||
|         'dn_num_group': num_group, | ||||
|         'dn_num_split': [num_dn, num_queries]} | ||||
|         "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)], | ||||
|         "dn_num_group": num_group, | ||||
|         "dn_num_split": [num_dn, num_queries], | ||||
|     } | ||||
| 
 | ||||
|     return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to( | ||||
|         class_embed.device), dn_meta | ||||
|     return ( | ||||
|         padding_cls.to(class_embed.device), | ||||
|         padding_bbox.to(class_embed.device), | ||||
|         attn_mask.to(class_embed.device), | ||||
|         dn_meta, | ||||
|     ) | ||||
|  | ||||
| @ -4,4 +4,4 @@ from ultralytics.models.yolo import classify, detect, obb, pose, segment | ||||
| 
 | ||||
| from .model import YOLO | ||||
| 
 | ||||
| __all__ = 'classify', 'segment', 'detect', 'pose', 'obb', 'YOLO' | ||||
| __all__ = "classify", "segment", "detect", "pose", "obb", "YOLO" | ||||
|  | ||||
| @ -4,4 +4,4 @@ from ultralytics.models.yolo.classify.predict import ClassificationPredictor | ||||
| from ultralytics.models.yolo.classify.train import ClassificationTrainer | ||||
| from ultralytics.models.yolo.classify.val import ClassificationValidator | ||||
| 
 | ||||
| __all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator' | ||||
| __all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator" | ||||
|  | ||||
| @ -30,19 +30,21 @@ class ClassificationPredictor(BasePredictor): | ||||
|     def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): | ||||
|         """Initializes ClassificationPredictor setting the task to 'classify'.""" | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
|         self.args.task = 'classify' | ||||
|         self._legacy_transform_name = 'ultralytics.yolo.data.augment.ToTensor' | ||||
|         self.args.task = "classify" | ||||
|         self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor" | ||||
| 
 | ||||
|     def preprocess(self, img): | ||||
|         """Converts input image to model-compatible data type.""" | ||||
|         if not isinstance(img, torch.Tensor): | ||||
|             is_legacy_transform = any(self._legacy_transform_name in str(transform) | ||||
|                                       for transform in self.transforms.transforms) | ||||
|             is_legacy_transform = any( | ||||
|                 self._legacy_transform_name in str(transform) for transform in self.transforms.transforms | ||||
|             ) | ||||
|             if is_legacy_transform:  # to handle legacy transforms | ||||
|                 img = torch.stack([self.transforms(im) for im in img], dim=0) | ||||
|             else: | ||||
|                 img = torch.stack([self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], | ||||
|                                   dim=0) | ||||
|                 img = torch.stack( | ||||
|                     [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0 | ||||
|                 ) | ||||
|         img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) | ||||
|         return img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32 | ||||
| 
 | ||||
|  | ||||
| @ -33,23 +33,23 @@ class ClassificationTrainer(BaseTrainer): | ||||
|         """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks.""" | ||||
|         if overrides is None: | ||||
|             overrides = {} | ||||
|         overrides['task'] = 'classify' | ||||
|         if overrides.get('imgsz') is None: | ||||
|             overrides['imgsz'] = 224 | ||||
|         overrides["task"] = "classify" | ||||
|         if overrides.get("imgsz") is None: | ||||
|             overrides["imgsz"] = 224 | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
| 
 | ||||
|     def set_model_attributes(self): | ||||
|         """Set the YOLO model's class names from the loaded dataset.""" | ||||
|         self.model.names = self.data['names'] | ||||
|         self.model.names = self.data["names"] | ||||
| 
 | ||||
|     def get_model(self, cfg=None, weights=None, verbose=True): | ||||
|         """Returns a modified PyTorch model configured for training YOLO.""" | ||||
|         model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) | ||||
|         model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) | ||||
|         if weights: | ||||
|             model.load(weights) | ||||
| 
 | ||||
|         for m in model.modules(): | ||||
|             if not self.args.pretrained and hasattr(m, 'reset_parameters'): | ||||
|             if not self.args.pretrained and hasattr(m, "reset_parameters"): | ||||
|                 m.reset_parameters() | ||||
|             if isinstance(m, torch.nn.Dropout) and self.args.dropout: | ||||
|                 m.p = self.args.dropout  # set dropout | ||||
| @ -64,32 +64,32 @@ class ClassificationTrainer(BaseTrainer): | ||||
| 
 | ||||
|         model, ckpt = str(self.model), None | ||||
|         # Load a YOLO model locally, from torchvision, or from Ultralytics assets | ||||
|         if model.endswith('.pt'): | ||||
|             self.model, ckpt = attempt_load_one_weight(model, device='cpu') | ||||
|         if model.endswith(".pt"): | ||||
|             self.model, ckpt = attempt_load_one_weight(model, device="cpu") | ||||
|             for p in self.model.parameters(): | ||||
|                 p.requires_grad = True  # for training | ||||
|         elif model.split('.')[-1] in ('yaml', 'yml'): | ||||
|         elif model.split(".")[-1] in ("yaml", "yml"): | ||||
|             self.model = self.get_model(cfg=model) | ||||
|         elif model in torchvision.models.__dict__: | ||||
|             self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None) | ||||
|             self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None) | ||||
|         else: | ||||
|             FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') | ||||
|         ClassificationModel.reshape_outputs(self.model, self.data['nc']) | ||||
|             FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.") | ||||
|         ClassificationModel.reshape_outputs(self.model, self.data["nc"]) | ||||
| 
 | ||||
|         return ckpt | ||||
| 
 | ||||
|     def build_dataset(self, img_path, mode='train', batch=None): | ||||
|     def build_dataset(self, img_path, mode="train", batch=None): | ||||
|         """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.).""" | ||||
|         return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode) | ||||
|         return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode) | ||||
| 
 | ||||
|     def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): | ||||
|     def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): | ||||
|         """Returns PyTorch DataLoader with transforms to preprocess images for inference.""" | ||||
|         with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP | ||||
|             dataset = self.build_dataset(dataset_path, mode) | ||||
| 
 | ||||
|         loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank) | ||||
|         # Attach inference transforms | ||||
|         if mode != 'train': | ||||
|         if mode != "train": | ||||
|             if is_parallel(self.model): | ||||
|                 self.model.module.transforms = loader.dataset.torch_transforms | ||||
|             else: | ||||
| @ -98,27 +98,32 @@ class ClassificationTrainer(BaseTrainer): | ||||
| 
 | ||||
|     def preprocess_batch(self, batch): | ||||
|         """Preprocesses a batch of images and classes.""" | ||||
|         batch['img'] = batch['img'].to(self.device) | ||||
|         batch['cls'] = batch['cls'].to(self.device) | ||||
|         batch["img"] = batch["img"].to(self.device) | ||||
|         batch["cls"] = batch["cls"].to(self.device) | ||||
|         return batch | ||||
| 
 | ||||
|     def progress_string(self): | ||||
|         """Returns a formatted string showing training progress.""" | ||||
|         return ('\n' + '%11s' * (4 + len(self.loss_names))) % \ | ||||
|             ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') | ||||
|         return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( | ||||
|             "Epoch", | ||||
|             "GPU_mem", | ||||
|             *self.loss_names, | ||||
|             "Instances", | ||||
|             "Size", | ||||
|         ) | ||||
| 
 | ||||
|     def get_validator(self): | ||||
|         """Returns an instance of ClassificationValidator for validation.""" | ||||
|         self.loss_names = ['loss'] | ||||
|         self.loss_names = ["loss"] | ||||
|         return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks) | ||||
| 
 | ||||
|     def label_loss_items(self, loss_items=None, prefix='train'): | ||||
|     def label_loss_items(self, loss_items=None, prefix="train"): | ||||
|         """ | ||||
|         Returns a loss dict with labelled training loss items tensor. | ||||
| 
 | ||||
|         Not needed for classification but necessary for segmentation & detection | ||||
|         """ | ||||
|         keys = [f'{prefix}/{x}' for x in self.loss_names] | ||||
|         keys = [f"{prefix}/{x}" for x in self.loss_names] | ||||
|         if loss_items is None: | ||||
|             return keys | ||||
|         loss_items = [round(float(loss_items), 5)] | ||||
| @ -134,19 +139,20 @@ class ClassificationTrainer(BaseTrainer): | ||||
|             if f.exists(): | ||||
|                 strip_optimizer(f)  # strip optimizers | ||||
|                 if f is self.best: | ||||
|                     LOGGER.info(f'\nValidating {f}...') | ||||
|                     LOGGER.info(f"\nValidating {f}...") | ||||
|                     self.validator.args.data = self.args.data | ||||
|                     self.validator.args.plots = self.args.plots | ||||
|                     self.metrics = self.validator(model=f) | ||||
|                     self.metrics.pop('fitness', None) | ||||
|                     self.run_callbacks('on_fit_epoch_end') | ||||
|                     self.metrics.pop("fitness", None) | ||||
|                     self.run_callbacks("on_fit_epoch_end") | ||||
|         LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") | ||||
| 
 | ||||
|     def plot_training_samples(self, batch, ni): | ||||
|         """Plots training samples with their annotations.""" | ||||
|         plot_images( | ||||
|             images=batch['img'], | ||||
|             batch_idx=torch.arange(len(batch['img'])), | ||||
|             cls=batch['cls'].view(-1),  # warning: use .view(), not .squeeze() for Classify models | ||||
|             fname=self.save_dir / f'train_batch{ni}.jpg', | ||||
|             on_plot=self.on_plot) | ||||
|             images=batch["img"], | ||||
|             batch_idx=torch.arange(len(batch["img"])), | ||||
|             cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models | ||||
|             fname=self.save_dir / f"train_batch{ni}.jpg", | ||||
|             on_plot=self.on_plot, | ||||
|         ) | ||||
|  | ||||
| @ -31,43 +31,42 @@ class ClassificationValidator(BaseValidator): | ||||
|         super().__init__(dataloader, save_dir, pbar, args, _callbacks) | ||||
|         self.targets = None | ||||
|         self.pred = None | ||||
|         self.args.task = 'classify' | ||||
|         self.args.task = "classify" | ||||
|         self.metrics = ClassifyMetrics() | ||||
| 
 | ||||
|     def get_desc(self): | ||||
|         """Returns a formatted string summarizing classification metrics.""" | ||||
|         return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc') | ||||
|         return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc") | ||||
| 
 | ||||
|     def init_metrics(self, model): | ||||
|         """Initialize confusion matrix, class names, and top-1 and top-5 accuracy.""" | ||||
|         self.names = model.names | ||||
|         self.nc = len(model.names) | ||||
|         self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task='classify') | ||||
|         self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify") | ||||
|         self.pred = [] | ||||
|         self.targets = [] | ||||
| 
 | ||||
|     def preprocess(self, batch): | ||||
|         """Preprocesses input batch and returns it.""" | ||||
|         batch['img'] = batch['img'].to(self.device, non_blocking=True) | ||||
|         batch['img'] = batch['img'].half() if self.args.half else batch['img'].float() | ||||
|         batch['cls'] = batch['cls'].to(self.device) | ||||
|         batch["img"] = batch["img"].to(self.device, non_blocking=True) | ||||
|         batch["img"] = batch["img"].half() if self.args.half else batch["img"].float() | ||||
|         batch["cls"] = batch["cls"].to(self.device) | ||||
|         return batch | ||||
| 
 | ||||
|     def update_metrics(self, preds, batch): | ||||
|         """Updates running metrics with model predictions and batch targets.""" | ||||
|         n5 = min(len(self.names), 5) | ||||
|         self.pred.append(preds.argsort(1, descending=True)[:, :n5]) | ||||
|         self.targets.append(batch['cls']) | ||||
|         self.targets.append(batch["cls"]) | ||||
| 
 | ||||
|     def finalize_metrics(self, *args, **kwargs): | ||||
|         """Finalizes metrics of the model such as confusion_matrix and speed.""" | ||||
|         self.confusion_matrix.process_cls_preds(self.pred, self.targets) | ||||
|         if self.args.plots: | ||||
|             for normalize in True, False: | ||||
|                 self.confusion_matrix.plot(save_dir=self.save_dir, | ||||
|                                            names=self.names.values(), | ||||
|                                            normalize=normalize, | ||||
|                                            on_plot=self.on_plot) | ||||
|                 self.confusion_matrix.plot( | ||||
|                     save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot | ||||
|                 ) | ||||
|         self.metrics.speed = self.speed | ||||
|         self.metrics.confusion_matrix = self.confusion_matrix | ||||
|         self.metrics.save_dir = self.save_dir | ||||
| @ -88,24 +87,27 @@ class ClassificationValidator(BaseValidator): | ||||
| 
 | ||||
|     def print_results(self): | ||||
|         """Prints evaluation metrics for YOLO object detection model.""" | ||||
|         pf = '%22s' + '%11.3g' * len(self.metrics.keys)  # print format | ||||
|         LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) | ||||
|         pf = "%22s" + "%11.3g" * len(self.metrics.keys)  # print format | ||||
|         LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5)) | ||||
| 
 | ||||
|     def plot_val_samples(self, batch, ni): | ||||
|         """Plot validation image samples.""" | ||||
|         plot_images( | ||||
|             images=batch['img'], | ||||
|             batch_idx=torch.arange(len(batch['img'])), | ||||
|             cls=batch['cls'].view(-1),  # warning: use .view(), not .squeeze() for Classify models | ||||
|             fname=self.save_dir / f'val_batch{ni}_labels.jpg', | ||||
|             images=batch["img"], | ||||
|             batch_idx=torch.arange(len(batch["img"])), | ||||
|             cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models | ||||
|             fname=self.save_dir / f"val_batch{ni}_labels.jpg", | ||||
|             names=self.names, | ||||
|             on_plot=self.on_plot) | ||||
|             on_plot=self.on_plot, | ||||
|         ) | ||||
| 
 | ||||
|     def plot_predictions(self, batch, preds, ni): | ||||
|         """Plots predicted bounding boxes on input images and saves the result.""" | ||||
|         plot_images(batch['img'], | ||||
|                     batch_idx=torch.arange(len(batch['img'])), | ||||
|         plot_images( | ||||
|             batch["img"], | ||||
|             batch_idx=torch.arange(len(batch["img"])), | ||||
|             cls=torch.argmax(preds, dim=1), | ||||
|                     fname=self.save_dir / f'val_batch{ni}_pred.jpg', | ||||
|             fname=self.save_dir / f"val_batch{ni}_pred.jpg", | ||||
|             names=self.names, | ||||
|                     on_plot=self.on_plot)  # pred | ||||
|             on_plot=self.on_plot, | ||||
|         )  # pred | ||||
|  | ||||
| @ -4,4 +4,4 @@ from .predict import DetectionPredictor | ||||
| from .train import DetectionTrainer | ||||
| from .val import DetectionValidator | ||||
| 
 | ||||
| __all__ = 'DetectionPredictor', 'DetectionTrainer', 'DetectionValidator' | ||||
| __all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator" | ||||
|  | ||||
| @ -22,12 +22,14 @@ class DetectionPredictor(BasePredictor): | ||||
| 
 | ||||
|     def postprocess(self, preds, img, orig_imgs): | ||||
|         """Post-processes predictions and returns a list of Results objects.""" | ||||
|         preds = ops.non_max_suppression(preds, | ||||
|         preds = ops.non_max_suppression( | ||||
|             preds, | ||||
|             self.args.conf, | ||||
|             self.args.iou, | ||||
|             agnostic=self.args.agnostic_nms, | ||||
|             max_det=self.args.max_det, | ||||
|                                         classes=self.args.classes) | ||||
|             classes=self.args.classes, | ||||
|         ) | ||||
| 
 | ||||
|         if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list | ||||
|             orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) | ||||
|  | ||||
| @ -30,7 +30,7 @@ class DetectionTrainer(BaseTrainer): | ||||
|         ``` | ||||
|     """ | ||||
| 
 | ||||
|     def build_dataset(self, img_path, mode='train', batch=None): | ||||
|     def build_dataset(self, img_path, mode="train", batch=None): | ||||
|         """ | ||||
|         Build YOLO Dataset. | ||||
| 
 | ||||
| @ -40,33 +40,37 @@ class DetectionTrainer(BaseTrainer): | ||||
|             batch (int, optional): Size of batches, this is for `rect`. Defaults to None. | ||||
|         """ | ||||
|         gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) | ||||
|         return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs) | ||||
|         return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) | ||||
| 
 | ||||
|     def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): | ||||
|     def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): | ||||
|         """Construct and return dataloader.""" | ||||
|         assert mode in ['train', 'val'] | ||||
|         assert mode in ["train", "val"] | ||||
|         with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP | ||||
|             dataset = self.build_dataset(dataset_path, mode, batch_size) | ||||
|         shuffle = mode == 'train' | ||||
|         if getattr(dataset, 'rect', False) and shuffle: | ||||
|         shuffle = mode == "train" | ||||
|         if getattr(dataset, "rect", False) and shuffle: | ||||
|             LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") | ||||
|             shuffle = False | ||||
|         workers = self.args.workers if mode == 'train' else self.args.workers * 2 | ||||
|         workers = self.args.workers if mode == "train" else self.args.workers * 2 | ||||
|         return build_dataloader(dataset, batch_size, workers, shuffle, rank)  # return dataloader | ||||
| 
 | ||||
|     def preprocess_batch(self, batch): | ||||
|         """Preprocesses a batch of images by scaling and converting to float.""" | ||||
|         batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255 | ||||
|         batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 | ||||
|         if self.args.multi_scale: | ||||
|             imgs = batch['img'] | ||||
|             sz = (random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) // self.stride * | ||||
|                   self.stride)  # size | ||||
|             imgs = batch["img"] | ||||
|             sz = ( | ||||
|                 random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) | ||||
|                 // self.stride | ||||
|                 * self.stride | ||||
|             )  # size | ||||
|             sf = sz / max(imgs.shape[2:])  # scale factor | ||||
|             if sf != 1: | ||||
|                 ns = [math.ceil(x * sf / self.stride) * self.stride | ||||
|                       for x in imgs.shape[2:]]  # new shape (stretched to gs-multiple) | ||||
|                 imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) | ||||
|             batch['img'] = imgs | ||||
|                 ns = [ | ||||
|                     math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:] | ||||
|                 ]  # new shape (stretched to gs-multiple) | ||||
|                 imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) | ||||
|             batch["img"] = imgs | ||||
|         return batch | ||||
| 
 | ||||
|     def set_model_attributes(self): | ||||
| @ -74,33 +78,32 @@ class DetectionTrainer(BaseTrainer): | ||||
|         # self.args.box *= 3 / nl  # scale to layers | ||||
|         # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers | ||||
|         # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers | ||||
|         self.model.nc = self.data['nc']  # attach number of classes to model | ||||
|         self.model.names = self.data['names']  # attach class names to model | ||||
|         self.model.nc = self.data["nc"]  # attach number of classes to model | ||||
|         self.model.names = self.data["names"]  # attach class names to model | ||||
|         self.model.args = self.args  # attach hyperparameters to model | ||||
|         # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc | ||||
| 
 | ||||
|     def get_model(self, cfg=None, weights=None, verbose=True): | ||||
|         """Return a YOLO detection model.""" | ||||
|         model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) | ||||
|         model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) | ||||
|         if weights: | ||||
|             model.load(weights) | ||||
|         return model | ||||
| 
 | ||||
|     def get_validator(self): | ||||
|         """Returns a DetectionValidator for YOLO model validation.""" | ||||
|         self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' | ||||
|         return yolo.detect.DetectionValidator(self.test_loader, | ||||
|                                               save_dir=self.save_dir, | ||||
|                                               args=copy(self.args), | ||||
|                                               _callbacks=self.callbacks) | ||||
|         self.loss_names = "box_loss", "cls_loss", "dfl_loss" | ||||
|         return yolo.detect.DetectionValidator( | ||||
|             self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks | ||||
|         ) | ||||
| 
 | ||||
|     def label_loss_items(self, loss_items=None, prefix='train'): | ||||
|     def label_loss_items(self, loss_items=None, prefix="train"): | ||||
|         """ | ||||
|         Returns a loss dict with labelled training loss items tensor. | ||||
| 
 | ||||
|         Not needed for classification but necessary for segmentation & detection | ||||
|         """ | ||||
|         keys = [f'{prefix}/{x}' for x in self.loss_names] | ||||
|         keys = [f"{prefix}/{x}" for x in self.loss_names] | ||||
|         if loss_items is not None: | ||||
|             loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats | ||||
|             return dict(zip(keys, loss_items)) | ||||
| @ -109,18 +112,25 @@ class DetectionTrainer(BaseTrainer): | ||||
| 
 | ||||
|     def progress_string(self): | ||||
|         """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size.""" | ||||
|         return ('\n' + '%11s' * | ||||
|                 (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') | ||||
|         return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( | ||||
|             "Epoch", | ||||
|             "GPU_mem", | ||||
|             *self.loss_names, | ||||
|             "Instances", | ||||
|             "Size", | ||||
|         ) | ||||
| 
 | ||||
|     def plot_training_samples(self, batch, ni): | ||||
|         """Plots training samples with their annotations.""" | ||||
|         plot_images(images=batch['img'], | ||||
|                     batch_idx=batch['batch_idx'], | ||||
|                     cls=batch['cls'].squeeze(-1), | ||||
|                     bboxes=batch['bboxes'], | ||||
|                     paths=batch['im_file'], | ||||
|                     fname=self.save_dir / f'train_batch{ni}.jpg', | ||||
|                     on_plot=self.on_plot) | ||||
|         plot_images( | ||||
|             images=batch["img"], | ||||
|             batch_idx=batch["batch_idx"], | ||||
|             cls=batch["cls"].squeeze(-1), | ||||
|             bboxes=batch["bboxes"], | ||||
|             paths=batch["im_file"], | ||||
|             fname=self.save_dir / f"train_batch{ni}.jpg", | ||||
|             on_plot=self.on_plot, | ||||
|         ) | ||||
| 
 | ||||
|     def plot_metrics(self): | ||||
|         """Plots metrics from a CSV file.""" | ||||
| @ -128,6 +138,6 @@ class DetectionTrainer(BaseTrainer): | ||||
| 
 | ||||
|     def plot_training_labels(self): | ||||
|         """Create a labeled training plot of the YOLO model.""" | ||||
|         boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0) | ||||
|         cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0) | ||||
|         plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot) | ||||
|         boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0) | ||||
|         cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0) | ||||
|         plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot) | ||||
|  | ||||
| @ -34,7 +34,7 @@ class DetectionValidator(BaseValidator): | ||||
|         self.nt_per_class = None | ||||
|         self.is_coco = False | ||||
|         self.class_map = None | ||||
|         self.args.task = 'detect' | ||||
|         self.args.task = "detect" | ||||
|         self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) | ||||
|         self.iouv = torch.linspace(0.5, 0.95, 10)  # iou vector for mAP@0.5:0.95 | ||||
|         self.niou = self.iouv.numel() | ||||
| @ -42,25 +42,30 @@ class DetectionValidator(BaseValidator): | ||||
| 
 | ||||
|     def preprocess(self, batch): | ||||
|         """Preprocesses batch of images for YOLO training.""" | ||||
|         batch['img'] = batch['img'].to(self.device, non_blocking=True) | ||||
|         batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255 | ||||
|         for k in ['batch_idx', 'cls', 'bboxes']: | ||||
|         batch["img"] = batch["img"].to(self.device, non_blocking=True) | ||||
|         batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 | ||||
|         for k in ["batch_idx", "cls", "bboxes"]: | ||||
|             batch[k] = batch[k].to(self.device) | ||||
| 
 | ||||
|         if self.args.save_hybrid: | ||||
|             height, width = batch['img'].shape[2:] | ||||
|             nb = len(batch['img']) | ||||
|             bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device) | ||||
|             self.lb = [ | ||||
|                 torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1) | ||||
|                 for i in range(nb)] if self.args.save_hybrid else []  # for autolabelling | ||||
|             height, width = batch["img"].shape[2:] | ||||
|             nb = len(batch["img"]) | ||||
|             bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device) | ||||
|             self.lb = ( | ||||
|                 [ | ||||
|                     torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1) | ||||
|                     for i in range(nb) | ||||
|                 ] | ||||
|                 if self.args.save_hybrid | ||||
|                 else [] | ||||
|             )  # for autolabelling | ||||
| 
 | ||||
|         return batch | ||||
| 
 | ||||
|     def init_metrics(self, model): | ||||
|         """Initialize evaluation metrics for YOLO.""" | ||||
|         val = self.data.get(self.args.split, '')  # validation path | ||||
|         self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt')  # is COCO | ||||
|         val = self.data.get(self.args.split, "")  # validation path | ||||
|         self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt")  # is COCO | ||||
|         self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000)) | ||||
|         self.args.save_json |= self.is_coco and not self.training  # run on final val if training COCO | ||||
|         self.names = model.names | ||||
| @ -74,26 +79,28 @@ class DetectionValidator(BaseValidator): | ||||
| 
 | ||||
|     def get_desc(self): | ||||
|         """Return a formatted string summarizing class metrics of YOLO model.""" | ||||
|         return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)') | ||||
|         return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)") | ||||
| 
 | ||||
|     def postprocess(self, preds): | ||||
|         """Apply Non-maximum suppression to prediction outputs.""" | ||||
|         return ops.non_max_suppression(preds, | ||||
|         return ops.non_max_suppression( | ||||
|             preds, | ||||
|             self.args.conf, | ||||
|             self.args.iou, | ||||
|             labels=self.lb, | ||||
|             multi_label=True, | ||||
|             agnostic=self.args.single_cls, | ||||
|                                        max_det=self.args.max_det) | ||||
|             max_det=self.args.max_det, | ||||
|         ) | ||||
| 
 | ||||
|     def _prepare_batch(self, si, batch): | ||||
|         """Prepares a batch of images and annotations for validation.""" | ||||
|         idx = batch['batch_idx'] == si | ||||
|         cls = batch['cls'][idx].squeeze(-1) | ||||
|         bbox = batch['bboxes'][idx] | ||||
|         ori_shape = batch['ori_shape'][si] | ||||
|         imgsz = batch['img'].shape[2:] | ||||
|         ratio_pad = batch['ratio_pad'][si] | ||||
|         idx = batch["batch_idx"] == si | ||||
|         cls = batch["cls"][idx].squeeze(-1) | ||||
|         bbox = batch["bboxes"][idx] | ||||
|         ori_shape = batch["ori_shape"][si] | ||||
|         imgsz = batch["img"].shape[2:] | ||||
|         ratio_pad = batch["ratio_pad"][si] | ||||
|         if len(cls): | ||||
|             bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]  # target boxes | ||||
|             ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad)  # native-space labels | ||||
| @ -103,8 +110,9 @@ class DetectionValidator(BaseValidator): | ||||
|     def _prepare_pred(self, pred, pbatch): | ||||
|         """Prepares a batch of images and annotations for validation.""" | ||||
|         predn = pred.clone() | ||||
|         ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], | ||||
|                         ratio_pad=pbatch['ratio_pad'])  # native-space pred | ||||
|         ops.scale_boxes( | ||||
|             pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"] | ||||
|         )  # native-space pred | ||||
|         return predn | ||||
| 
 | ||||
|     def update_metrics(self, preds, batch): | ||||
| @ -112,19 +120,21 @@ class DetectionValidator(BaseValidator): | ||||
|         for si, pred in enumerate(preds): | ||||
|             self.seen += 1 | ||||
|             npr = len(pred) | ||||
|             stat = dict(conf=torch.zeros(0, device=self.device), | ||||
|             stat = dict( | ||||
|                 conf=torch.zeros(0, device=self.device), | ||||
|                 pred_cls=torch.zeros(0, device=self.device), | ||||
|                         tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)) | ||||
|                 tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), | ||||
|             ) | ||||
|             pbatch = self._prepare_batch(si, batch) | ||||
|             cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox') | ||||
|             cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") | ||||
|             nl = len(cls) | ||||
|             stat['target_cls'] = cls | ||||
|             stat["target_cls"] = cls | ||||
|             if npr == 0: | ||||
|                 if nl: | ||||
|                     for k in self.stats.keys(): | ||||
|                         self.stats[k].append(stat[k]) | ||||
|                     # TODO: obb has not supported confusion_matrix yet. | ||||
|                     if self.args.plots and self.args.task != 'obb': | ||||
|                     if self.args.plots and self.args.task != "obb": | ||||
|                         self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) | ||||
|                 continue | ||||
| 
 | ||||
| @ -132,24 +142,24 @@ class DetectionValidator(BaseValidator): | ||||
|             if self.args.single_cls: | ||||
|                 pred[:, 5] = 0 | ||||
|             predn = self._prepare_pred(pred, pbatch) | ||||
|             stat['conf'] = predn[:, 4] | ||||
|             stat['pred_cls'] = predn[:, 5] | ||||
|             stat["conf"] = predn[:, 4] | ||||
|             stat["pred_cls"] = predn[:, 5] | ||||
| 
 | ||||
|             # Evaluate | ||||
|             if nl: | ||||
|                 stat['tp'] = self._process_batch(predn, bbox, cls) | ||||
|                 stat["tp"] = self._process_batch(predn, bbox, cls) | ||||
|                 # TODO: obb has not supported confusion_matrix yet. | ||||
|                 if self.args.plots and self.args.task != 'obb': | ||||
|                 if self.args.plots and self.args.task != "obb": | ||||
|                     self.confusion_matrix.process_batch(predn, bbox, cls) | ||||
|             for k in self.stats.keys(): | ||||
|                 self.stats[k].append(stat[k]) | ||||
| 
 | ||||
|             # Save | ||||
|             if self.args.save_json: | ||||
|                 self.pred_to_json(predn, batch['im_file'][si]) | ||||
|                 self.pred_to_json(predn, batch["im_file"][si]) | ||||
|             if self.args.save_txt: | ||||
|                 file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt' | ||||
|                 self.save_one_txt(predn, self.args.save_conf, pbatch['ori_shape'], file) | ||||
|                 file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt' | ||||
|                 self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file) | ||||
| 
 | ||||
|     def finalize_metrics(self, *args, **kwargs): | ||||
|         """Set final values for metrics speed and confusion matrix.""" | ||||
| @ -159,19 +169,19 @@ class DetectionValidator(BaseValidator): | ||||
|     def get_stats(self): | ||||
|         """Returns metrics statistics and results dictionary.""" | ||||
|         stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()}  # to numpy | ||||
|         if len(stats) and stats['tp'].any(): | ||||
|         if len(stats) and stats["tp"].any(): | ||||
|             self.metrics.process(**stats) | ||||
|         self.nt_per_class = np.bincount(stats['target_cls'].astype(int), | ||||
|                                         minlength=self.nc)  # number of targets per class | ||||
|         self.nt_per_class = np.bincount( | ||||
|             stats["target_cls"].astype(int), minlength=self.nc | ||||
|         )  # number of targets per class | ||||
|         return self.metrics.results_dict | ||||
| 
 | ||||
|     def print_results(self): | ||||
|         """Prints training/validation set metrics per class.""" | ||||
|         pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys)  # print format | ||||
|         LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) | ||||
|         pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys)  # print format | ||||
|         LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) | ||||
|         if self.nt_per_class.sum() == 0: | ||||
|             LOGGER.warning( | ||||
|                 f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels') | ||||
|             LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels") | ||||
| 
 | ||||
|         # Print results per class | ||||
|         if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): | ||||
| @ -180,10 +190,9 @@ class DetectionValidator(BaseValidator): | ||||
| 
 | ||||
|         if self.args.plots: | ||||
|             for normalize in True, False: | ||||
|                 self.confusion_matrix.plot(save_dir=self.save_dir, | ||||
|                                            names=self.names.values(), | ||||
|                                            normalize=normalize, | ||||
|                                            on_plot=self.on_plot) | ||||
|                 self.confusion_matrix.plot( | ||||
|                     save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot | ||||
|                 ) | ||||
| 
 | ||||
|     def _process_batch(self, detections, gt_bboxes, gt_cls): | ||||
|         """ | ||||
| @ -201,7 +210,7 @@ class DetectionValidator(BaseValidator): | ||||
|         iou = box_iou(gt_bboxes, detections[:, :4]) | ||||
|         return self.match_predictions(detections[:, 5], gt_cls, iou) | ||||
| 
 | ||||
|     def build_dataset(self, img_path, mode='val', batch=None): | ||||
|     def build_dataset(self, img_path, mode="val", batch=None): | ||||
|         """ | ||||
|         Build YOLO Dataset. | ||||
| 
 | ||||
| @ -214,28 +223,32 @@ class DetectionValidator(BaseValidator): | ||||
| 
 | ||||
|     def get_dataloader(self, dataset_path, batch_size): | ||||
|         """Construct and return dataloader.""" | ||||
|         dataset = self.build_dataset(dataset_path, batch=batch_size, mode='val') | ||||
|         dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val") | ||||
|         return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1)  # return dataloader | ||||
| 
 | ||||
|     def plot_val_samples(self, batch, ni): | ||||
|         """Plot validation image samples.""" | ||||
|         plot_images(batch['img'], | ||||
|                     batch['batch_idx'], | ||||
|                     batch['cls'].squeeze(-1), | ||||
|                     batch['bboxes'], | ||||
|                     paths=batch['im_file'], | ||||
|                     fname=self.save_dir / f'val_batch{ni}_labels.jpg', | ||||
|         plot_images( | ||||
|             batch["img"], | ||||
|             batch["batch_idx"], | ||||
|             batch["cls"].squeeze(-1), | ||||
|             batch["bboxes"], | ||||
|             paths=batch["im_file"], | ||||
|             fname=self.save_dir / f"val_batch{ni}_labels.jpg", | ||||
|             names=self.names, | ||||
|                     on_plot=self.on_plot) | ||||
|             on_plot=self.on_plot, | ||||
|         ) | ||||
| 
 | ||||
|     def plot_predictions(self, batch, preds, ni): | ||||
|         """Plots predicted bounding boxes on input images and saves the result.""" | ||||
|         plot_images(batch['img'], | ||||
|         plot_images( | ||||
|             batch["img"], | ||||
|             *output_to_target(preds, max_det=self.args.max_det), | ||||
|                     paths=batch['im_file'], | ||||
|                     fname=self.save_dir / f'val_batch{ni}_pred.jpg', | ||||
|             paths=batch["im_file"], | ||||
|             fname=self.save_dir / f"val_batch{ni}_pred.jpg", | ||||
|             names=self.names, | ||||
|                     on_plot=self.on_plot)  # pred | ||||
|             on_plot=self.on_plot, | ||||
|         )  # pred | ||||
| 
 | ||||
|     def save_one_txt(self, predn, save_conf, shape, file): | ||||
|         """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" | ||||
| @ -243,8 +256,8 @@ class DetectionValidator(BaseValidator): | ||||
|         for *xyxy, conf, cls in predn.tolist(): | ||||
|             xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh | ||||
|             line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format | ||||
|             with open(file, 'a') as f: | ||||
|                 f.write(('%g ' * len(line)).rstrip() % line + '\n') | ||||
|             with open(file, "a") as f: | ||||
|                 f.write(("%g " * len(line)).rstrip() % line + "\n") | ||||
| 
 | ||||
|     def pred_to_json(self, predn, filename): | ||||
|         """Serialize YOLO predictions to COCO json format.""" | ||||
| @ -253,28 +266,31 @@ class DetectionValidator(BaseValidator): | ||||
|         box = ops.xyxy2xywh(predn[:, :4])  # xywh | ||||
|         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner | ||||
|         for p, b in zip(predn.tolist(), box.tolist()): | ||||
|             self.jdict.append({ | ||||
|                 'image_id': image_id, | ||||
|                 'category_id': self.class_map[int(p[5])], | ||||
|                 'bbox': [round(x, 3) for x in b], | ||||
|                 'score': round(p[4], 5)}) | ||||
|             self.jdict.append( | ||||
|                 { | ||||
|                     "image_id": image_id, | ||||
|                     "category_id": self.class_map[int(p[5])], | ||||
|                     "bbox": [round(x, 3) for x in b], | ||||
|                     "score": round(p[4], 5), | ||||
|                 } | ||||
|             ) | ||||
| 
 | ||||
|     def eval_json(self, stats): | ||||
|         """Evaluates YOLO output in JSON format and returns performance statistics.""" | ||||
|         if self.args.save_json and self.is_coco and len(self.jdict): | ||||
|             anno_json = self.data['path'] / 'annotations/instances_val2017.json'  # annotations | ||||
|             pred_json = self.save_dir / 'predictions.json'  # predictions | ||||
|             LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') | ||||
|             anno_json = self.data["path"] / "annotations/instances_val2017.json"  # annotations | ||||
|             pred_json = self.save_dir / "predictions.json"  # predictions | ||||
|             LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") | ||||
|             try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb | ||||
|                 check_requirements('pycocotools>=2.0.6') | ||||
|                 check_requirements("pycocotools>=2.0.6") | ||||
|                 from pycocotools.coco import COCO  # noqa | ||||
|                 from pycocotools.cocoeval import COCOeval  # noqa | ||||
| 
 | ||||
|                 for x in anno_json, pred_json: | ||||
|                     assert x.is_file(), f'{x} file not found' | ||||
|                     assert x.is_file(), f"{x} file not found" | ||||
|                 anno = COCO(str(anno_json))  # init annotations api | ||||
|                 pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path) | ||||
|                 eval = COCOeval(anno, pred, 'bbox') | ||||
|                 eval = COCOeval(anno, pred, "bbox") | ||||
|                 if self.is_coco: | ||||
|                     eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]  # images to eval | ||||
|                 eval.evaluate() | ||||
| @ -282,5 +298,5 @@ class DetectionValidator(BaseValidator): | ||||
|                 eval.summarize() | ||||
|                 stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2]  # update mAP50-95 and mAP50 | ||||
|             except Exception as e: | ||||
|                 LOGGER.warning(f'pycocotools unable to run: {e}') | ||||
|                 LOGGER.warning(f"pycocotools unable to run: {e}") | ||||
|         return stats | ||||
|  | ||||
| @ -12,28 +12,34 @@ class YOLO(Model): | ||||
|     def task_map(self): | ||||
|         """Map head to model, trainer, validator, and predictor classes.""" | ||||
|         return { | ||||
|             'classify': { | ||||
|                 'model': ClassificationModel, | ||||
|                 'trainer': yolo.classify.ClassificationTrainer, | ||||
|                 'validator': yolo.classify.ClassificationValidator, | ||||
|                 'predictor': yolo.classify.ClassificationPredictor, }, | ||||
|             'detect': { | ||||
|                 'model': DetectionModel, | ||||
|                 'trainer': yolo.detect.DetectionTrainer, | ||||
|                 'validator': yolo.detect.DetectionValidator, | ||||
|                 'predictor': yolo.detect.DetectionPredictor, }, | ||||
|             'segment': { | ||||
|                 'model': SegmentationModel, | ||||
|                 'trainer': yolo.segment.SegmentationTrainer, | ||||
|                 'validator': yolo.segment.SegmentationValidator, | ||||
|                 'predictor': yolo.segment.SegmentationPredictor, }, | ||||
|             'pose': { | ||||
|                 'model': PoseModel, | ||||
|                 'trainer': yolo.pose.PoseTrainer, | ||||
|                 'validator': yolo.pose.PoseValidator, | ||||
|                 'predictor': yolo.pose.PosePredictor, }, | ||||
|             'obb': { | ||||
|                 'model': OBBModel, | ||||
|                 'trainer': yolo.obb.OBBTrainer, | ||||
|                 'validator': yolo.obb.OBBValidator, | ||||
|                 'predictor': yolo.obb.OBBPredictor, }, } | ||||
|             "classify": { | ||||
|                 "model": ClassificationModel, | ||||
|                 "trainer": yolo.classify.ClassificationTrainer, | ||||
|                 "validator": yolo.classify.ClassificationValidator, | ||||
|                 "predictor": yolo.classify.ClassificationPredictor, | ||||
|             }, | ||||
|             "detect": { | ||||
|                 "model": DetectionModel, | ||||
|                 "trainer": yolo.detect.DetectionTrainer, | ||||
|                 "validator": yolo.detect.DetectionValidator, | ||||
|                 "predictor": yolo.detect.DetectionPredictor, | ||||
|             }, | ||||
|             "segment": { | ||||
|                 "model": SegmentationModel, | ||||
|                 "trainer": yolo.segment.SegmentationTrainer, | ||||
|                 "validator": yolo.segment.SegmentationValidator, | ||||
|                 "predictor": yolo.segment.SegmentationPredictor, | ||||
|             }, | ||||
|             "pose": { | ||||
|                 "model": PoseModel, | ||||
|                 "trainer": yolo.pose.PoseTrainer, | ||||
|                 "validator": yolo.pose.PoseValidator, | ||||
|                 "predictor": yolo.pose.PosePredictor, | ||||
|             }, | ||||
|             "obb": { | ||||
|                 "model": OBBModel, | ||||
|                 "trainer": yolo.obb.OBBTrainer, | ||||
|                 "validator": yolo.obb.OBBValidator, | ||||
|                 "predictor": yolo.obb.OBBPredictor, | ||||
|             }, | ||||
|         } | ||||
|  | ||||
| @ -4,4 +4,4 @@ from .predict import OBBPredictor | ||||
| from .train import OBBTrainer | ||||
| from .val import OBBValidator | ||||
| 
 | ||||
| __all__ = 'OBBPredictor', 'OBBTrainer', 'OBBValidator' | ||||
| __all__ = "OBBPredictor", "OBBTrainer", "OBBValidator" | ||||
|  | ||||
| @ -25,26 +25,27 @@ class OBBPredictor(DetectionPredictor): | ||||
|     def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): | ||||
|         """Initializes OBBPredictor with optional model and data configuration overrides.""" | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
|         self.args.task = 'obb' | ||||
|         self.args.task = "obb" | ||||
| 
 | ||||
|     def postprocess(self, preds, img, orig_imgs): | ||||
|         """Post-processes predictions and returns a list of Results objects.""" | ||||
|         preds = ops.non_max_suppression(preds, | ||||
|         preds = ops.non_max_suppression( | ||||
|             preds, | ||||
|             self.args.conf, | ||||
|             self.args.iou, | ||||
|             agnostic=self.args.agnostic_nms, | ||||
|             max_det=self.args.max_det, | ||||
|             nc=len(self.model.names), | ||||
|             classes=self.args.classes, | ||||
|                                         rotated=True) | ||||
|             rotated=True, | ||||
|         ) | ||||
| 
 | ||||
|         if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list | ||||
|             orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) | ||||
| 
 | ||||
|         results = [] | ||||
|         for i, (pred, orig_img) in enumerate(zip(preds, orig_imgs)): | ||||
|         for i, (pred, orig_img, img_path) in enumerate(zip(preds, orig_imgs, self.batch[0])): | ||||
|             pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True) | ||||
|             img_path = self.batch[0][i] | ||||
|             # xywh, r, conf, cls | ||||
|             obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1) | ||||
|             results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb)) | ||||
|  | ||||
| @ -25,12 +25,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer): | ||||
|         """Initialize a OBBTrainer object with given arguments.""" | ||||
|         if overrides is None: | ||||
|             overrides = {} | ||||
|         overrides['task'] = 'obb' | ||||
|         overrides["task"] = "obb" | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
| 
 | ||||
|     def get_model(self, cfg=None, weights=None, verbose=True): | ||||
|         """Return OBBModel initialized with specified config and weights.""" | ||||
|         model = OBBModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) | ||||
|         model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) | ||||
|         if weights: | ||||
|             model.load(weights) | ||||
| 
 | ||||
| @ -38,5 +38,5 @@ class OBBTrainer(yolo.detect.DetectionTrainer): | ||||
| 
 | ||||
|     def get_validator(self): | ||||
|         """Return an instance of OBBValidator for validation of YOLO model.""" | ||||
|         self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' | ||||
|         self.loss_names = "box_loss", "cls_loss", "dfl_loss" | ||||
|         return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) | ||||
|  | ||||
| @ -27,18 +27,19 @@ class OBBValidator(DetectionValidator): | ||||
|     def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): | ||||
|         """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.""" | ||||
|         super().__init__(dataloader, save_dir, pbar, args, _callbacks) | ||||
|         self.args.task = 'obb' | ||||
|         self.args.task = "obb" | ||||
|         self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot) | ||||
| 
 | ||||
|     def init_metrics(self, model): | ||||
|         """Initialize evaluation metrics for YOLO.""" | ||||
|         super().init_metrics(model) | ||||
|         val = self.data.get(self.args.split, '')  # validation path | ||||
|         self.is_dota = isinstance(val, str) and 'DOTA' in val  # is COCO | ||||
|         val = self.data.get(self.args.split, "")  # validation path | ||||
|         self.is_dota = isinstance(val, str) and "DOTA" in val  # is COCO | ||||
| 
 | ||||
|     def postprocess(self, preds): | ||||
|         """Apply Non-maximum suppression to prediction outputs.""" | ||||
|         return ops.non_max_suppression(preds, | ||||
|         return ops.non_max_suppression( | ||||
|             preds, | ||||
|             self.args.conf, | ||||
|             self.args.iou, | ||||
|             labels=self.lb, | ||||
| @ -46,7 +47,8 @@ class OBBValidator(DetectionValidator): | ||||
|             multi_label=True, | ||||
|             agnostic=self.args.single_cls, | ||||
|             max_det=self.args.max_det, | ||||
|                                        rotated=True) | ||||
|             rotated=True, | ||||
|         ) | ||||
| 
 | ||||
|     def _process_batch(self, detections, gt_bboxes, gt_cls): | ||||
|         """ | ||||
| @ -66,12 +68,12 @@ class OBBValidator(DetectionValidator): | ||||
| 
 | ||||
|     def _prepare_batch(self, si, batch): | ||||
|         """Prepares and returns a batch for OBB validation.""" | ||||
|         idx = batch['batch_idx'] == si | ||||
|         cls = batch['cls'][idx].squeeze(-1) | ||||
|         bbox = batch['bboxes'][idx] | ||||
|         ori_shape = batch['ori_shape'][si] | ||||
|         imgsz = batch['img'].shape[2:] | ||||
|         ratio_pad = batch['ratio_pad'][si] | ||||
|         idx = batch["batch_idx"] == si | ||||
|         cls = batch["cls"][idx].squeeze(-1) | ||||
|         bbox = batch["bboxes"][idx] | ||||
|         ori_shape = batch["ori_shape"][si] | ||||
|         imgsz = batch["img"].shape[2:] | ||||
|         ratio_pad = batch["ratio_pad"][si] | ||||
|         if len(cls): | ||||
|             bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]])  # target boxes | ||||
|             ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True)  # native-space labels | ||||
| @ -81,18 +83,21 @@ class OBBValidator(DetectionValidator): | ||||
|     def _prepare_pred(self, pred, pbatch): | ||||
|         """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes.""" | ||||
|         predn = pred.clone() | ||||
|         ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'], | ||||
|                         xywh=True)  # native-space pred | ||||
|         ops.scale_boxes( | ||||
|             pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True | ||||
|         )  # native-space pred | ||||
|         return predn | ||||
| 
 | ||||
|     def plot_predictions(self, batch, preds, ni): | ||||
|         """Plots predicted bounding boxes on input images and saves the result.""" | ||||
|         plot_images(batch['img'], | ||||
|         plot_images( | ||||
|             batch["img"], | ||||
|             *output_to_rotated_target(preds, max_det=self.args.max_det), | ||||
|                     paths=batch['im_file'], | ||||
|                     fname=self.save_dir / f'val_batch{ni}_pred.jpg', | ||||
|             paths=batch["im_file"], | ||||
|             fname=self.save_dir / f"val_batch{ni}_pred.jpg", | ||||
|             names=self.names, | ||||
|                     on_plot=self.on_plot)  # pred | ||||
|             on_plot=self.on_plot, | ||||
|         )  # pred | ||||
| 
 | ||||
|     def pred_to_json(self, predn, filename): | ||||
|         """Serialize YOLO predictions to COCO json format.""" | ||||
| @ -101,12 +106,15 @@ class OBBValidator(DetectionValidator): | ||||
|         rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1) | ||||
|         poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8) | ||||
|         for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())): | ||||
|             self.jdict.append({ | ||||
|                 'image_id': image_id, | ||||
|                 'category_id': self.class_map[int(predn[i, 5].item())], | ||||
|                 'score': round(predn[i, 4].item(), 5), | ||||
|                 'rbox': [round(x, 3) for x in r], | ||||
|                 'poly': [round(x, 3) for x in b]}) | ||||
|             self.jdict.append( | ||||
|                 { | ||||
|                     "image_id": image_id, | ||||
|                     "category_id": self.class_map[int(predn[i, 5].item())], | ||||
|                     "score": round(predn[i, 4].item(), 5), | ||||
|                     "rbox": [round(x, 3) for x in r], | ||||
|                     "poly": [round(x, 3) for x in b], | ||||
|                 } | ||||
|             ) | ||||
| 
 | ||||
|     def save_one_txt(self, predn, save_conf, shape, file): | ||||
|         """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" | ||||
| @ -116,8 +124,8 @@ class OBBValidator(DetectionValidator): | ||||
|             xywha[:, :4] /= gn | ||||
|             xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist()  # normalized xywh | ||||
|             line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy)  # label format | ||||
|             with open(file, 'a') as f: | ||||
|                 f.write(('%g ' * len(line)).rstrip() % line + '\n') | ||||
|             with open(file, "a") as f: | ||||
|                 f.write(("%g " * len(line)).rstrip() % line + "\n") | ||||
| 
 | ||||
|     def eval_json(self, stats): | ||||
|         """Evaluates YOLO output in JSON format and returns performance statistics.""" | ||||
| @ -125,42 +133,43 @@ class OBBValidator(DetectionValidator): | ||||
|             import json | ||||
|             import re | ||||
|             from collections import defaultdict | ||||
|             pred_json = self.save_dir / 'predictions.json'  # predictions | ||||
|             pred_txt = self.save_dir / 'predictions_txt'  # predictions | ||||
| 
 | ||||
|             pred_json = self.save_dir / "predictions.json"  # predictions | ||||
|             pred_txt = self.save_dir / "predictions_txt"  # predictions | ||||
|             pred_txt.mkdir(parents=True, exist_ok=True) | ||||
|             data = json.load(open(pred_json)) | ||||
|             # Save split results | ||||
|             LOGGER.info(f'Saving predictions with DOTA format to {str(pred_txt)}...') | ||||
|             LOGGER.info(f"Saving predictions with DOTA format to {str(pred_txt)}...") | ||||
|             for d in data: | ||||
|                 image_id = d['image_id'] | ||||
|                 score = d['score'] | ||||
|                 classname = self.names[d['category_id']].replace(' ', '-') | ||||
|                 image_id = d["image_id"] | ||||
|                 score = d["score"] | ||||
|                 classname = self.names[d["category_id"]].replace(" ", "-") | ||||
| 
 | ||||
|                 lines = '{} {} {} {} {} {} {} {} {} {}\n'.format( | ||||
|                 lines = "{} {} {} {} {} {} {} {} {} {}\n".format( | ||||
|                     image_id, | ||||
|                     score, | ||||
|                     d['poly'][0], | ||||
|                     d['poly'][1], | ||||
|                     d['poly'][2], | ||||
|                     d['poly'][3], | ||||
|                     d['poly'][4], | ||||
|                     d['poly'][5], | ||||
|                     d['poly'][6], | ||||
|                     d['poly'][7], | ||||
|                     d["poly"][0], | ||||
|                     d["poly"][1], | ||||
|                     d["poly"][2], | ||||
|                     d["poly"][3], | ||||
|                     d["poly"][4], | ||||
|                     d["poly"][5], | ||||
|                     d["poly"][6], | ||||
|                     d["poly"][7], | ||||
|                 ) | ||||
|                 with open(str(pred_txt / f'Task1_{classname}') + '.txt', 'a') as f: | ||||
|                 with open(str(pred_txt / f"Task1_{classname}") + ".txt", "a") as f: | ||||
|                     f.writelines(lines) | ||||
|             # Save merged results, this could result slightly lower map than using official merging script, | ||||
|             # because of the probiou calculation. | ||||
|             pred_merged_txt = self.save_dir / 'predictions_merged_txt'  # predictions | ||||
|             pred_merged_txt = self.save_dir / "predictions_merged_txt"  # predictions | ||||
|             pred_merged_txt.mkdir(parents=True, exist_ok=True) | ||||
|             merged_results = defaultdict(list) | ||||
|             LOGGER.info(f'Saving merged predictions with DOTA format to {str(pred_merged_txt)}...') | ||||
|             LOGGER.info(f"Saving merged predictions with DOTA format to {str(pred_merged_txt)}...") | ||||
|             for d in data: | ||||
|                 image_id = d['image_id'].split('__')[0] | ||||
|                 pattern = re.compile(r'\d+___\d+') | ||||
|                 x, y = (int(c) for c in re.findall(pattern, d['image_id'])[0].split('___')) | ||||
|                 bbox, score, cls = d['rbox'], d['score'], d['category_id'] | ||||
|                 image_id = d["image_id"].split("__")[0] | ||||
|                 pattern = re.compile(r"\d+___\d+") | ||||
|                 x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___")) | ||||
|                 bbox, score, cls = d["rbox"], d["score"], d["category_id"] | ||||
|                 bbox[0] += x | ||||
|                 bbox[1] += y | ||||
|                 bbox.extend([score, cls]) | ||||
| @ -178,11 +187,11 @@ class OBBValidator(DetectionValidator): | ||||
| 
 | ||||
|                 b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8) | ||||
|                 for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist(): | ||||
|                     classname = self.names[int(x[-1])].replace(' ', '-') | ||||
|                     classname = self.names[int(x[-1])].replace(" ", "-") | ||||
|                     poly = [round(i, 3) for i in x[:-2]] | ||||
|                     score = round(x[-2], 3) | ||||
| 
 | ||||
|                     lines = '{} {} {} {} {} {} {} {} {} {}\n'.format( | ||||
|                     lines = "{} {} {} {} {} {} {} {} {} {}\n".format( | ||||
|                         image_id, | ||||
|                         score, | ||||
|                         poly[0], | ||||
| @ -194,7 +203,7 @@ class OBBValidator(DetectionValidator): | ||||
|                         poly[6], | ||||
|                         poly[7], | ||||
|                     ) | ||||
|                     with open(str(pred_merged_txt / f'Task1_{classname}') + '.txt', 'a') as f: | ||||
|                     with open(str(pred_merged_txt / f"Task1_{classname}") + ".txt", "a") as f: | ||||
|                         f.writelines(lines) | ||||
| 
 | ||||
|         return stats | ||||
|  | ||||
| @ -4,4 +4,4 @@ from .predict import PosePredictor | ||||
| from .train import PoseTrainer | ||||
| from .val import PoseValidator | ||||
| 
 | ||||
| __all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor' | ||||
| __all__ = "PoseTrainer", "PoseValidator", "PosePredictor" | ||||
|  | ||||
| @ -23,20 +23,24 @@ class PosePredictor(DetectionPredictor): | ||||
|     def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): | ||||
|         """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device.""" | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
|         self.args.task = 'pose' | ||||
|         if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': | ||||
|             LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " | ||||
|                            'See https://github.com/ultralytics/ultralytics/issues/4031.') | ||||
|         self.args.task = "pose" | ||||
|         if isinstance(self.args.device, str) and self.args.device.lower() == "mps": | ||||
|             LOGGER.warning( | ||||
|                 "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " | ||||
|                 "See https://github.com/ultralytics/ultralytics/issues/4031." | ||||
|             ) | ||||
| 
 | ||||
|     def postprocess(self, preds, img, orig_imgs): | ||||
|         """Return detection results for a given input image or list of images.""" | ||||
|         preds = ops.non_max_suppression(preds, | ||||
|         preds = ops.non_max_suppression( | ||||
|             preds, | ||||
|             self.args.conf, | ||||
|             self.args.iou, | ||||
|             agnostic=self.args.agnostic_nms, | ||||
|             max_det=self.args.max_det, | ||||
|             classes=self.args.classes, | ||||
|                                         nc=len(self.model.names)) | ||||
|             nc=len(self.model.names), | ||||
|         ) | ||||
| 
 | ||||
|         if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list | ||||
|             orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) | ||||
| @ -49,5 +53,6 @@ class PosePredictor(DetectionPredictor): | ||||
|             pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape) | ||||
|             img_path = self.batch[0][i] | ||||
|             results.append( | ||||
|                 Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)) | ||||
|                 Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts) | ||||
|             ) | ||||
|         return results | ||||
|  | ||||
| @ -26,16 +26,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer): | ||||
|         """Initialize a PoseTrainer object with specified configurations and overrides.""" | ||||
|         if overrides is None: | ||||
|             overrides = {} | ||||
|         overrides['task'] = 'pose' | ||||
|         overrides["task"] = "pose" | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
| 
 | ||||
|         if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': | ||||
|             LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " | ||||
|                            'See https://github.com/ultralytics/ultralytics/issues/4031.') | ||||
|         if isinstance(self.args.device, str) and self.args.device.lower() == "mps": | ||||
|             LOGGER.warning( | ||||
|                 "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " | ||||
|                 "See https://github.com/ultralytics/ultralytics/issues/4031." | ||||
|             ) | ||||
| 
 | ||||
|     def get_model(self, cfg=None, weights=None, verbose=True): | ||||
|         """Get pose estimation model with specified configuration and weights.""" | ||||
|         model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose) | ||||
|         model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose) | ||||
|         if weights: | ||||
|             model.load(weights) | ||||
| 
 | ||||
| @ -44,32 +46,33 @@ class PoseTrainer(yolo.detect.DetectionTrainer): | ||||
|     def set_model_attributes(self): | ||||
|         """Sets keypoints shape attribute of PoseModel.""" | ||||
|         super().set_model_attributes() | ||||
|         self.model.kpt_shape = self.data['kpt_shape'] | ||||
|         self.model.kpt_shape = self.data["kpt_shape"] | ||||
| 
 | ||||
|     def get_validator(self): | ||||
|         """Returns an instance of the PoseValidator class for validation.""" | ||||
|         self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' | ||||
|         return yolo.pose.PoseValidator(self.test_loader, | ||||
|                                        save_dir=self.save_dir, | ||||
|                                        args=copy(self.args), | ||||
|                                        _callbacks=self.callbacks) | ||||
|         self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss" | ||||
|         return yolo.pose.PoseValidator( | ||||
|             self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks | ||||
|         ) | ||||
| 
 | ||||
|     def plot_training_samples(self, batch, ni): | ||||
|         """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" | ||||
|         images = batch['img'] | ||||
|         kpts = batch['keypoints'] | ||||
|         cls = batch['cls'].squeeze(-1) | ||||
|         bboxes = batch['bboxes'] | ||||
|         paths = batch['im_file'] | ||||
|         batch_idx = batch['batch_idx'] | ||||
|         plot_images(images, | ||||
|         images = batch["img"] | ||||
|         kpts = batch["keypoints"] | ||||
|         cls = batch["cls"].squeeze(-1) | ||||
|         bboxes = batch["bboxes"] | ||||
|         paths = batch["im_file"] | ||||
|         batch_idx = batch["batch_idx"] | ||||
|         plot_images( | ||||
|             images, | ||||
|             batch_idx, | ||||
|             cls, | ||||
|             bboxes, | ||||
|             kpts=kpts, | ||||
|             paths=paths, | ||||
|                     fname=self.save_dir / f'train_batch{ni}.jpg', | ||||
|                     on_plot=self.on_plot) | ||||
|             fname=self.save_dir / f"train_batch{ni}.jpg", | ||||
|             on_plot=self.on_plot, | ||||
|         ) | ||||
| 
 | ||||
|     def plot_metrics(self): | ||||
|         """Plots training/val metrics.""" | ||||
|  | ||||
| @ -31,38 +31,53 @@ class PoseValidator(DetectionValidator): | ||||
|         super().__init__(dataloader, save_dir, pbar, args, _callbacks) | ||||
|         self.sigma = None | ||||
|         self.kpt_shape = None | ||||
|         self.args.task = 'pose' | ||||
|         self.args.task = "pose" | ||||
|         self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot) | ||||
|         if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': | ||||
|             LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " | ||||
|                            'See https://github.com/ultralytics/ultralytics/issues/4031.') | ||||
|         if isinstance(self.args.device, str) and self.args.device.lower() == "mps": | ||||
|             LOGGER.warning( | ||||
|                 "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " | ||||
|                 "See https://github.com/ultralytics/ultralytics/issues/4031." | ||||
|             ) | ||||
| 
 | ||||
|     def preprocess(self, batch): | ||||
|         """Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device.""" | ||||
|         batch = super().preprocess(batch) | ||||
|         batch['keypoints'] = batch['keypoints'].to(self.device).float() | ||||
|         batch["keypoints"] = batch["keypoints"].to(self.device).float() | ||||
|         return batch | ||||
| 
 | ||||
|     def get_desc(self): | ||||
|         """Returns description of evaluation metrics in string format.""" | ||||
|         return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P', | ||||
|                                          'R', 'mAP50', 'mAP50-95)') | ||||
|         return ("%22s" + "%11s" * 10) % ( | ||||
|             "Class", | ||||
|             "Images", | ||||
|             "Instances", | ||||
|             "Box(P", | ||||
|             "R", | ||||
|             "mAP50", | ||||
|             "mAP50-95)", | ||||
|             "Pose(P", | ||||
|             "R", | ||||
|             "mAP50", | ||||
|             "mAP50-95)", | ||||
|         ) | ||||
| 
 | ||||
|     def postprocess(self, preds): | ||||
|         """Apply non-maximum suppression and return detections with high confidence scores.""" | ||||
|         return ops.non_max_suppression(preds, | ||||
|         return ops.non_max_suppression( | ||||
|             preds, | ||||
|             self.args.conf, | ||||
|             self.args.iou, | ||||
|             labels=self.lb, | ||||
|             multi_label=True, | ||||
|             agnostic=self.args.single_cls, | ||||
|             max_det=self.args.max_det, | ||||
|                                        nc=self.nc) | ||||
|             nc=self.nc, | ||||
|         ) | ||||
| 
 | ||||
|     def init_metrics(self, model): | ||||
|         """Initiate pose estimation metrics for YOLO model.""" | ||||
|         super().init_metrics(model) | ||||
|         self.kpt_shape = self.data['kpt_shape'] | ||||
|         self.kpt_shape = self.data["kpt_shape"] | ||||
|         is_pose = self.kpt_shape == [17, 3] | ||||
|         nkpt = self.kpt_shape[0] | ||||
|         self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt | ||||
| @ -71,21 +86,21 @@ class PoseValidator(DetectionValidator): | ||||
|     def _prepare_batch(self, si, batch): | ||||
|         """Prepares a batch for processing by converting keypoints to float and moving to device.""" | ||||
|         pbatch = super()._prepare_batch(si, batch) | ||||
|         kpts = batch['keypoints'][batch['batch_idx'] == si] | ||||
|         h, w = pbatch['imgsz'] | ||||
|         kpts = batch["keypoints"][batch["batch_idx"] == si] | ||||
|         h, w = pbatch["imgsz"] | ||||
|         kpts = kpts.clone() | ||||
|         kpts[..., 0] *= w | ||||
|         kpts[..., 1] *= h | ||||
|         kpts = ops.scale_coords(pbatch['imgsz'], kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad']) | ||||
|         pbatch['kpts'] = kpts | ||||
|         kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]) | ||||
|         pbatch["kpts"] = kpts | ||||
|         return pbatch | ||||
| 
 | ||||
|     def _prepare_pred(self, pred, pbatch): | ||||
|         """Prepares and scales keypoints in a batch for pose processing.""" | ||||
|         predn = super()._prepare_pred(pred, pbatch) | ||||
|         nk = pbatch['kpts'].shape[1] | ||||
|         nk = pbatch["kpts"].shape[1] | ||||
|         pred_kpts = predn[:, 6:].view(len(predn), nk, -1) | ||||
|         ops.scale_coords(pbatch['imgsz'], pred_kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad']) | ||||
|         ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]) | ||||
|         return predn, pred_kpts | ||||
| 
 | ||||
|     def update_metrics(self, preds, batch): | ||||
| @ -93,14 +108,16 @@ class PoseValidator(DetectionValidator): | ||||
|         for si, pred in enumerate(preds): | ||||
|             self.seen += 1 | ||||
|             npr = len(pred) | ||||
|             stat = dict(conf=torch.zeros(0, device=self.device), | ||||
|             stat = dict( | ||||
|                 conf=torch.zeros(0, device=self.device), | ||||
|                 pred_cls=torch.zeros(0, device=self.device), | ||||
|                 tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), | ||||
|                         tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)) | ||||
|                 tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), | ||||
|             ) | ||||
|             pbatch = self._prepare_batch(si, batch) | ||||
|             cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox') | ||||
|             cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") | ||||
|             nl = len(cls) | ||||
|             stat['target_cls'] = cls | ||||
|             stat["target_cls"] = cls | ||||
|             if npr == 0: | ||||
|                 if nl: | ||||
|                     for k in self.stats.keys(): | ||||
| @ -113,13 +130,13 @@ class PoseValidator(DetectionValidator): | ||||
|             if self.args.single_cls: | ||||
|                 pred[:, 5] = 0 | ||||
|             predn, pred_kpts = self._prepare_pred(pred, pbatch) | ||||
|             stat['conf'] = predn[:, 4] | ||||
|             stat['pred_cls'] = predn[:, 5] | ||||
|             stat["conf"] = predn[:, 4] | ||||
|             stat["pred_cls"] = predn[:, 5] | ||||
| 
 | ||||
|             # Evaluate | ||||
|             if nl: | ||||
|                 stat['tp'] = self._process_batch(predn, bbox, cls) | ||||
|                 stat['tp_p'] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch['kpts']) | ||||
|                 stat["tp"] = self._process_batch(predn, bbox, cls) | ||||
|                 stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"]) | ||||
|                 if self.args.plots: | ||||
|                     self.confusion_matrix.process_batch(predn, bbox, cls) | ||||
| 
 | ||||
| @ -128,7 +145,7 @@ class PoseValidator(DetectionValidator): | ||||
| 
 | ||||
|             # Save | ||||
|             if self.args.save_json: | ||||
|                 self.pred_to_json(predn, batch['im_file'][si]) | ||||
|                 self.pred_to_json(predn, batch["im_file"][si]) | ||||
|             # if self.args.save_txt: | ||||
|             #    save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') | ||||
| 
 | ||||
| @ -159,26 +176,30 @@ class PoseValidator(DetectionValidator): | ||||
| 
 | ||||
|     def plot_val_samples(self, batch, ni): | ||||
|         """Plots and saves validation set samples with predicted bounding boxes and keypoints.""" | ||||
|         plot_images(batch['img'], | ||||
|                     batch['batch_idx'], | ||||
|                     batch['cls'].squeeze(-1), | ||||
|                     batch['bboxes'], | ||||
|                     kpts=batch['keypoints'], | ||||
|                     paths=batch['im_file'], | ||||
|                     fname=self.save_dir / f'val_batch{ni}_labels.jpg', | ||||
|         plot_images( | ||||
|             batch["img"], | ||||
|             batch["batch_idx"], | ||||
|             batch["cls"].squeeze(-1), | ||||
|             batch["bboxes"], | ||||
|             kpts=batch["keypoints"], | ||||
|             paths=batch["im_file"], | ||||
|             fname=self.save_dir / f"val_batch{ni}_labels.jpg", | ||||
|             names=self.names, | ||||
|                     on_plot=self.on_plot) | ||||
|             on_plot=self.on_plot, | ||||
|         ) | ||||
| 
 | ||||
|     def plot_predictions(self, batch, preds, ni): | ||||
|         """Plots predictions for YOLO model.""" | ||||
|         pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0) | ||||
|         plot_images(batch['img'], | ||||
|         plot_images( | ||||
|             batch["img"], | ||||
|             *output_to_target(preds, max_det=self.args.max_det), | ||||
|             kpts=pred_kpts, | ||||
|                     paths=batch['im_file'], | ||||
|                     fname=self.save_dir / f'val_batch{ni}_pred.jpg', | ||||
|             paths=batch["im_file"], | ||||
|             fname=self.save_dir / f"val_batch{ni}_pred.jpg", | ||||
|             names=self.names, | ||||
|                     on_plot=self.on_plot)  # pred | ||||
|             on_plot=self.on_plot, | ||||
|         )  # pred | ||||
| 
 | ||||
|     def pred_to_json(self, predn, filename): | ||||
|         """Converts YOLO predictions to COCO JSON format.""" | ||||
| @ -187,37 +208,41 @@ class PoseValidator(DetectionValidator): | ||||
|         box = ops.xyxy2xywh(predn[:, :4])  # xywh | ||||
|         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner | ||||
|         for p, b in zip(predn.tolist(), box.tolist()): | ||||
|             self.jdict.append({ | ||||
|                 'image_id': image_id, | ||||
|                 'category_id': self.class_map[int(p[5])], | ||||
|                 'bbox': [round(x, 3) for x in b], | ||||
|                 'keypoints': p[6:], | ||||
|                 'score': round(p[4], 5)}) | ||||
|             self.jdict.append( | ||||
|                 { | ||||
|                     "image_id": image_id, | ||||
|                     "category_id": self.class_map[int(p[5])], | ||||
|                     "bbox": [round(x, 3) for x in b], | ||||
|                     "keypoints": p[6:], | ||||
|                     "score": round(p[4], 5), | ||||
|                 } | ||||
|             ) | ||||
| 
 | ||||
|     def eval_json(self, stats): | ||||
|         """Evaluates object detection model using COCO JSON format.""" | ||||
|         if self.args.save_json and self.is_coco and len(self.jdict): | ||||
|             anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json'  # annotations | ||||
|             pred_json = self.save_dir / 'predictions.json'  # predictions | ||||
|             LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') | ||||
|             anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json"  # annotations | ||||
|             pred_json = self.save_dir / "predictions.json"  # predictions | ||||
|             LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") | ||||
|             try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb | ||||
|                 check_requirements('pycocotools>=2.0.6') | ||||
|                 check_requirements("pycocotools>=2.0.6") | ||||
|                 from pycocotools.coco import COCO  # noqa | ||||
|                 from pycocotools.cocoeval import COCOeval  # noqa | ||||
| 
 | ||||
|                 for x in anno_json, pred_json: | ||||
|                     assert x.is_file(), f'{x} file not found' | ||||
|                     assert x.is_file(), f"{x} file not found" | ||||
|                 anno = COCO(str(anno_json))  # init annotations api | ||||
|                 pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path) | ||||
|                 for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]): | ||||
|                 for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]): | ||||
|                     if self.is_coco: | ||||
|                         eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]  # im to eval | ||||
|                     eval.evaluate() | ||||
|                     eval.accumulate() | ||||
|                     eval.summarize() | ||||
|                     idx = i * 4 + 2 | ||||
|                     stats[self.metrics.keys[idx + 1]], stats[ | ||||
|                         self.metrics.keys[idx]] = eval.stats[:2]  # update mAP50-95 and mAP50 | ||||
|                     stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[ | ||||
|                         :2 | ||||
|                     ]  # update mAP50-95 and mAP50 | ||||
|             except Exception as e: | ||||
|                 LOGGER.warning(f'pycocotools unable to run: {e}') | ||||
|                 LOGGER.warning(f"pycocotools unable to run: {e}") | ||||
|         return stats | ||||
|  | ||||
| @ -4,4 +4,4 @@ from .predict import SegmentationPredictor | ||||
| from .train import SegmentationTrainer | ||||
| from .val import SegmentationValidator | ||||
| 
 | ||||
| __all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator' | ||||
| __all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator" | ||||
|  | ||||
| @ -23,17 +23,19 @@ class SegmentationPredictor(DetectionPredictor): | ||||
|     def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): | ||||
|         """Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks.""" | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
|         self.args.task = 'segment' | ||||
|         self.args.task = "segment" | ||||
| 
 | ||||
|     def postprocess(self, preds, img, orig_imgs): | ||||
|         """Applies non-max suppression and processes detections for each image in an input batch.""" | ||||
|         p = ops.non_max_suppression(preds[0], | ||||
|         p = ops.non_max_suppression( | ||||
|             preds[0], | ||||
|             self.args.conf, | ||||
|             self.args.iou, | ||||
|             agnostic=self.args.agnostic_nms, | ||||
|             max_det=self.args.max_det, | ||||
|             nc=len(self.model.names), | ||||
|                                     classes=self.args.classes) | ||||
|             classes=self.args.classes, | ||||
|         ) | ||||
| 
 | ||||
|         if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list | ||||
|             orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) | ||||
|  | ||||
| @ -26,12 +26,12 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer): | ||||
|         """Initialize a SegmentationTrainer object with given arguments.""" | ||||
|         if overrides is None: | ||||
|             overrides = {} | ||||
|         overrides['task'] = 'segment' | ||||
|         overrides["task"] = "segment" | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
| 
 | ||||
|     def get_model(self, cfg=None, weights=None, verbose=True): | ||||
|         """Return SegmentationModel initialized with specified config and weights.""" | ||||
|         model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) | ||||
|         model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) | ||||
|         if weights: | ||||
|             model.load(weights) | ||||
| 
 | ||||
| @ -39,22 +39,23 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer): | ||||
| 
 | ||||
|     def get_validator(self): | ||||
|         """Return an instance of SegmentationValidator for validation of YOLO model.""" | ||||
|         self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' | ||||
|         return yolo.segment.SegmentationValidator(self.test_loader, | ||||
|                                                   save_dir=self.save_dir, | ||||
|                                                   args=copy(self.args), | ||||
|                                                   _callbacks=self.callbacks) | ||||
|         self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss" | ||||
|         return yolo.segment.SegmentationValidator( | ||||
|             self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks | ||||
|         ) | ||||
| 
 | ||||
|     def plot_training_samples(self, batch, ni): | ||||
|         """Creates a plot of training sample images with labels and box coordinates.""" | ||||
|         plot_images(batch['img'], | ||||
|                     batch['batch_idx'], | ||||
|                     batch['cls'].squeeze(-1), | ||||
|                     batch['bboxes'], | ||||
|                     masks=batch['masks'], | ||||
|                     paths=batch['im_file'], | ||||
|                     fname=self.save_dir / f'train_batch{ni}.jpg', | ||||
|                     on_plot=self.on_plot) | ||||
|         plot_images( | ||||
|             batch["img"], | ||||
|             batch["batch_idx"], | ||||
|             batch["cls"].squeeze(-1), | ||||
|             batch["bboxes"], | ||||
|             masks=batch["masks"], | ||||
|             paths=batch["im_file"], | ||||
|             fname=self.save_dir / f"train_batch{ni}.jpg", | ||||
|             on_plot=self.on_plot, | ||||
|         ) | ||||
| 
 | ||||
|     def plot_metrics(self): | ||||
|         """Plots training/val metrics.""" | ||||
|  | ||||
| @ -33,13 +33,13 @@ class SegmentationValidator(DetectionValidator): | ||||
|         super().__init__(dataloader, save_dir, pbar, args, _callbacks) | ||||
|         self.plot_masks = None | ||||
|         self.process = None | ||||
|         self.args.task = 'segment' | ||||
|         self.args.task = "segment" | ||||
|         self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) | ||||
| 
 | ||||
|     def preprocess(self, batch): | ||||
|         """Preprocesses batch by converting masks to float and sending to device.""" | ||||
|         batch = super().preprocess(batch) | ||||
|         batch['masks'] = batch['masks'].to(self.device).float() | ||||
|         batch["masks"] = batch["masks"].to(self.device).float() | ||||
|         return batch | ||||
| 
 | ||||
|     def init_metrics(self, model): | ||||
| @ -47,7 +47,7 @@ class SegmentationValidator(DetectionValidator): | ||||
|         super().init_metrics(model) | ||||
|         self.plot_masks = [] | ||||
|         if self.args.save_json: | ||||
|             check_requirements('pycocotools>=2.0.6') | ||||
|             check_requirements("pycocotools>=2.0.6") | ||||
|             self.process = ops.process_mask_upsample  # more accurate | ||||
|         else: | ||||
|             self.process = ops.process_mask  # faster | ||||
| @ -55,33 +55,46 @@ class SegmentationValidator(DetectionValidator): | ||||
| 
 | ||||
|     def get_desc(self): | ||||
|         """Return a formatted description of evaluation metrics.""" | ||||
|         return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P', | ||||
|                                          'R', 'mAP50', 'mAP50-95)') | ||||
|         return ("%22s" + "%11s" * 10) % ( | ||||
|             "Class", | ||||
|             "Images", | ||||
|             "Instances", | ||||
|             "Box(P", | ||||
|             "R", | ||||
|             "mAP50", | ||||
|             "mAP50-95)", | ||||
|             "Mask(P", | ||||
|             "R", | ||||
|             "mAP50", | ||||
|             "mAP50-95)", | ||||
|         ) | ||||
| 
 | ||||
|     def postprocess(self, preds): | ||||
|         """Post-processes YOLO predictions and returns output detections with proto.""" | ||||
|         p = ops.non_max_suppression(preds[0], | ||||
|         p = ops.non_max_suppression( | ||||
|             preds[0], | ||||
|             self.args.conf, | ||||
|             self.args.iou, | ||||
|             labels=self.lb, | ||||
|             multi_label=True, | ||||
|             agnostic=self.args.single_cls, | ||||
|             max_det=self.args.max_det, | ||||
|                                     nc=self.nc) | ||||
|             nc=self.nc, | ||||
|         ) | ||||
|         proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported | ||||
|         return p, proto | ||||
| 
 | ||||
|     def _prepare_batch(self, si, batch): | ||||
|         """Prepares a batch for training or inference by processing images and targets.""" | ||||
|         prepared_batch = super()._prepare_batch(si, batch) | ||||
|         midx = [si] if self.args.overlap_mask else batch['batch_idx'] == si | ||||
|         prepared_batch['masks'] = batch['masks'][midx] | ||||
|         midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si | ||||
|         prepared_batch["masks"] = batch["masks"][midx] | ||||
|         return prepared_batch | ||||
| 
 | ||||
|     def _prepare_pred(self, pred, pbatch, proto): | ||||
|         """Prepares a batch for training or inference by processing images and targets.""" | ||||
|         predn = super()._prepare_pred(pred, pbatch) | ||||
|         pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch['imgsz']) | ||||
|         pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"]) | ||||
|         return predn, pred_masks | ||||
| 
 | ||||
|     def update_metrics(self, preds, batch): | ||||
| @ -89,14 +102,16 @@ class SegmentationValidator(DetectionValidator): | ||||
|         for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): | ||||
|             self.seen += 1 | ||||
|             npr = len(pred) | ||||
|             stat = dict(conf=torch.zeros(0, device=self.device), | ||||
|             stat = dict( | ||||
|                 conf=torch.zeros(0, device=self.device), | ||||
|                 pred_cls=torch.zeros(0, device=self.device), | ||||
|                 tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), | ||||
|                         tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)) | ||||
|                 tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), | ||||
|             ) | ||||
|             pbatch = self._prepare_batch(si, batch) | ||||
|             cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox') | ||||
|             cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") | ||||
|             nl = len(cls) | ||||
|             stat['target_cls'] = cls | ||||
|             stat["target_cls"] = cls | ||||
|             if npr == 0: | ||||
|                 if nl: | ||||
|                     for k in self.stats.keys(): | ||||
| @ -106,24 +121,20 @@ class SegmentationValidator(DetectionValidator): | ||||
|                 continue | ||||
| 
 | ||||
|             # Masks | ||||
|             gt_masks = pbatch.pop('masks') | ||||
|             gt_masks = pbatch.pop("masks") | ||||
|             # Predictions | ||||
|             if self.args.single_cls: | ||||
|                 pred[:, 5] = 0 | ||||
|             predn, pred_masks = self._prepare_pred(pred, pbatch, proto) | ||||
|             stat['conf'] = predn[:, 4] | ||||
|             stat['pred_cls'] = predn[:, 5] | ||||
|             stat["conf"] = predn[:, 4] | ||||
|             stat["pred_cls"] = predn[:, 5] | ||||
| 
 | ||||
|             # Evaluate | ||||
|             if nl: | ||||
|                 stat['tp'] = self._process_batch(predn, bbox, cls) | ||||
|                 stat['tp_m'] = self._process_batch(predn, | ||||
|                                                    bbox, | ||||
|                                                    cls, | ||||
|                                                    pred_masks, | ||||
|                                                    gt_masks, | ||||
|                                                    self.args.overlap_mask, | ||||
|                                                    masks=True) | ||||
|                 stat["tp"] = self._process_batch(predn, bbox, cls) | ||||
|                 stat["tp_m"] = self._process_batch( | ||||
|                     predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True | ||||
|                 ) | ||||
|                 if self.args.plots: | ||||
|                     self.confusion_matrix.process_batch(predn, bbox, cls) | ||||
| 
 | ||||
| @ -136,10 +147,12 @@ class SegmentationValidator(DetectionValidator): | ||||
| 
 | ||||
|             # Save | ||||
|             if self.args.save_json: | ||||
|                 pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), | ||||
|                                              pbatch['ori_shape'], | ||||
|                                              ratio_pad=batch['ratio_pad'][si]) | ||||
|                 self.pred_to_json(predn, batch['im_file'][si], pred_masks) | ||||
|                 pred_masks = ops.scale_image( | ||||
|                     pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), | ||||
|                     pbatch["ori_shape"], | ||||
|                     ratio_pad=batch["ratio_pad"][si], | ||||
|                 ) | ||||
|                 self.pred_to_json(predn, batch["im_file"][si], pred_masks) | ||||
|             # if self.args.save_txt: | ||||
|             #    save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') | ||||
| 
 | ||||
| @ -166,7 +179,7 @@ class SegmentationValidator(DetectionValidator): | ||||
|                 gt_masks = gt_masks.repeat(nl, 1, 1)  # shape(1,640,640) -> (n,640,640) | ||||
|                 gt_masks = torch.where(gt_masks == index, 1.0, 0.0) | ||||
|             if gt_masks.shape[1:] != pred_masks.shape[1:]: | ||||
|                 gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0] | ||||
|                 gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0] | ||||
|                 gt_masks = gt_masks.gt_(0.5) | ||||
|             iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) | ||||
|         else:  # boxes | ||||
| @ -176,26 +189,29 @@ class SegmentationValidator(DetectionValidator): | ||||
| 
 | ||||
|     def plot_val_samples(self, batch, ni): | ||||
|         """Plots validation samples with bounding box labels.""" | ||||
|         plot_images(batch['img'], | ||||
|                     batch['batch_idx'], | ||||
|                     batch['cls'].squeeze(-1), | ||||
|                     batch['bboxes'], | ||||
|                     masks=batch['masks'], | ||||
|                     paths=batch['im_file'], | ||||
|                     fname=self.save_dir / f'val_batch{ni}_labels.jpg', | ||||
|         plot_images( | ||||
|             batch["img"], | ||||
|             batch["batch_idx"], | ||||
|             batch["cls"].squeeze(-1), | ||||
|             batch["bboxes"], | ||||
|             masks=batch["masks"], | ||||
|             paths=batch["im_file"], | ||||
|             fname=self.save_dir / f"val_batch{ni}_labels.jpg", | ||||
|             names=self.names, | ||||
|                     on_plot=self.on_plot) | ||||
|             on_plot=self.on_plot, | ||||
|         ) | ||||
| 
 | ||||
|     def plot_predictions(self, batch, preds, ni): | ||||
|         """Plots batch predictions with masks and bounding boxes.""" | ||||
|         plot_images( | ||||
|             batch['img'], | ||||
|             batch["img"], | ||||
|             *output_to_target(preds[0], max_det=15),  # not set to self.args.max_det due to slow plotting speed | ||||
|             torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, | ||||
|             paths=batch['im_file'], | ||||
|             fname=self.save_dir / f'val_batch{ni}_pred.jpg', | ||||
|             paths=batch["im_file"], | ||||
|             fname=self.save_dir / f"val_batch{ni}_pred.jpg", | ||||
|             names=self.names, | ||||
|             on_plot=self.on_plot)  # pred | ||||
|             on_plot=self.on_plot, | ||||
|         )  # pred | ||||
|         self.plot_masks.clear() | ||||
| 
 | ||||
|     def pred_to_json(self, predn, filename, pred_masks): | ||||
| @ -205,8 +221,8 @@ class SegmentationValidator(DetectionValidator): | ||||
| 
 | ||||
|         def single_encode(x): | ||||
|             """Encode predicted masks as RLE and append results to jdict.""" | ||||
|             rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0] | ||||
|             rle['counts'] = rle['counts'].decode('utf-8') | ||||
|             rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] | ||||
|             rle["counts"] = rle["counts"].decode("utf-8") | ||||
|             return rle | ||||
| 
 | ||||
|         stem = Path(filename).stem | ||||
| @ -217,37 +233,41 @@ class SegmentationValidator(DetectionValidator): | ||||
|         with ThreadPool(NUM_THREADS) as pool: | ||||
|             rles = pool.map(single_encode, pred_masks) | ||||
|         for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): | ||||
|             self.jdict.append({ | ||||
|                 'image_id': image_id, | ||||
|                 'category_id': self.class_map[int(p[5])], | ||||
|                 'bbox': [round(x, 3) for x in b], | ||||
|                 'score': round(p[4], 5), | ||||
|                 'segmentation': rles[i]}) | ||||
|             self.jdict.append( | ||||
|                 { | ||||
|                     "image_id": image_id, | ||||
|                     "category_id": self.class_map[int(p[5])], | ||||
|                     "bbox": [round(x, 3) for x in b], | ||||
|                     "score": round(p[4], 5), | ||||
|                     "segmentation": rles[i], | ||||
|                 } | ||||
|             ) | ||||
| 
 | ||||
|     def eval_json(self, stats): | ||||
|         """Return COCO-style object detection evaluation metrics.""" | ||||
|         if self.args.save_json and self.is_coco and len(self.jdict): | ||||
|             anno_json = self.data['path'] / 'annotations/instances_val2017.json'  # annotations | ||||
|             pred_json = self.save_dir / 'predictions.json'  # predictions | ||||
|             LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') | ||||
|             anno_json = self.data["path"] / "annotations/instances_val2017.json"  # annotations | ||||
|             pred_json = self.save_dir / "predictions.json"  # predictions | ||||
|             LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") | ||||
|             try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb | ||||
|                 check_requirements('pycocotools>=2.0.6') | ||||
|                 check_requirements("pycocotools>=2.0.6") | ||||
|                 from pycocotools.coco import COCO  # noqa | ||||
|                 from pycocotools.cocoeval import COCOeval  # noqa | ||||
| 
 | ||||
|                 for x in anno_json, pred_json: | ||||
|                     assert x.is_file(), f'{x} file not found' | ||||
|                     assert x.is_file(), f"{x} file not found" | ||||
|                 anno = COCO(str(anno_json))  # init annotations api | ||||
|                 pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path) | ||||
|                 for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]): | ||||
|                 for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]): | ||||
|                     if self.is_coco: | ||||
|                         eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]  # im to eval | ||||
|                     eval.evaluate() | ||||
|                     eval.accumulate() | ||||
|                     eval.summarize() | ||||
|                     idx = i * 4 + 2 | ||||
|                     stats[self.metrics.keys[idx + 1]], stats[ | ||||
|                         self.metrics.keys[idx]] = eval.stats[:2]  # update mAP50-95 and mAP50 | ||||
|                     stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[ | ||||
|                         :2 | ||||
|                     ]  # update mAP50-95 and mAP50 | ||||
|             except Exception as e: | ||||
|                 LOGGER.warning(f'pycocotools unable to run: {e}') | ||||
|                 LOGGER.warning(f"pycocotools unable to run: {e}") | ||||
|         return stats | ||||
|  | ||||
| @ -1,9 +1,29 @@ | ||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||
| 
 | ||||
| from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight, | ||||
|                     attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load, | ||||
|                     yaml_model_load) | ||||
| from .tasks import ( | ||||
|     BaseModel, | ||||
|     ClassificationModel, | ||||
|     DetectionModel, | ||||
|     SegmentationModel, | ||||
|     attempt_load_one_weight, | ||||
|     attempt_load_weights, | ||||
|     guess_model_scale, | ||||
|     guess_model_task, | ||||
|     parse_model, | ||||
|     torch_safe_load, | ||||
|     yaml_model_load, | ||||
| ) | ||||
| 
 | ||||
| __all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task', | ||||
|            'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel', | ||||
|            'BaseModel') | ||||
| __all__ = ( | ||||
|     "attempt_load_one_weight", | ||||
|     "attempt_load_weights", | ||||
|     "parse_model", | ||||
|     "yaml_model_load", | ||||
|     "guess_model_task", | ||||
|     "guess_model_scale", | ||||
|     "torch_safe_load", | ||||
|     "DetectionModel", | ||||
|     "SegmentationModel", | ||||
|     "ClassificationModel", | ||||
|     "BaseModel", | ||||
| ) | ||||
|  | ||||
| @ -32,10 +32,12 @@ def check_class_names(names): | ||||
|         names = {int(k): str(v) for k, v in names.items()} | ||||
|         n = len(names) | ||||
|         if max(names.keys()) >= n: | ||||
|             raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices ' | ||||
|                            f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.') | ||||
|         if isinstance(names[0], str) and names[0].startswith('n0'):  # imagenet class codes, i.e. 'n01440764' | ||||
|             names_map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map']  # human-readable names | ||||
|             raise KeyError( | ||||
|                 f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices " | ||||
|                 f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." | ||||
|             ) | ||||
|         if isinstance(names[0], str) and names[0].startswith("n0"):  # imagenet class codes, i.e. 'n01440764' | ||||
|             names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"]  # human-readable names | ||||
|             names = {k: names_map[v] for k, v in names.items()} | ||||
|     return names | ||||
| 
 | ||||
| @ -44,8 +46,8 @@ def default_class_names(data=None): | ||||
|     """Applies default class names to an input YAML file or returns numerical class names.""" | ||||
|     if data: | ||||
|         with contextlib.suppress(Exception): | ||||
|             return yaml_load(check_yaml(data))['names'] | ||||
|     return {i: f'class{i}' for i in range(999)}  # return default if above errors | ||||
|             return yaml_load(check_yaml(data))["names"] | ||||
|     return {i: f"class{i}" for i in range(999)}  # return default if above errors | ||||
| 
 | ||||
| 
 | ||||
| class AutoBackend(nn.Module): | ||||
| @ -77,14 +79,16 @@ class AutoBackend(nn.Module): | ||||
|     """ | ||||
| 
 | ||||
|     @torch.no_grad() | ||||
|     def __init__(self, | ||||
|                  weights='yolov8n.pt', | ||||
|                  device=torch.device('cpu'), | ||||
|     def __init__( | ||||
|         self, | ||||
|         weights="yolov8n.pt", | ||||
|         device=torch.device("cpu"), | ||||
|         dnn=False, | ||||
|         data=None, | ||||
|         fp16=False, | ||||
|         fuse=True, | ||||
|                  verbose=True): | ||||
|         verbose=True, | ||||
|     ): | ||||
|         """ | ||||
|         Initialize the AutoBackend for inference. | ||||
| 
 | ||||
| @ -100,17 +104,31 @@ class AutoBackend(nn.Module): | ||||
|         super().__init__() | ||||
|         w = str(weights[0] if isinstance(weights, list) else weights) | ||||
|         nn_module = isinstance(weights, torch.nn.Module) | ||||
|         pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \ | ||||
|             self._model_type(w) | ||||
|         ( | ||||
|             pt, | ||||
|             jit, | ||||
|             onnx, | ||||
|             xml, | ||||
|             engine, | ||||
|             coreml, | ||||
|             saved_model, | ||||
|             pb, | ||||
|             tflite, | ||||
|             edgetpu, | ||||
|             tfjs, | ||||
|             paddle, | ||||
|             ncnn, | ||||
|             triton, | ||||
|         ) = self._model_type(w) | ||||
|         fp16 &= pt or jit or onnx or xml or engine or nn_module or triton  # FP16 | ||||
|         nhwc = coreml or saved_model or pb or tflite or edgetpu  # BHWC formats (vs torch BCWH) | ||||
|         stride = 32  # default stride | ||||
|         model, metadata = None, None | ||||
| 
 | ||||
|         # Set device | ||||
|         cuda = torch.cuda.is_available() and device.type != 'cpu'  # use CUDA | ||||
|         cuda = torch.cuda.is_available() and device.type != "cpu"  # use CUDA | ||||
|         if cuda and not any([nn_module, pt, jit, engine, onnx]):  # GPU dataloader formats | ||||
|             device = torch.device('cpu') | ||||
|             device = torch.device("cpu") | ||||
|             cuda = False | ||||
| 
 | ||||
|         # Download if not local | ||||
| @ -121,77 +139,79 @@ class AutoBackend(nn.Module): | ||||
|         if nn_module:  # in-memory PyTorch model | ||||
|             model = weights.to(device) | ||||
|             model = model.fuse(verbose=verbose) if fuse else model | ||||
|             if hasattr(model, 'kpt_shape'): | ||||
|             if hasattr(model, "kpt_shape"): | ||||
|                 kpt_shape = model.kpt_shape  # pose-only | ||||
|             stride = max(int(model.stride.max()), 32)  # model stride | ||||
|             names = model.module.names if hasattr(model, 'module') else model.names  # get class names | ||||
|             names = model.module.names if hasattr(model, "module") else model.names  # get class names | ||||
|             model.half() if fp16 else model.float() | ||||
|             self.model = model  # explicitly assign for to(), cpu(), cuda(), half() | ||||
|             pt = True | ||||
|         elif pt:  # PyTorch | ||||
|             from ultralytics.nn.tasks import attempt_load_weights | ||||
|             model = attempt_load_weights(weights if isinstance(weights, list) else w, | ||||
|                                          device=device, | ||||
|                                          inplace=True, | ||||
|                                          fuse=fuse) | ||||
|             if hasattr(model, 'kpt_shape'): | ||||
| 
 | ||||
|             model = attempt_load_weights( | ||||
|                 weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse | ||||
|             ) | ||||
|             if hasattr(model, "kpt_shape"): | ||||
|                 kpt_shape = model.kpt_shape  # pose-only | ||||
|             stride = max(int(model.stride.max()), 32)  # model stride | ||||
|             names = model.module.names if hasattr(model, 'module') else model.names  # get class names | ||||
|             names = model.module.names if hasattr(model, "module") else model.names  # get class names | ||||
|             model.half() if fp16 else model.float() | ||||
|             self.model = model  # explicitly assign for to(), cpu(), cuda(), half() | ||||
|         elif jit:  # TorchScript | ||||
|             LOGGER.info(f'Loading {w} for TorchScript inference...') | ||||
|             extra_files = {'config.txt': ''}  # model metadata | ||||
|             LOGGER.info(f"Loading {w} for TorchScript inference...") | ||||
|             extra_files = {"config.txt": ""}  # model metadata | ||||
|             model = torch.jit.load(w, _extra_files=extra_files, map_location=device) | ||||
|             model.half() if fp16 else model.float() | ||||
|             if extra_files['config.txt']:  # load metadata dict | ||||
|                 metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items())) | ||||
|             if extra_files["config.txt"]:  # load metadata dict | ||||
|                 metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) | ||||
|         elif dnn:  # ONNX OpenCV DNN | ||||
|             LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...') | ||||
|             check_requirements('opencv-python>=4.5.4') | ||||
|             LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") | ||||
|             check_requirements("opencv-python>=4.5.4") | ||||
|             net = cv2.dnn.readNetFromONNX(w) | ||||
|         elif onnx:  # ONNX Runtime | ||||
|             LOGGER.info(f'Loading {w} for ONNX Runtime inference...') | ||||
|             check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime')) | ||||
|             LOGGER.info(f"Loading {w} for ONNX Runtime inference...") | ||||
|             check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) | ||||
|             import onnxruntime | ||||
|             providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] | ||||
| 
 | ||||
|             providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"] | ||||
|             session = onnxruntime.InferenceSession(w, providers=providers) | ||||
|             output_names = [x.name for x in session.get_outputs()] | ||||
|             metadata = session.get_modelmeta().custom_metadata_map  # metadata | ||||
|         elif xml:  # OpenVINO | ||||
|             LOGGER.info(f'Loading {w} for OpenVINO inference...') | ||||
|             check_requirements('openvino>=2023.0')  # requires openvino-dev: https://pypi.org/project/openvino-dev/ | ||||
|             LOGGER.info(f"Loading {w} for OpenVINO inference...") | ||||
|             check_requirements("openvino>=2023.0")  # requires openvino-dev: https://pypi.org/project/openvino-dev/ | ||||
|             from openvino.runtime import Core, Layout, get_batch  # noqa | ||||
| 
 | ||||
|             core = Core() | ||||
|             w = Path(w) | ||||
|             if not w.is_file():  # if not *.xml | ||||
|                 w = next(w.glob('*.xml'))  # get *.xml file from *_openvino_model dir | ||||
|             ov_model = core.read_model(model=str(w), weights=w.with_suffix('.bin')) | ||||
|                 w = next(w.glob("*.xml"))  # get *.xml file from *_openvino_model dir | ||||
|             ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin")) | ||||
|             if ov_model.get_parameters()[0].get_layout().empty: | ||||
|                 ov_model.get_parameters()[0].set_layout(Layout('NCHW')) | ||||
|                 ov_model.get_parameters()[0].set_layout(Layout("NCHW")) | ||||
|             batch_dim = get_batch(ov_model) | ||||
|             if batch_dim.is_static: | ||||
|                 batch_size = batch_dim.get_length() | ||||
|             ov_compiled_model = core.compile_model(ov_model, device_name='AUTO')  # AUTO selects best available device | ||||
|             metadata = w.parent / 'metadata.yaml' | ||||
|             ov_compiled_model = core.compile_model(ov_model, device_name="AUTO")  # AUTO selects best available device | ||||
|             metadata = w.parent / "metadata.yaml" | ||||
|         elif engine:  # TensorRT | ||||
|             LOGGER.info(f'Loading {w} for TensorRT inference...') | ||||
|             LOGGER.info(f"Loading {w} for TensorRT inference...") | ||||
|             try: | ||||
|                 import tensorrt as trt  # noqa https://developer.nvidia.com/nvidia-tensorrt-download | ||||
|             except ImportError: | ||||
|                 if LINUX: | ||||
|                     check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com') | ||||
|                     check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com") | ||||
|                 import tensorrt as trt  # noqa | ||||
|             check_version(trt.__version__, '7.0.0', hard=True)  # require tensorrt>=7.0.0 | ||||
|             if device.type == 'cpu': | ||||
|                 device = torch.device('cuda:0') | ||||
|             Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) | ||||
|             check_version(trt.__version__, "7.0.0", hard=True)  # require tensorrt>=7.0.0 | ||||
|             if device.type == "cpu": | ||||
|                 device = torch.device("cuda:0") | ||||
|             Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) | ||||
|             logger = trt.Logger(trt.Logger.INFO) | ||||
|             # Read file | ||||
|             with open(w, 'rb') as f, trt.Runtime(logger) as runtime: | ||||
|                 meta_len = int.from_bytes(f.read(4), byteorder='little')  # read metadata length | ||||
|                 metadata = json.loads(f.read(meta_len).decode('utf-8'))  # read metadata | ||||
|             with open(w, "rb") as f, trt.Runtime(logger) as runtime: | ||||
|                 meta_len = int.from_bytes(f.read(4), byteorder="little")  # read metadata length | ||||
|                 metadata = json.loads(f.read(meta_len).decode("utf-8"))  # read metadata | ||||
|                 model = runtime.deserialize_cuda_engine(f.read())  # read engine | ||||
|             context = model.create_execution_context() | ||||
|             bindings = OrderedDict() | ||||
| @ -213,116 +233,124 @@ class AutoBackend(nn.Module): | ||||
|                 im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) | ||||
|                 bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) | ||||
|             binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) | ||||
|             batch_size = bindings['images'].shape[0]  # if dynamic, this is instead max batch size | ||||
|             batch_size = bindings["images"].shape[0]  # if dynamic, this is instead max batch size | ||||
|         elif coreml:  # CoreML | ||||
|             LOGGER.info(f'Loading {w} for CoreML inference...') | ||||
|             LOGGER.info(f"Loading {w} for CoreML inference...") | ||||
|             import coremltools as ct | ||||
| 
 | ||||
|             model = ct.models.MLModel(w) | ||||
|             metadata = dict(model.user_defined_metadata) | ||||
|         elif saved_model:  # TF SavedModel | ||||
|             LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') | ||||
|             LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") | ||||
|             import tensorflow as tf | ||||
| 
 | ||||
|             keras = False  # assume TF1 saved_model | ||||
|             model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) | ||||
|             metadata = Path(w) / 'metadata.yaml' | ||||
|             metadata = Path(w) / "metadata.yaml" | ||||
|         elif pb:  # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt | ||||
|             LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') | ||||
|             LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") | ||||
|             import tensorflow as tf | ||||
| 
 | ||||
|             from ultralytics.engine.exporter import gd_outputs | ||||
| 
 | ||||
|             def wrap_frozen_graph(gd, inputs, outputs): | ||||
|                 """Wrap frozen graphs for deployment.""" | ||||
|                 x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), [])  # wrapped | ||||
|                 x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), [])  # wrapped | ||||
|                 ge = x.graph.as_graph_element | ||||
|                 return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) | ||||
| 
 | ||||
|             gd = tf.Graph().as_graph_def()  # TF GraphDef | ||||
|             with open(w, 'rb') as f: | ||||
|             with open(w, "rb") as f: | ||||
|                 gd.ParseFromString(f.read()) | ||||
|             frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd)) | ||||
|             frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) | ||||
|         elif tflite or edgetpu:  # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python | ||||
|             try:  # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu | ||||
|                 from tflite_runtime.interpreter import Interpreter, load_delegate | ||||
|             except ImportError: | ||||
|                 import tensorflow as tf | ||||
| 
 | ||||
|                 Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate | ||||
|             if edgetpu:  # TF Edge TPU https://coral.ai/software/#edgetpu-runtime | ||||
|                 LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') | ||||
|                 delegate = { | ||||
|                     'Linux': 'libedgetpu.so.1', | ||||
|                     'Darwin': 'libedgetpu.1.dylib', | ||||
|                     'Windows': 'edgetpu.dll'}[platform.system()] | ||||
|                 LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...") | ||||
|                 delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[ | ||||
|                     platform.system() | ||||
|                 ] | ||||
|                 interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)]) | ||||
|             else:  # TFLite | ||||
|                 LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') | ||||
|                 LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") | ||||
|                 interpreter = Interpreter(model_path=w)  # load TFLite model | ||||
|             interpreter.allocate_tensors()  # allocate | ||||
|             input_details = interpreter.get_input_details()  # inputs | ||||
|             output_details = interpreter.get_output_details()  # outputs | ||||
|             # Load metadata | ||||
|             with contextlib.suppress(zipfile.BadZipFile): | ||||
|                 with zipfile.ZipFile(w, 'r') as model: | ||||
|                 with zipfile.ZipFile(w, "r") as model: | ||||
|                     meta_file = model.namelist()[0] | ||||
|                     metadata = ast.literal_eval(model.read(meta_file).decode('utf-8')) | ||||
|                     metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) | ||||
|         elif tfjs:  # TF.js | ||||
|             raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.') | ||||
|             raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") | ||||
|         elif paddle:  # PaddlePaddle | ||||
|             LOGGER.info(f'Loading {w} for PaddlePaddle inference...') | ||||
|             check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle') | ||||
|             LOGGER.info(f"Loading {w} for PaddlePaddle inference...") | ||||
|             check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") | ||||
|             import paddle.inference as pdi  # noqa | ||||
| 
 | ||||
|             w = Path(w) | ||||
|             if not w.is_file():  # if not *.pdmodel | ||||
|                 w = next(w.rglob('*.pdmodel'))  # get *.pdmodel file from *_paddle_model dir | ||||
|             config = pdi.Config(str(w), str(w.with_suffix('.pdiparams'))) | ||||
|                 w = next(w.rglob("*.pdmodel"))  # get *.pdmodel file from *_paddle_model dir | ||||
|             config = pdi.Config(str(w), str(w.with_suffix(".pdiparams"))) | ||||
|             if cuda: | ||||
|                 config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) | ||||
|             predictor = pdi.create_predictor(config) | ||||
|             input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) | ||||
|             output_names = predictor.get_output_names() | ||||
|             metadata = w.parents[1] / 'metadata.yaml' | ||||
|             metadata = w.parents[1] / "metadata.yaml" | ||||
|         elif ncnn:  # ncnn | ||||
|             LOGGER.info(f'Loading {w} for ncnn inference...') | ||||
|             check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn')  # requires ncnn | ||||
|             LOGGER.info(f"Loading {w} for ncnn inference...") | ||||
|             check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn")  # requires ncnn | ||||
|             import ncnn as pyncnn | ||||
| 
 | ||||
|             net = pyncnn.Net() | ||||
|             net.opt.use_vulkan_compute = cuda | ||||
|             w = Path(w) | ||||
|             if not w.is_file():  # if not *.param | ||||
|                 w = next(w.glob('*.param'))  # get *.param file from *_ncnn_model dir | ||||
|                 w = next(w.glob("*.param"))  # get *.param file from *_ncnn_model dir | ||||
|             net.load_param(str(w)) | ||||
|             net.load_model(str(w.with_suffix('.bin'))) | ||||
|             metadata = w.parent / 'metadata.yaml' | ||||
|             net.load_model(str(w.with_suffix(".bin"))) | ||||
|             metadata = w.parent / "metadata.yaml" | ||||
|         elif triton:  # NVIDIA Triton Inference Server | ||||
|             check_requirements('tritonclient[all]') | ||||
|             check_requirements("tritonclient[all]") | ||||
|             from ultralytics.utils.triton import TritonRemoteModel | ||||
| 
 | ||||
|             model = TritonRemoteModel(w) | ||||
|         else: | ||||
|             from ultralytics.engine.exporter import export_formats | ||||
|             raise TypeError(f"model='{w}' is not a supported model format. " | ||||
|                             'See https://docs.ultralytics.com/modes/predict for help.' | ||||
|                             f'\n\n{export_formats()}') | ||||
| 
 | ||||
|             raise TypeError( | ||||
|                 f"model='{w}' is not a supported model format. " | ||||
|                 "See https://docs.ultralytics.com/modes/predict for help." | ||||
|                 f"\n\n{export_formats()}" | ||||
|             ) | ||||
| 
 | ||||
|         # Load external metadata YAML | ||||
|         if isinstance(metadata, (str, Path)) and Path(metadata).exists(): | ||||
|             metadata = yaml_load(metadata) | ||||
|         if metadata: | ||||
|             for k, v in metadata.items(): | ||||
|                 if k in ('stride', 'batch'): | ||||
|                 if k in ("stride", "batch"): | ||||
|                     metadata[k] = int(v) | ||||
|                 elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str): | ||||
|                 elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str): | ||||
|                     metadata[k] = eval(v) | ||||
|             stride = metadata['stride'] | ||||
|             task = metadata['task'] | ||||
|             batch = metadata['batch'] | ||||
|             imgsz = metadata['imgsz'] | ||||
|             names = metadata['names'] | ||||
|             kpt_shape = metadata.get('kpt_shape') | ||||
|             stride = metadata["stride"] | ||||
|             task = metadata["task"] | ||||
|             batch = metadata["batch"] | ||||
|             imgsz = metadata["imgsz"] | ||||
|             names = metadata["names"] | ||||
|             kpt_shape = metadata.get("kpt_shape") | ||||
|         elif not (pt or triton or nn_module): | ||||
|             LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") | ||||
| 
 | ||||
|         # Check names | ||||
|         if 'names' not in locals():  # names missing | ||||
|         if "names" not in locals():  # names missing | ||||
|             names = default_class_names(data) | ||||
|         names = check_class_names(names) | ||||
| 
 | ||||
| @ -367,26 +395,28 @@ class AutoBackend(nn.Module): | ||||
|             im = im.cpu().numpy()  # FP32 | ||||
|             y = list(self.ov_compiled_model(im).values()) | ||||
|         elif self.engine:  # TensorRT | ||||
|             if self.dynamic and im.shape != self.bindings['images'].shape: | ||||
|                 i = self.model.get_binding_index('images') | ||||
|             if self.dynamic and im.shape != self.bindings["images"].shape: | ||||
|                 i = self.model.get_binding_index("images") | ||||
|                 self.context.set_binding_shape(i, im.shape)  # reshape if dynamic | ||||
|                 self.bindings['images'] = self.bindings['images']._replace(shape=im.shape) | ||||
|                 self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) | ||||
|                 for name in self.output_names: | ||||
|                     i = self.model.get_binding_index(name) | ||||
|                     self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) | ||||
|             s = self.bindings['images'].shape | ||||
|             s = self.bindings["images"].shape | ||||
|             assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" | ||||
|             self.binding_addrs['images'] = int(im.data_ptr()) | ||||
|             self.binding_addrs["images"] = int(im.data_ptr()) | ||||
|             self.context.execute_v2(list(self.binding_addrs.values())) | ||||
|             y = [self.bindings[x].data for x in sorted(self.output_names)] | ||||
|         elif self.coreml:  # CoreML | ||||
|             im = im[0].cpu().numpy() | ||||
|             im_pil = Image.fromarray((im * 255).astype('uint8')) | ||||
|             im_pil = Image.fromarray((im * 255).astype("uint8")) | ||||
|             # im = im.resize((192, 320), Image.BILINEAR) | ||||
|             y = self.model.predict({'image': im_pil})  # coordinates are xywh normalized | ||||
|             if 'confidence' in y: | ||||
|                 raise TypeError('Ultralytics only supports inference of non-pipelined CoreML models exported with ' | ||||
|                                 f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.") | ||||
|             y = self.model.predict({"image": im_pil})  # coordinates are xywh normalized | ||||
|             if "confidence" in y: | ||||
|                 raise TypeError( | ||||
|                     "Ultralytics only supports inference of non-pipelined CoreML models exported with " | ||||
|                     f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export." | ||||
|                 ) | ||||
|                 # TODO: CoreML NMS inference handling | ||||
|                 # from ultralytics.utils.ops import xywh2xyxy | ||||
|                 # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]])  # xyxy pixels | ||||
| @ -425,20 +455,20 @@ class AutoBackend(nn.Module): | ||||
|                 if len(y) == 2 and len(self.names) == 999:  # segments and names not defined | ||||
|                     ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0)  # index of protos, boxes | ||||
|                     nc = y[ib].shape[1] - y[ip].shape[3] - 4  # y = (1, 160, 160, 32), (1, 116, 8400) | ||||
|                     self.names = {i: f'class{i}' for i in range(nc)} | ||||
|                     self.names = {i: f"class{i}" for i in range(nc)} | ||||
|             else:  # Lite or Edge TPU | ||||
|                 details = self.input_details[0] | ||||
|                 integer = details['dtype'] in (np.int8, np.int16)  # is TFLite quantized int8 or int16 model | ||||
|                 integer = details["dtype"] in (np.int8, np.int16)  # is TFLite quantized int8 or int16 model | ||||
|                 if integer: | ||||
|                     scale, zero_point = details['quantization'] | ||||
|                     im = (im / scale + zero_point).astype(details['dtype'])  # de-scale | ||||
|                 self.interpreter.set_tensor(details['index'], im) | ||||
|                     scale, zero_point = details["quantization"] | ||||
|                     im = (im / scale + zero_point).astype(details["dtype"])  # de-scale | ||||
|                 self.interpreter.set_tensor(details["index"], im) | ||||
|                 self.interpreter.invoke() | ||||
|                 y = [] | ||||
|                 for output in self.output_details: | ||||
|                     x = self.interpreter.get_tensor(output['index']) | ||||
|                     x = self.interpreter.get_tensor(output["index"]) | ||||
|                     if integer: | ||||
|                         scale, zero_point = output['quantization'] | ||||
|                         scale, zero_point = output["quantization"] | ||||
|                         x = (x.astype(np.float32) - zero_point) * scale  # re-scale | ||||
|                     if x.ndim > 2:  # if task is not classification | ||||
|                         # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 | ||||
| @ -483,13 +513,13 @@ class AutoBackend(nn.Module): | ||||
|             (None): This method runs the forward pass and don't return any value | ||||
|         """ | ||||
|         warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module | ||||
|         if any(warmup_types) and (self.device.type != 'cpu' or self.triton): | ||||
|         if any(warmup_types) and (self.device.type != "cpu" or self.triton): | ||||
|             im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device)  # input | ||||
|             for _ in range(2 if self.jit else 1): | ||||
|                 self.forward(im)  # warmup | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _model_type(p='path/to/model.pt'): | ||||
|     def _model_type(p="path/to/model.pt"): | ||||
|         """ | ||||
|         This function takes a path to a model file and returns the model type. | ||||
| 
 | ||||
| @ -499,18 +529,20 @@ class AutoBackend(nn.Module): | ||||
|         # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx | ||||
|         # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle] | ||||
|         from ultralytics.engine.exporter import export_formats | ||||
| 
 | ||||
|         sf = list(export_formats().Suffix)  # export suffixes | ||||
|         if not is_url(p, check=False) and not isinstance(p, str): | ||||
|             check_suffix(p, sf)  # checks | ||||
|         name = Path(p).name | ||||
|         types = [s in name for s in sf] | ||||
|         types[5] |= name.endswith('.mlmodel')  # retain support for older Apple CoreML *.mlmodel formats | ||||
|         types[5] |= name.endswith(".mlmodel")  # retain support for older Apple CoreML *.mlmodel formats | ||||
|         types[8] &= not types[9]  # tflite &= not edgetpu | ||||
|         if any(types): | ||||
|             triton = False | ||||
|         else: | ||||
|             from urllib.parse import urlsplit | ||||
| 
 | ||||
|             url = urlsplit(p) | ||||
|             triton = url.netloc and url.path and url.scheme in {'http', 'grpc'} | ||||
|             triton = url.netloc and url.path and url.scheme in {"http", "grpc"} | ||||
| 
 | ||||
|         return types + [triton] | ||||
|  | ||||
| @ -17,18 +17,101 @@ Example: | ||||
|     ``` | ||||
| """ | ||||
| 
 | ||||
| from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck, | ||||
|                     HGBlock, HGStem, Proto, RepC3, ResNetLayer) | ||||
| from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus, | ||||
|                    GhostConv, LightConv, RepConv, SpatialAttention) | ||||
| from .block import ( | ||||
|     C1, | ||||
|     C2, | ||||
|     C3, | ||||
|     C3TR, | ||||
|     DFL, | ||||
|     SPP, | ||||
|     SPPF, | ||||
|     Bottleneck, | ||||
|     BottleneckCSP, | ||||
|     C2f, | ||||
|     C3Ghost, | ||||
|     C3x, | ||||
|     GhostBottleneck, | ||||
|     HGBlock, | ||||
|     HGStem, | ||||
|     Proto, | ||||
|     RepC3, | ||||
|     ResNetLayer, | ||||
| ) | ||||
| from .conv import ( | ||||
|     CBAM, | ||||
|     ChannelAttention, | ||||
|     Concat, | ||||
|     Conv, | ||||
|     Conv2, | ||||
|     ConvTranspose, | ||||
|     DWConv, | ||||
|     DWConvTranspose2d, | ||||
|     Focus, | ||||
|     GhostConv, | ||||
|     LightConv, | ||||
|     RepConv, | ||||
|     SpatialAttention, | ||||
| ) | ||||
| from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment | ||||
| from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, | ||||
|                           MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer) | ||||
| from .transformer import ( | ||||
|     AIFI, | ||||
|     MLP, | ||||
|     DeformableTransformerDecoder, | ||||
|     DeformableTransformerDecoderLayer, | ||||
|     LayerNorm2d, | ||||
|     MLPBlock, | ||||
|     MSDeformAttn, | ||||
|     TransformerBlock, | ||||
|     TransformerEncoderLayer, | ||||
|     TransformerLayer, | ||||
| ) | ||||
| 
 | ||||
| __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', | ||||
|            'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', | ||||
|            'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', | ||||
|            'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', | ||||
|            'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI', | ||||
|            'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer', | ||||
|            'OBB') | ||||
| __all__ = ( | ||||
|     "Conv", | ||||
|     "Conv2", | ||||
|     "LightConv", | ||||
|     "RepConv", | ||||
|     "DWConv", | ||||
|     "DWConvTranspose2d", | ||||
|     "ConvTranspose", | ||||
|     "Focus", | ||||
|     "GhostConv", | ||||
|     "ChannelAttention", | ||||
|     "SpatialAttention", | ||||
|     "CBAM", | ||||
|     "Concat", | ||||
|     "TransformerLayer", | ||||
|     "TransformerBlock", | ||||
|     "MLPBlock", | ||||
|     "LayerNorm2d", | ||||
|     "DFL", | ||||
|     "HGBlock", | ||||
|     "HGStem", | ||||
|     "SPP", | ||||
|     "SPPF", | ||||
|     "C1", | ||||
|     "C2", | ||||
|     "C3", | ||||
|     "C2f", | ||||
|     "C3x", | ||||
|     "C3TR", | ||||
|     "C3Ghost", | ||||
|     "GhostBottleneck", | ||||
|     "Bottleneck", | ||||
|     "BottleneckCSP", | ||||
|     "Proto", | ||||
|     "Detect", | ||||
|     "Segment", | ||||
|     "Pose", | ||||
|     "Classify", | ||||
|     "TransformerEncoderLayer", | ||||
|     "RepC3", | ||||
|     "RTDETRDecoder", | ||||
|     "AIFI", | ||||
|     "DeformableTransformerDecoder", | ||||
|     "DeformableTransformerDecoderLayer", | ||||
|     "MSDeformAttn", | ||||
|     "MLP", | ||||
|     "ResNetLayer", | ||||
|     "OBB", | ||||
| ) | ||||
|  | ||||
| @ -8,8 +8,26 @@ import torch.nn.functional as F | ||||
| from .conv import Conv, DWConv, GhostConv, LightConv, RepConv | ||||
| from .transformer import TransformerBlock | ||||
| 
 | ||||
| __all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost', | ||||
|            'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3', 'ResNetLayer') | ||||
| __all__ = ( | ||||
|     "DFL", | ||||
|     "HGBlock", | ||||
|     "HGStem", | ||||
|     "SPP", | ||||
|     "SPPF", | ||||
|     "C1", | ||||
|     "C2", | ||||
|     "C3", | ||||
|     "C2f", | ||||
|     "C3x", | ||||
|     "C3TR", | ||||
|     "C3Ghost", | ||||
|     "GhostBottleneck", | ||||
|     "Bottleneck", | ||||
|     "BottleneckCSP", | ||||
|     "Proto", | ||||
|     "RepC3", | ||||
|     "ResNetLayer", | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class DFL(nn.Module): | ||||
| @ -284,9 +302,11 @@ class GhostBottleneck(nn.Module): | ||||
|         self.conv = nn.Sequential( | ||||
|             GhostConv(c1, c_, 1, 1),  # pw | ||||
|             DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(),  # dw | ||||
|             GhostConv(c_, c2, 1, 1, act=False))  # pw-linear | ||||
|         self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, | ||||
|                                                                             act=False)) if s == 2 else nn.Identity() | ||||
|             GhostConv(c_, c2, 1, 1, act=False),  # pw-linear | ||||
|         ) | ||||
|         self.shortcut = ( | ||||
|             nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity() | ||||
|         ) | ||||
| 
 | ||||
|     def forward(self, x): | ||||
|         """Applies skip connection and concatenation to input tensor.""" | ||||
| @ -359,8 +379,9 @@ class ResNetLayer(nn.Module): | ||||
|         self.is_first = is_first | ||||
| 
 | ||||
|         if self.is_first: | ||||
|             self.layer = nn.Sequential(Conv(c1, c2, k=7, s=2, p=3, act=True), | ||||
|                                        nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | ||||
|             self.layer = nn.Sequential( | ||||
|                 Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||
|             ) | ||||
|         else: | ||||
|             blocks = [ResNetBlock(c1, c2, s, e=e)] | ||||
|             blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)]) | ||||
|  | ||||
| @ -7,8 +7,21 @@ import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| 
 | ||||
| __all__ = ('Conv', 'Conv2', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv', | ||||
|            'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv') | ||||
| __all__ = ( | ||||
|     "Conv", | ||||
|     "Conv2", | ||||
|     "LightConv", | ||||
|     "DWConv", | ||||
|     "DWConvTranspose2d", | ||||
|     "ConvTranspose", | ||||
|     "Focus", | ||||
|     "GhostConv", | ||||
|     "ChannelAttention", | ||||
|     "SpatialAttention", | ||||
|     "CBAM", | ||||
|     "Concat", | ||||
|     "RepConv", | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| def autopad(k, p=None, d=1):  # kernel, padding, dilation | ||||
| @ -22,6 +35,7 @@ def autopad(k, p=None, d=1):  # kernel, padding, dilation | ||||
| 
 | ||||
| class Conv(nn.Module): | ||||
|     """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation).""" | ||||
| 
 | ||||
|     default_act = nn.SiLU()  # default activation | ||||
| 
 | ||||
|     def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): | ||||
| @ -60,9 +74,9 @@ class Conv2(Conv): | ||||
|         """Fuse parallel convolutions.""" | ||||
|         w = torch.zeros_like(self.conv.weight.data) | ||||
|         i = [x // 2 for x in w.shape[2:]] | ||||
|         w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone() | ||||
|         w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone() | ||||
|         self.conv.weight.data += w | ||||
|         self.__delattr__('cv2') | ||||
|         self.__delattr__("cv2") | ||||
|         self.forward = self.forward_fuse | ||||
| 
 | ||||
| 
 | ||||
| @ -102,6 +116,7 @@ class DWConvTranspose2d(nn.ConvTranspose2d): | ||||
| 
 | ||||
| class ConvTranspose(nn.Module): | ||||
|     """Convolution transpose 2d layer.""" | ||||
| 
 | ||||
|     default_act = nn.SiLU()  # default activation | ||||
| 
 | ||||
|     def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True): | ||||
| @ -164,6 +179,7 @@ class RepConv(nn.Module): | ||||
|     This module is used in RT-DETR. | ||||
|     Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py | ||||
|     """ | ||||
| 
 | ||||
|     default_act = nn.SiLU()  # default activation | ||||
| 
 | ||||
|     def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False): | ||||
| @ -214,7 +230,7 @@ class RepConv(nn.Module): | ||||
|             beta = branch.bn.bias | ||||
|             eps = branch.bn.eps | ||||
|         elif isinstance(branch, nn.BatchNorm2d): | ||||
|             if not hasattr(self, 'id_tensor'): | ||||
|             if not hasattr(self, "id_tensor"): | ||||
|                 input_dim = self.c1 // self.g | ||||
|                 kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32) | ||||
|                 for i in range(self.c1): | ||||
| @ -232,29 +248,31 @@ class RepConv(nn.Module): | ||||
| 
 | ||||
|     def fuse_convs(self): | ||||
|         """Combines two convolution layers into a single layer and removes unused attributes from the class.""" | ||||
|         if hasattr(self, 'conv'): | ||||
|         if hasattr(self, "conv"): | ||||
|             return | ||||
|         kernel, bias = self.get_equivalent_kernel_bias() | ||||
|         self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels, | ||||
|         self.conv = nn.Conv2d( | ||||
|             in_channels=self.conv1.conv.in_channels, | ||||
|             out_channels=self.conv1.conv.out_channels, | ||||
|             kernel_size=self.conv1.conv.kernel_size, | ||||
|             stride=self.conv1.conv.stride, | ||||
|             padding=self.conv1.conv.padding, | ||||
|             dilation=self.conv1.conv.dilation, | ||||
|             groups=self.conv1.conv.groups, | ||||
|                               bias=True).requires_grad_(False) | ||||
|             bias=True, | ||||
|         ).requires_grad_(False) | ||||
|         self.conv.weight.data = kernel | ||||
|         self.conv.bias.data = bias | ||||
|         for para in self.parameters(): | ||||
|             para.detach_() | ||||
|         self.__delattr__('conv1') | ||||
|         self.__delattr__('conv2') | ||||
|         if hasattr(self, 'nm'): | ||||
|             self.__delattr__('nm') | ||||
|         if hasattr(self, 'bn'): | ||||
|             self.__delattr__('bn') | ||||
|         if hasattr(self, 'id_tensor'): | ||||
|             self.__delattr__('id_tensor') | ||||
|         self.__delattr__("conv1") | ||||
|         self.__delattr__("conv2") | ||||
|         if hasattr(self, "nm"): | ||||
|             self.__delattr__("nm") | ||||
|         if hasattr(self, "bn"): | ||||
|             self.__delattr__("bn") | ||||
|         if hasattr(self, "id_tensor"): | ||||
|             self.__delattr__("id_tensor") | ||||
| 
 | ||||
| 
 | ||||
| class ChannelAttention(nn.Module): | ||||
| @ -278,7 +296,7 @@ class SpatialAttention(nn.Module): | ||||
|     def __init__(self, kernel_size=7): | ||||
|         """Initialize Spatial-attention module with kernel size argument.""" | ||||
|         super().__init__() | ||||
|         assert kernel_size in (3, 7), 'kernel size must be 3 or 7' | ||||
|         assert kernel_size in (3, 7), "kernel size must be 3 or 7" | ||||
|         padding = 3 if kernel_size == 7 else 1 | ||||
|         self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) | ||||
|         self.act = nn.Sigmoid() | ||||
|  | ||||
| @ -14,11 +14,12 @@ from .conv import Conv | ||||
| from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer | ||||
| from .utils import bias_init_with_prob, linear_init_ | ||||
| 
 | ||||
| __all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'OBB', 'RTDETRDecoder' | ||||
| __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder" | ||||
| 
 | ||||
| 
 | ||||
| class Detect(nn.Module): | ||||
|     """YOLOv8 Detect head for detection models.""" | ||||
| 
 | ||||
|     dynamic = False  # force grid reconstruction | ||||
|     export = False  # export mode | ||||
|     shape = None | ||||
| @ -35,7 +36,8 @@ class Detect(nn.Module): | ||||
|         self.stride = torch.zeros(self.nl)  # strides computed during build | ||||
|         c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channels | ||||
|         self.cv2 = nn.ModuleList( | ||||
|             nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch) | ||||
|             nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch | ||||
|         ) | ||||
|         self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) | ||||
|         self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() | ||||
| 
 | ||||
| @ -53,14 +55,14 @@ class Detect(nn.Module): | ||||
|             self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) | ||||
|             self.shape = shape | ||||
| 
 | ||||
|         if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'):  # avoid TF FlexSplitV ops | ||||
|             box = x_cat[:, :self.reg_max * 4] | ||||
|             cls = x_cat[:, self.reg_max * 4:] | ||||
|         if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"):  # avoid TF FlexSplitV ops | ||||
|             box = x_cat[:, : self.reg_max * 4] | ||||
|             cls = x_cat[:, self.reg_max * 4 :] | ||||
|         else: | ||||
|             box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) | ||||
|         dbox = self.decode_bboxes(box) | ||||
| 
 | ||||
|         if self.export and self.format in ('tflite', 'edgetpu'): | ||||
|         if self.export and self.format in ("tflite", "edgetpu"): | ||||
|             # Precompute normalization factor to increase numerical stability | ||||
|             # See https://github.com/ultralytics/ultralytics/issues/7371 | ||||
|             img_h = shape[2] | ||||
| @ -79,7 +81,7 @@ class Detect(nn.Module): | ||||
|         # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequency | ||||
|         for a, b, s in zip(m.cv2, m.cv3, m.stride):  # from | ||||
|             a[-1].bias.data[:] = 1.0  # box | ||||
|             b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img) | ||||
|             b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img) | ||||
| 
 | ||||
|     def decode_bboxes(self, bboxes): | ||||
|         """Decode bounding boxes.""" | ||||
| @ -214,6 +216,7 @@ class RTDETRDecoder(nn.Module): | ||||
|     and class labels for objects in an image. It integrates features from multiple layers and runs through a series of | ||||
|     Transformer decoder layers to output the final predictions. | ||||
|     """ | ||||
| 
 | ||||
|     export = False  # export mode | ||||
| 
 | ||||
|     def __init__( | ||||
| @ -226,14 +229,15 @@ class RTDETRDecoder(nn.Module): | ||||
|         nh=8,  # num head | ||||
|         ndl=6,  # num decoder layers | ||||
|         d_ffn=1024,  # dim of feedforward | ||||
|             dropout=0., | ||||
|         dropout=0.0, | ||||
|         act=nn.ReLU(), | ||||
|         eval_idx=-1, | ||||
|         # Training args | ||||
|         nd=100,  # num denoising | ||||
|         label_noise_ratio=0.5, | ||||
|         box_noise_scale=1.0, | ||||
|             learnt_init_query=False): | ||||
|         learnt_init_query=False, | ||||
|     ): | ||||
|         """ | ||||
|         Initializes the RTDETRDecoder module with the given parameters. | ||||
| 
 | ||||
| @ -302,28 +306,30 @@ class RTDETRDecoder(nn.Module): | ||||
|         feats, shapes = self._get_encoder_input(x) | ||||
| 
 | ||||
|         # Prepare denoising training | ||||
|         dn_embed, dn_bbox, attn_mask, dn_meta = \ | ||||
|             get_cdn_group(batch, | ||||
|         dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group( | ||||
|             batch, | ||||
|             self.nc, | ||||
|             self.num_queries, | ||||
|             self.denoising_class_embed.weight, | ||||
|             self.num_denoising, | ||||
|             self.label_noise_ratio, | ||||
|             self.box_noise_scale, | ||||
|                           self.training) | ||||
|             self.training, | ||||
|         ) | ||||
| 
 | ||||
|         embed, refer_bbox, enc_bboxes, enc_scores = \ | ||||
|             self._get_decoder_input(feats, shapes, dn_embed, dn_bbox) | ||||
|         embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox) | ||||
| 
 | ||||
|         # Decoder | ||||
|         dec_bboxes, dec_scores = self.decoder(embed, | ||||
|         dec_bboxes, dec_scores = self.decoder( | ||||
|             embed, | ||||
|             refer_bbox, | ||||
|             feats, | ||||
|             shapes, | ||||
|             self.dec_bbox_head, | ||||
|             self.dec_score_head, | ||||
|             self.query_pos_head, | ||||
|                                               attn_mask=attn_mask) | ||||
|             attn_mask=attn_mask, | ||||
|         ) | ||||
|         x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta | ||||
|         if self.training: | ||||
|             return x | ||||
| @ -331,24 +337,24 @@ class RTDETRDecoder(nn.Module): | ||||
|         y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1) | ||||
|         return y if self.export else (y, x) | ||||
| 
 | ||||
|     def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2): | ||||
|     def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2): | ||||
|         """Generates anchor bounding boxes for given shapes with specific grid size and validates them.""" | ||||
|         anchors = [] | ||||
|         for i, (h, w) in enumerate(shapes): | ||||
|             sy = torch.arange(end=h, dtype=dtype, device=device) | ||||
|             sx = torch.arange(end=w, dtype=dtype, device=device) | ||||
|             grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx) | ||||
|             grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) | ||||
|             grid_xy = torch.stack([grid_x, grid_y], -1)  # (h, w, 2) | ||||
| 
 | ||||
|             valid_WH = torch.tensor([w, h], dtype=dtype, device=device) | ||||
|             grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH  # (1, h, w, 2) | ||||
|             wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i) | ||||
|             wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i) | ||||
|             anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4))  # (1, h*w, 4) | ||||
| 
 | ||||
|         anchors = torch.cat(anchors, 1)  # (1, h*w*nl, 4) | ||||
|         valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)  # 1, h*w*nl, 1 | ||||
|         anchors = torch.log(anchors / (1 - anchors)) | ||||
|         anchors = anchors.masked_fill(~valid_mask, float('inf')) | ||||
|         anchors = anchors.masked_fill(~valid_mask, float("inf")) | ||||
|         return anchors, valid_mask | ||||
| 
 | ||||
|     def _get_encoder_input(self, x): | ||||
| @ -415,13 +421,13 @@ class RTDETRDecoder(nn.Module): | ||||
|         # NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets. | ||||
|         # linear_init_(self.enc_score_head) | ||||
|         constant_(self.enc_score_head.bias, bias_cls) | ||||
|         constant_(self.enc_bbox_head.layers[-1].weight, 0.) | ||||
|         constant_(self.enc_bbox_head.layers[-1].bias, 0.) | ||||
|         constant_(self.enc_bbox_head.layers[-1].weight, 0.0) | ||||
|         constant_(self.enc_bbox_head.layers[-1].bias, 0.0) | ||||
|         for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): | ||||
|             # linear_init_(cls_) | ||||
|             constant_(cls_.bias, bias_cls) | ||||
|             constant_(reg_.layers[-1].weight, 0.) | ||||
|             constant_(reg_.layers[-1].bias, 0.) | ||||
|             constant_(reg_.layers[-1].weight, 0.0) | ||||
|             constant_(reg_.layers[-1].bias, 0.0) | ||||
| 
 | ||||
|         linear_init_(self.enc_output[0]) | ||||
|         xavier_uniform_(self.enc_output[0].weight) | ||||
|  | ||||
| @ -11,8 +11,18 @@ from torch.nn.init import constant_, xavier_uniform_ | ||||
| from .conv import Conv | ||||
| from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch | ||||
| 
 | ||||
| __all__ = ('TransformerEncoderLayer', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'AIFI', | ||||
|            'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP') | ||||
| __all__ = ( | ||||
|     "TransformerEncoderLayer", | ||||
|     "TransformerLayer", | ||||
|     "TransformerBlock", | ||||
|     "MLPBlock", | ||||
|     "LayerNorm2d", | ||||
|     "AIFI", | ||||
|     "DeformableTransformerDecoder", | ||||
|     "DeformableTransformerDecoderLayer", | ||||
|     "MSDeformAttn", | ||||
|     "MLP", | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class TransformerEncoderLayer(nn.Module): | ||||
| @ -22,9 +32,11 @@ class TransformerEncoderLayer(nn.Module): | ||||
|         """Initialize the TransformerEncoderLayer with specified parameters.""" | ||||
|         super().__init__() | ||||
|         from ...utils.torch_utils import TORCH_1_9 | ||||
| 
 | ||||
|         if not TORCH_1_9: | ||||
|             raise ModuleNotFoundError( | ||||
|                 'TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).') | ||||
|                 "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)." | ||||
|             ) | ||||
|         self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True) | ||||
|         # Implementation of Feedforward model | ||||
|         self.fc1 = nn.Linear(c1, cm) | ||||
| @ -91,12 +103,11 @@ class AIFI(TransformerEncoderLayer): | ||||
|         """Builds 2D sine-cosine position embedding.""" | ||||
|         grid_w = torch.arange(int(w), dtype=torch.float32) | ||||
|         grid_h = torch.arange(int(h), dtype=torch.float32) | ||||
|         grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') | ||||
|         assert embed_dim % 4 == 0, \ | ||||
|             'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' | ||||
|         grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") | ||||
|         assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" | ||||
|         pos_dim = embed_dim // 4 | ||||
|         omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim | ||||
|         omega = 1. / (temperature ** omega) | ||||
|         omega = 1.0 / (temperature**omega) | ||||
| 
 | ||||
|         out_w = grid_w.flatten()[..., None] @ omega[None] | ||||
|         out_h = grid_h.flatten()[..., None] @ omega[None] | ||||
| @ -213,10 +224,10 @@ class MSDeformAttn(nn.Module): | ||||
|         """Initialize MSDeformAttn with the given parameters.""" | ||||
|         super().__init__() | ||||
|         if d_model % n_heads != 0: | ||||
|             raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}') | ||||
|             raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}") | ||||
|         _d_per_head = d_model // n_heads | ||||
|         # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation | ||||
|         assert _d_per_head * n_heads == d_model, '`d_model` must be divisible by `n_heads`' | ||||
|         assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`" | ||||
| 
 | ||||
|         self.im2col_step = 64 | ||||
| 
 | ||||
| @ -234,21 +245,24 @@ class MSDeformAttn(nn.Module): | ||||
| 
 | ||||
|     def _reset_parameters(self): | ||||
|         """Reset module parameters.""" | ||||
|         constant_(self.sampling_offsets.weight.data, 0.) | ||||
|         constant_(self.sampling_offsets.weight.data, 0.0) | ||||
|         thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) | ||||
|         grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) | ||||
|         grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat( | ||||
|             1, self.n_levels, self.n_points, 1) | ||||
|         grid_init = ( | ||||
|             (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) | ||||
|             .view(self.n_heads, 1, 1, 2) | ||||
|             .repeat(1, self.n_levels, self.n_points, 1) | ||||
|         ) | ||||
|         for i in range(self.n_points): | ||||
|             grid_init[:, :, i, :] *= i + 1 | ||||
|         with torch.no_grad(): | ||||
|             self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) | ||||
|         constant_(self.attention_weights.weight.data, 0.) | ||||
|         constant_(self.attention_weights.bias.data, 0.) | ||||
|         constant_(self.attention_weights.weight.data, 0.0) | ||||
|         constant_(self.attention_weights.bias.data, 0.0) | ||||
|         xavier_uniform_(self.value_proj.weight.data) | ||||
|         constant_(self.value_proj.bias.data, 0.) | ||||
|         constant_(self.value_proj.bias.data, 0.0) | ||||
|         xavier_uniform_(self.output_proj.weight.data) | ||||
|         constant_(self.output_proj.bias.data, 0.) | ||||
|         constant_(self.output_proj.bias.data, 0.0) | ||||
| 
 | ||||
|     def forward(self, query, refer_bbox, value, value_shapes, value_mask=None): | ||||
|         """ | ||||
| @ -288,7 +302,7 @@ class MSDeformAttn(nn.Module): | ||||
|             add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5 | ||||
|             sampling_locations = refer_bbox[:, :, None, :, None, :2] + add | ||||
|         else: | ||||
|             raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.') | ||||
|             raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.") | ||||
|         output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights) | ||||
|         return self.output_proj(output) | ||||
| 
 | ||||
| @ -301,7 +315,7 @@ class DeformableTransformerDecoderLayer(nn.Module): | ||||
|     https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0., act=nn.ReLU(), n_levels=4, n_points=4): | ||||
|     def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0.0, act=nn.ReLU(), n_levels=4, n_points=4): | ||||
|         """Initialize the DeformableTransformerDecoderLayer with the given parameters.""" | ||||
|         super().__init__() | ||||
| 
 | ||||
| @ -339,14 +353,16 @@ class DeformableTransformerDecoderLayer(nn.Module): | ||||
| 
 | ||||
|         # Self attention | ||||
|         q = k = self.with_pos_embed(embed, query_pos) | ||||
|         tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), | ||||
|                              attn_mask=attn_mask)[0].transpose(0, 1) | ||||
|         tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[ | ||||
|             0 | ||||
|         ].transpose(0, 1) | ||||
|         embed = embed + self.dropout1(tgt) | ||||
|         embed = self.norm1(embed) | ||||
| 
 | ||||
|         # Cross attention | ||||
|         tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, | ||||
|                               padding_mask) | ||||
|         tgt = self.cross_attn( | ||||
|             self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask | ||||
|         ) | ||||
|         embed = embed + self.dropout2(tgt) | ||||
|         embed = self.norm2(embed) | ||||
| 
 | ||||
| @ -379,7 +395,8 @@ class DeformableTransformerDecoder(nn.Module): | ||||
|         score_head, | ||||
|         pos_mlp, | ||||
|         attn_mask=None, | ||||
|             padding_mask=None): | ||||
|         padding_mask=None, | ||||
|     ): | ||||
|         """Perform the forward pass through the entire decoder.""" | ||||
|         output = embed | ||||
|         dec_bboxes = [] | ||||
|  | ||||
| @ -10,7 +10,7 @@ import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from torch.nn.init import uniform_ | ||||
| 
 | ||||
| __all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid' | ||||
| __all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid" | ||||
| 
 | ||||
| 
 | ||||
| def _get_clones(module, n): | ||||
| @ -27,7 +27,7 @@ def linear_init_(module): | ||||
|     """Initialize the weights and biases of a linear module.""" | ||||
|     bound = 1 / math.sqrt(module.weight.shape[0]) | ||||
|     uniform_(module.weight, -bound, bound) | ||||
|     if hasattr(module, 'bias') and module.bias is not None: | ||||
|     if hasattr(module, "bias") and module.bias is not None: | ||||
|         uniform_(module.bias, -bound, bound) | ||||
| 
 | ||||
| 
 | ||||
| @ -39,9 +39,12 @@ def inverse_sigmoid(x, eps=1e-5): | ||||
|     return torch.log(x1 / x2) | ||||
| 
 | ||||
| 
 | ||||
| def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor, | ||||
| def multi_scale_deformable_attn_pytorch( | ||||
|     value: torch.Tensor, | ||||
|     value_spatial_shapes: torch.Tensor, | ||||
|     sampling_locations: torch.Tensor, | ||||
|                                         attention_weights: torch.Tensor) -> torch.Tensor: | ||||
|     attention_weights: torch.Tensor, | ||||
| ) -> torch.Tensor: | ||||
|     """ | ||||
|     Multi-scale deformable attention. | ||||
| 
 | ||||
| @ -58,23 +61,25 @@ def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shape | ||||
|         # bs, H_*W_, num_heads*embed_dims -> | ||||
|         # bs, num_heads*embed_dims, H_*W_ -> | ||||
|         # bs*num_heads, embed_dims, H_, W_ | ||||
|         value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)) | ||||
|         value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_) | ||||
|         # bs, num_queries, num_heads, num_points, 2 -> | ||||
|         # bs, num_heads, num_queries, num_points, 2 -> | ||||
|         # bs*num_heads, num_queries, num_points, 2 | ||||
|         sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) | ||||
|         # bs*num_heads, embed_dims, num_queries, num_points | ||||
|         sampling_value_l_ = F.grid_sample(value_l_, | ||||
|                                           sampling_grid_l_, | ||||
|                                           mode='bilinear', | ||||
|                                           padding_mode='zeros', | ||||
|                                           align_corners=False) | ||||
|         sampling_value_l_ = F.grid_sample( | ||||
|             value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False | ||||
|         ) | ||||
|         sampling_value_list.append(sampling_value_l_) | ||||
|     # (bs, num_queries, num_heads, num_levels, num_points) -> | ||||
|     # (bs, num_heads, num_queries, num_levels, num_points) -> | ||||
|     # (bs, num_heads, 1, num_queries, num_levels*num_points) | ||||
|     attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries, | ||||
|                                                                   num_levels * num_points) | ||||
|     output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view( | ||||
|         bs, num_heads * embed_dims, num_queries)) | ||||
|     attention_weights = attention_weights.transpose(1, 2).reshape( | ||||
|         bs * num_heads, 1, num_queries, num_levels * num_points | ||||
|     ) | ||||
|     output = ( | ||||
|         (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) | ||||
|         .sum(-1) | ||||
|         .view(bs, num_heads * embed_dims, num_queries) | ||||
|     ) | ||||
|     return output.transpose(1, 2).contiguous() | ||||
|  | ||||
| @ -7,16 +7,54 @@ from pathlib import Path | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| 
 | ||||
| from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, OBB, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, | ||||
|                                     C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, | ||||
|                                     DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, | ||||
|                                     RepConv, ResNetLayer, RTDETRDecoder, Segment) | ||||
| from ultralytics.nn.modules import ( | ||||
|     AIFI, | ||||
|     C1, | ||||
|     C2, | ||||
|     C3, | ||||
|     C3TR, | ||||
|     OBB, | ||||
|     SPP, | ||||
|     SPPF, | ||||
|     Bottleneck, | ||||
|     BottleneckCSP, | ||||
|     C2f, | ||||
|     C3Ghost, | ||||
|     C3x, | ||||
|     Classify, | ||||
|     Concat, | ||||
|     Conv, | ||||
|     Conv2, | ||||
|     ConvTranspose, | ||||
|     Detect, | ||||
|     DWConv, | ||||
|     DWConvTranspose2d, | ||||
|     Focus, | ||||
|     GhostBottleneck, | ||||
|     GhostConv, | ||||
|     HGBlock, | ||||
|     HGStem, | ||||
|     Pose, | ||||
|     RepC3, | ||||
|     RepConv, | ||||
|     ResNetLayer, | ||||
|     RTDETRDecoder, | ||||
|     Segment, | ||||
| ) | ||||
| from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load | ||||
| from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml | ||||
| from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss | ||||
| from ultralytics.utils.plotting import feature_visualization | ||||
| from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts, | ||||
|                                            make_divisible, model_info, scale_img, time_sync) | ||||
| from ultralytics.utils.torch_utils import ( | ||||
|     fuse_conv_and_bn, | ||||
|     fuse_deconv_and_bn, | ||||
|     initialize_weights, | ||||
|     intersect_dicts, | ||||
|     make_divisible, | ||||
|     model_info, | ||||
|     scale_img, | ||||
|     time_sync, | ||||
| ) | ||||
| 
 | ||||
| try: | ||||
|     import thop | ||||
| @ -90,8 +128,10 @@ class BaseModel(nn.Module): | ||||
| 
 | ||||
|     def _predict_augment(self, x): | ||||
|         """Perform augmentations on input image x and return augmented inference.""" | ||||
|         LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. ' | ||||
|                        f'Reverting to single-scale inference instead.') | ||||
|         LOGGER.warning( | ||||
|             f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. " | ||||
|             f"Reverting to single-scale inference instead." | ||||
|         ) | ||||
|         return self._predict_once(x) | ||||
| 
 | ||||
|     def _profile_one_layer(self, m, x, dt): | ||||
| @ -108,14 +148,14 @@ class BaseModel(nn.Module): | ||||
|             None | ||||
|         """ | ||||
|         c = m == self.model[-1] and isinstance(x, list)  # is final layer list, copy input as inplace fix | ||||
|         flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0  # FLOPs | ||||
|         flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0  # FLOPs | ||||
|         t = time_sync() | ||||
|         for _ in range(10): | ||||
|             m(x.copy() if c else x) | ||||
|         dt.append((time_sync() - t) * 100) | ||||
|         if m == self.model[0]: | ||||
|             LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s}  module") | ||||
|         LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f}  {m.type}') | ||||
|         LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f}  {m.type}") | ||||
|         if c: | ||||
|             LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s}  Total") | ||||
| 
 | ||||
| @ -129,15 +169,15 @@ class BaseModel(nn.Module): | ||||
|         """ | ||||
|         if not self.is_fused(): | ||||
|             for m in self.model.modules(): | ||||
|                 if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'): | ||||
|                 if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"): | ||||
|                     if isinstance(m, Conv2): | ||||
|                         m.fuse_convs() | ||||
|                     m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv | ||||
|                     delattr(m, 'bn')  # remove batchnorm | ||||
|                     delattr(m, "bn")  # remove batchnorm | ||||
|                     m.forward = m.forward_fuse  # update forward | ||||
|                 if isinstance(m, ConvTranspose) and hasattr(m, 'bn'): | ||||
|                 if isinstance(m, ConvTranspose) and hasattr(m, "bn"): | ||||
|                     m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) | ||||
|                     delattr(m, 'bn')  # remove batchnorm | ||||
|                     delattr(m, "bn")  # remove batchnorm | ||||
|                     m.forward = m.forward_fuse  # update forward | ||||
|                 if isinstance(m, RepConv): | ||||
|                     m.fuse_convs() | ||||
| @ -156,7 +196,7 @@ class BaseModel(nn.Module): | ||||
|         Returns: | ||||
|             (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise. | ||||
|         """ | ||||
|         bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d() | ||||
|         bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d() | ||||
|         return sum(isinstance(v, bn) for v in self.modules()) < thresh  # True if < 'thresh' BatchNorm layers in model | ||||
| 
 | ||||
|     def info(self, detailed=False, verbose=True, imgsz=640): | ||||
| @ -196,12 +236,12 @@ class BaseModel(nn.Module): | ||||
|             weights (dict | torch.nn.Module): The pre-trained weights to be loaded. | ||||
|             verbose (bool, optional): Whether to log the transfer progress. Defaults to True. | ||||
|         """ | ||||
|         model = weights['model'] if isinstance(weights, dict) else weights  # torchvision models are not dicts | ||||
|         model = weights["model"] if isinstance(weights, dict) else weights  # torchvision models are not dicts | ||||
|         csd = model.float().state_dict()  # checkpoint state_dict as FP32 | ||||
|         csd = intersect_dicts(csd, self.state_dict())  # intersect | ||||
|         self.load_state_dict(csd, strict=False)  # load | ||||
|         if verbose: | ||||
|             LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights') | ||||
|             LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights") | ||||
| 
 | ||||
|     def loss(self, batch, preds=None): | ||||
|         """ | ||||
| @ -211,33 +251,33 @@ class BaseModel(nn.Module): | ||||
|             batch (dict): Batch to compute loss on | ||||
|             preds (torch.Tensor | List[torch.Tensor]): Predictions. | ||||
|         """ | ||||
|         if not hasattr(self, 'criterion'): | ||||
|         if not hasattr(self, "criterion"): | ||||
|             self.criterion = self.init_criterion() | ||||
| 
 | ||||
|         preds = self.forward(batch['img']) if preds is None else preds | ||||
|         preds = self.forward(batch["img"]) if preds is None else preds | ||||
|         return self.criterion(preds, batch) | ||||
| 
 | ||||
|     def init_criterion(self): | ||||
|         """Initialize the loss criterion for the BaseModel.""" | ||||
|         raise NotImplementedError('compute_loss() needs to be implemented by task heads') | ||||
|         raise NotImplementedError("compute_loss() needs to be implemented by task heads") | ||||
| 
 | ||||
| 
 | ||||
| class DetectionModel(BaseModel): | ||||
|     """YOLOv8 detection model.""" | ||||
| 
 | ||||
|     def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True):  # model, input channels, number of classes | ||||
|     def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True):  # model, input channels, number of classes | ||||
|         """Initialize the YOLOv8 detection model with the given config and parameters.""" | ||||
|         super().__init__() | ||||
|         self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dict | ||||
| 
 | ||||
|         # Define model | ||||
|         ch = self.yaml['ch'] = self.yaml.get('ch', ch)  # input channels | ||||
|         if nc and nc != self.yaml['nc']: | ||||
|         ch = self.yaml["ch"] = self.yaml.get("ch", ch)  # input channels | ||||
|         if nc and nc != self.yaml["nc"]: | ||||
|             LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") | ||||
|             self.yaml['nc'] = nc  # override YAML value | ||||
|             self.yaml["nc"] = nc  # override YAML value | ||||
|         self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelist | ||||
|         self.names = {i: f'{i}' for i in range(self.yaml['nc'])}  # default names dict | ||||
|         self.inplace = self.yaml.get('inplace', True) | ||||
|         self.names = {i: f"{i}" for i in range(self.yaml["nc"])}  # default names dict | ||||
|         self.inplace = self.yaml.get("inplace", True) | ||||
| 
 | ||||
|         # Build strides | ||||
|         m = self.model[-1]  # Detect() | ||||
| @ -255,7 +295,7 @@ class DetectionModel(BaseModel): | ||||
|         initialize_weights(self) | ||||
|         if verbose: | ||||
|             self.info() | ||||
|             LOGGER.info('') | ||||
|             LOGGER.info("") | ||||
| 
 | ||||
|     def _predict_augment(self, x): | ||||
|         """Perform augmentations on input image x and return augmented inference and train outputs.""" | ||||
| @ -285,9 +325,9 @@ class DetectionModel(BaseModel): | ||||
|     def _clip_augmented(self, y): | ||||
|         """Clip YOLO augmented inference tails.""" | ||||
|         nl = self.model[-1].nl  # number of detection layers (P3-P5) | ||||
|         g = sum(4 ** x for x in range(nl))  # grid points | ||||
|         g = sum(4**x for x in range(nl))  # grid points | ||||
|         e = 1  # exclude layer count | ||||
|         i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e))  # indices | ||||
|         i = (y[0].shape[-1] // g) * sum(4**x for x in range(e))  # indices | ||||
|         y[0] = y[0][..., :-i]  # large | ||||
|         i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e))  # indices | ||||
|         y[-1] = y[-1][..., i:]  # small | ||||
| @ -301,7 +341,7 @@ class DetectionModel(BaseModel): | ||||
| class OBBModel(DetectionModel): | ||||
|     """"YOLOv8 Oriented Bounding Box (OBB) model.""" | ||||
| 
 | ||||
|     def __init__(self, cfg='yolov8n-obb.yaml', ch=3, nc=None, verbose=True): | ||||
|     def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True): | ||||
|         """Initialize YOLOv8 OBB model with given config and parameters.""" | ||||
|         super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) | ||||
| 
 | ||||
| @ -313,7 +353,7 @@ class OBBModel(DetectionModel): | ||||
| class SegmentationModel(DetectionModel): | ||||
|     """YOLOv8 segmentation model.""" | ||||
| 
 | ||||
|     def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True): | ||||
|     def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True): | ||||
|         """Initialize YOLOv8 segmentation model with given config and parameters.""" | ||||
|         super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) | ||||
| 
 | ||||
| @ -325,13 +365,13 @@ class SegmentationModel(DetectionModel): | ||||
| class PoseModel(DetectionModel): | ||||
|     """YOLOv8 pose model.""" | ||||
| 
 | ||||
|     def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): | ||||
|     def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): | ||||
|         """Initialize YOLOv8 Pose model.""" | ||||
|         if not isinstance(cfg, dict): | ||||
|             cfg = yaml_model_load(cfg)  # load model YAML | ||||
|         if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']): | ||||
|         if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]): | ||||
|             LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}") | ||||
|             cfg['kpt_shape'] = data_kpt_shape | ||||
|             cfg["kpt_shape"] = data_kpt_shape | ||||
|         super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) | ||||
| 
 | ||||
|     def init_criterion(self): | ||||
| @ -342,7 +382,7 @@ class PoseModel(DetectionModel): | ||||
| class ClassificationModel(BaseModel): | ||||
|     """YOLOv8 classification model.""" | ||||
| 
 | ||||
|     def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True): | ||||
|     def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True): | ||||
|         """Init ClassificationModel with YAML, channels, number of classes, verbose flag.""" | ||||
|         super().__init__() | ||||
|         self._from_yaml(cfg, ch, nc, verbose) | ||||
| @ -352,21 +392,21 @@ class ClassificationModel(BaseModel): | ||||
|         self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dict | ||||
| 
 | ||||
|         # Define model | ||||
|         ch = self.yaml['ch'] = self.yaml.get('ch', ch)  # input channels | ||||
|         if nc and nc != self.yaml['nc']: | ||||
|         ch = self.yaml["ch"] = self.yaml.get("ch", ch)  # input channels | ||||
|         if nc and nc != self.yaml["nc"]: | ||||
|             LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") | ||||
|             self.yaml['nc'] = nc  # override YAML value | ||||
|         elif not nc and not self.yaml.get('nc', None): | ||||
|             raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.') | ||||
|             self.yaml["nc"] = nc  # override YAML value | ||||
|         elif not nc and not self.yaml.get("nc", None): | ||||
|             raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.") | ||||
|         self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelist | ||||
|         self.stride = torch.Tensor([1])  # no stride constraints | ||||
|         self.names = {i: f'{i}' for i in range(self.yaml['nc'])}  # default names dict | ||||
|         self.names = {i: f"{i}" for i in range(self.yaml["nc"])}  # default names dict | ||||
|         self.info() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def reshape_outputs(model, nc): | ||||
|         """Update a TorchVision classification model to class count 'n' if required.""" | ||||
|         name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1]  # last module | ||||
|         name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1]  # last module | ||||
|         if isinstance(m, Classify):  # YOLO Classify() head | ||||
|             if m.linear.out_features != nc: | ||||
|                 m.linear = nn.Linear(m.linear.in_features, nc) | ||||
| @ -409,7 +449,7 @@ class RTDETRDetectionModel(DetectionModel): | ||||
|         predict: Performs a forward pass through the network and returns the output. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True): | ||||
|     def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True): | ||||
|         """ | ||||
|         Initialize the RTDETRDetectionModel. | ||||
| 
 | ||||
| @ -438,39 +478,39 @@ class RTDETRDetectionModel(DetectionModel): | ||||
|         Returns: | ||||
|             (tuple): A tuple containing the total loss and main three losses in a tensor. | ||||
|         """ | ||||
|         if not hasattr(self, 'criterion'): | ||||
|         if not hasattr(self, "criterion"): | ||||
|             self.criterion = self.init_criterion() | ||||
| 
 | ||||
|         img = batch['img'] | ||||
|         img = batch["img"] | ||||
|         # NOTE: preprocess gt_bbox and gt_labels to list. | ||||
|         bs = len(img) | ||||
|         batch_idx = batch['batch_idx'] | ||||
|         batch_idx = batch["batch_idx"] | ||||
|         gt_groups = [(batch_idx == i).sum().item() for i in range(bs)] | ||||
|         targets = { | ||||
|             'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1), | ||||
|             'bboxes': batch['bboxes'].to(device=img.device), | ||||
|             'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1), | ||||
|             'gt_groups': gt_groups} | ||||
|             "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1), | ||||
|             "bboxes": batch["bboxes"].to(device=img.device), | ||||
|             "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1), | ||||
|             "gt_groups": gt_groups, | ||||
|         } | ||||
| 
 | ||||
|         preds = self.predict(img, batch=targets) if preds is None else preds | ||||
|         dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1] | ||||
|         if dn_meta is None: | ||||
|             dn_bboxes, dn_scores = None, None | ||||
|         else: | ||||
|             dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2) | ||||
|             dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2) | ||||
|             dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2) | ||||
|             dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2) | ||||
| 
 | ||||
|         dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes])  # (7, bs, 300, 4) | ||||
|         dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores]) | ||||
| 
 | ||||
|         loss = self.criterion((dec_bboxes, dec_scores), | ||||
|                               targets, | ||||
|                               dn_bboxes=dn_bboxes, | ||||
|                               dn_scores=dn_scores, | ||||
|                               dn_meta=dn_meta) | ||||
|         loss = self.criterion( | ||||
|             (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta | ||||
|         ) | ||||
|         # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses. | ||||
|         return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']], | ||||
|                                                    device=img.device) | ||||
|         return sum(loss.values()), torch.as_tensor( | ||||
|             [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device | ||||
|         ) | ||||
| 
 | ||||
|     def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None): | ||||
|         """ | ||||
| @ -553,6 +593,7 @@ def temporary_modules(modules=None): | ||||
| 
 | ||||
|     import importlib | ||||
|     import sys | ||||
| 
 | ||||
|     try: | ||||
|         # Set modules in sys.modules under their old name | ||||
|         for old, new in modules.items(): | ||||
| @ -580,30 +621,38 @@ def torch_safe_load(weight): | ||||
|     """ | ||||
|     from ultralytics.utils.downloads import attempt_download_asset | ||||
| 
 | ||||
|     check_suffix(file=weight, suffix='.pt') | ||||
|     check_suffix(file=weight, suffix=".pt") | ||||
|     file = attempt_download_asset(weight)  # search online if missing locally | ||||
|     try: | ||||
|         with temporary_modules({ | ||||
|                 'ultralytics.yolo.utils': 'ultralytics.utils', | ||||
|                 'ultralytics.yolo.v8': 'ultralytics.models.yolo', | ||||
|                 'ultralytics.yolo.data': 'ultralytics.data'}):  # for legacy 8.0 Classify and Pose models | ||||
|             return torch.load(file, map_location='cpu'), file  # load | ||||
|         with temporary_modules( | ||||
|             { | ||||
|                 "ultralytics.yolo.utils": "ultralytics.utils", | ||||
|                 "ultralytics.yolo.v8": "ultralytics.models.yolo", | ||||
|                 "ultralytics.yolo.data": "ultralytics.data", | ||||
|             } | ||||
|         ):  # for legacy 8.0 Classify and Pose models | ||||
|             return torch.load(file, map_location="cpu"), file  # load | ||||
| 
 | ||||
|     except ModuleNotFoundError as e:  # e.name is missing module name | ||||
|         if e.name == 'models': | ||||
|         if e.name == "models": | ||||
|             raise TypeError( | ||||
|                 emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained ' | ||||
|                        f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with ' | ||||
|                        f'YOLOv8 at https://github.com/ultralytics/ultralytics.' | ||||
|                 emojis( | ||||
|                     f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained " | ||||
|                     f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with " | ||||
|                     f"YOLOv8 at https://github.com/ultralytics/ultralytics." | ||||
|                     f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " | ||||
|                        f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e | ||||
|         LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements." | ||||
|                     f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'" | ||||
|                 ) | ||||
|             ) from e | ||||
|         LOGGER.warning( | ||||
|             f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements." | ||||
|             f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future." | ||||
|             f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " | ||||
|                        f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'") | ||||
|             f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'" | ||||
|         ) | ||||
|         check_requirements(e.name)  # install missing module | ||||
| 
 | ||||
|         return torch.load(file, map_location='cpu'), file  # load | ||||
|         return torch.load(file, map_location="cpu"), file  # load | ||||
| 
 | ||||
| 
 | ||||
| def attempt_load_weights(weights, device=None, inplace=True, fuse=False): | ||||
| @ -612,25 +661,25 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): | ||||
|     ensemble = Ensemble() | ||||
|     for w in weights if isinstance(weights, list) else [weights]: | ||||
|         ckpt, w = torch_safe_load(w)  # load ckpt | ||||
|         args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None  # combined args | ||||
|         model = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model | ||||
|         args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None  # combined args | ||||
|         model = (ckpt.get("ema") or ckpt["model"]).to(device).float()  # FP32 model | ||||
| 
 | ||||
|         # Model compatibility updates | ||||
|         model.args = args  # attach args to model | ||||
|         model.pt_path = w  # attach *.pt file path to model | ||||
|         model.task = guess_model_task(model) | ||||
|         if not hasattr(model, 'stride'): | ||||
|             model.stride = torch.tensor([32.]) | ||||
|         if not hasattr(model, "stride"): | ||||
|             model.stride = torch.tensor([32.0]) | ||||
| 
 | ||||
|         # Append | ||||
|         ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval())  # model in eval mode | ||||
|         ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval())  # model in eval mode | ||||
| 
 | ||||
|     # Module updates | ||||
|     for m in ensemble.modules(): | ||||
|         t = type(m) | ||||
|         if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB): | ||||
|             m.inplace = inplace | ||||
|         elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'): | ||||
|         elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"): | ||||
|             m.recompute_scale_factor = None  # torch 1.11.0 compatibility | ||||
| 
 | ||||
|     # Return model | ||||
| @ -638,35 +687,35 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): | ||||
|         return ensemble[-1] | ||||
| 
 | ||||
|     # Return ensemble | ||||
|     LOGGER.info(f'Ensemble created with {weights}\n') | ||||
|     for k in 'names', 'nc', 'yaml': | ||||
|     LOGGER.info(f"Ensemble created with {weights}\n") | ||||
|     for k in "names", "nc", "yaml": | ||||
|         setattr(ensemble, k, getattr(ensemble[0], k)) | ||||
|     ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride | ||||
|     assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}' | ||||
|     assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}" | ||||
|     return ensemble | ||||
| 
 | ||||
| 
 | ||||
| def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): | ||||
|     """Loads a single model weights.""" | ||||
|     ckpt, weight = torch_safe_load(weight)  # load ckpt | ||||
|     args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))}  # combine model and default args, preferring model args | ||||
|     model = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model | ||||
|     args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))}  # combine model and default args, preferring model args | ||||
|     model = (ckpt.get("ema") or ckpt["model"]).to(device).float()  # FP32 model | ||||
| 
 | ||||
|     # Model compatibility updates | ||||
|     model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # attach args to model | ||||
|     model.pt_path = weight  # attach *.pt file path to model | ||||
|     model.task = guess_model_task(model) | ||||
|     if not hasattr(model, 'stride'): | ||||
|         model.stride = torch.tensor([32.]) | ||||
|     if not hasattr(model, "stride"): | ||||
|         model.stride = torch.tensor([32.0]) | ||||
| 
 | ||||
|     model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()  # model in eval mode | ||||
|     model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()  # model in eval mode | ||||
| 
 | ||||
|     # Module updates | ||||
|     for m in model.modules(): | ||||
|         t = type(m) | ||||
|         if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB): | ||||
|             m.inplace = inplace | ||||
|         elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'): | ||||
|         elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"): | ||||
|             m.recompute_scale_factor = None  # torch 1.11.0 compatibility | ||||
| 
 | ||||
|     # Return model and ckpt | ||||
| @ -678,11 +727,11 @@ def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3) | ||||
|     import ast | ||||
| 
 | ||||
|     # Args | ||||
|     max_channels = float('inf') | ||||
|     nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales')) | ||||
|     depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape')) | ||||
|     max_channels = float("inf") | ||||
|     nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales")) | ||||
|     depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape")) | ||||
|     if scales: | ||||
|         scale = d.get('scale') | ||||
|         scale = d.get("scale") | ||||
|         if not scale: | ||||
|             scale = tuple(scales.keys())[0] | ||||
|             LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.") | ||||
| @ -697,16 +746,37 @@ def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3) | ||||
|         LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}") | ||||
|     ch = [ch] | ||||
|     layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out | ||||
|     for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args | ||||
|         m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m]  # get module | ||||
|     for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # from, number, module, args | ||||
|         m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m]  # get module | ||||
|         for j, a in enumerate(args): | ||||
|             if isinstance(a, str): | ||||
|                 with contextlib.suppress(ValueError): | ||||
|                     args[j] = locals()[a] if a in locals() else ast.literal_eval(a) | ||||
| 
 | ||||
|         n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain | ||||
|         if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, | ||||
|                  BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3): | ||||
|         if m in ( | ||||
|             Classify, | ||||
|             Conv, | ||||
|             ConvTranspose, | ||||
|             GhostConv, | ||||
|             Bottleneck, | ||||
|             GhostBottleneck, | ||||
|             SPP, | ||||
|             SPPF, | ||||
|             DWConv, | ||||
|             Focus, | ||||
|             BottleneckCSP, | ||||
|             C1, | ||||
|             C2, | ||||
|             C2f, | ||||
|             C3, | ||||
|             C3TR, | ||||
|             C3Ghost, | ||||
|             nn.ConvTranspose2d, | ||||
|             DWConvTranspose2d, | ||||
|             C3x, | ||||
|             RepC3, | ||||
|         ): | ||||
|             c1, c2 = ch[f], args[0] | ||||
|             if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output) | ||||
|                 c2 = make_divisible(min(c2, max_channels) * width, 8) | ||||
| @ -739,11 +809,11 @@ def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3) | ||||
|             c2 = ch[f] | ||||
| 
 | ||||
|         m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module | ||||
|         t = str(m)[8:-2].replace('__main__.', '')  # module type | ||||
|         t = str(m)[8:-2].replace("__main__.", "")  # module type | ||||
|         m.np = sum(x.numel() for x in m_.parameters())  # number params | ||||
|         m_.i, m_.f, m_.type = i, f, t  # attach index, 'from' index, type | ||||
|         if verbose: | ||||
|             LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}')  # print | ||||
|             LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}")  # print | ||||
|         save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist | ||||
|         layers.append(m_) | ||||
|         if i == 0: | ||||
| @ -757,16 +827,16 @@ def yaml_model_load(path): | ||||
|     import re | ||||
| 
 | ||||
|     path = Path(path) | ||||
|     if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)): | ||||
|         new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem) | ||||
|         LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.') | ||||
|     if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)): | ||||
|         new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem) | ||||
|         LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") | ||||
|         path = path.with_name(new_stem + path.suffix) | ||||
| 
 | ||||
|     unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path))  # i.e. yolov8x.yaml -> yolov8.yaml | ||||
|     unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path))  # i.e. yolov8x.yaml -> yolov8.yaml | ||||
|     yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) | ||||
|     d = yaml_load(yaml_file)  # model dict | ||||
|     d['scale'] = guess_model_scale(path) | ||||
|     d['yaml_file'] = str(path) | ||||
|     d["scale"] = guess_model_scale(path) | ||||
|     d["yaml_file"] = str(path) | ||||
|     return d | ||||
| 
 | ||||
| 
 | ||||
| @ -784,8 +854,9 @@ def guess_model_scale(model_path): | ||||
|     """ | ||||
|     with contextlib.suppress(AttributeError): | ||||
|         import re | ||||
|         return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1)  # n, s, m, l, or x | ||||
|     return '' | ||||
| 
 | ||||
|         return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1)  # n, s, m, l, or x | ||||
|     return "" | ||||
| 
 | ||||
| 
 | ||||
| def guess_model_task(model): | ||||
| @ -804,17 +875,17 @@ def guess_model_task(model): | ||||
| 
 | ||||
|     def cfg2task(cfg): | ||||
|         """Guess from YAML dictionary.""" | ||||
|         m = cfg['head'][-1][-2].lower()  # output module name | ||||
|         if m in ('classify', 'classifier', 'cls', 'fc'): | ||||
|             return 'classify' | ||||
|         if m == 'detect': | ||||
|             return 'detect' | ||||
|         if m == 'segment': | ||||
|             return 'segment' | ||||
|         if m == 'pose': | ||||
|             return 'pose' | ||||
|         if m == 'obb': | ||||
|             return 'obb' | ||||
|         m = cfg["head"][-1][-2].lower()  # output module name | ||||
|         if m in ("classify", "classifier", "cls", "fc"): | ||||
|             return "classify" | ||||
|         if m == "detect": | ||||
|             return "detect" | ||||
|         if m == "segment": | ||||
|             return "segment" | ||||
|         if m == "pose": | ||||
|             return "pose" | ||||
|         if m == "obb": | ||||
|             return "obb" | ||||
| 
 | ||||
|     # Guess from model cfg | ||||
|     if isinstance(model, dict): | ||||
| @ -823,40 +894,42 @@ def guess_model_task(model): | ||||
| 
 | ||||
|     # Guess from PyTorch model | ||||
|     if isinstance(model, nn.Module):  # PyTorch model | ||||
|         for x in 'model.args', 'model.model.args', 'model.model.model.args': | ||||
|         for x in "model.args", "model.model.args", "model.model.model.args": | ||||
|             with contextlib.suppress(Exception): | ||||
|                 return eval(x)['task'] | ||||
|         for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml': | ||||
|                 return eval(x)["task"] | ||||
|         for x in "model.yaml", "model.model.yaml", "model.model.model.yaml": | ||||
|             with contextlib.suppress(Exception): | ||||
|                 return cfg2task(eval(x)) | ||||
| 
 | ||||
|         for m in model.modules(): | ||||
|             if isinstance(m, Detect): | ||||
|                 return 'detect' | ||||
|                 return "detect" | ||||
|             elif isinstance(m, Segment): | ||||
|                 return 'segment' | ||||
|                 return "segment" | ||||
|             elif isinstance(m, Classify): | ||||
|                 return 'classify' | ||||
|                 return "classify" | ||||
|             elif isinstance(m, Pose): | ||||
|                 return 'pose' | ||||
|                 return "pose" | ||||
|             elif isinstance(m, OBB): | ||||
|                 return 'obb' | ||||
|                 return "obb" | ||||
| 
 | ||||
|     # Guess from model filename | ||||
|     if isinstance(model, (str, Path)): | ||||
|         model = Path(model) | ||||
|         if '-seg' in model.stem or 'segment' in model.parts: | ||||
|             return 'segment' | ||||
|         elif '-cls' in model.stem or 'classify' in model.parts: | ||||
|             return 'classify' | ||||
|         elif '-pose' in model.stem or 'pose' in model.parts: | ||||
|             return 'pose' | ||||
|         elif '-obb' in model.stem or 'obb' in model.parts: | ||||
|             return 'obb' | ||||
|         elif 'detect' in model.parts: | ||||
|             return 'detect' | ||||
|         if "-seg" in model.stem or "segment" in model.parts: | ||||
|             return "segment" | ||||
|         elif "-cls" in model.stem or "classify" in model.parts: | ||||
|             return "classify" | ||||
|         elif "-pose" in model.stem or "pose" in model.parts: | ||||
|             return "pose" | ||||
|         elif "-obb" in model.stem or "obb" in model.parts: | ||||
|             return "obb" | ||||
|         elif "detect" in model.parts: | ||||
|             return "detect" | ||||
| 
 | ||||
|     # Unable to determine task from model | ||||
|     LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. " | ||||
|                    "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.") | ||||
|     return 'detect'  # assume detect | ||||
|     LOGGER.warning( | ||||
|         "WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. " | ||||
|         "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'." | ||||
|     ) | ||||
|     return "detect"  # assume detect | ||||
|  | ||||
| @ -26,7 +26,7 @@ class AIGym: | ||||
|         self.angle = None | ||||
|         self.count = None | ||||
|         self.stage = None | ||||
|         self.pose_type = 'pushup' | ||||
|         self.pose_type = "pushup" | ||||
|         self.kpts_to_check = None | ||||
| 
 | ||||
|         # Visual Information | ||||
| @ -36,13 +36,15 @@ class AIGym: | ||||
|         # Check if environment support imshow | ||||
|         self.env_check = check_imshow(warn=True) | ||||
| 
 | ||||
|     def set_args(self, | ||||
|     def set_args( | ||||
|         self, | ||||
|         kpts_to_check, | ||||
|         line_thickness=2, | ||||
|         view_img=False, | ||||
|         pose_up_angle=145.0, | ||||
|         pose_down_angle=90.0, | ||||
|                  pose_type='pullup'): | ||||
|         pose_type="pullup", | ||||
|     ): | ||||
|         """ | ||||
|         Configures the AIGym line_thickness, save image and view image parameters | ||||
|         Args: | ||||
| @ -72,65 +74,75 @@ class AIGym: | ||||
|         if frame_count == 1: | ||||
|             self.count = [0] * len(results[0]) | ||||
|             self.angle = [0] * len(results[0]) | ||||
|             self.stage = ['-' for _ in results[0]] | ||||
|             self.stage = ["-" for _ in results[0]] | ||||
|         self.keypoints = results[0].keypoints.data | ||||
|         self.annotator = Annotator(im0, line_width=2) | ||||
| 
 | ||||
|         for ind, k in enumerate(reversed(self.keypoints)): | ||||
|             if self.pose_type == 'pushup' or self.pose_type == 'pullup': | ||||
|                 self.angle[ind] = self.annotator.estimate_pose_angle(k[int(self.kpts_to_check[0])].cpu(), | ||||
|             if self.pose_type == "pushup" or self.pose_type == "pullup": | ||||
|                 self.angle[ind] = self.annotator.estimate_pose_angle( | ||||
|                     k[int(self.kpts_to_check[0])].cpu(), | ||||
|                     k[int(self.kpts_to_check[1])].cpu(), | ||||
|                                                                      k[int(self.kpts_to_check[2])].cpu()) | ||||
|                     k[int(self.kpts_to_check[2])].cpu(), | ||||
|                 ) | ||||
|                 self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10) | ||||
| 
 | ||||
|             if self.pose_type == 'abworkout': | ||||
|                 self.angle[ind] = self.annotator.estimate_pose_angle(k[int(self.kpts_to_check[0])].cpu(), | ||||
|             if self.pose_type == "abworkout": | ||||
|                 self.angle[ind] = self.annotator.estimate_pose_angle( | ||||
|                     k[int(self.kpts_to_check[0])].cpu(), | ||||
|                     k[int(self.kpts_to_check[1])].cpu(), | ||||
|                                                                      k[int(self.kpts_to_check[2])].cpu()) | ||||
|                     k[int(self.kpts_to_check[2])].cpu(), | ||||
|                 ) | ||||
|                 self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10) | ||||
|                 if self.angle[ind] > self.poseup_angle: | ||||
|                     self.stage[ind] = 'down' | ||||
|                 if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'down': | ||||
|                     self.stage[ind] = 'up' | ||||
|                     self.stage[ind] = "down" | ||||
|                 if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down": | ||||
|                     self.stage[ind] = "up" | ||||
|                     self.count[ind] += 1 | ||||
|                 self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind], | ||||
|                 self.annotator.plot_angle_and_count_and_stage( | ||||
|                     angle_text=self.angle[ind], | ||||
|                     count_text=self.count[ind], | ||||
|                     stage_text=self.stage[ind], | ||||
|                     center_kpt=k[int(self.kpts_to_check[1])], | ||||
|                                                               line_thickness=self.tf) | ||||
|                     line_thickness=self.tf, | ||||
|                 ) | ||||
| 
 | ||||
|             if self.pose_type == 'pushup': | ||||
|             if self.pose_type == "pushup": | ||||
|                 if self.angle[ind] > self.poseup_angle: | ||||
|                     self.stage[ind] = 'up' | ||||
|                 if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'up': | ||||
|                     self.stage[ind] = 'down' | ||||
|                     self.stage[ind] = "up" | ||||
|                 if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up": | ||||
|                     self.stage[ind] = "down" | ||||
|                     self.count[ind] += 1 | ||||
|                 self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind], | ||||
|                 self.annotator.plot_angle_and_count_and_stage( | ||||
|                     angle_text=self.angle[ind], | ||||
|                     count_text=self.count[ind], | ||||
|                     stage_text=self.stage[ind], | ||||
|                     center_kpt=k[int(self.kpts_to_check[1])], | ||||
|                                                               line_thickness=self.tf) | ||||
|             if self.pose_type == 'pullup': | ||||
|                     line_thickness=self.tf, | ||||
|                 ) | ||||
|             if self.pose_type == "pullup": | ||||
|                 if self.angle[ind] > self.poseup_angle: | ||||
|                     self.stage[ind] = 'down' | ||||
|                 if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'down': | ||||
|                     self.stage[ind] = 'up' | ||||
|                     self.stage[ind] = "down" | ||||
|                 if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down": | ||||
|                     self.stage[ind] = "up" | ||||
|                     self.count[ind] += 1 | ||||
|                 self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind], | ||||
|                 self.annotator.plot_angle_and_count_and_stage( | ||||
|                     angle_text=self.angle[ind], | ||||
|                     count_text=self.count[ind], | ||||
|                     stage_text=self.stage[ind], | ||||
|                     center_kpt=k[int(self.kpts_to_check[1])], | ||||
|                                                               line_thickness=self.tf) | ||||
|                     line_thickness=self.tf, | ||||
|                 ) | ||||
| 
 | ||||
|             self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True) | ||||
| 
 | ||||
|         if self.env_check and self.view_img: | ||||
|             cv2.imshow('Ultralytics YOLOv8 AI GYM', self.im0) | ||||
|             if cv2.waitKey(1) & 0xFF == ord('q'): | ||||
|             cv2.imshow("Ultralytics YOLOv8 AI GYM", self.im0) | ||||
|             if cv2.waitKey(1) & 0xFF == ord("q"): | ||||
|                 return | ||||
| 
 | ||||
|         return self.im0 | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     AIGym() | ||||
|  | ||||
| @ -41,13 +41,15 @@ class DistanceCalculation: | ||||
|         # Check if environment support imshow | ||||
|         self.env_check = check_imshow(warn=True) | ||||
| 
 | ||||
|     def set_args(self, | ||||
|     def set_args( | ||||
|         self, | ||||
|         names, | ||||
|         pixels_per_meter=10, | ||||
|         view_img=False, | ||||
|         line_thickness=2, | ||||
|         line_color=(255, 255, 0), | ||||
|                  centroid_color=(255, 0, 255)): | ||||
|         centroid_color=(255, 0, 255), | ||||
|     ): | ||||
|         """ | ||||
|         Configures the distance calculation and display parameters. | ||||
| 
 | ||||
| @ -129,8 +131,9 @@ class DistanceCalculation: | ||||
|             distance (float): Distance between two centroids | ||||
|         """ | ||||
|         cv2.rectangle(self.im0, (15, 25), (280, 70), (255, 255, 255), -1) | ||||
|         cv2.putText(self.im0, f'Distance : {distance:.2f}m', (20, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2, | ||||
|                     cv2.LINE_AA) | ||||
|         cv2.putText( | ||||
|             self.im0, f"Distance : {distance:.2f}m", (20, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2, cv2.LINE_AA | ||||
|         ) | ||||
|         cv2.line(self.im0, self.centroids[0], self.centroids[1], self.line_color, 3) | ||||
|         cv2.circle(self.im0, self.centroids[0], 6, self.centroid_color, -1) | ||||
|         cv2.circle(self.im0, self.centroids[1], 6, self.centroid_color, -1) | ||||
| @ -179,13 +182,13 @@ class DistanceCalculation: | ||||
| 
 | ||||
|     def display_frames(self): | ||||
|         """Display frame.""" | ||||
|         cv2.namedWindow('Ultralytics Distance Estimation') | ||||
|         cv2.setMouseCallback('Ultralytics Distance Estimation', self.mouse_event_for_distance) | ||||
|         cv2.imshow('Ultralytics Distance Estimation', self.im0) | ||||
|         cv2.namedWindow("Ultralytics Distance Estimation") | ||||
|         cv2.setMouseCallback("Ultralytics Distance Estimation", self.mouse_event_for_distance) | ||||
|         cv2.imshow("Ultralytics Distance Estimation", self.im0) | ||||
| 
 | ||||
|         if cv2.waitKey(1) & 0xFF == ord('q'): | ||||
|         if cv2.waitKey(1) & 0xFF == ord("q"): | ||||
|             return | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|     DistanceCalculation() | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Glenn Jocher
						Glenn Jocher