mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-11-01 07:05:39 +08:00
ultralytics 8.0.239 Ultralytics Actions and hub-sdk adoption (#7431)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
parent
e795277391
commit
fe27db2f6e
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
@ -95,7 +95,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: ['3.10']
|
python-version: ['3.11']
|
||||||
model: [yolov8n]
|
model: [yolov8n]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|||||||
25
.github/workflows/format.yml
vendored
Normal file
25
.github/workflows/format.yml
vendored
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Ultralytics 🚀 - AGPL-3.0 license
|
||||||
|
# Ultralytics Actions https://github.com/ultralytics/actions
|
||||||
|
# This workflow automatically formats code and documentation in PRs to official Ultralytics standards
|
||||||
|
|
||||||
|
name: Ultralytics Actions
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
format:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Run Ultralytics Formatting
|
||||||
|
uses: ultralytics/actions@main
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.GITHUB_TOKEN }} # automatically generated
|
||||||
|
python: true
|
||||||
|
docstrings: true
|
||||||
|
markdown: true
|
||||||
|
spelling: true
|
||||||
|
links: false
|
||||||
@ -22,7 +22,6 @@ repos:
|
|||||||
- id: check-case-conflict
|
- id: check-case-conflict
|
||||||
# - id: check-yaml
|
# - id: check-yaml
|
||||||
- id: check-docstring-first
|
- id: check-docstring-first
|
||||||
- id: double-quote-string-fixer
|
|
||||||
- id: detect-private-key
|
- id: detect-private-key
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
@ -64,7 +63,7 @@ repos:
|
|||||||
- id: codespell
|
- id: codespell
|
||||||
exclude: 'docs/de|docs/fr|docs/pt|docs/es|docs/mkdocs_de.yml'
|
exclude: 'docs/de|docs/fr|docs/pt|docs/es|docs/mkdocs_de.yml'
|
||||||
args:
|
args:
|
||||||
- --ignore-words-list=crate,nd,strack,dota,ane,segway,fo,gool,winn
|
- --ignore-words-list=crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall
|
||||||
|
|
||||||
- repo: https://github.com/PyCQA/docformatter
|
- repo: https://github.com/PyCQA/docformatter
|
||||||
rev: v1.7.5
|
rev: v1.7.5
|
||||||
|
|||||||
@ -30,45 +30,47 @@ import subprocess
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
DOCS = Path(__file__).parent.resolve()
|
DOCS = Path(__file__).parent.resolve()
|
||||||
SITE = DOCS.parent / 'site'
|
SITE = DOCS.parent / "site"
|
||||||
|
|
||||||
|
|
||||||
def build_docs():
|
def build_docs():
|
||||||
"""Build docs using mkdocs."""
|
"""Build docs using mkdocs."""
|
||||||
if SITE.exists():
|
if SITE.exists():
|
||||||
print(f'Removing existing {SITE}')
|
print(f"Removing existing {SITE}")
|
||||||
shutil.rmtree(SITE)
|
shutil.rmtree(SITE)
|
||||||
|
|
||||||
# Build the main documentation
|
# Build the main documentation
|
||||||
print(f'Building docs from {DOCS}')
|
print(f"Building docs from {DOCS}")
|
||||||
subprocess.run(f'mkdocs build -f {DOCS}/mkdocs.yml', check=True, shell=True)
|
subprocess.run(f"mkdocs build -f {DOCS}/mkdocs.yml", check=True, shell=True)
|
||||||
|
|
||||||
# Build other localized documentations
|
# Build other localized documentations
|
||||||
for file in DOCS.glob('mkdocs_*.yml'):
|
for file in DOCS.glob("mkdocs_*.yml"):
|
||||||
print(f'Building MkDocs site with configuration file: {file}')
|
print(f"Building MkDocs site with configuration file: {file}")
|
||||||
subprocess.run(f'mkdocs build -f {file}', check=True, shell=True)
|
subprocess.run(f"mkdocs build -f {file}", check=True, shell=True)
|
||||||
print(f'Site built at {SITE}')
|
print(f"Site built at {SITE}")
|
||||||
|
|
||||||
|
|
||||||
def update_html_links():
|
def update_html_links():
|
||||||
"""Update href links in HTML files to remove '.md' and '/index.md', excluding links starting with 'https://'."""
|
"""Update href links in HTML files to remove '.md' and '/index.md', excluding links starting with 'https://'."""
|
||||||
html_files = Path(SITE).rglob('*.html')
|
html_files = Path(SITE).rglob("*.html")
|
||||||
total_updated_links = 0
|
total_updated_links = 0
|
||||||
|
|
||||||
for html_file in html_files:
|
for html_file in html_files:
|
||||||
with open(html_file, 'r+', encoding='utf-8') as file:
|
with open(html_file, "r+", encoding="utf-8") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
# Find all links to be updated, excluding those starting with 'https://'
|
# Find all links to be updated, excluding those starting with 'https://'
|
||||||
links_to_update = re.findall(r'href="(?!https://)([^"]+?)(/index)?\.md"', content)
|
links_to_update = re.findall(r'href="(?!https://)([^"]+?)(/index)?\.md"', content)
|
||||||
|
|
||||||
# Update the content and count the number of links updated
|
# Update the content and count the number of links updated
|
||||||
updated_content, number_of_links_updated = re.subn(r'href="(?!https://)([^"]+?)(/index)?\.md"',
|
updated_content, number_of_links_updated = re.subn(
|
||||||
r'href="\1"', content)
|
r'href="(?!https://)([^"]+?)(/index)?\.md"', r'href="\1"', content
|
||||||
|
)
|
||||||
total_updated_links += number_of_links_updated
|
total_updated_links += number_of_links_updated
|
||||||
|
|
||||||
# Special handling for '/index' links
|
# Special handling for '/index' links
|
||||||
updated_content, number_of_index_links_updated = re.subn(r'href="([^"]+)/index"', r'href="\1/"',
|
updated_content, number_of_index_links_updated = re.subn(
|
||||||
updated_content)
|
r'href="([^"]+)/index"', r'href="\1/"', updated_content
|
||||||
|
)
|
||||||
total_updated_links += number_of_index_links_updated
|
total_updated_links += number_of_index_links_updated
|
||||||
|
|
||||||
# Write the updated content back to the file
|
# Write the updated content back to the file
|
||||||
@ -78,23 +80,23 @@ def update_html_links():
|
|||||||
|
|
||||||
# Print updated links for this file
|
# Print updated links for this file
|
||||||
for link in links_to_update:
|
for link in links_to_update:
|
||||||
print(f'Updated link in {html_file}: {link[0]}')
|
print(f"Updated link in {html_file}: {link[0]}")
|
||||||
|
|
||||||
print(f'Total number of links updated: {total_updated_links}')
|
print(f"Total number of links updated: {total_updated_links}")
|
||||||
|
|
||||||
|
|
||||||
def update_page_title(file_path: Path, new_title: str):
|
def update_page_title(file_path: Path, new_title: str):
|
||||||
"""Update the title of an HTML file."""
|
"""Update the title of an HTML file."""
|
||||||
|
|
||||||
# Read the content of the file
|
# Read the content of the file
|
||||||
with open(file_path, encoding='utf-8') as file:
|
with open(file_path, encoding="utf-8") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
|
|
||||||
# Replace the existing title with the new title
|
# Replace the existing title with the new title
|
||||||
updated_content = re.sub(r'<title>.*?</title>', f'<title>{new_title}</title>', content)
|
updated_content = re.sub(r"<title>.*?</title>", f"<title>{new_title}</title>", content)
|
||||||
|
|
||||||
# Write the updated content back to the file
|
# Write the updated content back to the file
|
||||||
with open(file_path, 'w', encoding='utf-8') as file:
|
with open(file_path, "w", encoding="utf-8") as file:
|
||||||
file.write(updated_content)
|
file.write(updated_content)
|
||||||
|
|
||||||
|
|
||||||
@ -109,8 +111,8 @@ def main():
|
|||||||
print('Serve site at http://localhost:8000 with "python -m http.server --directory site"')
|
print('Serve site at http://localhost:8000 with "python -m http.server --directory site"')
|
||||||
|
|
||||||
# Update titles
|
# Update titles
|
||||||
update_page_title(SITE / '404.html', new_title='Ultralytics Docs - Not Found')
|
update_page_title(SITE / "404.html", new_title="Ultralytics Docs - Not Found")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -14,14 +14,14 @@ from ultralytics.utils import ROOT
|
|||||||
|
|
||||||
NEW_YAML_DIR = ROOT.parent
|
NEW_YAML_DIR = ROOT.parent
|
||||||
CODE_DIR = ROOT
|
CODE_DIR = ROOT
|
||||||
REFERENCE_DIR = ROOT.parent / 'docs/en/reference'
|
REFERENCE_DIR = ROOT.parent / "docs/en/reference"
|
||||||
|
|
||||||
|
|
||||||
def extract_classes_and_functions(filepath: Path) -> tuple:
|
def extract_classes_and_functions(filepath: Path) -> tuple:
|
||||||
"""Extracts class and function names from a given Python file."""
|
"""Extracts class and function names from a given Python file."""
|
||||||
content = filepath.read_text()
|
content = filepath.read_text()
|
||||||
class_pattern = r'(?:^|\n)class\s(\w+)(?:\(|:)'
|
class_pattern = r"(?:^|\n)class\s(\w+)(?:\(|:)"
|
||||||
func_pattern = r'(?:^|\n)def\s(\w+)\('
|
func_pattern = r"(?:^|\n)def\s(\w+)\("
|
||||||
|
|
||||||
classes = re.findall(class_pattern, content)
|
classes = re.findall(class_pattern, content)
|
||||||
functions = re.findall(func_pattern, content)
|
functions = re.findall(func_pattern, content)
|
||||||
@ -31,31 +31,31 @@ def extract_classes_and_functions(filepath: Path) -> tuple:
|
|||||||
|
|
||||||
def create_markdown(py_filepath: Path, module_path: str, classes: list, functions: list):
|
def create_markdown(py_filepath: Path, module_path: str, classes: list, functions: list):
|
||||||
"""Creates a Markdown file containing the API reference for the given Python module."""
|
"""Creates a Markdown file containing the API reference for the given Python module."""
|
||||||
md_filepath = py_filepath.with_suffix('.md')
|
md_filepath = py_filepath.with_suffix(".md")
|
||||||
|
|
||||||
# Read existing content and keep header content between first two ---
|
# Read existing content and keep header content between first two ---
|
||||||
header_content = ''
|
header_content = ""
|
||||||
if md_filepath.exists():
|
if md_filepath.exists():
|
||||||
existing_content = md_filepath.read_text()
|
existing_content = md_filepath.read_text()
|
||||||
header_parts = existing_content.split('---')
|
header_parts = existing_content.split("---")
|
||||||
for part in header_parts:
|
for part in header_parts:
|
||||||
if 'description:' in part or 'comments:' in part:
|
if "description:" in part or "comments:" in part:
|
||||||
header_content += f'---{part}---\n\n'
|
header_content += f"---{part}---\n\n"
|
||||||
|
|
||||||
module_name = module_path.replace('.__init__', '')
|
module_name = module_path.replace(".__init__", "")
|
||||||
module_path = module_path.replace('.', '/')
|
module_path = module_path.replace(".", "/")
|
||||||
url = f'https://github.com/ultralytics/ultralytics/blob/main/{module_path}.py'
|
url = f"https://github.com/ultralytics/ultralytics/blob/main/{module_path}.py"
|
||||||
edit = f'https://github.com/ultralytics/ultralytics/edit/main/{module_path}.py'
|
edit = f"https://github.com/ultralytics/ultralytics/edit/main/{module_path}.py"
|
||||||
title_content = (
|
title_content = (
|
||||||
f'# Reference for `{module_path}.py`\n\n'
|
f"# Reference for `{module_path}.py`\n\n"
|
||||||
f'!!! Note\n\n'
|
f"!!! Note\n\n"
|
||||||
f' This file is available at [{url}]({url}). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠️. Thank you 🙏!\n\n'
|
f" This file is available at [{url}]({url}). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠️. Thank you 🙏!\n\n"
|
||||||
)
|
)
|
||||||
md_content = ['<br><br>\n'] + [f'## ::: {module_name}.{class_name}\n\n<br><br>\n' for class_name in classes]
|
md_content = ["<br><br>\n"] + [f"## ::: {module_name}.{class_name}\n\n<br><br>\n" for class_name in classes]
|
||||||
md_content.extend(f'## ::: {module_name}.{func_name}\n\n<br><br>\n' for func_name in functions)
|
md_content.extend(f"## ::: {module_name}.{func_name}\n\n<br><br>\n" for func_name in functions)
|
||||||
md_content = header_content + title_content + '\n'.join(md_content)
|
md_content = header_content + title_content + "\n".join(md_content)
|
||||||
if not md_content.endswith('\n'):
|
if not md_content.endswith("\n"):
|
||||||
md_content += '\n'
|
md_content += "\n"
|
||||||
|
|
||||||
md_filepath.parent.mkdir(parents=True, exist_ok=True)
|
md_filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
md_filepath.write_text(md_content)
|
md_filepath.write_text(md_content)
|
||||||
@ -80,28 +80,28 @@ def create_nav_menu_yaml(nav_items: list):
|
|||||||
for item_str in nav_items:
|
for item_str in nav_items:
|
||||||
item = Path(item_str)
|
item = Path(item_str)
|
||||||
parts = item.parts
|
parts = item.parts
|
||||||
current_level = nav_tree['reference']
|
current_level = nav_tree["reference"]
|
||||||
for part in parts[2:-1]: # skip the first two parts (docs and reference) and the last part (filename)
|
for part in parts[2:-1]: # skip the first two parts (docs and reference) and the last part (filename)
|
||||||
current_level = current_level[part]
|
current_level = current_level[part]
|
||||||
|
|
||||||
md_file_name = parts[-1].replace('.md', '')
|
md_file_name = parts[-1].replace(".md", "")
|
||||||
current_level[md_file_name] = item
|
current_level[md_file_name] = item
|
||||||
|
|
||||||
nav_tree_sorted = sort_nested_dict(nav_tree)
|
nav_tree_sorted = sort_nested_dict(nav_tree)
|
||||||
|
|
||||||
def _dict_to_yaml(d, level=0):
|
def _dict_to_yaml(d, level=0):
|
||||||
"""Converts a nested dictionary to a YAML-formatted string with indentation."""
|
"""Converts a nested dictionary to a YAML-formatted string with indentation."""
|
||||||
yaml_str = ''
|
yaml_str = ""
|
||||||
indent = ' ' * level
|
indent = " " * level
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
yaml_str += f'{indent}- {k}:\n{_dict_to_yaml(v, level + 1)}'
|
yaml_str += f"{indent}- {k}:\n{_dict_to_yaml(v, level + 1)}"
|
||||||
else:
|
else:
|
||||||
yaml_str += f"{indent}- {k}: {str(v).replace('docs/en/', '')}\n"
|
yaml_str += f"{indent}- {k}: {str(v).replace('docs/en/', '')}\n"
|
||||||
return yaml_str
|
return yaml_str
|
||||||
|
|
||||||
# Print updated YAML reference section
|
# Print updated YAML reference section
|
||||||
print('Scan complete, new mkdocs.yaml reference section is:\n\n', _dict_to_yaml(nav_tree_sorted))
|
print("Scan complete, new mkdocs.yaml reference section is:\n\n", _dict_to_yaml(nav_tree_sorted))
|
||||||
|
|
||||||
# Save new YAML reference section
|
# Save new YAML reference section
|
||||||
# (NEW_YAML_DIR / 'nav_menu_updated.yml').write_text(_dict_to_yaml(nav_tree_sorted))
|
# (NEW_YAML_DIR / 'nav_menu_updated.yml').write_text(_dict_to_yaml(nav_tree_sorted))
|
||||||
@ -111,7 +111,7 @@ def main():
|
|||||||
"""Main function to extract class and function names, create Markdown files, and generate a YAML navigation menu."""
|
"""Main function to extract class and function names, create Markdown files, and generate a YAML navigation menu."""
|
||||||
nav_items = []
|
nav_items = []
|
||||||
|
|
||||||
for py_filepath in CODE_DIR.rglob('*.py'):
|
for py_filepath in CODE_DIR.rglob("*.py"):
|
||||||
classes, functions = extract_classes_and_functions(py_filepath)
|
classes, functions = extract_classes_and_functions(py_filepath)
|
||||||
|
|
||||||
if classes or functions:
|
if classes or functions:
|
||||||
@ -124,5 +124,5 @@ def main():
|
|||||||
create_nav_menu_yaml(nav_items)
|
create_nav_menu_yaml(nav_items)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -22,69 +22,232 @@ class MarkdownLinkFixer:
|
|||||||
self.base_dir = Path(base_dir)
|
self.base_dir = Path(base_dir)
|
||||||
self.update_links = update_links
|
self.update_links = update_links
|
||||||
self.update_text = update_text
|
self.update_text = update_text
|
||||||
self.md_link_regex = re.compile(r'\[([^]]+)]\(([^:)]+)\.md\)')
|
self.md_link_regex = re.compile(r"\[([^]]+)]\(([^:)]+)\.md\)")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def replace_front_matter(content, lang_dir):
|
def replace_front_matter(content, lang_dir):
|
||||||
"""Ensure front matter keywords remain in English."""
|
"""Ensure front matter keywords remain in English."""
|
||||||
english = ['comments', 'description', 'keywords']
|
english = ["comments", "description", "keywords"]
|
||||||
translations = {
|
translations = {
|
||||||
'zh': ['评论', '描述', '关键词'], # Mandarin Chinese (Simplified) warning, sometimes translates as 关键字
|
"zh": ["评论", "描述", "关键词"], # Mandarin Chinese (Simplified) warning, sometimes translates as 关键字
|
||||||
'es': ['comentarios', 'descripción', 'palabras clave'], # Spanish
|
"es": ["comentarios", "descripción", "palabras clave"], # Spanish
|
||||||
'ru': ['комментарии', 'описание', 'ключевые слова'], # Russian
|
"ru": ["комментарии", "описание", "ключевые слова"], # Russian
|
||||||
'pt': ['comentários', 'descrição', 'palavras-chave'], # Portuguese
|
"pt": ["comentários", "descrição", "palavras-chave"], # Portuguese
|
||||||
'fr': ['commentaires', 'description', 'mots-clés'], # French
|
"fr": ["commentaires", "description", "mots-clés"], # French
|
||||||
'de': ['kommentare', 'beschreibung', 'schlüsselwörter'], # German
|
"de": ["kommentare", "beschreibung", "schlüsselwörter"], # German
|
||||||
'ja': ['コメント', '説明', 'キーワード'], # Japanese
|
"ja": ["コメント", "説明", "キーワード"], # Japanese
|
||||||
'ko': ['댓글', '설명', '키워드'], # Korean
|
"ko": ["댓글", "설명", "키워드"], # Korean
|
||||||
'hi': ['टिप्पणियाँ', 'विवरण', 'कीवर्ड'], # Hindi
|
"hi": ["टिप्पणियाँ", "विवरण", "कीवर्ड"], # Hindi
|
||||||
'ar': ['التعليقات', 'الوصف', 'الكلمات الرئيسية'] # Arabic
|
"ar": ["التعليقات", "الوصف", "الكلمات الرئيسية"], # Arabic
|
||||||
} # front matter translations for comments, description, keyword
|
} # front matter translations for comments, description, keyword
|
||||||
|
|
||||||
for term, eng_key in zip(translations.get(lang_dir.stem, []), english):
|
for term, eng_key in zip(translations.get(lang_dir.stem, []), english):
|
||||||
content = re.sub(rf'{term} *[::].*', f'{eng_key}: true', content, flags=re.IGNORECASE) if \
|
content = (
|
||||||
eng_key == 'comments' else re.sub(rf'{term} *[::] *', f'{eng_key}: ', content, flags=re.IGNORECASE)
|
re.sub(rf"{term} *[::].*", f"{eng_key}: true", content, flags=re.IGNORECASE)
|
||||||
|
if eng_key == "comments"
|
||||||
|
else re.sub(rf"{term} *[::] *", f"{eng_key}: ", content, flags=re.IGNORECASE)
|
||||||
|
)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def replace_admonitions(content, lang_dir):
|
def replace_admonitions(content, lang_dir):
|
||||||
"""Ensure front matter keywords remain in English."""
|
"""Ensure front matter keywords remain in English."""
|
||||||
english = [
|
english = [
|
||||||
'Note', 'Summary', 'Tip', 'Info', 'Success', 'Question', 'Warning', 'Failure', 'Danger', 'Bug', 'Example',
|
"Note",
|
||||||
'Quote', 'Abstract', 'Seealso', 'Admonition']
|
"Summary",
|
||||||
|
"Tip",
|
||||||
|
"Info",
|
||||||
|
"Success",
|
||||||
|
"Question",
|
||||||
|
"Warning",
|
||||||
|
"Failure",
|
||||||
|
"Danger",
|
||||||
|
"Bug",
|
||||||
|
"Example",
|
||||||
|
"Quote",
|
||||||
|
"Abstract",
|
||||||
|
"Seealso",
|
||||||
|
"Admonition",
|
||||||
|
]
|
||||||
translations = {
|
translations = {
|
||||||
'en':
|
"en": english,
|
||||||
english,
|
"zh": [
|
||||||
'zh': ['笔记', '摘要', '提示', '信息', '成功', '问题', '警告', '失败', '危险', '故障', '示例', '引用', '摘要', '另见', '警告'],
|
"笔记",
|
||||||
'es': [
|
"摘要",
|
||||||
'Nota', 'Resumen', 'Consejo', 'Información', 'Éxito', 'Pregunta', 'Advertencia', 'Fracaso', 'Peligro',
|
"提示",
|
||||||
'Error', 'Ejemplo', 'Cita', 'Abstracto', 'Véase También', 'Amonestación'],
|
"信息",
|
||||||
'ru': [
|
"成功",
|
||||||
'Заметка', 'Сводка', 'Совет', 'Информация', 'Успех', 'Вопрос', 'Предупреждение', 'Неудача', 'Опасность',
|
"问题",
|
||||||
'Ошибка', 'Пример', 'Цитата', 'Абстракт', 'См. Также', 'Предостережение'],
|
"警告",
|
||||||
'pt': [
|
"失败",
|
||||||
'Nota', 'Resumo', 'Dica', 'Informação', 'Sucesso', 'Questão', 'Aviso', 'Falha', 'Perigo', 'Bug',
|
"危险",
|
||||||
'Exemplo', 'Citação', 'Abstrato', 'Veja Também', 'Advertência'],
|
"故障",
|
||||||
'fr': [
|
"示例",
|
||||||
'Note', 'Résumé', 'Conseil', 'Info', 'Succès', 'Question', 'Avertissement', 'Échec', 'Danger', 'Bug',
|
"引用",
|
||||||
'Exemple', 'Citation', 'Abstrait', 'Voir Aussi', 'Admonestation'],
|
"摘要",
|
||||||
'de': [
|
"另见",
|
||||||
'Hinweis', 'Zusammenfassung', 'Tipp', 'Info', 'Erfolg', 'Frage', 'Warnung', 'Ausfall', 'Gefahr',
|
"警告",
|
||||||
'Fehler', 'Beispiel', 'Zitat', 'Abstrakt', 'Siehe Auch', 'Ermahnung'],
|
],
|
||||||
'ja': ['ノート', '要約', 'ヒント', '情報', '成功', '質問', '警告', '失敗', '危険', 'バグ', '例', '引用', '抄録', '参照', '訓告'],
|
"es": [
|
||||||
'ko': ['노트', '요약', '팁', '정보', '성공', '질문', '경고', '실패', '위험', '버그', '예제', '인용', '추상', '참조', '경고'],
|
"Nota",
|
||||||
'hi': [
|
"Resumen",
|
||||||
'नोट', 'सारांश', 'सुझाव', 'जानकारी', 'सफलता', 'प्रश्न', 'चेतावनी', 'विफलता', 'खतरा', 'बग', 'उदाहरण',
|
"Consejo",
|
||||||
'उद्धरण', 'सार', 'देखें भी', 'आगाही'],
|
"Información",
|
||||||
'ar': [
|
"Éxito",
|
||||||
'ملاحظة', 'ملخص', 'نصيحة', 'معلومات', 'نجاح', 'سؤال', 'تحذير', 'فشل', 'خطر', 'عطل', 'مثال', 'اقتباس',
|
"Pregunta",
|
||||||
'ملخص', 'انظر أيضاً', 'تحذير']}
|
"Advertencia",
|
||||||
|
"Fracaso",
|
||||||
|
"Peligro",
|
||||||
|
"Error",
|
||||||
|
"Ejemplo",
|
||||||
|
"Cita",
|
||||||
|
"Abstracto",
|
||||||
|
"Véase También",
|
||||||
|
"Amonestación",
|
||||||
|
],
|
||||||
|
"ru": [
|
||||||
|
"Заметка",
|
||||||
|
"Сводка",
|
||||||
|
"Совет",
|
||||||
|
"Информация",
|
||||||
|
"Успех",
|
||||||
|
"Вопрос",
|
||||||
|
"Предупреждение",
|
||||||
|
"Неудача",
|
||||||
|
"Опасность",
|
||||||
|
"Ошибка",
|
||||||
|
"Пример",
|
||||||
|
"Цитата",
|
||||||
|
"Абстракт",
|
||||||
|
"См. Также",
|
||||||
|
"Предостережение",
|
||||||
|
],
|
||||||
|
"pt": [
|
||||||
|
"Nota",
|
||||||
|
"Resumo",
|
||||||
|
"Dica",
|
||||||
|
"Informação",
|
||||||
|
"Sucesso",
|
||||||
|
"Questão",
|
||||||
|
"Aviso",
|
||||||
|
"Falha",
|
||||||
|
"Perigo",
|
||||||
|
"Bug",
|
||||||
|
"Exemplo",
|
||||||
|
"Citação",
|
||||||
|
"Abstrato",
|
||||||
|
"Veja Também",
|
||||||
|
"Advertência",
|
||||||
|
],
|
||||||
|
"fr": [
|
||||||
|
"Note",
|
||||||
|
"Résumé",
|
||||||
|
"Conseil",
|
||||||
|
"Info",
|
||||||
|
"Succès",
|
||||||
|
"Question",
|
||||||
|
"Avertissement",
|
||||||
|
"Échec",
|
||||||
|
"Danger",
|
||||||
|
"Bug",
|
||||||
|
"Exemple",
|
||||||
|
"Citation",
|
||||||
|
"Abstrait",
|
||||||
|
"Voir Aussi",
|
||||||
|
"Admonestation",
|
||||||
|
],
|
||||||
|
"de": [
|
||||||
|
"Hinweis",
|
||||||
|
"Zusammenfassung",
|
||||||
|
"Tipp",
|
||||||
|
"Info",
|
||||||
|
"Erfolg",
|
||||||
|
"Frage",
|
||||||
|
"Warnung",
|
||||||
|
"Ausfall",
|
||||||
|
"Gefahr",
|
||||||
|
"Fehler",
|
||||||
|
"Beispiel",
|
||||||
|
"Zitat",
|
||||||
|
"Abstrakt",
|
||||||
|
"Siehe Auch",
|
||||||
|
"Ermahnung",
|
||||||
|
],
|
||||||
|
"ja": [
|
||||||
|
"ノート",
|
||||||
|
"要約",
|
||||||
|
"ヒント",
|
||||||
|
"情報",
|
||||||
|
"成功",
|
||||||
|
"質問",
|
||||||
|
"警告",
|
||||||
|
"失敗",
|
||||||
|
"危険",
|
||||||
|
"バグ",
|
||||||
|
"例",
|
||||||
|
"引用",
|
||||||
|
"抄録",
|
||||||
|
"参照",
|
||||||
|
"訓告",
|
||||||
|
],
|
||||||
|
"ko": [
|
||||||
|
"노트",
|
||||||
|
"요약",
|
||||||
|
"팁",
|
||||||
|
"정보",
|
||||||
|
"성공",
|
||||||
|
"질문",
|
||||||
|
"경고",
|
||||||
|
"실패",
|
||||||
|
"위험",
|
||||||
|
"버그",
|
||||||
|
"예제",
|
||||||
|
"인용",
|
||||||
|
"추상",
|
||||||
|
"참조",
|
||||||
|
"경고",
|
||||||
|
],
|
||||||
|
"hi": [
|
||||||
|
"नोट",
|
||||||
|
"सारांश",
|
||||||
|
"सुझाव",
|
||||||
|
"जानकारी",
|
||||||
|
"सफलता",
|
||||||
|
"प्रश्न",
|
||||||
|
"चेतावनी",
|
||||||
|
"विफलता",
|
||||||
|
"खतरा",
|
||||||
|
"बग",
|
||||||
|
"उदाहरण",
|
||||||
|
"उद्धरण",
|
||||||
|
"सार",
|
||||||
|
"देखें भी",
|
||||||
|
"आगाही",
|
||||||
|
],
|
||||||
|
"ar": [
|
||||||
|
"ملاحظة",
|
||||||
|
"ملخص",
|
||||||
|
"نصيحة",
|
||||||
|
"معلومات",
|
||||||
|
"نجاح",
|
||||||
|
"سؤال",
|
||||||
|
"تحذير",
|
||||||
|
"فشل",
|
||||||
|
"خطر",
|
||||||
|
"عطل",
|
||||||
|
"مثال",
|
||||||
|
"اقتباس",
|
||||||
|
"ملخص",
|
||||||
|
"انظر أيضاً",
|
||||||
|
"تحذير",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
for term, eng_key in zip(translations.get(lang_dir.stem, []), english):
|
for term, eng_key in zip(translations.get(lang_dir.stem, []), english):
|
||||||
if lang_dir.stem != 'en':
|
if lang_dir.stem != "en":
|
||||||
content = re.sub(rf'!!! *{eng_key} *\n', f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
|
content = re.sub(rf"!!! *{eng_key} *\n", f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
|
||||||
content = re.sub(rf'!!! *{term} *\n', f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
|
content = re.sub(rf"!!! *{term} *\n", f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
|
||||||
content = re.sub(rf'!!! *{term}', f'!!! {eng_key}', content, flags=re.IGNORECASE)
|
content = re.sub(rf"!!! *{term}", f"!!! {eng_key}", content, flags=re.IGNORECASE)
|
||||||
content = re.sub(r'!!! *"', '!!! Example "', content, flags=re.IGNORECASE)
|
content = re.sub(r'!!! *"', '!!! Example "', content, flags=re.IGNORECASE)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
@ -92,30 +255,30 @@ class MarkdownLinkFixer:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def update_iframe(content):
|
def update_iframe(content):
|
||||||
"""Update the 'allow' attribute of iframe if it does not contain the specific English permissions."""
|
"""Update the 'allow' attribute of iframe if it does not contain the specific English permissions."""
|
||||||
english = 'accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share'
|
english = "accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
|
||||||
pattern = re.compile(f'allow="(?!{re.escape(english)}).+?"')
|
pattern = re.compile(f'allow="(?!{re.escape(english)}).+?"')
|
||||||
return pattern.sub(f'allow="{english}"', content)
|
return pattern.sub(f'allow="{english}"', content)
|
||||||
|
|
||||||
def link_replacer(self, match, parent_dir, lang_dir, use_abs_link=False):
|
def link_replacer(self, match, parent_dir, lang_dir, use_abs_link=False):
|
||||||
"""Replace broken links with corresponding links in the /en/ directory."""
|
"""Replace broken links with corresponding links in the /en/ directory."""
|
||||||
text, path = match.groups()
|
text, path = match.groups()
|
||||||
linked_path = (parent_dir / path).resolve().with_suffix('.md')
|
linked_path = (parent_dir / path).resolve().with_suffix(".md")
|
||||||
|
|
||||||
if not linked_path.exists():
|
if not linked_path.exists():
|
||||||
en_linked_path = Path(str(linked_path).replace(str(lang_dir), str(lang_dir.parent / 'en')))
|
en_linked_path = Path(str(linked_path).replace(str(lang_dir), str(lang_dir.parent / "en")))
|
||||||
if en_linked_path.exists():
|
if en_linked_path.exists():
|
||||||
if use_abs_link:
|
if use_abs_link:
|
||||||
# Use absolute links WARNING: BUGS, DO NOT USE
|
# Use absolute links WARNING: BUGS, DO NOT USE
|
||||||
docs_root_relative_path = en_linked_path.relative_to(lang_dir.parent)
|
docs_root_relative_path = en_linked_path.relative_to(lang_dir.parent)
|
||||||
updated_path = str(docs_root_relative_path).replace('en/', '/../')
|
updated_path = str(docs_root_relative_path).replace("en/", "/../")
|
||||||
else:
|
else:
|
||||||
# Use relative links
|
# Use relative links
|
||||||
steps_up = len(parent_dir.relative_to(self.base_dir).parts)
|
steps_up = len(parent_dir.relative_to(self.base_dir).parts)
|
||||||
updated_path = Path('../' * steps_up) / en_linked_path.relative_to(self.base_dir)
|
updated_path = Path("../" * steps_up) / en_linked_path.relative_to(self.base_dir)
|
||||||
updated_path = str(updated_path).replace('/en/', '/')
|
updated_path = str(updated_path).replace("/en/", "/")
|
||||||
|
|
||||||
print(f"Redirecting link '[{text}]({path})' from {parent_dir} to {updated_path}")
|
print(f"Redirecting link '[{text}]({path})' from {parent_dir} to {updated_path}")
|
||||||
return f'[{text}]({updated_path})'
|
return f"[{text}]({updated_path})"
|
||||||
else:
|
else:
|
||||||
print(f"Warning: Broken link '[{text}]({path})' found in {parent_dir} does not exist in /docs/en/.")
|
print(f"Warning: Broken link '[{text}]({path})' found in {parent_dir} does not exist in /docs/en/.")
|
||||||
|
|
||||||
@ -124,28 +287,30 @@ class MarkdownLinkFixer:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def update_html_tags(content):
|
def update_html_tags(content):
|
||||||
"""Updates HTML tags in docs."""
|
"""Updates HTML tags in docs."""
|
||||||
alt_tag = 'MISSING'
|
alt_tag = "MISSING"
|
||||||
|
|
||||||
# Remove closing slashes from self-closing HTML tags
|
# Remove closing slashes from self-closing HTML tags
|
||||||
pattern = re.compile(r'<([^>]+?)\s*/>')
|
pattern = re.compile(r"<([^>]+?)\s*/>")
|
||||||
content = re.sub(pattern, r'<\1>', content)
|
content = re.sub(pattern, r"<\1>", content)
|
||||||
|
|
||||||
# Find all images without alt tags and add placeholder alt text
|
# Find all images without alt tags and add placeholder alt text
|
||||||
pattern = re.compile(r'!\[(.*?)\]\((.*?)\)')
|
pattern = re.compile(r"!\[(.*?)\]\((.*?)\)")
|
||||||
content, num_replacements = re.subn(pattern, lambda match: f'})',
|
content, num_replacements = re.subn(
|
||||||
content)
|
pattern, lambda match: f"})", content
|
||||||
|
)
|
||||||
|
|
||||||
# Add missing alt tags to HTML images
|
# Add missing alt tags to HTML images
|
||||||
pattern = re.compile(r'<img\s+(?!.*?\balt\b)[^>]*src=["\'](.*?)["\'][^>]*>')
|
pattern = re.compile(r'<img\s+(?!.*?\balt\b)[^>]*src=["\'](.*?)["\'][^>]*>')
|
||||||
content, num_replacements = re.subn(pattern, lambda match: match.group(0).replace('>', f' alt="{alt_tag}">', 1),
|
content, num_replacements = re.subn(
|
||||||
content)
|
pattern, lambda match: match.group(0).replace(">", f' alt="{alt_tag}">', 1), content
|
||||||
|
)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def process_markdown_file(self, md_file_path, lang_dir):
|
def process_markdown_file(self, md_file_path, lang_dir):
|
||||||
"""Process each markdown file in the language directory."""
|
"""Process each markdown file in the language directory."""
|
||||||
print(f'Processing file: {md_file_path}')
|
print(f"Processing file: {md_file_path}")
|
||||||
with open(md_file_path, encoding='utf-8') as file:
|
with open(md_file_path, encoding="utf-8") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
|
|
||||||
if self.update_links:
|
if self.update_links:
|
||||||
@ -157,23 +322,23 @@ class MarkdownLinkFixer:
|
|||||||
content = self.update_iframe(content)
|
content = self.update_iframe(content)
|
||||||
content = self.update_html_tags(content)
|
content = self.update_html_tags(content)
|
||||||
|
|
||||||
with open(md_file_path, 'w', encoding='utf-8') as file:
|
with open(md_file_path, "w", encoding="utf-8") as file:
|
||||||
file.write(content)
|
file.write(content)
|
||||||
|
|
||||||
def process_language_directory(self, lang_dir):
|
def process_language_directory(self, lang_dir):
|
||||||
"""Process each language-specific directory."""
|
"""Process each language-specific directory."""
|
||||||
print(f'Processing language directory: {lang_dir}')
|
print(f"Processing language directory: {lang_dir}")
|
||||||
for md_file in lang_dir.rglob('*.md'):
|
for md_file in lang_dir.rglob("*.md"):
|
||||||
self.process_markdown_file(md_file, lang_dir)
|
self.process_markdown_file(md_file, lang_dir)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Run the link fixing and front matter updating process for each language-specific directory."""
|
"""Run the link fixing and front matter updating process for each language-specific directory."""
|
||||||
for subdir in self.base_dir.iterdir():
|
for subdir in self.base_dir.iterdir():
|
||||||
if subdir.is_dir() and re.match(r'^\w\w$', subdir.name):
|
if subdir.is_dir() and re.match(r"^\w\w$", subdir.name):
|
||||||
self.process_language_directory(subdir)
|
self.process_language_directory(subdir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# Set the path to your MkDocs 'docs' directory here
|
# Set the path to your MkDocs 'docs' directory here
|
||||||
docs_dir = str(Path(__file__).parent.resolve())
|
docs_dir = str(Path(__file__).parent.resolve())
|
||||||
fixer = MarkdownLinkFixer(docs_dir, update_links=True, update_text=True)
|
fixer = MarkdownLinkFixer(docs_dir, update_links=True, update_text=True)
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class YOLOv8:
|
|||||||
self.iou_thres = iou_thres
|
self.iou_thres = iou_thres
|
||||||
|
|
||||||
# Load the class names from the COCO dataset
|
# Load the class names from the COCO dataset
|
||||||
self.classes = yaml_load(check_yaml('coco128.yaml'))['names']
|
self.classes = yaml_load(check_yaml("coco128.yaml"))["names"]
|
||||||
|
|
||||||
# Generate a color palette for the classes
|
# Generate a color palette for the classes
|
||||||
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
|
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
|
||||||
@ -57,7 +57,7 @@ class YOLOv8:
|
|||||||
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
|
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
|
||||||
|
|
||||||
# Create the label text with class name and score
|
# Create the label text with class name and score
|
||||||
label = f'{self.classes[class_id]}: {score:.2f}'
|
label = f"{self.classes[class_id]}: {score:.2f}"
|
||||||
|
|
||||||
# Calculate the dimensions of the label text
|
# Calculate the dimensions of the label text
|
||||||
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||||
@ -67,8 +67,9 @@ class YOLOv8:
|
|||||||
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
|
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
|
||||||
|
|
||||||
# Draw a filled rectangle as the background for the label text
|
# Draw a filled rectangle as the background for the label text
|
||||||
cv2.rectangle(img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color,
|
cv2.rectangle(
|
||||||
cv2.FILLED)
|
img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
|
||||||
|
)
|
||||||
|
|
||||||
# Draw the label text on the image
|
# Draw the label text on the image
|
||||||
cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
||||||
@ -182,7 +183,7 @@ class YOLOv8:
|
|||||||
output_img: The output image with drawn detections.
|
output_img: The output image with drawn detections.
|
||||||
"""
|
"""
|
||||||
# Create an inference session using the ONNX model and specify execution providers
|
# Create an inference session using the ONNX model and specify execution providers
|
||||||
session = ort.InferenceSession(self.onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
||||||
|
|
||||||
# Get the model inputs
|
# Get the model inputs
|
||||||
model_inputs = session.get_inputs()
|
model_inputs = session.get_inputs()
|
||||||
@ -202,17 +203,17 @@ class YOLOv8:
|
|||||||
return self.postprocess(self.img, outputs) # output image
|
return self.postprocess(self.img, outputs) # output image
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# Create an argument parser to handle command-line arguments
|
# Create an argument parser to handle command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', type=str, default='yolov8n.onnx', help='Input your ONNX model.')
|
parser.add_argument("--model", type=str, default="yolov8n.onnx", help="Input your ONNX model.")
|
||||||
parser.add_argument('--img', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image.')
|
parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.")
|
||||||
parser.add_argument('--conf-thres', type=float, default=0.5, help='Confidence threshold')
|
parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold")
|
||||||
parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold')
|
parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Check the requirements and select the appropriate backend (CPU or GPU)
|
# Check the requirements and select the appropriate backend (CPU or GPU)
|
||||||
check_requirements('onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime')
|
check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime")
|
||||||
|
|
||||||
# Create an instance of the YOLOv8 class with the specified arguments
|
# Create an instance of the YOLOv8 class with the specified arguments
|
||||||
detection = YOLOv8(args.model, args.img, args.conf_thres, args.iou_thres)
|
detection = YOLOv8(args.model, args.img, args.conf_thres, args.iou_thres)
|
||||||
@ -221,8 +222,8 @@ if __name__ == '__main__':
|
|||||||
output_image = detection.main()
|
output_image = detection.main()
|
||||||
|
|
||||||
# Display the output image in a window
|
# Display the output image in a window
|
||||||
cv2.namedWindow('Output', cv2.WINDOW_NORMAL)
|
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
|
||||||
cv2.imshow('Output', output_image)
|
cv2.imshow("Output", output_image)
|
||||||
|
|
||||||
# Wait for a key press to exit
|
# Wait for a key press to exit
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from ultralytics.utils import ASSETS, yaml_load
|
from ultralytics.utils import ASSETS, yaml_load
|
||||||
from ultralytics.utils.checks import check_yaml
|
from ultralytics.utils.checks import check_yaml
|
||||||
|
|
||||||
CLASSES = yaml_load(check_yaml('coco128.yaml'))['names']
|
CLASSES = yaml_load(check_yaml("coco128.yaml"))["names"]
|
||||||
colors = np.random.uniform(0, 255, size=(len(CLASSES), 3))
|
colors = np.random.uniform(0, 255, size=(len(CLASSES), 3))
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
|
|||||||
x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box.
|
x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box.
|
||||||
y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box.
|
y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box.
|
||||||
"""
|
"""
|
||||||
label = f'{CLASSES[class_id]} ({confidence:.2f})'
|
label = f"{CLASSES[class_id]} ({confidence:.2f})"
|
||||||
color = colors[class_id]
|
color = colors[class_id]
|
||||||
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
|
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
|
||||||
cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||||
@ -76,8 +76,11 @@ def main(onnx_model, input_image):
|
|||||||
(minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores)
|
(minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores)
|
||||||
if maxScore >= 0.25:
|
if maxScore >= 0.25:
|
||||||
box = [
|
box = [
|
||||||
outputs[0][i][0] - (0.5 * outputs[0][i][2]), outputs[0][i][1] - (0.5 * outputs[0][i][3]),
|
outputs[0][i][0] - (0.5 * outputs[0][i][2]),
|
||||||
outputs[0][i][2], outputs[0][i][3]]
|
outputs[0][i][1] - (0.5 * outputs[0][i][3]),
|
||||||
|
outputs[0][i][2],
|
||||||
|
outputs[0][i][3],
|
||||||
|
]
|
||||||
boxes.append(box)
|
boxes.append(box)
|
||||||
scores.append(maxScore)
|
scores.append(maxScore)
|
||||||
class_ids.append(maxClassIndex)
|
class_ids.append(maxClassIndex)
|
||||||
@ -92,26 +95,34 @@ def main(onnx_model, input_image):
|
|||||||
index = result_boxes[i]
|
index = result_boxes[i]
|
||||||
box = boxes[index]
|
box = boxes[index]
|
||||||
detection = {
|
detection = {
|
||||||
'class_id': class_ids[index],
|
"class_id": class_ids[index],
|
||||||
'class_name': CLASSES[class_ids[index]],
|
"class_name": CLASSES[class_ids[index]],
|
||||||
'confidence': scores[index],
|
"confidence": scores[index],
|
||||||
'box': box,
|
"box": box,
|
||||||
'scale': scale}
|
"scale": scale,
|
||||||
|
}
|
||||||
detections.append(detection)
|
detections.append(detection)
|
||||||
draw_bounding_box(original_image, class_ids[index], scores[index], round(box[0] * scale), round(box[1] * scale),
|
draw_bounding_box(
|
||||||
round((box[0] + box[2]) * scale), round((box[1] + box[3]) * scale))
|
original_image,
|
||||||
|
class_ids[index],
|
||||||
|
scores[index],
|
||||||
|
round(box[0] * scale),
|
||||||
|
round(box[1] * scale),
|
||||||
|
round((box[0] + box[2]) * scale),
|
||||||
|
round((box[1] + box[3]) * scale),
|
||||||
|
)
|
||||||
|
|
||||||
# Display the image with bounding boxes
|
# Display the image with bounding boxes
|
||||||
cv2.imshow('image', original_image)
|
cv2.imshow("image", original_image)
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', default='yolov8n.onnx', help='Input your ONNX model.')
|
parser.add_argument("--model", default="yolov8n.onnx", help="Input your ONNX model.")
|
||||||
parser.add_argument('--img', default=str(ASSETS / 'bus.jpg'), help='Path to input image.')
|
parser.add_argument("--img", default=str(ASSETS / "bus.jpg"), help="Path to input image.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args.model, args.img)
|
main(args.model, args.img)
|
||||||
|
|||||||
@ -13,14 +13,9 @@ img_height = 640
|
|||||||
|
|
||||||
|
|
||||||
class LetterBox:
|
class LetterBox:
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self, new_shape=(img_width, img_height), auto=False, scaleFill=False, scaleup=True, center=True, stride=32
|
||||||
new_shape=(img_width, img_height),
|
):
|
||||||
auto=False,
|
|
||||||
scaleFill=False,
|
|
||||||
scaleup=True,
|
|
||||||
center=True,
|
|
||||||
stride=32):
|
|
||||||
self.new_shape = new_shape
|
self.new_shape = new_shape
|
||||||
self.auto = auto
|
self.auto = auto
|
||||||
self.scaleFill = scaleFill
|
self.scaleFill = scaleFill
|
||||||
@ -33,9 +28,9 @@ class LetterBox:
|
|||||||
|
|
||||||
if labels is None:
|
if labels is None:
|
||||||
labels = {}
|
labels = {}
|
||||||
img = labels.get('img') if image is None else image
|
img = labels.get("img") if image is None else image
|
||||||
shape = img.shape[:2] # current shape [height, width]
|
shape = img.shape[:2] # current shape [height, width]
|
||||||
new_shape = labels.pop('rect_shape', self.new_shape)
|
new_shape = labels.pop("rect_shape", self.new_shape)
|
||||||
if isinstance(new_shape, int):
|
if isinstance(new_shape, int):
|
||||||
new_shape = (new_shape, new_shape)
|
new_shape = (new_shape, new_shape)
|
||||||
|
|
||||||
@ -63,15 +58,16 @@ class LetterBox:
|
|||||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||||
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
||||||
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
||||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
img = cv2.copyMakeBorder(
|
||||||
value=(114, 114, 114)) # add border
|
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
||||||
if labels.get('ratio_pad'):
|
) # add border
|
||||||
labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation
|
if labels.get("ratio_pad"):
|
||||||
|
labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
|
||||||
|
|
||||||
if len(labels):
|
if len(labels):
|
||||||
labels = self._update_labels(labels, ratio, dw, dh)
|
labels = self._update_labels(labels, ratio, dw, dh)
|
||||||
labels['img'] = img
|
labels["img"] = img
|
||||||
labels['resized_shape'] = new_shape
|
labels["resized_shape"] = new_shape
|
||||||
return labels
|
return labels
|
||||||
else:
|
else:
|
||||||
return img
|
return img
|
||||||
@ -79,15 +75,14 @@ class LetterBox:
|
|||||||
def _update_labels(self, labels, ratio, padw, padh):
|
def _update_labels(self, labels, ratio, padw, padh):
|
||||||
"""Update labels."""
|
"""Update labels."""
|
||||||
|
|
||||||
labels['instances'].convert_bbox(format='xyxy')
|
labels["instances"].convert_bbox(format="xyxy")
|
||||||
labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
|
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
|
||||||
labels['instances'].scale(*ratio)
|
labels["instances"].scale(*ratio)
|
||||||
labels['instances'].add_padding(padw, padh)
|
labels["instances"].add_padding(padw, padh)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
class Yolov8TFLite:
|
class Yolov8TFLite:
|
||||||
|
|
||||||
def __init__(self, tflite_model, input_image, confidence_thres, iou_thres):
|
def __init__(self, tflite_model, input_image, confidence_thres, iou_thres):
|
||||||
"""
|
"""
|
||||||
Initializes an instance of the Yolov8TFLite class.
|
Initializes an instance of the Yolov8TFLite class.
|
||||||
@ -105,7 +100,7 @@ class Yolov8TFLite:
|
|||||||
self.iou_thres = iou_thres
|
self.iou_thres = iou_thres
|
||||||
|
|
||||||
# Load the class names from the COCO dataset
|
# Load the class names from the COCO dataset
|
||||||
self.classes = yaml_load(check_yaml('coco128.yaml'))['names']
|
self.classes = yaml_load(check_yaml("coco128.yaml"))["names"]
|
||||||
|
|
||||||
# Generate a color palette for the classes
|
# Generate a color palette for the classes
|
||||||
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
|
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
|
||||||
@ -134,7 +129,7 @@ class Yolov8TFLite:
|
|||||||
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
|
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
|
||||||
|
|
||||||
# Create the label text with class name and score
|
# Create the label text with class name and score
|
||||||
label = f'{self.classes[class_id]}: {score:.2f}'
|
label = f"{self.classes[class_id]}: {score:.2f}"
|
||||||
|
|
||||||
# Calculate the dimensions of the label text
|
# Calculate the dimensions of the label text
|
||||||
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||||
@ -144,8 +139,13 @@ class Yolov8TFLite:
|
|||||||
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
|
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
|
||||||
|
|
||||||
# Draw a filled rectangle as the background for the label text
|
# Draw a filled rectangle as the background for the label text
|
||||||
cv2.rectangle(img, (int(label_x), int(label_y - label_height)),
|
cv2.rectangle(
|
||||||
(int(label_x + label_width), int(label_y + label_height)), color, cv2.FILLED)
|
img,
|
||||||
|
(int(label_x), int(label_y - label_height)),
|
||||||
|
(int(label_x + label_width), int(label_y + label_height)),
|
||||||
|
color,
|
||||||
|
cv2.FILLED,
|
||||||
|
)
|
||||||
|
|
||||||
# Draw the label text on the image
|
# Draw the label text on the image
|
||||||
cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
||||||
@ -161,7 +161,7 @@ class Yolov8TFLite:
|
|||||||
# Read the input image using OpenCV
|
# Read the input image using OpenCV
|
||||||
self.img = cv2.imread(self.input_image)
|
self.img = cv2.imread(self.input_image)
|
||||||
|
|
||||||
print('image befor', self.img)
|
print("image before", self.img)
|
||||||
# Get the height and width of the input image
|
# Get the height and width of the input image
|
||||||
self.img_height, self.img_width = self.img.shape[:2]
|
self.img_height, self.img_width = self.img.shape[:2]
|
||||||
|
|
||||||
@ -209,8 +209,10 @@ class Yolov8TFLite:
|
|||||||
# Get the box, score, and class ID corresponding to the index
|
# Get the box, score, and class ID corresponding to the index
|
||||||
box = boxes[i]
|
box = boxes[i]
|
||||||
gain = min(img_width / self.img_width, img_height / self.img_height)
|
gain = min(img_width / self.img_width, img_height / self.img_height)
|
||||||
pad = round((img_width - self.img_width * gain) / 2 -
|
pad = (
|
||||||
0.1), round((img_height - self.img_height * gain) / 2 - 0.1)
|
round((img_width - self.img_width * gain) / 2 - 0.1),
|
||||||
|
round((img_height - self.img_height * gain) / 2 - 0.1),
|
||||||
|
)
|
||||||
box[0] = (box[0] - pad[0]) / gain
|
box[0] = (box[0] - pad[0]) / gain
|
||||||
box[1] = (box[1] - pad[1]) / gain
|
box[1] = (box[1] - pad[1]) / gain
|
||||||
box[2] = box[2] / gain
|
box[2] = box[2] / gain
|
||||||
@ -242,7 +244,7 @@ class Yolov8TFLite:
|
|||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
|
|
||||||
# Store the shape of the input for later use
|
# Store the shape of the input for later use
|
||||||
input_shape = input_details[0]['shape']
|
input_shape = input_details[0]["shape"]
|
||||||
self.input_width = input_shape[1]
|
self.input_width = input_shape[1]
|
||||||
self.input_height = input_shape[2]
|
self.input_height = input_shape[2]
|
||||||
|
|
||||||
@ -251,19 +253,19 @@ class Yolov8TFLite:
|
|||||||
img_data = img_data
|
img_data = img_data
|
||||||
# img_data = img_data.cpu().numpy()
|
# img_data = img_data.cpu().numpy()
|
||||||
# Set the input tensor to the interpreter
|
# Set the input tensor to the interpreter
|
||||||
print(input_details[0]['index'])
|
print(input_details[0]["index"])
|
||||||
print(img_data.shape)
|
print(img_data.shape)
|
||||||
img_data = img_data.transpose((0, 2, 3, 1))
|
img_data = img_data.transpose((0, 2, 3, 1))
|
||||||
|
|
||||||
scale, zero_point = input_details[0]['quantization']
|
scale, zero_point = input_details[0]["quantization"]
|
||||||
interpreter.set_tensor(input_details[0]['index'], img_data)
|
interpreter.set_tensor(input_details[0]["index"], img_data)
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
interpreter.invoke()
|
interpreter.invoke()
|
||||||
|
|
||||||
# Get the output tensor from the interpreter
|
# Get the output tensor from the interpreter
|
||||||
output = interpreter.get_tensor(output_details[0]['index'])
|
output = interpreter.get_tensor(output_details[0]["index"])
|
||||||
scale, zero_point = output_details[0]['quantization']
|
scale, zero_point = output_details[0]["quantization"]
|
||||||
output = (output.astype(np.float32) - zero_point) * scale
|
output = (output.astype(np.float32) - zero_point) * scale
|
||||||
|
|
||||||
output[:, [0, 2]] *= img_width
|
output[:, [0, 2]] *= img_width
|
||||||
@ -273,16 +275,15 @@ class Yolov8TFLite:
|
|||||||
return self.postprocess(self.img, output)
|
return self.postprocess(self.img, output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# Create an argument parser to handle command-line arguments
|
# Create an argument parser to handle command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model',
|
parser.add_argument(
|
||||||
type=str,
|
"--model", type=str, default="yolov8n_full_integer_quant.tflite", help="Input your TFLite model."
|
||||||
default='yolov8n_full_integer_quant.tflite',
|
)
|
||||||
help='Input your TFLite model.')
|
parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.")
|
||||||
parser.add_argument('--img', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image.')
|
parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold")
|
||||||
parser.add_argument('--conf-thres', type=float, default=0.5, help='Confidence threshold')
|
parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold")
|
||||||
parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create an instance of the Yolov8TFLite class with the specified arguments
|
# Create an instance of the Yolov8TFLite class with the specified arguments
|
||||||
@ -292,7 +293,7 @@ if __name__ == '__main__':
|
|||||||
output_image = detection.main()
|
output_image = detection.main()
|
||||||
|
|
||||||
# Display the output image in a window
|
# Display the output image in a window
|
||||||
cv2.imshow('Output', output_image)
|
cv2.imshow("Output", output_image)
|
||||||
|
|
||||||
# Wait for a key press to exit
|
# Wait for a key press to exit
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
|
|||||||
@ -16,21 +16,22 @@ track_history = defaultdict(list)
|
|||||||
current_region = None
|
current_region = None
|
||||||
counting_regions = [
|
counting_regions = [
|
||||||
{
|
{
|
||||||
'name': 'YOLOv8 Polygon Region',
|
"name": "YOLOv8 Polygon Region",
|
||||||
'polygon': Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]), # Polygon points
|
"polygon": Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]), # Polygon points
|
||||||
'counts': 0,
|
"counts": 0,
|
||||||
'dragging': False,
|
"dragging": False,
|
||||||
'region_color': (255, 42, 4), # BGR Value
|
"region_color": (255, 42, 4), # BGR Value
|
||||||
'text_color': (255, 255, 255) # Region Text Color
|
"text_color": (255, 255, 255), # Region Text Color
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'name': 'YOLOv8 Rectangle Region',
|
"name": "YOLOv8 Rectangle Region",
|
||||||
'polygon': Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Polygon points
|
"polygon": Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Polygon points
|
||||||
'counts': 0,
|
"counts": 0,
|
||||||
'dragging': False,
|
"dragging": False,
|
||||||
'region_color': (37, 255, 225), # BGR Value
|
"region_color": (37, 255, 225), # BGR Value
|
||||||
'text_color': (0, 0, 0), # Region Text Color
|
"text_color": (0, 0, 0), # Region Text Color
|
||||||
}, ]
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def mouse_callback(event, x, y, flags, param):
|
def mouse_callback(event, x, y, flags, param):
|
||||||
@ -40,32 +41,33 @@ def mouse_callback(event, x, y, flags, param):
|
|||||||
# Mouse left button down event
|
# Mouse left button down event
|
||||||
if event == cv2.EVENT_LBUTTONDOWN:
|
if event == cv2.EVENT_LBUTTONDOWN:
|
||||||
for region in counting_regions:
|
for region in counting_regions:
|
||||||
if region['polygon'].contains(Point((x, y))):
|
if region["polygon"].contains(Point((x, y))):
|
||||||
current_region = region
|
current_region = region
|
||||||
current_region['dragging'] = True
|
current_region["dragging"] = True
|
||||||
current_region['offset_x'] = x
|
current_region["offset_x"] = x
|
||||||
current_region['offset_y'] = y
|
current_region["offset_y"] = y
|
||||||
|
|
||||||
# Mouse move event
|
# Mouse move event
|
||||||
elif event == cv2.EVENT_MOUSEMOVE:
|
elif event == cv2.EVENT_MOUSEMOVE:
|
||||||
if current_region is not None and current_region['dragging']:
|
if current_region is not None and current_region["dragging"]:
|
||||||
dx = x - current_region['offset_x']
|
dx = x - current_region["offset_x"]
|
||||||
dy = y - current_region['offset_y']
|
dy = y - current_region["offset_y"]
|
||||||
current_region['polygon'] = Polygon([
|
current_region["polygon"] = Polygon(
|
||||||
(p[0] + dx, p[1] + dy) for p in current_region['polygon'].exterior.coords])
|
[(p[0] + dx, p[1] + dy) for p in current_region["polygon"].exterior.coords]
|
||||||
current_region['offset_x'] = x
|
)
|
||||||
current_region['offset_y'] = y
|
current_region["offset_x"] = x
|
||||||
|
current_region["offset_y"] = y
|
||||||
|
|
||||||
# Mouse left button up event
|
# Mouse left button up event
|
||||||
elif event == cv2.EVENT_LBUTTONUP:
|
elif event == cv2.EVENT_LBUTTONUP:
|
||||||
if current_region is not None and current_region['dragging']:
|
if current_region is not None and current_region["dragging"]:
|
||||||
current_region['dragging'] = False
|
current_region["dragging"] = False
|
||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
weights='yolov8n.pt',
|
weights="yolov8n.pt",
|
||||||
source=None,
|
source=None,
|
||||||
device='cpu',
|
device="cpu",
|
||||||
view_img=False,
|
view_img=False,
|
||||||
save_img=False,
|
save_img=False,
|
||||||
exist_ok=False,
|
exist_ok=False,
|
||||||
@ -100,8 +102,8 @@ def run(
|
|||||||
raise FileNotFoundError(f"Source path '{source}' does not exist.")
|
raise FileNotFoundError(f"Source path '{source}' does not exist.")
|
||||||
|
|
||||||
# Setup Model
|
# Setup Model
|
||||||
model = YOLO(f'{weights}')
|
model = YOLO(f"{weights}")
|
||||||
model.to('cuda') if device == '0' else model.to('cpu')
|
model.to("cuda") if device == "0" else model.to("cpu")
|
||||||
|
|
||||||
# Extract classes names
|
# Extract classes names
|
||||||
names = model.model.names
|
names = model.model.names
|
||||||
@ -109,12 +111,12 @@ def run(
|
|||||||
# Video setup
|
# Video setup
|
||||||
videocapture = cv2.VideoCapture(source)
|
videocapture = cv2.VideoCapture(source)
|
||||||
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
|
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
|
||||||
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v')
|
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v")
|
||||||
|
|
||||||
# Output setup
|
# Output setup
|
||||||
save_dir = increment_path(Path('ultralytics_rc_output') / 'exp', exist_ok)
|
save_dir = increment_path(Path("ultralytics_rc_output") / "exp", exist_ok)
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height))
|
video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height))
|
||||||
|
|
||||||
# Iterate over video frames
|
# Iterate over video frames
|
||||||
while videocapture.isOpened():
|
while videocapture.isOpened():
|
||||||
@ -146,43 +148,48 @@ def run(
|
|||||||
|
|
||||||
# Check if detection inside region
|
# Check if detection inside region
|
||||||
for region in counting_regions:
|
for region in counting_regions:
|
||||||
if region['polygon'].contains(Point((bbox_center[0], bbox_center[1]))):
|
if region["polygon"].contains(Point((bbox_center[0], bbox_center[1]))):
|
||||||
region['counts'] += 1
|
region["counts"] += 1
|
||||||
|
|
||||||
# Draw regions (Polygons/Rectangles)
|
# Draw regions (Polygons/Rectangles)
|
||||||
for region in counting_regions:
|
for region in counting_regions:
|
||||||
region_label = str(region['counts'])
|
region_label = str(region["counts"])
|
||||||
region_color = region['region_color']
|
region_color = region["region_color"]
|
||||||
region_text_color = region['text_color']
|
region_text_color = region["text_color"]
|
||||||
|
|
||||||
polygon_coords = np.array(region['polygon'].exterior.coords, dtype=np.int32)
|
polygon_coords = np.array(region["polygon"].exterior.coords, dtype=np.int32)
|
||||||
centroid_x, centroid_y = int(region['polygon'].centroid.x), int(region['polygon'].centroid.y)
|
centroid_x, centroid_y = int(region["polygon"].centroid.x), int(region["polygon"].centroid.y)
|
||||||
|
|
||||||
text_size, _ = cv2.getTextSize(region_label,
|
text_size, _ = cv2.getTextSize(
|
||||||
cv2.FONT_HERSHEY_SIMPLEX,
|
region_label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, thickness=line_thickness
|
||||||
fontScale=0.7,
|
)
|
||||||
thickness=line_thickness)
|
|
||||||
text_x = centroid_x - text_size[0] // 2
|
text_x = centroid_x - text_size[0] // 2
|
||||||
text_y = centroid_y + text_size[1] // 2
|
text_y = centroid_y + text_size[1] // 2
|
||||||
cv2.rectangle(frame, (text_x - 5, text_y - text_size[1] - 5), (text_x + text_size[0] + 5, text_y + 5),
|
cv2.rectangle(
|
||||||
region_color, -1)
|
frame,
|
||||||
cv2.putText(frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color,
|
(text_x - 5, text_y - text_size[1] - 5),
|
||||||
line_thickness)
|
(text_x + text_size[0] + 5, text_y + 5),
|
||||||
|
region_color,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
cv2.putText(
|
||||||
|
frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color, line_thickness
|
||||||
|
)
|
||||||
cv2.polylines(frame, [polygon_coords], isClosed=True, color=region_color, thickness=region_thickness)
|
cv2.polylines(frame, [polygon_coords], isClosed=True, color=region_color, thickness=region_thickness)
|
||||||
|
|
||||||
if view_img:
|
if view_img:
|
||||||
if vid_frame_count == 1:
|
if vid_frame_count == 1:
|
||||||
cv2.namedWindow('Ultralytics YOLOv8 Region Counter Movable')
|
cv2.namedWindow("Ultralytics YOLOv8 Region Counter Movable")
|
||||||
cv2.setMouseCallback('Ultralytics YOLOv8 Region Counter Movable', mouse_callback)
|
cv2.setMouseCallback("Ultralytics YOLOv8 Region Counter Movable", mouse_callback)
|
||||||
cv2.imshow('Ultralytics YOLOv8 Region Counter Movable', frame)
|
cv2.imshow("Ultralytics YOLOv8 Region Counter Movable", frame)
|
||||||
|
|
||||||
if save_img:
|
if save_img:
|
||||||
video_writer.write(frame)
|
video_writer.write(frame)
|
||||||
|
|
||||||
for region in counting_regions: # Reinitialize count for each region
|
for region in counting_regions: # Reinitialize count for each region
|
||||||
region['counts'] = 0
|
region["counts"] = 0
|
||||||
|
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
break
|
break
|
||||||
|
|
||||||
del vid_frame_count
|
del vid_frame_count
|
||||||
@ -194,16 +201,16 @@ def run(
|
|||||||
def parse_opt():
|
def parse_opt():
|
||||||
"""Parse command line arguments."""
|
"""Parse command line arguments."""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path')
|
parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
|
||||||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
|
||||||
parser.add_argument('--source', type=str, required=True, help='video file path')
|
parser.add_argument("--source", type=str, required=True, help="video file path")
|
||||||
parser.add_argument('--view-img', action='store_true', help='show results')
|
parser.add_argument("--view-img", action="store_true", help="show results")
|
||||||
parser.add_argument('--save-img', action='store_true', help='save results')
|
parser.add_argument("--save-img", action="store_true", help="save results")
|
||||||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
|
||||||
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
|
parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3")
|
||||||
parser.add_argument('--line-thickness', type=int, default=2, help='bounding box thickness')
|
parser.add_argument("--line-thickness", type=int, default=2, help="bounding box thickness")
|
||||||
parser.add_argument('--track-thickness', type=int, default=2, help='Tracking line thickness')
|
parser.add_argument("--track-thickness", type=int, default=2, help="Tracking line thickness")
|
||||||
parser.add_argument('--region-thickness', type=int, default=4, help='Region thickness')
|
parser.add_argument("--region-thickness", type=int, default=4, help="Region thickness")
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -213,6 +220,6 @@ def main(opt):
|
|||||||
run(**vars(opt))
|
run(**vars(opt))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
opt = parse_opt()
|
opt = parse_opt()
|
||||||
main(opt)
|
main(opt)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from sahi.utils.yolov8 import download_yolov8s_model
|
|||||||
from ultralytics.utils.files import increment_path
|
from ultralytics.utils.files import increment_path
|
||||||
|
|
||||||
|
|
||||||
def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, exist_ok=False):
|
def run(weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False):
|
||||||
"""
|
"""
|
||||||
Run object detection on a video using YOLOv8 and SAHI.
|
Run object detection on a video using YOLOv8 and SAHI.
|
||||||
|
|
||||||
@ -25,41 +25,41 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False,
|
|||||||
if not Path(source).exists():
|
if not Path(source).exists():
|
||||||
raise FileNotFoundError(f"Source path '{source}' does not exist.")
|
raise FileNotFoundError(f"Source path '{source}' does not exist.")
|
||||||
|
|
||||||
yolov8_model_path = f'models/{weights}'
|
yolov8_model_path = f"models/{weights}"
|
||||||
download_yolov8s_model(yolov8_model_path)
|
download_yolov8s_model(yolov8_model_path)
|
||||||
detection_model = AutoDetectionModel.from_pretrained(model_type='yolov8',
|
detection_model = AutoDetectionModel.from_pretrained(
|
||||||
model_path=yolov8_model_path,
|
model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu"
|
||||||
confidence_threshold=0.3,
|
)
|
||||||
device='cpu')
|
|
||||||
|
|
||||||
# Video setup
|
# Video setup
|
||||||
videocapture = cv2.VideoCapture(source)
|
videocapture = cv2.VideoCapture(source)
|
||||||
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
|
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
|
||||||
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v')
|
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v")
|
||||||
|
|
||||||
# Output setup
|
# Output setup
|
||||||
save_dir = increment_path(Path('ultralytics_results_with_sahi') / 'exp', exist_ok)
|
save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok)
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height))
|
video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height))
|
||||||
|
|
||||||
while videocapture.isOpened():
|
while videocapture.isOpened():
|
||||||
success, frame = videocapture.read()
|
success, frame = videocapture.read()
|
||||||
if not success:
|
if not success:
|
||||||
break
|
break
|
||||||
|
|
||||||
results = get_sliced_prediction(frame,
|
results = get_sliced_prediction(
|
||||||
detection_model,
|
frame, detection_model, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2
|
||||||
slice_height=512,
|
)
|
||||||
slice_width=512,
|
|
||||||
overlap_height_ratio=0.2,
|
|
||||||
overlap_width_ratio=0.2)
|
|
||||||
object_prediction_list = results.object_prediction_list
|
object_prediction_list = results.object_prediction_list
|
||||||
|
|
||||||
boxes_list = []
|
boxes_list = []
|
||||||
clss_list = []
|
clss_list = []
|
||||||
for ind, _ in enumerate(object_prediction_list):
|
for ind, _ in enumerate(object_prediction_list):
|
||||||
boxes = object_prediction_list[ind].bbox.minx, object_prediction_list[ind].bbox.miny, \
|
boxes = (
|
||||||
object_prediction_list[ind].bbox.maxx, object_prediction_list[ind].bbox.maxy
|
object_prediction_list[ind].bbox.minx,
|
||||||
|
object_prediction_list[ind].bbox.miny,
|
||||||
|
object_prediction_list[ind].bbox.maxx,
|
||||||
|
object_prediction_list[ind].bbox.maxy,
|
||||||
|
)
|
||||||
clss = object_prediction_list[ind].category.name
|
clss = object_prediction_list[ind].category.name
|
||||||
boxes_list.append(boxes)
|
boxes_list.append(boxes)
|
||||||
clss_list.append(clss)
|
clss_list.append(clss)
|
||||||
@ -69,21 +69,19 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False,
|
|||||||
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2)
|
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2)
|
||||||
label = str(cls)
|
label = str(cls)
|
||||||
t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
|
t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
|
||||||
cv2.rectangle(frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255),
|
cv2.rectangle(
|
||||||
-1)
|
frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), -1
|
||||||
cv2.putText(frame,
|
)
|
||||||
label, (int(x1), int(y1) - 2),
|
cv2.putText(
|
||||||
0,
|
frame, label, (int(x1), int(y1) - 2), 0, 0.6, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA
|
||||||
0.6, [255, 255, 255],
|
)
|
||||||
thickness=1,
|
|
||||||
lineType=cv2.LINE_AA)
|
|
||||||
|
|
||||||
if view_img:
|
if view_img:
|
||||||
cv2.imshow(Path(source).stem, frame)
|
cv2.imshow(Path(source).stem, frame)
|
||||||
if save_img:
|
if save_img:
|
||||||
video_writer.write(frame)
|
video_writer.write(frame)
|
||||||
|
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
break
|
break
|
||||||
video_writer.release()
|
video_writer.release()
|
||||||
videocapture.release()
|
videocapture.release()
|
||||||
@ -93,11 +91,11 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False,
|
|||||||
def parse_opt():
|
def parse_opt():
|
||||||
"""Parse command line arguments."""
|
"""Parse command line arguments."""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path')
|
parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
|
||||||
parser.add_argument('--source', type=str, required=True, help='video file path')
|
parser.add_argument("--source", type=str, required=True, help="video file path")
|
||||||
parser.add_argument('--view-img', action='store_true', help='show results')
|
parser.add_argument("--view-img", action="store_true", help="show results")
|
||||||
parser.add_argument('--save-img', action='store_true', help='save results')
|
parser.add_argument("--save-img", action="store_true", help="save results")
|
||||||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -106,6 +104,6 @@ def main(opt):
|
|||||||
run(**vars(opt))
|
run(**vars(opt))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
opt = parse_opt()
|
opt = parse_opt()
|
||||||
main(opt)
|
main(opt)
|
||||||
|
|||||||
@ -21,18 +21,21 @@ class YOLOv8Seg:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Build Ort session
|
# Build Ort session
|
||||||
self.session = ort.InferenceSession(onnx_model,
|
self.session = ort.InferenceSession(
|
||||||
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
onnx_model,
|
||||||
if ort.get_device() == 'GPU' else ['CPUExecutionProvider'])
|
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||||
|
if ort.get_device() == "GPU"
|
||||||
|
else ["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
# Numpy dtype: support both FP32 and FP16 onnx model
|
# Numpy dtype: support both FP32 and FP16 onnx model
|
||||||
self.ndtype = np.half if self.session.get_inputs()[0].type == 'tensor(float16)' else np.single
|
self.ndtype = np.half if self.session.get_inputs()[0].type == "tensor(float16)" else np.single
|
||||||
|
|
||||||
# Get model width and height(YOLOv8-seg only has one input)
|
# Get model width and height(YOLOv8-seg only has one input)
|
||||||
self.model_height, self.model_width = [x.shape for x in self.session.get_inputs()][0][-2:]
|
self.model_height, self.model_width = [x.shape for x in self.session.get_inputs()][0][-2:]
|
||||||
|
|
||||||
# Load COCO class names
|
# Load COCO class names
|
||||||
self.classes = yaml_load(check_yaml('coco128.yaml'))['names']
|
self.classes = yaml_load(check_yaml("coco128.yaml"))["names"]
|
||||||
|
|
||||||
# Create color palette
|
# Create color palette
|
||||||
self.color_palette = Colors()
|
self.color_palette = Colors()
|
||||||
@ -60,14 +63,16 @@ class YOLOv8Seg:
|
|||||||
preds = self.session.run(None, {self.session.get_inputs()[0].name: im})
|
preds = self.session.run(None, {self.session.get_inputs()[0].name: im})
|
||||||
|
|
||||||
# Post-process
|
# Post-process
|
||||||
boxes, segments, masks = self.postprocess(preds,
|
boxes, segments, masks = self.postprocess(
|
||||||
|
preds,
|
||||||
im0=im0,
|
im0=im0,
|
||||||
ratio=ratio,
|
ratio=ratio,
|
||||||
pad_w=pad_w,
|
pad_w=pad_w,
|
||||||
pad_h=pad_h,
|
pad_h=pad_h,
|
||||||
conf_threshold=conf_threshold,
|
conf_threshold=conf_threshold,
|
||||||
iou_threshold=iou_threshold,
|
iou_threshold=iou_threshold,
|
||||||
nm=nm)
|
nm=nm,
|
||||||
|
)
|
||||||
return boxes, segments, masks
|
return boxes, segments, masks
|
||||||
|
|
||||||
def preprocess(self, img):
|
def preprocess(self, img):
|
||||||
@ -98,7 +103,7 @@ class YOLOv8Seg:
|
|||||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||||||
|
|
||||||
# Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional)
|
# Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional)
|
||||||
img = np.ascontiguousarray(np.einsum('HWC->CHW', img)[::-1], dtype=self.ndtype) / 255.0
|
img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype) / 255.0
|
||||||
img_process = img[None] if len(img.shape) == 3 else img
|
img_process = img[None] if len(img.shape) == 3 else img
|
||||||
return img_process, ratio, (pad_w, pad_h)
|
return img_process, ratio, (pad_w, pad_h)
|
||||||
|
|
||||||
@ -124,7 +129,7 @@ class YOLOv8Seg:
|
|||||||
x, protos = preds[0], preds[1] # Two outputs: predictions and protos
|
x, protos = preds[0], preds[1] # Two outputs: predictions and protos
|
||||||
|
|
||||||
# Transpose the first output: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm)
|
# Transpose the first output: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm)
|
||||||
x = np.einsum('bcn->bnc', x)
|
x = np.einsum("bcn->bnc", x)
|
||||||
|
|
||||||
# Predictions filtering by conf-threshold
|
# Predictions filtering by conf-threshold
|
||||||
x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold]
|
x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold]
|
||||||
@ -138,7 +143,6 @@ class YOLOv8Seg:
|
|||||||
|
|
||||||
# Decode and return
|
# Decode and return
|
||||||
if len(x) > 0:
|
if len(x) > 0:
|
||||||
|
|
||||||
# Bounding boxes format change: cxcywh -> xyxy
|
# Bounding boxes format change: cxcywh -> xyxy
|
||||||
x[..., [0, 1]] -= x[..., [2, 3]] / 2
|
x[..., [0, 1]] -= x[..., [2, 3]] / 2
|
||||||
x[..., [2, 3]] += x[..., [0, 1]]
|
x[..., [2, 3]] += x[..., [0, 1]]
|
||||||
@ -173,13 +177,13 @@ class YOLOv8Seg:
|
|||||||
segments (List): list of segment masks.
|
segments (List): list of segment masks.
|
||||||
"""
|
"""
|
||||||
segments = []
|
segments = []
|
||||||
for x in masks.astype('uint8'):
|
for x in masks.astype("uint8"):
|
||||||
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] # CHAIN_APPROX_SIMPLE
|
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] # CHAIN_APPROX_SIMPLE
|
||||||
if c:
|
if c:
|
||||||
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
||||||
else:
|
else:
|
||||||
c = np.zeros((0, 2)) # no segments found
|
c = np.zeros((0, 2)) # no segments found
|
||||||
segments.append(c.astype('float32'))
|
segments.append(c.astype("float32"))
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -219,7 +223,7 @@ class YOLOv8Seg:
|
|||||||
masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw)).transpose(1, 2, 0) # HWN
|
masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw)).transpose(1, 2, 0) # HWN
|
||||||
masks = np.ascontiguousarray(masks)
|
masks = np.ascontiguousarray(masks)
|
||||||
masks = self.scale_mask(masks, im0_shape) # re-scale mask from P3 shape to original input image shape
|
masks = self.scale_mask(masks, im0_shape) # re-scale mask from P3 shape to original input image shape
|
||||||
masks = np.einsum('HWN -> NHW', masks) # HWN -> NHW
|
masks = np.einsum("HWN -> NHW", masks) # HWN -> NHW
|
||||||
masks = self.crop_mask(masks, bboxes)
|
masks = self.crop_mask(masks, bboxes)
|
||||||
return np.greater(masks, 0.5)
|
return np.greater(masks, 0.5)
|
||||||
|
|
||||||
@ -250,8 +254,9 @@ class YOLOv8Seg:
|
|||||||
if len(masks.shape) < 2:
|
if len(masks.shape) < 2:
|
||||||
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
|
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
|
||||||
masks = masks[top:bottom, left:right]
|
masks = masks[top:bottom, left:right]
|
||||||
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]),
|
masks = cv2.resize(
|
||||||
interpolation=cv2.INTER_LINEAR) # INTER_CUBIC would be better
|
masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR
|
||||||
|
) # INTER_CUBIC would be better
|
||||||
if len(masks.shape) == 2:
|
if len(masks.shape) == 2:
|
||||||
masks = masks[:, :, None]
|
masks = masks[:, :, None]
|
||||||
return masks
|
return masks
|
||||||
@ -279,32 +284,46 @@ class YOLOv8Seg:
|
|||||||
cv2.fillPoly(im_canvas, np.int32([segment]), self.color_palette(int(cls_), bgr=True))
|
cv2.fillPoly(im_canvas, np.int32([segment]), self.color_palette(int(cls_), bgr=True))
|
||||||
|
|
||||||
# draw bbox rectangle
|
# draw bbox rectangle
|
||||||
cv2.rectangle(im, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
|
cv2.rectangle(
|
||||||
self.color_palette(int(cls_), bgr=True), 1, cv2.LINE_AA)
|
im,
|
||||||
cv2.putText(im, f'{self.classes[cls_]}: {conf:.3f}', (int(box[0]), int(box[1] - 9)),
|
(int(box[0]), int(box[1])),
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, self.color_palette(int(cls_), bgr=True), 2, cv2.LINE_AA)
|
(int(box[2]), int(box[3])),
|
||||||
|
self.color_palette(int(cls_), bgr=True),
|
||||||
|
1,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
cv2.putText(
|
||||||
|
im,
|
||||||
|
f"{self.classes[cls_]}: {conf:.3f}",
|
||||||
|
(int(box[0]), int(box[1] - 9)),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.7,
|
||||||
|
self.color_palette(int(cls_), bgr=True),
|
||||||
|
2,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
# Mix image
|
# Mix image
|
||||||
im = cv2.addWeighted(im_canvas, 0.3, im, 0.7, 0)
|
im = cv2.addWeighted(im_canvas, 0.3, im, 0.7, 0)
|
||||||
|
|
||||||
# Show image
|
# Show image
|
||||||
if vis:
|
if vis:
|
||||||
cv2.imshow('demo', im)
|
cv2.imshow("demo", im)
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
# Save image
|
# Save image
|
||||||
if save:
|
if save:
|
||||||
cv2.imwrite('demo.jpg', im)
|
cv2.imwrite("demo.jpg", im)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# Create an argument parser to handle command-line arguments
|
# Create an argument parser to handle command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', type=str, required=True, help='Path to ONNX model')
|
parser.add_argument("--model", type=str, required=True, help="Path to ONNX model")
|
||||||
parser.add_argument('--source', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image')
|
parser.add_argument("--source", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image")
|
||||||
parser.add_argument('--conf', type=float, default=0.25, help='Confidence threshold')
|
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
|
||||||
parser.add_argument('--iou', type=float, default=0.45, help='NMS IoU threshold')
|
parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Build model
|
# Build model
|
||||||
|
|||||||
@ -179,5 +179,5 @@ pre-summary-newline = true
|
|||||||
close-quotes-on-newline = true
|
close-quotes-on-newline = true
|
||||||
|
|
||||||
[tool.codespell]
|
[tool.codespell]
|
||||||
ignore-words-list = "crate,nd,strack,dota,ane,segway,fo,gool,winn,commend"
|
ignore-words-list = "crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall"
|
||||||
skip = '*.csv,*venv*,docs/??/,docs/mkdocs_??.yml'
|
skip = "*.pt,*.pth,*.torchscript,*.onnx,*.tflite,*.pb,*.bin,*.param,*.mlmodel,*.engine,*.npy,*.data*,*.csv,*pnnx*,*venv*,__pycache__*,*.ico,*.jpg,*.png,*.mp4,*.mov,/runs,/.git,./docs/??/*.md,./docs/mkdocs_??.yml"
|
||||||
|
|||||||
@ -3,6 +3,8 @@ import PIL
|
|||||||
from ultralytics import Explorer
|
from ultralytics import Explorer
|
||||||
from ultralytics.utils import ASSETS
|
from ultralytics.utils import ASSETS
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
def test_similarity():
|
def test_similarity():
|
||||||
"""Test similarity calculations and SQL queries for correctness and response length."""
|
"""Test similarity calculations and SQL queries for correctness and response length."""
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.238'
|
__version__ = "8.0.239"
|
||||||
|
|
||||||
from ultralytics.data.explorer.explorer import Explorer
|
from ultralytics.data.explorer.explorer import Explorer
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO
|
from ultralytics.models import RTDETR, SAM, YOLO
|
||||||
@ -10,4 +10,4 @@ from ultralytics.utils import SETTINGS as settings
|
|||||||
from ultralytics.utils.checks import check_yolo as checks
|
from ultralytics.utils.checks import check_yolo as checks
|
||||||
from ultralytics.utils.downloads import download
|
from ultralytics.utils.downloads import download
|
||||||
|
|
||||||
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings', 'Explorer'
|
__all__ = "__version__", "YOLO", "NAS", "SAM", "FastSAM", "RTDETR", "checks", "download", "settings", "Explorer"
|
||||||
|
|||||||
@ -8,34 +8,53 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, ROOT, RUNS_DIR,
|
from ultralytics.utils import (
|
||||||
SETTINGS, SETTINGS_YAML, TESTS_RUNNING, IterableSimpleNamespace, __version__, checks,
|
ASSETS,
|
||||||
colorstr, deprecation_warn, yaml_load, yaml_print)
|
DEFAULT_CFG,
|
||||||
|
DEFAULT_CFG_DICT,
|
||||||
|
DEFAULT_CFG_PATH,
|
||||||
|
LOGGER,
|
||||||
|
RANK,
|
||||||
|
ROOT,
|
||||||
|
RUNS_DIR,
|
||||||
|
SETTINGS,
|
||||||
|
SETTINGS_YAML,
|
||||||
|
TESTS_RUNNING,
|
||||||
|
IterableSimpleNamespace,
|
||||||
|
__version__,
|
||||||
|
checks,
|
||||||
|
colorstr,
|
||||||
|
deprecation_warn,
|
||||||
|
yaml_load,
|
||||||
|
yaml_print,
|
||||||
|
)
|
||||||
|
|
||||||
# Define valid tasks and modes
|
# Define valid tasks and modes
|
||||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
MODES = "train", "val", "predict", "export", "track", "benchmark"
|
||||||
TASKS = 'detect', 'segment', 'classify', 'pose', 'obb'
|
TASKS = "detect", "segment", "classify", "pose", "obb"
|
||||||
TASK2DATA = {
|
TASK2DATA = {
|
||||||
'detect': 'coco8.yaml',
|
"detect": "coco8.yaml",
|
||||||
'segment': 'coco8-seg.yaml',
|
"segment": "coco8-seg.yaml",
|
||||||
'classify': 'imagenet10',
|
"classify": "imagenet10",
|
||||||
'pose': 'coco8-pose.yaml',
|
"pose": "coco8-pose.yaml",
|
||||||
'obb': 'dota8-obb.yaml'} # not implemented yet
|
"obb": "dota8-obb.yaml",
|
||||||
|
}
|
||||||
TASK2MODEL = {
|
TASK2MODEL = {
|
||||||
'detect': 'yolov8n.pt',
|
"detect": "yolov8n.pt",
|
||||||
'segment': 'yolov8n-seg.pt',
|
"segment": "yolov8n-seg.pt",
|
||||||
'classify': 'yolov8n-cls.pt',
|
"classify": "yolov8n-cls.pt",
|
||||||
'pose': 'yolov8n-pose.pt',
|
"pose": "yolov8n-pose.pt",
|
||||||
'obb': 'yolov8n-obb.pt'}
|
"obb": "yolov8n-obb.pt",
|
||||||
|
}
|
||||||
TASK2METRIC = {
|
TASK2METRIC = {
|
||||||
'detect': 'metrics/mAP50-95(B)',
|
"detect": "metrics/mAP50-95(B)",
|
||||||
'segment': 'metrics/mAP50-95(M)',
|
"segment": "metrics/mAP50-95(M)",
|
||||||
'classify': 'metrics/accuracy_top1',
|
"classify": "metrics/accuracy_top1",
|
||||||
'pose': 'metrics/mAP50-95(P)',
|
"pose": "metrics/mAP50-95(P)",
|
||||||
'obb': 'metrics/mAP50-95(OBB)'}
|
"obb": "metrics/mAP50-95(OBB)",
|
||||||
|
}
|
||||||
|
|
||||||
CLI_HELP_MSG = \
|
CLI_HELP_MSG = f"""
|
||||||
f"""
|
|
||||||
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
||||||
|
|
||||||
yolo TASK MODE ARGS
|
yolo TASK MODE ARGS
|
||||||
@ -74,16 +93,83 @@ CLI_HELP_MSG = \
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Define keys for arg type checks
|
# Define keys for arg type checks
|
||||||
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'time'
|
CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"
|
||||||
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
|
CFG_FRACTION_KEYS = (
|
||||||
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
|
"dropout",
|
||||||
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
|
"iou",
|
||||||
CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
"lr0",
|
||||||
'line_width', 'workspace', 'nbs', 'save_period')
|
"lrf",
|
||||||
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
|
"momentum",
|
||||||
'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
|
"weight_decay",
|
||||||
'save_frames', 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks',
|
"warmup_momentum",
|
||||||
'show_boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile', 'multi_scale')
|
"warmup_bias_lr",
|
||||||
|
"label_smoothing",
|
||||||
|
"hsv_h",
|
||||||
|
"hsv_s",
|
||||||
|
"hsv_v",
|
||||||
|
"translate",
|
||||||
|
"scale",
|
||||||
|
"perspective",
|
||||||
|
"flipud",
|
||||||
|
"fliplr",
|
||||||
|
"mosaic",
|
||||||
|
"mixup",
|
||||||
|
"copy_paste",
|
||||||
|
"conf",
|
||||||
|
"iou",
|
||||||
|
"fraction",
|
||||||
|
) # fraction floats 0.0 - 1.0
|
||||||
|
CFG_INT_KEYS = (
|
||||||
|
"epochs",
|
||||||
|
"patience",
|
||||||
|
"batch",
|
||||||
|
"workers",
|
||||||
|
"seed",
|
||||||
|
"close_mosaic",
|
||||||
|
"mask_ratio",
|
||||||
|
"max_det",
|
||||||
|
"vid_stride",
|
||||||
|
"line_width",
|
||||||
|
"workspace",
|
||||||
|
"nbs",
|
||||||
|
"save_period",
|
||||||
|
)
|
||||||
|
CFG_BOOL_KEYS = (
|
||||||
|
"save",
|
||||||
|
"exist_ok",
|
||||||
|
"verbose",
|
||||||
|
"deterministic",
|
||||||
|
"single_cls",
|
||||||
|
"rect",
|
||||||
|
"cos_lr",
|
||||||
|
"overlap_mask",
|
||||||
|
"val",
|
||||||
|
"save_json",
|
||||||
|
"save_hybrid",
|
||||||
|
"half",
|
||||||
|
"dnn",
|
||||||
|
"plots",
|
||||||
|
"show",
|
||||||
|
"save_txt",
|
||||||
|
"save_conf",
|
||||||
|
"save_crop",
|
||||||
|
"save_frames",
|
||||||
|
"show_labels",
|
||||||
|
"show_conf",
|
||||||
|
"visualize",
|
||||||
|
"augment",
|
||||||
|
"agnostic_nms",
|
||||||
|
"retina_masks",
|
||||||
|
"show_boxes",
|
||||||
|
"keras",
|
||||||
|
"optimize",
|
||||||
|
"int8",
|
||||||
|
"dynamic",
|
||||||
|
"simplify",
|
||||||
|
"nms",
|
||||||
|
"profile",
|
||||||
|
"multi_scale",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cfg2dict(cfg):
|
def cfg2dict(cfg):
|
||||||
@ -119,38 +205,44 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
|||||||
# Merge overrides
|
# Merge overrides
|
||||||
if overrides:
|
if overrides:
|
||||||
overrides = cfg2dict(overrides)
|
overrides = cfg2dict(overrides)
|
||||||
if 'save_dir' not in cfg:
|
if "save_dir" not in cfg:
|
||||||
overrides.pop('save_dir', None) # special override keys to ignore
|
overrides.pop("save_dir", None) # special override keys to ignore
|
||||||
check_dict_alignment(cfg, overrides)
|
check_dict_alignment(cfg, overrides)
|
||||||
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
||||||
|
|
||||||
# Special handling for numeric project/name
|
# Special handling for numeric project/name
|
||||||
for k in 'project', 'name':
|
for k in "project", "name":
|
||||||
if k in cfg and isinstance(cfg[k], (int, float)):
|
if k in cfg and isinstance(cfg[k], (int, float)):
|
||||||
cfg[k] = str(cfg[k])
|
cfg[k] = str(cfg[k])
|
||||||
if cfg.get('name') == 'model': # assign model to 'name' arg
|
if cfg.get("name") == "model": # assign model to 'name' arg
|
||||||
cfg['name'] = cfg.get('model', '').split('.')[0]
|
cfg["name"] = cfg.get("model", "").split(".")[0]
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
|
||||||
|
|
||||||
# Type and Value checks
|
# Type and Value checks
|
||||||
for k, v in cfg.items():
|
for k, v in cfg.items():
|
||||||
if v is not None: # None values may be from optional args
|
if v is not None: # None values may be from optional args
|
||||||
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
||||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
raise TypeError(
|
||||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||||
|
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
|
||||||
|
)
|
||||||
elif k in CFG_FRACTION_KEYS:
|
elif k in CFG_FRACTION_KEYS:
|
||||||
if not isinstance(v, (int, float)):
|
if not isinstance(v, (int, float)):
|
||||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
raise TypeError(
|
||||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||||
|
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
|
||||||
|
)
|
||||||
if not (0.0 <= v <= 1.0):
|
if not (0.0 <= v <= 1.0):
|
||||||
raise ValueError(f"'{k}={v}' is an invalid value. "
|
raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")
|
||||||
f"Valid '{k}' values are between 0.0 and 1.0.")
|
|
||||||
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
||||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
raise TypeError(
|
||||||
f"'{k}' must be an int (i.e. '{k}=8')")
|
f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
|
||||||
|
)
|
||||||
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
||||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
raise TypeError(
|
||||||
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
|
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||||
|
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
|
||||||
|
)
|
||||||
|
|
||||||
# Return instance
|
# Return instance
|
||||||
return IterableSimpleNamespace(**cfg)
|
return IterableSimpleNamespace(**cfg)
|
||||||
@ -159,13 +251,13 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
|||||||
def get_save_dir(args, name=None):
|
def get_save_dir(args, name=None):
|
||||||
"""Return save_dir as created from train/val/predict arguments."""
|
"""Return save_dir as created from train/val/predict arguments."""
|
||||||
|
|
||||||
if getattr(args, 'save_dir', None):
|
if getattr(args, "save_dir", None):
|
||||||
save_dir = args.save_dir
|
save_dir = args.save_dir
|
||||||
else:
|
else:
|
||||||
from ultralytics.utils.files import increment_path
|
from ultralytics.utils.files import increment_path
|
||||||
|
|
||||||
project = args.project or (ROOT.parent / 'tests/tmp/runs' if TESTS_RUNNING else RUNS_DIR) / args.task
|
project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task
|
||||||
name = name or args.name or f'{args.mode}'
|
name = name or args.name or f"{args.mode}"
|
||||||
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
|
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
|
||||||
|
|
||||||
return Path(save_dir)
|
return Path(save_dir)
|
||||||
@ -175,18 +267,18 @@ def _handle_deprecation(custom):
|
|||||||
"""Hardcoded function to handle deprecated config keys."""
|
"""Hardcoded function to handle deprecated config keys."""
|
||||||
|
|
||||||
for key in custom.copy().keys():
|
for key in custom.copy().keys():
|
||||||
if key == 'boxes':
|
if key == "boxes":
|
||||||
deprecation_warn(key, 'show_boxes')
|
deprecation_warn(key, "show_boxes")
|
||||||
custom['show_boxes'] = custom.pop('boxes')
|
custom["show_boxes"] = custom.pop("boxes")
|
||||||
if key == 'hide_labels':
|
if key == "hide_labels":
|
||||||
deprecation_warn(key, 'show_labels')
|
deprecation_warn(key, "show_labels")
|
||||||
custom['show_labels'] = custom.pop('hide_labels') == 'False'
|
custom["show_labels"] = custom.pop("hide_labels") == "False"
|
||||||
if key == 'hide_conf':
|
if key == "hide_conf":
|
||||||
deprecation_warn(key, 'show_conf')
|
deprecation_warn(key, "show_conf")
|
||||||
custom['show_conf'] = custom.pop('hide_conf') == 'False'
|
custom["show_conf"] = custom.pop("hide_conf") == "False"
|
||||||
if key == 'line_thickness':
|
if key == "line_thickness":
|
||||||
deprecation_warn(key, 'line_width')
|
deprecation_warn(key, "line_width")
|
||||||
custom['line_width'] = custom.pop('line_thickness')
|
custom["line_width"] = custom.pop("line_thickness")
|
||||||
|
|
||||||
return custom
|
return custom
|
||||||
|
|
||||||
@ -207,11 +299,11 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
|
|||||||
if mismatched:
|
if mismatched:
|
||||||
from difflib import get_close_matches
|
from difflib import get_close_matches
|
||||||
|
|
||||||
string = ''
|
string = ""
|
||||||
for x in mismatched:
|
for x in mismatched:
|
||||||
matches = get_close_matches(x, base_keys) # key list
|
matches = get_close_matches(x, base_keys) # key list
|
||||||
matches = [f'{k}={base[k]}' if base.get(k) is not None else k for k in matches]
|
matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches]
|
||||||
match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
|
match_str = f"Similar arguments are i.e. {matches}." if matches else ""
|
||||||
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
||||||
raise SyntaxError(string + CLI_HELP_MSG) from e
|
raise SyntaxError(string + CLI_HELP_MSG) from e
|
||||||
|
|
||||||
@ -229,13 +321,13 @@ def merge_equals_args(args: List[str]) -> List[str]:
|
|||||||
"""
|
"""
|
||||||
new_args = []
|
new_args = []
|
||||||
for i, arg in enumerate(args):
|
for i, arg in enumerate(args):
|
||||||
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
|
if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
|
||||||
new_args[-1] += f'={args[i + 1]}'
|
new_args[-1] += f"={args[i + 1]}"
|
||||||
del args[i + 1]
|
del args[i + 1]
|
||||||
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
|
elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val']
|
||||||
new_args.append(f'{arg}{args[i + 1]}')
|
new_args.append(f"{arg}{args[i + 1]}")
|
||||||
del args[i + 1]
|
del args[i + 1]
|
||||||
elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
|
elif arg.startswith("=") and i > 0: # merge ['arg', '=val']
|
||||||
new_args[-1] += arg
|
new_args[-1] += arg
|
||||||
else:
|
else:
|
||||||
new_args.append(arg)
|
new_args.append(arg)
|
||||||
@ -259,11 +351,11 @@ def handle_yolo_hub(args: List[str]) -> None:
|
|||||||
"""
|
"""
|
||||||
from ultralytics import hub
|
from ultralytics import hub
|
||||||
|
|
||||||
if args[0] == 'login':
|
if args[0] == "login":
|
||||||
key = args[1] if len(args) > 1 else ''
|
key = args[1] if len(args) > 1 else ""
|
||||||
# Log in to Ultralytics HUB using the provided API key
|
# Log in to Ultralytics HUB using the provided API key
|
||||||
hub.login(key)
|
hub.login(key)
|
||||||
elif args[0] == 'logout':
|
elif args[0] == "logout":
|
||||||
# Log out from Ultralytics HUB
|
# Log out from Ultralytics HUB
|
||||||
hub.logout()
|
hub.logout()
|
||||||
|
|
||||||
@ -283,19 +375,19 @@ def handle_yolo_settings(args: List[str]) -> None:
|
|||||||
python my_script.py yolo settings reset
|
python my_script.py yolo settings reset
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
url = 'https://docs.ultralytics.com/quickstart/#ultralytics-settings' # help URL
|
url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL
|
||||||
try:
|
try:
|
||||||
if any(args):
|
if any(args):
|
||||||
if args[0] == 'reset':
|
if args[0] == "reset":
|
||||||
SETTINGS_YAML.unlink() # delete the settings file
|
SETTINGS_YAML.unlink() # delete the settings file
|
||||||
SETTINGS.reset() # create new settings
|
SETTINGS.reset() # create new settings
|
||||||
LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
|
LOGGER.info("Settings reset successfully") # inform the user that settings have been reset
|
||||||
else: # save a new setting
|
else: # save a new setting
|
||||||
new = dict(parse_key_value_pair(a) for a in args)
|
new = dict(parse_key_value_pair(a) for a in args)
|
||||||
check_dict_alignment(SETTINGS, new)
|
check_dict_alignment(SETTINGS, new)
|
||||||
SETTINGS.update(new)
|
SETTINGS.update(new)
|
||||||
|
|
||||||
LOGGER.info(f'💡 Learn about settings at {url}')
|
LOGGER.info(f"💡 Learn about settings at {url}")
|
||||||
yaml_print(SETTINGS_YAML) # print the current settings
|
yaml_print(SETTINGS_YAML) # print the current settings
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.")
|
LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.")
|
||||||
@ -303,13 +395,13 @@ def handle_yolo_settings(args: List[str]) -> None:
|
|||||||
|
|
||||||
def handle_explorer():
|
def handle_explorer():
|
||||||
"""Open the Ultralytics Explorer GUI."""
|
"""Open the Ultralytics Explorer GUI."""
|
||||||
checks.check_requirements('streamlit')
|
checks.check_requirements("streamlit")
|
||||||
subprocess.run(['streamlit', 'run', ROOT / 'data/explorer/gui/dash.py', '--server.maxMessageSize', '2048'])
|
subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"])
|
||||||
|
|
||||||
|
|
||||||
def parse_key_value_pair(pair):
|
def parse_key_value_pair(pair):
|
||||||
"""Parse one 'key=value' pair and return key and value."""
|
"""Parse one 'key=value' pair and return key and value."""
|
||||||
k, v = pair.split('=', 1) # split on first '=' sign
|
k, v = pair.split("=", 1) # split on first '=' sign
|
||||||
k, v = k.strip(), v.strip() # remove spaces
|
k, v = k.strip(), v.strip() # remove spaces
|
||||||
assert v, f"missing '{k}' value"
|
assert v, f"missing '{k}' value"
|
||||||
return k, smart_value(v)
|
return k, smart_value(v)
|
||||||
@ -318,11 +410,11 @@ def parse_key_value_pair(pair):
|
|||||||
def smart_value(v):
|
def smart_value(v):
|
||||||
"""Convert a string to an underlying type such as int, float, bool, etc."""
|
"""Convert a string to an underlying type such as int, float, bool, etc."""
|
||||||
v_lower = v.lower()
|
v_lower = v.lower()
|
||||||
if v_lower == 'none':
|
if v_lower == "none":
|
||||||
return None
|
return None
|
||||||
elif v_lower == 'true':
|
elif v_lower == "true":
|
||||||
return True
|
return True
|
||||||
elif v_lower == 'false':
|
elif v_lower == "false":
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
@ -330,7 +422,7 @@ def smart_value(v):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(debug=''):
|
def entrypoint(debug=""):
|
||||||
"""
|
"""
|
||||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||||
to the package.
|
to the package.
|
||||||
@ -345,139 +437,150 @@ def entrypoint(debug=''):
|
|||||||
It uses the package's default cfg and initializes it using the passed overrides.
|
It uses the package's default cfg and initializes it using the passed overrides.
|
||||||
Then it calls the CLI function with the composed cfg
|
Then it calls the CLI function with the composed cfg
|
||||||
"""
|
"""
|
||||||
args = (debug.split(' ') if debug else sys.argv)[1:]
|
args = (debug.split(" ") if debug else sys.argv)[1:]
|
||||||
if not args: # no arguments passed
|
if not args: # no arguments passed
|
||||||
LOGGER.info(CLI_HELP_MSG)
|
LOGGER.info(CLI_HELP_MSG)
|
||||||
return
|
return
|
||||||
|
|
||||||
special = {
|
special = {
|
||||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
"help": lambda: LOGGER.info(CLI_HELP_MSG),
|
||||||
'checks': checks.collect_system_info,
|
"checks": checks.collect_system_info,
|
||||||
'version': lambda: LOGGER.info(__version__),
|
"version": lambda: LOGGER.info(__version__),
|
||||||
'settings': lambda: handle_yolo_settings(args[1:]),
|
"settings": lambda: handle_yolo_settings(args[1:]),
|
||||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
"cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||||
'hub': lambda: handle_yolo_hub(args[1:]),
|
"hub": lambda: handle_yolo_hub(args[1:]),
|
||||||
'login': lambda: handle_yolo_hub(args),
|
"login": lambda: handle_yolo_hub(args),
|
||||||
'copy-cfg': copy_default_cfg,
|
"copy-cfg": copy_default_cfg,
|
||||||
'explorer': lambda: handle_explorer()}
|
"explorer": lambda: handle_explorer(),
|
||||||
|
}
|
||||||
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
||||||
|
|
||||||
# Define common misuses of special commands, i.e. -h, -help, --help
|
# Define common misuses of special commands, i.e. -h, -help, --help
|
||||||
special.update({k[0]: v for k, v in special.items()}) # singular
|
special.update({k[0]: v for k, v in special.items()}) # singular
|
||||||
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
|
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular
|
||||||
special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
|
special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}}
|
||||||
|
|
||||||
overrides = {} # basic overrides, i.e. imgsz=320
|
overrides = {} # basic overrides, i.e. imgsz=320
|
||||||
for a in merge_equals_args(args): # merge spaces around '=' sign
|
for a in merge_equals_args(args): # merge spaces around '=' sign
|
||||||
if a.startswith('--'):
|
if a.startswith("--"):
|
||||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
||||||
a = a[2:]
|
a = a[2:]
|
||||||
if a.endswith(','):
|
if a.endswith(","):
|
||||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
||||||
a = a[:-1]
|
a = a[:-1]
|
||||||
if '=' in a:
|
if "=" in a:
|
||||||
try:
|
try:
|
||||||
k, v = parse_key_value_pair(a)
|
k, v = parse_key_value_pair(a)
|
||||||
if k == 'cfg' and v is not None: # custom.yaml passed
|
if k == "cfg" and v is not None: # custom.yaml passed
|
||||||
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
|
LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}")
|
||||||
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
|
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"}
|
||||||
else:
|
else:
|
||||||
overrides[k] = v
|
overrides[k] = v
|
||||||
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
||||||
check_dict_alignment(full_args_dict, {a: ''}, e)
|
check_dict_alignment(full_args_dict, {a: ""}, e)
|
||||||
|
|
||||||
elif a in TASKS:
|
elif a in TASKS:
|
||||||
overrides['task'] = a
|
overrides["task"] = a
|
||||||
elif a in MODES:
|
elif a in MODES:
|
||||||
overrides['mode'] = a
|
overrides["mode"] = a
|
||||||
elif a.lower() in special:
|
elif a.lower() in special:
|
||||||
special[a.lower()]()
|
special[a.lower()]()
|
||||||
return
|
return
|
||||||
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
|
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
|
||||||
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
|
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
|
||||||
elif a in DEFAULT_CFG_DICT:
|
elif a in DEFAULT_CFG_DICT:
|
||||||
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
raise SyntaxError(
|
||||||
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
||||||
|
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
check_dict_alignment(full_args_dict, {a: ''})
|
check_dict_alignment(full_args_dict, {a: ""})
|
||||||
|
|
||||||
# Check keys
|
# Check keys
|
||||||
check_dict_alignment(full_args_dict, overrides)
|
check_dict_alignment(full_args_dict, overrides)
|
||||||
|
|
||||||
# Mode
|
# Mode
|
||||||
mode = overrides.get('mode')
|
mode = overrides.get("mode")
|
||||||
if mode is None:
|
if mode is None:
|
||||||
mode = DEFAULT_CFG.mode or 'predict'
|
mode = DEFAULT_CFG.mode or "predict"
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
||||||
elif mode not in MODES:
|
elif mode not in MODES:
|
||||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
||||||
|
|
||||||
# Task
|
# Task
|
||||||
task = overrides.pop('task', None)
|
task = overrides.pop("task", None)
|
||||||
if task:
|
if task:
|
||||||
if task not in TASKS:
|
if task not in TASKS:
|
||||||
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
||||||
if 'model' not in overrides:
|
if "model" not in overrides:
|
||||||
overrides['model'] = TASK2MODEL[task]
|
overrides["model"] = TASK2MODEL[task]
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
model = overrides.pop('model', DEFAULT_CFG.model)
|
model = overrides.pop("model", DEFAULT_CFG.model)
|
||||||
if model is None:
|
if model is None:
|
||||||
model = 'yolov8n.pt'
|
model = "yolov8n.pt"
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||||
overrides['model'] = model
|
overrides["model"] = model
|
||||||
stem = Path(model).stem.lower()
|
stem = Path(model).stem.lower()
|
||||||
if 'rtdetr' in stem: # guess architecture
|
if "rtdetr" in stem: # guess architecture
|
||||||
from ultralytics import RTDETR
|
from ultralytics import RTDETR
|
||||||
|
|
||||||
model = RTDETR(model) # no task argument
|
model = RTDETR(model) # no task argument
|
||||||
elif 'fastsam' in stem:
|
elif "fastsam" in stem:
|
||||||
from ultralytics import FastSAM
|
from ultralytics import FastSAM
|
||||||
|
|
||||||
model = FastSAM(model)
|
model = FastSAM(model)
|
||||||
elif 'sam' in stem:
|
elif "sam" in stem:
|
||||||
from ultralytics import SAM
|
from ultralytics import SAM
|
||||||
|
|
||||||
model = SAM(model)
|
model = SAM(model)
|
||||||
else:
|
else:
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
model = YOLO(model, task=task)
|
model = YOLO(model, task=task)
|
||||||
if isinstance(overrides.get('pretrained'), str):
|
if isinstance(overrides.get("pretrained"), str):
|
||||||
model.load(overrides['pretrained'])
|
model.load(overrides["pretrained"])
|
||||||
|
|
||||||
# Task Update
|
# Task Update
|
||||||
if task != model.task:
|
if task != model.task:
|
||||||
if task:
|
if task:
|
||||||
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
|
LOGGER.warning(
|
||||||
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
|
f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
|
||||||
|
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model."
|
||||||
|
)
|
||||||
task = model.task
|
task = model.task
|
||||||
|
|
||||||
# Mode
|
# Mode
|
||||||
if mode in ('predict', 'track') and 'source' not in overrides:
|
if mode in ("predict", "track") and "source" not in overrides:
|
||||||
overrides['source'] = DEFAULT_CFG.source or ASSETS
|
overrides["source"] = DEFAULT_CFG.source or ASSETS
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||||
elif mode in ('train', 'val'):
|
elif mode in ("train", "val"):
|
||||||
if 'data' not in overrides and 'resume' not in overrides:
|
if "data" not in overrides and "resume" not in overrides:
|
||||||
overrides['data'] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
||||||
elif mode == 'export':
|
elif mode == "export":
|
||||||
if 'format' not in overrides:
|
if "format" not in overrides:
|
||||||
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
|
overrides["format"] = DEFAULT_CFG.format or "torchscript"
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
||||||
|
|
||||||
# Run command in python
|
# Run command in python
|
||||||
getattr(model, mode)(**overrides) # default args from model
|
getattr(model, mode)(**overrides) # default args from model
|
||||||
|
|
||||||
# Show help
|
# Show help
|
||||||
LOGGER.info(f'💡 Learn more at https://docs.ultralytics.com/modes/{mode}')
|
LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}")
|
||||||
|
|
||||||
|
|
||||||
# Special modes --------------------------------------------------------------------------------------------------------
|
# Special modes --------------------------------------------------------------------------------------------------------
|
||||||
def copy_default_cfg():
|
def copy_default_cfg():
|
||||||
"""Copy and create a new default configuration file with '_copy' appended to its name."""
|
"""Copy and create a new default configuration file with '_copy' appended to its name."""
|
||||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml")
|
||||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||||
LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
|
LOGGER.info(
|
||||||
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
|
f"{DEFAULT_CFG_PATH} copied to {new_file}\n"
|
||||||
|
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# Example: entrypoint(debug='yolo predict model=yolov8n.pt')
|
# Example: entrypoint(debug='yolo predict model=yolov8n.pt')
|
||||||
entrypoint(debug='')
|
entrypoint(debug="")
|
||||||
|
|||||||
@ -4,5 +4,12 @@ from .base import BaseDataset
|
|||||||
from .build import build_dataloader, build_yolo_dataset, load_inference_source
|
from .build import build_dataloader, build_yolo_dataset, load_inference_source
|
||||||
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
||||||
|
|
||||||
__all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset',
|
__all__ = (
|
||||||
'build_dataloader', 'load_inference_source')
|
"BaseDataset",
|
||||||
|
"ClassificationDataset",
|
||||||
|
"SemanticDataset",
|
||||||
|
"YOLODataset",
|
||||||
|
"build_yolo_dataset",
|
||||||
|
"build_dataloader",
|
||||||
|
"load_inference_source",
|
||||||
|
)
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
from ultralytics import SAM, YOLO
|
from ultralytics import SAM, YOLO
|
||||||
|
|
||||||
|
|
||||||
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
|
def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="", output_dir=None):
|
||||||
"""
|
"""
|
||||||
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
|
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
|
|||||||
|
|
||||||
data = Path(data)
|
data = Path(data)
|
||||||
if not output_dir:
|
if not output_dir:
|
||||||
output_dir = data.parent / f'{data.stem}_auto_annotate_labels'
|
output_dir = data.parent / f"{data.stem}_auto_annotate_labels"
|
||||||
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
det_results = det_model(data, stream=True, device=device)
|
det_results = det_model(data, stream=True, device=device)
|
||||||
@ -41,10 +41,10 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
|
|||||||
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
|
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
|
||||||
segments = sam_results[0].masks.xyn # noqa
|
segments = sam_results[0].masks.xyn # noqa
|
||||||
|
|
||||||
with open(f'{str(Path(output_dir) / Path(result.path).stem)}.txt', 'w') as f:
|
with open(f"{str(Path(output_dir) / Path(result.path).stem)}.txt", "w") as f:
|
||||||
for i in range(len(segments)):
|
for i in range(len(segments)):
|
||||||
s = segments[i]
|
s = segments[i]
|
||||||
if len(s) == 0:
|
if len(s) == 0:
|
||||||
continue
|
continue
|
||||||
segment = map(str, segments[i].reshape(-1).tolist())
|
segment = map(str, segments[i].reshape(-1).tolist())
|
||||||
f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
|
f.write(f"{class_ids[i]} " + " ".join(segment) + "\n")
|
||||||
|
|||||||
@ -117,11 +117,11 @@ class BaseMixTransform:
|
|||||||
if self.pre_transform is not None:
|
if self.pre_transform is not None:
|
||||||
for i, data in enumerate(mix_labels):
|
for i, data in enumerate(mix_labels):
|
||||||
mix_labels[i] = self.pre_transform(data)
|
mix_labels[i] = self.pre_transform(data)
|
||||||
labels['mix_labels'] = mix_labels
|
labels["mix_labels"] = mix_labels
|
||||||
|
|
||||||
# Mosaic or MixUp
|
# Mosaic or MixUp
|
||||||
labels = self._mix_transform(labels)
|
labels = self._mix_transform(labels)
|
||||||
labels.pop('mix_labels', None)
|
labels.pop("mix_labels", None)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def _mix_transform(self, labels):
|
def _mix_transform(self, labels):
|
||||||
@ -149,8 +149,8 @@ class Mosaic(BaseMixTransform):
|
|||||||
|
|
||||||
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
|
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
|
||||||
"""Initializes the object with a dataset, image size, probability, and border."""
|
"""Initializes the object with a dataset, image size, probability, and border."""
|
||||||
assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
|
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
|
||||||
assert n in (4, 9), 'grid must be equal to 4 or 9.'
|
assert n in (4, 9), "grid must be equal to 4 or 9."
|
||||||
super().__init__(dataset=dataset, p=p)
|
super().__init__(dataset=dataset, p=p)
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.imgsz = imgsz
|
self.imgsz = imgsz
|
||||||
@ -166,20 +166,21 @@ class Mosaic(BaseMixTransform):
|
|||||||
|
|
||||||
def _mix_transform(self, labels):
|
def _mix_transform(self, labels):
|
||||||
"""Apply mixup transformation to the input image and labels."""
|
"""Apply mixup transformation to the input image and labels."""
|
||||||
assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
|
assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive."
|
||||||
assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
|
assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment."
|
||||||
return self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(
|
return (
|
||||||
labels) # This code is modified for mosaic3 method.
|
self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
|
||||||
|
) # This code is modified for mosaic3 method.
|
||||||
|
|
||||||
def _mosaic3(self, labels):
|
def _mosaic3(self, labels):
|
||||||
"""Create a 1x3 image mosaic."""
|
"""Create a 1x3 image mosaic."""
|
||||||
mosaic_labels = []
|
mosaic_labels = []
|
||||||
s = self.imgsz
|
s = self.imgsz
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
|
||||||
# Load image
|
# Load image
|
||||||
img = labels_patch['img']
|
img = labels_patch["img"]
|
||||||
h, w = labels_patch.pop('resized_shape')
|
h, w = labels_patch.pop("resized_shape")
|
||||||
|
|
||||||
# Place img in img3
|
# Place img in img3
|
||||||
if i == 0: # center
|
if i == 0: # center
|
||||||
@ -194,7 +195,7 @@ class Mosaic(BaseMixTransform):
|
|||||||
padw, padh = c[:2]
|
padw, padh = c[:2]
|
||||||
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
|
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
|
||||||
|
|
||||||
img3[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img3[ymin:ymax, xmin:xmax]
|
img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax]
|
||||||
# hp, wp = h, w # height, width previous for next iteration
|
# hp, wp = h, w # height, width previous for next iteration
|
||||||
|
|
||||||
# Labels assuming imgsz*2 mosaic size
|
# Labels assuming imgsz*2 mosaic size
|
||||||
@ -202,7 +203,7 @@ class Mosaic(BaseMixTransform):
|
|||||||
mosaic_labels.append(labels_patch)
|
mosaic_labels.append(labels_patch)
|
||||||
final_labels = self._cat_labels(mosaic_labels)
|
final_labels = self._cat_labels(mosaic_labels)
|
||||||
|
|
||||||
final_labels['img'] = img3[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
|
final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
|
||||||
return final_labels
|
return final_labels
|
||||||
|
|
||||||
def _mosaic4(self, labels):
|
def _mosaic4(self, labels):
|
||||||
@ -211,10 +212,10 @@ class Mosaic(BaseMixTransform):
|
|||||||
s = self.imgsz
|
s = self.imgsz
|
||||||
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
|
||||||
# Load image
|
# Load image
|
||||||
img = labels_patch['img']
|
img = labels_patch["img"]
|
||||||
h, w = labels_patch.pop('resized_shape')
|
h, w = labels_patch.pop("resized_shape")
|
||||||
|
|
||||||
# Place img in img4
|
# Place img in img4
|
||||||
if i == 0: # top left
|
if i == 0: # top left
|
||||||
@ -238,7 +239,7 @@ class Mosaic(BaseMixTransform):
|
|||||||
labels_patch = self._update_labels(labels_patch, padw, padh)
|
labels_patch = self._update_labels(labels_patch, padw, padh)
|
||||||
mosaic_labels.append(labels_patch)
|
mosaic_labels.append(labels_patch)
|
||||||
final_labels = self._cat_labels(mosaic_labels)
|
final_labels = self._cat_labels(mosaic_labels)
|
||||||
final_labels['img'] = img4
|
final_labels["img"] = img4
|
||||||
return final_labels
|
return final_labels
|
||||||
|
|
||||||
def _mosaic9(self, labels):
|
def _mosaic9(self, labels):
|
||||||
@ -247,10 +248,10 @@ class Mosaic(BaseMixTransform):
|
|||||||
s = self.imgsz
|
s = self.imgsz
|
||||||
hp, wp = -1, -1 # height, width previous
|
hp, wp = -1, -1 # height, width previous
|
||||||
for i in range(9):
|
for i in range(9):
|
||||||
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
|
||||||
# Load image
|
# Load image
|
||||||
img = labels_patch['img']
|
img = labels_patch["img"]
|
||||||
h, w = labels_patch.pop('resized_shape')
|
h, w = labels_patch.pop("resized_shape")
|
||||||
|
|
||||||
# Place img in img9
|
# Place img in img9
|
||||||
if i == 0: # center
|
if i == 0: # center
|
||||||
@ -278,7 +279,7 @@ class Mosaic(BaseMixTransform):
|
|||||||
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
|
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
|
||||||
|
|
||||||
# Image
|
# Image
|
||||||
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
|
img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax]
|
||||||
hp, wp = h, w # height, width previous for next iteration
|
hp, wp = h, w # height, width previous for next iteration
|
||||||
|
|
||||||
# Labels assuming imgsz*2 mosaic size
|
# Labels assuming imgsz*2 mosaic size
|
||||||
@ -286,16 +287,16 @@ class Mosaic(BaseMixTransform):
|
|||||||
mosaic_labels.append(labels_patch)
|
mosaic_labels.append(labels_patch)
|
||||||
final_labels = self._cat_labels(mosaic_labels)
|
final_labels = self._cat_labels(mosaic_labels)
|
||||||
|
|
||||||
final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
|
final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
|
||||||
return final_labels
|
return final_labels
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_labels(labels, padw, padh):
|
def _update_labels(labels, padw, padh):
|
||||||
"""Update labels."""
|
"""Update labels."""
|
||||||
nh, nw = labels['img'].shape[:2]
|
nh, nw = labels["img"].shape[:2]
|
||||||
labels['instances'].convert_bbox(format='xyxy')
|
labels["instances"].convert_bbox(format="xyxy")
|
||||||
labels['instances'].denormalize(nw, nh)
|
labels["instances"].denormalize(nw, nh)
|
||||||
labels['instances'].add_padding(padw, padh)
|
labels["instances"].add_padding(padw, padh)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def _cat_labels(self, mosaic_labels):
|
def _cat_labels(self, mosaic_labels):
|
||||||
@ -306,18 +307,20 @@ class Mosaic(BaseMixTransform):
|
|||||||
instances = []
|
instances = []
|
||||||
imgsz = self.imgsz * 2 # mosaic imgsz
|
imgsz = self.imgsz * 2 # mosaic imgsz
|
||||||
for labels in mosaic_labels:
|
for labels in mosaic_labels:
|
||||||
cls.append(labels['cls'])
|
cls.append(labels["cls"])
|
||||||
instances.append(labels['instances'])
|
instances.append(labels["instances"])
|
||||||
|
# Final labels
|
||||||
final_labels = {
|
final_labels = {
|
||||||
'im_file': mosaic_labels[0]['im_file'],
|
"im_file": mosaic_labels[0]["im_file"],
|
||||||
'ori_shape': mosaic_labels[0]['ori_shape'],
|
"ori_shape": mosaic_labels[0]["ori_shape"],
|
||||||
'resized_shape': (imgsz, imgsz),
|
"resized_shape": (imgsz, imgsz),
|
||||||
'cls': np.concatenate(cls, 0),
|
"cls": np.concatenate(cls, 0),
|
||||||
'instances': Instances.concatenate(instances, axis=0),
|
"instances": Instances.concatenate(instances, axis=0),
|
||||||
'mosaic_border': self.border} # final_labels
|
"mosaic_border": self.border,
|
||||||
final_labels['instances'].clip(imgsz, imgsz)
|
}
|
||||||
good = final_labels['instances'].remove_zero_area_boxes()
|
final_labels["instances"].clip(imgsz, imgsz)
|
||||||
final_labels['cls'] = final_labels['cls'][good]
|
good = final_labels["instances"].remove_zero_area_boxes()
|
||||||
|
final_labels["cls"] = final_labels["cls"][good]
|
||||||
return final_labels
|
return final_labels
|
||||||
|
|
||||||
|
|
||||||
@ -335,10 +338,10 @@ class MixUp(BaseMixTransform):
|
|||||||
def _mix_transform(self, labels):
|
def _mix_transform(self, labels):
|
||||||
"""Applies MixUp augmentation as per https://arxiv.org/pdf/1710.09412.pdf."""
|
"""Applies MixUp augmentation as per https://arxiv.org/pdf/1710.09412.pdf."""
|
||||||
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
||||||
labels2 = labels['mix_labels'][0]
|
labels2 = labels["mix_labels"][0]
|
||||||
labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
|
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
|
||||||
labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
|
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
|
||||||
labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
|
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@ -366,14 +369,9 @@ class RandomPerspective:
|
|||||||
box_candidates(box1, box2): Filters out bounding boxes that don't meet certain criteria post-transformation.
|
box_candidates(box1, box2): Filters out bounding boxes that don't meet certain criteria post-transformation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
degrees=0.0,
|
self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None
|
||||||
translate=0.1,
|
):
|
||||||
scale=0.5,
|
|
||||||
shear=0.0,
|
|
||||||
perspective=0.0,
|
|
||||||
border=(0, 0),
|
|
||||||
pre_transform=None):
|
|
||||||
"""Initializes RandomPerspective object with transformation parameters."""
|
"""Initializes RandomPerspective object with transformation parameters."""
|
||||||
|
|
||||||
self.degrees = degrees
|
self.degrees = degrees
|
||||||
@ -519,18 +517,18 @@ class RandomPerspective:
|
|||||||
Args:
|
Args:
|
||||||
labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
|
labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
|
||||||
"""
|
"""
|
||||||
if self.pre_transform and 'mosaic_border' not in labels:
|
if self.pre_transform and "mosaic_border" not in labels:
|
||||||
labels = self.pre_transform(labels)
|
labels = self.pre_transform(labels)
|
||||||
labels.pop('ratio_pad', None) # do not need ratio pad
|
labels.pop("ratio_pad", None) # do not need ratio pad
|
||||||
|
|
||||||
img = labels['img']
|
img = labels["img"]
|
||||||
cls = labels['cls']
|
cls = labels["cls"]
|
||||||
instances = labels.pop('instances')
|
instances = labels.pop("instances")
|
||||||
# Make sure the coord formats are right
|
# Make sure the coord formats are right
|
||||||
instances.convert_bbox(format='xyxy')
|
instances.convert_bbox(format="xyxy")
|
||||||
instances.denormalize(*img.shape[:2][::-1])
|
instances.denormalize(*img.shape[:2][::-1])
|
||||||
|
|
||||||
border = labels.pop('mosaic_border', self.border)
|
border = labels.pop("mosaic_border", self.border)
|
||||||
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
|
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
|
||||||
# M is affine matrix
|
# M is affine matrix
|
||||||
# Scale for func:`box_candidates`
|
# Scale for func:`box_candidates`
|
||||||
@ -546,20 +544,20 @@ class RandomPerspective:
|
|||||||
|
|
||||||
if keypoints is not None:
|
if keypoints is not None:
|
||||||
keypoints = self.apply_keypoints(keypoints, M)
|
keypoints = self.apply_keypoints(keypoints, M)
|
||||||
new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
|
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
|
||||||
# Clip
|
# Clip
|
||||||
new_instances.clip(*self.size)
|
new_instances.clip(*self.size)
|
||||||
|
|
||||||
# Filter instances
|
# Filter instances
|
||||||
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
|
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
|
||||||
# Make the bboxes have the same scale with new_bboxes
|
# Make the bboxes have the same scale with new_bboxes
|
||||||
i = self.box_candidates(box1=instances.bboxes.T,
|
i = self.box_candidates(
|
||||||
box2=new_instances.bboxes.T,
|
box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10
|
||||||
area_thr=0.01 if len(segments) else 0.10)
|
)
|
||||||
labels['instances'] = new_instances[i]
|
labels["instances"] = new_instances[i]
|
||||||
labels['cls'] = cls[i]
|
labels["cls"] = cls[i]
|
||||||
labels['img'] = img
|
labels["img"] = img
|
||||||
labels['resized_shape'] = img.shape[:2]
|
labels["resized_shape"] = img.shape[:2]
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
|
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
|
||||||
@ -611,7 +609,7 @@ class RandomHSV:
|
|||||||
|
|
||||||
The modified image replaces the original image in the input 'labels' dict.
|
The modified image replaces the original image in the input 'labels' dict.
|
||||||
"""
|
"""
|
||||||
img = labels['img']
|
img = labels["img"]
|
||||||
if self.hgain or self.sgain or self.vgain:
|
if self.hgain or self.sgain or self.vgain:
|
||||||
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
|
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
|
||||||
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
||||||
@ -634,7 +632,7 @@ class RandomFlip:
|
|||||||
Also updates any instances (bounding boxes, keypoints, etc.) accordingly.
|
Also updates any instances (bounding boxes, keypoints, etc.) accordingly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None:
|
def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the RandomFlip class with probability and direction.
|
Initializes the RandomFlip class with probability and direction.
|
||||||
|
|
||||||
@ -644,7 +642,7 @@ class RandomFlip:
|
|||||||
Default is 'horizontal'.
|
Default is 'horizontal'.
|
||||||
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
|
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
|
||||||
"""
|
"""
|
||||||
assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
|
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
|
||||||
assert 0 <= p <= 1.0
|
assert 0 <= p <= 1.0
|
||||||
|
|
||||||
self.p = p
|
self.p = p
|
||||||
@ -662,25 +660,25 @@ class RandomFlip:
|
|||||||
Returns:
|
Returns:
|
||||||
(dict): The same dict with the flipped image and updated instances under the 'img' and 'instances' keys.
|
(dict): The same dict with the flipped image and updated instances under the 'img' and 'instances' keys.
|
||||||
"""
|
"""
|
||||||
img = labels['img']
|
img = labels["img"]
|
||||||
instances = labels.pop('instances')
|
instances = labels.pop("instances")
|
||||||
instances.convert_bbox(format='xywh')
|
instances.convert_bbox(format="xywh")
|
||||||
h, w = img.shape[:2]
|
h, w = img.shape[:2]
|
||||||
h = 1 if instances.normalized else h
|
h = 1 if instances.normalized else h
|
||||||
w = 1 if instances.normalized else w
|
w = 1 if instances.normalized else w
|
||||||
|
|
||||||
# Flip up-down
|
# Flip up-down
|
||||||
if self.direction == 'vertical' and random.random() < self.p:
|
if self.direction == "vertical" and random.random() < self.p:
|
||||||
img = np.flipud(img)
|
img = np.flipud(img)
|
||||||
instances.flipud(h)
|
instances.flipud(h)
|
||||||
if self.direction == 'horizontal' and random.random() < self.p:
|
if self.direction == "horizontal" and random.random() < self.p:
|
||||||
img = np.fliplr(img)
|
img = np.fliplr(img)
|
||||||
instances.fliplr(w)
|
instances.fliplr(w)
|
||||||
# For keypoints
|
# For keypoints
|
||||||
if self.flip_idx is not None and instances.keypoints is not None:
|
if self.flip_idx is not None and instances.keypoints is not None:
|
||||||
instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
|
instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
|
||||||
labels['img'] = np.ascontiguousarray(img)
|
labels["img"] = np.ascontiguousarray(img)
|
||||||
labels['instances'] = instances
|
labels["instances"] = instances
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@ -700,9 +698,9 @@ class LetterBox:
|
|||||||
"""Return updated labels and image with added border."""
|
"""Return updated labels and image with added border."""
|
||||||
if labels is None:
|
if labels is None:
|
||||||
labels = {}
|
labels = {}
|
||||||
img = labels.get('img') if image is None else image
|
img = labels.get("img") if image is None else image
|
||||||
shape = img.shape[:2] # current shape [height, width]
|
shape = img.shape[:2] # current shape [height, width]
|
||||||
new_shape = labels.pop('rect_shape', self.new_shape)
|
new_shape = labels.pop("rect_shape", self.new_shape)
|
||||||
if isinstance(new_shape, int):
|
if isinstance(new_shape, int):
|
||||||
new_shape = (new_shape, new_shape)
|
new_shape = (new_shape, new_shape)
|
||||||
|
|
||||||
@ -730,25 +728,26 @@ class LetterBox:
|
|||||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||||
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
||||||
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
||||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
img = cv2.copyMakeBorder(
|
||||||
value=(114, 114, 114)) # add border
|
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
||||||
if labels.get('ratio_pad'):
|
) # add border
|
||||||
labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation
|
if labels.get("ratio_pad"):
|
||||||
|
labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
|
||||||
|
|
||||||
if len(labels):
|
if len(labels):
|
||||||
labels = self._update_labels(labels, ratio, dw, dh)
|
labels = self._update_labels(labels, ratio, dw, dh)
|
||||||
labels['img'] = img
|
labels["img"] = img
|
||||||
labels['resized_shape'] = new_shape
|
labels["resized_shape"] = new_shape
|
||||||
return labels
|
return labels
|
||||||
else:
|
else:
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def _update_labels(self, labels, ratio, padw, padh):
|
def _update_labels(self, labels, ratio, padw, padh):
|
||||||
"""Update labels."""
|
"""Update labels."""
|
||||||
labels['instances'].convert_bbox(format='xyxy')
|
labels["instances"].convert_bbox(format="xyxy")
|
||||||
labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
|
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
|
||||||
labels['instances'].scale(*ratio)
|
labels["instances"].scale(*ratio)
|
||||||
labels['instances'].add_padding(padw, padh)
|
labels["instances"].add_padding(padw, padh)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@ -785,11 +784,11 @@ class CopyPaste:
|
|||||||
1. Instances are expected to have 'segments' as one of their attributes for this augmentation to work.
|
1. Instances are expected to have 'segments' as one of their attributes for this augmentation to work.
|
||||||
2. This method modifies the input dictionary 'labels' in place.
|
2. This method modifies the input dictionary 'labels' in place.
|
||||||
"""
|
"""
|
||||||
im = labels['img']
|
im = labels["img"]
|
||||||
cls = labels['cls']
|
cls = labels["cls"]
|
||||||
h, w = im.shape[:2]
|
h, w = im.shape[:2]
|
||||||
instances = labels.pop('instances')
|
instances = labels.pop("instances")
|
||||||
instances.convert_bbox(format='xyxy')
|
instances.convert_bbox(format="xyxy")
|
||||||
instances.denormalize(w, h)
|
instances.denormalize(w, h)
|
||||||
if self.p and len(instances.segments):
|
if self.p and len(instances.segments):
|
||||||
n = len(instances)
|
n = len(instances)
|
||||||
@ -812,9 +811,9 @@ class CopyPaste:
|
|||||||
i = cv2.flip(im_new, 1).astype(bool)
|
i = cv2.flip(im_new, 1).astype(bool)
|
||||||
im[i] = result[i]
|
im[i] = result[i]
|
||||||
|
|
||||||
labels['img'] = im
|
labels["img"] = im
|
||||||
labels['cls'] = cls
|
labels["cls"] = cls
|
||||||
labels['instances'] = instances
|
labels["instances"] = instances
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@ -831,12 +830,13 @@ class Albumentations:
|
|||||||
"""Initialize the transform object for YOLO bbox formatted params."""
|
"""Initialize the transform object for YOLO bbox formatted params."""
|
||||||
self.p = p
|
self.p = p
|
||||||
self.transform = None
|
self.transform = None
|
||||||
prefix = colorstr('albumentations: ')
|
prefix = colorstr("albumentations: ")
|
||||||
try:
|
try:
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
|
|
||||||
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
||||||
|
|
||||||
|
# Transforms
|
||||||
T = [
|
T = [
|
||||||
A.Blur(p=0.01),
|
A.Blur(p=0.01),
|
||||||
A.MedianBlur(p=0.01),
|
A.MedianBlur(p=0.01),
|
||||||
@ -844,31 +844,32 @@ class Albumentations:
|
|||||||
A.CLAHE(p=0.01),
|
A.CLAHE(p=0.01),
|
||||||
A.RandomBrightnessContrast(p=0.0),
|
A.RandomBrightnessContrast(p=0.0),
|
||||||
A.RandomGamma(p=0.0),
|
A.RandomGamma(p=0.0),
|
||||||
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
|
A.ImageCompression(quality_lower=75, p=0.0),
|
||||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
]
|
||||||
|
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
|
||||||
|
|
||||||
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
||||||
except ImportError: # package not installed, skip
|
except ImportError: # package not installed, skip
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.info(f'{prefix}{e}')
|
LOGGER.info(f"{prefix}{e}")
|
||||||
|
|
||||||
def __call__(self, labels):
|
def __call__(self, labels):
|
||||||
"""Generates object detections and returns a dictionary with detection results."""
|
"""Generates object detections and returns a dictionary with detection results."""
|
||||||
im = labels['img']
|
im = labels["img"]
|
||||||
cls = labels['cls']
|
cls = labels["cls"]
|
||||||
if len(cls):
|
if len(cls):
|
||||||
labels['instances'].convert_bbox('xywh')
|
labels["instances"].convert_bbox("xywh")
|
||||||
labels['instances'].normalize(*im.shape[:2][::-1])
|
labels["instances"].normalize(*im.shape[:2][::-1])
|
||||||
bboxes = labels['instances'].bboxes
|
bboxes = labels["instances"].bboxes
|
||||||
# TODO: add supports of segments and keypoints
|
# TODO: add supports of segments and keypoints
|
||||||
if self.transform and random.random() < self.p:
|
if self.transform and random.random() < self.p:
|
||||||
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
||||||
if len(new['class_labels']) > 0: # skip update if no bbox in new im
|
if len(new["class_labels"]) > 0: # skip update if no bbox in new im
|
||||||
labels['img'] = new['image']
|
labels["img"] = new["image"]
|
||||||
labels['cls'] = np.array(new['class_labels'])
|
labels["cls"] = np.array(new["class_labels"])
|
||||||
bboxes = np.array(new['bboxes'], dtype=np.float32)
|
bboxes = np.array(new["bboxes"], dtype=np.float32)
|
||||||
labels['instances'].update(bboxes=bboxes)
|
labels["instances"].update(bboxes=bboxes)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@ -888,15 +889,17 @@ class Format:
|
|||||||
batch_idx (bool): Keep batch indexes. Default is True.
|
batch_idx (bool): Keep batch indexes. Default is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
bbox_format='xywh',
|
self,
|
||||||
|
bbox_format="xywh",
|
||||||
normalize=True,
|
normalize=True,
|
||||||
return_mask=False,
|
return_mask=False,
|
||||||
return_keypoint=False,
|
return_keypoint=False,
|
||||||
return_obb=False,
|
return_obb=False,
|
||||||
mask_ratio=4,
|
mask_ratio=4,
|
||||||
mask_overlap=True,
|
mask_overlap=True,
|
||||||
batch_idx=True):
|
batch_idx=True,
|
||||||
|
):
|
||||||
"""Initializes the Format class with given parameters."""
|
"""Initializes the Format class with given parameters."""
|
||||||
self.bbox_format = bbox_format
|
self.bbox_format = bbox_format
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
@ -909,10 +912,10 @@ class Format:
|
|||||||
|
|
||||||
def __call__(self, labels):
|
def __call__(self, labels):
|
||||||
"""Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
|
"""Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
|
||||||
img = labels.pop('img')
|
img = labels.pop("img")
|
||||||
h, w = img.shape[:2]
|
h, w = img.shape[:2]
|
||||||
cls = labels.pop('cls')
|
cls = labels.pop("cls")
|
||||||
instances = labels.pop('instances')
|
instances = labels.pop("instances")
|
||||||
instances.convert_bbox(format=self.bbox_format)
|
instances.convert_bbox(format=self.bbox_format)
|
||||||
instances.denormalize(w, h)
|
instances.denormalize(w, h)
|
||||||
nl = len(instances)
|
nl = len(instances)
|
||||||
@ -922,22 +925,24 @@ class Format:
|
|||||||
masks, instances, cls = self._format_segments(instances, cls, w, h)
|
masks, instances, cls = self._format_segments(instances, cls, w, h)
|
||||||
masks = torch.from_numpy(masks)
|
masks = torch.from_numpy(masks)
|
||||||
else:
|
else:
|
||||||
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
|
masks = torch.zeros(
|
||||||
img.shape[1] // self.mask_ratio)
|
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio
|
||||||
labels['masks'] = masks
|
)
|
||||||
|
labels["masks"] = masks
|
||||||
if self.normalize:
|
if self.normalize:
|
||||||
instances.normalize(w, h)
|
instances.normalize(w, h)
|
||||||
labels['img'] = self._format_img(img)
|
labels["img"] = self._format_img(img)
|
||||||
labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||||
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||||
if self.return_keypoint:
|
if self.return_keypoint:
|
||||||
labels['keypoints'] = torch.from_numpy(instances.keypoints)
|
labels["keypoints"] = torch.from_numpy(instances.keypoints)
|
||||||
if self.return_obb:
|
if self.return_obb:
|
||||||
labels['bboxes'] = xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(
|
labels["bboxes"] = (
|
||||||
instances.segments) else torch.zeros((0, 5))
|
xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5))
|
||||||
|
)
|
||||||
# Then we can use collate_fn
|
# Then we can use collate_fn
|
||||||
if self.batch_idx:
|
if self.batch_idx:
|
||||||
labels['batch_idx'] = torch.zeros(nl)
|
labels["batch_idx"] = torch.zeros(nl)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def _format_img(self, img):
|
def _format_img(self, img):
|
||||||
@ -964,7 +969,8 @@ class Format:
|
|||||||
|
|
||||||
def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
||||||
"""Convert images to a size suitable for YOLOv8 training."""
|
"""Convert images to a size suitable for YOLOv8 training."""
|
||||||
pre_transform = Compose([
|
pre_transform = Compose(
|
||||||
|
[
|
||||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
|
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
|
||||||
CopyPaste(p=hyp.copy_paste),
|
CopyPaste(p=hyp.copy_paste),
|
||||||
RandomPerspective(
|
RandomPerspective(
|
||||||
@ -974,23 +980,28 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
|||||||
shear=hyp.shear,
|
shear=hyp.shear,
|
||||||
perspective=hyp.perspective,
|
perspective=hyp.perspective,
|
||||||
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
|
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
|
||||||
)])
|
),
|
||||||
flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation
|
]
|
||||||
|
)
|
||||||
|
flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation
|
||||||
if dataset.use_keypoints:
|
if dataset.use_keypoints:
|
||||||
kpt_shape = dataset.data.get('kpt_shape', None)
|
kpt_shape = dataset.data.get("kpt_shape", None)
|
||||||
if len(flip_idx) == 0 and hyp.fliplr > 0.0:
|
if len(flip_idx) == 0 and hyp.fliplr > 0.0:
|
||||||
hyp.fliplr = 0.0
|
hyp.fliplr = 0.0
|
||||||
LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
|
LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
|
||||||
elif flip_idx and (len(flip_idx) != kpt_shape[0]):
|
elif flip_idx and (len(flip_idx) != kpt_shape[0]):
|
||||||
raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
|
raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}")
|
||||||
|
|
||||||
return Compose([
|
return Compose(
|
||||||
|
[
|
||||||
pre_transform,
|
pre_transform,
|
||||||
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
||||||
Albumentations(p=1.0),
|
Albumentations(p=1.0),
|
||||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||||
RandomFlip(direction='vertical', p=hyp.flipud),
|
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||||
RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms
|
RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),
|
||||||
|
]
|
||||||
|
) # transforms
|
||||||
|
|
||||||
|
|
||||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||||
@ -1031,10 +1042,13 @@ def classify_transforms(
|
|||||||
tfl = [T.Resize(scale_size)]
|
tfl = [T.Resize(scale_size)]
|
||||||
tfl += [T.CenterCrop(size)]
|
tfl += [T.CenterCrop(size)]
|
||||||
|
|
||||||
tfl += [T.ToTensor(), T.Normalize(
|
tfl += [
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(
|
||||||
mean=torch.tensor(mean),
|
mean=torch.tensor(mean),
|
||||||
std=torch.tensor(std),
|
std=torch.tensor(std),
|
||||||
)]
|
),
|
||||||
|
]
|
||||||
|
|
||||||
return T.Compose(tfl)
|
return T.Compose(tfl)
|
||||||
|
|
||||||
@ -1053,7 +1067,7 @@ def classify_augmentations(
|
|||||||
hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
|
hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
|
||||||
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
||||||
force_color_jitter=False,
|
force_color_jitter=False,
|
||||||
erasing=0.,
|
erasing=0.0,
|
||||||
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
|
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1080,13 +1094,13 @@ def classify_augmentations(
|
|||||||
"""
|
"""
|
||||||
# Transforms to apply if albumentations not installed
|
# Transforms to apply if albumentations not installed
|
||||||
if not isinstance(size, int):
|
if not isinstance(size, int):
|
||||||
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
|
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
|
||||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||||
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
|
||||||
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
|
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
|
||||||
if hflip > 0.:
|
if hflip > 0.0:
|
||||||
primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
|
primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
|
||||||
if vflip > 0.:
|
if vflip > 0.0:
|
||||||
primary_tfl += [T.RandomVerticalFlip(p=vflip)]
|
primary_tfl += [T.RandomVerticalFlip(p=vflip)]
|
||||||
|
|
||||||
secondary_tfl = []
|
secondary_tfl = []
|
||||||
@ -1097,27 +1111,29 @@ def classify_augmentations(
|
|||||||
# this allows override without breaking old hparm cfgs
|
# this allows override without breaking old hparm cfgs
|
||||||
disable_color_jitter = not force_color_jitter
|
disable_color_jitter = not force_color_jitter
|
||||||
|
|
||||||
if auto_augment == 'randaugment':
|
if auto_augment == "randaugment":
|
||||||
if TORCHVISION_0_11:
|
if TORCHVISION_0_11:
|
||||||
secondary_tfl += [T.RandAugment(interpolation=interpolation)]
|
secondary_tfl += [T.RandAugment(interpolation=interpolation)]
|
||||||
else:
|
else:
|
||||||
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
|
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
|
||||||
|
|
||||||
elif auto_augment == 'augmix':
|
elif auto_augment == "augmix":
|
||||||
if TORCHVISION_0_13:
|
if TORCHVISION_0_13:
|
||||||
secondary_tfl += [T.AugMix(interpolation=interpolation)]
|
secondary_tfl += [T.AugMix(interpolation=interpolation)]
|
||||||
else:
|
else:
|
||||||
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
|
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
|
||||||
|
|
||||||
elif auto_augment == 'autoaugment':
|
elif auto_augment == "autoaugment":
|
||||||
if TORCHVISION_0_10:
|
if TORCHVISION_0_10:
|
||||||
secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
|
secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
|
||||||
else:
|
else:
|
||||||
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
|
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
|
raise ValueError(
|
||||||
f'"augmix", "autoaugment" or None')
|
f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
|
||||||
|
f'"augmix", "autoaugment" or None'
|
||||||
|
)
|
||||||
|
|
||||||
if not disable_color_jitter:
|
if not disable_color_jitter:
|
||||||
secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
|
secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
|
||||||
@ -1125,7 +1141,8 @@ def classify_augmentations(
|
|||||||
final_tfl = [
|
final_tfl = [
|
||||||
T.ToTensor(),
|
T.ToTensor(),
|
||||||
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
|
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
|
||||||
T.RandomErasing(p=erasing, inplace=True)]
|
T.RandomErasing(p=erasing, inplace=True),
|
||||||
|
]
|
||||||
|
|
||||||
return T.Compose(primary_tfl + secondary_tfl + final_tfl)
|
return T.Compose(primary_tfl + secondary_tfl + final_tfl)
|
||||||
|
|
||||||
@ -1177,7 +1194,7 @@ class ClassifyLetterBox:
|
|||||||
|
|
||||||
# Create padded image
|
# Create padded image
|
||||||
im_out = np.full((hs, ws, 3), 114, dtype=im.dtype)
|
im_out = np.full((hs, ws, 3), 114, dtype=im.dtype)
|
||||||
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||||
return im_out
|
return im_out
|
||||||
|
|
||||||
|
|
||||||
@ -1205,7 +1222,7 @@ class CenterCrop:
|
|||||||
imh, imw = im.shape[:2]
|
imh, imw = im.shape[:2]
|
||||||
m = min(imh, imw) # min dimension
|
m = min(imh, imw) # min dimension
|
||||||
top, left = (imh - m) // 2, (imw - m) // 2
|
top, left = (imh - m) // 2, (imw - m) // 2
|
||||||
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: keep this class for backward compatibility
|
# NOTE: keep this class for backward compatibility
|
||||||
|
|||||||
@ -47,20 +47,22 @@ class BaseDataset(Dataset):
|
|||||||
transforms (callable): Image transformation function.
|
transforms (callable): Image transformation function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
img_path,
|
img_path,
|
||||||
imgsz=640,
|
imgsz=640,
|
||||||
cache=False,
|
cache=False,
|
||||||
augment=True,
|
augment=True,
|
||||||
hyp=DEFAULT_CFG,
|
hyp=DEFAULT_CFG,
|
||||||
prefix='',
|
prefix="",
|
||||||
rect=False,
|
rect=False,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
stride=32,
|
stride=32,
|
||||||
pad=0.5,
|
pad=0.5,
|
||||||
single_cls=False,
|
single_cls=False,
|
||||||
classes=None,
|
classes=None,
|
||||||
fraction=1.0):
|
fraction=1.0,
|
||||||
|
):
|
||||||
"""Initialize BaseDataset with given configuration and options."""
|
"""Initialize BaseDataset with given configuration and options."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.img_path = img_path
|
self.img_path = img_path
|
||||||
@ -86,10 +88,10 @@ class BaseDataset(Dataset):
|
|||||||
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
||||||
|
|
||||||
# Cache images
|
# Cache images
|
||||||
if cache == 'ram' and not self.check_cache_ram():
|
if cache == "ram" and not self.check_cache_ram():
|
||||||
cache = False
|
cache = False
|
||||||
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
||||||
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
|
||||||
if cache:
|
if cache:
|
||||||
self.cache_images(cache)
|
self.cache_images(cache)
|
||||||
|
|
||||||
@ -103,23 +105,23 @@ class BaseDataset(Dataset):
|
|||||||
for p in img_path if isinstance(img_path, list) else [img_path]:
|
for p in img_path if isinstance(img_path, list) else [img_path]:
|
||||||
p = Path(p) # os-agnostic
|
p = Path(p) # os-agnostic
|
||||||
if p.is_dir(): # dir
|
if p.is_dir(): # dir
|
||||||
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
|
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
|
||||||
# F = list(p.rglob('*.*')) # pathlib
|
# F = list(p.rglob('*.*')) # pathlib
|
||||||
elif p.is_file(): # file
|
elif p.is_file(): # file
|
||||||
with open(p) as t:
|
with open(p) as t:
|
||||||
t = t.read().strip().splitlines()
|
t = t.read().strip().splitlines()
|
||||||
parent = str(p.parent) + os.sep
|
parent = str(p.parent) + os.sep
|
||||||
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
|
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
|
||||||
# F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
# F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f'{self.prefix}{p} does not exist')
|
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
||||||
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
|
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
||||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||||
assert im_files, f'{self.prefix}No images found in {img_path}'
|
assert im_files, f"{self.prefix}No images found in {img_path}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
|
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||||||
if self.fraction < 1:
|
if self.fraction < 1:
|
||||||
im_files = im_files[:round(len(im_files) * self.fraction)]
|
im_files = im_files[: round(len(im_files) * self.fraction)]
|
||||||
return im_files
|
return im_files
|
||||||
|
|
||||||
def update_labels(self, include_class: Optional[list]):
|
def update_labels(self, include_class: Optional[list]):
|
||||||
@ -127,19 +129,19 @@ class BaseDataset(Dataset):
|
|||||||
include_class_array = np.array(include_class).reshape(1, -1)
|
include_class_array = np.array(include_class).reshape(1, -1)
|
||||||
for i in range(len(self.labels)):
|
for i in range(len(self.labels)):
|
||||||
if include_class is not None:
|
if include_class is not None:
|
||||||
cls = self.labels[i]['cls']
|
cls = self.labels[i]["cls"]
|
||||||
bboxes = self.labels[i]['bboxes']
|
bboxes = self.labels[i]["bboxes"]
|
||||||
segments = self.labels[i]['segments']
|
segments = self.labels[i]["segments"]
|
||||||
keypoints = self.labels[i]['keypoints']
|
keypoints = self.labels[i]["keypoints"]
|
||||||
j = (cls == include_class_array).any(1)
|
j = (cls == include_class_array).any(1)
|
||||||
self.labels[i]['cls'] = cls[j]
|
self.labels[i]["cls"] = cls[j]
|
||||||
self.labels[i]['bboxes'] = bboxes[j]
|
self.labels[i]["bboxes"] = bboxes[j]
|
||||||
if segments:
|
if segments:
|
||||||
self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
|
self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
|
||||||
if keypoints is not None:
|
if keypoints is not None:
|
||||||
self.labels[i]['keypoints'] = keypoints[j]
|
self.labels[i]["keypoints"] = keypoints[j]
|
||||||
if self.single_cls:
|
if self.single_cls:
|
||||||
self.labels[i]['cls'][:, 0] = 0
|
self.labels[i]["cls"][:, 0] = 0
|
||||||
|
|
||||||
def load_image(self, i, rect_mode=True):
|
def load_image(self, i, rect_mode=True):
|
||||||
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
||||||
@ -149,13 +151,13 @@ class BaseDataset(Dataset):
|
|||||||
try:
|
try:
|
||||||
im = np.load(fn)
|
im = np.load(fn)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}')
|
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}")
|
||||||
Path(fn).unlink(missing_ok=True)
|
Path(fn).unlink(missing_ok=True)
|
||||||
im = cv2.imread(f) # BGR
|
im = cv2.imread(f) # BGR
|
||||||
else: # read image
|
else: # read image
|
||||||
im = cv2.imread(f) # BGR
|
im = cv2.imread(f) # BGR
|
||||||
if im is None:
|
if im is None:
|
||||||
raise FileNotFoundError(f'Image Not Found {f}')
|
raise FileNotFoundError(f"Image Not Found {f}")
|
||||||
|
|
||||||
h0, w0 = im.shape[:2] # orig hw
|
h0, w0 = im.shape[:2] # orig hw
|
||||||
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
|
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
|
||||||
@ -181,17 +183,17 @@ class BaseDataset(Dataset):
|
|||||||
def cache_images(self, cache):
|
def cache_images(self, cache):
|
||||||
"""Cache images to memory or disk."""
|
"""Cache images to memory or disk."""
|
||||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||||
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
|
||||||
with ThreadPool(NUM_THREADS) as pool:
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
results = pool.imap(fcn, range(self.ni))
|
results = pool.imap(fcn, range(self.ni))
|
||||||
pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
|
pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
|
||||||
for i, x in pbar:
|
for i, x in pbar:
|
||||||
if cache == 'disk':
|
if cache == "disk":
|
||||||
b += self.npy_files[i].stat().st_size
|
b += self.npy_files[i].stat().st_size
|
||||||
else: # 'ram'
|
else: # 'ram'
|
||||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||||
b += self.ims[i].nbytes
|
b += self.ims[i].nbytes
|
||||||
pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
|
pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})"
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
def cache_images_to_disk(self, i):
|
def cache_images_to_disk(self, i):
|
||||||
@ -207,15 +209,17 @@ class BaseDataset(Dataset):
|
|||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
im = cv2.imread(random.choice(self.im_files)) # sample image
|
im = cv2.imread(random.choice(self.im_files)) # sample image
|
||||||
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
|
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
|
||||||
b += im.nbytes * ratio ** 2
|
b += im.nbytes * ratio**2
|
||||||
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
|
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
|
||||||
mem = psutil.virtual_memory()
|
mem = psutil.virtual_memory()
|
||||||
cache = mem_required < mem.available # to cache or not to cache, that is the question
|
cache = mem_required < mem.available # to cache or not to cache, that is the question
|
||||||
if not cache:
|
if not cache:
|
||||||
LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
|
LOGGER.info(
|
||||||
|
f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
|
||||||
f'with {int(safety_margin * 100)}% safety margin but only '
|
f'with {int(safety_margin * 100)}% safety margin but only '
|
||||||
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
|
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
|
||||||
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
|
f"{'caching images ✅' if cache else 'not caching images ⚠️'}"
|
||||||
|
)
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def set_rectangle(self):
|
def set_rectangle(self):
|
||||||
@ -223,7 +227,7 @@ class BaseDataset(Dataset):
|
|||||||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||||
nb = bi[-1] + 1 # number of batches
|
nb = bi[-1] + 1 # number of batches
|
||||||
|
|
||||||
s = np.array([x.pop('shape') for x in self.labels]) # hw
|
s = np.array([x.pop("shape") for x in self.labels]) # hw
|
||||||
ar = s[:, 0] / s[:, 1] # aspect ratio
|
ar = s[:, 0] / s[:, 1] # aspect ratio
|
||||||
irect = ar.argsort()
|
irect = ar.argsort()
|
||||||
self.im_files = [self.im_files[i] for i in irect]
|
self.im_files = [self.im_files[i] for i in irect]
|
||||||
@ -250,12 +254,14 @@ class BaseDataset(Dataset):
|
|||||||
def get_image_and_label(self, index):
|
def get_image_and_label(self, index):
|
||||||
"""Get and return label information from the dataset."""
|
"""Get and return label information from the dataset."""
|
||||||
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
||||||
label.pop('shape', None) # shape is for rect, remove it
|
label.pop("shape", None) # shape is for rect, remove it
|
||||||
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
|
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
|
||||||
label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],
|
label["ratio_pad"] = (
|
||||||
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
|
label["resized_shape"][0] / label["ori_shape"][0],
|
||||||
|
label["resized_shape"][1] / label["ori_shape"][1],
|
||||||
|
) # for evaluation
|
||||||
if self.rect:
|
if self.rect:
|
||||||
label['rect_shape'] = self.batch_shapes[self.batch[index]]
|
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
||||||
return self.update_labels_info(label)
|
return self.update_labels_info(label)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|||||||
@ -9,8 +9,16 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import dataloader, distributed
|
from torch.utils.data import dataloader, distributed
|
||||||
|
|
||||||
from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor,
|
from ultralytics.data.loaders import (
|
||||||
SourceTypes, autocast_list)
|
LOADERS,
|
||||||
|
LoadImages,
|
||||||
|
LoadPilAndNumpy,
|
||||||
|
LoadScreenshots,
|
||||||
|
LoadStreams,
|
||||||
|
LoadTensor,
|
||||||
|
SourceTypes,
|
||||||
|
autocast_list,
|
||||||
|
)
|
||||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||||||
from ultralytics.utils import RANK, colorstr
|
from ultralytics.utils import RANK, colorstr
|
||||||
from ultralytics.utils.checks import check_file
|
from ultralytics.utils.checks import check_file
|
||||||
@ -29,7 +37,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
|
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
|
||||||
self.iterator = super().__iter__()
|
self.iterator = super().__iter__()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -70,29 +78,30 @@ class _RepeatSampler:
|
|||||||
|
|
||||||
def seed_worker(worker_id): # noqa
|
def seed_worker(worker_id): # noqa
|
||||||
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
|
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
|
||||||
worker_seed = torch.initial_seed() % 2 ** 32
|
worker_seed = torch.initial_seed() % 2**32
|
||||||
np.random.seed(worker_seed)
|
np.random.seed(worker_seed)
|
||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
|
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
|
||||||
"""Build YOLO Dataset."""
|
"""Build YOLO Dataset."""
|
||||||
return YOLODataset(
|
return YOLODataset(
|
||||||
img_path=img_path,
|
img_path=img_path,
|
||||||
imgsz=cfg.imgsz,
|
imgsz=cfg.imgsz,
|
||||||
batch_size=batch,
|
batch_size=batch,
|
||||||
augment=mode == 'train', # augmentation
|
augment=mode == "train", # augmentation
|
||||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||||
rect=cfg.rect or rect, # rectangular batches
|
rect=cfg.rect or rect, # rectangular batches
|
||||||
cache=cfg.cache or None,
|
cache=cfg.cache or None,
|
||||||
single_cls=cfg.single_cls or False,
|
single_cls=cfg.single_cls or False,
|
||||||
stride=int(stride),
|
stride=int(stride),
|
||||||
pad=0.0 if mode == 'train' else 0.5,
|
pad=0.0 if mode == "train" else 0.5,
|
||||||
prefix=colorstr(f'{mode}: '),
|
prefix=colorstr(f"{mode}: "),
|
||||||
task=cfg.task,
|
task=cfg.task,
|
||||||
classes=cfg.classes,
|
classes=cfg.classes,
|
||||||
data=data,
|
data=data,
|
||||||
fraction=cfg.fraction if mode == 'train' else 1.0)
|
fraction=cfg.fraction if mode == "train" else 1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
||||||
@ -103,15 +112,17 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
|||||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator.manual_seed(6148914691236517205 + RANK)
|
generator.manual_seed(6148914691236517205 + RANK)
|
||||||
return InfiniteDataLoader(dataset=dataset,
|
return InfiniteDataLoader(
|
||||||
|
dataset=dataset,
|
||||||
batch_size=batch,
|
batch_size=batch,
|
||||||
shuffle=shuffle and sampler is None,
|
shuffle=shuffle and sampler is None,
|
||||||
num_workers=nw,
|
num_workers=nw,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=PIN_MEMORY,
|
pin_memory=PIN_MEMORY,
|
||||||
collate_fn=getattr(dataset, 'collate_fn', None),
|
collate_fn=getattr(dataset, "collate_fn", None),
|
||||||
worker_init_fn=seed_worker,
|
worker_init_fn=seed_worker,
|
||||||
generator=generator)
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_source(source):
|
def check_source(source):
|
||||||
@ -120,9 +131,9 @@ def check_source(source):
|
|||||||
if isinstance(source, (str, int, Path)): # int for local usb camera
|
if isinstance(source, (str, int, Path)): # int for local usb camera
|
||||||
source = str(source)
|
source = str(source)
|
||||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||||
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://'))
|
is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
|
||||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
|
||||||
screenshot = source.lower() == 'screen'
|
screenshot = source.lower() == "screen"
|
||||||
if is_url and is_file:
|
if is_url and is_file:
|
||||||
source = check_file(source) # download
|
source = check_file(source) # download
|
||||||
elif isinstance(source, LOADERS):
|
elif isinstance(source, LOADERS):
|
||||||
@ -135,7 +146,7 @@ def check_source(source):
|
|||||||
elif isinstance(source, torch.Tensor):
|
elif isinstance(source, torch.Tensor):
|
||||||
tensor = True
|
tensor = True
|
||||||
else:
|
else:
|
||||||
raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict')
|
raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict")
|
||||||
|
|
||||||
return source, webcam, screenshot, from_img, in_memory, tensor
|
return source, webcam, screenshot, from_img, in_memory, tensor
|
||||||
|
|
||||||
@ -171,6 +182,6 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
|
|||||||
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
|
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
|
||||||
|
|
||||||
# Attach source types to the dataset
|
# Attach source types to the dataset
|
||||||
setattr(dataset, 'source_type', source_type)
|
setattr(dataset, "source_type", source_type)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@ -20,10 +20,98 @@ def coco91_to_coco80_class():
|
|||||||
corresponding 91-index class ID.
|
corresponding 91-index class ID.
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
|
0,
|
||||||
None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
|
1,
|
||||||
51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
|
2,
|
||||||
None, 73, 74, 75, 76, 77, 78, 79, None]
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
None,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
14,
|
||||||
|
15,
|
||||||
|
16,
|
||||||
|
17,
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
20,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
None,
|
||||||
|
24,
|
||||||
|
25,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
26,
|
||||||
|
27,
|
||||||
|
28,
|
||||||
|
29,
|
||||||
|
30,
|
||||||
|
31,
|
||||||
|
32,
|
||||||
|
33,
|
||||||
|
34,
|
||||||
|
35,
|
||||||
|
36,
|
||||||
|
37,
|
||||||
|
38,
|
||||||
|
39,
|
||||||
|
None,
|
||||||
|
40,
|
||||||
|
41,
|
||||||
|
42,
|
||||||
|
43,
|
||||||
|
44,
|
||||||
|
45,
|
||||||
|
46,
|
||||||
|
47,
|
||||||
|
48,
|
||||||
|
49,
|
||||||
|
50,
|
||||||
|
51,
|
||||||
|
52,
|
||||||
|
53,
|
||||||
|
54,
|
||||||
|
55,
|
||||||
|
56,
|
||||||
|
57,
|
||||||
|
58,
|
||||||
|
59,
|
||||||
|
None,
|
||||||
|
60,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
61,
|
||||||
|
None,
|
||||||
|
62,
|
||||||
|
63,
|
||||||
|
64,
|
||||||
|
65,
|
||||||
|
66,
|
||||||
|
67,
|
||||||
|
68,
|
||||||
|
69,
|
||||||
|
70,
|
||||||
|
71,
|
||||||
|
72,
|
||||||
|
None,
|
||||||
|
73,
|
||||||
|
74,
|
||||||
|
75,
|
||||||
|
76,
|
||||||
|
77,
|
||||||
|
78,
|
||||||
|
79,
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def coco80_to_coco91_class():
|
def coco80_to_coco91_class():
|
||||||
@ -42,16 +130,96 @@ def coco80_to_coco91_class():
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
1,
|
||||||
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
2,
|
||||||
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
13,
|
||||||
|
14,
|
||||||
|
15,
|
||||||
|
16,
|
||||||
|
17,
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
20,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
24,
|
||||||
|
25,
|
||||||
|
27,
|
||||||
|
28,
|
||||||
|
31,
|
||||||
|
32,
|
||||||
|
33,
|
||||||
|
34,
|
||||||
|
35,
|
||||||
|
36,
|
||||||
|
37,
|
||||||
|
38,
|
||||||
|
39,
|
||||||
|
40,
|
||||||
|
41,
|
||||||
|
42,
|
||||||
|
43,
|
||||||
|
44,
|
||||||
|
46,
|
||||||
|
47,
|
||||||
|
48,
|
||||||
|
49,
|
||||||
|
50,
|
||||||
|
51,
|
||||||
|
52,
|
||||||
|
53,
|
||||||
|
54,
|
||||||
|
55,
|
||||||
|
56,
|
||||||
|
57,
|
||||||
|
58,
|
||||||
|
59,
|
||||||
|
60,
|
||||||
|
61,
|
||||||
|
62,
|
||||||
|
63,
|
||||||
|
64,
|
||||||
|
65,
|
||||||
|
67,
|
||||||
|
70,
|
||||||
|
72,
|
||||||
|
73,
|
||||||
|
74,
|
||||||
|
75,
|
||||||
|
76,
|
||||||
|
77,
|
||||||
|
78,
|
||||||
|
79,
|
||||||
|
80,
|
||||||
|
81,
|
||||||
|
82,
|
||||||
|
84,
|
||||||
|
85,
|
||||||
|
86,
|
||||||
|
87,
|
||||||
|
88,
|
||||||
|
89,
|
||||||
|
90,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def convert_coco(labels_dir='../coco/annotations/',
|
def convert_coco(
|
||||||
save_dir='coco_converted/',
|
labels_dir="../coco/annotations/",
|
||||||
|
save_dir="coco_converted/",
|
||||||
use_segments=False,
|
use_segments=False,
|
||||||
use_keypoints=False,
|
use_keypoints=False,
|
||||||
cls91to80=True):
|
cls91to80=True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
|
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
|
||||||
|
|
||||||
@ -75,76 +243,78 @@ def convert_coco(labels_dir='../coco/annotations/',
|
|||||||
|
|
||||||
# Create dataset directory
|
# Create dataset directory
|
||||||
save_dir = increment_path(save_dir) # increment if save directory already exists
|
save_dir = increment_path(save_dir) # increment if save directory already exists
|
||||||
for p in save_dir / 'labels', save_dir / 'images':
|
for p in save_dir / "labels", save_dir / "images":
|
||||||
p.mkdir(parents=True, exist_ok=True) # make dir
|
p.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
|
|
||||||
# Convert classes
|
# Convert classes
|
||||||
coco80 = coco91_to_coco80_class()
|
coco80 = coco91_to_coco80_class()
|
||||||
|
|
||||||
# Import json
|
# Import json
|
||||||
for json_file in sorted(Path(labels_dir).resolve().glob('*.json')):
|
for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
|
||||||
fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '') # folder name
|
fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name
|
||||||
fn.mkdir(parents=True, exist_ok=True)
|
fn.mkdir(parents=True, exist_ok=True)
|
||||||
with open(json_file) as f:
|
with open(json_file) as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
# Create image dict
|
# Create image dict
|
||||||
images = {f'{x["id"]:d}': x for x in data['images']}
|
images = {f'{x["id"]:d}': x for x in data["images"]}
|
||||||
# Create image-annotations dict
|
# Create image-annotations dict
|
||||||
imgToAnns = defaultdict(list)
|
imgToAnns = defaultdict(list)
|
||||||
for ann in data['annotations']:
|
for ann in data["annotations"]:
|
||||||
imgToAnns[ann['image_id']].append(ann)
|
imgToAnns[ann["image_id"]].append(ann)
|
||||||
|
|
||||||
# Write labels file
|
# Write labels file
|
||||||
for img_id, anns in TQDM(imgToAnns.items(), desc=f'Annotations {json_file}'):
|
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
|
||||||
img = images[f'{img_id:d}']
|
img = images[f"{img_id:d}"]
|
||||||
h, w, f = img['height'], img['width'], img['file_name']
|
h, w, f = img["height"], img["width"], img["file_name"]
|
||||||
|
|
||||||
bboxes = []
|
bboxes = []
|
||||||
segments = []
|
segments = []
|
||||||
keypoints = []
|
keypoints = []
|
||||||
for ann in anns:
|
for ann in anns:
|
||||||
if ann['iscrowd']:
|
if ann["iscrowd"]:
|
||||||
continue
|
continue
|
||||||
# The COCO box format is [top left x, top left y, width, height]
|
# The COCO box format is [top left x, top left y, width, height]
|
||||||
box = np.array(ann['bbox'], dtype=np.float64)
|
box = np.array(ann["bbox"], dtype=np.float64)
|
||||||
box[:2] += box[2:] / 2 # xy top-left corner to center
|
box[:2] += box[2:] / 2 # xy top-left corner to center
|
||||||
box[[0, 2]] /= w # normalize x
|
box[[0, 2]] /= w # normalize x
|
||||||
box[[1, 3]] /= h # normalize y
|
box[[1, 3]] /= h # normalize y
|
||||||
if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
|
if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1 # class
|
cls = coco80[ann["category_id"] - 1] if cls91to80 else ann["category_id"] - 1 # class
|
||||||
box = [cls] + box.tolist()
|
box = [cls] + box.tolist()
|
||||||
if box not in bboxes:
|
if box not in bboxes:
|
||||||
bboxes.append(box)
|
bboxes.append(box)
|
||||||
if use_segments and ann.get('segmentation') is not None:
|
if use_segments and ann.get("segmentation") is not None:
|
||||||
if len(ann['segmentation']) == 0:
|
if len(ann["segmentation"]) == 0:
|
||||||
segments.append([])
|
segments.append([])
|
||||||
continue
|
continue
|
||||||
elif len(ann['segmentation']) > 1:
|
elif len(ann["segmentation"]) > 1:
|
||||||
s = merge_multi_segment(ann['segmentation'])
|
s = merge_multi_segment(ann["segmentation"])
|
||||||
s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
|
s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
|
||||||
else:
|
else:
|
||||||
s = [j for i in ann['segmentation'] for j in i] # all segments concatenated
|
s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
|
||||||
s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()
|
s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()
|
||||||
s = [cls] + s
|
s = [cls] + s
|
||||||
segments.append(s)
|
segments.append(s)
|
||||||
if use_keypoints and ann.get('keypoints') is not None:
|
if use_keypoints and ann.get("keypoints") is not None:
|
||||||
keypoints.append(box + (np.array(ann['keypoints']).reshape(-1, 3) /
|
keypoints.append(
|
||||||
np.array([w, h, 1])).reshape(-1).tolist())
|
box + (np.array(ann["keypoints"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()
|
||||||
|
)
|
||||||
|
|
||||||
# Write
|
# Write
|
||||||
with open((fn / f).with_suffix('.txt'), 'a') as file:
|
with open((fn / f).with_suffix(".txt"), "a") as file:
|
||||||
for i in range(len(bboxes)):
|
for i in range(len(bboxes)):
|
||||||
if use_keypoints:
|
if use_keypoints:
|
||||||
line = *(keypoints[i]), # cls, box, keypoints
|
line = (*(keypoints[i]),) # cls, box, keypoints
|
||||||
else:
|
else:
|
||||||
line = *(segments[i]
|
line = (
|
||||||
if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments
|
*(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]),
|
||||||
file.write(('%g ' * len(line)).rstrip() % line + '\n')
|
) # cls, box or segments
|
||||||
|
file.write(("%g " * len(line)).rstrip() % line + "\n")
|
||||||
|
|
||||||
LOGGER.info(f'COCO data converted successfully.\nResults saved to {save_dir.resolve()}')
|
LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}")
|
||||||
|
|
||||||
|
|
||||||
def convert_dota_to_yolo_obb(dota_root_path: str):
|
def convert_dota_to_yolo_obb(dota_root_path: str):
|
||||||
@ -184,31 +354,32 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
|
|||||||
|
|
||||||
# Class names to indices mapping
|
# Class names to indices mapping
|
||||||
class_mapping = {
|
class_mapping = {
|
||||||
'plane': 0,
|
"plane": 0,
|
||||||
'ship': 1,
|
"ship": 1,
|
||||||
'storage-tank': 2,
|
"storage-tank": 2,
|
||||||
'baseball-diamond': 3,
|
"baseball-diamond": 3,
|
||||||
'tennis-court': 4,
|
"tennis-court": 4,
|
||||||
'basketball-court': 5,
|
"basketball-court": 5,
|
||||||
'ground-track-field': 6,
|
"ground-track-field": 6,
|
||||||
'harbor': 7,
|
"harbor": 7,
|
||||||
'bridge': 8,
|
"bridge": 8,
|
||||||
'large-vehicle': 9,
|
"large-vehicle": 9,
|
||||||
'small-vehicle': 10,
|
"small-vehicle": 10,
|
||||||
'helicopter': 11,
|
"helicopter": 11,
|
||||||
'roundabout': 12,
|
"roundabout": 12,
|
||||||
'soccer-ball-field': 13,
|
"soccer-ball-field": 13,
|
||||||
'swimming-pool': 14,
|
"swimming-pool": 14,
|
||||||
'container-crane': 15,
|
"container-crane": 15,
|
||||||
'airport': 16,
|
"airport": 16,
|
||||||
'helipad': 17}
|
"helipad": 17,
|
||||||
|
}
|
||||||
|
|
||||||
def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir):
|
def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir):
|
||||||
"""Converts a single image's DOTA annotation to YOLO OBB format and saves it to a specified directory."""
|
"""Converts a single image's DOTA annotation to YOLO OBB format and saves it to a specified directory."""
|
||||||
orig_label_path = orig_label_dir / f'{image_name}.txt'
|
orig_label_path = orig_label_dir / f"{image_name}.txt"
|
||||||
save_path = save_dir / f'{image_name}.txt'
|
save_path = save_dir / f"{image_name}.txt"
|
||||||
|
|
||||||
with orig_label_path.open('r') as f, save_path.open('w') as g:
|
with orig_label_path.open("r") as f, save_path.open("w") as g:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
for line in lines:
|
for line in lines:
|
||||||
parts = line.strip().split()
|
parts = line.strip().split()
|
||||||
@ -218,20 +389,21 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
|
|||||||
class_idx = class_mapping[class_name]
|
class_idx = class_mapping[class_name]
|
||||||
coords = [float(p) for p in parts[:8]]
|
coords = [float(p) for p in parts[:8]]
|
||||||
normalized_coords = [
|
normalized_coords = [
|
||||||
coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)]
|
coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)
|
||||||
formatted_coords = ['{:.6g}'.format(coord) for coord in normalized_coords]
|
]
|
||||||
|
formatted_coords = ["{:.6g}".format(coord) for coord in normalized_coords]
|
||||||
g.write(f"{class_idx} {' '.join(formatted_coords)}\n")
|
g.write(f"{class_idx} {' '.join(formatted_coords)}\n")
|
||||||
|
|
||||||
for phase in ['train', 'val']:
|
for phase in ["train", "val"]:
|
||||||
image_dir = dota_root_path / 'images' / phase
|
image_dir = dota_root_path / "images" / phase
|
||||||
orig_label_dir = dota_root_path / 'labels' / f'{phase}_original'
|
orig_label_dir = dota_root_path / "labels" / f"{phase}_original"
|
||||||
save_dir = dota_root_path / 'labels' / phase
|
save_dir = dota_root_path / "labels" / phase
|
||||||
|
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
image_paths = list(image_dir.iterdir())
|
image_paths = list(image_dir.iterdir())
|
||||||
for image_path in TQDM(image_paths, desc=f'Processing {phase} images'):
|
for image_path in TQDM(image_paths, desc=f"Processing {phase} images"):
|
||||||
if image_path.suffix != '.png':
|
if image_path.suffix != ".png":
|
||||||
continue
|
continue
|
||||||
image_name_without_ext = image_path.stem
|
image_name_without_ext = image_path.stem
|
||||||
img = cv2.imread(str(image_path))
|
img = cv2.imread(str(image_path))
|
||||||
@ -293,7 +465,7 @@ def merge_multi_segment(segments):
|
|||||||
s.append(segments[i])
|
s.append(segments[i])
|
||||||
else:
|
else:
|
||||||
idx = [0, idx[1] - idx[0]]
|
idx = [0, idx[1] - idx[0]]
|
||||||
s.append(segments[i][idx[0]:idx[1] + 1])
|
s.append(segments[i][idx[0] : idx[1] + 1])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for i in range(len(idx_list) - 1, -1, -1):
|
for i in range(len(idx_list) - 1, -1, -1):
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from .base import BaseDataset
|
|||||||
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
||||||
|
|
||||||
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
||||||
DATASET_CACHE_VERSION = '1.0.3'
|
DATASET_CACHE_VERSION = "1.0.3"
|
||||||
|
|
||||||
|
|
||||||
class YOLODataset(BaseDataset):
|
class YOLODataset(BaseDataset):
|
||||||
@ -33,16 +33,16 @@ class YOLODataset(BaseDataset):
|
|||||||
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, data=None, task='detect', **kwargs):
|
def __init__(self, *args, data=None, task="detect", **kwargs):
|
||||||
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
|
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
|
||||||
self.use_segments = task == 'segment'
|
self.use_segments = task == "segment"
|
||||||
self.use_keypoints = task == 'pose'
|
self.use_keypoints = task == "pose"
|
||||||
self.use_obb = task == 'obb'
|
self.use_obb = task == "obb"
|
||||||
self.data = data
|
self.data = data
|
||||||
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
|
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def cache_labels(self, path=Path('./labels.cache')):
|
def cache_labels(self, path=Path("./labels.cache")):
|
||||||
"""
|
"""
|
||||||
Cache dataset labels, check images and read shapes.
|
Cache dataset labels, check images and read shapes.
|
||||||
|
|
||||||
@ -51,19 +51,29 @@ class YOLODataset(BaseDataset):
|
|||||||
Returns:
|
Returns:
|
||||||
(dict): labels.
|
(dict): labels.
|
||||||
"""
|
"""
|
||||||
x = {'labels': []}
|
x = {"labels": []}
|
||||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||||
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
|
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||||
total = len(self.im_files)
|
total = len(self.im_files)
|
||||||
nkpt, ndim = self.data.get('kpt_shape', (0, 0))
|
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
||||||
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
||||||
raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
raise ValueError(
|
||||||
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
|
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||||||
|
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
||||||
|
)
|
||||||
with ThreadPool(NUM_THREADS) as pool:
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
results = pool.imap(func=verify_image_label,
|
results = pool.imap(
|
||||||
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
|
func=verify_image_label,
|
||||||
repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
|
iterable=zip(
|
||||||
repeat(ndim)))
|
self.im_files,
|
||||||
|
self.label_files,
|
||||||
|
repeat(self.prefix),
|
||||||
|
repeat(self.use_keypoints),
|
||||||
|
repeat(len(self.data["names"])),
|
||||||
|
repeat(nkpt),
|
||||||
|
repeat(ndim),
|
||||||
|
),
|
||||||
|
)
|
||||||
pbar = TQDM(results, desc=desc, total=total)
|
pbar = TQDM(results, desc=desc, total=total)
|
||||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||||
nm += nm_f
|
nm += nm_f
|
||||||
@ -71,7 +81,7 @@ class YOLODataset(BaseDataset):
|
|||||||
ne += ne_f
|
ne += ne_f
|
||||||
nc += nc_f
|
nc += nc_f
|
||||||
if im_file:
|
if im_file:
|
||||||
x['labels'].append(
|
x["labels"].append(
|
||||||
dict(
|
dict(
|
||||||
im_file=im_file,
|
im_file=im_file,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
@ -80,60 +90,63 @@ class YOLODataset(BaseDataset):
|
|||||||
segments=segments,
|
segments=segments,
|
||||||
keypoints=keypoint,
|
keypoints=keypoint,
|
||||||
normalized=True,
|
normalized=True,
|
||||||
bbox_format='xywh'))
|
bbox_format="xywh",
|
||||||
|
)
|
||||||
|
)
|
||||||
if msg:
|
if msg:
|
||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
if msgs:
|
if msgs:
|
||||||
LOGGER.info('\n'.join(msgs))
|
LOGGER.info("\n".join(msgs))
|
||||||
if nf == 0:
|
if nf == 0:
|
||||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
|
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
|
||||||
x['hash'] = get_hash(self.label_files + self.im_files)
|
x["hash"] = get_hash(self.label_files + self.im_files)
|
||||||
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||||
x['msgs'] = msgs # warnings
|
x["msgs"] = msgs # warnings
|
||||||
save_dataset_cache_file(self.prefix, path, x)
|
save_dataset_cache_file(self.prefix, path, x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""Returns dictionary of labels for YOLO training."""
|
"""Returns dictionary of labels for YOLO training."""
|
||||||
self.label_files = img2label_paths(self.im_files)
|
self.label_files = img2label_paths(self.im_files)
|
||||||
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
||||||
try:
|
try:
|
||||||
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
||||||
assert cache['version'] == DATASET_CACHE_VERSION # matches current version
|
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||||
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
||||||
except (FileNotFoundError, AssertionError, AttributeError):
|
except (FileNotFoundError, AssertionError, AttributeError):
|
||||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||||
|
|
||||||
# Display cache
|
# Display cache
|
||||||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||||
if exists and LOCAL_RANK in (-1, 0):
|
if exists and LOCAL_RANK in (-1, 0):
|
||||||
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||||
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
||||||
if cache['msgs']:
|
if cache["msgs"]:
|
||||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||||
|
|
||||||
# Read cache
|
# Read cache
|
||||||
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||||||
labels = cache['labels']
|
labels = cache["labels"]
|
||||||
if not labels:
|
if not labels:
|
||||||
LOGGER.warning(f'WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}')
|
LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
|
||||||
self.im_files = [lb['im_file'] for lb in labels] # update im_files
|
self.im_files = [lb["im_file"] for lb in labels] # update im_files
|
||||||
|
|
||||||
# Check if the dataset is all boxes or all segments
|
# Check if the dataset is all boxes or all segments
|
||||||
lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
|
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
|
||||||
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
|
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
|
||||||
if len_segments and len_boxes != len_segments:
|
if len_segments and len_boxes != len_segments:
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
|
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
||||||
f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
|
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
||||||
'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
|
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
|
||||||
|
)
|
||||||
for lb in labels:
|
for lb in labels:
|
||||||
lb['segments'] = []
|
lb["segments"] = []
|
||||||
if len_cls == 0:
|
if len_cls == 0:
|
||||||
LOGGER.warning(f'WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}')
|
LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def build_transforms(self, hyp=None):
|
def build_transforms(self, hyp=None):
|
||||||
@ -145,14 +158,17 @@ class YOLODataset(BaseDataset):
|
|||||||
else:
|
else:
|
||||||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
||||||
transforms.append(
|
transforms.append(
|
||||||
Format(bbox_format='xywh',
|
Format(
|
||||||
|
bbox_format="xywh",
|
||||||
normalize=True,
|
normalize=True,
|
||||||
return_mask=self.use_segments,
|
return_mask=self.use_segments,
|
||||||
return_keypoint=self.use_keypoints,
|
return_keypoint=self.use_keypoints,
|
||||||
return_obb=self.use_obb,
|
return_obb=self.use_obb,
|
||||||
batch_idx=True,
|
batch_idx=True,
|
||||||
mask_ratio=hyp.mask_ratio,
|
mask_ratio=hyp.mask_ratio,
|
||||||
mask_overlap=hyp.overlap_mask))
|
mask_overlap=hyp.overlap_mask,
|
||||||
|
)
|
||||||
|
)
|
||||||
return transforms
|
return transforms
|
||||||
|
|
||||||
def close_mosaic(self, hyp):
|
def close_mosaic(self, hyp):
|
||||||
@ -166,11 +182,11 @@ class YOLODataset(BaseDataset):
|
|||||||
"""Custom your label format here."""
|
"""Custom your label format here."""
|
||||||
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
||||||
# We can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
# We can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
||||||
bboxes = label.pop('bboxes')
|
bboxes = label.pop("bboxes")
|
||||||
segments = label.pop('segments', [])
|
segments = label.pop("segments", [])
|
||||||
keypoints = label.pop('keypoints', None)
|
keypoints = label.pop("keypoints", None)
|
||||||
bbox_format = label.pop('bbox_format')
|
bbox_format = label.pop("bbox_format")
|
||||||
normalized = label.pop('normalized')
|
normalized = label.pop("normalized")
|
||||||
|
|
||||||
# NOTE: do NOT resample oriented boxes
|
# NOTE: do NOT resample oriented boxes
|
||||||
segment_resamples = 100 if self.use_obb else 1000
|
segment_resamples = 100 if self.use_obb else 1000
|
||||||
@ -180,7 +196,7 @@ class YOLODataset(BaseDataset):
|
|||||||
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
|
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
|
||||||
else:
|
else:
|
||||||
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
|
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
|
||||||
label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||||
return label
|
return label
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -191,15 +207,15 @@ class YOLODataset(BaseDataset):
|
|||||||
values = list(zip(*[list(b.values()) for b in batch]))
|
values = list(zip(*[list(b.values()) for b in batch]))
|
||||||
for i, k in enumerate(keys):
|
for i, k in enumerate(keys):
|
||||||
value = values[i]
|
value = values[i]
|
||||||
if k == 'img':
|
if k == "img":
|
||||||
value = torch.stack(value, 0)
|
value = torch.stack(value, 0)
|
||||||
if k in ['masks', 'keypoints', 'bboxes', 'cls', 'segments', 'obb']:
|
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
|
||||||
value = torch.cat(value, 0)
|
value = torch.cat(value, 0)
|
||||||
new_batch[k] = value
|
new_batch[k] = value
|
||||||
new_batch['batch_idx'] = list(new_batch['batch_idx'])
|
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||||
for i in range(len(new_batch['batch_idx'])):
|
for i in range(len(new_batch["batch_idx"])):
|
||||||
new_batch['batch_idx'][i] += i # add target image index for build_targets()
|
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
||||||
new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
|
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
|
|
||||||
@ -219,7 +235,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, root, args, augment=False, cache=False, prefix=''):
|
def __init__(self, root, args, augment=False, cache=False, prefix=""):
|
||||||
"""
|
"""
|
||||||
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
||||||
|
|
||||||
@ -231,14 +247,16 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
"""
|
"""
|
||||||
super().__init__(root=root)
|
super().__init__(root=root)
|
||||||
if augment and args.fraction < 1.0: # reduce training fraction
|
if augment and args.fraction < 1.0: # reduce training fraction
|
||||||
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
|
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
||||||
self.prefix = colorstr(f'{prefix}: ') if prefix else ''
|
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
||||||
self.cache_ram = cache is True or cache == 'ram'
|
self.cache_ram = cache is True or cache == "ram"
|
||||||
self.cache_disk = cache == 'disk'
|
self.cache_disk = cache == "disk"
|
||||||
self.samples = self.verify_images() # filter out bad images
|
self.samples = self.verify_images() # filter out bad images
|
||||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||||
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
||||||
self.torch_transforms = classify_augmentations(size=args.imgsz,
|
self.torch_transforms = (
|
||||||
|
classify_augmentations(
|
||||||
|
size=args.imgsz,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
hflip=args.fliplr,
|
hflip=args.fliplr,
|
||||||
vflip=args.flipud,
|
vflip=args.flipud,
|
||||||
@ -246,8 +264,11 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
auto_augment=args.auto_augment,
|
auto_augment=args.auto_augment,
|
||||||
hsv_h=args.hsv_h,
|
hsv_h=args.hsv_h,
|
||||||
hsv_s=args.hsv_s,
|
hsv_s=args.hsv_s,
|
||||||
hsv_v=args.hsv_v) if augment else classify_transforms(
|
hsv_v=args.hsv_v,
|
||||||
size=args.imgsz, crop_fraction=args.crop_fraction)
|
)
|
||||||
|
if augment
|
||||||
|
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
||||||
|
)
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
"""Returns subset of data and targets corresponding to given indices."""
|
"""Returns subset of data and targets corresponding to given indices."""
|
||||||
@ -263,7 +284,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
# Convert NumPy array to PIL image
|
# Convert NumPy array to PIL image
|
||||||
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
||||||
sample = self.torch_transforms(im)
|
sample = self.torch_transforms(im)
|
||||||
return {'img': sample, 'cls': j}
|
return {"img": sample, "cls": j}
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return the total number of samples in the dataset."""
|
"""Return the total number of samples in the dataset."""
|
||||||
@ -271,19 +292,19 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
|
|
||||||
def verify_images(self):
|
def verify_images(self):
|
||||||
"""Verify all images in dataset."""
|
"""Verify all images in dataset."""
|
||||||
desc = f'{self.prefix}Scanning {self.root}...'
|
desc = f"{self.prefix}Scanning {self.root}..."
|
||||||
path = Path(self.root).with_suffix('.cache') # *.cache file path
|
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||||||
|
|
||||||
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
||||||
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||||
assert cache['version'] == DATASET_CACHE_VERSION # matches current version
|
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||||
assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash
|
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||||
nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total
|
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||||
if LOCAL_RANK in (-1, 0):
|
if LOCAL_RANK in (-1, 0):
|
||||||
d = f'{desc} {nf} images, {nc} corrupt'
|
d = f"{desc} {nf} images, {nc} corrupt"
|
||||||
TQDM(None, desc=d, total=n, initial=n)
|
TQDM(None, desc=d, total=n, initial=n)
|
||||||
if cache['msgs']:
|
if cache["msgs"]:
|
||||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
# Run scan if *.cache retrieval failed
|
# Run scan if *.cache retrieval failed
|
||||||
@ -298,13 +319,13 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
nf += nf_f
|
nf += nf_f
|
||||||
nc += nc_f
|
nc += nc_f
|
||||||
pbar.desc = f'{desc} {nf} images, {nc} corrupt'
|
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
||||||
pbar.close()
|
pbar.close()
|
||||||
if msgs:
|
if msgs:
|
||||||
LOGGER.info('\n'.join(msgs))
|
LOGGER.info("\n".join(msgs))
|
||||||
x['hash'] = get_hash([x[0] for x in self.samples])
|
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||||
x['results'] = nf, nc, len(samples), samples
|
x["results"] = nf, nc, len(samples), samples
|
||||||
x['msgs'] = msgs # warnings
|
x["msgs"] = msgs # warnings
|
||||||
save_dataset_cache_file(self.prefix, path, x)
|
save_dataset_cache_file(self.prefix, path, x)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -312,6 +333,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
def load_dataset_cache_file(path):
|
def load_dataset_cache_file(path):
|
||||||
"""Load an Ultralytics *.cache dictionary from path."""
|
"""Load an Ultralytics *.cache dictionary from path."""
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
||||||
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
||||||
gc.enable()
|
gc.enable()
|
||||||
@ -320,15 +342,15 @@ def load_dataset_cache_file(path):
|
|||||||
|
|
||||||
def save_dataset_cache_file(prefix, path, x):
|
def save_dataset_cache_file(prefix, path, x):
|
||||||
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
||||||
x['version'] = DATASET_CACHE_VERSION # add cache version
|
x["version"] = DATASET_CACHE_VERSION # add cache version
|
||||||
if is_dir_writeable(path.parent):
|
if is_dir_writeable(path.parent):
|
||||||
if path.exists():
|
if path.exists():
|
||||||
path.unlink() # remove *.cache file if exists
|
path.unlink() # remove *.cache file if exists
|
||||||
np.save(str(path), x) # save cache for next time
|
np.save(str(path), x) # save cache for next time
|
||||||
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||||
LOGGER.info(f'{prefix}New cache created: {path}')
|
LOGGER.info(f"{prefix}New cache created: {path}")
|
||||||
else:
|
else:
|
||||||
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
||||||
|
|
||||||
|
|
||||||
# TODO: support semantic segmentation
|
# TODO: support semantic segmentation
|
||||||
|
|||||||
@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
from .utils import plot_query_result
|
from .utils import plot_query_result
|
||||||
|
|
||||||
__all__ = ['plot_query_result']
|
__all__ = ["plot_query_result"]
|
||||||
|
|||||||
@ -22,7 +22,6 @@ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, pr
|
|||||||
|
|
||||||
|
|
||||||
class ExplorerDataset(YOLODataset):
|
class ExplorerDataset(YOLODataset):
|
||||||
|
|
||||||
def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
||||||
super().__init__(*args, data=data, **kwargs)
|
super().__init__(*args, data=data, **kwargs)
|
||||||
|
|
||||||
@ -35,7 +34,7 @@ class ExplorerDataset(YOLODataset):
|
|||||||
else: # read image
|
else: # read image
|
||||||
im = cv2.imread(f) # BGR
|
im = cv2.imread(f) # BGR
|
||||||
if im is None:
|
if im is None:
|
||||||
raise FileNotFoundError(f'Image Not Found {f}')
|
raise FileNotFoundError(f"Image Not Found {f}")
|
||||||
h0, w0 = im.shape[:2] # orig hw
|
h0, w0 = im.shape[:2] # orig hw
|
||||||
return im, (h0, w0), im.shape[:2]
|
return im, (h0, w0), im.shape[:2]
|
||||||
|
|
||||||
@ -44,7 +43,7 @@ class ExplorerDataset(YOLODataset):
|
|||||||
def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
||||||
"""Creates transforms for dataset images without resizing."""
|
"""Creates transforms for dataset images without resizing."""
|
||||||
return Format(
|
return Format(
|
||||||
bbox_format='xyxy',
|
bbox_format="xyxy",
|
||||||
normalize=False,
|
normalize=False,
|
||||||
return_mask=self.use_segments,
|
return_mask=self.use_segments,
|
||||||
return_keypoint=self.use_keypoints,
|
return_keypoint=self.use_keypoints,
|
||||||
@ -55,17 +54,16 @@ class ExplorerDataset(YOLODataset):
|
|||||||
|
|
||||||
|
|
||||||
class Explorer:
|
class Explorer:
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer"
|
||||||
data: Union[str, Path] = 'coco128.yaml',
|
) -> None:
|
||||||
model: str = 'yolov8n.pt',
|
checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
|
||||||
uri: str = '~/ultralytics/explorer') -> None:
|
|
||||||
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
|
|
||||||
import lancedb
|
import lancedb
|
||||||
|
|
||||||
self.connection = lancedb.connect(uri)
|
self.connection = lancedb.connect(uri)
|
||||||
self.table_name = Path(data).name.lower() + '_' + model.lower()
|
self.table_name = Path(data).name.lower() + "_" + model.lower()
|
||||||
self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower(
|
self.sim_idx_base_name = (
|
||||||
|
f"{self.table_name}_sim_idx".lower()
|
||||||
) # Use this name and append thres and top_k to reuse the table
|
) # Use this name and append thres and top_k to reuse the table
|
||||||
self.model = YOLO(model)
|
self.model = YOLO(model)
|
||||||
self.data = data # None
|
self.data = data # None
|
||||||
@ -74,7 +72,7 @@ class Explorer:
|
|||||||
self.table = None
|
self.table = None
|
||||||
self.progress = 0
|
self.progress = 0
|
||||||
|
|
||||||
def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None:
|
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
|
||||||
"""
|
"""
|
||||||
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
||||||
already exists. Pass force=True to overwrite the existing table.
|
already exists. Pass force=True to overwrite the existing table.
|
||||||
@ -90,20 +88,20 @@ class Explorer:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if self.table is not None and not force:
|
if self.table is not None and not force:
|
||||||
LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.')
|
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
|
||||||
return
|
return
|
||||||
if self.table_name in self.connection.table_names() and not force:
|
if self.table_name in self.connection.table_names() and not force:
|
||||||
LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.')
|
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
|
||||||
self.table = self.connection.open_table(self.table_name)
|
self.table = self.connection.open_table(self.table_name)
|
||||||
self.progress = 1
|
self.progress = 1
|
||||||
return
|
return
|
||||||
if self.data is None:
|
if self.data is None:
|
||||||
raise ValueError('Data must be provided to create embeddings table')
|
raise ValueError("Data must be provided to create embeddings table")
|
||||||
|
|
||||||
data_info = check_det_dataset(self.data)
|
data_info = check_det_dataset(self.data)
|
||||||
if split not in data_info:
|
if split not in data_info:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}'
|
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
choice_set = data_info[split]
|
choice_set = data_info[split]
|
||||||
@ -113,13 +111,16 @@ class Explorer:
|
|||||||
|
|
||||||
# Create the table schema
|
# Create the table schema
|
||||||
batch = dataset[0]
|
batch = dataset[0]
|
||||||
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
|
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
|
||||||
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
|
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
|
||||||
table.add(
|
table.add(
|
||||||
self._yield_batches(dataset,
|
self._yield_batches(
|
||||||
|
dataset,
|
||||||
data_info,
|
data_info,
|
||||||
self.model,
|
self.model,
|
||||||
exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx']))
|
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.table = table
|
self.table = table
|
||||||
|
|
||||||
@ -131,12 +132,12 @@ class Explorer:
|
|||||||
for k in exclude_keys:
|
for k in exclude_keys:
|
||||||
batch.pop(k, None)
|
batch.pop(k, None)
|
||||||
batch = sanitize_batch(batch, data_info)
|
batch = sanitize_batch(batch, data_info)
|
||||||
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
|
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
|
||||||
yield [batch]
|
yield [batch]
|
||||||
|
|
||||||
def query(self,
|
def query(
|
||||||
imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
|
||||||
limit: int = 25) -> Any: # pyarrow.Table
|
) -> Any: # pyarrow.Table
|
||||||
"""
|
"""
|
||||||
Query the table for similar images. Accepts a single image or a list of images.
|
Query the table for similar images. Accepts a single image or a list of images.
|
||||||
|
|
||||||
@ -157,18 +158,18 @@ class Explorer:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if self.table is None:
|
if self.table is None:
|
||||||
raise ValueError('Table is not created. Please create the table first.')
|
raise ValueError("Table is not created. Please create the table first.")
|
||||||
if isinstance(imgs, str):
|
if isinstance(imgs, str):
|
||||||
imgs = [imgs]
|
imgs = [imgs]
|
||||||
assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
|
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
|
||||||
embeds = self.model.embed(imgs)
|
embeds = self.model.embed(imgs)
|
||||||
# Get avg if multiple images are passed (len > 1)
|
# Get avg if multiple images are passed (len > 1)
|
||||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
||||||
return self.table.search(embeds).limit(limit).to_arrow()
|
return self.table.search(embeds).limit(limit).to_arrow()
|
||||||
|
|
||||||
def sql_query(self,
|
def sql_query(
|
||||||
query: str,
|
self, query: str, return_type: str = "pandas"
|
||||||
return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
||||||
"""
|
"""
|
||||||
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
||||||
|
|
||||||
@ -187,27 +188,29 @@ class Explorer:
|
|||||||
result = exp.sql_query(query)
|
result = exp.sql_query(query)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
assert return_type in ['pandas',
|
assert return_type in [
|
||||||
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
|
"pandas",
|
||||||
|
"arrow",
|
||||||
|
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
||||||
import duckdb
|
import duckdb
|
||||||
|
|
||||||
if self.table is None:
|
if self.table is None:
|
||||||
raise ValueError('Table is not created. Please create the table first.')
|
raise ValueError("Table is not created. Please create the table first.")
|
||||||
|
|
||||||
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
||||||
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
||||||
if not query.startswith('SELECT') and not query.startswith('WHERE'):
|
if not query.startswith("SELECT") and not query.startswith("WHERE"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
|
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
|
||||||
)
|
)
|
||||||
if query.startswith('WHERE'):
|
if query.startswith("WHERE"):
|
||||||
query = f"SELECT * FROM 'table' {query}"
|
query = f"SELECT * FROM 'table' {query}"
|
||||||
LOGGER.info(f'Running query: {query}')
|
LOGGER.info(f"Running query: {query}")
|
||||||
|
|
||||||
rs = duckdb.sql(query)
|
rs = duckdb.sql(query)
|
||||||
if return_type == 'pandas':
|
if return_type == "pandas":
|
||||||
return rs.df()
|
return rs.df()
|
||||||
elif return_type == 'arrow':
|
elif return_type == "arrow":
|
||||||
return rs.arrow()
|
return rs.arrow()
|
||||||
|
|
||||||
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
||||||
@ -228,18 +231,20 @@ class Explorer:
|
|||||||
result = exp.plot_sql_query(query)
|
result = exp.plot_sql_query(query)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
result = self.sql_query(query, return_type='arrow')
|
result = self.sql_query(query, return_type="arrow")
|
||||||
if len(result) == 0:
|
if len(result) == 0:
|
||||||
LOGGER.info('No results found.')
|
LOGGER.info("No results found.")
|
||||||
return None
|
return None
|
||||||
img = plot_query_result(result, plot_labels=labels)
|
img = plot_query_result(result, plot_labels=labels)
|
||||||
return Image.fromarray(img)
|
return Image.fromarray(img)
|
||||||
|
|
||||||
def get_similar(self,
|
def get_similar(
|
||||||
|
self,
|
||||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||||
idx: Union[int, List[int]] = None,
|
idx: Union[int, List[int]] = None,
|
||||||
limit: int = 25,
|
limit: int = 25,
|
||||||
return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
return_type: str = "pandas",
|
||||||
|
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
||||||
"""
|
"""
|
||||||
Query the table for similar images. Accepts a single image or a list of images.
|
Query the table for similar images. Accepts a single image or a list of images.
|
||||||
|
|
||||||
@ -259,21 +264,25 @@ class Explorer:
|
|||||||
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
assert return_type in ['pandas',
|
assert return_type in [
|
||||||
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
|
"pandas",
|
||||||
|
"arrow",
|
||||||
|
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
||||||
img = self._check_imgs_or_idxs(img, idx)
|
img = self._check_imgs_or_idxs(img, idx)
|
||||||
similar = self.query(img, limit=limit)
|
similar = self.query(img, limit=limit)
|
||||||
|
|
||||||
if return_type == 'pandas':
|
if return_type == "pandas":
|
||||||
return similar.to_pandas()
|
return similar.to_pandas()
|
||||||
elif return_type == 'arrow':
|
elif return_type == "arrow":
|
||||||
return similar
|
return similar
|
||||||
|
|
||||||
def plot_similar(self,
|
def plot_similar(
|
||||||
|
self,
|
||||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||||
idx: Union[int, List[int]] = None,
|
idx: Union[int, List[int]] = None,
|
||||||
limit: int = 25,
|
limit: int = 25,
|
||||||
labels: bool = True) -> Image.Image:
|
labels: bool = True,
|
||||||
|
) -> Image.Image:
|
||||||
"""
|
"""
|
||||||
Plot the similar images. Accepts images or indexes.
|
Plot the similar images. Accepts images or indexes.
|
||||||
|
|
||||||
@ -293,9 +302,9 @@ class Explorer:
|
|||||||
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
|
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
similar = self.get_similar(img, idx, limit, return_type='arrow')
|
similar = self.get_similar(img, idx, limit, return_type="arrow")
|
||||||
if len(similar) == 0:
|
if len(similar) == 0:
|
||||||
LOGGER.info('No results found.')
|
LOGGER.info("No results found.")
|
||||||
return None
|
return None
|
||||||
img = plot_query_result(similar, plot_labels=labels)
|
img = plot_query_result(similar, plot_labels=labels)
|
||||||
return Image.fromarray(img)
|
return Image.fromarray(img)
|
||||||
@ -323,34 +332,37 @@ class Explorer:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if self.table is None:
|
if self.table is None:
|
||||||
raise ValueError('Table is not created. Please create the table first.')
|
raise ValueError("Table is not created. Please create the table first.")
|
||||||
sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower()
|
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
|
||||||
if sim_idx_table_name in self.connection.table_names() and not force:
|
if sim_idx_table_name in self.connection.table_names() and not force:
|
||||||
LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.')
|
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
|
||||||
return self.connection.open_table(sim_idx_table_name).to_pandas()
|
return self.connection.open_table(sim_idx_table_name).to_pandas()
|
||||||
|
|
||||||
if top_k and not (1.0 >= top_k >= 0.0):
|
if top_k and not (1.0 >= top_k >= 0.0):
|
||||||
raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}')
|
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
|
||||||
if max_dist < 0.0:
|
if max_dist < 0.0:
|
||||||
raise ValueError(f'max_dist must be greater than 0. Got {max_dist}')
|
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
|
||||||
|
|
||||||
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
||||||
top_k = max(top_k, 1)
|
top_k = max(top_k, 1)
|
||||||
features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict()
|
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
|
||||||
im_files = features['im_file']
|
im_files = features["im_file"]
|
||||||
embeddings = features['vector']
|
embeddings = features["vector"]
|
||||||
|
|
||||||
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
|
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
|
||||||
|
|
||||||
def _yield_sim_idx():
|
def _yield_sim_idx():
|
||||||
"""Generates a dataframe with similarity indices and distances for images."""
|
"""Generates a dataframe with similarity indices and distances for images."""
|
||||||
for i in tqdm(range(len(embeddings))):
|
for i in tqdm(range(len(embeddings))):
|
||||||
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
|
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
|
||||||
yield [{
|
yield [
|
||||||
'idx': i,
|
{
|
||||||
'im_file': im_files[i],
|
"idx": i,
|
||||||
'count': len(sim_idx),
|
"im_file": im_files[i],
|
||||||
'sim_im_files': sim_idx['im_file'].tolist()}]
|
"count": len(sim_idx),
|
||||||
|
"sim_im_files": sim_idx["im_file"].tolist(),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
sim_table.add(_yield_sim_idx())
|
sim_table.add(_yield_sim_idx())
|
||||||
self.sim_index = sim_table
|
self.sim_index = sim_table
|
||||||
@ -381,7 +393,7 @@ class Explorer:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
||||||
sim_count = sim_idx['count'].tolist()
|
sim_count = sim_idx["count"].tolist()
|
||||||
sim_count = np.array(sim_count)
|
sim_count = np.array(sim_count)
|
||||||
|
|
||||||
indices = np.arange(len(sim_count))
|
indices = np.arange(len(sim_count))
|
||||||
@ -390,25 +402,26 @@ class Explorer:
|
|||||||
plt.bar(indices, sim_count)
|
plt.bar(indices, sim_count)
|
||||||
|
|
||||||
# Customize the plot (optional)
|
# Customize the plot (optional)
|
||||||
plt.xlabel('data idx')
|
plt.xlabel("data idx")
|
||||||
plt.ylabel('Count')
|
plt.ylabel("Count")
|
||||||
plt.title('Similarity Count')
|
plt.title("Similarity Count")
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
plt.savefig(buffer, format='png')
|
plt.savefig(buffer, format="png")
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
|
|
||||||
# Use Pillow to open the image from the buffer
|
# Use Pillow to open the image from the buffer
|
||||||
return Image.fromarray(np.array(Image.open(buffer)))
|
return Image.fromarray(np.array(Image.open(buffer)))
|
||||||
|
|
||||||
def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None],
|
def _check_imgs_or_idxs(
|
||||||
idx: Union[None, int, List[int]]) -> List[np.ndarray]:
|
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
|
||||||
|
) -> List[np.ndarray]:
|
||||||
if img is None and idx is None:
|
if img is None and idx is None:
|
||||||
raise ValueError('Either img or idx must be provided.')
|
raise ValueError("Either img or idx must be provided.")
|
||||||
if img is not None and idx is not None:
|
if img is not None and idx is not None:
|
||||||
raise ValueError('Only one of img or idx must be provided.')
|
raise ValueError("Only one of img or idx must be provided.")
|
||||||
if idx is not None:
|
if idx is not None:
|
||||||
idx = idx if isinstance(idx, list) else [idx]
|
idx = idx if isinstance(idx, list) else [idx]
|
||||||
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
|
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
|
||||||
|
|
||||||
return img if isinstance(img, list) else [img]
|
return img if isinstance(img, list) else [img]
|
||||||
|
|
||||||
@ -433,7 +446,7 @@ class Explorer:
|
|||||||
try:
|
try:
|
||||||
df = self.sql_query(result)
|
df = self.sql_query(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
|
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
||||||
LOGGER.error(e)
|
LOGGER.error(e)
|
||||||
return None
|
return None
|
||||||
return df
|
return df
|
||||||
|
|||||||
@ -9,100 +9,114 @@ from ultralytics import Explorer
|
|||||||
from ultralytics.utils import ROOT, SETTINGS
|
from ultralytics.utils import ROOT, SETTINGS
|
||||||
from ultralytics.utils.checks import check_requirements
|
from ultralytics.utils.checks import check_requirements
|
||||||
|
|
||||||
check_requirements(('streamlit>=1.29.0', 'streamlit-select>=0.2'))
|
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2"))
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from streamlit_select import image_select
|
from streamlit_select import image_select
|
||||||
|
|
||||||
|
|
||||||
def _get_explorer():
|
def _get_explorer():
|
||||||
"""Initializes and returns an instance of the Explorer class."""
|
"""Initializes and returns an instance of the Explorer class."""
|
||||||
exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
|
exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
|
||||||
thread = Thread(target=exp.create_embeddings_table,
|
thread = Thread(
|
||||||
kwargs={'force': st.session_state.get('force_recreate_embeddings')})
|
target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")}
|
||||||
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
progress_bar = st.progress(0, text='Creating embeddings table...')
|
progress_bar = st.progress(0, text="Creating embeddings table...")
|
||||||
while exp.progress < 1:
|
while exp.progress < 1:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
|
progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
|
||||||
thread.join()
|
thread.join()
|
||||||
st.session_state['explorer'] = exp
|
st.session_state["explorer"] = exp
|
||||||
progress_bar.empty()
|
progress_bar.empty()
|
||||||
|
|
||||||
|
|
||||||
def init_explorer_form():
|
def init_explorer_form():
|
||||||
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
|
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
|
||||||
datasets = ROOT / 'cfg' / 'datasets'
|
datasets = ROOT / "cfg" / "datasets"
|
||||||
ds = [d.name for d in datasets.glob('*.yaml')]
|
ds = [d.name for d in datasets.glob("*.yaml")]
|
||||||
models = [
|
models = [
|
||||||
'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
|
"yolov8n.pt",
|
||||||
'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
|
"yolov8s.pt",
|
||||||
'yolov8l-pose.pt', 'yolov8x-pose.pt']
|
"yolov8m.pt",
|
||||||
with st.form(key='explorer_init_form'):
|
"yolov8l.pt",
|
||||||
|
"yolov8x.pt",
|
||||||
|
"yolov8n-seg.pt",
|
||||||
|
"yolov8s-seg.pt",
|
||||||
|
"yolov8m-seg.pt",
|
||||||
|
"yolov8l-seg.pt",
|
||||||
|
"yolov8x-seg.pt",
|
||||||
|
"yolov8n-pose.pt",
|
||||||
|
"yolov8s-pose.pt",
|
||||||
|
"yolov8m-pose.pt",
|
||||||
|
"yolov8l-pose.pt",
|
||||||
|
"yolov8x-pose.pt",
|
||||||
|
]
|
||||||
|
with st.form(key="explorer_init_form"):
|
||||||
col1, col2 = st.columns(2)
|
col1, col2 = st.columns(2)
|
||||||
with col1:
|
with col1:
|
||||||
st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
|
st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml"))
|
||||||
with col2:
|
with col2:
|
||||||
st.selectbox('Select model', models, key='model')
|
st.selectbox("Select model", models, key="model")
|
||||||
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
|
st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
|
||||||
|
|
||||||
st.form_submit_button('Explore', on_click=_get_explorer)
|
st.form_submit_button("Explore", on_click=_get_explorer)
|
||||||
|
|
||||||
|
|
||||||
def query_form():
|
def query_form():
|
||||||
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
|
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
|
||||||
with st.form('query_form'):
|
with st.form("query_form"):
|
||||||
col1, col2 = st.columns([0.8, 0.2])
|
col1, col2 = st.columns([0.8, 0.2])
|
||||||
with col1:
|
with col1:
|
||||||
st.text_input('Query',
|
st.text_input(
|
||||||
|
"Query",
|
||||||
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
|
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
|
||||||
label_visibility='collapsed',
|
label_visibility="collapsed",
|
||||||
key='query')
|
key="query",
|
||||||
|
)
|
||||||
with col2:
|
with col2:
|
||||||
st.form_submit_button('Query', on_click=run_sql_query)
|
st.form_submit_button("Query", on_click=run_sql_query)
|
||||||
|
|
||||||
|
|
||||||
def ai_query_form():
|
def ai_query_form():
|
||||||
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
|
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
|
||||||
with st.form('ai_query_form'):
|
with st.form("ai_query_form"):
|
||||||
col1, col2 = st.columns([0.8, 0.2])
|
col1, col2 = st.columns([0.8, 0.2])
|
||||||
with col1:
|
with col1:
|
||||||
st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
|
st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
|
||||||
with col2:
|
with col2:
|
||||||
st.form_submit_button('Ask AI', on_click=run_ai_query)
|
st.form_submit_button("Ask AI", on_click=run_ai_query)
|
||||||
|
|
||||||
|
|
||||||
def find_similar_imgs(imgs):
|
def find_similar_imgs(imgs):
|
||||||
"""Initializes a Streamlit form for AI-based image querying with custom input."""
|
"""Initializes a Streamlit form for AI-based image querying with custom input."""
|
||||||
exp = st.session_state['explorer']
|
exp = st.session_state["explorer"]
|
||||||
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
|
similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
|
||||||
paths = similar.to_pydict()['im_file']
|
paths = similar.to_pydict()["im_file"]
|
||||||
st.session_state['imgs'] = paths
|
st.session_state["imgs"] = paths
|
||||||
|
|
||||||
|
|
||||||
def similarity_form(selected_imgs):
|
def similarity_form(selected_imgs):
|
||||||
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
|
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
|
||||||
st.write('Similarity Search')
|
st.write("Similarity Search")
|
||||||
with st.form('similarity_form'):
|
with st.form("similarity_form"):
|
||||||
subcol1, subcol2 = st.columns([1, 1])
|
subcol1, subcol2 = st.columns([1, 1])
|
||||||
with subcol1:
|
with subcol1:
|
||||||
st.number_input('limit',
|
st.number_input(
|
||||||
min_value=None,
|
"limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
|
||||||
max_value=None,
|
)
|
||||||
value=25,
|
|
||||||
label_visibility='collapsed',
|
|
||||||
key='limit')
|
|
||||||
|
|
||||||
with subcol2:
|
with subcol2:
|
||||||
disabled = not len(selected_imgs)
|
disabled = not len(selected_imgs)
|
||||||
st.write('Selected: ', len(selected_imgs))
|
st.write("Selected: ", len(selected_imgs))
|
||||||
st.form_submit_button(
|
st.form_submit_button(
|
||||||
'Search',
|
"Search",
|
||||||
disabled=disabled,
|
disabled=disabled,
|
||||||
on_click=find_similar_imgs,
|
on_click=find_similar_imgs,
|
||||||
args=(selected_imgs, ),
|
args=(selected_imgs,),
|
||||||
)
|
)
|
||||||
if disabled:
|
if disabled:
|
||||||
st.error('Select at least one image to search.')
|
st.error("Select at least one image to search.")
|
||||||
|
|
||||||
|
|
||||||
# def persist_reset_form():
|
# def persist_reset_form():
|
||||||
@ -117,100 +131,108 @@ def similarity_form(selected_imgs):
|
|||||||
|
|
||||||
def run_sql_query():
|
def run_sql_query():
|
||||||
"""Executes an SQL query and returns the results."""
|
"""Executes an SQL query and returns the results."""
|
||||||
st.session_state['error'] = None
|
st.session_state["error"] = None
|
||||||
query = st.session_state.get('query')
|
query = st.session_state.get("query")
|
||||||
if query.rstrip().lstrip():
|
if query.rstrip().lstrip():
|
||||||
exp = st.session_state['explorer']
|
exp = st.session_state["explorer"]
|
||||||
res = exp.sql_query(query, return_type='arrow')
|
res = exp.sql_query(query, return_type="arrow")
|
||||||
st.session_state['imgs'] = res.to_pydict()['im_file']
|
st.session_state["imgs"] = res.to_pydict()["im_file"]
|
||||||
|
|
||||||
|
|
||||||
def run_ai_query():
|
def run_ai_query():
|
||||||
"""Execute SQL query and update session state with query results."""
|
"""Execute SQL query and update session state with query results."""
|
||||||
if not SETTINGS['openai_api_key']:
|
if not SETTINGS["openai_api_key"]:
|
||||||
st.session_state[
|
st.session_state[
|
||||||
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
"error"
|
||||||
|
] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
||||||
return
|
return
|
||||||
st.session_state['error'] = None
|
st.session_state["error"] = None
|
||||||
query = st.session_state.get('ai_query')
|
query = st.session_state.get("ai_query")
|
||||||
if query.rstrip().lstrip():
|
if query.rstrip().lstrip():
|
||||||
exp = st.session_state['explorer']
|
exp = st.session_state["explorer"]
|
||||||
res = exp.ask_ai(query)
|
res = exp.ask_ai(query)
|
||||||
if not isinstance(res, pd.DataFrame) or res.empty:
|
if not isinstance(res, pd.DataFrame) or res.empty:
|
||||||
st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
|
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
|
||||||
return
|
return
|
||||||
st.session_state['imgs'] = res['im_file'].to_list()
|
st.session_state["imgs"] = res["im_file"].to_list()
|
||||||
|
|
||||||
|
|
||||||
def reset_explorer():
|
def reset_explorer():
|
||||||
"""Resets the explorer to its initial state by clearing session variables."""
|
"""Resets the explorer to its initial state by clearing session variables."""
|
||||||
st.session_state['explorer'] = None
|
st.session_state["explorer"] = None
|
||||||
st.session_state['imgs'] = None
|
st.session_state["imgs"] = None
|
||||||
st.session_state['error'] = None
|
st.session_state["error"] = None
|
||||||
|
|
||||||
|
|
||||||
def utralytics_explorer_docs_callback():
|
def utralytics_explorer_docs_callback():
|
||||||
"""Resets the explorer to its initial state by clearing session variables."""
|
"""Resets the explorer to its initial state by clearing session variables."""
|
||||||
with st.container(border=True):
|
with st.container(border=True):
|
||||||
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
|
st.image(
|
||||||
width=100)
|
"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
|
||||||
|
width=100,
|
||||||
|
)
|
||||||
st.markdown(
|
st.markdown(
|
||||||
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
|
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
|
||||||
unsafe_allow_html=True,
|
unsafe_allow_html=True,
|
||||||
help=None)
|
help=None,
|
||||||
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
|
)
|
||||||
|
st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
|
||||||
|
|
||||||
|
|
||||||
def layout():
|
def layout():
|
||||||
"""Resets explorer session variables and provides documentation with a link to API docs."""
|
"""Resets explorer session variables and provides documentation with a link to API docs."""
|
||||||
st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
|
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
|
||||||
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
|
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
|
||||||
|
|
||||||
if st.session_state.get('explorer') is None:
|
if st.session_state.get("explorer") is None:
|
||||||
init_explorer_form()
|
init_explorer_form()
|
||||||
return
|
return
|
||||||
|
|
||||||
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
|
st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
|
||||||
exp = st.session_state.get('explorer')
|
exp = st.session_state.get("explorer")
|
||||||
col1, col2 = st.columns([0.75, 0.25], gap='small')
|
col1, col2 = st.columns([0.75, 0.25], gap="small")
|
||||||
imgs = []
|
imgs = []
|
||||||
if st.session_state.get('error'):
|
if st.session_state.get("error"):
|
||||||
st.error(st.session_state['error'])
|
st.error(st.session_state["error"])
|
||||||
else:
|
else:
|
||||||
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
|
imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
|
||||||
total_imgs, selected_imgs = len(imgs), []
|
total_imgs, selected_imgs = len(imgs), []
|
||||||
with col1:
|
with col1:
|
||||||
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
|
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
|
||||||
with subcol1:
|
with subcol1:
|
||||||
st.write('Max Images Displayed:')
|
st.write("Max Images Displayed:")
|
||||||
with subcol2:
|
with subcol2:
|
||||||
num = st.number_input('Max Images Displayed',
|
num = st.number_input(
|
||||||
|
"Max Images Displayed",
|
||||||
min_value=0,
|
min_value=0,
|
||||||
max_value=total_imgs,
|
max_value=total_imgs,
|
||||||
value=min(500, total_imgs),
|
value=min(500, total_imgs),
|
||||||
key='num_imgs_displayed',
|
key="num_imgs_displayed",
|
||||||
label_visibility='collapsed')
|
label_visibility="collapsed",
|
||||||
|
)
|
||||||
with subcol3:
|
with subcol3:
|
||||||
st.write('Start Index:')
|
st.write("Start Index:")
|
||||||
with subcol4:
|
with subcol4:
|
||||||
start_idx = st.number_input('Start Index',
|
start_idx = st.number_input(
|
||||||
|
"Start Index",
|
||||||
min_value=0,
|
min_value=0,
|
||||||
max_value=total_imgs,
|
max_value=total_imgs,
|
||||||
value=0,
|
value=0,
|
||||||
key='start_index',
|
key="start_index",
|
||||||
label_visibility='collapsed')
|
label_visibility="collapsed",
|
||||||
|
)
|
||||||
with subcol5:
|
with subcol5:
|
||||||
reset = st.button('Reset', use_container_width=False, key='reset')
|
reset = st.button("Reset", use_container_width=False, key="reset")
|
||||||
if reset:
|
if reset:
|
||||||
st.session_state['imgs'] = None
|
st.session_state["imgs"] = None
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
query_form()
|
query_form()
|
||||||
ai_query_form()
|
ai_query_form()
|
||||||
if total_imgs:
|
if total_imgs:
|
||||||
imgs_displayed = imgs[start_idx:start_idx + num]
|
imgs_displayed = imgs[start_idx : start_idx + num]
|
||||||
selected_imgs = image_select(
|
selected_imgs = image_select(
|
||||||
f'Total samples: {total_imgs}',
|
f"Total samples: {total_imgs}",
|
||||||
images=imgs_displayed,
|
images=imgs_displayed,
|
||||||
use_container_width=False,
|
use_container_width=False,
|
||||||
# indices=[i for i in range(num)] if select_all else None,
|
# indices=[i for i in range(num)] if select_all else None,
|
||||||
@ -222,5 +244,5 @@ def layout():
|
|||||||
utralytics_explorer_docs_callback()
|
utralytics_explorer_docs_callback()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
layout()
|
layout()
|
||||||
|
|||||||
@ -46,14 +46,13 @@ def get_sim_index_schema():
|
|||||||
|
|
||||||
def sanitize_batch(batch, dataset_info):
|
def sanitize_batch(batch, dataset_info):
|
||||||
"""Sanitizes input batch for inference, ensuring correct format and dimensions."""
|
"""Sanitizes input batch for inference, ensuring correct format and dimensions."""
|
||||||
batch['cls'] = batch['cls'].flatten().int().tolist()
|
batch["cls"] = batch["cls"].flatten().int().tolist()
|
||||||
box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
|
box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
|
||||||
batch['bboxes'] = [box for box, _ in box_cls_pair]
|
batch["bboxes"] = [box for box, _ in box_cls_pair]
|
||||||
batch['cls'] = [cls for _, cls in box_cls_pair]
|
batch["cls"] = [cls for _, cls in box_cls_pair]
|
||||||
batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
|
batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
|
||||||
batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
|
batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
|
||||||
batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
|
batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@ -65,15 +64,16 @@ def plot_query_result(similar_set, plot_labels=True):
|
|||||||
similar_set (list): Pyarrow or pandas object containing the similar data points
|
similar_set (list): Pyarrow or pandas object containing the similar data points
|
||||||
plot_labels (bool): Whether to plot labels or not
|
plot_labels (bool): Whether to plot labels or not
|
||||||
"""
|
"""
|
||||||
similar_set = similar_set.to_dict(
|
similar_set = (
|
||||||
orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
|
similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
|
||||||
|
)
|
||||||
empty_masks = [[[]]]
|
empty_masks = [[[]]]
|
||||||
empty_boxes = [[]]
|
empty_boxes = [[]]
|
||||||
images = similar_set.get('im_file', [])
|
images = similar_set.get("im_file", [])
|
||||||
bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
|
bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
|
||||||
masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
|
masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
|
||||||
kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
|
kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
|
||||||
cls = similar_set.get('cls', [])
|
cls = similar_set.get("cls", [])
|
||||||
|
|
||||||
plot_size = 640
|
plot_size = 640
|
||||||
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
|
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
|
||||||
@ -104,34 +104,26 @@ def plot_query_result(similar_set, plot_labels=True):
|
|||||||
batch_idx = np.concatenate(batch_idx, axis=0)
|
batch_idx = np.concatenate(batch_idx, axis=0)
|
||||||
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
|
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
|
||||||
|
|
||||||
return plot_images(imgs,
|
return plot_images(
|
||||||
batch_idx,
|
imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
|
||||||
cls,
|
)
|
||||||
bboxes=boxes,
|
|
||||||
masks=masks,
|
|
||||||
kpts=kpts,
|
|
||||||
max_subplots=len(images),
|
|
||||||
save=False,
|
|
||||||
threaded=False)
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_sql_query(query):
|
def prompt_sql_query(query):
|
||||||
"""Plots images with optional labels from a similar data set."""
|
"""Plots images with optional labels from a similar data set."""
|
||||||
check_requirements('openai>=1.6.1')
|
check_requirements("openai>=1.6.1")
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
if not SETTINGS['openai_api_key']:
|
if not SETTINGS["openai_api_key"]:
|
||||||
logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
|
logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
|
||||||
openai_api_key = getpass.getpass('OpenAI API key: ')
|
openai_api_key = getpass.getpass("OpenAI API key: ")
|
||||||
SETTINGS.update({'openai_api_key': openai_api_key})
|
SETTINGS.update({"openai_api_key": openai_api_key})
|
||||||
openai = OpenAI(api_key=SETTINGS['openai_api_key'])
|
openai = OpenAI(api_key=SETTINGS["openai_api_key"])
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
'role':
|
"role": "system",
|
||||||
'system',
|
"content": """
|
||||||
'content':
|
|
||||||
'''
|
|
||||||
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
|
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
|
||||||
the following schema and a user request. You only need to output the format with fixed selection
|
the following schema and a user request. You only need to output the format with fixed selection
|
||||||
statement that selects everything from "'table'", like `SELECT * from 'table'`
|
statement that selects everything from "'table'", like `SELECT * from 'table'`
|
||||||
@ -165,10 +157,10 @@ def prompt_sql_query(query):
|
|||||||
request - Get all data points that contain 2 or more people and at least one dog
|
request - Get all data points that contain 2 or more people and at least one dog
|
||||||
correct query-
|
correct query-
|
||||||
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
|
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
|
||||||
'''},
|
""",
|
||||||
{
|
},
|
||||||
'role': 'user',
|
{"role": "user", "content": f"{query}"},
|
||||||
'content': f'{query}'}, ]
|
]
|
||||||
|
|
||||||
response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
|
response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from ultralytics.utils.checks import check_requirements
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SourceTypes:
|
class SourceTypes:
|
||||||
"""Class to represent various types of input sources for predictions."""
|
"""Class to represent various types of input sources for predictions."""
|
||||||
|
|
||||||
webcam: bool = False
|
webcam: bool = False
|
||||||
screenshot: bool = False
|
screenshot: bool = False
|
||||||
from_img: bool = False
|
from_img: bool = False
|
||||||
@ -59,12 +60,12 @@ class LoadStreams:
|
|||||||
__len__: Return the length of the sources object.
|
__len__: Return the length of the sources object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, buffer=False):
|
def __init__(self, sources="file.streams", imgsz=640, vid_stride=1, buffer=False):
|
||||||
"""Initialize instance variables and check for consistent input stream shapes."""
|
"""Initialize instance variables and check for consistent input stream shapes."""
|
||||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||||
self.buffer = buffer # buffer input streams
|
self.buffer = buffer # buffer input streams
|
||||||
self.running = True # running flag for Thread
|
self.running = True # running flag for Thread
|
||||||
self.mode = 'stream'
|
self.mode = "stream"
|
||||||
self.imgsz = imgsz
|
self.imgsz = imgsz
|
||||||
self.vid_stride = vid_stride # video frame-rate stride
|
self.vid_stride = vid_stride # video frame-rate stride
|
||||||
|
|
||||||
@ -79,33 +80,36 @@ class LoadStreams:
|
|||||||
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
|
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
|
||||||
for i, s in enumerate(sources): # index, source
|
for i, s in enumerate(sources): # index, source
|
||||||
# Start thread to read frames from video stream
|
# Start thread to read frames from video stream
|
||||||
st = f'{i + 1}/{n}: {s}... '
|
st = f"{i + 1}/{n}: {s}... "
|
||||||
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
|
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
|
||||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
|
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
|
||||||
s = get_best_youtube_url(s)
|
s = get_best_youtube_url(s)
|
||||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||||
if s == 0 and (is_colab() or is_kaggle()):
|
if s == 0 and (is_colab() or is_kaggle()):
|
||||||
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
raise NotImplementedError(
|
||||||
"Try running 'source=0' in a local environment.")
|
"'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
||||||
|
"Try running 'source=0' in a local environment."
|
||||||
|
)
|
||||||
self.caps[i] = cv2.VideoCapture(s) # store video capture object
|
self.caps[i] = cv2.VideoCapture(s) # store video capture object
|
||||||
if not self.caps[i].isOpened():
|
if not self.caps[i].isOpened():
|
||||||
raise ConnectionError(f'{st}Failed to open {s}')
|
raise ConnectionError(f"{st}Failed to open {s}")
|
||||||
w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
|
w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
|
h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
||||||
self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(
|
self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(
|
||||||
'inf') # infinite stream fallback
|
"inf"
|
||||||
|
) # infinite stream fallback
|
||||||
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
|
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
|
||||||
|
|
||||||
success, im = self.caps[i].read() # guarantee first frame
|
success, im = self.caps[i].read() # guarantee first frame
|
||||||
if not success or im is None:
|
if not success or im is None:
|
||||||
raise ConnectionError(f'{st}Failed to read images from {s}')
|
raise ConnectionError(f"{st}Failed to read images from {s}")
|
||||||
self.imgs[i].append(im)
|
self.imgs[i].append(im)
|
||||||
self.shape[i] = im.shape
|
self.shape[i] = im.shape
|
||||||
self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
|
self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
|
||||||
LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
|
LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
|
||||||
self.threads[i].start()
|
self.threads[i].start()
|
||||||
LOGGER.info('') # newline
|
LOGGER.info("") # newline
|
||||||
|
|
||||||
# Check for common shapes
|
# Check for common shapes
|
||||||
self.bs = self.__len__()
|
self.bs = self.__len__()
|
||||||
@ -121,7 +125,7 @@ class LoadStreams:
|
|||||||
success, im = cap.retrieve()
|
success, im = cap.retrieve()
|
||||||
if not success:
|
if not success:
|
||||||
im = np.zeros(self.shape[i], dtype=np.uint8)
|
im = np.zeros(self.shape[i], dtype=np.uint8)
|
||||||
LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
|
LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.")
|
||||||
cap.open(stream) # re-open stream if signal was lost
|
cap.open(stream) # re-open stream if signal was lost
|
||||||
if self.buffer:
|
if self.buffer:
|
||||||
self.imgs[i].append(im)
|
self.imgs[i].append(im)
|
||||||
@ -140,7 +144,7 @@ class LoadStreams:
|
|||||||
try:
|
try:
|
||||||
cap.release() # release video capture
|
cap.release() # release video capture
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'WARNING ⚠️ Could not release VideoCapture object: {e}')
|
LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}")
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@ -154,16 +158,15 @@ class LoadStreams:
|
|||||||
|
|
||||||
images = []
|
images = []
|
||||||
for i, x in enumerate(self.imgs):
|
for i, x in enumerate(self.imgs):
|
||||||
|
|
||||||
# Wait until a frame is available in each buffer
|
# Wait until a frame is available in each buffer
|
||||||
while not x:
|
while not x:
|
||||||
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord('q'): # q to quit
|
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit
|
||||||
self.close()
|
self.close()
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
time.sleep(1 / min(self.fps))
|
time.sleep(1 / min(self.fps))
|
||||||
x = self.imgs[i]
|
x = self.imgs[i]
|
||||||
if not x:
|
if not x:
|
||||||
LOGGER.warning(f'WARNING ⚠️ Waiting for stream {i}')
|
LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}")
|
||||||
|
|
||||||
# Get and remove the first frame from imgs buffer
|
# Get and remove the first frame from imgs buffer
|
||||||
if self.buffer:
|
if self.buffer:
|
||||||
@ -174,7 +177,7 @@ class LoadStreams:
|
|||||||
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
|
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
|
||||||
x.clear()
|
x.clear()
|
||||||
|
|
||||||
return self.sources, images, None, ''
|
return self.sources, images, None, ""
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Return the length of the sources object."""
|
"""Return the length of the sources object."""
|
||||||
@ -209,7 +212,7 @@ class LoadScreenshots:
|
|||||||
|
|
||||||
def __init__(self, source, imgsz=640):
|
def __init__(self, source, imgsz=640):
|
||||||
"""Source = [screen_number left top width height] (pixels)."""
|
"""Source = [screen_number left top width height] (pixels)."""
|
||||||
check_requirements('mss')
|
check_requirements("mss")
|
||||||
import mss # noqa
|
import mss # noqa
|
||||||
|
|
||||||
source, *params = source.split()
|
source, *params = source.split()
|
||||||
@ -221,18 +224,18 @@ class LoadScreenshots:
|
|||||||
elif len(params) == 5:
|
elif len(params) == 5:
|
||||||
self.screen, left, top, width, height = (int(x) for x in params)
|
self.screen, left, top, width, height = (int(x) for x in params)
|
||||||
self.imgsz = imgsz
|
self.imgsz = imgsz
|
||||||
self.mode = 'stream'
|
self.mode = "stream"
|
||||||
self.frame = 0
|
self.frame = 0
|
||||||
self.sct = mss.mss()
|
self.sct = mss.mss()
|
||||||
self.bs = 1
|
self.bs = 1
|
||||||
|
|
||||||
# Parse monitor shape
|
# Parse monitor shape
|
||||||
monitor = self.sct.monitors[self.screen]
|
monitor = self.sct.monitors[self.screen]
|
||||||
self.top = monitor['top'] if top is None else (monitor['top'] + top)
|
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
||||||
self.left = monitor['left'] if left is None else (monitor['left'] + left)
|
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
||||||
self.width = width or monitor['width']
|
self.width = width or monitor["width"]
|
||||||
self.height = height or monitor['height']
|
self.height = height or monitor["height"]
|
||||||
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
|
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""Returns an iterator of the object."""
|
"""Returns an iterator of the object."""
|
||||||
@ -241,7 +244,7 @@ class LoadScreenshots:
|
|||||||
def __next__(self):
|
def __next__(self):
|
||||||
"""mss screen capture: get raw pixels from the screen as np array."""
|
"""mss screen capture: get raw pixels from the screen as np array."""
|
||||||
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
|
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
|
||||||
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
|
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
||||||
|
|
||||||
self.frame += 1
|
self.frame += 1
|
||||||
return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string
|
return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string
|
||||||
@ -274,32 +277,32 @@ class LoadImages:
|
|||||||
def __init__(self, path, imgsz=640, vid_stride=1):
|
def __init__(self, path, imgsz=640, vid_stride=1):
|
||||||
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
||||||
parent = None
|
parent = None
|
||||||
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
|
||||||
parent = Path(path).parent
|
parent = Path(path).parent
|
||||||
path = Path(path).read_text().splitlines() # list of sources
|
path = Path(path).read_text().splitlines() # list of sources
|
||||||
files = []
|
files = []
|
||||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||||
a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
|
a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
|
||||||
if '*' in a:
|
if "*" in a:
|
||||||
files.extend(sorted(glob.glob(a, recursive=True))) # glob
|
files.extend(sorted(glob.glob(a, recursive=True))) # glob
|
||||||
elif os.path.isdir(a):
|
elif os.path.isdir(a):
|
||||||
files.extend(sorted(glob.glob(os.path.join(a, '*.*')))) # dir
|
files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir
|
||||||
elif os.path.isfile(a):
|
elif os.path.isfile(a):
|
||||||
files.append(a) # files (absolute or relative to CWD)
|
files.append(a) # files (absolute or relative to CWD)
|
||||||
elif parent and (parent / p).is_file():
|
elif parent and (parent / p).is_file():
|
||||||
files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
|
files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f'{p} does not exist')
|
raise FileNotFoundError(f"{p} does not exist")
|
||||||
|
|
||||||
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
|
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
|
||||||
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
|
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
|
||||||
ni, nv = len(images), len(videos)
|
ni, nv = len(images), len(videos)
|
||||||
|
|
||||||
self.imgsz = imgsz
|
self.imgsz = imgsz
|
||||||
self.files = images + videos
|
self.files = images + videos
|
||||||
self.nf = ni + nv # number of files
|
self.nf = ni + nv # number of files
|
||||||
self.video_flag = [False] * ni + [True] * nv
|
self.video_flag = [False] * ni + [True] * nv
|
||||||
self.mode = 'image'
|
self.mode = "image"
|
||||||
self.vid_stride = vid_stride # video frame-rate stride
|
self.vid_stride = vid_stride # video frame-rate stride
|
||||||
self.bs = 1
|
self.bs = 1
|
||||||
if any(videos):
|
if any(videos):
|
||||||
@ -307,8 +310,10 @@ class LoadImages:
|
|||||||
else:
|
else:
|
||||||
self.cap = None
|
self.cap = None
|
||||||
if self.nf == 0:
|
if self.nf == 0:
|
||||||
raise FileNotFoundError(f'No images or videos found in {p}. '
|
raise FileNotFoundError(
|
||||||
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
|
f"No images or videos found in {p}. "
|
||||||
|
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
||||||
|
)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""Returns an iterator object for VideoStream or ImageFolder."""
|
"""Returns an iterator object for VideoStream or ImageFolder."""
|
||||||
@ -323,7 +328,7 @@ class LoadImages:
|
|||||||
|
|
||||||
if self.video_flag[self.count]:
|
if self.video_flag[self.count]:
|
||||||
# Read video
|
# Read video
|
||||||
self.mode = 'video'
|
self.mode = "video"
|
||||||
for _ in range(self.vid_stride):
|
for _ in range(self.vid_stride):
|
||||||
self.cap.grab()
|
self.cap.grab()
|
||||||
success, im0 = self.cap.retrieve()
|
success, im0 = self.cap.retrieve()
|
||||||
@ -338,15 +343,15 @@ class LoadImages:
|
|||||||
|
|
||||||
self.frame += 1
|
self.frame += 1
|
||||||
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
|
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
|
||||||
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Read image
|
# Read image
|
||||||
self.count += 1
|
self.count += 1
|
||||||
im0 = cv2.imread(path) # BGR
|
im0 = cv2.imread(path) # BGR
|
||||||
if im0 is None:
|
if im0 is None:
|
||||||
raise FileNotFoundError(f'Image Not Found {path}')
|
raise FileNotFoundError(f"Image Not Found {path}")
|
||||||
s = f'image {self.count}/{self.nf} {path}: '
|
s = f"image {self.count}/{self.nf} {path}: "
|
||||||
|
|
||||||
return [path], [im0], self.cap, s
|
return [path], [im0], self.cap, s
|
||||||
|
|
||||||
@ -385,20 +390,20 @@ class LoadPilAndNumpy:
|
|||||||
"""Initialize PIL and Numpy Dataloader."""
|
"""Initialize PIL and Numpy Dataloader."""
|
||||||
if not isinstance(im0, list):
|
if not isinstance(im0, list):
|
||||||
im0 = [im0]
|
im0 = [im0]
|
||||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
|
||||||
self.im0 = [self._single_check(im) for im in im0]
|
self.im0 = [self._single_check(im) for im in im0]
|
||||||
self.imgsz = imgsz
|
self.imgsz = imgsz
|
||||||
self.mode = 'image'
|
self.mode = "image"
|
||||||
# Generate fake paths
|
# Generate fake paths
|
||||||
self.bs = len(self.im0)
|
self.bs = len(self.im0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _single_check(im):
|
def _single_check(im):
|
||||||
"""Validate and format an image to numpy array."""
|
"""Validate and format an image to numpy array."""
|
||||||
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
|
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
|
||||||
if isinstance(im, Image.Image):
|
if isinstance(im, Image.Image):
|
||||||
if im.mode != 'RGB':
|
if im.mode != "RGB":
|
||||||
im = im.convert('RGB')
|
im = im.convert("RGB")
|
||||||
im = np.asarray(im)[:, :, ::-1]
|
im = np.asarray(im)[:, :, ::-1]
|
||||||
im = np.ascontiguousarray(im) # contiguous
|
im = np.ascontiguousarray(im) # contiguous
|
||||||
return im
|
return im
|
||||||
@ -412,7 +417,7 @@ class LoadPilAndNumpy:
|
|||||||
if self.count == 1: # loop only once as it's batch inference
|
if self.count == 1: # loop only once as it's batch inference
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
self.count += 1
|
self.count += 1
|
||||||
return self.paths, self.im0, None, ''
|
return self.paths, self.im0, None, ""
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""Enables iteration for class LoadPilAndNumpy."""
|
"""Enables iteration for class LoadPilAndNumpy."""
|
||||||
@ -441,14 +446,16 @@ class LoadTensor:
|
|||||||
"""Initialize Tensor Dataloader."""
|
"""Initialize Tensor Dataloader."""
|
||||||
self.im0 = self._single_check(im0)
|
self.im0 = self._single_check(im0)
|
||||||
self.bs = self.im0.shape[0]
|
self.bs = self.im0.shape[0]
|
||||||
self.mode = 'image'
|
self.mode = "image"
|
||||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _single_check(im, stride=32):
|
def _single_check(im, stride=32):
|
||||||
"""Validate and format an image to torch.Tensor."""
|
"""Validate and format an image to torch.Tensor."""
|
||||||
s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
|
s = (
|
||||||
f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
|
f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
|
||||||
|
f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
|
||||||
|
)
|
||||||
if len(im.shape) != 4:
|
if len(im.shape) != 4:
|
||||||
if len(im.shape) != 3:
|
if len(im.shape) != 3:
|
||||||
raise ValueError(s)
|
raise ValueError(s)
|
||||||
@ -457,8 +464,10 @@ class LoadTensor:
|
|||||||
if im.shape[2] % stride or im.shape[3] % stride:
|
if im.shape[2] % stride or im.shape[3] % stride:
|
||||||
raise ValueError(s)
|
raise ValueError(s)
|
||||||
if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07
|
if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07
|
||||||
LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. '
|
LOGGER.warning(
|
||||||
f'Dividing input by 255.')
|
f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. "
|
||||||
|
f"Dividing input by 255."
|
||||||
|
)
|
||||||
im = im.float() / 255.0
|
im = im.float() / 255.0
|
||||||
|
|
||||||
return im
|
return im
|
||||||
@ -473,7 +482,7 @@ class LoadTensor:
|
|||||||
if self.count == 1:
|
if self.count == 1:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
self.count += 1
|
self.count += 1
|
||||||
return self.paths, self.im0, None, ''
|
return self.paths, self.im0, None, ""
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Returns the batch size."""
|
"""Returns the batch size."""
|
||||||
@ -485,12 +494,14 @@ def autocast_list(source):
|
|||||||
files = []
|
files = []
|
||||||
for im in source:
|
for im in source:
|
||||||
if isinstance(im, (str, Path)): # filename or uri
|
if isinstance(im, (str, Path)): # filename or uri
|
||||||
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
|
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im))
|
||||||
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
||||||
files.append(im)
|
files.append(im)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
|
raise TypeError(
|
||||||
f'See https://docs.ultralytics.com/modes/predict for supported source types.')
|
f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
|
||||||
|
f"See https://docs.ultralytics.com/modes/predict for supported source types."
|
||||||
|
)
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
@ -513,16 +524,18 @@ def get_best_youtube_url(url, use_pafy=True):
|
|||||||
(str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
|
(str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
|
||||||
"""
|
"""
|
||||||
if use_pafy:
|
if use_pafy:
|
||||||
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
check_requirements(("pafy", "youtube_dl==2020.12.2"))
|
||||||
import pafy # noqa
|
import pafy # noqa
|
||||||
return pafy.new(url).getbestvideo(preftype='mp4').url
|
|
||||||
|
return pafy.new(url).getbestvideo(preftype="mp4").url
|
||||||
else:
|
else:
|
||||||
check_requirements('yt-dlp')
|
check_requirements("yt-dlp")
|
||||||
import yt_dlp
|
import yt_dlp
|
||||||
with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
|
|
||||||
|
with yt_dlp.YoutubeDL({"quiet": True}) as ydl:
|
||||||
info_dict = ydl.extract_info(url, download=False) # extract info
|
info_dict = ydl.extract_info(url, download=False) # extract info
|
||||||
for f in reversed(info_dict.get('formats', [])): # reversed because best is usually last
|
for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last
|
||||||
# Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
|
# Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
|
||||||
good_size = (f.get('width') or 0) >= 1920 or (f.get('height') or 0) >= 1080
|
good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
|
||||||
if good_size and f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4':
|
if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
|
||||||
return f.get('url')
|
return f.get("url")
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from tqdm import tqdm
|
|||||||
from ultralytics.data.utils import exif_size, img2label_paths
|
from ultralytics.data.utils import exif_size, img2label_paths
|
||||||
from ultralytics.utils.checks import check_requirements
|
from ultralytics.utils.checks import check_requirements
|
||||||
|
|
||||||
check_requirements('shapely')
|
check_requirements("shapely")
|
||||||
from shapely.geometry import Polygon
|
from shapely.geometry import Polygon
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ def bbox_iof(polygon1, bbox2, eps=1e-6):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def load_yolo_dota(data_root, split='train'):
|
def load_yolo_dota(data_root, split="train"):
|
||||||
"""
|
"""
|
||||||
Load DOTA dataset.
|
Load DOTA dataset.
|
||||||
|
|
||||||
@ -72,10 +72,10 @@ def load_yolo_dota(data_root, split='train'):
|
|||||||
- train
|
- train
|
||||||
- val
|
- val
|
||||||
"""
|
"""
|
||||||
assert split in ['train', 'val']
|
assert split in ["train", "val"]
|
||||||
im_dir = os.path.join(data_root, f'images/{split}')
|
im_dir = os.path.join(data_root, f"images/{split}")
|
||||||
assert Path(im_dir).exists(), f"Can't find {im_dir}, please check your data root."
|
assert Path(im_dir).exists(), f"Can't find {im_dir}, please check your data root."
|
||||||
im_files = glob(os.path.join(data_root, f'images/{split}/*'))
|
im_files = glob(os.path.join(data_root, f"images/{split}/*"))
|
||||||
lb_files = img2label_paths(im_files)
|
lb_files = img2label_paths(im_files)
|
||||||
annos = []
|
annos = []
|
||||||
for im_file, lb_file in zip(im_files, lb_files):
|
for im_file, lb_file in zip(im_files, lb_files):
|
||||||
@ -100,7 +100,7 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0
|
|||||||
h, w = im_size
|
h, w = im_size
|
||||||
windows = []
|
windows = []
|
||||||
for crop_size, gap in zip(crop_sizes, gaps):
|
for crop_size, gap in zip(crop_sizes, gaps):
|
||||||
assert crop_size > gap, f'invaild crop_size gap pair [{crop_size} {gap}]'
|
assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
|
||||||
step = crop_size - gap
|
step = crop_size - gap
|
||||||
|
|
||||||
xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
|
xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
|
||||||
@ -132,8 +132,8 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0
|
|||||||
|
|
||||||
def get_window_obj(anno, windows, iof_thr=0.7):
|
def get_window_obj(anno, windows, iof_thr=0.7):
|
||||||
"""Get objects for each window."""
|
"""Get objects for each window."""
|
||||||
h, w = anno['ori_size']
|
h, w = anno["ori_size"]
|
||||||
label = anno['label']
|
label = anno["label"]
|
||||||
if len(label):
|
if len(label):
|
||||||
label[:, 1::2] *= w
|
label[:, 1::2] *= w
|
||||||
label[:, 2::2] *= h
|
label[:, 2::2] *= h
|
||||||
@ -166,15 +166,15 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
|
|||||||
- train
|
- train
|
||||||
- val
|
- val
|
||||||
"""
|
"""
|
||||||
im = cv2.imread(anno['filepath'])
|
im = cv2.imread(anno["filepath"])
|
||||||
name = Path(anno['filepath']).stem
|
name = Path(anno["filepath"]).stem
|
||||||
for i, window in enumerate(windows):
|
for i, window in enumerate(windows):
|
||||||
x_start, y_start, x_stop, y_stop = window.tolist()
|
x_start, y_start, x_stop, y_stop = window.tolist()
|
||||||
new_name = name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start)
|
new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
|
||||||
patch_im = im[y_start:y_stop, x_start:x_stop]
|
patch_im = im[y_start:y_stop, x_start:x_stop]
|
||||||
ph, pw = patch_im.shape[:2]
|
ph, pw = patch_im.shape[:2]
|
||||||
|
|
||||||
cv2.imwrite(os.path.join(im_dir, f'{new_name}.jpg'), patch_im)
|
cv2.imwrite(os.path.join(im_dir, f"{new_name}.jpg"), patch_im)
|
||||||
label = window_objs[i]
|
label = window_objs[i]
|
||||||
if len(label) == 0:
|
if len(label) == 0:
|
||||||
continue
|
continue
|
||||||
@ -183,13 +183,13 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
|
|||||||
label[:, 1::2] /= pw
|
label[:, 1::2] /= pw
|
||||||
label[:, 2::2] /= ph
|
label[:, 2::2] /= ph
|
||||||
|
|
||||||
with open(os.path.join(lb_dir, f'{new_name}.txt'), 'w') as f:
|
with open(os.path.join(lb_dir, f"{new_name}.txt"), "w") as f:
|
||||||
for lb in label:
|
for lb in label:
|
||||||
formatted_coords = ['{:.6g}'.format(coord) for coord in lb[1:]]
|
formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]]
|
||||||
f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
|
f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
|
||||||
|
|
||||||
|
|
||||||
def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024], gaps=[200]):
|
def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=[1024], gaps=[200]):
|
||||||
"""
|
"""
|
||||||
Split both images and labels.
|
Split both images and labels.
|
||||||
|
|
||||||
@ -207,14 +207,14 @@ def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024
|
|||||||
- labels
|
- labels
|
||||||
- split
|
- split
|
||||||
"""
|
"""
|
||||||
im_dir = Path(save_dir) / 'images' / split
|
im_dir = Path(save_dir) / "images" / split
|
||||||
im_dir.mkdir(parents=True, exist_ok=True)
|
im_dir.mkdir(parents=True, exist_ok=True)
|
||||||
lb_dir = Path(save_dir) / 'labels' / split
|
lb_dir = Path(save_dir) / "labels" / split
|
||||||
lb_dir.mkdir(parents=True, exist_ok=True)
|
lb_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
annos = load_yolo_dota(data_root, split=split)
|
annos = load_yolo_dota(data_root, split=split)
|
||||||
for anno in tqdm(annos, total=len(annos), desc=split):
|
for anno in tqdm(annos, total=len(annos), desc=split):
|
||||||
windows = get_windows(anno['ori_size'], crop_sizes, gaps)
|
windows = get_windows(anno["ori_size"], crop_sizes, gaps)
|
||||||
window_objs = get_window_obj(anno, windows)
|
window_objs = get_window_obj(anno, windows)
|
||||||
crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
|
crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
|
|||||||
for r in rates:
|
for r in rates:
|
||||||
crop_sizes.append(int(crop_size / r))
|
crop_sizes.append(int(crop_size / r))
|
||||||
gaps.append(int(gap / r))
|
gaps.append(int(gap / r))
|
||||||
for split in ['train', 'val']:
|
for split in ["train", "val"]:
|
||||||
split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
|
split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
|
||||||
|
|
||||||
|
|
||||||
@ -267,30 +267,30 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
|
|||||||
for r in rates:
|
for r in rates:
|
||||||
crop_sizes.append(int(crop_size / r))
|
crop_sizes.append(int(crop_size / r))
|
||||||
gaps.append(int(gap / r))
|
gaps.append(int(gap / r))
|
||||||
save_dir = Path(save_dir) / 'images' / 'test'
|
save_dir = Path(save_dir) / "images" / "test"
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
im_dir = Path(os.path.join(data_root, 'images/test'))
|
im_dir = Path(os.path.join(data_root, "images/test"))
|
||||||
assert im_dir.exists(), f"Can't find {str(im_dir)}, please check your data root."
|
assert im_dir.exists(), f"Can't find {str(im_dir)}, please check your data root."
|
||||||
im_files = glob(str(im_dir / '*'))
|
im_files = glob(str(im_dir / "*"))
|
||||||
for im_file in tqdm(im_files, total=len(im_files), desc='test'):
|
for im_file in tqdm(im_files, total=len(im_files), desc="test"):
|
||||||
w, h = exif_size(Image.open(im_file))
|
w, h = exif_size(Image.open(im_file))
|
||||||
windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
|
windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
|
||||||
im = cv2.imread(im_file)
|
im = cv2.imread(im_file)
|
||||||
name = Path(im_file).stem
|
name = Path(im_file).stem
|
||||||
for window in windows:
|
for window in windows:
|
||||||
x_start, y_start, x_stop, y_stop = window.tolist()
|
x_start, y_start, x_stop, y_stop = window.tolist()
|
||||||
new_name = (name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start))
|
new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
|
||||||
patch_im = im[y_start:y_stop, x_start:x_stop]
|
patch_im = im[y_start:y_stop, x_start:x_stop]
|
||||||
cv2.imwrite(os.path.join(str(save_dir), f'{new_name}.jpg'), patch_im)
|
cv2.imwrite(os.path.join(str(save_dir), f"{new_name}.jpg"), patch_im)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
split_trainval(
|
split_trainval(
|
||||||
data_root='DOTAv2',
|
data_root="DOTAv2",
|
||||||
save_dir='DOTAv2-split',
|
save_dir="DOTAv2-split",
|
||||||
)
|
)
|
||||||
split_test(
|
split_test(
|
||||||
data_root='DOTAv2',
|
data_root="DOTAv2",
|
||||||
save_dir='DOTAv2-split',
|
save_dir="DOTAv2-split",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -17,36 +17,47 @@ import numpy as np
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from ultralytics.nn.autobackend import check_class_names
|
from ultralytics.nn.autobackend import check_class_names
|
||||||
from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, TQDM, clean_url, colorstr,
|
from ultralytics.utils import (
|
||||||
emojis, yaml_load, yaml_save)
|
DATASETS_DIR,
|
||||||
|
LOGGER,
|
||||||
|
NUM_THREADS,
|
||||||
|
ROOT,
|
||||||
|
SETTINGS_YAML,
|
||||||
|
TQDM,
|
||||||
|
clean_url,
|
||||||
|
colorstr,
|
||||||
|
emojis,
|
||||||
|
yaml_load,
|
||||||
|
yaml_save,
|
||||||
|
)
|
||||||
from ultralytics.utils.checks import check_file, check_font, is_ascii
|
from ultralytics.utils.checks import check_file, check_font, is_ascii
|
||||||
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
||||||
from ultralytics.utils.ops import segments2boxes
|
from ultralytics.utils.ops import segments2boxes
|
||||||
|
|
||||||
HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.'
|
HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance."
|
||||||
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
|
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # image suffixes
|
||||||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
|
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm" # video suffixes
|
||||||
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||||
|
|
||||||
|
|
||||||
def img2label_paths(img_paths):
|
def img2label_paths(img_paths):
|
||||||
"""Define label paths as a function of image paths."""
|
"""Define label paths as a function of image paths."""
|
||||||
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
|
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
|
||||||
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
|
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
||||||
|
|
||||||
|
|
||||||
def get_hash(paths):
|
def get_hash(paths):
|
||||||
"""Returns a single hash value of a list of paths (files or dirs)."""
|
"""Returns a single hash value of a list of paths (files or dirs)."""
|
||||||
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
||||||
h = hashlib.sha256(str(size).encode()) # hash sizes
|
h = hashlib.sha256(str(size).encode()) # hash sizes
|
||||||
h.update(''.join(paths).encode()) # hash paths
|
h.update("".join(paths).encode()) # hash paths
|
||||||
return h.hexdigest() # return hash
|
return h.hexdigest() # return hash
|
||||||
|
|
||||||
|
|
||||||
def exif_size(img: Image.Image):
|
def exif_size(img: Image.Image):
|
||||||
"""Returns exif-corrected PIL size."""
|
"""Returns exif-corrected PIL size."""
|
||||||
s = img.size # (width, height)
|
s = img.size # (width, height)
|
||||||
if img.format == 'JPEG': # only support JPEG images
|
if img.format == "JPEG": # only support JPEG images
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
exif = img.getexif()
|
exif = img.getexif()
|
||||||
if exif:
|
if exif:
|
||||||
@ -60,24 +71,24 @@ def verify_image(args):
|
|||||||
"""Verify one image."""
|
"""Verify one image."""
|
||||||
(im_file, cls), prefix = args
|
(im_file, cls), prefix = args
|
||||||
# Number (found, corrupt), message
|
# Number (found, corrupt), message
|
||||||
nf, nc, msg = 0, 0, ''
|
nf, nc, msg = 0, 0, ""
|
||||||
try:
|
try:
|
||||||
im = Image.open(im_file)
|
im = Image.open(im_file)
|
||||||
im.verify() # PIL verify
|
im.verify() # PIL verify
|
||||||
shape = exif_size(im) # image size
|
shape = exif_size(im) # image size
|
||||||
shape = (shape[1], shape[0]) # hw
|
shape = (shape[1], shape[0]) # hw
|
||||||
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||||
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
|
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
||||||
if im.format.lower() in ('jpg', 'jpeg'):
|
if im.format.lower() in ("jpg", "jpeg"):
|
||||||
with open(im_file, 'rb') as f:
|
with open(im_file, "rb") as f:
|
||||||
f.seek(-2, 2)
|
f.seek(-2, 2)
|
||||||
if f.read() != b'\xff\xd9': # corrupt JPEG
|
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
|
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
||||||
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
|
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
|
||||||
nf = 1
|
nf = 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
nc = 1
|
nc = 1
|
||||||
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
|
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
|
||||||
return (im_file, cls), nf, nc, msg
|
return (im_file, cls), nf, nc, msg
|
||||||
|
|
||||||
|
|
||||||
@ -85,21 +96,21 @@ def verify_image_label(args):
|
|||||||
"""Verify one image-label pair."""
|
"""Verify one image-label pair."""
|
||||||
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
|
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
|
||||||
# Number (missing, found, empty, corrupt), message, segments, keypoints
|
# Number (missing, found, empty, corrupt), message, segments, keypoints
|
||||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
|
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
|
||||||
try:
|
try:
|
||||||
# Verify images
|
# Verify images
|
||||||
im = Image.open(im_file)
|
im = Image.open(im_file)
|
||||||
im.verify() # PIL verify
|
im.verify() # PIL verify
|
||||||
shape = exif_size(im) # image size
|
shape = exif_size(im) # image size
|
||||||
shape = (shape[1], shape[0]) # hw
|
shape = (shape[1], shape[0]) # hw
|
||||||
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||||
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
|
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
||||||
if im.format.lower() in ('jpg', 'jpeg'):
|
if im.format.lower() in ("jpg", "jpeg"):
|
||||||
with open(im_file, 'rb') as f:
|
with open(im_file, "rb") as f:
|
||||||
f.seek(-2, 2)
|
f.seek(-2, 2)
|
||||||
if f.read() != b'\xff\xd9': # corrupt JPEG
|
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
|
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
||||||
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
|
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
|
||||||
|
|
||||||
# Verify labels
|
# Verify labels
|
||||||
if os.path.isfile(lb_file):
|
if os.path.isfile(lb_file):
|
||||||
@ -114,25 +125,26 @@ def verify_image_label(args):
|
|||||||
nl = len(lb)
|
nl = len(lb)
|
||||||
if nl:
|
if nl:
|
||||||
if keypoint:
|
if keypoint:
|
||||||
assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
|
assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
|
||||||
points = lb[:, 5:].reshape(-1, ndim)[:, :2]
|
points = lb[:, 5:].reshape(-1, ndim)[:, :2]
|
||||||
else:
|
else:
|
||||||
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
|
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
||||||
points = lb[:, 1:]
|
points = lb[:, 1:]
|
||||||
assert points.max() <= 1, f'non-normalized or out of bounds coordinates {points[points > 1]}'
|
assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
|
||||||
assert lb.min() >= 0, f'negative label values {lb[lb < 0]}'
|
assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
|
||||||
|
|
||||||
# All labels
|
# All labels
|
||||||
max_cls = lb[:, 0].max() # max label count
|
max_cls = lb[:, 0].max() # max label count
|
||||||
assert max_cls <= num_cls, \
|
assert max_cls <= num_cls, (
|
||||||
f'Label class {int(max_cls)} exceeds dataset class count {num_cls}. ' \
|
f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
|
||||||
f'Possible class labels are 0-{num_cls - 1}'
|
f"Possible class labels are 0-{num_cls - 1}"
|
||||||
|
)
|
||||||
_, i = np.unique(lb, axis=0, return_index=True)
|
_, i = np.unique(lb, axis=0, return_index=True)
|
||||||
if len(i) < nl: # duplicate row check
|
if len(i) < nl: # duplicate row check
|
||||||
lb = lb[i] # remove duplicates
|
lb = lb[i] # remove duplicates
|
||||||
if segments:
|
if segments:
|
||||||
segments = [segments[x] for x in i]
|
segments = [segments[x] for x in i]
|
||||||
msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
|
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
|
||||||
else:
|
else:
|
||||||
ne = 1 # label empty
|
ne = 1 # label empty
|
||||||
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
|
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
|
||||||
@ -148,7 +160,7 @@ def verify_image_label(args):
|
|||||||
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
nc = 1
|
nc = 1
|
||||||
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
|
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
|
||||||
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
||||||
|
|
||||||
|
|
||||||
@ -194,8 +206,10 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
|
|||||||
|
|
||||||
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
||||||
"""Return a (640, 640) overlap mask."""
|
"""Return a (640, 640) overlap mask."""
|
||||||
masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
|
masks = np.zeros(
|
||||||
dtype=np.int32 if len(segments) > 255 else np.uint8)
|
(imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
|
||||||
|
dtype=np.int32 if len(segments) > 255 else np.uint8,
|
||||||
|
)
|
||||||
areas = []
|
areas = []
|
||||||
ms = []
|
ms = []
|
||||||
for si in range(len(segments)):
|
for si in range(len(segments)):
|
||||||
@ -226,7 +240,7 @@ def find_dataset_yaml(path: Path) -> Path:
|
|||||||
Returns:
|
Returns:
|
||||||
(Path): The path of the found YAML file.
|
(Path): The path of the found YAML file.
|
||||||
"""
|
"""
|
||||||
files = list(path.glob('*.yaml')) or list(path.rglob('*.yaml')) # try root level first and then recursive
|
files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
|
||||||
assert files, f"No YAML file found in '{path.resolve()}'"
|
assert files, f"No YAML file found in '{path.resolve()}'"
|
||||||
if len(files) > 1:
|
if len(files) > 1:
|
||||||
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
|
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
|
||||||
@ -253,7 +267,7 @@ def check_det_dataset(dataset, autodownload=True):
|
|||||||
file = check_file(dataset)
|
file = check_file(dataset)
|
||||||
|
|
||||||
# Download (optional)
|
# Download (optional)
|
||||||
extract_dir = ''
|
extract_dir = ""
|
||||||
if zipfile.is_zipfile(file) or is_tarfile(file):
|
if zipfile.is_zipfile(file) or is_tarfile(file):
|
||||||
new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||||
file = find_dataset_yaml(DATASETS_DIR / new_dir)
|
file = find_dataset_yaml(DATASETS_DIR / new_dir)
|
||||||
@ -263,43 +277,44 @@ def check_det_dataset(dataset, autodownload=True):
|
|||||||
data = yaml_load(file, append_filename=True) # dictionary
|
data = yaml_load(file, append_filename=True) # dictionary
|
||||||
|
|
||||||
# Checks
|
# Checks
|
||||||
for k in 'train', 'val':
|
for k in "train", "val":
|
||||||
if k not in data:
|
if k not in data:
|
||||||
if k != 'val' or 'validation' not in data:
|
if k != "val" or "validation" not in data:
|
||||||
raise SyntaxError(
|
raise SyntaxError(
|
||||||
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
|
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
|
||||||
|
)
|
||||||
LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
|
LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
|
||||||
data['val'] = data.pop('validation') # replace 'validation' key with 'val' key
|
data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
|
||||||
if 'names' not in data and 'nc' not in data:
|
if "names" not in data and "nc" not in data:
|
||||||
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
|
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
|
||||||
if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
|
if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
|
||||||
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
|
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
|
||||||
if 'names' not in data:
|
if "names" not in data:
|
||||||
data['names'] = [f'class_{i}' for i in range(data['nc'])]
|
data["names"] = [f"class_{i}" for i in range(data["nc"])]
|
||||||
else:
|
else:
|
||||||
data['nc'] = len(data['names'])
|
data["nc"] = len(data["names"])
|
||||||
|
|
||||||
data['names'] = check_class_names(data['names'])
|
data["names"] = check_class_names(data["names"])
|
||||||
|
|
||||||
# Resolve paths
|
# Resolve paths
|
||||||
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
|
path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
|
||||||
if not path.is_absolute():
|
if not path.is_absolute():
|
||||||
path = (DATASETS_DIR / path).resolve()
|
path = (DATASETS_DIR / path).resolve()
|
||||||
|
|
||||||
# Set paths
|
# Set paths
|
||||||
data['path'] = path # download scripts
|
data["path"] = path # download scripts
|
||||||
for k in 'train', 'val', 'test':
|
for k in "train", "val", "test":
|
||||||
if data.get(k): # prepend path
|
if data.get(k): # prepend path
|
||||||
if isinstance(data[k], str):
|
if isinstance(data[k], str):
|
||||||
x = (path / data[k]).resolve()
|
x = (path / data[k]).resolve()
|
||||||
if not x.exists() and data[k].startswith('../'):
|
if not x.exists() and data[k].startswith("../"):
|
||||||
x = (path / data[k][3:]).resolve()
|
x = (path / data[k][3:]).resolve()
|
||||||
data[k] = str(x)
|
data[k] = str(x)
|
||||||
else:
|
else:
|
||||||
data[k] = [str((path / x).resolve()) for x in data[k]]
|
data[k] = [str((path / x).resolve()) for x in data[k]]
|
||||||
|
|
||||||
# Parse YAML
|
# Parse YAML
|
||||||
val, s = (data.get(x) for x in ('val', 'download'))
|
val, s = (data.get(x) for x in ("val", "download"))
|
||||||
if val:
|
if val:
|
||||||
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
||||||
if not all(x.exists() for x in val):
|
if not all(x.exists() for x in val):
|
||||||
@ -312,22 +327,22 @@ def check_det_dataset(dataset, autodownload=True):
|
|||||||
raise FileNotFoundError(m)
|
raise FileNotFoundError(m)
|
||||||
t = time.time()
|
t = time.time()
|
||||||
r = None # success
|
r = None # success
|
||||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
if s.startswith("http") and s.endswith(".zip"): # URL
|
||||||
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
||||||
elif s.startswith('bash '): # bash script
|
elif s.startswith("bash "): # bash script
|
||||||
LOGGER.info(f'Running {s} ...')
|
LOGGER.info(f"Running {s} ...")
|
||||||
r = os.system(s)
|
r = os.system(s)
|
||||||
else: # python script
|
else: # python script
|
||||||
exec(s, {'yaml': data})
|
exec(s, {"yaml": data})
|
||||||
dt = f'({round(time.time() - t, 1)}s)'
|
dt = f"({round(time.time() - t, 1)}s)"
|
||||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
|
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
||||||
LOGGER.info(f'Dataset download {s}\n')
|
LOGGER.info(f"Dataset download {s}\n")
|
||||||
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
|
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
|
||||||
|
|
||||||
return data # dictionary
|
return data # dictionary
|
||||||
|
|
||||||
|
|
||||||
def check_cls_dataset(dataset, split=''):
|
def check_cls_dataset(dataset, split=""):
|
||||||
"""
|
"""
|
||||||
Checks a classification dataset such as Imagenet.
|
Checks a classification dataset such as Imagenet.
|
||||||
|
|
||||||
@ -348,54 +363,59 @@ def check_cls_dataset(dataset, split=''):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Download (optional if dataset=https://file.zip is passed directly)
|
# Download (optional if dataset=https://file.zip is passed directly)
|
||||||
if str(dataset).startswith(('http:/', 'https:/')):
|
if str(dataset).startswith(("http:/", "https:/")):
|
||||||
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||||
|
|
||||||
dataset = Path(dataset)
|
dataset = Path(dataset)
|
||||||
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
|
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
|
||||||
if not data_dir.is_dir():
|
if not data_dir.is_dir():
|
||||||
LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
|
||||||
t = time.time()
|
t = time.time()
|
||||||
if str(dataset) == 'imagenet':
|
if str(dataset) == "imagenet":
|
||||||
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||||
else:
|
else:
|
||||||
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip"
|
||||||
download(url, dir=data_dir.parent)
|
download(url, dir=data_dir.parent)
|
||||||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
||||||
LOGGER.info(s)
|
LOGGER.info(s)
|
||||||
train_set = data_dir / 'train'
|
train_set = data_dir / "train"
|
||||||
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if \
|
val_set = (
|
||||||
(data_dir / 'validation').exists() else None # data/test or data/val
|
data_dir / "val"
|
||||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
|
if (data_dir / "val").exists()
|
||||||
if split == 'val' and not val_set:
|
else data_dir / "validation"
|
||||||
|
if (data_dir / "validation").exists()
|
||||||
|
else None
|
||||||
|
) # data/test or data/val
|
||||||
|
test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
|
||||||
|
if split == "val" and not val_set:
|
||||||
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
||||||
elif split == 'test' and not test_set:
|
elif split == "test" and not test_set:
|
||||||
LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
|
LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
|
||||||
|
|
||||||
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
|
||||||
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
|
||||||
names = dict(enumerate(sorted(names)))
|
names = dict(enumerate(sorted(names)))
|
||||||
|
|
||||||
# Print to console
|
# Print to console
|
||||||
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
|
for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
|
||||||
prefix = f'{colorstr(f"{k}:")} {v}...'
|
prefix = f'{colorstr(f"{k}:")} {v}...'
|
||||||
if v is None:
|
if v is None:
|
||||||
LOGGER.info(prefix)
|
LOGGER.info(prefix)
|
||||||
else:
|
else:
|
||||||
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
|
files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
|
||||||
nf = len(files) # number of files
|
nf = len(files) # number of files
|
||||||
nd = len({file.parent for file in files}) # number of directories
|
nd = len({file.parent for file in files}) # number of directories
|
||||||
if nf == 0:
|
if nf == 0:
|
||||||
if k == 'train':
|
if k == "train":
|
||||||
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
|
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
|
||||||
else:
|
else:
|
||||||
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found')
|
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
|
||||||
elif nd != nc:
|
elif nd != nc:
|
||||||
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}')
|
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
|
||||||
else:
|
else:
|
||||||
LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ')
|
LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
|
||||||
|
|
||||||
return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
|
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
|
||||||
|
|
||||||
|
|
||||||
class HUBDatasetStats:
|
class HUBDatasetStats:
|
||||||
@ -423,42 +443,43 @@ class HUBDatasetStats:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, path='coco8.yaml', task='detect', autodownload=False):
|
def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
|
||||||
"""Initialize class."""
|
"""Initialize class."""
|
||||||
path = Path(path).resolve()
|
path = Path(path).resolve()
|
||||||
LOGGER.info(f'Starting HUB dataset checks for {path}....')
|
LOGGER.info(f"Starting HUB dataset checks for {path}....")
|
||||||
|
|
||||||
self.task = task # detect, segment, pose, classify
|
self.task = task # detect, segment, pose, classify
|
||||||
if self.task == 'classify':
|
if self.task == "classify":
|
||||||
unzip_dir = unzip_file(path)
|
unzip_dir = unzip_file(path)
|
||||||
data = check_cls_dataset(unzip_dir)
|
data = check_cls_dataset(unzip_dir)
|
||||||
data['path'] = unzip_dir
|
data["path"] = unzip_dir
|
||||||
else: # detect, segment, pose
|
else: # detect, segment, pose
|
||||||
_, data_dir, yaml_path = self._unzip(Path(path))
|
_, data_dir, yaml_path = self._unzip(Path(path))
|
||||||
try:
|
try:
|
||||||
# Load YAML with checks
|
# Load YAML with checks
|
||||||
data = yaml_load(yaml_path)
|
data = yaml_load(yaml_path)
|
||||||
data['path'] = '' # strip path since YAML should be in dataset root for all HUB datasets
|
data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
|
||||||
yaml_save(yaml_path, data)
|
yaml_save(yaml_path, data)
|
||||||
data = check_det_dataset(yaml_path, autodownload) # dict
|
data = check_det_dataset(yaml_path, autodownload) # dict
|
||||||
data['path'] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
|
data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception('error/HUB/dataset_stats/init') from e
|
raise Exception("error/HUB/dataset_stats/init") from e
|
||||||
|
|
||||||
self.hub_dir = Path(f'{data["path"]}-hub')
|
self.hub_dir = Path(f'{data["path"]}-hub')
|
||||||
self.im_dir = self.hub_dir / 'images'
|
self.im_dir = self.hub_dir / "images"
|
||||||
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
||||||
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
|
self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _unzip(path):
|
def _unzip(path):
|
||||||
"""Unzip data.zip."""
|
"""Unzip data.zip."""
|
||||||
if not str(path).endswith('.zip'): # path is data.yaml
|
if not str(path).endswith(".zip"): # path is data.yaml
|
||||||
return False, None, path
|
return False, None, path
|
||||||
unzip_dir = unzip_file(path, path=path.parent)
|
unzip_dir = unzip_file(path, path=path.parent)
|
||||||
assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
|
assert unzip_dir.is_dir(), (
|
||||||
f'path/to/abc.zip MUST unzip to path/to/abc/'
|
f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/"
|
||||||
|
)
|
||||||
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
||||||
|
|
||||||
def _hub_ops(self, f):
|
def _hub_ops(self, f):
|
||||||
@ -470,31 +491,31 @@ class HUBDatasetStats:
|
|||||||
|
|
||||||
def _round(labels):
|
def _round(labels):
|
||||||
"""Update labels to integer class and 4 decimal place floats."""
|
"""Update labels to integer class and 4 decimal place floats."""
|
||||||
if self.task == 'detect':
|
if self.task == "detect":
|
||||||
coordinates = labels['bboxes']
|
coordinates = labels["bboxes"]
|
||||||
elif self.task == 'segment':
|
elif self.task == "segment":
|
||||||
coordinates = [x.flatten() for x in labels['segments']]
|
coordinates = [x.flatten() for x in labels["segments"]]
|
||||||
elif self.task == 'pose':
|
elif self.task == "pose":
|
||||||
n = labels['keypoints'].shape[0]
|
n = labels["keypoints"].shape[0]
|
||||||
coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
|
coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Undefined dataset task.')
|
raise ValueError("Undefined dataset task.")
|
||||||
zipped = zip(labels['cls'], coordinates)
|
zipped = zip(labels["cls"], coordinates)
|
||||||
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
|
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
|
||||||
|
|
||||||
for split in 'train', 'val', 'test':
|
for split in "train", "val", "test":
|
||||||
self.stats[split] = None # predefine
|
self.stats[split] = None # predefine
|
||||||
path = self.data.get(split)
|
path = self.data.get(split)
|
||||||
|
|
||||||
# Check split
|
# Check split
|
||||||
if path is None: # no split
|
if path is None: # no split
|
||||||
continue
|
continue
|
||||||
files = [f for f in Path(path).rglob('*.*') if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
|
files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
|
||||||
if not files: # no images
|
if not files: # no images
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get dataset statistics
|
# Get dataset statistics
|
||||||
if self.task == 'classify':
|
if self.task == "classify":
|
||||||
from torchvision.datasets import ImageFolder
|
from torchvision.datasets import ImageFolder
|
||||||
|
|
||||||
dataset = ImageFolder(self.data[split])
|
dataset = ImageFolder(self.data[split])
|
||||||
@ -504,38 +525,35 @@ class HUBDatasetStats:
|
|||||||
x[im[1]] += 1
|
x[im[1]] += 1
|
||||||
|
|
||||||
self.stats[split] = {
|
self.stats[split] = {
|
||||||
'instance_stats': {
|
"instance_stats": {"total": len(dataset), "per_class": x.tolist()},
|
||||||
'total': len(dataset),
|
"image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
|
||||||
'per_class': x.tolist()},
|
"labels": [{Path(k).name: v} for k, v in dataset.imgs],
|
||||||
'image_stats': {
|
}
|
||||||
'total': len(dataset),
|
|
||||||
'unlabelled': 0,
|
|
||||||
'per_class': x.tolist()},
|
|
||||||
'labels': [{
|
|
||||||
Path(k).name: v} for k, v in dataset.imgs]}
|
|
||||||
else:
|
else:
|
||||||
from ultralytics.data import YOLODataset
|
from ultralytics.data import YOLODataset
|
||||||
|
|
||||||
dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
|
dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
|
||||||
x = np.array([
|
x = np.array(
|
||||||
np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
|
[
|
||||||
for label in TQDM(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
|
np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
|
||||||
|
for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
|
||||||
|
]
|
||||||
|
) # shape(128x80)
|
||||||
self.stats[split] = {
|
self.stats[split] = {
|
||||||
'instance_stats': {
|
"instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
|
||||||
'total': int(x.sum()),
|
"image_stats": {
|
||||||
'per_class': x.sum(0).tolist()},
|
"total": len(dataset),
|
||||||
'image_stats': {
|
"unlabelled": int(np.all(x == 0, 1).sum()),
|
||||||
'total': len(dataset),
|
"per_class": (x > 0).sum(0).tolist(),
|
||||||
'unlabelled': int(np.all(x == 0, 1).sum()),
|
},
|
||||||
'per_class': (x > 0).sum(0).tolist()},
|
"labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
|
||||||
'labels': [{
|
}
|
||||||
Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
|
|
||||||
|
|
||||||
# Save, print and return
|
# Save, print and return
|
||||||
if save:
|
if save:
|
||||||
stats_path = self.hub_dir / 'stats.json'
|
stats_path = self.hub_dir / "stats.json"
|
||||||
LOGGER.info(f'Saving {stats_path.resolve()}...')
|
LOGGER.info(f"Saving {stats_path.resolve()}...")
|
||||||
with open(stats_path, 'w') as f:
|
with open(stats_path, "w") as f:
|
||||||
json.dump(self.stats, f) # save stats.json
|
json.dump(self.stats, f) # save stats.json
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
||||||
@ -545,14 +563,14 @@ class HUBDatasetStats:
|
|||||||
"""Compress images for Ultralytics HUB."""
|
"""Compress images for Ultralytics HUB."""
|
||||||
from ultralytics.data import YOLODataset # ClassificationDataset
|
from ultralytics.data import YOLODataset # ClassificationDataset
|
||||||
|
|
||||||
for split in 'train', 'val', 'test':
|
for split in "train", "val", "test":
|
||||||
if self.data.get(split) is None:
|
if self.data.get(split) is None:
|
||||||
continue
|
continue
|
||||||
dataset = YOLODataset(img_path=self.data[split], data=self.data)
|
dataset = YOLODataset(img_path=self.data[split], data=self.data)
|
||||||
with ThreadPool(NUM_THREADS) as pool:
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
|
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
|
||||||
pass
|
pass
|
||||||
LOGGER.info(f'Done. All images saved to {self.im_dir}')
|
LOGGER.info(f"Done. All images saved to {self.im_dir}")
|
||||||
return self.im_dir
|
return self.im_dir
|
||||||
|
|
||||||
|
|
||||||
@ -583,9 +601,9 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
|||||||
r = max_dim / max(im.height, im.width) # ratio
|
r = max_dim / max(im.height, im.width) # ratio
|
||||||
if r < 1.0: # image too large
|
if r < 1.0: # image too large
|
||||||
im = im.resize((int(im.width * r), int(im.height * r)))
|
im = im.resize((int(im.width * r), int(im.height * r)))
|
||||||
im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
|
im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
|
||||||
except Exception as e: # use OpenCV
|
except Exception as e: # use OpenCV
|
||||||
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
|
LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}")
|
||||||
im = cv2.imread(f)
|
im = cv2.imread(f)
|
||||||
im_height, im_width = im.shape[:2]
|
im_height, im_width = im.shape[:2]
|
||||||
r = max_dim / max(im_height, im_width) # ratio
|
r = max_dim / max(im_height, im_width) # ratio
|
||||||
@ -594,7 +612,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
|||||||
cv2.imwrite(str(f_new or f), im)
|
cv2.imwrite(str(f_new or f), im)
|
||||||
|
|
||||||
|
|
||||||
def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
|
def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
|
||||||
"""
|
"""
|
||||||
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
|
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
|
||||||
|
|
||||||
@ -612,18 +630,18 @@ def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annot
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
path = Path(path) # images dir
|
path = Path(path) # images dir
|
||||||
files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
|
files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
|
||||||
n = len(files) # number of files
|
n = len(files) # number of files
|
||||||
random.seed(0) # for reproducibility
|
random.seed(0) # for reproducibility
|
||||||
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
||||||
|
|
||||||
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
|
txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
|
||||||
for x in txt:
|
for x in txt:
|
||||||
if (path.parent / x).exists():
|
if (path.parent / x).exists():
|
||||||
(path.parent / x).unlink() # remove existing
|
(path.parent / x).unlink() # remove existing
|
||||||
|
|
||||||
LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
|
LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
|
||||||
for i, img in TQDM(zip(indices, files), total=n):
|
for i, img in TQDM(zip(indices, files), total=n):
|
||||||
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
||||||
with open(path.parent / txt[i], 'a') as f:
|
with open(path.parent / txt[i], "a") as f:
|
||||||
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
|
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -53,7 +53,7 @@ class Model(nn.Module):
|
|||||||
list(ultralytics.engine.results.Results): The prediction results.
|
list(ultralytics.engine.results.Results): The prediction results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
|
def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the YOLO model.
|
Initializes the YOLO model.
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
# Load or create new YOLO model
|
# Load or create new YOLO model
|
||||||
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||||
if Path(model).suffix in ('.yaml', '.yml'):
|
if Path(model).suffix in (".yaml", ".yml"):
|
||||||
self._new(model, task)
|
self._new(model, task)
|
||||||
else:
|
else:
|
||||||
self._load(model, task)
|
self._load(model, task)
|
||||||
@ -112,16 +112,20 @@ class Model(nn.Module):
|
|||||||
def is_triton_model(model):
|
def is_triton_model(model):
|
||||||
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
url = urlsplit(model)
|
url = urlsplit(model)
|
||||||
return url.netloc and url.path and url.scheme in {'http', 'grpc'}
|
return url.netloc and url.path and url.scheme in {"http", "grpc"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_hub_model(model):
|
def is_hub_model(model):
|
||||||
"""Check if the provided model is a HUB model."""
|
"""Check if the provided model is a HUB model."""
|
||||||
return any((
|
return any(
|
||||||
model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
(
|
||||||
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
||||||
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
|
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
|
||||||
|
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"),
|
||||||
|
)
|
||||||
|
) # MODELID
|
||||||
|
|
||||||
def _new(self, cfg: str, task=None, model=None, verbose=True):
|
def _new(self, cfg: str, task=None, model=None, verbose=True):
|
||||||
"""
|
"""
|
||||||
@ -136,9 +140,9 @@ class Model(nn.Module):
|
|||||||
cfg_dict = yaml_model_load(cfg)
|
cfg_dict = yaml_model_load(cfg)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.task = task or guess_model_task(cfg_dict)
|
self.task = task or guess_model_task(cfg_dict)
|
||||||
self.model = (model or self._smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model
|
self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||||
self.overrides['model'] = self.cfg
|
self.overrides["model"] = self.cfg
|
||||||
self.overrides['task'] = self.task
|
self.overrides["task"] = self.task
|
||||||
|
|
||||||
# Below added to allow export from YAMLs
|
# Below added to allow export from YAMLs
|
||||||
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
||||||
@ -153,9 +157,9 @@ class Model(nn.Module):
|
|||||||
task (str | None): model task
|
task (str | None): model task
|
||||||
"""
|
"""
|
||||||
suffix = Path(weights).suffix
|
suffix = Path(weights).suffix
|
||||||
if suffix == '.pt':
|
if suffix == ".pt":
|
||||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||||
self.task = self.model.args['task']
|
self.task = self.model.args["task"]
|
||||||
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
||||||
self.ckpt_path = self.model.pt_path
|
self.ckpt_path = self.model.pt_path
|
||||||
else:
|
else:
|
||||||
@ -163,12 +167,12 @@ class Model(nn.Module):
|
|||||||
self.model, self.ckpt = weights, None
|
self.model, self.ckpt = weights, None
|
||||||
self.task = task or guess_model_task(weights)
|
self.task = task or guess_model_task(weights)
|
||||||
self.ckpt_path = weights
|
self.ckpt_path = weights
|
||||||
self.overrides['model'] = weights
|
self.overrides["model"] = weights
|
||||||
self.overrides['task'] = self.task
|
self.overrides["task"] = self.task
|
||||||
|
|
||||||
def _check_is_pytorch_model(self):
|
def _check_is_pytorch_model(self):
|
||||||
"""Raises TypeError is model is not a PyTorch model."""
|
"""Raises TypeError is model is not a PyTorch model."""
|
||||||
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
|
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
||||||
pt_module = isinstance(self.model, nn.Module)
|
pt_module = isinstance(self.model, nn.Module)
|
||||||
if not (pt_module or pt_str):
|
if not (pt_module or pt_str):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -176,19 +180,20 @@ class Model(nn.Module):
|
|||||||
f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported "
|
f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported "
|
||||||
f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, "
|
f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, "
|
||||||
f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device "
|
f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device "
|
||||||
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'")
|
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
|
||||||
|
)
|
||||||
|
|
||||||
def reset_weights(self):
|
def reset_weights(self):
|
||||||
"""Resets the model modules parameters to randomly initialized values, losing all training information."""
|
"""Resets the model modules parameters to randomly initialized values, losing all training information."""
|
||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if hasattr(m, 'reset_parameters'):
|
if hasattr(m, "reset_parameters"):
|
||||||
m.reset_parameters()
|
m.reset_parameters()
|
||||||
for p in self.model.parameters():
|
for p in self.model.parameters():
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def load(self, weights='yolov8n.pt'):
|
def load(self, weights="yolov8n.pt"):
|
||||||
"""Transfers parameters with matching names and shapes from 'weights' to model."""
|
"""Transfers parameters with matching names and shapes from 'weights' to model."""
|
||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
if isinstance(weights, (str, Path)):
|
if isinstance(weights, (str, Path)):
|
||||||
@ -226,8 +231,8 @@ class Model(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
(List[torch.Tensor]): A list of image embeddings.
|
(List[torch.Tensor]): A list of image embeddings.
|
||||||
"""
|
"""
|
||||||
if not kwargs.get('embed'):
|
if not kwargs.get("embed"):
|
||||||
kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
||||||
return self.predict(source, stream, **kwargs)
|
return self.predict(source, stream, **kwargs)
|
||||||
|
|
||||||
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
||||||
@ -249,21 +254,22 @@ class Model(nn.Module):
|
|||||||
source = ASSETS
|
source = ASSETS
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||||
|
|
||||||
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
|
is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any(
|
||||||
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
|
x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
|
||||||
|
)
|
||||||
|
|
||||||
custom = {'conf': 0.25, 'save': is_cli} # method defaults
|
custom = {"conf": 0.25, "save": is_cli} # method defaults
|
||||||
args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'} # highest priority args on the right
|
args = {**self.overrides, **custom, **kwargs, "mode": "predict"} # highest priority args on the right
|
||||||
prompts = args.pop('prompts', None) # for SAM-type models
|
prompts = args.pop("prompts", None) # for SAM-type models
|
||||||
|
|
||||||
if not self.predictor:
|
if not self.predictor:
|
||||||
self.predictor = predictor or self._smart_load('predictor')(overrides=args, _callbacks=self.callbacks)
|
self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
|
||||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||||
else: # only update args if predictor is already setup
|
else: # only update args if predictor is already setup
|
||||||
self.predictor.args = get_cfg(self.predictor.args, args)
|
self.predictor.args = get_cfg(self.predictor.args, args)
|
||||||
if 'project' in args or 'name' in args:
|
if "project" in args or "name" in args:
|
||||||
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
||||||
if prompts and hasattr(self.predictor, 'set_prompts'): # for SAM-type models
|
if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
|
||||||
self.predictor.set_prompts(prompts)
|
self.predictor.set_prompts(prompts)
|
||||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||||
|
|
||||||
@ -280,11 +286,12 @@ class Model(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
(List[ultralytics.engine.results.Results]): The tracking results.
|
(List[ultralytics.engine.results.Results]): The tracking results.
|
||||||
"""
|
"""
|
||||||
if not hasattr(self.predictor, 'trackers'):
|
if not hasattr(self.predictor, "trackers"):
|
||||||
from ultralytics.trackers import register_tracker
|
from ultralytics.trackers import register_tracker
|
||||||
|
|
||||||
register_tracker(self, persist)
|
register_tracker(self, persist)
|
||||||
kwargs['conf'] = kwargs.get('conf') or 0.1 # ByteTrack-based method needs low confidence predictions as input
|
kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input
|
||||||
kwargs['mode'] = 'track'
|
kwargs["mode"] = "track"
|
||||||
return self.predict(source=source, stream=stream, **kwargs)
|
return self.predict(source=source, stream=stream, **kwargs)
|
||||||
|
|
||||||
def val(self, validator=None, **kwargs):
|
def val(self, validator=None, **kwargs):
|
||||||
@ -295,10 +302,10 @@ class Model(nn.Module):
|
|||||||
validator (BaseValidator): Customized validator.
|
validator (BaseValidator): Customized validator.
|
||||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||||
"""
|
"""
|
||||||
custom = {'rect': True} # method defaults
|
custom = {"rect": True} # method defaults
|
||||||
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
|
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
|
||||||
|
|
||||||
validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks)
|
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
|
||||||
validator(model=self.model)
|
validator(model=self.model)
|
||||||
self.metrics = validator.metrics
|
self.metrics = validator.metrics
|
||||||
return validator.metrics
|
return validator.metrics
|
||||||
@ -313,16 +320,17 @@ class Model(nn.Module):
|
|||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
from ultralytics.utils.benchmarks import benchmark
|
from ultralytics.utils.benchmarks import benchmark
|
||||||
|
|
||||||
custom = {'verbose': False} # method defaults
|
custom = {"verbose": False} # method defaults
|
||||||
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, 'mode': 'benchmark'}
|
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
|
||||||
return benchmark(
|
return benchmark(
|
||||||
model=self,
|
model=self,
|
||||||
data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
|
data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets
|
||||||
imgsz=args['imgsz'],
|
imgsz=args["imgsz"],
|
||||||
half=args['half'],
|
half=args["half"],
|
||||||
int8=args['int8'],
|
int8=args["int8"],
|
||||||
device=args['device'],
|
device=args["device"],
|
||||||
verbose=kwargs.get('verbose'))
|
verbose=kwargs.get("verbose"),
|
||||||
|
)
|
||||||
|
|
||||||
def export(self, **kwargs):
|
def export(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -334,8 +342,8 @@ class Model(nn.Module):
|
|||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
from .exporter import Exporter
|
from .exporter import Exporter
|
||||||
|
|
||||||
custom = {'imgsz': self.model.args['imgsz'], 'batch': 1, 'data': None, 'verbose': False} # method defaults
|
custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False} # method defaults
|
||||||
args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right
|
args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
|
||||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||||
|
|
||||||
def train(self, trainer=None, **kwargs):
|
def train(self, trainer=None, **kwargs):
|
||||||
@ -347,32 +355,32 @@ class Model(nn.Module):
|
|||||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||||
"""
|
"""
|
||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
if hasattr(self.session, 'model') and self.session.model.id: # Ultralytics HUB session with loaded model
|
if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
|
||||||
if any(kwargs):
|
if any(kwargs):
|
||||||
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
|
||||||
kwargs = self.session.train_args # overwrite kwargs
|
kwargs = self.session.train_args # overwrite kwargs
|
||||||
|
|
||||||
checks.check_pip_update_available()
|
checks.check_pip_update_available()
|
||||||
|
|
||||||
overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
|
overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
|
||||||
custom = {'data': DEFAULT_CFG_DICT['data'] or TASK2DATA[self.task]} # method defaults
|
custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults
|
||||||
args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
|
args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
||||||
if args.get('resume'):
|
if args.get("resume"):
|
||||||
args['resume'] = self.ckpt_path
|
args["resume"] = self.ckpt_path
|
||||||
|
|
||||||
self.trainer = (trainer or self._smart_load('trainer'))(overrides=args, _callbacks=self.callbacks)
|
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
|
||||||
if not args.get('resume'): # manually set model only if not resuming
|
if not args.get("resume"): # manually set model only if not resuming
|
||||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||||
self.model = self.trainer.model
|
self.model = self.trainer.model
|
||||||
|
|
||||||
if SETTINGS['hub'] is True and not self.session:
|
if SETTINGS["hub"] is True and not self.session:
|
||||||
# Create a model in HUB
|
# Create a model in HUB
|
||||||
try:
|
try:
|
||||||
self.session = self._get_hub_session(self.model_name)
|
self.session = self._get_hub_session(self.model_name)
|
||||||
if self.session:
|
if self.session:
|
||||||
self.session.create_model(args)
|
self.session.create_model(args)
|
||||||
# Check model was created
|
# Check model was created
|
||||||
if not getattr(self.session.model, 'id', None):
|
if not getattr(self.session.model, "id", None):
|
||||||
self.session = None
|
self.session = None
|
||||||
except PermissionError:
|
except PermissionError:
|
||||||
# Ignore permission error
|
# Ignore permission error
|
||||||
@ -385,7 +393,7 @@ class Model(nn.Module):
|
|||||||
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
||||||
self.model, _ = attempt_load_one_weight(ckpt)
|
self.model, _ = attempt_load_one_weight(ckpt)
|
||||||
self.overrides = self.model.args
|
self.overrides = self.model.args
|
||||||
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
||||||
return self.metrics
|
return self.metrics
|
||||||
|
|
||||||
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
|
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
|
||||||
@ -398,12 +406,13 @@ class Model(nn.Module):
|
|||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
if use_ray:
|
if use_ray:
|
||||||
from ultralytics.utils.tuner import run_ray_tune
|
from ultralytics.utils.tuner import run_ray_tune
|
||||||
|
|
||||||
return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
|
return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
from .tuner import Tuner
|
from .tuner import Tuner
|
||||||
|
|
||||||
custom = {} # method defaults
|
custom = {} # method defaults
|
||||||
args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
|
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
||||||
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
||||||
|
|
||||||
def _apply(self, fn):
|
def _apply(self, fn):
|
||||||
@ -411,13 +420,13 @@ class Model(nn.Module):
|
|||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
self = super()._apply(fn) # noqa
|
self = super()._apply(fn) # noqa
|
||||||
self.predictor = None # reset predictor as device may have changed
|
self.predictor = None # reset predictor as device may have changed
|
||||||
self.overrides['device'] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
|
self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self):
|
def names(self):
|
||||||
"""Returns class names of the loaded model."""
|
"""Returns class names of the loaded model."""
|
||||||
return self.model.names if hasattr(self.model, 'names') else None
|
return self.model.names if hasattr(self.model, "names") else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
@ -427,7 +436,7 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def transforms(self):
|
def transforms(self):
|
||||||
"""Returns transform of the loaded model."""
|
"""Returns transform of the loaded model."""
|
||||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
return self.model.transforms if hasattr(self.model, "transforms") else None
|
||||||
|
|
||||||
def add_callback(self, event: str, func):
|
def add_callback(self, event: str, func):
|
||||||
"""Add a callback."""
|
"""Add a callback."""
|
||||||
@ -445,7 +454,7 @@ class Model(nn.Module):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_ckpt_args(args):
|
def _reset_ckpt_args(args):
|
||||||
"""Reset arguments when loading a PyTorch model."""
|
"""Reset arguments when loading a PyTorch model."""
|
||||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
|
||||||
return {k: v for k, v in args.items() if k in include}
|
return {k: v for k, v in args.items() if k in include}
|
||||||
|
|
||||||
# def __getattr__(self, attr):
|
# def __getattr__(self, attr):
|
||||||
@ -461,7 +470,8 @@ class Model(nn.Module):
|
|||||||
name = self.__class__.__name__
|
name = self.__class__.__name__
|
||||||
mode = inspect.stack()[1][3] # get the function name.
|
mode = inspect.stack()[1][3] # get the function name.
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")) from e
|
emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")
|
||||||
|
) from e
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_map(self):
|
def task_map(self):
|
||||||
@ -471,4 +481,4 @@ class Model(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
task_map (dict): The map of model task to mode classes.
|
task_map (dict): The map of model task to mode classes.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError('Please provide task map for your model!')
|
raise NotImplementedError("Please provide task map for your model!")
|
||||||
|
|||||||
@ -132,8 +132,11 @@ class BasePredictor:
|
|||||||
|
|
||||||
def inference(self, im, *args, **kwargs):
|
def inference(self, im, *args, **kwargs):
|
||||||
"""Runs inference on a given image using the specified model and arguments."""
|
"""Runs inference on a given image using the specified model and arguments."""
|
||||||
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
|
visualize = (
|
||||||
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
|
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
|
||||||
|
if self.args.visualize and (not self.source_type.tensor)
|
||||||
|
else False
|
||||||
|
)
|
||||||
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
||||||
|
|
||||||
def pre_transform(self, im):
|
def pre_transform(self, im):
|
||||||
@ -153,35 +156,38 @@ class BasePredictor:
|
|||||||
def write_results(self, idx, results, batch):
|
def write_results(self, idx, results, batch):
|
||||||
"""Write inference results to a file or directory."""
|
"""Write inference results to a file or directory."""
|
||||||
p, im, _ = batch
|
p, im, _ = batch
|
||||||
log_string = ''
|
log_string = ""
|
||||||
if len(im.shape) == 3:
|
if len(im.shape) == 3:
|
||||||
im = im[None] # expand for batch dim
|
im = im[None] # expand for batch dim
|
||||||
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
||||||
log_string += f'{idx}: '
|
log_string += f"{idx}: "
|
||||||
frame = self.dataset.count
|
frame = self.dataset.count
|
||||||
else:
|
else:
|
||||||
frame = getattr(self.dataset, 'frame', 0)
|
frame = getattr(self.dataset, "frame", 0)
|
||||||
self.data_path = p
|
self.data_path = p
|
||||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}")
|
||||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
log_string += "%gx%g " % im.shape[2:] # print string
|
||||||
result = results[idx]
|
result = results[idx]
|
||||||
log_string += result.verbose()
|
log_string += result.verbose()
|
||||||
|
|
||||||
if self.args.save or self.args.show: # Add bbox to image
|
if self.args.save or self.args.show: # Add bbox to image
|
||||||
plot_args = {
|
plot_args = {
|
||||||
'line_width': self.args.line_width,
|
"line_width": self.args.line_width,
|
||||||
'boxes': self.args.show_boxes,
|
"boxes": self.args.show_boxes,
|
||||||
'conf': self.args.show_conf,
|
"conf": self.args.show_conf,
|
||||||
'labels': self.args.show_labels}
|
"labels": self.args.show_labels,
|
||||||
|
}
|
||||||
if not self.args.retina_masks:
|
if not self.args.retina_masks:
|
||||||
plot_args['im_gpu'] = im[idx]
|
plot_args["im_gpu"] = im[idx]
|
||||||
self.plotted_img = result.plot(**plot_args)
|
self.plotted_img = result.plot(**plot_args)
|
||||||
# Write
|
# Write
|
||||||
if self.args.save_txt:
|
if self.args.save_txt:
|
||||||
result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
|
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
|
||||||
if self.args.save_crop:
|
if self.args.save_crop:
|
||||||
result.save_crop(save_dir=self.save_dir / 'crops',
|
result.save_crop(
|
||||||
file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}'))
|
save_dir=self.save_dir / "crops",
|
||||||
|
file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"),
|
||||||
|
)
|
||||||
|
|
||||||
return log_string
|
return log_string
|
||||||
|
|
||||||
@ -210,17 +216,24 @@ class BasePredictor:
|
|||||||
def setup_source(self, source):
|
def setup_source(self, source):
|
||||||
"""Sets up source and inference mode."""
|
"""Sets up source and inference mode."""
|
||||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||||
self.transforms = getattr(
|
self.transforms = (
|
||||||
self.model.model, 'transforms', classify_transforms(
|
getattr(
|
||||||
self.imgsz[0], crop_fraction=self.args.crop_fraction)) if self.args.task == 'classify' else None
|
self.model.model,
|
||||||
self.dataset = load_inference_source(source=source,
|
"transforms",
|
||||||
imgsz=self.imgsz,
|
classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
|
||||||
vid_stride=self.args.vid_stride,
|
)
|
||||||
buffer=self.args.stream_buffer)
|
if self.args.task == "classify"
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.dataset = load_inference_source(
|
||||||
|
source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer
|
||||||
|
)
|
||||||
self.source_type = self.dataset.source_type
|
self.source_type = self.dataset.source_type
|
||||||
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
|
if not getattr(self, "stream", True) and (
|
||||||
len(self.dataset) > 1000 or # images
|
self.dataset.mode == "stream" # streams
|
||||||
any(getattr(self.dataset, 'video_flag', [False]))): # videos
|
or len(self.dataset) > 1000 # images
|
||||||
|
or any(getattr(self.dataset, "video_flag", [False]))
|
||||||
|
): # videos
|
||||||
LOGGER.warning(STREAM_WARNING)
|
LOGGER.warning(STREAM_WARNING)
|
||||||
self.vid_path = [None] * self.dataset.bs
|
self.vid_path = [None] * self.dataset.bs
|
||||||
self.vid_writer = [None] * self.dataset.bs
|
self.vid_writer = [None] * self.dataset.bs
|
||||||
@ -230,7 +243,7 @@ class BasePredictor:
|
|||||||
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
||||||
"""Streams real-time inference on camera feed and saves results to file."""
|
"""Streams real-time inference on camera feed and saves results to file."""
|
||||||
if self.args.verbose:
|
if self.args.verbose:
|
||||||
LOGGER.info('')
|
LOGGER.info("")
|
||||||
|
|
||||||
# Setup model
|
# Setup model
|
||||||
if not self.model:
|
if not self.model:
|
||||||
@ -242,7 +255,7 @@ class BasePredictor:
|
|||||||
|
|
||||||
# Check if save_dir/ label file exists
|
# Check if save_dir/ label file exists
|
||||||
if self.args.save or self.args.save_txt:
|
if self.args.save or self.args.save_txt:
|
||||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Warmup model
|
# Warmup model
|
||||||
if not self.done_warmup:
|
if not self.done_warmup:
|
||||||
@ -250,10 +263,10 @@ class BasePredictor:
|
|||||||
self.done_warmup = True
|
self.done_warmup = True
|
||||||
|
|
||||||
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
|
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
|
||||||
self.run_callbacks('on_predict_start')
|
self.run_callbacks("on_predict_start")
|
||||||
|
|
||||||
for batch in self.dataset:
|
for batch in self.dataset:
|
||||||
self.run_callbacks('on_predict_batch_start')
|
self.run_callbacks("on_predict_batch_start")
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
path, im0s, vid_cap, s = batch
|
path, im0s, vid_cap, s = batch
|
||||||
|
|
||||||
@ -272,15 +285,16 @@ class BasePredictor:
|
|||||||
with profilers[2]:
|
with profilers[2]:
|
||||||
self.results = self.postprocess(preds, im, im0s)
|
self.results = self.postprocess(preds, im, im0s)
|
||||||
|
|
||||||
self.run_callbacks('on_predict_postprocess_end')
|
self.run_callbacks("on_predict_postprocess_end")
|
||||||
# Visualize, save, write results
|
# Visualize, save, write results
|
||||||
n = len(im0s)
|
n = len(im0s)
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
self.results[i].speed = {
|
self.results[i].speed = {
|
||||||
'preprocess': profilers[0].dt * 1E3 / n,
|
"preprocess": profilers[0].dt * 1e3 / n,
|
||||||
'inference': profilers[1].dt * 1E3 / n,
|
"inference": profilers[1].dt * 1e3 / n,
|
||||||
'postprocess': profilers[2].dt * 1E3 / n}
|
"postprocess": profilers[2].dt * 1e3 / n,
|
||||||
|
}
|
||||||
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
||||||
p = Path(p)
|
p = Path(p)
|
||||||
|
|
||||||
@ -293,12 +307,12 @@ class BasePredictor:
|
|||||||
if self.args.save and self.plotted_img is not None:
|
if self.args.save and self.plotted_img is not None:
|
||||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||||
|
|
||||||
self.run_callbacks('on_predict_batch_end')
|
self.run_callbacks("on_predict_batch_end")
|
||||||
yield from self.results
|
yield from self.results
|
||||||
|
|
||||||
# Print time (inference-only)
|
# Print time (inference-only)
|
||||||
if self.args.verbose:
|
if self.args.verbose:
|
||||||
LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
|
LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms")
|
||||||
|
|
||||||
# Release assets
|
# Release assets
|
||||||
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
|
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
|
||||||
@ -306,25 +320,29 @@ class BasePredictor:
|
|||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
if self.args.verbose and self.seen:
|
if self.args.verbose and self.seen:
|
||||||
t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
|
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
|
||||||
LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
|
LOGGER.info(
|
||||||
f'{(1, 3, *im.shape[2:])}' % t)
|
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
|
||||||
|
f"{(1, 3, *im.shape[2:])}" % t
|
||||||
|
)
|
||||||
if self.args.save or self.args.save_txt or self.args.save_crop:
|
if self.args.save or self.args.save_txt or self.args.save_crop:
|
||||||
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
|
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
|
||||||
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
|
||||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||||
|
|
||||||
self.run_callbacks('on_predict_end')
|
self.run_callbacks("on_predict_end")
|
||||||
|
|
||||||
def setup_model(self, model, verbose=True):
|
def setup_model(self, model, verbose=True):
|
||||||
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
||||||
self.model = AutoBackend(model or self.args.model,
|
self.model = AutoBackend(
|
||||||
|
model or self.args.model,
|
||||||
device=select_device(self.args.device, verbose=verbose),
|
device=select_device(self.args.device, verbose=verbose),
|
||||||
dnn=self.args.dnn,
|
dnn=self.args.dnn,
|
||||||
data=self.args.data,
|
data=self.args.data,
|
||||||
fp16=self.args.half,
|
fp16=self.args.half,
|
||||||
fuse=True,
|
fuse=True,
|
||||||
verbose=verbose)
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
self.device = self.model.device # update device
|
self.device = self.model.device # update device
|
||||||
self.args.half = self.model.fp16 # update half
|
self.args.half = self.model.fp16 # update half
|
||||||
@ -333,18 +351,18 @@ class BasePredictor:
|
|||||||
def show(self, p):
|
def show(self, p):
|
||||||
"""Display an image in a window using OpenCV imshow()."""
|
"""Display an image in a window using OpenCV imshow()."""
|
||||||
im0 = self.plotted_img
|
im0 = self.plotted_img
|
||||||
if platform.system() == 'Linux' and p not in self.windows:
|
if platform.system() == "Linux" and p not in self.windows:
|
||||||
self.windows.append(p)
|
self.windows.append(p)
|
||||||
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||||
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
||||||
cv2.imshow(str(p), im0)
|
cv2.imshow(str(p), im0)
|
||||||
cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond
|
cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond
|
||||||
|
|
||||||
def save_preds(self, vid_cap, idx, save_path):
|
def save_preds(self, vid_cap, idx, save_path):
|
||||||
"""Save video predictions as mp4 at specified path."""
|
"""Save video predictions as mp4 at specified path."""
|
||||||
im0 = self.plotted_img
|
im0 = self.plotted_img
|
||||||
# Save imgs
|
# Save imgs
|
||||||
if self.dataset.mode == 'image':
|
if self.dataset.mode == "image":
|
||||||
cv2.imwrite(save_path, im0)
|
cv2.imwrite(save_path, im0)
|
||||||
else: # 'video' or 'stream'
|
else: # 'video' or 'stream'
|
||||||
frames_path = f'{save_path.split(".", 1)[0]}_frames/'
|
frames_path = f'{save_path.split(".", 1)[0]}_frames/'
|
||||||
@ -361,15 +379,16 @@ class BasePredictor:
|
|||||||
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
else: # stream
|
else: # stream
|
||||||
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
||||||
suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG')
|
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
|
||||||
self.vid_writer[idx] = cv2.VideoWriter(str(Path(save_path).with_suffix(suffix)),
|
self.vid_writer[idx] = cv2.VideoWriter(
|
||||||
cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
|
str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)
|
||||||
|
)
|
||||||
# Write video
|
# Write video
|
||||||
self.vid_writer[idx].write(im0)
|
self.vid_writer[idx].write(im0)
|
||||||
|
|
||||||
# Write frame
|
# Write frame
|
||||||
if self.args.save_frames:
|
if self.args.save_frames:
|
||||||
cv2.imwrite(f'{frames_path}{self.vid_frame[idx]}.jpg', im0)
|
cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0)
|
||||||
self.vid_frame[idx] += 1
|
self.vid_frame[idx] += 1
|
||||||
|
|
||||||
def run_callbacks(self, event: str):
|
def run_callbacks(self, event: str):
|
||||||
|
|||||||
@ -98,15 +98,15 @@ class Results(SimpleClass):
|
|||||||
self.probs = Probs(probs) if probs is not None else None
|
self.probs = Probs(probs) if probs is not None else None
|
||||||
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
||||||
self.obb = OBB(obb, self.orig_shape) if obb is not None else None
|
self.obb = OBB(obb, self.orig_shape) if obb is not None else None
|
||||||
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
|
self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image
|
||||||
self.names = names
|
self.names = names
|
||||||
self.path = path
|
self.path = path
|
||||||
self.save_dir = None
|
self.save_dir = None
|
||||||
self._keys = 'boxes', 'masks', 'probs', 'keypoints', 'obb'
|
self._keys = "boxes", "masks", "probs", "keypoints", "obb"
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
"""Return a Results object for the specified index."""
|
"""Return a Results object for the specified index."""
|
||||||
return self._apply('__getitem__', idx)
|
return self._apply("__getitem__", idx)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Return the number of detections in the Results object."""
|
"""Return the number of detections in the Results object."""
|
||||||
@ -146,19 +146,19 @@ class Results(SimpleClass):
|
|||||||
|
|
||||||
def cpu(self):
|
def cpu(self):
|
||||||
"""Return a copy of the Results object with all tensors on CPU memory."""
|
"""Return a copy of the Results object with all tensors on CPU memory."""
|
||||||
return self._apply('cpu')
|
return self._apply("cpu")
|
||||||
|
|
||||||
def numpy(self):
|
def numpy(self):
|
||||||
"""Return a copy of the Results object with all tensors as numpy arrays."""
|
"""Return a copy of the Results object with all tensors as numpy arrays."""
|
||||||
return self._apply('numpy')
|
return self._apply("numpy")
|
||||||
|
|
||||||
def cuda(self):
|
def cuda(self):
|
||||||
"""Return a copy of the Results object with all tensors on GPU memory."""
|
"""Return a copy of the Results object with all tensors on GPU memory."""
|
||||||
return self._apply('cuda')
|
return self._apply("cuda")
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
"""Return a copy of the Results object with tensors on the specified device and dtype."""
|
"""Return a copy of the Results object with tensors on the specified device and dtype."""
|
||||||
return self._apply('to', *args, **kwargs)
|
return self._apply("to", *args, **kwargs)
|
||||||
|
|
||||||
def new(self):
|
def new(self):
|
||||||
"""Return a new Results object with the same image, path, and names."""
|
"""Return a new Results object with the same image, path, and names."""
|
||||||
@ -169,7 +169,7 @@ class Results(SimpleClass):
|
|||||||
conf=True,
|
conf=True,
|
||||||
line_width=None,
|
line_width=None,
|
||||||
font_size=None,
|
font_size=None,
|
||||||
font='Arial.ttf',
|
font="Arial.ttf",
|
||||||
pil=False,
|
pil=False,
|
||||||
img=None,
|
img=None,
|
||||||
im_gpu=None,
|
im_gpu=None,
|
||||||
@ -229,14 +229,20 @@ class Results(SimpleClass):
|
|||||||
font_size,
|
font_size,
|
||||||
font,
|
font,
|
||||||
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
|
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
|
||||||
example=names)
|
example=names,
|
||||||
|
)
|
||||||
|
|
||||||
# Plot Segment results
|
# Plot Segment results
|
||||||
if pred_masks and show_masks:
|
if pred_masks and show_masks:
|
||||||
if im_gpu is None:
|
if im_gpu is None:
|
||||||
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
||||||
im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
|
im_gpu = (
|
||||||
2, 0, 1).flip(0).contiguous() / 255
|
torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device)
|
||||||
|
.permute(2, 0, 1)
|
||||||
|
.flip(0)
|
||||||
|
.contiguous()
|
||||||
|
/ 255
|
||||||
|
)
|
||||||
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
||||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
|
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
|
||||||
|
|
||||||
@ -244,14 +250,14 @@ class Results(SimpleClass):
|
|||||||
if pred_boxes is not None and show_boxes:
|
if pred_boxes is not None and show_boxes:
|
||||||
for d in reversed(pred_boxes):
|
for d in reversed(pred_boxes):
|
||||||
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
||||||
name = ('' if id is None else f'id:{id} ') + names[c]
|
name = ("" if id is None else f"id:{id} ") + names[c]
|
||||||
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
|
label = (f"{name} {conf:.2f}" if conf else name) if labels else None
|
||||||
box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze()
|
box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze()
|
||||||
annotator.box_label(box, label, color=colors(c, True), rotated=is_obb)
|
annotator.box_label(box, label, color=colors(c, True), rotated=is_obb)
|
||||||
|
|
||||||
# Plot Classify results
|
# Plot Classify results
|
||||||
if pred_probs is not None and show_probs:
|
if pred_probs is not None and show_probs:
|
||||||
text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
|
text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5)
|
||||||
x = round(self.orig_shape[0] * 0.03)
|
x = round(self.orig_shape[0] * 0.03)
|
||||||
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||||
|
|
||||||
@ -264,11 +270,11 @@ class Results(SimpleClass):
|
|||||||
|
|
||||||
def verbose(self):
|
def verbose(self):
|
||||||
"""Return log string for each task."""
|
"""Return log string for each task."""
|
||||||
log_string = ''
|
log_string = ""
|
||||||
probs = self.probs
|
probs = self.probs
|
||||||
boxes = self.boxes
|
boxes = self.boxes
|
||||||
if len(self) == 0:
|
if len(self) == 0:
|
||||||
return log_string if probs is not None else f'{log_string}(no detections), '
|
return log_string if probs is not None else f"{log_string}(no detections), "
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
|
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
|
||||||
if boxes:
|
if boxes:
|
||||||
@ -293,7 +299,7 @@ class Results(SimpleClass):
|
|||||||
texts = []
|
texts = []
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
# Classify
|
# Classify
|
||||||
[texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
|
[texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5]
|
||||||
elif boxes:
|
elif boxes:
|
||||||
# Detect/segment/pose
|
# Detect/segment/pose
|
||||||
for j, d in enumerate(boxes):
|
for j, d in enumerate(boxes):
|
||||||
@ -304,16 +310,16 @@ class Results(SimpleClass):
|
|||||||
line = (c, *seg)
|
line = (c, *seg)
|
||||||
if kpts is not None:
|
if kpts is not None:
|
||||||
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
|
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
|
||||||
line += (*kpt.reshape(-1).tolist(), )
|
line += (*kpt.reshape(-1).tolist(),)
|
||||||
line += (conf, ) * save_conf + (() if id is None else (id, ))
|
line += (conf,) * save_conf + (() if id is None else (id,))
|
||||||
texts.append(('%g ' * len(line)).rstrip() % line)
|
texts.append(("%g " * len(line)).rstrip() % line)
|
||||||
|
|
||||||
if texts:
|
if texts:
|
||||||
Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory
|
Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory
|
||||||
with open(txt_file, 'a') as f:
|
with open(txt_file, "a") as f:
|
||||||
f.writelines(text + '\n' for text in texts)
|
f.writelines(text + "\n" for text in texts)
|
||||||
|
|
||||||
def save_crop(self, save_dir, file_name=Path('im.jpg')):
|
def save_crop(self, save_dir, file_name=Path("im.jpg")):
|
||||||
"""
|
"""
|
||||||
Save cropped predictions to `save_dir/cls/file_name.jpg`.
|
Save cropped predictions to `save_dir/cls/file_name.jpg`.
|
||||||
|
|
||||||
@ -322,21 +328,23 @@ class Results(SimpleClass):
|
|||||||
file_name (str | pathlib.Path): File name.
|
file_name (str | pathlib.Path): File name.
|
||||||
"""
|
"""
|
||||||
if self.probs is not None:
|
if self.probs is not None:
|
||||||
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
|
LOGGER.warning("WARNING ⚠️ Classify task do not support `save_crop`.")
|
||||||
return
|
return
|
||||||
if self.obb is not None:
|
if self.obb is not None:
|
||||||
LOGGER.warning('WARNING ⚠️ OBB task do not support `save_crop`.')
|
LOGGER.warning("WARNING ⚠️ OBB task do not support `save_crop`.")
|
||||||
return
|
return
|
||||||
for d in self.boxes:
|
for d in self.boxes:
|
||||||
save_one_box(d.xyxy,
|
save_one_box(
|
||||||
|
d.xyxy,
|
||||||
self.orig_img.copy(),
|
self.orig_img.copy(),
|
||||||
file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name)}.jpg',
|
file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg",
|
||||||
BGR=True)
|
BGR=True,
|
||||||
|
)
|
||||||
|
|
||||||
def tojson(self, normalize=False):
|
def tojson(self, normalize=False):
|
||||||
"""Convert the object to JSON format."""
|
"""Convert the object to JSON format."""
|
||||||
if self.probs is not None:
|
if self.probs is not None:
|
||||||
LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
|
LOGGER.warning("Warning: Classify task do not support `tojson` yet.")
|
||||||
return
|
return
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@ -346,19 +354,19 @@ class Results(SimpleClass):
|
|||||||
data = self.boxes.data.cpu().tolist()
|
data = self.boxes.data.cpu().tolist()
|
||||||
h, w = self.orig_shape if normalize else (1, 1)
|
h, w = self.orig_shape if normalize else (1, 1)
|
||||||
for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
|
for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
|
||||||
box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h}
|
box = {"x1": row[0] / w, "y1": row[1] / h, "x2": row[2] / w, "y2": row[3] / h}
|
||||||
conf = row[-2]
|
conf = row[-2]
|
||||||
class_id = int(row[-1])
|
class_id = int(row[-1])
|
||||||
name = self.names[class_id]
|
name = self.names[class_id]
|
||||||
result = {'name': name, 'class': class_id, 'confidence': conf, 'box': box}
|
result = {"name": name, "class": class_id, "confidence": conf, "box": box}
|
||||||
if self.boxes.is_track:
|
if self.boxes.is_track:
|
||||||
result['track_id'] = int(row[-3]) # track ID
|
result["track_id"] = int(row[-3]) # track ID
|
||||||
if self.masks:
|
if self.masks:
|
||||||
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
|
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
|
||||||
result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
|
result["segments"] = {"x": (x / w).tolist(), "y": (y / h).tolist()}
|
||||||
if self.keypoints is not None:
|
if self.keypoints is not None:
|
||||||
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
|
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
|
||||||
result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
|
result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()}
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
# Convert detections to JSON
|
# Convert detections to JSON
|
||||||
@ -397,7 +405,7 @@ class Boxes(BaseTensor):
|
|||||||
if boxes.ndim == 1:
|
if boxes.ndim == 1:
|
||||||
boxes = boxes[None, :]
|
boxes = boxes[None, :]
|
||||||
n = boxes.shape[-1]
|
n = boxes.shape[-1]
|
||||||
assert n in (6, 7), f'expected 6 or 7 values but got {n}' # xyxy, track_id, conf, cls
|
assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
|
||||||
super().__init__(boxes, orig_shape)
|
super().__init__(boxes, orig_shape)
|
||||||
self.is_track = n == 7
|
self.is_track = n == 7
|
||||||
self.orig_shape = orig_shape
|
self.orig_shape = orig_shape
|
||||||
@ -474,7 +482,8 @@ class Masks(BaseTensor):
|
|||||||
"""Return normalized segments."""
|
"""Return normalized segments."""
|
||||||
return [
|
return [
|
||||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
|
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
|
||||||
for x in ops.masks2segments(self.data)]
|
for x in ops.masks2segments(self.data)
|
||||||
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
@ -482,7 +491,8 @@ class Masks(BaseTensor):
|
|||||||
"""Return segments in pixel coordinates."""
|
"""Return segments in pixel coordinates."""
|
||||||
return [
|
return [
|
||||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
|
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
|
||||||
for x in ops.masks2segments(self.data)]
|
for x in ops.masks2segments(self.data)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Keypoints(BaseTensor):
|
class Keypoints(BaseTensor):
|
||||||
@ -610,7 +620,7 @@ class OBB(BaseTensor):
|
|||||||
if boxes.ndim == 1:
|
if boxes.ndim == 1:
|
||||||
boxes = boxes[None, :]
|
boxes = boxes[None, :]
|
||||||
n = boxes.shape[-1]
|
n = boxes.shape[-1]
|
||||||
assert n in (7, 8), f'expected 7 or 8 values but got {n}' # xywh, rotation, track_id, conf, cls
|
assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
|
||||||
super().__init__(boxes, orig_shape)
|
super().__init__(boxes, orig_shape)
|
||||||
self.is_track = n == 8
|
self.is_track = n == 8
|
||||||
self.orig_shape = orig_shape
|
self.orig_shape = orig_shape
|
||||||
|
|||||||
@ -23,14 +23,31 @@ from torch import nn, optim
|
|||||||
from ultralytics.cfg import get_cfg, get_save_dir
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||||
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis,
|
from ultralytics.utils import (
|
||||||
yaml_save)
|
DEFAULT_CFG,
|
||||||
|
LOGGER,
|
||||||
|
RANK,
|
||||||
|
TQDM,
|
||||||
|
__version__,
|
||||||
|
callbacks,
|
||||||
|
clean_url,
|
||||||
|
colorstr,
|
||||||
|
emojis,
|
||||||
|
yaml_save,
|
||||||
|
)
|
||||||
from ultralytics.utils.autobatch import check_train_batch_size
|
from ultralytics.utils.autobatch import check_train_batch_size
|
||||||
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
||||||
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||||
from ultralytics.utils.files import get_latest_run
|
from ultralytics.utils.files import get_latest_run
|
||||||
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
|
from ultralytics.utils.torch_utils import (
|
||||||
strip_optimizer)
|
EarlyStopping,
|
||||||
|
ModelEMA,
|
||||||
|
de_parallel,
|
||||||
|
init_seeds,
|
||||||
|
one_cycle,
|
||||||
|
select_device,
|
||||||
|
strip_optimizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainer:
|
class BaseTrainer:
|
||||||
@ -89,12 +106,12 @@ class BaseTrainer:
|
|||||||
# Dirs
|
# Dirs
|
||||||
self.save_dir = get_save_dir(self.args)
|
self.save_dir = get_save_dir(self.args)
|
||||||
self.args.name = self.save_dir.name # update name for loggers
|
self.args.name = self.save_dir.name # update name for loggers
|
||||||
self.wdir = self.save_dir / 'weights' # weights dir
|
self.wdir = self.save_dir / "weights" # weights dir
|
||||||
if RANK in (-1, 0):
|
if RANK in (-1, 0):
|
||||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
self.args.save_dir = str(self.save_dir)
|
self.args.save_dir = str(self.save_dir)
|
||||||
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
|
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
||||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
|
||||||
self.save_period = self.args.save_period
|
self.save_period = self.args.save_period
|
||||||
|
|
||||||
self.batch_size = self.args.batch
|
self.batch_size = self.args.batch
|
||||||
@ -104,18 +121,18 @@ class BaseTrainer:
|
|||||||
print_args(vars(self.args))
|
print_args(vars(self.args))
|
||||||
|
|
||||||
# Device
|
# Device
|
||||||
if self.device.type in ('cpu', 'mps'):
|
if self.device.type in ("cpu", "mps"):
|
||||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||||
|
|
||||||
# Model and Dataset
|
# Model and Dataset
|
||||||
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||||
try:
|
try:
|
||||||
if self.args.task == 'classify':
|
if self.args.task == "classify":
|
||||||
self.data = check_cls_dataset(self.args.data)
|
self.data = check_cls_dataset(self.args.data)
|
||||||
elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'):
|
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ("detect", "segment", "pose"):
|
||||||
self.data = check_det_dataset(self.args.data)
|
self.data = check_det_dataset(self.args.data)
|
||||||
if 'yaml_file' in self.data:
|
if "yaml_file" in self.data:
|
||||||
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||||
|
|
||||||
@ -131,8 +148,8 @@ class BaseTrainer:
|
|||||||
self.fitness = None
|
self.fitness = None
|
||||||
self.loss = None
|
self.loss = None
|
||||||
self.tloss = None
|
self.tloss = None
|
||||||
self.loss_names = ['Loss']
|
self.loss_names = ["Loss"]
|
||||||
self.csv = self.save_dir / 'results.csv'
|
self.csv = self.save_dir / "results.csv"
|
||||||
self.plot_idx = [0, 1, 2]
|
self.plot_idx = [0, 1, 2]
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
@ -156,7 +173,7 @@ class BaseTrainer:
|
|||||||
def train(self):
|
def train(self):
|
||||||
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
||||||
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
||||||
world_size = len(self.args.device.split(','))
|
world_size = len(self.args.device.split(","))
|
||||||
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
||||||
world_size = len(self.args.device)
|
world_size = len(self.args.device)
|
||||||
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
||||||
@ -165,14 +182,16 @@ class BaseTrainer:
|
|||||||
world_size = 0
|
world_size = 0
|
||||||
|
|
||||||
# Run subprocess if DDP training, else train normally
|
# Run subprocess if DDP training, else train normally
|
||||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
||||||
# Argument checks
|
# Argument checks
|
||||||
if self.args.rect:
|
if self.args.rect:
|
||||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
||||||
self.args.rect = False
|
self.args.rect = False
|
||||||
if self.args.batch == -1:
|
if self.args.batch == -1:
|
||||||
LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
LOGGER.warning(
|
||||||
"default 'batch=16'")
|
"WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
||||||
|
"default 'batch=16'"
|
||||||
|
)
|
||||||
self.args.batch = 16
|
self.args.batch = 16
|
||||||
|
|
||||||
# Command
|
# Command
|
||||||
@ -199,37 +218,45 @@ class BaseTrainer:
|
|||||||
def _setup_ddp(self, world_size):
|
def _setup_ddp(self, world_size):
|
||||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||||
torch.cuda.set_device(RANK)
|
torch.cuda.set_device(RANK)
|
||||||
self.device = torch.device('cuda', RANK)
|
self.device = torch.device("cuda", RANK)
|
||||||
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||||
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
|
os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
'nccl' if dist.is_nccl_available() else 'gloo',
|
"nccl" if dist.is_nccl_available() else "gloo",
|
||||||
timeout=timedelta(seconds=10800), # 3 hours
|
timeout=timedelta(seconds=10800), # 3 hours
|
||||||
rank=RANK,
|
rank=RANK,
|
||||||
world_size=world_size)
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
def _setup_train(self, world_size):
|
def _setup_train(self, world_size):
|
||||||
"""Builds dataloaders and optimizer on correct rank process."""
|
"""Builds dataloaders and optimizer on correct rank process."""
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
self.run_callbacks('on_pretrain_routine_start')
|
self.run_callbacks("on_pretrain_routine_start")
|
||||||
ckpt = self.setup_model()
|
ckpt = self.setup_model()
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
self.set_model_attributes()
|
self.set_model_attributes()
|
||||||
|
|
||||||
# Freeze layers
|
# Freeze layers
|
||||||
freeze_list = self.args.freeze if isinstance(
|
freeze_list = (
|
||||||
self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
|
self.args.freeze
|
||||||
always_freeze_names = ['.dfl'] # always freeze these layers
|
if isinstance(self.args.freeze, list)
|
||||||
freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
|
else range(self.args.freeze)
|
||||||
|
if isinstance(self.args.freeze, int)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
always_freeze_names = [".dfl"] # always freeze these layers
|
||||||
|
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
|
||||||
for k, v in self.model.named_parameters():
|
for k, v in self.model.named_parameters():
|
||||||
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
||||||
if any(x in k for x in freeze_layer_names):
|
if any(x in k for x in freeze_layer_names):
|
||||||
LOGGER.info(f"Freezing layer '{k}'")
|
LOGGER.info(f"Freezing layer '{k}'")
|
||||||
v.requires_grad = False
|
v.requires_grad = False
|
||||||
elif not v.requires_grad:
|
elif not v.requires_grad:
|
||||||
LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
|
LOGGER.info(
|
||||||
'See ultralytics.engine.trainer for customization of frozen layers.')
|
f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
|
||||||
|
"See ultralytics.engine.trainer for customization of frozen layers."
|
||||||
|
)
|
||||||
v.requires_grad = True
|
v.requires_grad = True
|
||||||
|
|
||||||
# Check AMP
|
# Check AMP
|
||||||
@ -246,7 +273,7 @@ class BaseTrainer:
|
|||||||
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
|
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
|
||||||
|
|
||||||
# Check imgsz
|
# Check imgsz
|
||||||
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
|
||||||
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
||||||
self.stride = gs # for multi-scale training
|
self.stride = gs # for multi-scale training
|
||||||
|
|
||||||
@ -256,15 +283,14 @@ class BaseTrainer:
|
|||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
batch_size = self.batch_size // max(world_size, 1)
|
batch_size = self.batch_size // max(world_size, 1)
|
||||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
||||||
if RANK in (-1, 0):
|
if RANK in (-1, 0):
|
||||||
# NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects.
|
# NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects.
|
||||||
self.test_loader = self.get_dataloader(self.testset,
|
self.test_loader = self.get_dataloader(
|
||||||
batch_size=batch_size if self.args.task == 'obb' else batch_size * 2,
|
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
||||||
rank=-1,
|
)
|
||||||
mode='val')
|
|
||||||
self.validator = self.get_validator()
|
self.validator = self.get_validator()
|
||||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
|
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
||||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
||||||
self.ema = ModelEMA(self.model)
|
self.ema = ModelEMA(self.model)
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
@ -274,18 +300,20 @@ class BaseTrainer:
|
|||||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||||
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||||
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
||||||
self.optimizer = self.build_optimizer(model=self.model,
|
self.optimizer = self.build_optimizer(
|
||||||
|
model=self.model,
|
||||||
name=self.args.optimizer,
|
name=self.args.optimizer,
|
||||||
lr=self.args.lr0,
|
lr=self.args.lr0,
|
||||||
momentum=self.args.momentum,
|
momentum=self.args.momentum,
|
||||||
decay=weight_decay,
|
decay=weight_decay,
|
||||||
iterations=iterations)
|
iterations=iterations,
|
||||||
|
)
|
||||||
# Scheduler
|
# Scheduler
|
||||||
self._setup_scheduler()
|
self._setup_scheduler()
|
||||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||||
self.resume_training(ckpt)
|
self.resume_training(ckpt)
|
||||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||||
self.run_callbacks('on_pretrain_routine_end')
|
self.run_callbacks("on_pretrain_routine_end")
|
||||||
|
|
||||||
def _do_train(self, world_size=1):
|
def _do_train(self, world_size=1):
|
||||||
"""Train completed, evaluate and plot if specified by arguments."""
|
"""Train completed, evaluate and plot if specified by arguments."""
|
||||||
@ -299,19 +327,23 @@ class BaseTrainer:
|
|||||||
self.epoch_time = None
|
self.epoch_time = None
|
||||||
self.epoch_time_start = time.time()
|
self.epoch_time_start = time.time()
|
||||||
self.train_time_start = time.time()
|
self.train_time_start = time.time()
|
||||||
self.run_callbacks('on_train_start')
|
self.run_callbacks("on_train_start")
|
||||||
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
LOGGER.info(
|
||||||
|
f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||||
f'Starting training for '
|
f'Starting training for '
|
||||||
f'{self.args.time} hours...' if self.args.time else f'{self.epochs} epochs...')
|
f'{self.args.time} hours...'
|
||||||
|
if self.args.time
|
||||||
|
else f"{self.epochs} epochs..."
|
||||||
|
)
|
||||||
if self.args.close_mosaic:
|
if self.args.close_mosaic:
|
||||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||||
epoch = self.epochs # predefine for resume fully trained model edge cases
|
epoch = self.epochs # predefine for resume fully trained model edge cases
|
||||||
for epoch in range(self.start_epoch, self.epochs):
|
for epoch in range(self.start_epoch, self.epochs):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.run_callbacks('on_train_epoch_start')
|
self.run_callbacks("on_train_epoch_start")
|
||||||
self.model.train()
|
self.model.train()
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
self.train_loader.sampler.set_epoch(epoch)
|
self.train_loader.sampler.set_epoch(epoch)
|
||||||
@ -327,7 +359,7 @@ class BaseTrainer:
|
|||||||
self.tloss = None
|
self.tloss = None
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
for i, batch in pbar:
|
for i, batch in pbar:
|
||||||
self.run_callbacks('on_train_batch_start')
|
self.run_callbacks("on_train_batch_start")
|
||||||
# Warmup
|
# Warmup
|
||||||
ni = i + nb * epoch
|
ni = i + nb * epoch
|
||||||
if ni <= nw:
|
if ni <= nw:
|
||||||
@ -335,10 +367,11 @@ class BaseTrainer:
|
|||||||
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
|
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
|
||||||
for j, x in enumerate(self.optimizer.param_groups):
|
for j, x in enumerate(self.optimizer.param_groups):
|
||||||
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||||
x['lr'] = np.interp(
|
x["lr"] = np.interp(
|
||||||
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
|
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
|
||||||
if 'momentum' in x:
|
)
|
||||||
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
if "momentum" in x:
|
||||||
|
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
with torch.cuda.amp.autocast(self.amp):
|
with torch.cuda.amp.autocast(self.amp):
|
||||||
@ -346,8 +379,9 @@ class BaseTrainer:
|
|||||||
self.loss, self.loss_items = self.model(batch)
|
self.loss, self.loss_items = self.model(batch)
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
self.loss *= world_size
|
self.loss *= world_size
|
||||||
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
self.tloss = (
|
||||||
else self.loss_items
|
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
||||||
|
)
|
||||||
|
|
||||||
# Backward
|
# Backward
|
||||||
self.scaler.scale(self.loss).backward()
|
self.scaler.scale(self.loss).backward()
|
||||||
@ -368,24 +402,25 @@ class BaseTrainer:
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Log
|
# Log
|
||||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
||||||
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
||||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||||
if RANK in (-1, 0):
|
if RANK in (-1, 0):
|
||||||
pbar.set_description(
|
pbar.set_description(
|
||||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
||||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
|
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
||||||
self.run_callbacks('on_batch_end')
|
)
|
||||||
|
self.run_callbacks("on_batch_end")
|
||||||
if self.args.plots and ni in self.plot_idx:
|
if self.args.plots and ni in self.plot_idx:
|
||||||
self.plot_training_samples(batch, ni)
|
self.plot_training_samples(batch, ni)
|
||||||
|
|
||||||
self.run_callbacks('on_train_batch_end')
|
self.run_callbacks("on_train_batch_end")
|
||||||
|
|
||||||
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||||
self.run_callbacks('on_train_epoch_end')
|
self.run_callbacks("on_train_epoch_end")
|
||||||
if RANK in (-1, 0):
|
if RANK in (-1, 0):
|
||||||
final_epoch = epoch + 1 == self.epochs
|
final_epoch = epoch + 1 == self.epochs
|
||||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
|
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
|
||||||
@ -398,14 +433,14 @@ class BaseTrainer:
|
|||||||
# Save model
|
# Save model
|
||||||
if self.args.save or final_epoch:
|
if self.args.save or final_epoch:
|
||||||
self.save_model()
|
self.save_model()
|
||||||
self.run_callbacks('on_model_save')
|
self.run_callbacks("on_model_save")
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
t = time.time()
|
t = time.time()
|
||||||
self.epoch_time = t - self.epoch_time_start
|
self.epoch_time = t - self.epoch_time_start
|
||||||
self.epoch_time_start = t
|
self.epoch_time_start = t
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||||
if self.args.time:
|
if self.args.time:
|
||||||
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
||||||
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
||||||
@ -413,7 +448,7 @@ class BaseTrainer:
|
|||||||
self.scheduler.last_epoch = self.epoch # do not move
|
self.scheduler.last_epoch = self.epoch # do not move
|
||||||
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
self.run_callbacks('on_fit_epoch_end')
|
self.run_callbacks("on_fit_epoch_end")
|
||||||
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
||||||
|
|
||||||
# Early Stopping
|
# Early Stopping
|
||||||
@ -426,39 +461,43 @@ class BaseTrainer:
|
|||||||
|
|
||||||
if RANK in (-1, 0):
|
if RANK in (-1, 0):
|
||||||
# Do final val with best.pt
|
# Do final val with best.pt
|
||||||
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
LOGGER.info(
|
||||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
||||||
|
f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
|
||||||
|
)
|
||||||
self.final_eval()
|
self.final_eval()
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.plot_metrics()
|
self.plot_metrics()
|
||||||
self.run_callbacks('on_train_end')
|
self.run_callbacks("on_train_end")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.run_callbacks('teardown')
|
self.run_callbacks("teardown")
|
||||||
|
|
||||||
def save_model(self):
|
def save_model(self):
|
||||||
"""Save model training checkpoints with additional metadata."""
|
"""Save model training checkpoints with additional metadata."""
|
||||||
import pandas as pd # scope for faster startup
|
import pandas as pd # scope for faster startup
|
||||||
metrics = {**self.metrics, **{'fitness': self.fitness}}
|
|
||||||
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient='list').items()}
|
metrics = {**self.metrics, **{"fitness": self.fitness}}
|
||||||
|
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
|
||||||
ckpt = {
|
ckpt = {
|
||||||
'epoch': self.epoch,
|
"epoch": self.epoch,
|
||||||
'best_fitness': self.best_fitness,
|
"best_fitness": self.best_fitness,
|
||||||
'model': deepcopy(de_parallel(self.model)).half(),
|
"model": deepcopy(de_parallel(self.model)).half(),
|
||||||
'ema': deepcopy(self.ema.ema).half(),
|
"ema": deepcopy(self.ema.ema).half(),
|
||||||
'updates': self.ema.updates,
|
"updates": self.ema.updates,
|
||||||
'optimizer': self.optimizer.state_dict(),
|
"optimizer": self.optimizer.state_dict(),
|
||||||
'train_args': vars(self.args), # save as dict
|
"train_args": vars(self.args), # save as dict
|
||||||
'train_metrics': metrics,
|
"train_metrics": metrics,
|
||||||
'train_results': results,
|
"train_results": results,
|
||||||
'date': datetime.now().isoformat(),
|
"date": datetime.now().isoformat(),
|
||||||
'version': __version__}
|
"version": __version__,
|
||||||
|
}
|
||||||
|
|
||||||
# Save last and best
|
# Save last and best
|
||||||
torch.save(ckpt, self.last)
|
torch.save(ckpt, self.last)
|
||||||
if self.best_fitness == self.fitness:
|
if self.best_fitness == self.fitness:
|
||||||
torch.save(ckpt, self.best)
|
torch.save(ckpt, self.best)
|
||||||
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
||||||
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
|
torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dataset(data):
|
def get_dataset(data):
|
||||||
@ -467,7 +506,7 @@ class BaseTrainer:
|
|||||||
|
|
||||||
Returns None if data format is not recognized.
|
Returns None if data format is not recognized.
|
||||||
"""
|
"""
|
||||||
return data['train'], data.get('val') or data.get('test')
|
return data["train"], data.get("val") or data.get("test")
|
||||||
|
|
||||||
def setup_model(self):
|
def setup_model(self):
|
||||||
"""Load/create/download model for any task."""
|
"""Load/create/download model for any task."""
|
||||||
@ -476,9 +515,9 @@ class BaseTrainer:
|
|||||||
|
|
||||||
model, weights = self.model, None
|
model, weights = self.model, None
|
||||||
ckpt = None
|
ckpt = None
|
||||||
if str(model).endswith('.pt'):
|
if str(model).endswith(".pt"):
|
||||||
weights, ckpt = attempt_load_one_weight(model)
|
weights, ckpt = attempt_load_one_weight(model)
|
||||||
cfg = ckpt['model'].yaml
|
cfg = ckpt["model"].yaml
|
||||||
else:
|
else:
|
||||||
cfg = model
|
cfg = model
|
||||||
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
||||||
@ -505,7 +544,7 @@ class BaseTrainer:
|
|||||||
The returned dict is expected to contain "fitness" key.
|
The returned dict is expected to contain "fitness" key.
|
||||||
"""
|
"""
|
||||||
metrics = self.validator(self)
|
metrics = self.validator(self)
|
||||||
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||||
if not self.best_fitness or self.best_fitness < fitness:
|
if not self.best_fitness or self.best_fitness < fitness:
|
||||||
self.best_fitness = fitness
|
self.best_fitness = fitness
|
||||||
return metrics, fitness
|
return metrics, fitness
|
||||||
@ -516,24 +555,24 @@ class BaseTrainer:
|
|||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
"""Returns a NotImplementedError when the get_validator function is called."""
|
"""Returns a NotImplementedError when the get_validator function is called."""
|
||||||
raise NotImplementedError('get_validator function not implemented in trainer')
|
raise NotImplementedError("get_validator function not implemented in trainer")
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||||
"""Returns dataloader derived from torch.data.Dataloader."""
|
"""Returns dataloader derived from torch.data.Dataloader."""
|
||||||
raise NotImplementedError('get_dataloader function not implemented in trainer')
|
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||||
|
|
||||||
def build_dataset(self, img_path, mode='train', batch=None):
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
"""Build dataset."""
|
"""Build dataset."""
|
||||||
raise NotImplementedError('build_dataset function not implemented in trainer')
|
raise NotImplementedError("build_dataset function not implemented in trainer")
|
||||||
|
|
||||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||||
"""Returns a loss dict with labelled training loss items tensor."""
|
"""Returns a loss dict with labelled training loss items tensor."""
|
||||||
# Not needed for classification but necessary for segmentation & detection
|
# Not needed for classification but necessary for segmentation & detection
|
||||||
return {'loss': loss_items} if loss_items is not None else ['loss']
|
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
||||||
|
|
||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
"""To set or update model parameters before training."""
|
"""To set or update model parameters before training."""
|
||||||
self.model.names = self.data['names']
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
def build_targets(self, preds, targets):
|
def build_targets(self, preds, targets):
|
||||||
"""Builds target tensors for training YOLO model."""
|
"""Builds target tensors for training YOLO model."""
|
||||||
@ -541,7 +580,7 @@ class BaseTrainer:
|
|||||||
|
|
||||||
def progress_string(self):
|
def progress_string(self):
|
||||||
"""Returns a string describing training progress."""
|
"""Returns a string describing training progress."""
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
# TODO: may need to put these following functions into callback
|
# TODO: may need to put these following functions into callback
|
||||||
def plot_training_samples(self, batch, ni):
|
def plot_training_samples(self, batch, ni):
|
||||||
@ -556,9 +595,9 @@ class BaseTrainer:
|
|||||||
"""Saves training metrics to a CSV file."""
|
"""Saves training metrics to a CSV file."""
|
||||||
keys, vals = list(metrics.keys()), list(metrics.values())
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||||
n = len(metrics) + 1 # number of cols
|
n = len(metrics) + 1 # number of cols
|
||||||
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
|
||||||
with open(self.csv, 'a') as f:
|
with open(self.csv, "a") as f:
|
||||||
f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n')
|
f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
|
||||||
|
|
||||||
def plot_metrics(self):
|
def plot_metrics(self):
|
||||||
"""Plot and display metrics visually."""
|
"""Plot and display metrics visually."""
|
||||||
@ -567,7 +606,7 @@ class BaseTrainer:
|
|||||||
def on_plot(self, name, data=None):
|
def on_plot(self, name, data=None):
|
||||||
"""Registers plots (e.g. to be consumed in callbacks)"""
|
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||||
path = Path(name)
|
path = Path(name)
|
||||||
self.plots[path] = {'data': data, 'timestamp': time.time()}
|
self.plots[path] = {"data": data, "timestamp": time.time()}
|
||||||
|
|
||||||
def final_eval(self):
|
def final_eval(self):
|
||||||
"""Performs final evaluation and validation for object detection YOLO model."""
|
"""Performs final evaluation and validation for object detection YOLO model."""
|
||||||
@ -575,11 +614,11 @@ class BaseTrainer:
|
|||||||
if f.exists():
|
if f.exists():
|
||||||
strip_optimizer(f) # strip optimizers
|
strip_optimizer(f) # strip optimizers
|
||||||
if f is self.best:
|
if f is self.best:
|
||||||
LOGGER.info(f'\nValidating {f}...')
|
LOGGER.info(f"\nValidating {f}...")
|
||||||
self.validator.args.plots = self.args.plots
|
self.validator.args.plots = self.args.plots
|
||||||
self.metrics = self.validator(model=f)
|
self.metrics = self.validator(model=f)
|
||||||
self.metrics.pop('fitness', None)
|
self.metrics.pop("fitness", None)
|
||||||
self.run_callbacks('on_fit_epoch_end')
|
self.run_callbacks("on_fit_epoch_end")
|
||||||
|
|
||||||
def check_resume(self, overrides):
|
def check_resume(self, overrides):
|
||||||
"""Check if resume checkpoint exists and update arguments accordingly."""
|
"""Check if resume checkpoint exists and update arguments accordingly."""
|
||||||
@ -591,19 +630,21 @@ class BaseTrainer:
|
|||||||
|
|
||||||
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
||||||
ckpt_args = attempt_load_weights(last).args
|
ckpt_args = attempt_load_weights(last).args
|
||||||
if not Path(ckpt_args['data']).exists():
|
if not Path(ckpt_args["data"]).exists():
|
||||||
ckpt_args['data'] = self.args.data
|
ckpt_args["data"] = self.args.data
|
||||||
|
|
||||||
resume = True
|
resume = True
|
||||||
self.args = get_cfg(ckpt_args)
|
self.args = get_cfg(ckpt_args)
|
||||||
self.args.model = str(last) # reinstate model
|
self.args.model = str(last) # reinstate model
|
||||||
for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
|
for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
|
||||||
if k in overrides:
|
if k in overrides:
|
||||||
setattr(self.args, k, overrides[k])
|
setattr(self.args, k, overrides[k])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
|
raise FileNotFoundError(
|
||||||
"i.e. 'yolo train resume model=path/to/last.pt'") from e
|
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
||||||
|
"i.e. 'yolo train resume model=path/to/last.pt'"
|
||||||
|
) from e
|
||||||
self.resume = resume
|
self.resume = resume
|
||||||
|
|
||||||
def resume_training(self, ckpt):
|
def resume_training(self, ckpt):
|
||||||
@ -611,23 +652,26 @@ class BaseTrainer:
|
|||||||
if ckpt is None:
|
if ckpt is None:
|
||||||
return
|
return
|
||||||
best_fitness = 0.0
|
best_fitness = 0.0
|
||||||
start_epoch = ckpt['epoch'] + 1
|
start_epoch = ckpt["epoch"] + 1
|
||||||
if ckpt['optimizer'] is not None:
|
if ckpt["optimizer"] is not None:
|
||||||
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
||||||
best_fitness = ckpt['best_fitness']
|
best_fitness = ckpt["best_fitness"]
|
||||||
if self.ema and ckpt.get('ema'):
|
if self.ema and ckpt.get("ema"):
|
||||||
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
||||||
self.ema.updates = ckpt['updates']
|
self.ema.updates = ckpt["updates"]
|
||||||
if self.resume:
|
if self.resume:
|
||||||
assert start_epoch > 0, \
|
assert start_epoch > 0, (
|
||||||
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
||||||
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
||||||
|
)
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
|
f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
|
||||||
|
)
|
||||||
if self.epochs < start_epoch:
|
if self.epochs < start_epoch:
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
|
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
||||||
self.epochs += ckpt['epoch'] # finetune additional epochs
|
)
|
||||||
|
self.epochs += ckpt["epoch"] # finetune additional epochs
|
||||||
self.best_fitness = best_fitness
|
self.best_fitness = best_fitness
|
||||||
self.start_epoch = start_epoch
|
self.start_epoch = start_epoch
|
||||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||||
@ -635,13 +679,13 @@ class BaseTrainer:
|
|||||||
|
|
||||||
def _close_dataloader_mosaic(self):
|
def _close_dataloader_mosaic(self):
|
||||||
"""Update dataloaders to stop using mosaic augmentation."""
|
"""Update dataloaders to stop using mosaic augmentation."""
|
||||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
if hasattr(self.train_loader.dataset, "mosaic"):
|
||||||
self.train_loader.dataset.mosaic = False
|
self.train_loader.dataset.mosaic = False
|
||||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
if hasattr(self.train_loader.dataset, "close_mosaic"):
|
||||||
LOGGER.info('Closing dataloader mosaic')
|
LOGGER.info("Closing dataloader mosaic")
|
||||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||||
|
|
||||||
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||||
"""
|
"""
|
||||||
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
|
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
|
||||||
weight decay, and number of iterations.
|
weight decay, and number of iterations.
|
||||||
@ -661,41 +705,45 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
g = [], [], [] # optimizer parameter groups
|
g = [], [], [] # optimizer parameter groups
|
||||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||||
if name == 'auto':
|
if name == "auto":
|
||||||
LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
|
LOGGER.info(
|
||||||
|
f"{colorstr('optimizer:')} 'optimizer=auto' found, "
|
||||||
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
|
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
|
||||||
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
|
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
|
||||||
nc = getattr(model, 'nc', 10) # number of classes
|
)
|
||||||
|
nc = getattr(model, "nc", 10) # number of classes
|
||||||
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
||||||
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
|
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
|
||||||
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
||||||
|
|
||||||
for module_name, module in model.named_modules():
|
for module_name, module in model.named_modules():
|
||||||
for param_name, param in module.named_parameters(recurse=False):
|
for param_name, param in module.named_parameters(recurse=False):
|
||||||
fullname = f'{module_name}.{param_name}' if module_name else param_name
|
fullname = f"{module_name}.{param_name}" if module_name else param_name
|
||||||
if 'bias' in fullname: # bias (no decay)
|
if "bias" in fullname: # bias (no decay)
|
||||||
g[2].append(param)
|
g[2].append(param)
|
||||||
elif isinstance(module, bn): # weight (no decay)
|
elif isinstance(module, bn): # weight (no decay)
|
||||||
g[1].append(param)
|
g[1].append(param)
|
||||||
else: # weight (with decay)
|
else: # weight (with decay)
|
||||||
g[0].append(param)
|
g[0].append(param)
|
||||||
|
|
||||||
if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
|
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
|
||||||
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||||
elif name == 'RMSProp':
|
elif name == "RMSProp":
|
||||||
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||||
elif name == 'SGD':
|
elif name == "SGD":
|
||||||
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Optimizer '{name}' not found in list of available optimizers "
|
f"Optimizer '{name}' not found in list of available optimizers "
|
||||||
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
|
f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
|
||||||
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
|
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
|
||||||
|
)
|
||||||
|
|
||||||
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
|
||||||
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
||||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
|
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
||||||
|
)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|||||||
@ -73,40 +73,43 @@ class Tuner:
|
|||||||
Args:
|
Args:
|
||||||
args (dict, optional): Configuration for hyperparameter evolution.
|
args (dict, optional): Configuration for hyperparameter evolution.
|
||||||
"""
|
"""
|
||||||
self.space = args.pop('space', None) or { # key: (min, max, gain(optional))
|
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
|
||||||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
||||||
'lr0': (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
"lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
||||||
'lrf': (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
|
"lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
|
||||||
'momentum': (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
|
"momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
|
||||||
'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4
|
"weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
|
||||||
'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok)
|
"warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
|
||||||
'warmup_momentum': (0.0, 0.95), # warmup initial momentum
|
"warmup_momentum": (0.0, 0.95), # warmup initial momentum
|
||||||
'box': (1.0, 20.0), # box loss gain
|
"box": (1.0, 20.0), # box loss gain
|
||||||
'cls': (0.2, 4.0), # cls loss gain (scale with pixels)
|
"cls": (0.2, 4.0), # cls loss gain (scale with pixels)
|
||||||
'dfl': (0.4, 6.0), # dfl loss gain
|
"dfl": (0.4, 6.0), # dfl loss gain
|
||||||
'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
"hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||||
'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
"hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||||
'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction)
|
"hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||||
'degrees': (0.0, 45.0), # image rotation (+/- deg)
|
"degrees": (0.0, 45.0), # image rotation (+/- deg)
|
||||||
'translate': (0.0, 0.9), # image translation (+/- fraction)
|
"translate": (0.0, 0.9), # image translation (+/- fraction)
|
||||||
'scale': (0.0, 0.95), # image scale (+/- gain)
|
"scale": (0.0, 0.95), # image scale (+/- gain)
|
||||||
'shear': (0.0, 10.0), # image shear (+/- deg)
|
"shear": (0.0, 10.0), # image shear (+/- deg)
|
||||||
'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
"perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||||
'flipud': (0.0, 1.0), # image flip up-down (probability)
|
"flipud": (0.0, 1.0), # image flip up-down (probability)
|
||||||
'fliplr': (0.0, 1.0), # image flip left-right (probability)
|
"fliplr": (0.0, 1.0), # image flip left-right (probability)
|
||||||
'mosaic': (0.0, 1.0), # image mixup (probability)
|
"mosaic": (0.0, 1.0), # image mixup (probability)
|
||||||
'mixup': (0.0, 1.0), # image mixup (probability)
|
"mixup": (0.0, 1.0), # image mixup (probability)
|
||||||
'copy_paste': (0.0, 1.0)} # segment copy-paste (probability)
|
"copy_paste": (0.0, 1.0), # segment copy-paste (probability)
|
||||||
|
}
|
||||||
self.args = get_cfg(overrides=args)
|
self.args = get_cfg(overrides=args)
|
||||||
self.tune_dir = get_save_dir(self.args, name='tune')
|
self.tune_dir = get_save_dir(self.args, name="tune")
|
||||||
self.tune_csv = self.tune_dir / 'tune_results.csv'
|
self.tune_csv = self.tune_dir / "tune_results.csv"
|
||||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
self.prefix = colorstr('Tuner: ')
|
self.prefix = colorstr("Tuner: ")
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
LOGGER.info(f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
|
LOGGER.info(
|
||||||
f'{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning')
|
f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
|
||||||
|
f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
|
||||||
|
)
|
||||||
|
|
||||||
def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2):
|
def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
|
||||||
"""
|
"""
|
||||||
Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
|
Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
|
||||||
|
|
||||||
@ -121,15 +124,15 @@ class Tuner:
|
|||||||
"""
|
"""
|
||||||
if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
|
if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
|
||||||
# Select parent(s)
|
# Select parent(s)
|
||||||
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
|
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
||||||
fitness = x[:, 0] # first column
|
fitness = x[:, 0] # first column
|
||||||
n = min(n, len(x)) # number of previous results to consider
|
n = min(n, len(x)) # number of previous results to consider
|
||||||
x = x[np.argsort(-fitness)][:n] # top n mutations
|
x = x[np.argsort(-fitness)][:n] # top n mutations
|
||||||
w = x[:, 0] - x[:, 0].min() + 1E-6 # weights (sum > 0)
|
w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
|
||||||
if parent == 'single' or len(x) == 1:
|
if parent == "single" or len(x) == 1:
|
||||||
# x = x[random.randint(0, n - 1)] # random selection
|
# x = x[random.randint(0, n - 1)] # random selection
|
||||||
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
|
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
|
||||||
elif parent == 'weighted':
|
elif parent == "weighted":
|
||||||
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
|
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
|
||||||
|
|
||||||
# Mutate
|
# Mutate
|
||||||
@ -174,44 +177,44 @@ class Tuner:
|
|||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
best_save_dir, best_metrics = None, None
|
best_save_dir, best_metrics = None, None
|
||||||
(self.tune_dir / 'weights').mkdir(parents=True, exist_ok=True)
|
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
# Mutate hyperparameters
|
# Mutate hyperparameters
|
||||||
mutated_hyp = self._mutate()
|
mutated_hyp = self._mutate()
|
||||||
LOGGER.info(f'{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
|
LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
|
||||||
|
|
||||||
metrics = {}
|
metrics = {}
|
||||||
train_args = {**vars(self.args), **mutated_hyp}
|
train_args = {**vars(self.args), **mutated_hyp}
|
||||||
save_dir = get_save_dir(get_cfg(train_args))
|
save_dir = get_save_dir(get_cfg(train_args))
|
||||||
weights_dir = save_dir / 'weights'
|
weights_dir = save_dir / "weights"
|
||||||
ckpt_file = weights_dir / ('best.pt' if (weights_dir / 'best.pt').exists() else 'last.pt')
|
ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
|
||||||
try:
|
try:
|
||||||
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
|
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
|
||||||
cmd = ['yolo', 'train', *(f'{k}={v}' for k, v in train_args.items())]
|
cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())]
|
||||||
return_code = subprocess.run(cmd, check=True).returncode
|
return_code = subprocess.run(cmd, check=True).returncode
|
||||||
metrics = torch.load(ckpt_file)['train_metrics']
|
metrics = torch.load(ckpt_file)["train_metrics"]
|
||||||
assert return_code == 0, 'training failed'
|
assert return_code == 0, "training failed"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}')
|
LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}")
|
||||||
|
|
||||||
# Save results and mutated_hyp to CSV
|
# Save results and mutated_hyp to CSV
|
||||||
fitness = metrics.get('fitness', 0.0)
|
fitness = metrics.get("fitness", 0.0)
|
||||||
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
|
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
|
||||||
headers = '' if self.tune_csv.exists() else (','.join(['fitness'] + list(self.space.keys())) + '\n')
|
headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
|
||||||
with open(self.tune_csv, 'a') as f:
|
with open(self.tune_csv, "a") as f:
|
||||||
f.write(headers + ','.join(map(str, log_row)) + '\n')
|
f.write(headers + ",".join(map(str, log_row)) + "\n")
|
||||||
|
|
||||||
# Get best results
|
# Get best results
|
||||||
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
|
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
||||||
fitness = x[:, 0] # first column
|
fitness = x[:, 0] # first column
|
||||||
best_idx = fitness.argmax()
|
best_idx = fitness.argmax()
|
||||||
best_is_current = best_idx == i
|
best_is_current = best_idx == i
|
||||||
if best_is_current:
|
if best_is_current:
|
||||||
best_save_dir = save_dir
|
best_save_dir = save_dir
|
||||||
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
|
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
|
||||||
for ckpt in weights_dir.glob('*.pt'):
|
for ckpt in weights_dir.glob("*.pt"):
|
||||||
shutil.copy2(ckpt, self.tune_dir / 'weights')
|
shutil.copy2(ckpt, self.tune_dir / "weights")
|
||||||
elif cleanup:
|
elif cleanup:
|
||||||
shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space
|
shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space
|
||||||
|
|
||||||
@ -219,15 +222,19 @@ class Tuner:
|
|||||||
plot_tune_results(self.tune_csv)
|
plot_tune_results(self.tune_csv)
|
||||||
|
|
||||||
# Save and print tune results
|
# Save and print tune results
|
||||||
header = (f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
|
header = (
|
||||||
|
f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
|
||||||
f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
|
f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
|
||||||
f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
|
f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
|
||||||
f'{self.prefix}Best fitness metrics are {best_metrics}\n'
|
f'{self.prefix}Best fitness metrics are {best_metrics}\n'
|
||||||
f'{self.prefix}Best fitness model is {best_save_dir}\n'
|
f'{self.prefix}Best fitness model is {best_save_dir}\n'
|
||||||
f'{self.prefix}Best fitness hyperparameters are printed below.\n')
|
f'{self.prefix}Best fitness hyperparameters are printed below.\n'
|
||||||
LOGGER.info('\n' + header)
|
)
|
||||||
|
LOGGER.info("\n" + header)
|
||||||
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
|
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
|
||||||
yaml_save(self.tune_dir / 'best_hyperparameters.yaml',
|
yaml_save(
|
||||||
|
self.tune_dir / "best_hyperparameters.yaml",
|
||||||
data=data,
|
data=data,
|
||||||
header=remove_colorstr(header.replace(self.prefix, '# ')) + '\n')
|
header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
|
||||||
yaml_print(self.tune_dir / 'best_hyperparameters.yaml')
|
)
|
||||||
|
yaml_print(self.tune_dir / "best_hyperparameters.yaml")
|
||||||
|
|||||||
@ -89,10 +89,10 @@ class BaseValidator:
|
|||||||
self.nc = None
|
self.nc = None
|
||||||
self.iouv = None
|
self.iouv = None
|
||||||
self.jdict = None
|
self.jdict = None
|
||||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||||
|
|
||||||
self.save_dir = save_dir or get_save_dir(self.args)
|
self.save_dir = save_dir or get_save_dir(self.args)
|
||||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
if self.args.conf is None:
|
if self.args.conf is None:
|
||||||
self.args.conf = 0.001 # default conf=0.001
|
self.args.conf = 0.001 # default conf=0.001
|
||||||
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
|
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
|
||||||
@ -110,7 +110,7 @@ class BaseValidator:
|
|||||||
if self.training:
|
if self.training:
|
||||||
self.device = trainer.device
|
self.device = trainer.device
|
||||||
self.data = trainer.data
|
self.data = trainer.data
|
||||||
self.args.half = self.device.type != 'cpu' # force FP16 val during training
|
self.args.half = self.device.type != "cpu" # force FP16 val during training
|
||||||
model = trainer.ema.ema or trainer.model
|
model = trainer.ema.ema or trainer.model
|
||||||
model = model.half() if self.args.half else model.float()
|
model = model.half() if self.args.half else model.float()
|
||||||
# self.model = model
|
# self.model = model
|
||||||
@ -119,11 +119,13 @@ class BaseValidator:
|
|||||||
model.eval()
|
model.eval()
|
||||||
else:
|
else:
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
model = AutoBackend(model or self.args.model,
|
model = AutoBackend(
|
||||||
|
model or self.args.model,
|
||||||
device=select_device(self.args.device, self.args.batch),
|
device=select_device(self.args.device, self.args.batch),
|
||||||
dnn=self.args.dnn,
|
dnn=self.args.dnn,
|
||||||
data=self.args.data,
|
data=self.args.data,
|
||||||
fp16=self.args.half)
|
fp16=self.args.half,
|
||||||
|
)
|
||||||
# self.model = model
|
# self.model = model
|
||||||
self.device = model.device # update device
|
self.device = model.device # update device
|
||||||
self.args.half = model.fp16 # update half
|
self.args.half = model.fp16 # update half
|
||||||
@ -133,16 +135,16 @@ class BaseValidator:
|
|||||||
self.args.batch = model.batch_size
|
self.args.batch = model.batch_size
|
||||||
elif not pt and not jit:
|
elif not pt and not jit:
|
||||||
self.args.batch = 1 # export.py models default to batch-size 1
|
self.args.batch = 1 # export.py models default to batch-size 1
|
||||||
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
||||||
|
|
||||||
if str(self.args.data).split('.')[-1] in ('yaml', 'yml'):
|
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
|
||||||
self.data = check_det_dataset(self.args.data)
|
self.data = check_det_dataset(self.args.data)
|
||||||
elif self.args.task == 'classify':
|
elif self.args.task == "classify":
|
||||||
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||||
|
|
||||||
if self.device.type in ('cpu', 'mps'):
|
if self.device.type in ("cpu", "mps"):
|
||||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||||
if not pt:
|
if not pt:
|
||||||
self.args.rect = False
|
self.args.rect = False
|
||||||
@ -152,13 +154,13 @@ class BaseValidator:
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
|
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
|
||||||
|
|
||||||
self.run_callbacks('on_val_start')
|
self.run_callbacks("on_val_start")
|
||||||
dt = Profile(), Profile(), Profile(), Profile()
|
dt = Profile(), Profile(), Profile(), Profile()
|
||||||
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
|
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
|
||||||
self.init_metrics(de_parallel(model))
|
self.init_metrics(de_parallel(model))
|
||||||
self.jdict = [] # empty before each val
|
self.jdict = [] # empty before each val
|
||||||
for batch_i, batch in enumerate(bar):
|
for batch_i, batch in enumerate(bar):
|
||||||
self.run_callbacks('on_val_batch_start')
|
self.run_callbacks("on_val_batch_start")
|
||||||
self.batch_i = batch_i
|
self.batch_i = batch_i
|
||||||
# Preprocess
|
# Preprocess
|
||||||
with dt[0]:
|
with dt[0]:
|
||||||
@ -166,7 +168,7 @@ class BaseValidator:
|
|||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
with dt[1]:
|
with dt[1]:
|
||||||
preds = model(batch['img'], augment=augment)
|
preds = model(batch["img"], augment=augment)
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
with dt[2]:
|
with dt[2]:
|
||||||
@ -182,23 +184,25 @@ class BaseValidator:
|
|||||||
self.plot_val_samples(batch, batch_i)
|
self.plot_val_samples(batch, batch_i)
|
||||||
self.plot_predictions(batch, preds, batch_i)
|
self.plot_predictions(batch, preds, batch_i)
|
||||||
|
|
||||||
self.run_callbacks('on_val_batch_end')
|
self.run_callbacks("on_val_batch_end")
|
||||||
stats = self.get_stats()
|
stats = self.get_stats()
|
||||||
self.check_stats(stats)
|
self.check_stats(stats)
|
||||||
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
|
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
||||||
self.finalize_metrics()
|
self.finalize_metrics()
|
||||||
self.print_results()
|
self.print_results()
|
||||||
self.run_callbacks('on_val_end')
|
self.run_callbacks("on_val_end")
|
||||||
if self.training:
|
if self.training:
|
||||||
model.float()
|
model.float()
|
||||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||||
else:
|
else:
|
||||||
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
LOGGER.info(
|
||||||
tuple(self.speed.values()))
|
"Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
|
||||||
|
% tuple(self.speed.values())
|
||||||
|
)
|
||||||
if self.args.save_json and self.jdict:
|
if self.args.save_json and self.jdict:
|
||||||
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
with open(str(self.save_dir / "predictions.json"), "w") as f:
|
||||||
LOGGER.info(f'Saving {f.name}...')
|
LOGGER.info(f"Saving {f.name}...")
|
||||||
json.dump(self.jdict, f) # flatten and save
|
json.dump(self.jdict, f) # flatten and save
|
||||||
stats = self.eval_json(stats) # update stats
|
stats = self.eval_json(stats) # update stats
|
||||||
if self.args.plots or self.args.save_json:
|
if self.args.plots or self.args.save_json:
|
||||||
@ -228,6 +232,7 @@ class BaseValidator:
|
|||||||
if use_scipy:
|
if use_scipy:
|
||||||
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
|
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
|
||||||
import scipy # scope import to avoid importing for all commands
|
import scipy # scope import to avoid importing for all commands
|
||||||
|
|
||||||
cost_matrix = iou * (iou >= threshold)
|
cost_matrix = iou * (iou >= threshold)
|
||||||
if cost_matrix.any():
|
if cost_matrix.any():
|
||||||
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
|
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
|
||||||
@ -257,11 +262,11 @@ class BaseValidator:
|
|||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size):
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
"""Get data loader from dataset path and batch size."""
|
"""Get data loader from dataset path and batch size."""
|
||||||
raise NotImplementedError('get_dataloader function not implemented for this validator')
|
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
||||||
|
|
||||||
def build_dataset(self, img_path):
|
def build_dataset(self, img_path):
|
||||||
"""Build dataset."""
|
"""Build dataset."""
|
||||||
raise NotImplementedError('build_dataset function not implemented in validator')
|
raise NotImplementedError("build_dataset function not implemented in validator")
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
"""Preprocesses an input batch."""
|
"""Preprocesses an input batch."""
|
||||||
@ -306,7 +311,7 @@ class BaseValidator:
|
|||||||
|
|
||||||
def on_plot(self, name, data=None):
|
def on_plot(self, name, data=None):
|
||||||
"""Registers plots (e.g. to be consumed in callbacks)"""
|
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||||
self.plots[Path(name)] = {'data': data, 'timestamp': time.time()}
|
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
||||||
|
|
||||||
# TODO: may need to put these following functions into callback
|
# TODO: may need to put these following functions into callback
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
|
|||||||
@ -21,10 +21,10 @@ def login(api_key: str = None, save=True) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if authentication is successful, False otherwise.
|
bool: True if authentication is successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
api_key_url = f'{HUB_WEB_ROOT}/settings?tab=api+keys' # Set the redirect URL
|
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
|
||||||
saved_key = SETTINGS.get('api_key')
|
saved_key = SETTINGS.get("api_key")
|
||||||
active_key = api_key or saved_key
|
active_key = api_key or saved_key
|
||||||
credentials = {'api_key': active_key} if active_key and active_key != '' else None # Set credentials
|
credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
|
||||||
|
|
||||||
client = HUBClient(credentials) # initialize HUBClient
|
client = HUBClient(credentials) # initialize HUBClient
|
||||||
|
|
||||||
@ -32,17 +32,18 @@ def login(api_key: str = None, save=True) -> bool:
|
|||||||
# Successfully authenticated with HUB
|
# Successfully authenticated with HUB
|
||||||
|
|
||||||
if save and client.api_key != saved_key:
|
if save and client.api_key != saved_key:
|
||||||
SETTINGS.update({'api_key': client.api_key}) # update settings with valid API key
|
SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
|
||||||
|
|
||||||
# Set message based on whether key was provided or retrieved from settings
|
# Set message based on whether key was provided or retrieved from settings
|
||||||
log_message = ('New authentication successful ✅'
|
log_message = (
|
||||||
if client.api_key == api_key or not credentials else 'Authenticated ✅')
|
"New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
|
||||||
LOGGER.info(f'{PREFIX}{log_message}')
|
)
|
||||||
|
LOGGER.info(f"{PREFIX}{log_message}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
# Failed to authenticate with HUB
|
# Failed to authenticate with HUB
|
||||||
LOGGER.info(f'{PREFIX}Retrieve API key from {api_key_url}')
|
LOGGER.info(f"{PREFIX}Retrieve API key from {api_key_url}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -57,50 +58,50 @@ def logout():
|
|||||||
hub.logout()
|
hub.logout()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
SETTINGS['api_key'] = ''
|
SETTINGS["api_key"] = ""
|
||||||
SETTINGS.save()
|
SETTINGS.save()
|
||||||
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
|
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
|
||||||
|
|
||||||
|
|
||||||
def reset_model(model_id=''):
|
def reset_model(model_id=""):
|
||||||
"""Reset a trained model to an untrained state."""
|
"""Reset a trained model to an untrained state."""
|
||||||
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'modelId': model_id}, headers={'x-api-key': Auth().api_key})
|
r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
|
||||||
if r.status_code == 200:
|
if r.status_code == 200:
|
||||||
LOGGER.info(f'{PREFIX}Model reset successfully')
|
LOGGER.info(f"{PREFIX}Model reset successfully")
|
||||||
return
|
return
|
||||||
LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
|
LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
|
||||||
|
|
||||||
|
|
||||||
def export_fmts_hub():
|
def export_fmts_hub():
|
||||||
"""Returns a list of HUB-supported export formats."""
|
"""Returns a list of HUB-supported export formats."""
|
||||||
from ultralytics.engine.exporter import export_formats
|
from ultralytics.engine.exporter import export_formats
|
||||||
return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
|
|
||||||
|
return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
|
||||||
|
|
||||||
|
|
||||||
def export_model(model_id='', format='torchscript'):
|
def export_model(model_id="", format="torchscript"):
|
||||||
"""Export a model to all formats."""
|
"""Export a model to all formats."""
|
||||||
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
||||||
r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export',
|
r = requests.post(
|
||||||
json={'format': format},
|
f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
|
||||||
headers={'x-api-key': Auth().api_key})
|
)
|
||||||
assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
|
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
|
||||||
LOGGER.info(f'{PREFIX}{format} export started ✅')
|
LOGGER.info(f"{PREFIX}{format} export started ✅")
|
||||||
|
|
||||||
|
|
||||||
def get_export(model_id='', format='torchscript'):
|
def get_export(model_id="", format="torchscript"):
|
||||||
"""Get an exported model dictionary with download URL."""
|
"""Get an exported model dictionary with download URL."""
|
||||||
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
||||||
r = requests.post(f'{HUB_API_ROOT}/get-export',
|
r = requests.post(
|
||||||
json={
|
f"{HUB_API_ROOT}/get-export",
|
||||||
'apiKey': Auth().api_key,
|
json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
|
||||||
'modelId': model_id,
|
headers={"x-api-key": Auth().api_key},
|
||||||
'format': format},
|
)
|
||||||
headers={'x-api-key': Auth().api_key})
|
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
|
||||||
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
|
|
||||||
return r.json()
|
return r.json()
|
||||||
|
|
||||||
|
|
||||||
def check_dataset(path='', task='detect'):
|
def check_dataset(path="", task="detect"):
|
||||||
"""
|
"""
|
||||||
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
|
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
|
||||||
to the HUB. Usage examples are given below.
|
to the HUB. Usage examples are given below.
|
||||||
@ -119,4 +120,4 @@ def check_dataset(path='', task='detect'):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
HUBDatasetStats(path=path, task=task).get_json()
|
HUBDatasetStats(path=path, task=task).get_json()
|
||||||
LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
|
LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT
|
|||||||
from ultralytics.hub.utils import PREFIX, request_with_credentials
|
from ultralytics.hub.utils import PREFIX, request_with_credentials
|
||||||
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
|
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
|
||||||
|
|
||||||
API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'
|
API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
|
||||||
|
|
||||||
|
|
||||||
class Auth:
|
class Auth:
|
||||||
@ -23,9 +23,10 @@ class Auth:
|
|||||||
api_key (str or bool): API key for authentication, initialized as False.
|
api_key (str or bool): API key for authentication, initialized as False.
|
||||||
model_key (bool): Placeholder for model key, initialized as False.
|
model_key (bool): Placeholder for model key, initialized as False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id_token = api_key = model_key = False
|
id_token = api_key = model_key = False
|
||||||
|
|
||||||
def __init__(self, api_key='', verbose=False):
|
def __init__(self, api_key="", verbose=False):
|
||||||
"""
|
"""
|
||||||
Initialize the Auth class with an optional API key.
|
Initialize the Auth class with an optional API key.
|
||||||
|
|
||||||
@ -33,18 +34,18 @@ class Auth:
|
|||||||
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
||||||
"""
|
"""
|
||||||
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
||||||
api_key = api_key.split('_')[0]
|
api_key = api_key.split("_")[0]
|
||||||
|
|
||||||
# Set API key attribute as value passed or SETTINGS API key if none passed
|
# Set API key attribute as value passed or SETTINGS API key if none passed
|
||||||
self.api_key = api_key or SETTINGS.get('api_key', '')
|
self.api_key = api_key or SETTINGS.get("api_key", "")
|
||||||
|
|
||||||
# If an API key is provided
|
# If an API key is provided
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
# If the provided API key matches the API key in the SETTINGS
|
# If the provided API key matches the API key in the SETTINGS
|
||||||
if self.api_key == SETTINGS.get('api_key'):
|
if self.api_key == SETTINGS.get("api_key"):
|
||||||
# Log that the user is already logged in
|
# Log that the user is already logged in
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f'{PREFIX}Authenticated ✅')
|
LOGGER.info(f"{PREFIX}Authenticated ✅")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Attempt to authenticate with the provided API key
|
# Attempt to authenticate with the provided API key
|
||||||
@ -59,12 +60,12 @@ class Auth:
|
|||||||
|
|
||||||
# Update SETTINGS with the new API key after successful authentication
|
# Update SETTINGS with the new API key after successful authentication
|
||||||
if success:
|
if success:
|
||||||
SETTINGS.update({'api_key': self.api_key})
|
SETTINGS.update({"api_key": self.api_key})
|
||||||
# Log that the new login was successful
|
# Log that the new login was successful
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f'{PREFIX}New authentication successful ✅')
|
LOGGER.info(f"{PREFIX}New authentication successful ✅")
|
||||||
elif verbose:
|
elif verbose:
|
||||||
LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
|
LOGGER.info(f"{PREFIX}Retrieve API key from {API_KEY_URL}")
|
||||||
|
|
||||||
def request_api_key(self, max_attempts=3):
|
def request_api_key(self, max_attempts=3):
|
||||||
"""
|
"""
|
||||||
@ -73,13 +74,14 @@ class Auth:
|
|||||||
Returns the model ID.
|
Returns the model ID.
|
||||||
"""
|
"""
|
||||||
import getpass
|
import getpass
|
||||||
|
|
||||||
for attempts in range(max_attempts):
|
for attempts in range(max_attempts):
|
||||||
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
|
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
|
||||||
input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
|
input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
|
||||||
self.api_key = input_key.split('_')[0] # remove model id if present
|
self.api_key = input_key.split("_")[0] # remove model id if present
|
||||||
if self.authenticate():
|
if self.authenticate():
|
||||||
return True
|
return True
|
||||||
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
|
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
|
||||||
|
|
||||||
def authenticate(self) -> bool:
|
def authenticate(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -90,14 +92,14 @@ class Auth:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if header := self.get_auth_header():
|
if header := self.get_auth_header():
|
||||||
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
|
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
|
||||||
if not r.json().get('success', False):
|
if not r.json().get("success", False):
|
||||||
raise ConnectionError('Unable to authenticate.')
|
raise ConnectionError("Unable to authenticate.")
|
||||||
return True
|
return True
|
||||||
raise ConnectionError('User has not authenticated locally.')
|
raise ConnectionError("User has not authenticated locally.")
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
self.id_token = self.api_key = False # reset invalid
|
self.id_token = self.api_key = False # reset invalid
|
||||||
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
|
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def auth_with_cookies(self) -> bool:
|
def auth_with_cookies(self) -> bool:
|
||||||
@ -111,12 +113,12 @@ class Auth:
|
|||||||
if not is_colab():
|
if not is_colab():
|
||||||
return False # Currently only works with Colab
|
return False # Currently only works with Colab
|
||||||
try:
|
try:
|
||||||
authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
|
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
|
||||||
if authn.get('success', False):
|
if authn.get("success", False):
|
||||||
self.id_token = authn.get('data', {}).get('idToken', None)
|
self.id_token = authn.get("data", {}).get("idToken", None)
|
||||||
self.authenticate()
|
self.authenticate()
|
||||||
return True
|
return True
|
||||||
raise ConnectionError('Unable to fetch browser authentication details.')
|
raise ConnectionError("Unable to fetch browser authentication details.")
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
self.id_token = False # reset invalid
|
self.id_token = False # reset invalid
|
||||||
return False
|
return False
|
||||||
@ -129,7 +131,7 @@ class Auth:
|
|||||||
(dict): The authentication header if id_token or API key is set, None otherwise.
|
(dict): The authentication header if id_token or API key is set, None otherwise.
|
||||||
"""
|
"""
|
||||||
if self.id_token:
|
if self.id_token:
|
||||||
return {'authorization': f'Bearer {self.id_token}'}
|
return {"authorization": f"Bearer {self.id_token}"}
|
||||||
elif self.api_key:
|
elif self.api_key:
|
||||||
return {'x-api-key': self.api_key}
|
return {"x-api-key": self.api_key}
|
||||||
# else returns None
|
# else returns None
|
||||||
|
|||||||
@ -12,16 +12,13 @@ from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM
|
|||||||
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
|
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
|
||||||
from ultralytics.utils.errors import HUBModelError
|
from ultralytics.utils.errors import HUBModelError
|
||||||
|
|
||||||
AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local')
|
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
|
||||||
|
|
||||||
|
|
||||||
class HUBTrainingSession:
|
class HUBTrainingSession:
|
||||||
"""
|
"""
|
||||||
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): Model identifier used to initialize the HUB training session.
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
agent_id (str): Identifier for the instance communicating with the server.
|
agent_id (str): Identifier for the instance communicating with the server.
|
||||||
model_id (str): Identifier for the YOLO model being trained.
|
model_id (str): Identifier for the YOLO model being trained.
|
||||||
@ -40,7 +37,7 @@ class HUBTrainingSession:
|
|||||||
Initialize the HUBTrainingSession with the provided model identifier.
|
Initialize the HUBTrainingSession with the provided model identifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url (str): Model identifier used to initialize the HUB training session.
|
identifier (str): Model identifier used to initialize the HUB training session.
|
||||||
It can be a URL string or a model key with specific format.
|
It can be a URL string or a model key with specific format.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -48,9 +45,10 @@ class HUBTrainingSession:
|
|||||||
ConnectionError: If connecting with global API key is not supported.
|
ConnectionError: If connecting with global API key is not supported.
|
||||||
"""
|
"""
|
||||||
self.rate_limits = {
|
self.rate_limits = {
|
||||||
'metrics': 3.0,
|
"metrics": 3.0,
|
||||||
'ckpt': 900.0,
|
"ckpt": 900.0,
|
||||||
'heartbeat': 300.0, } # rate limits (seconds)
|
"heartbeat": 300.0,
|
||||||
|
} # rate limits (seconds)
|
||||||
self.metrics_queue = {} # holds metrics for each epoch until upload
|
self.metrics_queue = {} # holds metrics for each epoch until upload
|
||||||
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
||||||
|
|
||||||
@ -58,8 +56,8 @@ class HUBTrainingSession:
|
|||||||
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
||||||
|
|
||||||
# Get credentials
|
# Get credentials
|
||||||
active_key = api_key or SETTINGS.get('api_key')
|
active_key = api_key or SETTINGS.get("api_key")
|
||||||
credentials = {'api_key': active_key} if active_key else None # set credentials
|
credentials = {"api_key": active_key} if active_key else None # set credentials
|
||||||
|
|
||||||
# Initialize client
|
# Initialize client
|
||||||
self.client = HUBClient(credentials)
|
self.client = HUBClient(credentials)
|
||||||
@ -72,35 +70,37 @@ class HUBTrainingSession:
|
|||||||
def load_model(self, model_id):
|
def load_model(self, model_id):
|
||||||
# Initialize model
|
# Initialize model
|
||||||
self.model = self.client.model(model_id)
|
self.model = self.client.model(model_id)
|
||||||
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
|
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||||
|
|
||||||
self._set_train_args()
|
self._set_train_args()
|
||||||
|
|
||||||
# Start heartbeats for HUB to monitor agent
|
# Start heartbeats for HUB to monitor agent
|
||||||
self.model.start_heartbeat(self.rate_limits['heartbeat'])
|
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
||||||
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||||
|
|
||||||
def create_model(self, model_args):
|
def create_model(self, model_args):
|
||||||
# Initialize model
|
# Initialize model
|
||||||
payload = {
|
payload = {
|
||||||
'config': {
|
"config": {
|
||||||
'batchSize': model_args.get('batch', -1),
|
"batchSize": model_args.get("batch", -1),
|
||||||
'epochs': model_args.get('epochs', 300),
|
"epochs": model_args.get("epochs", 300),
|
||||||
'imageSize': model_args.get('imgsz', 640),
|
"imageSize": model_args.get("imgsz", 640),
|
||||||
'patience': model_args.get('patience', 100),
|
"patience": model_args.get("patience", 100),
|
||||||
'device': model_args.get('device', ''),
|
"device": model_args.get("device", ""),
|
||||||
'cache': model_args.get('cache', 'ram'), },
|
"cache": model_args.get("cache", "ram"),
|
||||||
'dataset': {
|
},
|
||||||
'name': model_args.get('data')},
|
"dataset": {"name": model_args.get("data")},
|
||||||
'lineage': {
|
"lineage": {
|
||||||
'architecture': {
|
"architecture": {
|
||||||
'name': self.filename.replace('.pt', '').replace('.yaml', ''), },
|
"name": self.filename.replace(".pt", "").replace(".yaml", ""),
|
||||||
'parent': {}, },
|
},
|
||||||
'meta': {
|
"parent": {},
|
||||||
'name': self.filename}, }
|
},
|
||||||
|
"meta": {"name": self.filename},
|
||||||
|
}
|
||||||
|
|
||||||
if self.filename.endswith('.pt'):
|
if self.filename.endswith(".pt"):
|
||||||
payload['lineage']['parent']['name'] = self.filename
|
payload["lineage"]["parent"]["name"] = self.filename
|
||||||
|
|
||||||
self.model.create_model(payload)
|
self.model.create_model(payload)
|
||||||
|
|
||||||
@ -109,12 +109,12 @@ class HUBTrainingSession:
|
|||||||
if not self.model.id:
|
if not self.model.id:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
|
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||||
|
|
||||||
# Start heartbeats for HUB to monitor agent
|
# Start heartbeats for HUB to monitor agent
|
||||||
self.model.start_heartbeat(self.rate_limits['heartbeat'])
|
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
||||||
|
|
||||||
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||||
|
|
||||||
def _parse_identifier(self, identifier):
|
def _parse_identifier(self, identifier):
|
||||||
"""
|
"""
|
||||||
@ -140,12 +140,12 @@ class HUBTrainingSession:
|
|||||||
api_key, model_id, filename = None, None, None
|
api_key, model_id, filename = None, None, None
|
||||||
|
|
||||||
# Check if identifier is a HUB URL
|
# Check if identifier is a HUB URL
|
||||||
if identifier.startswith(f'{HUB_WEB_ROOT}/models/'):
|
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||||
# Extract the model_id after the HUB_WEB_ROOT URL
|
# Extract the model_id after the HUB_WEB_ROOT URL
|
||||||
model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1]
|
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
|
||||||
else:
|
else:
|
||||||
# Split the identifier based on underscores only if it's not a HUB URL
|
# Split the identifier based on underscores only if it's not a HUB URL
|
||||||
parts = identifier.split('_')
|
parts = identifier.split("_")
|
||||||
|
|
||||||
# Check if identifier is in the format of API key and model ID
|
# Check if identifier is in the format of API key and model ID
|
||||||
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
||||||
@ -154,43 +154,46 @@ class HUBTrainingSession:
|
|||||||
elif len(parts) == 1 and len(parts[0]) == 20:
|
elif len(parts) == 1 and len(parts[0]) == 20:
|
||||||
model_id = parts[0]
|
model_id = parts[0]
|
||||||
# Check if identifier is a local filename
|
# Check if identifier is a local filename
|
||||||
elif identifier.endswith('.pt') or identifier.endswith('.yaml'):
|
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
|
||||||
filename = identifier
|
filename = identifier
|
||||||
else:
|
else:
|
||||||
raise HUBModelError(
|
raise HUBModelError(
|
||||||
f"model='{identifier}' could not be parsed. Check format is correct. "
|
f"model='{identifier}' could not be parsed. Check format is correct. "
|
||||||
f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.')
|
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
|
||||||
|
)
|
||||||
|
|
||||||
return api_key, model_id, filename
|
return api_key, model_id, filename
|
||||||
|
|
||||||
def _set_train_args(self, **kwargs):
|
def _set_train_args(self, **kwargs):
|
||||||
if self.model.is_trained():
|
if self.model.is_trained():
|
||||||
# Model is already trained
|
# Model is already trained
|
||||||
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
|
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
|
||||||
|
|
||||||
if self.model.is_resumable():
|
if self.model.is_resumable():
|
||||||
# Model has saved weights
|
# Model has saved weights
|
||||||
self.train_args = {'data': self.model.get_dataset_url(), 'resume': True}
|
self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
|
||||||
self.model_file = self.model.get_weights_url('last')
|
self.model_file = self.model.get_weights_url("last")
|
||||||
else:
|
else:
|
||||||
# Model has no saved weights
|
# Model has no saved weights
|
||||||
def get_train_args(config):
|
def get_train_args(config):
|
||||||
return {
|
return {
|
||||||
'batch': config['batchSize'],
|
"batch": config["batchSize"],
|
||||||
'epochs': config['epochs'],
|
"epochs": config["epochs"],
|
||||||
'imgsz': config['imageSize'],
|
"imgsz": config["imageSize"],
|
||||||
'patience': config['patience'],
|
"patience": config["patience"],
|
||||||
'device': config['device'],
|
"device": config["device"],
|
||||||
'cache': config['cache'],
|
"cache": config["cache"],
|
||||||
'data': self.model.get_dataset_url(), }
|
"data": self.model.get_dataset_url(),
|
||||||
|
}
|
||||||
|
|
||||||
self.train_args = get_train_args(self.model.data.get('config'))
|
self.train_args = get_train_args(self.model.data.get("config"))
|
||||||
# Set the model file as either a *.pt or *.yaml file
|
# Set the model file as either a *.pt or *.yaml file
|
||||||
self.model_file = (self.model.get_weights_url('parent')
|
self.model_file = (
|
||||||
if self.model.is_pretrained() else self.model.get_architecture())
|
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
|
||||||
|
)
|
||||||
|
|
||||||
if not self.train_args.get('data'):
|
if not self.train_args.get("data"):
|
||||||
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
|
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
|
||||||
|
|
||||||
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
||||||
self.model_id = self.model.id
|
self.model_id = self.model.id
|
||||||
@ -206,12 +209,11 @@ class HUBTrainingSession:
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
def retry_request():
|
def retry_request():
|
||||||
t0 = time.time() # Record the start time for the timeout
|
t0 = time.time() # Record the start time for the timeout
|
||||||
for i in range(retry + 1):
|
for i in range(retry + 1):
|
||||||
if (time.time() - t0) > timeout:
|
if (time.time() - t0) > timeout:
|
||||||
LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}')
|
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
|
||||||
break # Timeout reached, exit loop
|
break # Timeout reached, exit loop
|
||||||
|
|
||||||
response = request_func(*args, **kwargs)
|
response = request_func(*args, **kwargs)
|
||||||
@ -219,8 +221,8 @@ class HUBTrainingSession:
|
|||||||
self._show_upload_progress(progress_total, response)
|
self._show_upload_progress(progress_total, response)
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}')
|
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
|
||||||
time.sleep(2 ** i) # Exponential backoff before retrying
|
time.sleep(2**i) # Exponential backoff before retrying
|
||||||
continue # Skip further processing and retry
|
continue # Skip further processing and retry
|
||||||
|
|
||||||
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||||
@ -231,13 +233,13 @@ class HUBTrainingSession:
|
|||||||
message = self._get_failure_message(response, retry, timeout)
|
message = self._get_failure_message(response, retry, timeout)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})')
|
LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
|
||||||
|
|
||||||
if not self._should_retry(response.status_code):
|
if not self._should_retry(response.status_code):
|
||||||
LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}')
|
LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
|
||||||
break # Not an error that should be retried, exit loop
|
break # Not an error that should be retried, exit loop
|
||||||
|
|
||||||
time.sleep(2 ** i) # Exponential backoff for retries
|
time.sleep(2**i) # Exponential backoff for retries
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -253,7 +255,8 @@ class HUBTrainingSession:
|
|||||||
retry_codes = {
|
retry_codes = {
|
||||||
HTTPStatus.REQUEST_TIMEOUT,
|
HTTPStatus.REQUEST_TIMEOUT,
|
||||||
HTTPStatus.BAD_GATEWAY,
|
HTTPStatus.BAD_GATEWAY,
|
||||||
HTTPStatus.GATEWAY_TIMEOUT, }
|
HTTPStatus.GATEWAY_TIMEOUT,
|
||||||
|
}
|
||||||
return True if status_code in retry_codes else False
|
return True if status_code in retry_codes else False
|
||||||
|
|
||||||
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
|
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
|
||||||
@ -269,16 +272,18 @@ class HUBTrainingSession:
|
|||||||
str: The retry message.
|
str: The retry message.
|
||||||
"""
|
"""
|
||||||
if self._should_retry(response.status_code):
|
if self._should_retry(response.status_code):
|
||||||
return f'Retrying {retry}x for {timeout}s.' if retry else ''
|
return f"Retrying {retry}x for {timeout}s." if retry else ""
|
||||||
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
|
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
|
||||||
headers = response.headers
|
headers = response.headers
|
||||||
return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
|
return (
|
||||||
f"Please retry after {headers['Retry-After']}s.")
|
f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
|
||||||
|
f"Please retry after {headers['Retry-After']}s."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return response.json().get('message', 'No JSON message.')
|
return response.json().get("message", "No JSON message.")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return 'Unable to read JSON.'
|
return "Unable to read JSON."
|
||||||
|
|
||||||
def upload_metrics(self):
|
def upload_metrics(self):
|
||||||
"""Upload model metrics to Ultralytics HUB."""
|
"""Upload model metrics to Ultralytics HUB."""
|
||||||
@ -303,7 +308,7 @@ class HUBTrainingSession:
|
|||||||
final (bool): Indicates if the model is the final model after training.
|
final (bool): Indicates if the model is the final model after training.
|
||||||
"""
|
"""
|
||||||
if Path(weights).is_file():
|
if Path(weights).is_file():
|
||||||
progress_total = (Path(weights).stat().st_size if final else None) # Only show progress if final
|
progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
|
||||||
self.request_queue(
|
self.request_queue(
|
||||||
self.model.upload_model,
|
self.model.upload_model,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
@ -317,7 +322,7 @@ class HUBTrainingSession:
|
|||||||
progress_total=progress_total,
|
progress_total=progress_total,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
|
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
|
||||||
|
|
||||||
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
|
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
|
||||||
"""
|
"""
|
||||||
@ -330,6 +335,6 @@ class HUBTrainingSession:
|
|||||||
Returns:
|
Returns:
|
||||||
(None)
|
(None)
|
||||||
"""
|
"""
|
||||||
with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
||||||
for data in response.iter_content(chunk_size=1024):
|
for data in response.iter_content(chunk_size=1024):
|
||||||
pbar.update(len(data))
|
pbar.update(len(data))
|
||||||
|
|||||||
@ -9,12 +9,26 @@ from pathlib import Path
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, __version__,
|
from ultralytics.utils import (
|
||||||
colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package)
|
ENVIRONMENT,
|
||||||
|
LOGGER,
|
||||||
|
ONLINE,
|
||||||
|
RANK,
|
||||||
|
SETTINGS,
|
||||||
|
TESTS_RUNNING,
|
||||||
|
TQDM,
|
||||||
|
TryExcept,
|
||||||
|
__version__,
|
||||||
|
colorstr,
|
||||||
|
get_git_origin_url,
|
||||||
|
is_colab,
|
||||||
|
is_git_dir,
|
||||||
|
is_pip_package,
|
||||||
|
)
|
||||||
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
|
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
|
||||||
|
|
||||||
PREFIX = colorstr('Ultralytics HUB: ')
|
PREFIX = colorstr("Ultralytics HUB: ")
|
||||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
|
||||||
|
|
||||||
|
|
||||||
def request_with_credentials(url: str) -> any:
|
def request_with_credentials(url: str) -> any:
|
||||||
@ -31,11 +45,13 @@ def request_with_credentials(url: str) -> any:
|
|||||||
OSError: If the function is not run in a Google Colab environment.
|
OSError: If the function is not run in a Google Colab environment.
|
||||||
"""
|
"""
|
||||||
if not is_colab():
|
if not is_colab():
|
||||||
raise OSError('request_with_credentials() must run in a Colab environment')
|
raise OSError("request_with_credentials() must run in a Colab environment")
|
||||||
from google.colab import output # noqa
|
from google.colab import output # noqa
|
||||||
from IPython import display # noqa
|
from IPython import display # noqa
|
||||||
|
|
||||||
display.display(
|
display.display(
|
||||||
display.Javascript("""
|
display.Javascript(
|
||||||
|
"""
|
||||||
window._hub_tmp = new Promise((resolve, reject) => {
|
window._hub_tmp = new Promise((resolve, reject) => {
|
||||||
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
|
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
|
||||||
fetch("%s", {
|
fetch("%s", {
|
||||||
@ -50,8 +66,11 @@ def request_with_credentials(url: str) -> any:
|
|||||||
reject(err);
|
reject(err);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
""" % url))
|
"""
|
||||||
return output.eval_js('_hub_tmp')
|
% url
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return output.eval_js("_hub_tmp")
|
||||||
|
|
||||||
|
|
||||||
def requests_with_progress(method, url, **kwargs):
|
def requests_with_progress(method, url, **kwargs):
|
||||||
@ -71,13 +90,13 @@ def requests_with_progress(method, url, **kwargs):
|
|||||||
content length.
|
content length.
|
||||||
- If 'progress' is a number then progress bar will display assuming content length = progress.
|
- If 'progress' is a number then progress bar will display assuming content length = progress.
|
||||||
"""
|
"""
|
||||||
progress = kwargs.pop('progress', False)
|
progress = kwargs.pop("progress", False)
|
||||||
if not progress:
|
if not progress:
|
||||||
return requests.request(method, url, **kwargs)
|
return requests.request(method, url, **kwargs)
|
||||||
response = requests.request(method, url, stream=True, **kwargs)
|
response = requests.request(method, url, stream=True, **kwargs)
|
||||||
total = int(response.headers.get('content-length', 0) if isinstance(progress, bool) else progress) # total size
|
total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size
|
||||||
try:
|
try:
|
||||||
pbar = TQDM(total=total, unit='B', unit_scale=True, unit_divisor=1024)
|
pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
|
||||||
for data in response.iter_content(chunk_size=1024):
|
for data in response.iter_content(chunk_size=1024):
|
||||||
pbar.update(len(data))
|
pbar.update(len(data))
|
||||||
pbar.close()
|
pbar.close()
|
||||||
@ -118,25 +137,27 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
|
|||||||
if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
|
if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
m = r.json().get('message', 'No JSON message.')
|
m = r.json().get("message", "No JSON message.")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
m = 'Unable to read JSON.'
|
m = "Unable to read JSON."
|
||||||
if i == 0:
|
if i == 0:
|
||||||
if r.status_code in retry_codes:
|
if r.status_code in retry_codes:
|
||||||
m += f' Retrying {retry}x for {timeout}s.' if retry else ''
|
m += f" Retrying {retry}x for {timeout}s." if retry else ""
|
||||||
elif r.status_code == 429: # rate limit
|
elif r.status_code == 429: # rate limit
|
||||||
h = r.headers # response headers
|
h = r.headers # response headers
|
||||||
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
|
m = (
|
||||||
|
f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "
|
||||||
f"Please retry after {h['Retry-After']}s."
|
f"Please retry after {h['Retry-After']}s."
|
||||||
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
|
LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
|
||||||
if r.status_code not in retry_codes:
|
if r.status_code not in retry_codes:
|
||||||
return r
|
return r
|
||||||
time.sleep(2 ** i) # exponential standoff
|
time.sleep(2**i) # exponential standoff
|
||||||
return r
|
return r
|
||||||
|
|
||||||
args = method, url
|
args = method, url
|
||||||
kwargs['progress'] = progress
|
kwargs["progress"] = progress
|
||||||
if thread:
|
if thread:
|
||||||
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
||||||
else:
|
else:
|
||||||
@ -155,7 +176,7 @@ class Events:
|
|||||||
enabled (bool): A flag to enable or disable Events based on certain conditions.
|
enabled (bool): A flag to enable or disable Events based on certain conditions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw'
|
url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initializes the Events object with default values for events, rate_limit, and metadata."""
|
"""Initializes the Events object with default values for events, rate_limit, and metadata."""
|
||||||
@ -163,19 +184,21 @@ class Events:
|
|||||||
self.rate_limit = 60.0 # rate limit (seconds)
|
self.rate_limit = 60.0 # rate limit (seconds)
|
||||||
self.t = 0.0 # rate limit timer (seconds)
|
self.t = 0.0 # rate limit timer (seconds)
|
||||||
self.metadata = {
|
self.metadata = {
|
||||||
'cli': Path(sys.argv[0]).name == 'yolo',
|
"cli": Path(sys.argv[0]).name == "yolo",
|
||||||
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
|
||||||
'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10
|
"python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
|
||||||
'version': __version__,
|
"version": __version__,
|
||||||
'env': ENVIRONMENT,
|
"env": ENVIRONMENT,
|
||||||
'session_id': round(random.random() * 1E15),
|
"session_id": round(random.random() * 1e15),
|
||||||
'engagement_time_msec': 1000}
|
"engagement_time_msec": 1000,
|
||||||
self.enabled = \
|
}
|
||||||
SETTINGS['sync'] and \
|
self.enabled = (
|
||||||
RANK in (-1, 0) and \
|
SETTINGS["sync"]
|
||||||
not TESTS_RUNNING and \
|
and RANK in (-1, 0)
|
||||||
ONLINE and \
|
and not TESTS_RUNNING
|
||||||
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
|
and ONLINE
|
||||||
|
and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, cfg):
|
def __call__(self, cfg):
|
||||||
"""
|
"""
|
||||||
@ -191,11 +214,13 @@ class Events:
|
|||||||
# Attempt to add to events
|
# Attempt to add to events
|
||||||
if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
|
if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
|
||||||
params = {
|
params = {
|
||||||
**self.metadata, 'task': cfg.task,
|
**self.metadata,
|
||||||
'model': cfg.model if cfg.model in GITHUB_ASSETS_NAMES else 'custom'}
|
"task": cfg.task,
|
||||||
if cfg.mode == 'export':
|
"model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
|
||||||
params['format'] = cfg.format
|
}
|
||||||
self.events.append({'name': cfg.mode, 'params': params})
|
if cfg.mode == "export":
|
||||||
|
params["format"] = cfg.format
|
||||||
|
self.events.append({"name": cfg.mode, "params": params})
|
||||||
|
|
||||||
# Check rate limit
|
# Check rate limit
|
||||||
t = time.time()
|
t = time.time()
|
||||||
@ -204,10 +229,10 @@ class Events:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Time is over rate limiter, send now
|
# Time is over rate limiter, send now
|
||||||
data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list
|
data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list
|
||||||
|
|
||||||
# POST equivalent to requests.post(self.url, json=data)
|
# POST equivalent to requests.post(self.url, json=data)
|
||||||
smart_request('post', self.url, json=data, retry=0, verbose=False)
|
smart_request("post", self.url, json=data, retry=0, verbose=False)
|
||||||
|
|
||||||
# Reset events and rate limit timer
|
# Reset events and rate limit timer
|
||||||
self.events = []
|
self.events = []
|
||||||
|
|||||||
@ -4,4 +4,4 @@ from .rtdetr import RTDETR
|
|||||||
from .sam import SAM
|
from .sam import SAM
|
||||||
from .yolo import YOLO
|
from .yolo import YOLO
|
||||||
|
|
||||||
__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import
|
__all__ = "YOLO", "RTDETR", "SAM" # allow simpler import
|
||||||
|
|||||||
@ -5,4 +5,4 @@ from .predict import FastSAMPredictor
|
|||||||
from .prompt import FastSAMPrompt
|
from .prompt import FastSAMPrompt
|
||||||
from .val import FastSAMValidator
|
from .val import FastSAMValidator
|
||||||
|
|
||||||
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'
|
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMPrompt", "FastSAMValidator"
|
||||||
|
|||||||
@ -21,14 +21,14 @@ class FastSAM(Model):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model='FastSAM-x.pt'):
|
def __init__(self, model="FastSAM-x.pt"):
|
||||||
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
|
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
|
||||||
if str(model) == 'FastSAM.pt':
|
if str(model) == "FastSAM.pt":
|
||||||
model = 'FastSAM-x.pt'
|
model = "FastSAM-x.pt"
|
||||||
assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.'
|
assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models."
|
||||||
super().__init__(model=model, task='segment')
|
super().__init__(model=model, task="segment")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_map(self):
|
def task_map(self):
|
||||||
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
||||||
return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
|
return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class FastSAMPredictor(DetectionPredictor):
|
|||||||
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
|
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
|
||||||
"""
|
"""
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
self.args.task = 'segment'
|
self.args.task = "segment"
|
||||||
|
|
||||||
def postprocess(self, preds, img, orig_imgs):
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
"""
|
"""
|
||||||
@ -55,7 +55,8 @@ class FastSAMPredictor(DetectionPredictor):
|
|||||||
agnostic=self.args.agnostic_nms,
|
agnostic=self.args.agnostic_nms,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
nc=1, # set to 1 class since SAM has no class predictions
|
nc=1, # set to 1 class since SAM has no class predictions
|
||||||
classes=self.args.classes)
|
classes=self.args.classes,
|
||||||
|
)
|
||||||
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
|
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
|
||||||
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
|
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
|
||||||
full_box = full_box.view(1, -1)
|
full_box = full_box.view(1, -1)
|
||||||
|
|||||||
@ -23,7 +23,7 @@ class FastSAMPrompt:
|
|||||||
clip: CLIP model for linear assignment.
|
clip: CLIP model for linear assignment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, source, results, device='cuda') -> None:
|
def __init__(self, source, results, device="cuda") -> None:
|
||||||
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
||||||
self.device = device
|
self.device = device
|
||||||
self.results = results
|
self.results = results
|
||||||
@ -34,7 +34,8 @@ class FastSAMPrompt:
|
|||||||
import clip # for linear_assignment
|
import clip # for linear_assignment
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from ultralytics.utils.checks import check_requirements
|
from ultralytics.utils.checks import check_requirements
|
||||||
check_requirements('git+https://github.com/openai/CLIP.git')
|
|
||||||
|
check_requirements("git+https://github.com/openai/CLIP.git")
|
||||||
import clip
|
import clip
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
|
|
||||||
@ -46,11 +47,11 @@ class FastSAMPrompt:
|
|||||||
x1, y1, x2, y2 = bbox
|
x1, y1, x2, y2 = bbox
|
||||||
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
||||||
segmented_image = Image.fromarray(segmented_image_array)
|
segmented_image = Image.fromarray(segmented_image_array)
|
||||||
black_image = Image.new('RGB', image.size, (255, 255, 255))
|
black_image = Image.new("RGB", image.size, (255, 255, 255))
|
||||||
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
||||||
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
||||||
transparency_mask[y1:y2, x1:x2] = 255
|
transparency_mask[y1:y2, x1:x2] = 255
|
||||||
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
|
transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
|
||||||
black_image.paste(segmented_image, mask=transparency_mask_image)
|
black_image.paste(segmented_image, mask=transparency_mask_image)
|
||||||
return black_image
|
return black_image
|
||||||
|
|
||||||
@ -65,11 +66,12 @@ class FastSAMPrompt:
|
|||||||
mask = result.masks.data[i] == 1.0
|
mask = result.masks.data[i] == 1.0
|
||||||
if torch.sum(mask) >= filter:
|
if torch.sum(mask) >= filter:
|
||||||
annotation = {
|
annotation = {
|
||||||
'id': i,
|
"id": i,
|
||||||
'segmentation': mask.cpu().numpy(),
|
"segmentation": mask.cpu().numpy(),
|
||||||
'bbox': result.boxes.data[i],
|
"bbox": result.boxes.data[i],
|
||||||
'score': result.boxes.conf[i]}
|
"score": result.boxes.conf[i],
|
||||||
annotation['area'] = annotation['segmentation'].sum()
|
}
|
||||||
|
annotation["area"] = annotation["segmentation"].sum()
|
||||||
annotations.append(annotation)
|
annotations.append(annotation)
|
||||||
return annotations
|
return annotations
|
||||||
|
|
||||||
@ -91,7 +93,8 @@ class FastSAMPrompt:
|
|||||||
y2 = max(y2, y_t + h_t)
|
y2 = max(y2, y_t + h_t)
|
||||||
return [x1, y1, x2, y2]
|
return [x1, y1, x2, y2]
|
||||||
|
|
||||||
def plot(self,
|
def plot(
|
||||||
|
self,
|
||||||
annotations,
|
annotations,
|
||||||
output,
|
output,
|
||||||
bbox=None,
|
bbox=None,
|
||||||
@ -100,7 +103,8 @@ class FastSAMPrompt:
|
|||||||
mask_random_color=True,
|
mask_random_color=True,
|
||||||
better_quality=True,
|
better_quality=True,
|
||||||
retina=False,
|
retina=False,
|
||||||
with_contours=True):
|
with_contours=True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Plots annotations, bounding boxes, and points on images and saves the output.
|
Plots annotations, bounding boxes, and points on images and saves the output.
|
||||||
|
|
||||||
@ -139,7 +143,8 @@ class FastSAMPrompt:
|
|||||||
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
||||||
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
||||||
|
|
||||||
self.fast_show_mask(masks,
|
self.fast_show_mask(
|
||||||
|
masks,
|
||||||
plt.gca(),
|
plt.gca(),
|
||||||
random_color=mask_random_color,
|
random_color=mask_random_color,
|
||||||
bbox=bbox,
|
bbox=bbox,
|
||||||
@ -147,7 +152,8 @@ class FastSAMPrompt:
|
|||||||
pointlabel=point_label,
|
pointlabel=point_label,
|
||||||
retinamask=retina,
|
retinamask=retina,
|
||||||
target_height=original_h,
|
target_height=original_h,
|
||||||
target_width=original_w)
|
target_width=original_w,
|
||||||
|
)
|
||||||
|
|
||||||
if with_contours:
|
if with_contours:
|
||||||
contour_all = []
|
contour_all = []
|
||||||
@ -166,10 +172,10 @@ class FastSAMPrompt:
|
|||||||
# Save the figure
|
# Save the figure
|
||||||
save_path = Path(output) / result_name
|
save_path = Path(output) / result_name
|
||||||
save_path.parent.mkdir(exist_ok=True, parents=True)
|
save_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
plt.axis('off')
|
plt.axis("off")
|
||||||
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True)
|
plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
|
||||||
plt.close()
|
plt.close()
|
||||||
pbar.set_description(f'Saving {result_name} to {save_path}')
|
pbar.set_description(f"Saving {result_name} to {save_path}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fast_show_mask(
|
def fast_show_mask(
|
||||||
@ -212,26 +218,26 @@ class FastSAMPrompt:
|
|||||||
mask_image = np.expand_dims(annotation, -1) * visual
|
mask_image = np.expand_dims(annotation, -1) * visual
|
||||||
|
|
||||||
show = np.zeros((h, w, 4))
|
show = np.zeros((h, w, 4))
|
||||||
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
|
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
|
||||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
||||||
|
|
||||||
show[h_indices, w_indices, :] = mask_image[indices]
|
show[h_indices, w_indices, :] = mask_image[indices]
|
||||||
if bbox is not None:
|
if bbox is not None:
|
||||||
x1, y1, x2, y2 = bbox
|
x1, y1, x2, y2 = bbox
|
||||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
|
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
|
||||||
# Draw point
|
# Draw point
|
||||||
if points is not None:
|
if points is not None:
|
||||||
plt.scatter(
|
plt.scatter(
|
||||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||||
s=20,
|
s=20,
|
||||||
c='y',
|
c="y",
|
||||||
)
|
)
|
||||||
plt.scatter(
|
plt.scatter(
|
||||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||||
s=20,
|
s=20,
|
||||||
c='m',
|
c="m",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not retinamask:
|
if not retinamask:
|
||||||
@ -258,7 +264,7 @@ class FastSAMPrompt:
|
|||||||
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
|
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
|
||||||
ori_w, ori_h = image.size
|
ori_w, ori_h = image.size
|
||||||
annotations = format_results
|
annotations = format_results
|
||||||
mask_h, mask_w = annotations[0]['segmentation'].shape
|
mask_h, mask_w = annotations[0]["segmentation"].shape
|
||||||
if ori_w != mask_w or ori_h != mask_h:
|
if ori_w != mask_w or ori_h != mask_h:
|
||||||
image = image.resize((mask_w, mask_h))
|
image = image.resize((mask_w, mask_h))
|
||||||
cropped_boxes = []
|
cropped_boxes = []
|
||||||
@ -266,19 +272,19 @@ class FastSAMPrompt:
|
|||||||
not_crop = []
|
not_crop = []
|
||||||
filter_id = []
|
filter_id = []
|
||||||
for _, mask in enumerate(annotations):
|
for _, mask in enumerate(annotations):
|
||||||
if np.sum(mask['segmentation']) <= 100:
|
if np.sum(mask["segmentation"]) <= 100:
|
||||||
filter_id.append(_)
|
filter_id.append(_)
|
||||||
continue
|
continue
|
||||||
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
|
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
|
||||||
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
|
cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
|
||||||
cropped_images.append(bbox) # 保存裁剪的图片的bbox
|
cropped_images.append(bbox) # save cropped image bbox
|
||||||
|
|
||||||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
||||||
|
|
||||||
def box_prompt(self, bbox):
|
def box_prompt(self, bbox):
|
||||||
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
||||||
if self.results[0].masks is not None:
|
if self.results[0].masks is not None:
|
||||||
assert (bbox[2] != 0 and bbox[3] != 0)
|
assert bbox[2] != 0 and bbox[3] != 0
|
||||||
if os.path.isdir(self.source):
|
if os.path.isdir(self.source):
|
||||||
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
||||||
masks = self.results[0].masks.data
|
masks = self.results[0].masks.data
|
||||||
@ -290,7 +296,8 @@ class FastSAMPrompt:
|
|||||||
int(bbox[0] * w / target_width),
|
int(bbox[0] * w / target_width),
|
||||||
int(bbox[1] * h / target_height),
|
int(bbox[1] * h / target_height),
|
||||||
int(bbox[2] * w / target_width),
|
int(bbox[2] * w / target_width),
|
||||||
int(bbox[3] * h / target_height), ]
|
int(bbox[3] * h / target_height),
|
||||||
|
]
|
||||||
bbox[0] = max(round(bbox[0]), 0)
|
bbox[0] = max(round(bbox[0]), 0)
|
||||||
bbox[1] = max(round(bbox[1]), 0)
|
bbox[1] = max(round(bbox[1]), 0)
|
||||||
bbox[2] = min(round(bbox[2]), w)
|
bbox[2] = min(round(bbox[2]), w)
|
||||||
@ -299,7 +306,7 @@ class FastSAMPrompt:
|
|||||||
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
||||||
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
||||||
|
|
||||||
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
|
masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
|
||||||
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
||||||
|
|
||||||
union = bbox_area + orig_masks_area - masks_area
|
union = bbox_area + orig_masks_area - masks_area
|
||||||
@ -316,13 +323,13 @@ class FastSAMPrompt:
|
|||||||
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
||||||
masks = self._format_results(self.results[0], 0)
|
masks = self._format_results(self.results[0], 0)
|
||||||
target_height, target_width = self.results[0].orig_shape
|
target_height, target_width = self.results[0].orig_shape
|
||||||
h = masks[0]['segmentation'].shape[0]
|
h = masks[0]["segmentation"].shape[0]
|
||||||
w = masks[0]['segmentation'].shape[1]
|
w = masks[0]["segmentation"].shape[1]
|
||||||
if h != target_height or w != target_width:
|
if h != target_height or w != target_width:
|
||||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
||||||
onemask = np.zeros((h, w))
|
onemask = np.zeros((h, w))
|
||||||
for annotation in masks:
|
for annotation in masks:
|
||||||
mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
|
mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
|
||||||
for i, point in enumerate(points):
|
for i, point in enumerate(points):
|
||||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
||||||
onemask += mask
|
onemask += mask
|
||||||
@ -337,12 +344,12 @@ class FastSAMPrompt:
|
|||||||
if self.results[0].masks is not None:
|
if self.results[0].masks is not None:
|
||||||
format_results = self._format_results(self.results[0], 0)
|
format_results = self._format_results(self.results[0], 0)
|
||||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
||||||
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
|
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
|
||||||
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
||||||
max_idx = scores.argsort()
|
max_idx = scores.argsort()
|
||||||
max_idx = max_idx[-1]
|
max_idx = max_idx[-1]
|
||||||
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
||||||
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]['segmentation']]))
|
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
|
||||||
return self.results
|
return self.results
|
||||||
|
|
||||||
def everything_prompt(self):
|
def everything_prompt(self):
|
||||||
|
|||||||
@ -35,6 +35,6 @@ class FastSAMValidator(SegmentationValidator):
|
|||||||
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
|
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
|
||||||
"""
|
"""
|
||||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||||
self.args.task = 'segment'
|
self.args.task = "segment"
|
||||||
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
|
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
|
||||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||||
|
|||||||
@ -4,4 +4,4 @@ from .model import NAS
|
|||||||
from .predict import NASPredictor
|
from .predict import NASPredictor
|
||||||
from .val import NASValidator
|
from .val import NASValidator
|
||||||
|
|
||||||
__all__ = 'NASPredictor', 'NASValidator', 'NAS'
|
__all__ = "NASPredictor", "NASValidator", "NAS"
|
||||||
|
|||||||
@ -44,20 +44,21 @@ class NAS(Model):
|
|||||||
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
|
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model='yolo_nas_s.pt') -> None:
|
def __init__(self, model="yolo_nas_s.pt") -> None:
|
||||||
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
||||||
assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
|
assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
|
||||||
super().__init__(model, task='detect')
|
super().__init__(model, task="detect")
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def _load(self, weights: str, task: str):
|
def _load(self, weights: str, task: str):
|
||||||
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
|
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
|
||||||
import super_gradients
|
import super_gradients
|
||||||
|
|
||||||
suffix = Path(weights).suffix
|
suffix = Path(weights).suffix
|
||||||
if suffix == '.pt':
|
if suffix == ".pt":
|
||||||
self.model = torch.load(weights)
|
self.model = torch.load(weights)
|
||||||
elif suffix == '':
|
elif suffix == "":
|
||||||
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
|
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
|
||||||
# Standardize model
|
# Standardize model
|
||||||
self.model.fuse = lambda verbose=True: self.model
|
self.model.fuse = lambda verbose=True: self.model
|
||||||
self.model.stride = torch.tensor([32])
|
self.model.stride = torch.tensor([32])
|
||||||
@ -65,7 +66,7 @@ class NAS(Model):
|
|||||||
self.model.is_fused = lambda: False # for info()
|
self.model.is_fused = lambda: False # for info()
|
||||||
self.model.yaml = {} # for info()
|
self.model.yaml = {} # for info()
|
||||||
self.model.pt_path = weights # for export()
|
self.model.pt_path = weights # for export()
|
||||||
self.model.task = 'detect' # for export()
|
self.model.task = "detect" # for export()
|
||||||
|
|
||||||
def info(self, detailed=False, verbose=True):
|
def info(self, detailed=False, verbose=True):
|
||||||
"""
|
"""
|
||||||
@ -80,4 +81,4 @@ class NAS(Model):
|
|||||||
@property
|
@property
|
||||||
def task_map(self):
|
def task_map(self):
|
||||||
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
|
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
|
||||||
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
|
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
|
||||||
|
|||||||
@ -39,12 +39,14 @@ class NASPredictor(BasePredictor):
|
|||||||
boxes = ops.xyxy2xywh(preds_in[0][0])
|
boxes = ops.xyxy2xywh(preds_in[0][0])
|
||||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
||||||
|
|
||||||
preds = ops.non_max_suppression(preds,
|
preds = ops.non_max_suppression(
|
||||||
|
preds,
|
||||||
self.args.conf,
|
self.args.conf,
|
||||||
self.args.iou,
|
self.args.iou,
|
||||||
agnostic=self.args.agnostic_nms,
|
agnostic=self.args.agnostic_nms,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
classes=self.args.classes)
|
classes=self.args.classes,
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import torch
|
|||||||
from ultralytics.models.yolo.detect import DetectionValidator
|
from ultralytics.models.yolo.detect import DetectionValidator
|
||||||
from ultralytics.utils import ops
|
from ultralytics.utils import ops
|
||||||
|
|
||||||
__all__ = ['NASValidator']
|
__all__ = ["NASValidator"]
|
||||||
|
|
||||||
|
|
||||||
class NASValidator(DetectionValidator):
|
class NASValidator(DetectionValidator):
|
||||||
@ -38,11 +38,13 @@ class NASValidator(DetectionValidator):
|
|||||||
"""Apply Non-maximum suppression to prediction outputs."""
|
"""Apply Non-maximum suppression to prediction outputs."""
|
||||||
boxes = ops.xyxy2xywh(preds_in[0][0])
|
boxes = ops.xyxy2xywh(preds_in[0][0])
|
||||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
||||||
return ops.non_max_suppression(preds,
|
return ops.non_max_suppression(
|
||||||
|
preds,
|
||||||
self.args.conf,
|
self.args.conf,
|
||||||
self.args.iou,
|
self.args.iou,
|
||||||
labels=self.lb,
|
labels=self.lb,
|
||||||
multi_label=False,
|
multi_label=False,
|
||||||
agnostic=self.args.single_cls,
|
agnostic=self.args.single_cls,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
max_time_img=0.5)
|
max_time_img=0.5,
|
||||||
|
)
|
||||||
|
|||||||
@ -4,4 +4,4 @@ from .model import RTDETR
|
|||||||
from .predict import RTDETRPredictor
|
from .predict import RTDETRPredictor
|
||||||
from .val import RTDETRValidator
|
from .val import RTDETRValidator
|
||||||
|
|
||||||
__all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR'
|
__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class RTDETR(Model):
|
|||||||
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
|
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model='rtdetr-l.pt') -> None:
|
def __init__(self, model="rtdetr-l.pt") -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
|
Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
|
||||||
|
|
||||||
@ -34,9 +34,9 @@ class RTDETR(Model):
|
|||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
||||||
"""
|
"""
|
||||||
if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
|
if model and model.split(".")[-1] not in ("pt", "yaml", "yml"):
|
||||||
raise NotImplementedError('RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.')
|
raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
|
||||||
super().__init__(model=model, task='detect')
|
super().__init__(model=model, task="detect")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_map(self) -> dict:
|
def task_map(self) -> dict:
|
||||||
@ -47,8 +47,10 @@ class RTDETR(Model):
|
|||||||
dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
'detect': {
|
"detect": {
|
||||||
'predictor': RTDETRPredictor,
|
"predictor": RTDETRPredictor,
|
||||||
'validator': RTDETRValidator,
|
"validator": RTDETRValidator,
|
||||||
'trainer': RTDETRTrainer,
|
"trainer": RTDETRTrainer,
|
||||||
'model': RTDETRDetectionModel}}
|
"model": RTDETRDetectionModel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -43,12 +43,12 @@ class RTDETRTrainer(DetectionTrainer):
|
|||||||
Returns:
|
Returns:
|
||||||
(RTDETRDetectionModel): Initialized model.
|
(RTDETRDetectionModel): Initialized model.
|
||||||
"""
|
"""
|
||||||
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def build_dataset(self, img_path, mode='val', batch=None):
|
def build_dataset(self, img_path, mode="val", batch=None):
|
||||||
"""
|
"""
|
||||||
Build and return an RT-DETR dataset for training or validation.
|
Build and return an RT-DETR dataset for training or validation.
|
||||||
|
|
||||||
@ -60,15 +60,17 @@ class RTDETRTrainer(DetectionTrainer):
|
|||||||
Returns:
|
Returns:
|
||||||
(RTDETRDataset): Dataset object for the specific mode.
|
(RTDETRDataset): Dataset object for the specific mode.
|
||||||
"""
|
"""
|
||||||
return RTDETRDataset(img_path=img_path,
|
return RTDETRDataset(
|
||||||
|
img_path=img_path,
|
||||||
imgsz=self.args.imgsz,
|
imgsz=self.args.imgsz,
|
||||||
batch_size=batch,
|
batch_size=batch,
|
||||||
augment=mode == 'train',
|
augment=mode == "train",
|
||||||
hyp=self.args,
|
hyp=self.args,
|
||||||
rect=False,
|
rect=False,
|
||||||
cache=self.args.cache or None,
|
cache=self.args.cache or None,
|
||||||
prefix=colorstr(f'{mode}: '),
|
prefix=colorstr(f"{mode}: "),
|
||||||
data=self.data)
|
data=self.data,
|
||||||
|
)
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
"""
|
"""
|
||||||
@ -77,7 +79,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|||||||
Returns:
|
Returns:
|
||||||
(RTDETRValidator): Validator object for model validation.
|
(RTDETRValidator): Validator object for model validation.
|
||||||
"""
|
"""
|
||||||
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
|
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
|
||||||
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
@ -91,10 +93,10 @@ class RTDETRTrainer(DetectionTrainer):
|
|||||||
(dict): Preprocessed batch.
|
(dict): Preprocessed batch.
|
||||||
"""
|
"""
|
||||||
batch = super().preprocess_batch(batch)
|
batch = super().preprocess_batch(batch)
|
||||||
bs = len(batch['img'])
|
bs = len(batch["img"])
|
||||||
batch_idx = batch['batch_idx']
|
batch_idx = batch["batch_idx"]
|
||||||
gt_bbox, gt_class = [], []
|
gt_bbox, gt_class = [], []
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
|
gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
|
||||||
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
|
gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
|
||||||
return batch
|
return batch
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from ultralytics.data.augment import Compose, Format, v8_transforms
|
|||||||
from ultralytics.models.yolo.detect import DetectionValidator
|
from ultralytics.models.yolo.detect import DetectionValidator
|
||||||
from ultralytics.utils import colorstr, ops
|
from ultralytics.utils import colorstr, ops
|
||||||
|
|
||||||
__all__ = 'RTDETRValidator', # tuple or list
|
__all__ = ("RTDETRValidator",) # tuple or list
|
||||||
|
|
||||||
|
|
||||||
class RTDETRDataset(YOLODataset):
|
class RTDETRDataset(YOLODataset):
|
||||||
@ -37,13 +37,16 @@ class RTDETRDataset(YOLODataset):
|
|||||||
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
|
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
|
||||||
transforms = Compose([])
|
transforms = Compose([])
|
||||||
transforms.append(
|
transforms.append(
|
||||||
Format(bbox_format='xywh',
|
Format(
|
||||||
|
bbox_format="xywh",
|
||||||
normalize=True,
|
normalize=True,
|
||||||
return_mask=self.use_segments,
|
return_mask=self.use_segments,
|
||||||
return_keypoint=self.use_keypoints,
|
return_keypoint=self.use_keypoints,
|
||||||
batch_idx=True,
|
batch_idx=True,
|
||||||
mask_ratio=hyp.mask_ratio,
|
mask_ratio=hyp.mask_ratio,
|
||||||
mask_overlap=hyp.overlap_mask))
|
mask_overlap=hyp.overlap_mask,
|
||||||
|
)
|
||||||
|
)
|
||||||
return transforms
|
return transforms
|
||||||
|
|
||||||
|
|
||||||
@ -68,7 +71,7 @@ class RTDETRValidator(DetectionValidator):
|
|||||||
For further details on the attributes and methods, refer to the parent DetectionValidator class.
|
For further details on the attributes and methods, refer to the parent DetectionValidator class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def build_dataset(self, img_path, mode='val', batch=None):
|
def build_dataset(self, img_path, mode="val", batch=None):
|
||||||
"""
|
"""
|
||||||
Build an RTDETR Dataset.
|
Build an RTDETR Dataset.
|
||||||
|
|
||||||
@ -85,8 +88,9 @@ class RTDETRValidator(DetectionValidator):
|
|||||||
hyp=self.args,
|
hyp=self.args,
|
||||||
rect=False, # no rect
|
rect=False, # no rect
|
||||||
cache=self.args.cache or None,
|
cache=self.args.cache or None,
|
||||||
prefix=colorstr(f'{mode}: '),
|
prefix=colorstr(f"{mode}: "),
|
||||||
data=self.data)
|
data=self.data,
|
||||||
|
)
|
||||||
|
|
||||||
def postprocess(self, preds):
|
def postprocess(self, preds):
|
||||||
"""Apply Non-maximum suppression to prediction outputs."""
|
"""Apply Non-maximum suppression to prediction outputs."""
|
||||||
@ -108,12 +112,12 @@ class RTDETRValidator(DetectionValidator):
|
|||||||
|
|
||||||
def _prepare_batch(self, si, batch):
|
def _prepare_batch(self, si, batch):
|
||||||
"""Prepares a batch for training or inference by applying transformations."""
|
"""Prepares a batch for training or inference by applying transformations."""
|
||||||
idx = batch['batch_idx'] == si
|
idx = batch["batch_idx"] == si
|
||||||
cls = batch['cls'][idx].squeeze(-1)
|
cls = batch["cls"][idx].squeeze(-1)
|
||||||
bbox = batch['bboxes'][idx]
|
bbox = batch["bboxes"][idx]
|
||||||
ori_shape = batch['ori_shape'][si]
|
ori_shape = batch["ori_shape"][si]
|
||||||
imgsz = batch['img'].shape[2:]
|
imgsz = batch["img"].shape[2:]
|
||||||
ratio_pad = batch['ratio_pad'][si]
|
ratio_pad = batch["ratio_pad"][si]
|
||||||
if len(cls):
|
if len(cls):
|
||||||
bbox = ops.xywh2xyxy(bbox) # target boxes
|
bbox = ops.xywh2xyxy(bbox) # target boxes
|
||||||
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
|
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
|
||||||
@ -124,6 +128,6 @@ class RTDETRValidator(DetectionValidator):
|
|||||||
def _prepare_pred(self, pred, pbatch):
|
def _prepare_pred(self, pred, pbatch):
|
||||||
"""Prepares and returns a batch with transformed bounding boxes and class labels."""
|
"""Prepares and returns a batch with transformed bounding boxes and class labels."""
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
predn[..., [0, 2]] *= pbatch['ori_shape'][1] / self.args.imgsz # native-space pred
|
predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
|
||||||
predn[..., [1, 3]] *= pbatch['ori_shape'][0] / self.args.imgsz # native-space pred
|
predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
|
||||||
return predn.float()
|
return predn.float()
|
||||||
|
|||||||
@ -3,4 +3,4 @@
|
|||||||
from .model import SAM
|
from .model import SAM
|
||||||
from .predict import Predictor
|
from .predict import Predictor
|
||||||
|
|
||||||
__all__ = 'SAM', 'Predictor' # tuple or list
|
__all__ = "SAM", "Predictor" # tuple or list
|
||||||
|
|||||||
@ -8,10 +8,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def is_box_near_crop_edge(boxes: torch.Tensor,
|
def is_box_near_crop_edge(
|
||||||
crop_box: List[int],
|
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
||||||
orig_box: List[int],
|
) -> torch.Tensor:
|
||||||
atol: float = 20.0) -> torch.Tensor:
|
|
||||||
"""Return a boolean tensor indicating if boxes are near the crop edge."""
|
"""Return a boolean tensor indicating if boxes are near the crop edge."""
|
||||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||||
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||||
@ -24,10 +23,10 @@ def is_box_near_crop_edge(boxes: torch.Tensor,
|
|||||||
|
|
||||||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
||||||
"""Yield batches of data from the input arguments."""
|
"""Yield batches of data from the input arguments."""
|
||||||
assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.'
|
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
|
||||||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
||||||
for b in range(n_batches):
|
for b in range(n_batches):
|
||||||
yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
|
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
||||||
|
|
||||||
|
|
||||||
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
||||||
@ -39,9 +38,8 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
|
|||||||
"""
|
"""
|
||||||
# One mask is always contained inside the other.
|
# One mask is always contained inside the other.
|
||||||
# Save memory by preventing unnecessary cast to torch.int64
|
# Save memory by preventing unnecessary cast to torch.int64
|
||||||
intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1,
|
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||||
dtype=torch.int32))
|
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||||
unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
|
|
||||||
return intersections / unions
|
return intersections / unions
|
||||||
|
|
||||||
|
|
||||||
@ -56,11 +54,12 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
|
|||||||
|
|
||||||
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
||||||
"""Generate point grids for all crop layers."""
|
"""Generate point grids for all crop layers."""
|
||||||
return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)]
|
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
||||||
|
|
||||||
|
|
||||||
def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
|
def generate_crop_boxes(
|
||||||
overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
|
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
||||||
|
) -> Tuple[List[List[int]], List[int]]:
|
||||||
"""
|
"""
|
||||||
Generates a list of crop boxes of different sizes.
|
Generates a list of crop boxes of different sizes.
|
||||||
|
|
||||||
@ -132,8 +131,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
|||||||
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
|
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
|
||||||
import cv2 # type: ignore
|
import cv2 # type: ignore
|
||||||
|
|
||||||
assert mode in {'holes', 'islands'}
|
assert mode in {"holes", "islands"}
|
||||||
correct_holes = mode == 'holes'
|
correct_holes = mode == "holes"
|
||||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||||
|
|||||||
@ -64,18 +64,16 @@ def build_mobile_sam(checkpoint=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_sam(encoder_embed_dim,
|
def _build_sam(
|
||||||
encoder_depth,
|
encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
|
||||||
encoder_num_heads,
|
):
|
||||||
encoder_global_attn_indexes,
|
|
||||||
checkpoint=None,
|
|
||||||
mobile_sam=False):
|
|
||||||
"""Builds the selected SAM model architecture."""
|
"""Builds the selected SAM model architecture."""
|
||||||
prompt_embed_dim = 256
|
prompt_embed_dim = 256
|
||||||
image_size = 1024
|
image_size = 1024
|
||||||
vit_patch_size = 16
|
vit_patch_size = 16
|
||||||
image_embedding_size = image_size // vit_patch_size
|
image_embedding_size = image_size // vit_patch_size
|
||||||
image_encoder = (TinyViT(
|
image_encoder = (
|
||||||
|
TinyViT(
|
||||||
img_size=1024,
|
img_size=1024,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
num_classes=1000,
|
num_classes=1000,
|
||||||
@ -90,7 +88,9 @@ def _build_sam(encoder_embed_dim,
|
|||||||
mbconv_expand_ratio=4.0,
|
mbconv_expand_ratio=4.0,
|
||||||
local_conv_size=3,
|
local_conv_size=3,
|
||||||
layer_lr_decay=0.8,
|
layer_lr_decay=0.8,
|
||||||
) if mobile_sam else ImageEncoderViT(
|
)
|
||||||
|
if mobile_sam
|
||||||
|
else ImageEncoderViT(
|
||||||
depth=encoder_depth,
|
depth=encoder_depth,
|
||||||
embed_dim=encoder_embed_dim,
|
embed_dim=encoder_embed_dim,
|
||||||
img_size=image_size,
|
img_size=image_size,
|
||||||
@ -103,7 +103,8 @@ def _build_sam(encoder_embed_dim,
|
|||||||
global_attn_indexes=encoder_global_attn_indexes,
|
global_attn_indexes=encoder_global_attn_indexes,
|
||||||
window_size=14,
|
window_size=14,
|
||||||
out_chans=prompt_embed_dim,
|
out_chans=prompt_embed_dim,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
sam = Sam(
|
sam = Sam(
|
||||||
image_encoder=image_encoder,
|
image_encoder=image_encoder,
|
||||||
prompt_encoder=PromptEncoder(
|
prompt_encoder=PromptEncoder(
|
||||||
@ -129,7 +130,7 @@ def _build_sam(encoder_embed_dim,
|
|||||||
)
|
)
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
checkpoint = attempt_download_asset(checkpoint)
|
checkpoint = attempt_download_asset(checkpoint)
|
||||||
with open(checkpoint, 'rb') as f:
|
with open(checkpoint, "rb") as f:
|
||||||
state_dict = torch.load(f)
|
state_dict = torch.load(f)
|
||||||
sam.load_state_dict(state_dict)
|
sam.load_state_dict(state_dict)
|
||||||
sam.eval()
|
sam.eval()
|
||||||
@ -139,13 +140,14 @@ def _build_sam(encoder_embed_dim,
|
|||||||
|
|
||||||
|
|
||||||
sam_model_map = {
|
sam_model_map = {
|
||||||
'sam_h.pt': build_sam_vit_h,
|
"sam_h.pt": build_sam_vit_h,
|
||||||
'sam_l.pt': build_sam_vit_l,
|
"sam_l.pt": build_sam_vit_l,
|
||||||
'sam_b.pt': build_sam_vit_b,
|
"sam_b.pt": build_sam_vit_b,
|
||||||
'mobile_sam.pt': build_mobile_sam, }
|
"mobile_sam.pt": build_mobile_sam,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_sam(ckpt='sam_b.pt'):
|
def build_sam(ckpt="sam_b.pt"):
|
||||||
"""Build a SAM model specified by ckpt."""
|
"""Build a SAM model specified by ckpt."""
|
||||||
model_builder = None
|
model_builder = None
|
||||||
ckpt = str(ckpt) # to allow Path ckpt types
|
ckpt = str(ckpt) # to allow Path ckpt types
|
||||||
@ -154,6 +156,6 @@ def build_sam(ckpt='sam_b.pt'):
|
|||||||
model_builder = sam_model_map.get(k)
|
model_builder = sam_model_map.get(k)
|
||||||
|
|
||||||
if not model_builder:
|
if not model_builder:
|
||||||
raise FileNotFoundError(f'{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}')
|
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
|
||||||
|
|
||||||
return model_builder(ckpt)
|
return model_builder(ckpt)
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class SAM(Model):
|
|||||||
dataset.
|
dataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model='sam_b.pt') -> None:
|
def __init__(self, model="sam_b.pt") -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the SAM model with a pre-trained model file.
|
Initializes the SAM model with a pre-trained model file.
|
||||||
|
|
||||||
@ -42,9 +42,9 @@ class SAM(Model):
|
|||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: If the model file extension is not .pt or .pth.
|
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||||
"""
|
"""
|
||||||
if model and Path(model).suffix not in ('.pt', '.pth'):
|
if model and Path(model).suffix not in (".pt", ".pth"):
|
||||||
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
|
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
||||||
super().__init__(model=model, task='segment')
|
super().__init__(model=model, task="segment")
|
||||||
|
|
||||||
def _load(self, weights: str, task=None):
|
def _load(self, weights: str, task=None):
|
||||||
"""
|
"""
|
||||||
@ -70,7 +70,7 @@ class SAM(Model):
|
|||||||
Returns:
|
Returns:
|
||||||
(list): The model predictions.
|
(list): The model predictions.
|
||||||
"""
|
"""
|
||||||
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
|
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
|
||||||
kwargs.update(overrides)
|
kwargs.update(overrides)
|
||||||
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
||||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||||
@ -112,4 +112,4 @@ class SAM(Model):
|
|||||||
Returns:
|
Returns:
|
||||||
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
|
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
|
||||||
"""
|
"""
|
||||||
return {'segment': {'predictor': Predictor}}
|
return {"segment": {"predictor": Predictor}}
|
||||||
|
|||||||
@ -64,8 +64,9 @@ class MaskDecoder(nn.Module):
|
|||||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||||
activation(),
|
activation(),
|
||||||
)
|
)
|
||||||
self.output_hypernetworks_mlps = nn.ModuleList([
|
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)])
|
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
|
||||||
|
)
|
||||||
|
|
||||||
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
|
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
|
||||||
|
|
||||||
@ -132,13 +133,14 @@ class MaskDecoder(nn.Module):
|
|||||||
# Run the transformer
|
# Run the transformer
|
||||||
hs, src = self.transformer(src, pos_src, tokens)
|
hs, src = self.transformer(src, pos_src, tokens)
|
||||||
iou_token_out = hs[:, 0, :]
|
iou_token_out = hs[:, 0, :]
|
||||||
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
|
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
||||||
|
|
||||||
# Upscale mask embeddings and predict masks using the mask tokens
|
# Upscale mask embeddings and predict masks using the mask tokens
|
||||||
src = src.transpose(1, 2).view(b, c, h, w)
|
src = src.transpose(1, 2).view(b, c, h, w)
|
||||||
upscaled_embedding = self.output_upscaling(src)
|
upscaled_embedding = self.output_upscaling(src)
|
||||||
hyper_in_list: List[torch.Tensor] = [
|
hyper_in_list: List[torch.Tensor] = [
|
||||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)]
|
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
|
||||||
|
]
|
||||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||||
b, c, h, w = upscaled_embedding.shape
|
b, c, h, w = upscaled_embedding.shape
|
||||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||||
|
|||||||
@ -283,9 +283,9 @@ class PromptEncoder(nn.Module):
|
|||||||
if masks is not None:
|
if masks is not None:
|
||||||
dense_embeddings = self._embed_masks(masks)
|
dense_embeddings = self._embed_masks(masks)
|
||||||
else:
|
else:
|
||||||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
|
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
||||||
1).expand(bs, -1, self.image_embedding_size[0],
|
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
||||||
self.image_embedding_size[1])
|
)
|
||||||
|
|
||||||
return sparse_embeddings, dense_embeddings
|
return sparse_embeddings, dense_embeddings
|
||||||
|
|
||||||
@ -298,7 +298,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
if scale is None or scale <= 0.0:
|
if scale is None or scale <= 0.0:
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats)))
|
self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
|
||||||
|
|
||||||
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
|
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
|
||||||
torch.use_deterministic_algorithms(False)
|
torch.use_deterministic_algorithms(False)
|
||||||
@ -425,14 +425,14 @@ class Attention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
head_dim = dim // num_heads
|
head_dim = dim // num_heads
|
||||||
self.scale = head_dim ** -0.5
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
self.proj = nn.Linear(dim, dim)
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
self.use_rel_pos = use_rel_pos
|
self.use_rel_pos = use_rel_pos
|
||||||
if self.use_rel_pos:
|
if self.use_rel_pos:
|
||||||
assert (input_size is not None), 'Input size must be provided if using relative positional encoding.'
|
assert input_size is not None, "Input size must be provided if using relative positional encoding."
|
||||||
# Initialize relative positional embeddings
|
# Initialize relative positional embeddings
|
||||||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
||||||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||||
@ -479,8 +479,9 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
|
|||||||
return windows, (Hp, Wp)
|
return windows, (Hp, Wp)
|
||||||
|
|
||||||
|
|
||||||
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int],
|
def window_unpartition(
|
||||||
hw: Tuple[int, int]) -> torch.Tensor:
|
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Window unpartition into original sequences and removing padding.
|
Window unpartition into original sequences and removing padding.
|
||||||
|
|
||||||
@ -523,7 +524,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
|
|||||||
rel_pos_resized = F.interpolate(
|
rel_pos_resized = F.interpolate(
|
||||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||||
size=max_rel_dist,
|
size=max_rel_dist,
|
||||||
mode='linear',
|
mode="linear",
|
||||||
)
|
)
|
||||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||||
else:
|
else:
|
||||||
@ -567,11 +568,12 @@ def add_decomposed_rel_pos(
|
|||||||
|
|
||||||
B, _, dim = q.shape
|
B, _, dim = q.shape
|
||||||
r_q = q.reshape(B, q_h, q_w, dim)
|
r_q = q.reshape(B, q_h, q_w, dim)
|
||||||
rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
|
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||||
rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
|
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||||
|
|
||||||
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
||||||
B, q_h * q_w, k_h * k_w)
|
B, q_h * q_w, k_h * k_w
|
||||||
|
)
|
||||||
|
|
||||||
return attn
|
return attn
|
||||||
|
|
||||||
|
|||||||
@ -30,8 +30,9 @@ class Sam(nn.Module):
|
|||||||
pixel_mean (List[float]): Mean pixel values for image normalization.
|
pixel_mean (List[float]): Mean pixel values for image normalization.
|
||||||
pixel_std (List[float]): Standard deviation values for image normalization.
|
pixel_std (List[float]): Standard deviation values for image normalization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mask_threshold: float = 0.0
|
mask_threshold: float = 0.0
|
||||||
image_format: str = 'RGB'
|
image_format: str = "RGB"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -39,7 +40,7 @@ class Sam(nn.Module):
|
|||||||
prompt_encoder: PromptEncoder,
|
prompt_encoder: PromptEncoder,
|
||||||
mask_decoder: MaskDecoder,
|
mask_decoder: MaskDecoder,
|
||||||
pixel_mean: List[float] = (123.675, 116.28, 103.53),
|
pixel_mean: List[float] = (123.675, 116.28, 103.53),
|
||||||
pixel_std: List[float] = (58.395, 57.12, 57.375)
|
pixel_std: List[float] = (58.395, 57.12, 57.375),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the Sam class to predict object masks from an image and input prompts.
|
Initialize the Sam class to predict object masks from an image and input prompts.
|
||||||
@ -60,5 +61,5 @@ class Sam(nn.Module):
|
|||||||
self.image_encoder = image_encoder
|
self.image_encoder = image_encoder
|
||||||
self.prompt_encoder = prompt_encoder
|
self.prompt_encoder = prompt_encoder
|
||||||
self.mask_decoder = mask_decoder
|
self.mask_decoder = mask_decoder
|
||||||
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||||
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||||
|
|||||||
@ -28,11 +28,11 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||||||
drop path.
|
drop path.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||||
bn = torch.nn.BatchNorm2d(b)
|
bn = torch.nn.BatchNorm2d(b)
|
||||||
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
||||||
torch.nn.init.constant_(bn.bias, 0)
|
torch.nn.init.constant_(bn.bias, 0)
|
||||||
self.add_module('bn', bn)
|
self.add_module("bn", bn)
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
class PatchEmbed(nn.Module):
|
||||||
@ -146,11 +146,11 @@ class ConvLayer(nn.Module):
|
|||||||
input_resolution,
|
input_resolution,
|
||||||
depth,
|
depth,
|
||||||
activation,
|
activation,
|
||||||
drop_path=0.,
|
drop_path=0.0,
|
||||||
downsample=None,
|
downsample=None,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
out_dim=None,
|
out_dim=None,
|
||||||
conv_expand_ratio=4.,
|
conv_expand_ratio=4.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the ConvLayer with the given dimensions and settings.
|
Initializes the ConvLayer with the given dimensions and settings.
|
||||||
@ -173,18 +173,25 @@ class ConvLayer(nn.Module):
|
|||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
|
|
||||||
# Build blocks
|
# Build blocks
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
MBConv(
|
MBConv(
|
||||||
dim,
|
dim,
|
||||||
dim,
|
dim,
|
||||||
conv_expand_ratio,
|
conv_expand_ratio,
|
||||||
activation,
|
activation,
|
||||||
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||||
) for i in range(depth)])
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Patch merging layer
|
# Patch merging layer
|
||||||
self.downsample = None if downsample is None else downsample(
|
self.downsample = (
|
||||||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
None
|
||||||
|
if downsample is None
|
||||||
|
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Processes the input through a series of convolutional layers and returns the activated output."""
|
"""Processes the input through a series of convolutional layers and returns the activated output."""
|
||||||
@ -200,7 +207,7 @@ class Mlp(nn.Module):
|
|||||||
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
|
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||||
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_features = out_features or in_features
|
out_features = out_features or in_features
|
||||||
@ -256,7 +263,7 @@ class Attention(torch.nn.Module):
|
|||||||
|
|
||||||
assert isinstance(resolution, tuple) and len(resolution) == 2
|
assert isinstance(resolution, tuple) and len(resolution) == 2
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.scale = key_dim ** -0.5
|
self.scale = key_dim**-0.5
|
||||||
self.key_dim = key_dim
|
self.key_dim = key_dim
|
||||||
self.nh_kd = nh_kd = key_dim * num_heads
|
self.nh_kd = nh_kd = key_dim * num_heads
|
||||||
self.d = int(attn_ratio * key_dim)
|
self.d = int(attn_ratio * key_dim)
|
||||||
@ -279,13 +286,13 @@ class Attention(torch.nn.Module):
|
|||||||
attention_offsets[offset] = len(attention_offsets)
|
attention_offsets[offset] = len(attention_offsets)
|
||||||
idxs.append(attention_offsets[offset])
|
idxs.append(attention_offsets[offset])
|
||||||
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
|
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
|
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
|
||||||
super().train(mode)
|
super().train(mode)
|
||||||
if mode and hasattr(self, 'ab'):
|
if mode and hasattr(self, "ab"):
|
||||||
del self.ab
|
del self.ab
|
||||||
else:
|
else:
|
||||||
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||||
@ -306,8 +313,9 @@ class Attention(torch.nn.Module):
|
|||||||
v = v.permute(0, 2, 1, 3)
|
v = v.permute(0, 2, 1, 3)
|
||||||
self.ab = self.ab.to(self.attention_biases.device)
|
self.ab = self.ab.to(self.attention_biases.device)
|
||||||
|
|
||||||
attn = ((q @ k.transpose(-2, -1)) * self.scale +
|
attn = (q @ k.transpose(-2, -1)) * self.scale + (
|
||||||
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
|
self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
|
||||||
|
)
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||||
return self.proj(x)
|
return self.proj(x)
|
||||||
@ -322,9 +330,9 @@ class TinyViTBlock(nn.Module):
|
|||||||
input_resolution,
|
input_resolution,
|
||||||
num_heads,
|
num_heads,
|
||||||
window_size=7,
|
window_size=7,
|
||||||
mlp_ratio=4.,
|
mlp_ratio=4.0,
|
||||||
drop=0.,
|
drop=0.0,
|
||||||
drop_path=0.,
|
drop_path=0.0,
|
||||||
local_conv_size=3,
|
local_conv_size=3,
|
||||||
activation=nn.GELU,
|
activation=nn.GELU,
|
||||||
):
|
):
|
||||||
@ -350,7 +358,7 @@ class TinyViTBlock(nn.Module):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.input_resolution = input_resolution
|
self.input_resolution = input_resolution
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
assert window_size > 0, 'window_size must be greater than 0'
|
assert window_size > 0, "window_size must be greater than 0"
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
@ -358,7 +366,7 @@ class TinyViTBlock(nn.Module):
|
|||||||
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
self.drop_path = nn.Identity()
|
self.drop_path = nn.Identity()
|
||||||
|
|
||||||
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
|
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
||||||
head_dim = dim // num_heads
|
head_dim = dim // num_heads
|
||||||
|
|
||||||
window_resolution = (window_size, window_size)
|
window_resolution = (window_size, window_size)
|
||||||
@ -377,7 +385,7 @@ class TinyViTBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
B, L, C = x.shape
|
B, L, C = x.shape
|
||||||
assert L == H * W, 'input feature has wrong size'
|
assert L == H * W, "input feature has wrong size"
|
||||||
res_x = x
|
res_x = x
|
||||||
if H == self.window_size and W == self.window_size:
|
if H == self.window_size and W == self.window_size:
|
||||||
x = self.attn(x)
|
x = self.attn(x)
|
||||||
@ -394,8 +402,11 @@ class TinyViTBlock(nn.Module):
|
|||||||
nH = pH // self.window_size
|
nH = pH // self.window_size
|
||||||
nW = pW // self.window_size
|
nW = pW // self.window_size
|
||||||
# Window partition
|
# Window partition
|
||||||
x = x.view(B, nH, self.window_size, nW, self.window_size,
|
x = (
|
||||||
C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
|
x.view(B, nH, self.window_size, nW, self.window_size, C)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(B * nH * nW, self.window_size * self.window_size, C)
|
||||||
|
)
|
||||||
x = self.attn(x)
|
x = self.attn(x)
|
||||||
# Window reverse
|
# Window reverse
|
||||||
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
|
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
|
||||||
@ -417,8 +428,10 @@ class TinyViTBlock(nn.Module):
|
|||||||
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
|
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
|
||||||
attentions heads, window size, and MLP ratio.
|
attentions heads, window size, and MLP ratio.
|
||||||
"""
|
"""
|
||||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
|
return (
|
||||||
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
|
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
||||||
|
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BasicLayer(nn.Module):
|
class BasicLayer(nn.Module):
|
||||||
@ -431,9 +444,9 @@ class BasicLayer(nn.Module):
|
|||||||
depth,
|
depth,
|
||||||
num_heads,
|
num_heads,
|
||||||
window_size,
|
window_size,
|
||||||
mlp_ratio=4.,
|
mlp_ratio=4.0,
|
||||||
drop=0.,
|
drop=0.0,
|
||||||
drop_path=0.,
|
drop_path=0.0,
|
||||||
downsample=None,
|
downsample=None,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
local_conv_size=3,
|
local_conv_size=3,
|
||||||
@ -468,7 +481,8 @@ class BasicLayer(nn.Module):
|
|||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
|
|
||||||
# Build blocks
|
# Build blocks
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
TinyViTBlock(
|
TinyViTBlock(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
input_resolution=input_resolution,
|
input_resolution=input_resolution,
|
||||||
@ -479,11 +493,17 @@ class BasicLayer(nn.Module):
|
|||||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||||
local_conv_size=local_conv_size,
|
local_conv_size=local_conv_size,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
) for i in range(depth)])
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Patch merging layer
|
# Patch merging layer
|
||||||
self.downsample = None if downsample is None else downsample(
|
self.downsample = (
|
||||||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
None
|
||||||
|
if downsample is None
|
||||||
|
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
|
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
|
||||||
@ -493,7 +513,7 @@ class BasicLayer(nn.Module):
|
|||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
"""Returns a string representation of the extra_repr function with the layer's parameters."""
|
"""Returns a string representation of the extra_repr function with the layer's parameters."""
|
||||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm2d(nn.Module):
|
class LayerNorm2d(nn.Module):
|
||||||
@ -549,8 +569,8 @@ class TinyViT(nn.Module):
|
|||||||
depths=[2, 2, 6, 2],
|
depths=[2, 2, 6, 2],
|
||||||
num_heads=[3, 6, 12, 24],
|
num_heads=[3, 6, 12, 24],
|
||||||
window_sizes=[7, 7, 14, 7],
|
window_sizes=[7, 7, 14, 7],
|
||||||
mlp_ratio=4.,
|
mlp_ratio=4.0,
|
||||||
drop_rate=0.,
|
drop_rate=0.0,
|
||||||
drop_path_rate=0.1,
|
drop_path_rate=0.1,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
mbconv_expand_ratio=4.0,
|
mbconv_expand_ratio=4.0,
|
||||||
@ -585,10 +605,9 @@ class TinyViT(nn.Module):
|
|||||||
|
|
||||||
activation = nn.GELU
|
activation = nn.GELU
|
||||||
|
|
||||||
self.patch_embed = PatchEmbed(in_chans=in_chans,
|
self.patch_embed = PatchEmbed(
|
||||||
embed_dim=embed_dims[0],
|
in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
|
||||||
resolution=img_size,
|
)
|
||||||
activation=activation)
|
|
||||||
|
|
||||||
patches_resolution = self.patch_embed.patches_resolution
|
patches_resolution = self.patch_embed.patches_resolution
|
||||||
self.patches_resolution = patches_resolution
|
self.patches_resolution = patches_resolution
|
||||||
@ -601,27 +620,30 @@ class TinyViT(nn.Module):
|
|||||||
for i_layer in range(self.num_layers):
|
for i_layer in range(self.num_layers):
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
dim=embed_dims[i_layer],
|
dim=embed_dims[i_layer],
|
||||||
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
input_resolution=(
|
||||||
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
|
patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||||
|
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||||
|
),
|
||||||
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
||||||
# patches_resolution[1] // (2 ** i_layer)),
|
# patches_resolution[1] // (2 ** i_layer)),
|
||||||
depth=depths[i_layer],
|
depth=depths[i_layer],
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
||||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
out_dim=embed_dims[min(i_layer + 1,
|
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
|
||||||
len(embed_dims) - 1)],
|
|
||||||
activation=activation,
|
activation=activation,
|
||||||
)
|
)
|
||||||
if i_layer == 0:
|
if i_layer == 0:
|
||||||
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
|
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
|
||||||
else:
|
else:
|
||||||
layer = BasicLayer(num_heads=num_heads[i_layer],
|
layer = BasicLayer(
|
||||||
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_sizes[i_layer],
|
window_size=window_sizes[i_layer],
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
drop=drop_rate,
|
drop=drop_rate,
|
||||||
local_conv_size=local_conv_size,
|
local_conv_size=local_conv_size,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
self.layers.append(layer)
|
self.layers.append(layer)
|
||||||
|
|
||||||
# Classifier head
|
# Classifier head
|
||||||
@ -680,7 +702,7 @@ class TinyViT(nn.Module):
|
|||||||
def _check_lr_scale(m):
|
def _check_lr_scale(m):
|
||||||
"""Checks if the learning rate scale attribute is present in module's parameters."""
|
"""Checks if the learning rate scale attribute is present in module's parameters."""
|
||||||
for p in m.parameters():
|
for p in m.parameters():
|
||||||
assert hasattr(p, 'lr_scale'), p.param_name
|
assert hasattr(p, "lr_scale"), p.param_name
|
||||||
|
|
||||||
self.apply(_check_lr_scale)
|
self.apply(_check_lr_scale)
|
||||||
|
|
||||||
@ -698,7 +720,7 @@ class TinyViT(nn.Module):
|
|||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay_keywords(self):
|
def no_weight_decay_keywords(self):
|
||||||
"""Returns a dictionary of parameter names where weight decay should not be applied."""
|
"""Returns a dictionary of parameter names where weight decay should not be applied."""
|
||||||
return {'attention_biases'}
|
return {"attention_biases"}
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
"""Runs the input through the model layers and returns the transformed output."""
|
"""Runs the input through the model layers and returns the transformed output."""
|
||||||
|
|||||||
@ -62,7 +62,8 @@ class TwoWayTransformer(nn.Module):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
attention_downsample_rate=attention_downsample_rate,
|
attention_downsample_rate=attention_downsample_rate,
|
||||||
skip_first_layer_pe=(i == 0),
|
skip_first_layer_pe=(i == 0),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
||||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||||
@ -227,7 +228,7 @@ class Attention(nn.Module):
|
|||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.internal_dim = embedding_dim // downsample_rate
|
self.internal_dim = embedding_dim // downsample_rate
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.'
|
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
||||||
|
|
||||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||||
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||||
|
|||||||
@ -19,8 +19,17 @@ from ultralytics.engine.results import Results
|
|||||||
from ultralytics.utils import DEFAULT_CFG, ops
|
from ultralytics.utils import DEFAULT_CFG, ops
|
||||||
from ultralytics.utils.torch_utils import select_device
|
from ultralytics.utils.torch_utils import select_device
|
||||||
|
|
||||||
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
|
from .amg import (
|
||||||
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
|
batch_iterator,
|
||||||
|
batched_mask_to_box,
|
||||||
|
build_all_layer_point_grids,
|
||||||
|
calculate_stability_score,
|
||||||
|
generate_crop_boxes,
|
||||||
|
is_box_near_crop_edge,
|
||||||
|
remove_small_regions,
|
||||||
|
uncrop_boxes_xyxy,
|
||||||
|
uncrop_masks,
|
||||||
|
)
|
||||||
from .build import build_sam
|
from .build import build_sam
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +67,7 @@ class Predictor(BasePredictor):
|
|||||||
"""
|
"""
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
overrides.update(dict(task="segment", mode="predict", imgsz=1024))
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
self.args.retina_masks = True
|
self.args.retina_masks = True
|
||||||
self.im = None
|
self.im = None
|
||||||
@ -107,7 +116,7 @@ class Predictor(BasePredictor):
|
|||||||
Returns:
|
Returns:
|
||||||
(List[np.ndarray]): List of transformed images.
|
(List[np.ndarray]): List of transformed images.
|
||||||
"""
|
"""
|
||||||
assert len(im) == 1, 'SAM model does not currently support batched inference'
|
assert len(im) == 1, "SAM model does not currently support batched inference"
|
||||||
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
||||||
return [letterbox(image=x) for x in im]
|
return [letterbox(image=x) for x in im]
|
||||||
|
|
||||||
@ -132,9 +141,9 @@ class Predictor(BasePredictor):
|
|||||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||||
"""
|
"""
|
||||||
# Override prompts if any stored in self.prompts
|
# Override prompts if any stored in self.prompts
|
||||||
bboxes = self.prompts.pop('bboxes', bboxes)
|
bboxes = self.prompts.pop("bboxes", bboxes)
|
||||||
points = self.prompts.pop('points', points)
|
points = self.prompts.pop("points", points)
|
||||||
masks = self.prompts.pop('masks', masks)
|
masks = self.prompts.pop("masks", masks)
|
||||||
|
|
||||||
if all(i is None for i in [bboxes, points, masks]):
|
if all(i is None for i in [bboxes, points, masks]):
|
||||||
return self.generate(im, *args, **kwargs)
|
return self.generate(im, *args, **kwargs)
|
||||||
@ -199,7 +208,8 @@ class Predictor(BasePredictor):
|
|||||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||||
|
|
||||||
def generate(self,
|
def generate(
|
||||||
|
self,
|
||||||
im,
|
im,
|
||||||
crop_n_layers=0,
|
crop_n_layers=0,
|
||||||
crop_overlap_ratio=512 / 1500,
|
crop_overlap_ratio=512 / 1500,
|
||||||
@ -210,7 +220,8 @@ class Predictor(BasePredictor):
|
|||||||
conf_thres=0.88,
|
conf_thres=0.88,
|
||||||
stability_score_thresh=0.95,
|
stability_score_thresh=0.95,
|
||||||
stability_score_offset=0.95,
|
stability_score_offset=0.95,
|
||||||
crop_nms_thresh=0.7):
|
crop_nms_thresh=0.7,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Perform image segmentation using the Segment Anything Model (SAM).
|
Perform image segmentation using the Segment Anything Model (SAM).
|
||||||
|
|
||||||
@ -248,19 +259,20 @@ class Predictor(BasePredictor):
|
|||||||
area = torch.tensor(w * h, device=im.device)
|
area = torch.tensor(w * h, device=im.device)
|
||||||
points_scale = np.array([[w, h]]) # w, h
|
points_scale = np.array([[w, h]]) # w, h
|
||||||
# Crop image and interpolate to input size
|
# Crop image and interpolate to input size
|
||||||
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
|
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
|
||||||
# (num_points, 2)
|
# (num_points, 2)
|
||||||
points_for_image = point_grids[layer_idx] * points_scale
|
points_for_image = point_grids[layer_idx] * points_scale
|
||||||
crop_masks, crop_scores, crop_bboxes = [], [], []
|
crop_masks, crop_scores, crop_bboxes = [], [], []
|
||||||
for (points, ) in batch_iterator(points_batch_size, points_for_image):
|
for (points,) in batch_iterator(points_batch_size, points_for_image):
|
||||||
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
|
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
|
||||||
# Interpolate predicted masks to input size
|
# Interpolate predicted masks to input size
|
||||||
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
|
pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
|
||||||
idx = pred_score > conf_thres
|
idx = pred_score > conf_thres
|
||||||
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
||||||
|
|
||||||
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
|
stability_score = calculate_stability_score(
|
||||||
stability_score_offset)
|
pred_mask, self.model.mask_threshold, stability_score_offset
|
||||||
|
)
|
||||||
idx = stability_score > stability_score_thresh
|
idx = stability_score > stability_score_thresh
|
||||||
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
||||||
# Bool type is much more memory-efficient.
|
# Bool type is much more memory-efficient.
|
||||||
@ -404,7 +416,7 @@ class Predictor(BasePredictor):
|
|||||||
model = build_sam(self.args.model)
|
model = build_sam(self.args.model)
|
||||||
self.setup_model(model)
|
self.setup_model(model)
|
||||||
self.setup_source(image)
|
self.setup_source(image)
|
||||||
assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
|
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
||||||
for batch in self.dataset:
|
for batch in self.dataset:
|
||||||
im = self.preprocess(batch[1])
|
im = self.preprocess(batch[1])
|
||||||
self.features = self.model.image_encoder(im)
|
self.features = self.model.image_encoder(im)
|
||||||
@ -446,9 +458,9 @@ class Predictor(BasePredictor):
|
|||||||
scores = []
|
scores = []
|
||||||
for mask in masks:
|
for mask in masks:
|
||||||
mask = mask.cpu().numpy().astype(np.uint8)
|
mask = mask.cpu().numpy().astype(np.uint8)
|
||||||
mask, changed = remove_small_regions(mask, min_area, mode='holes')
|
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
||||||
unchanged = not changed
|
unchanged = not changed
|
||||||
mask, changed = remove_small_regions(mask, min_area, mode='islands')
|
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
||||||
unchanged = unchanged and not changed
|
unchanged = unchanged and not changed
|
||||||
|
|
||||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
||||||
|
|||||||
@ -30,14 +30,9 @@ class DETRLoss(nn.Module):
|
|||||||
device (torch.device): Device on which tensors are stored.
|
device (torch.device): Device on which tensors are stored.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
nc=80,
|
self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
|
||||||
loss_gain=None,
|
):
|
||||||
aux_loss=True,
|
|
||||||
use_fl=True,
|
|
||||||
use_vfl=False,
|
|
||||||
use_uni_match=False,
|
|
||||||
uni_match_ind=0):
|
|
||||||
"""
|
"""
|
||||||
DETR loss function.
|
DETR loss function.
|
||||||
|
|
||||||
@ -52,9 +47,9 @@ class DETRLoss(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if loss_gain is None:
|
if loss_gain is None:
|
||||||
loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
|
loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
|
||||||
self.nc = nc
|
self.nc = nc
|
||||||
self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
|
self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
|
||||||
self.loss_gain = loss_gain
|
self.loss_gain = loss_gain
|
||||||
self.aux_loss = aux_loss
|
self.aux_loss = aux_loss
|
||||||
self.fl = FocalLoss() if use_fl else None
|
self.fl = FocalLoss() if use_fl else None
|
||||||
@ -64,10 +59,10 @@ class DETRLoss(nn.Module):
|
|||||||
self.uni_match_ind = uni_match_ind
|
self.uni_match_ind = uni_match_ind
|
||||||
self.device = None
|
self.device = None
|
||||||
|
|
||||||
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
|
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
|
||||||
"""Computes the classification loss based on predictions, target values, and ground truth scores."""
|
"""Computes the classification loss based on predictions, target values, and ground truth scores."""
|
||||||
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
||||||
name_class = f'loss_class{postfix}'
|
name_class = f"loss_class{postfix}"
|
||||||
bs, nq = pred_scores.shape[:2]
|
bs, nq = pred_scores.shape[:2]
|
||||||
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
|
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
|
||||||
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
|
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
|
||||||
@ -82,28 +77,28 @@ class DETRLoss(nn.Module):
|
|||||||
loss_cls = self.fl(pred_scores, one_hot.float())
|
loss_cls = self.fl(pred_scores, one_hot.float())
|
||||||
loss_cls /= max(num_gts, 1) / nq
|
loss_cls /= max(num_gts, 1) / nq
|
||||||
else:
|
else:
|
||||||
loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
|
loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
|
||||||
|
|
||||||
return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
|
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
||||||
|
|
||||||
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
|
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
|
||||||
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
|
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
|
||||||
boxes.
|
boxes.
|
||||||
"""
|
"""
|
||||||
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
||||||
name_bbox = f'loss_bbox{postfix}'
|
name_bbox = f"loss_bbox{postfix}"
|
||||||
name_giou = f'loss_giou{postfix}'
|
name_giou = f"loss_giou{postfix}"
|
||||||
|
|
||||||
loss = {}
|
loss = {}
|
||||||
if len(gt_bboxes) == 0:
|
if len(gt_bboxes) == 0:
|
||||||
loss[name_bbox] = torch.tensor(0., device=self.device)
|
loss[name_bbox] = torch.tensor(0.0, device=self.device)
|
||||||
loss[name_giou] = torch.tensor(0., device=self.device)
|
loss[name_giou] = torch.tensor(0.0, device=self.device)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
|
loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
|
||||||
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
|
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
|
||||||
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
|
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
|
||||||
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
|
loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
|
||||||
return {k: v.squeeze() for k, v in loss.items()}
|
return {k: v.squeeze() for k, v in loss.items()}
|
||||||
|
|
||||||
# This function is for future RT-DETR Segment models
|
# This function is for future RT-DETR Segment models
|
||||||
@ -137,30 +132,35 @@ class DETRLoss(nn.Module):
|
|||||||
# loss = 1 - (numerator + 1) / (denominator + 1)
|
# loss = 1 - (numerator + 1) / (denominator + 1)
|
||||||
# return loss.sum() / num_gts
|
# return loss.sum() / num_gts
|
||||||
|
|
||||||
def _get_loss_aux(self,
|
def _get_loss_aux(
|
||||||
|
self,
|
||||||
pred_bboxes,
|
pred_bboxes,
|
||||||
pred_scores,
|
pred_scores,
|
||||||
gt_bboxes,
|
gt_bboxes,
|
||||||
gt_cls,
|
gt_cls,
|
||||||
gt_groups,
|
gt_groups,
|
||||||
match_indices=None,
|
match_indices=None,
|
||||||
postfix='',
|
postfix="",
|
||||||
masks=None,
|
masks=None,
|
||||||
gt_mask=None):
|
gt_mask=None,
|
||||||
|
):
|
||||||
"""Get auxiliary losses."""
|
"""Get auxiliary losses."""
|
||||||
# NOTE: loss class, bbox, giou, mask, dice
|
# NOTE: loss class, bbox, giou, mask, dice
|
||||||
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
||||||
if match_indices is None and self.use_uni_match:
|
if match_indices is None and self.use_uni_match:
|
||||||
match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
|
match_indices = self.matcher(
|
||||||
|
pred_bboxes[self.uni_match_ind],
|
||||||
pred_scores[self.uni_match_ind],
|
pred_scores[self.uni_match_ind],
|
||||||
gt_bboxes,
|
gt_bboxes,
|
||||||
gt_cls,
|
gt_cls,
|
||||||
gt_groups,
|
gt_groups,
|
||||||
masks=masks[self.uni_match_ind] if masks is not None else None,
|
masks=masks[self.uni_match_ind] if masks is not None else None,
|
||||||
gt_mask=gt_mask)
|
gt_mask=gt_mask,
|
||||||
|
)
|
||||||
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
|
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
|
||||||
aux_masks = masks[i] if masks is not None else None
|
aux_masks = masks[i] if masks is not None else None
|
||||||
loss_ = self._get_loss(aux_bboxes,
|
loss_ = self._get_loss(
|
||||||
|
aux_bboxes,
|
||||||
aux_scores,
|
aux_scores,
|
||||||
gt_bboxes,
|
gt_bboxes,
|
||||||
gt_cls,
|
gt_cls,
|
||||||
@ -168,19 +168,21 @@ class DETRLoss(nn.Module):
|
|||||||
masks=aux_masks,
|
masks=aux_masks,
|
||||||
gt_mask=gt_mask,
|
gt_mask=gt_mask,
|
||||||
postfix=postfix,
|
postfix=postfix,
|
||||||
match_indices=match_indices)
|
match_indices=match_indices,
|
||||||
loss[0] += loss_[f'loss_class{postfix}']
|
)
|
||||||
loss[1] += loss_[f'loss_bbox{postfix}']
|
loss[0] += loss_[f"loss_class{postfix}"]
|
||||||
loss[2] += loss_[f'loss_giou{postfix}']
|
loss[1] += loss_[f"loss_bbox{postfix}"]
|
||||||
|
loss[2] += loss_[f"loss_giou{postfix}"]
|
||||||
# if masks is not None and gt_mask is not None:
|
# if masks is not None and gt_mask is not None:
|
||||||
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
|
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
|
||||||
# loss[3] += loss_[f'loss_mask{postfix}']
|
# loss[3] += loss_[f'loss_mask{postfix}']
|
||||||
# loss[4] += loss_[f'loss_dice{postfix}']
|
# loss[4] += loss_[f'loss_dice{postfix}']
|
||||||
|
|
||||||
loss = {
|
loss = {
|
||||||
f'loss_class_aux{postfix}': loss[0],
|
f"loss_class_aux{postfix}": loss[0],
|
||||||
f'loss_bbox_aux{postfix}': loss[1],
|
f"loss_bbox_aux{postfix}": loss[1],
|
||||||
f'loss_giou_aux{postfix}': loss[2]}
|
f"loss_giou_aux{postfix}": loss[2],
|
||||||
|
}
|
||||||
# if masks is not None and gt_mask is not None:
|
# if masks is not None and gt_mask is not None:
|
||||||
# loss[f'loss_mask_aux{postfix}'] = loss[3]
|
# loss[f'loss_mask_aux{postfix}'] = loss[3]
|
||||||
# loss[f'loss_dice_aux{postfix}'] = loss[4]
|
# loss[f'loss_dice_aux{postfix}'] = loss[4]
|
||||||
@ -196,15 +198,22 @@ class DETRLoss(nn.Module):
|
|||||||
|
|
||||||
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
|
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
|
||||||
"""Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
|
"""Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
|
||||||
pred_assigned = torch.cat([
|
pred_assigned = torch.cat(
|
||||||
|
[
|
||||||
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||||
for t, (I, _) in zip(pred_bboxes, match_indices)])
|
for t, (I, _) in zip(pred_bboxes, match_indices)
|
||||||
gt_assigned = torch.cat([
|
]
|
||||||
|
)
|
||||||
|
gt_assigned = torch.cat(
|
||||||
|
[
|
||||||
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||||
for t, (_, J) in zip(gt_bboxes, match_indices)])
|
for t, (_, J) in zip(gt_bboxes, match_indices)
|
||||||
|
]
|
||||||
|
)
|
||||||
return pred_assigned, gt_assigned
|
return pred_assigned, gt_assigned
|
||||||
|
|
||||||
def _get_loss(self,
|
def _get_loss(
|
||||||
|
self,
|
||||||
pred_bboxes,
|
pred_bboxes,
|
||||||
pred_scores,
|
pred_scores,
|
||||||
gt_bboxes,
|
gt_bboxes,
|
||||||
@ -212,17 +221,14 @@ class DETRLoss(nn.Module):
|
|||||||
gt_groups,
|
gt_groups,
|
||||||
masks=None,
|
masks=None,
|
||||||
gt_mask=None,
|
gt_mask=None,
|
||||||
postfix='',
|
postfix="",
|
||||||
match_indices=None):
|
match_indices=None,
|
||||||
|
):
|
||||||
"""Get losses."""
|
"""Get losses."""
|
||||||
if match_indices is None:
|
if match_indices is None:
|
||||||
match_indices = self.matcher(pred_bboxes,
|
match_indices = self.matcher(
|
||||||
pred_scores,
|
pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
|
||||||
gt_bboxes,
|
)
|
||||||
gt_cls,
|
|
||||||
gt_groups,
|
|
||||||
masks=masks,
|
|
||||||
gt_mask=gt_mask)
|
|
||||||
|
|
||||||
idx, gt_idx = self._get_index(match_indices)
|
idx, gt_idx = self._get_index(match_indices)
|
||||||
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
|
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
|
||||||
@ -242,7 +248,7 @@ class DETRLoss(nn.Module):
|
|||||||
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
|
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
|
def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
pred_bboxes (torch.Tensor): [l, b, query, 4]
|
pred_bboxes (torch.Tensor): [l, b, query, 4]
|
||||||
@ -254,21 +260,19 @@ class DETRLoss(nn.Module):
|
|||||||
postfix (str): postfix of loss name.
|
postfix (str): postfix of loss name.
|
||||||
"""
|
"""
|
||||||
self.device = pred_bboxes.device
|
self.device = pred_bboxes.device
|
||||||
match_indices = kwargs.get('match_indices', None)
|
match_indices = kwargs.get("match_indices", None)
|
||||||
gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
|
gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
|
||||||
|
|
||||||
total_loss = self._get_loss(pred_bboxes[-1],
|
total_loss = self._get_loss(
|
||||||
pred_scores[-1],
|
pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
|
||||||
gt_bboxes,
|
)
|
||||||
gt_cls,
|
|
||||||
gt_groups,
|
|
||||||
postfix=postfix,
|
|
||||||
match_indices=match_indices)
|
|
||||||
|
|
||||||
if self.aux_loss:
|
if self.aux_loss:
|
||||||
total_loss.update(
|
total_loss.update(
|
||||||
self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
|
self._get_loss_aux(
|
||||||
postfix))
|
pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
@ -300,18 +304,18 @@ class RTDETRDetectionLoss(DETRLoss):
|
|||||||
|
|
||||||
# Check for denoising metadata to compute denoising training loss
|
# Check for denoising metadata to compute denoising training loss
|
||||||
if dn_meta is not None:
|
if dn_meta is not None:
|
||||||
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
|
dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
|
||||||
assert len(batch['gt_groups']) == len(dn_pos_idx)
|
assert len(batch["gt_groups"]) == len(dn_pos_idx)
|
||||||
|
|
||||||
# Get the match indices for denoising
|
# Get the match indices for denoising
|
||||||
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
|
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
|
||||||
|
|
||||||
# Compute the denoising training loss
|
# Compute the denoising training loss
|
||||||
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
|
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
|
||||||
total_loss.update(dn_loss)
|
total_loss.update(dn_loss)
|
||||||
else:
|
else:
|
||||||
# If no denoising metadata is provided, set denoising loss to zero
|
# If no denoising metadata is provided, set denoising loss to zero
|
||||||
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
|
total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
@ -334,8 +338,8 @@ class RTDETRDetectionLoss(DETRLoss):
|
|||||||
if num_gt > 0:
|
if num_gt > 0:
|
||||||
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
||||||
gt_idx = gt_idx.repeat(dn_num_group)
|
gt_idx = gt_idx.repeat(dn_num_group)
|
||||||
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
|
assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
|
||||||
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
|
f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
|
||||||
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
||||||
else:
|
else:
|
||||||
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class HungarianMatcher(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if cost_gain is None:
|
if cost_gain is None:
|
||||||
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
|
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
|
||||||
self.cost_gain = cost_gain
|
self.cost_gain = cost_gain
|
||||||
self.use_fl = use_fl
|
self.use_fl = use_fl
|
||||||
self.with_mask = with_mask
|
self.with_mask = with_mask
|
||||||
@ -86,7 +86,7 @@ class HungarianMatcher(nn.Module):
|
|||||||
# Compute the classification cost
|
# Compute the classification cost
|
||||||
pred_scores = pred_scores[:, gt_cls]
|
pred_scores = pred_scores[:, gt_cls]
|
||||||
if self.use_fl:
|
if self.use_fl:
|
||||||
neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
||||||
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
||||||
cost_class = pos_cost_class - neg_cost_class
|
cost_class = pos_cost_class - neg_cost_class
|
||||||
else:
|
else:
|
||||||
@ -99,9 +99,11 @@ class HungarianMatcher(nn.Module):
|
|||||||
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
||||||
|
|
||||||
# Final cost matrix
|
# Final cost matrix
|
||||||
C = self.cost_gain['class'] * cost_class + \
|
C = (
|
||||||
self.cost_gain['bbox'] * cost_bbox + \
|
self.cost_gain["class"] * cost_class
|
||||||
self.cost_gain['giou'] * cost_giou
|
+ self.cost_gain["bbox"] * cost_bbox
|
||||||
|
+ self.cost_gain["giou"] * cost_giou
|
||||||
|
)
|
||||||
# Compute the mask cost and dice cost
|
# Compute the mask cost and dice cost
|
||||||
if self.with_mask:
|
if self.with_mask:
|
||||||
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
||||||
@ -111,10 +113,11 @@ class HungarianMatcher(nn.Module):
|
|||||||
|
|
||||||
C = C.view(bs, nq, -1).cpu()
|
C = C.view(bs, nq, -1).cpu()
|
||||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
||||||
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
|
||||||
# (idx for queries, idx for gt)
|
return [
|
||||||
return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
|
(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
|
||||||
for k, (i, j) in enumerate(indices)]
|
for k, (i, j) in enumerate(indices)
|
||||||
|
]
|
||||||
|
|
||||||
# This function is for future RT-DETR Segment models
|
# This function is for future RT-DETR Segment models
|
||||||
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
||||||
@ -147,14 +150,9 @@ class HungarianMatcher(nn.Module):
|
|||||||
# return C
|
# return C
|
||||||
|
|
||||||
|
|
||||||
def get_cdn_group(batch,
|
def get_cdn_group(
|
||||||
num_classes,
|
batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
|
||||||
num_queries,
|
):
|
||||||
class_embed,
|
|
||||||
num_dn=100,
|
|
||||||
cls_noise_ratio=0.5,
|
|
||||||
box_noise_scale=1.0,
|
|
||||||
training=False):
|
|
||||||
"""
|
"""
|
||||||
Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
|
Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
|
||||||
and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
|
and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
|
||||||
@ -180,7 +178,7 @@ def get_cdn_group(batch,
|
|||||||
|
|
||||||
if (not training) or num_dn <= 0:
|
if (not training) or num_dn <= 0:
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
gt_groups = batch['gt_groups']
|
gt_groups = batch["gt_groups"]
|
||||||
total_num = sum(gt_groups)
|
total_num = sum(gt_groups)
|
||||||
max_nums = max(gt_groups)
|
max_nums = max(gt_groups)
|
||||||
if max_nums == 0:
|
if max_nums == 0:
|
||||||
@ -190,9 +188,9 @@ def get_cdn_group(batch,
|
|||||||
num_group = 1 if num_group == 0 else num_group
|
num_group = 1 if num_group == 0 else num_group
|
||||||
# Pad gt to max_num of a batch
|
# Pad gt to max_num of a batch
|
||||||
bs = len(gt_groups)
|
bs = len(gt_groups)
|
||||||
gt_cls = batch['cls'] # (bs*num, )
|
gt_cls = batch["cls"] # (bs*num, )
|
||||||
gt_bbox = batch['bboxes'] # bs*num, 4
|
gt_bbox = batch["bboxes"] # bs*num, 4
|
||||||
b_idx = batch['batch_idx']
|
b_idx = batch["batch_idx"]
|
||||||
|
|
||||||
# Each group has positive and negative queries.
|
# Each group has positive and negative queries.
|
||||||
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
||||||
@ -245,16 +243,21 @@ def get_cdn_group(batch,
|
|||||||
# Reconstruct cannot see each other
|
# Reconstruct cannot see each other
|
||||||
for i in range(num_group):
|
for i in range(num_group):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
||||||
if i == num_group - 1:
|
if i == num_group - 1:
|
||||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
|
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
|
||||||
else:
|
else:
|
||||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
||||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
|
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
|
||||||
dn_meta = {
|
dn_meta = {
|
||||||
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
|
"dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
|
||||||
'dn_num_group': num_group,
|
"dn_num_group": num_group,
|
||||||
'dn_num_split': [num_dn, num_queries]}
|
"dn_num_split": [num_dn, num_queries],
|
||||||
|
}
|
||||||
|
|
||||||
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
|
return (
|
||||||
class_embed.device), dn_meta
|
padding_cls.to(class_embed.device),
|
||||||
|
padding_bbox.to(class_embed.device),
|
||||||
|
attn_mask.to(class_embed.device),
|
||||||
|
dn_meta,
|
||||||
|
)
|
||||||
|
|||||||
@ -4,4 +4,4 @@ from ultralytics.models.yolo import classify, detect, obb, pose, segment
|
|||||||
|
|
||||||
from .model import YOLO
|
from .model import YOLO
|
||||||
|
|
||||||
__all__ = 'classify', 'segment', 'detect', 'pose', 'obb', 'YOLO'
|
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO"
|
||||||
|
|||||||
@ -4,4 +4,4 @@ from ultralytics.models.yolo.classify.predict import ClassificationPredictor
|
|||||||
from ultralytics.models.yolo.classify.train import ClassificationTrainer
|
from ultralytics.models.yolo.classify.train import ClassificationTrainer
|
||||||
from ultralytics.models.yolo.classify.val import ClassificationValidator
|
from ultralytics.models.yolo.classify.val import ClassificationValidator
|
||||||
|
|
||||||
__all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator'
|
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
|
||||||
|
|||||||
@ -30,19 +30,21 @@ class ClassificationPredictor(BasePredictor):
|
|||||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
"""Initializes ClassificationPredictor setting the task to 'classify'."""
|
"""Initializes ClassificationPredictor setting the task to 'classify'."""
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
self.args.task = 'classify'
|
self.args.task = "classify"
|
||||||
self._legacy_transform_name = 'ultralytics.yolo.data.augment.ToTensor'
|
self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
|
||||||
|
|
||||||
def preprocess(self, img):
|
def preprocess(self, img):
|
||||||
"""Converts input image to model-compatible data type."""
|
"""Converts input image to model-compatible data type."""
|
||||||
if not isinstance(img, torch.Tensor):
|
if not isinstance(img, torch.Tensor):
|
||||||
is_legacy_transform = any(self._legacy_transform_name in str(transform)
|
is_legacy_transform = any(
|
||||||
for transform in self.transforms.transforms)
|
self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
|
||||||
|
)
|
||||||
if is_legacy_transform: # to handle legacy transforms
|
if is_legacy_transform: # to handle legacy transforms
|
||||||
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
||||||
else:
|
else:
|
||||||
img = torch.stack([self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img],
|
img = torch.stack(
|
||||||
dim=0)
|
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
||||||
|
)
|
||||||
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||||
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||||
|
|
||||||
|
|||||||
@ -33,23 +33,23 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
|
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides['task'] = 'classify'
|
overrides["task"] = "classify"
|
||||||
if overrides.get('imgsz') is None:
|
if overrides.get("imgsz") is None:
|
||||||
overrides['imgsz'] = 224
|
overrides["imgsz"] = 224
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
|
||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
"""Set the YOLO model's class names from the loaded dataset."""
|
"""Set the YOLO model's class names from the loaded dataset."""
|
||||||
self.model.names = self.data['names']
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
"""Returns a modified PyTorch model configured for training YOLO."""
|
"""Returns a modified PyTorch model configured for training YOLO."""
|
||||||
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
|
if not self.args.pretrained and hasattr(m, "reset_parameters"):
|
||||||
m.reset_parameters()
|
m.reset_parameters()
|
||||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||||
m.p = self.args.dropout # set dropout
|
m.p = self.args.dropout # set dropout
|
||||||
@ -64,32 +64,32 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
model, ckpt = str(self.model), None
|
model, ckpt = str(self.model), None
|
||||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||||
if model.endswith('.pt'):
|
if model.endswith(".pt"):
|
||||||
self.model, ckpt = attempt_load_one_weight(model, device='cpu')
|
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
|
||||||
for p in self.model.parameters():
|
for p in self.model.parameters():
|
||||||
p.requires_grad = True # for training
|
p.requires_grad = True # for training
|
||||||
elif model.split('.')[-1] in ('yaml', 'yml'):
|
elif model.split(".")[-1] in ("yaml", "yml"):
|
||||||
self.model = self.get_model(cfg=model)
|
self.model = self.get_model(cfg=model)
|
||||||
elif model in torchvision.models.__dict__:
|
elif model in torchvision.models.__dict__:
|
||||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
|
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
|
||||||
else:
|
else:
|
||||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
|
||||||
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
||||||
|
|
||||||
return ckpt
|
return ckpt
|
||||||
|
|
||||||
def build_dataset(self, img_path, mode='train', batch=None):
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
"""Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
|
"""Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
|
||||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
|
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||||
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||||
dataset = self.build_dataset(dataset_path, mode)
|
dataset = self.build_dataset(dataset_path, mode)
|
||||||
|
|
||||||
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
|
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
|
||||||
# Attach inference transforms
|
# Attach inference transforms
|
||||||
if mode != 'train':
|
if mode != "train":
|
||||||
if is_parallel(self.model):
|
if is_parallel(self.model):
|
||||||
self.model.module.transforms = loader.dataset.torch_transforms
|
self.model.module.transforms = loader.dataset.torch_transforms
|
||||||
else:
|
else:
|
||||||
@ -98,27 +98,32 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
"""Preprocesses a batch of images and classes."""
|
"""Preprocesses a batch of images and classes."""
|
||||||
batch['img'] = batch['img'].to(self.device)
|
batch["img"] = batch["img"].to(self.device)
|
||||||
batch['cls'] = batch['cls'].to(self.device)
|
batch["cls"] = batch["cls"].to(self.device)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def progress_string(self):
|
def progress_string(self):
|
||||||
"""Returns a formatted string showing training progress."""
|
"""Returns a formatted string showing training progress."""
|
||||||
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
|
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||||
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
"Epoch",
|
||||||
|
"GPU_mem",
|
||||||
|
*self.loss_names,
|
||||||
|
"Instances",
|
||||||
|
"Size",
|
||||||
|
)
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
"""Returns an instance of ClassificationValidator for validation."""
|
"""Returns an instance of ClassificationValidator for validation."""
|
||||||
self.loss_names = ['loss']
|
self.loss_names = ["loss"]
|
||||||
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
|
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
|
||||||
|
|
||||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||||
"""
|
"""
|
||||||
Returns a loss dict with labelled training loss items tensor.
|
Returns a loss dict with labelled training loss items tensor.
|
||||||
|
|
||||||
Not needed for classification but necessary for segmentation & detection
|
Not needed for classification but necessary for segmentation & detection
|
||||||
"""
|
"""
|
||||||
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||||
if loss_items is None:
|
if loss_items is None:
|
||||||
return keys
|
return keys
|
||||||
loss_items = [round(float(loss_items), 5)]
|
loss_items = [round(float(loss_items), 5)]
|
||||||
@ -134,19 +139,20 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
if f.exists():
|
if f.exists():
|
||||||
strip_optimizer(f) # strip optimizers
|
strip_optimizer(f) # strip optimizers
|
||||||
if f is self.best:
|
if f is self.best:
|
||||||
LOGGER.info(f'\nValidating {f}...')
|
LOGGER.info(f"\nValidating {f}...")
|
||||||
self.validator.args.data = self.args.data
|
self.validator.args.data = self.args.data
|
||||||
self.validator.args.plots = self.args.plots
|
self.validator.args.plots = self.args.plots
|
||||||
self.metrics = self.validator(model=f)
|
self.metrics = self.validator(model=f)
|
||||||
self.metrics.pop('fitness', None)
|
self.metrics.pop("fitness", None)
|
||||||
self.run_callbacks('on_fit_epoch_end')
|
self.run_callbacks("on_fit_epoch_end")
|
||||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||||
|
|
||||||
def plot_training_samples(self, batch, ni):
|
def plot_training_samples(self, batch, ni):
|
||||||
"""Plots training samples with their annotations."""
|
"""Plots training samples with their annotations."""
|
||||||
plot_images(
|
plot_images(
|
||||||
images=batch['img'],
|
images=batch["img"],
|
||||||
batch_idx=torch.arange(len(batch['img'])),
|
batch_idx=torch.arange(len(batch["img"])),
|
||||||
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||||
on_plot=self.on_plot)
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|||||||
@ -31,43 +31,42 @@ class ClassificationValidator(BaseValidator):
|
|||||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||||
self.targets = None
|
self.targets = None
|
||||||
self.pred = None
|
self.pred = None
|
||||||
self.args.task = 'classify'
|
self.args.task = "classify"
|
||||||
self.metrics = ClassifyMetrics()
|
self.metrics = ClassifyMetrics()
|
||||||
|
|
||||||
def get_desc(self):
|
def get_desc(self):
|
||||||
"""Returns a formatted string summarizing classification metrics."""
|
"""Returns a formatted string summarizing classification metrics."""
|
||||||
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
|
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
|
||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
|
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
|
||||||
self.names = model.names
|
self.names = model.names
|
||||||
self.nc = len(model.names)
|
self.nc = len(model.names)
|
||||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task='classify')
|
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
|
||||||
self.pred = []
|
self.pred = []
|
||||||
self.targets = []
|
self.targets = []
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
"""Preprocesses input batch and returns it."""
|
"""Preprocesses input batch and returns it."""
|
||||||
batch['img'] = batch['img'].to(self.device, non_blocking=True)
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||||
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
|
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
||||||
batch['cls'] = batch['cls'].to(self.device)
|
batch["cls"] = batch["cls"].to(self.device)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
"""Updates running metrics with model predictions and batch targets."""
|
"""Updates running metrics with model predictions and batch targets."""
|
||||||
n5 = min(len(self.names), 5)
|
n5 = min(len(self.names), 5)
|
||||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
|
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
|
||||||
self.targets.append(batch['cls'])
|
self.targets.append(batch["cls"])
|
||||||
|
|
||||||
def finalize_metrics(self, *args, **kwargs):
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
||||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
for normalize in True, False:
|
for normalize in True, False:
|
||||||
self.confusion_matrix.plot(save_dir=self.save_dir,
|
self.confusion_matrix.plot(
|
||||||
names=self.names.values(),
|
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
||||||
normalize=normalize,
|
)
|
||||||
on_plot=self.on_plot)
|
|
||||||
self.metrics.speed = self.speed
|
self.metrics.speed = self.speed
|
||||||
self.metrics.confusion_matrix = self.confusion_matrix
|
self.metrics.confusion_matrix = self.confusion_matrix
|
||||||
self.metrics.save_dir = self.save_dir
|
self.metrics.save_dir = self.save_dir
|
||||||
@ -88,24 +87,27 @@ class ClassificationValidator(BaseValidator):
|
|||||||
|
|
||||||
def print_results(self):
|
def print_results(self):
|
||||||
"""Prints evaluation metrics for YOLO object detection model."""
|
"""Prints evaluation metrics for YOLO object detection model."""
|
||||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
|
||||||
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
"""Plot validation image samples."""
|
"""Plot validation image samples."""
|
||||||
plot_images(
|
plot_images(
|
||||||
images=batch['img'],
|
images=batch["img"],
|
||||||
batch_idx=torch.arange(len(batch['img'])),
|
batch_idx=torch.arange(len(batch["img"])),
|
||||||
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||||
names=self.names,
|
names=self.names,
|
||||||
on_plot=self.on_plot)
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
def plot_predictions(self, batch, preds, ni):
|
def plot_predictions(self, batch, preds, ni):
|
||||||
"""Plots predicted bounding boxes on input images and saves the result."""
|
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||||
plot_images(batch['img'],
|
plot_images(
|
||||||
batch_idx=torch.arange(len(batch['img'])),
|
batch["img"],
|
||||||
|
batch_idx=torch.arange(len(batch["img"])),
|
||||||
cls=torch.argmax(preds, dim=1),
|
cls=torch.argmax(preds, dim=1),
|
||||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||||
names=self.names,
|
names=self.names,
|
||||||
on_plot=self.on_plot) # pred
|
on_plot=self.on_plot,
|
||||||
|
) # pred
|
||||||
|
|||||||
@ -4,4 +4,4 @@ from .predict import DetectionPredictor
|
|||||||
from .train import DetectionTrainer
|
from .train import DetectionTrainer
|
||||||
from .val import DetectionValidator
|
from .val import DetectionValidator
|
||||||
|
|
||||||
__all__ = 'DetectionPredictor', 'DetectionTrainer', 'DetectionValidator'
|
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"
|
||||||
|
|||||||
@ -22,12 +22,14 @@ class DetectionPredictor(BasePredictor):
|
|||||||
|
|
||||||
def postprocess(self, preds, img, orig_imgs):
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
"""Post-processes predictions and returns a list of Results objects."""
|
"""Post-processes predictions and returns a list of Results objects."""
|
||||||
preds = ops.non_max_suppression(preds,
|
preds = ops.non_max_suppression(
|
||||||
|
preds,
|
||||||
self.args.conf,
|
self.args.conf,
|
||||||
self.args.iou,
|
self.args.iou,
|
||||||
agnostic=self.args.agnostic_nms,
|
agnostic=self.args.agnostic_nms,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
classes=self.args.classes)
|
classes=self.args.classes,
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def build_dataset(self, img_path, mode='train', batch=None):
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
"""
|
"""
|
||||||
Build YOLO Dataset.
|
Build YOLO Dataset.
|
||||||
|
|
||||||
@ -40,33 +40,37 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||||
"""
|
"""
|
||||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||||
"""Construct and return dataloader."""
|
"""Construct and return dataloader."""
|
||||||
assert mode in ['train', 'val']
|
assert mode in ["train", "val"]
|
||||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||||
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
||||||
shuffle = mode == 'train'
|
shuffle = mode == "train"
|
||||||
if getattr(dataset, 'rect', False) and shuffle:
|
if getattr(dataset, "rect", False) and shuffle:
|
||||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||||
shuffle = False
|
shuffle = False
|
||||||
workers = self.args.workers if mode == 'train' else self.args.workers * 2
|
workers = self.args.workers if mode == "train" else self.args.workers * 2
|
||||||
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
|
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
"""Preprocesses a batch of images by scaling and converting to float."""
|
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||||
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||||
if self.args.multi_scale:
|
if self.args.multi_scale:
|
||||||
imgs = batch['img']
|
imgs = batch["img"]
|
||||||
sz = (random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) // self.stride *
|
sz = (
|
||||||
self.stride) # size
|
random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
|
||||||
|
// self.stride
|
||||||
|
* self.stride
|
||||||
|
) # size
|
||||||
sf = sz / max(imgs.shape[2:]) # scale factor
|
sf = sz / max(imgs.shape[2:]) # scale factor
|
||||||
if sf != 1:
|
if sf != 1:
|
||||||
ns = [math.ceil(x * sf / self.stride) * self.stride
|
ns = [
|
||||||
for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
|
math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
|
||||||
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
|
] # new shape (stretched to gs-multiple)
|
||||||
batch['img'] = imgs
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
batch["img"] = imgs
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
@ -74,33 +78,32 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
# self.args.box *= 3 / nl # scale to layers
|
# self.args.box *= 3 / nl # scale to layers
|
||||||
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||||
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||||
self.model.nc = self.data['nc'] # attach number of classes to model
|
self.model.nc = self.data["nc"] # attach number of classes to model
|
||||||
self.model.names = self.data['names'] # attach class names to model
|
self.model.names = self.data["names"] # attach class names to model
|
||||||
self.model.args = self.args # attach hyperparameters to model
|
self.model.args = self.args # attach hyperparameters to model
|
||||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
"""Return a YOLO detection model."""
|
"""Return a YOLO detection model."""
|
||||||
model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
"""Returns a DetectionValidator for YOLO model validation."""
|
"""Returns a DetectionValidator for YOLO model validation."""
|
||||||
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
|
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
||||||
return yolo.detect.DetectionValidator(self.test_loader,
|
return yolo.detect.DetectionValidator(
|
||||||
save_dir=self.save_dir,
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
args=copy(self.args),
|
)
|
||||||
_callbacks=self.callbacks)
|
|
||||||
|
|
||||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||||
"""
|
"""
|
||||||
Returns a loss dict with labelled training loss items tensor.
|
Returns a loss dict with labelled training loss items tensor.
|
||||||
|
|
||||||
Not needed for classification but necessary for segmentation & detection
|
Not needed for classification but necessary for segmentation & detection
|
||||||
"""
|
"""
|
||||||
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||||
if loss_items is not None:
|
if loss_items is not None:
|
||||||
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||||
return dict(zip(keys, loss_items))
|
return dict(zip(keys, loss_items))
|
||||||
@ -109,18 +112,25 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
|
|
||||||
def progress_string(self):
|
def progress_string(self):
|
||||||
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
|
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
|
||||||
return ('\n' + '%11s' *
|
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||||
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
"Epoch",
|
||||||
|
"GPU_mem",
|
||||||
|
*self.loss_names,
|
||||||
|
"Instances",
|
||||||
|
"Size",
|
||||||
|
)
|
||||||
|
|
||||||
def plot_training_samples(self, batch, ni):
|
def plot_training_samples(self, batch, ni):
|
||||||
"""Plots training samples with their annotations."""
|
"""Plots training samples with their annotations."""
|
||||||
plot_images(images=batch['img'],
|
plot_images(
|
||||||
batch_idx=batch['batch_idx'],
|
images=batch["img"],
|
||||||
cls=batch['cls'].squeeze(-1),
|
batch_idx=batch["batch_idx"],
|
||||||
bboxes=batch['bboxes'],
|
cls=batch["cls"].squeeze(-1),
|
||||||
paths=batch['im_file'],
|
bboxes=batch["bboxes"],
|
||||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
paths=batch["im_file"],
|
||||||
on_plot=self.on_plot)
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
def plot_metrics(self):
|
def plot_metrics(self):
|
||||||
"""Plots metrics from a CSV file."""
|
"""Plots metrics from a CSV file."""
|
||||||
@ -128,6 +138,6 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
|
|
||||||
def plot_training_labels(self):
|
def plot_training_labels(self):
|
||||||
"""Create a labeled training plot of the YOLO model."""
|
"""Create a labeled training plot of the YOLO model."""
|
||||||
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
|
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
|
||||||
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
|
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
|
||||||
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
|
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
self.nt_per_class = None
|
self.nt_per_class = None
|
||||||
self.is_coco = False
|
self.is_coco = False
|
||||||
self.class_map = None
|
self.class_map = None
|
||||||
self.args.task = 'detect'
|
self.args.task = "detect"
|
||||||
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||||
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
|
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
|
||||||
self.niou = self.iouv.numel()
|
self.niou = self.iouv.numel()
|
||||||
@ -42,25 +42,30 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
"""Preprocesses batch of images for YOLO training."""
|
"""Preprocesses batch of images for YOLO training."""
|
||||||
batch['img'] = batch['img'].to(self.device, non_blocking=True)
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||||
batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
||||||
for k in ['batch_idx', 'cls', 'bboxes']:
|
for k in ["batch_idx", "cls", "bboxes"]:
|
||||||
batch[k] = batch[k].to(self.device)
|
batch[k] = batch[k].to(self.device)
|
||||||
|
|
||||||
if self.args.save_hybrid:
|
if self.args.save_hybrid:
|
||||||
height, width = batch['img'].shape[2:]
|
height, width = batch["img"].shape[2:]
|
||||||
nb = len(batch['img'])
|
nb = len(batch["img"])
|
||||||
bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device)
|
bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
|
||||||
self.lb = [
|
self.lb = (
|
||||||
torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1)
|
[
|
||||||
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
|
torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
|
||||||
|
for i in range(nb)
|
||||||
|
]
|
||||||
|
if self.args.save_hybrid
|
||||||
|
else []
|
||||||
|
) # for autolabelling
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
"""Initialize evaluation metrics for YOLO."""
|
"""Initialize evaluation metrics for YOLO."""
|
||||||
val = self.data.get(self.args.split, '') # validation path
|
val = self.data.get(self.args.split, "") # validation path
|
||||||
self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO
|
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
|
||||||
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||||
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
|
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
|
||||||
self.names = model.names
|
self.names = model.names
|
||||||
@ -74,26 +79,28 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
def get_desc(self):
|
def get_desc(self):
|
||||||
"""Return a formatted string summarizing class metrics of YOLO model."""
|
"""Return a formatted string summarizing class metrics of YOLO model."""
|
||||||
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
|
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
||||||
|
|
||||||
def postprocess(self, preds):
|
def postprocess(self, preds):
|
||||||
"""Apply Non-maximum suppression to prediction outputs."""
|
"""Apply Non-maximum suppression to prediction outputs."""
|
||||||
return ops.non_max_suppression(preds,
|
return ops.non_max_suppression(
|
||||||
|
preds,
|
||||||
self.args.conf,
|
self.args.conf,
|
||||||
self.args.iou,
|
self.args.iou,
|
||||||
labels=self.lb,
|
labels=self.lb,
|
||||||
multi_label=True,
|
multi_label=True,
|
||||||
agnostic=self.args.single_cls,
|
agnostic=self.args.single_cls,
|
||||||
max_det=self.args.max_det)
|
max_det=self.args.max_det,
|
||||||
|
)
|
||||||
|
|
||||||
def _prepare_batch(self, si, batch):
|
def _prepare_batch(self, si, batch):
|
||||||
"""Prepares a batch of images and annotations for validation."""
|
"""Prepares a batch of images and annotations for validation."""
|
||||||
idx = batch['batch_idx'] == si
|
idx = batch["batch_idx"] == si
|
||||||
cls = batch['cls'][idx].squeeze(-1)
|
cls = batch["cls"][idx].squeeze(-1)
|
||||||
bbox = batch['bboxes'][idx]
|
bbox = batch["bboxes"][idx]
|
||||||
ori_shape = batch['ori_shape'][si]
|
ori_shape = batch["ori_shape"][si]
|
||||||
imgsz = batch['img'].shape[2:]
|
imgsz = batch["img"].shape[2:]
|
||||||
ratio_pad = batch['ratio_pad'][si]
|
ratio_pad = batch["ratio_pad"][si]
|
||||||
if len(cls):
|
if len(cls):
|
||||||
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
|
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
|
||||||
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
|
||||||
@ -103,8 +110,9 @@ class DetectionValidator(BaseValidator):
|
|||||||
def _prepare_pred(self, pred, pbatch):
|
def _prepare_pred(self, pred, pbatch):
|
||||||
"""Prepares a batch of images and annotations for validation."""
|
"""Prepares a batch of images and annotations for validation."""
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'],
|
ops.scale_boxes(
|
||||||
ratio_pad=pbatch['ratio_pad']) # native-space pred
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
|
||||||
|
) # native-space pred
|
||||||
return predn
|
return predn
|
||||||
|
|
||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
@ -112,19 +120,21 @@ class DetectionValidator(BaseValidator):
|
|||||||
for si, pred in enumerate(preds):
|
for si, pred in enumerate(preds):
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
npr = len(pred)
|
npr = len(pred)
|
||||||
stat = dict(conf=torch.zeros(0, device=self.device),
|
stat = dict(
|
||||||
|
conf=torch.zeros(0, device=self.device),
|
||||||
pred_cls=torch.zeros(0, device=self.device),
|
pred_cls=torch.zeros(0, device=self.device),
|
||||||
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
|
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
||||||
|
)
|
||||||
pbatch = self._prepare_batch(si, batch)
|
pbatch = self._prepare_batch(si, batch)
|
||||||
cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
|
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
||||||
nl = len(cls)
|
nl = len(cls)
|
||||||
stat['target_cls'] = cls
|
stat["target_cls"] = cls
|
||||||
if npr == 0:
|
if npr == 0:
|
||||||
if nl:
|
if nl:
|
||||||
for k in self.stats.keys():
|
for k in self.stats.keys():
|
||||||
self.stats[k].append(stat[k])
|
self.stats[k].append(stat[k])
|
||||||
# TODO: obb has not supported confusion_matrix yet.
|
# TODO: obb has not supported confusion_matrix yet.
|
||||||
if self.args.plots and self.args.task != 'obb':
|
if self.args.plots and self.args.task != "obb":
|
||||||
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -132,24 +142,24 @@ class DetectionValidator(BaseValidator):
|
|||||||
if self.args.single_cls:
|
if self.args.single_cls:
|
||||||
pred[:, 5] = 0
|
pred[:, 5] = 0
|
||||||
predn = self._prepare_pred(pred, pbatch)
|
predn = self._prepare_pred(pred, pbatch)
|
||||||
stat['conf'] = predn[:, 4]
|
stat["conf"] = predn[:, 4]
|
||||||
stat['pred_cls'] = predn[:, 5]
|
stat["pred_cls"] = predn[:, 5]
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if nl:
|
if nl:
|
||||||
stat['tp'] = self._process_batch(predn, bbox, cls)
|
stat["tp"] = self._process_batch(predn, bbox, cls)
|
||||||
# TODO: obb has not supported confusion_matrix yet.
|
# TODO: obb has not supported confusion_matrix yet.
|
||||||
if self.args.plots and self.args.task != 'obb':
|
if self.args.plots and self.args.task != "obb":
|
||||||
self.confusion_matrix.process_batch(predn, bbox, cls)
|
self.confusion_matrix.process_batch(predn, bbox, cls)
|
||||||
for k in self.stats.keys():
|
for k in self.stats.keys():
|
||||||
self.stats[k].append(stat[k])
|
self.stats[k].append(stat[k])
|
||||||
|
|
||||||
# Save
|
# Save
|
||||||
if self.args.save_json:
|
if self.args.save_json:
|
||||||
self.pred_to_json(predn, batch['im_file'][si])
|
self.pred_to_json(predn, batch["im_file"][si])
|
||||||
if self.args.save_txt:
|
if self.args.save_txt:
|
||||||
file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
|
file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
|
||||||
self.save_one_txt(predn, self.args.save_conf, pbatch['ori_shape'], file)
|
self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
|
||||||
|
|
||||||
def finalize_metrics(self, *args, **kwargs):
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
"""Set final values for metrics speed and confusion matrix."""
|
"""Set final values for metrics speed and confusion matrix."""
|
||||||
@ -159,19 +169,19 @@ class DetectionValidator(BaseValidator):
|
|||||||
def get_stats(self):
|
def get_stats(self):
|
||||||
"""Returns metrics statistics and results dictionary."""
|
"""Returns metrics statistics and results dictionary."""
|
||||||
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
|
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
|
||||||
if len(stats) and stats['tp'].any():
|
if len(stats) and stats["tp"].any():
|
||||||
self.metrics.process(**stats)
|
self.metrics.process(**stats)
|
||||||
self.nt_per_class = np.bincount(stats['target_cls'].astype(int),
|
self.nt_per_class = np.bincount(
|
||||||
minlength=self.nc) # number of targets per class
|
stats["target_cls"].astype(int), minlength=self.nc
|
||||||
|
) # number of targets per class
|
||||||
return self.metrics.results_dict
|
return self.metrics.results_dict
|
||||||
|
|
||||||
def print_results(self):
|
def print_results(self):
|
||||||
"""Prints training/validation set metrics per class."""
|
"""Prints training/validation set metrics per class."""
|
||||||
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
|
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
|
||||||
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||||
if self.nt_per_class.sum() == 0:
|
if self.nt_per_class.sum() == 0:
|
||||||
LOGGER.warning(
|
LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels")
|
||||||
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
|
|
||||||
|
|
||||||
# Print results per class
|
# Print results per class
|
||||||
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
|
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
|
||||||
@ -180,10 +190,9 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
for normalize in True, False:
|
for normalize in True, False:
|
||||||
self.confusion_matrix.plot(save_dir=self.save_dir,
|
self.confusion_matrix.plot(
|
||||||
names=self.names.values(),
|
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
||||||
normalize=normalize,
|
)
|
||||||
on_plot=self.on_plot)
|
|
||||||
|
|
||||||
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
||||||
"""
|
"""
|
||||||
@ -201,7 +210,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
iou = box_iou(gt_bboxes, detections[:, :4])
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
||||||
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
||||||
|
|
||||||
def build_dataset(self, img_path, mode='val', batch=None):
|
def build_dataset(self, img_path, mode="val", batch=None):
|
||||||
"""
|
"""
|
||||||
Build YOLO Dataset.
|
Build YOLO Dataset.
|
||||||
|
|
||||||
@ -214,28 +223,32 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size):
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
"""Construct and return dataloader."""
|
"""Construct and return dataloader."""
|
||||||
dataset = self.build_dataset(dataset_path, batch=batch_size, mode='val')
|
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
||||||
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
|
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
|
||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
"""Plot validation image samples."""
|
"""Plot validation image samples."""
|
||||||
plot_images(batch['img'],
|
plot_images(
|
||||||
batch['batch_idx'],
|
batch["img"],
|
||||||
batch['cls'].squeeze(-1),
|
batch["batch_idx"],
|
||||||
batch['bboxes'],
|
batch["cls"].squeeze(-1),
|
||||||
paths=batch['im_file'],
|
batch["bboxes"],
|
||||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||||
names=self.names,
|
names=self.names,
|
||||||
on_plot=self.on_plot)
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
def plot_predictions(self, batch, preds, ni):
|
def plot_predictions(self, batch, preds, ni):
|
||||||
"""Plots predicted bounding boxes on input images and saves the result."""
|
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||||
plot_images(batch['img'],
|
plot_images(
|
||||||
|
batch["img"],
|
||||||
*output_to_target(preds, max_det=self.args.max_det),
|
*output_to_target(preds, max_det=self.args.max_det),
|
||||||
paths=batch['im_file'],
|
paths=batch["im_file"],
|
||||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||||
names=self.names,
|
names=self.names,
|
||||||
on_plot=self.on_plot) # pred
|
on_plot=self.on_plot,
|
||||||
|
) # pred
|
||||||
|
|
||||||
def save_one_txt(self, predn, save_conf, shape, file):
|
def save_one_txt(self, predn, save_conf, shape, file):
|
||||||
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
||||||
@ -243,8 +256,8 @@ class DetectionValidator(BaseValidator):
|
|||||||
for *xyxy, conf, cls in predn.tolist():
|
for *xyxy, conf, cls in predn.tolist():
|
||||||
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||||
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
|
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
|
||||||
with open(file, 'a') as f:
|
with open(file, "a") as f:
|
||||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
f.write(("%g " * len(line)).rstrip() % line + "\n")
|
||||||
|
|
||||||
def pred_to_json(self, predn, filename):
|
def pred_to_json(self, predn, filename):
|
||||||
"""Serialize YOLO predictions to COCO json format."""
|
"""Serialize YOLO predictions to COCO json format."""
|
||||||
@ -253,28 +266,31 @@ class DetectionValidator(BaseValidator):
|
|||||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||||
for p, b in zip(predn.tolist(), box.tolist()):
|
for p, b in zip(predn.tolist(), box.tolist()):
|
||||||
self.jdict.append({
|
self.jdict.append(
|
||||||
'image_id': image_id,
|
{
|
||||||
'category_id': self.class_map[int(p[5])],
|
"image_id": image_id,
|
||||||
'bbox': [round(x, 3) for x in b],
|
"category_id": self.class_map[int(p[5])],
|
||||||
'score': round(p[4], 5)})
|
"bbox": [round(x, 3) for x in b],
|
||||||
|
"score": round(p[4], 5),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def eval_json(self, stats):
|
def eval_json(self, stats):
|
||||||
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
||||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
||||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
pred_json = self.save_dir / "predictions.json" # predictions
|
||||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
||||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||||
check_requirements('pycocotools>=2.0.6')
|
check_requirements("pycocotools>=2.0.6")
|
||||||
from pycocotools.coco import COCO # noqa
|
from pycocotools.coco import COCO # noqa
|
||||||
from pycocotools.cocoeval import COCOeval # noqa
|
from pycocotools.cocoeval import COCOeval # noqa
|
||||||
|
|
||||||
for x in anno_json, pred_json:
|
for x in anno_json, pred_json:
|
||||||
assert x.is_file(), f'{x} file not found'
|
assert x.is_file(), f"{x} file not found"
|
||||||
anno = COCO(str(anno_json)) # init annotations api
|
anno = COCO(str(anno_json)) # init annotations api
|
||||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||||
eval = COCOeval(anno, pred, 'bbox')
|
eval = COCOeval(anno, pred, "bbox")
|
||||||
if self.is_coco:
|
if self.is_coco:
|
||||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||||
eval.evaluate()
|
eval.evaluate()
|
||||||
@ -282,5 +298,5 @@ class DetectionValidator(BaseValidator):
|
|||||||
eval.summarize()
|
eval.summarize()
|
||||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
LOGGER.warning(f"pycocotools unable to run: {e}")
|
||||||
return stats
|
return stats
|
||||||
|
|||||||
@ -12,28 +12,34 @@ class YOLO(Model):
|
|||||||
def task_map(self):
|
def task_map(self):
|
||||||
"""Map head to model, trainer, validator, and predictor classes."""
|
"""Map head to model, trainer, validator, and predictor classes."""
|
||||||
return {
|
return {
|
||||||
'classify': {
|
"classify": {
|
||||||
'model': ClassificationModel,
|
"model": ClassificationModel,
|
||||||
'trainer': yolo.classify.ClassificationTrainer,
|
"trainer": yolo.classify.ClassificationTrainer,
|
||||||
'validator': yolo.classify.ClassificationValidator,
|
"validator": yolo.classify.ClassificationValidator,
|
||||||
'predictor': yolo.classify.ClassificationPredictor, },
|
"predictor": yolo.classify.ClassificationPredictor,
|
||||||
'detect': {
|
},
|
||||||
'model': DetectionModel,
|
"detect": {
|
||||||
'trainer': yolo.detect.DetectionTrainer,
|
"model": DetectionModel,
|
||||||
'validator': yolo.detect.DetectionValidator,
|
"trainer": yolo.detect.DetectionTrainer,
|
||||||
'predictor': yolo.detect.DetectionPredictor, },
|
"validator": yolo.detect.DetectionValidator,
|
||||||
'segment': {
|
"predictor": yolo.detect.DetectionPredictor,
|
||||||
'model': SegmentationModel,
|
},
|
||||||
'trainer': yolo.segment.SegmentationTrainer,
|
"segment": {
|
||||||
'validator': yolo.segment.SegmentationValidator,
|
"model": SegmentationModel,
|
||||||
'predictor': yolo.segment.SegmentationPredictor, },
|
"trainer": yolo.segment.SegmentationTrainer,
|
||||||
'pose': {
|
"validator": yolo.segment.SegmentationValidator,
|
||||||
'model': PoseModel,
|
"predictor": yolo.segment.SegmentationPredictor,
|
||||||
'trainer': yolo.pose.PoseTrainer,
|
},
|
||||||
'validator': yolo.pose.PoseValidator,
|
"pose": {
|
||||||
'predictor': yolo.pose.PosePredictor, },
|
"model": PoseModel,
|
||||||
'obb': {
|
"trainer": yolo.pose.PoseTrainer,
|
||||||
'model': OBBModel,
|
"validator": yolo.pose.PoseValidator,
|
||||||
'trainer': yolo.obb.OBBTrainer,
|
"predictor": yolo.pose.PosePredictor,
|
||||||
'validator': yolo.obb.OBBValidator,
|
},
|
||||||
'predictor': yolo.obb.OBBPredictor, }, }
|
"obb": {
|
||||||
|
"model": OBBModel,
|
||||||
|
"trainer": yolo.obb.OBBTrainer,
|
||||||
|
"validator": yolo.obb.OBBValidator,
|
||||||
|
"predictor": yolo.obb.OBBPredictor,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|||||||
@ -4,4 +4,4 @@ from .predict import OBBPredictor
|
|||||||
from .train import OBBTrainer
|
from .train import OBBTrainer
|
||||||
from .val import OBBValidator
|
from .val import OBBValidator
|
||||||
|
|
||||||
__all__ = 'OBBPredictor', 'OBBTrainer', 'OBBValidator'
|
__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"
|
||||||
|
|||||||
@ -25,26 +25,27 @@ class OBBPredictor(DetectionPredictor):
|
|||||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
"""Initializes OBBPredictor with optional model and data configuration overrides."""
|
"""Initializes OBBPredictor with optional model and data configuration overrides."""
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
self.args.task = 'obb'
|
self.args.task = "obb"
|
||||||
|
|
||||||
def postprocess(self, preds, img, orig_imgs):
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
"""Post-processes predictions and returns a list of Results objects."""
|
"""Post-processes predictions and returns a list of Results objects."""
|
||||||
preds = ops.non_max_suppression(preds,
|
preds = ops.non_max_suppression(
|
||||||
|
preds,
|
||||||
self.args.conf,
|
self.args.conf,
|
||||||
self.args.iou,
|
self.args.iou,
|
||||||
agnostic=self.args.agnostic_nms,
|
agnostic=self.args.agnostic_nms,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
nc=len(self.model.names),
|
nc=len(self.model.names),
|
||||||
classes=self.args.classes,
|
classes=self.args.classes,
|
||||||
rotated=True)
|
rotated=True,
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i, (pred, orig_img) in enumerate(zip(preds, orig_imgs)):
|
for i, (pred, orig_img, img_path) in enumerate(zip(preds, orig_imgs, self.batch[0])):
|
||||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
|
||||||
img_path = self.batch[0][i]
|
|
||||||
# xywh, r, conf, cls
|
# xywh, r, conf, cls
|
||||||
obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
|
obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
|
||||||
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
|
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
|
||||||
|
|||||||
@ -25,12 +25,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|||||||
"""Initialize a OBBTrainer object with given arguments."""
|
"""Initialize a OBBTrainer object with given arguments."""
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides['task'] = 'obb'
|
overrides["task"] = "obb"
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
"""Return OBBModel initialized with specified config and weights."""
|
"""Return OBBModel initialized with specified config and weights."""
|
||||||
model = OBBModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
@ -38,5 +38,5 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
"""Return an instance of OBBValidator for validation of YOLO model."""
|
"""Return an instance of OBBValidator for validation of YOLO model."""
|
||||||
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
|
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
||||||
return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||||
|
|||||||
@ -27,18 +27,19 @@ class OBBValidator(DetectionValidator):
|
|||||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||||
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
|
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
|
||||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||||
self.args.task = 'obb'
|
self.args.task = "obb"
|
||||||
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
|
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
|
||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
"""Initialize evaluation metrics for YOLO."""
|
"""Initialize evaluation metrics for YOLO."""
|
||||||
super().init_metrics(model)
|
super().init_metrics(model)
|
||||||
val = self.data.get(self.args.split, '') # validation path
|
val = self.data.get(self.args.split, "") # validation path
|
||||||
self.is_dota = isinstance(val, str) and 'DOTA' in val # is COCO
|
self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
|
||||||
|
|
||||||
def postprocess(self, preds):
|
def postprocess(self, preds):
|
||||||
"""Apply Non-maximum suppression to prediction outputs."""
|
"""Apply Non-maximum suppression to prediction outputs."""
|
||||||
return ops.non_max_suppression(preds,
|
return ops.non_max_suppression(
|
||||||
|
preds,
|
||||||
self.args.conf,
|
self.args.conf,
|
||||||
self.args.iou,
|
self.args.iou,
|
||||||
labels=self.lb,
|
labels=self.lb,
|
||||||
@ -46,7 +47,8 @@ class OBBValidator(DetectionValidator):
|
|||||||
multi_label=True,
|
multi_label=True,
|
||||||
agnostic=self.args.single_cls,
|
agnostic=self.args.single_cls,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
rotated=True)
|
rotated=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
||||||
"""
|
"""
|
||||||
@ -66,12 +68,12 @@ class OBBValidator(DetectionValidator):
|
|||||||
|
|
||||||
def _prepare_batch(self, si, batch):
|
def _prepare_batch(self, si, batch):
|
||||||
"""Prepares and returns a batch for OBB validation."""
|
"""Prepares and returns a batch for OBB validation."""
|
||||||
idx = batch['batch_idx'] == si
|
idx = batch["batch_idx"] == si
|
||||||
cls = batch['cls'][idx].squeeze(-1)
|
cls = batch["cls"][idx].squeeze(-1)
|
||||||
bbox = batch['bboxes'][idx]
|
bbox = batch["bboxes"][idx]
|
||||||
ori_shape = batch['ori_shape'][si]
|
ori_shape = batch["ori_shape"][si]
|
||||||
imgsz = batch['img'].shape[2:]
|
imgsz = batch["img"].shape[2:]
|
||||||
ratio_pad = batch['ratio_pad'][si]
|
ratio_pad = batch["ratio_pad"][si]
|
||||||
if len(cls):
|
if len(cls):
|
||||||
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
|
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
|
||||||
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
|
||||||
@ -81,18 +83,21 @@ class OBBValidator(DetectionValidator):
|
|||||||
def _prepare_pred(self, pred, pbatch):
|
def _prepare_pred(self, pred, pbatch):
|
||||||
"""Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
|
"""Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'],
|
ops.scale_boxes(
|
||||||
xywh=True) # native-space pred
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
||||||
|
) # native-space pred
|
||||||
return predn
|
return predn
|
||||||
|
|
||||||
def plot_predictions(self, batch, preds, ni):
|
def plot_predictions(self, batch, preds, ni):
|
||||||
"""Plots predicted bounding boxes on input images and saves the result."""
|
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||||
plot_images(batch['img'],
|
plot_images(
|
||||||
|
batch["img"],
|
||||||
*output_to_rotated_target(preds, max_det=self.args.max_det),
|
*output_to_rotated_target(preds, max_det=self.args.max_det),
|
||||||
paths=batch['im_file'],
|
paths=batch["im_file"],
|
||||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||||
names=self.names,
|
names=self.names,
|
||||||
on_plot=self.on_plot) # pred
|
on_plot=self.on_plot,
|
||||||
|
) # pred
|
||||||
|
|
||||||
def pred_to_json(self, predn, filename):
|
def pred_to_json(self, predn, filename):
|
||||||
"""Serialize YOLO predictions to COCO json format."""
|
"""Serialize YOLO predictions to COCO json format."""
|
||||||
@ -101,12 +106,15 @@ class OBBValidator(DetectionValidator):
|
|||||||
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
|
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
|
||||||
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
|
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
|
||||||
for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
|
for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
|
||||||
self.jdict.append({
|
self.jdict.append(
|
||||||
'image_id': image_id,
|
{
|
||||||
'category_id': self.class_map[int(predn[i, 5].item())],
|
"image_id": image_id,
|
||||||
'score': round(predn[i, 4].item(), 5),
|
"category_id": self.class_map[int(predn[i, 5].item())],
|
||||||
'rbox': [round(x, 3) for x in r],
|
"score": round(predn[i, 4].item(), 5),
|
||||||
'poly': [round(x, 3) for x in b]})
|
"rbox": [round(x, 3) for x in r],
|
||||||
|
"poly": [round(x, 3) for x in b],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def save_one_txt(self, predn, save_conf, shape, file):
|
def save_one_txt(self, predn, save_conf, shape, file):
|
||||||
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
||||||
@ -116,8 +124,8 @@ class OBBValidator(DetectionValidator):
|
|||||||
xywha[:, :4] /= gn
|
xywha[:, :4] /= gn
|
||||||
xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh
|
xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh
|
||||||
line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
|
line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
|
||||||
with open(file, 'a') as f:
|
with open(file, "a") as f:
|
||||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
f.write(("%g " * len(line)).rstrip() % line + "\n")
|
||||||
|
|
||||||
def eval_json(self, stats):
|
def eval_json(self, stats):
|
||||||
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
||||||
@ -125,42 +133,43 @@ class OBBValidator(DetectionValidator):
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
|
||||||
pred_txt = self.save_dir / 'predictions_txt' # predictions
|
pred_json = self.save_dir / "predictions.json" # predictions
|
||||||
|
pred_txt = self.save_dir / "predictions_txt" # predictions
|
||||||
pred_txt.mkdir(parents=True, exist_ok=True)
|
pred_txt.mkdir(parents=True, exist_ok=True)
|
||||||
data = json.load(open(pred_json))
|
data = json.load(open(pred_json))
|
||||||
# Save split results
|
# Save split results
|
||||||
LOGGER.info(f'Saving predictions with DOTA format to {str(pred_txt)}...')
|
LOGGER.info(f"Saving predictions with DOTA format to {str(pred_txt)}...")
|
||||||
for d in data:
|
for d in data:
|
||||||
image_id = d['image_id']
|
image_id = d["image_id"]
|
||||||
score = d['score']
|
score = d["score"]
|
||||||
classname = self.names[d['category_id']].replace(' ', '-')
|
classname = self.names[d["category_id"]].replace(" ", "-")
|
||||||
|
|
||||||
lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
|
lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
|
||||||
image_id,
|
image_id,
|
||||||
score,
|
score,
|
||||||
d['poly'][0],
|
d["poly"][0],
|
||||||
d['poly'][1],
|
d["poly"][1],
|
||||||
d['poly'][2],
|
d["poly"][2],
|
||||||
d['poly'][3],
|
d["poly"][3],
|
||||||
d['poly'][4],
|
d["poly"][4],
|
||||||
d['poly'][5],
|
d["poly"][5],
|
||||||
d['poly'][6],
|
d["poly"][6],
|
||||||
d['poly'][7],
|
d["poly"][7],
|
||||||
)
|
)
|
||||||
with open(str(pred_txt / f'Task1_{classname}') + '.txt', 'a') as f:
|
with open(str(pred_txt / f"Task1_{classname}") + ".txt", "a") as f:
|
||||||
f.writelines(lines)
|
f.writelines(lines)
|
||||||
# Save merged results, this could result slightly lower map than using official merging script,
|
# Save merged results, this could result slightly lower map than using official merging script,
|
||||||
# because of the probiou calculation.
|
# because of the probiou calculation.
|
||||||
pred_merged_txt = self.save_dir / 'predictions_merged_txt' # predictions
|
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
|
||||||
pred_merged_txt.mkdir(parents=True, exist_ok=True)
|
pred_merged_txt.mkdir(parents=True, exist_ok=True)
|
||||||
merged_results = defaultdict(list)
|
merged_results = defaultdict(list)
|
||||||
LOGGER.info(f'Saving merged predictions with DOTA format to {str(pred_merged_txt)}...')
|
LOGGER.info(f"Saving merged predictions with DOTA format to {str(pred_merged_txt)}...")
|
||||||
for d in data:
|
for d in data:
|
||||||
image_id = d['image_id'].split('__')[0]
|
image_id = d["image_id"].split("__")[0]
|
||||||
pattern = re.compile(r'\d+___\d+')
|
pattern = re.compile(r"\d+___\d+")
|
||||||
x, y = (int(c) for c in re.findall(pattern, d['image_id'])[0].split('___'))
|
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
|
||||||
bbox, score, cls = d['rbox'], d['score'], d['category_id']
|
bbox, score, cls = d["rbox"], d["score"], d["category_id"]
|
||||||
bbox[0] += x
|
bbox[0] += x
|
||||||
bbox[1] += y
|
bbox[1] += y
|
||||||
bbox.extend([score, cls])
|
bbox.extend([score, cls])
|
||||||
@ -178,11 +187,11 @@ class OBBValidator(DetectionValidator):
|
|||||||
|
|
||||||
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
|
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
|
||||||
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
|
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
|
||||||
classname = self.names[int(x[-1])].replace(' ', '-')
|
classname = self.names[int(x[-1])].replace(" ", "-")
|
||||||
poly = [round(i, 3) for i in x[:-2]]
|
poly = [round(i, 3) for i in x[:-2]]
|
||||||
score = round(x[-2], 3)
|
score = round(x[-2], 3)
|
||||||
|
|
||||||
lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
|
lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
|
||||||
image_id,
|
image_id,
|
||||||
score,
|
score,
|
||||||
poly[0],
|
poly[0],
|
||||||
@ -194,7 +203,7 @@ class OBBValidator(DetectionValidator):
|
|||||||
poly[6],
|
poly[6],
|
||||||
poly[7],
|
poly[7],
|
||||||
)
|
)
|
||||||
with open(str(pred_merged_txt / f'Task1_{classname}') + '.txt', 'a') as f:
|
with open(str(pred_merged_txt / f"Task1_{classname}") + ".txt", "a") as f:
|
||||||
f.writelines(lines)
|
f.writelines(lines)
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|||||||
@ -4,4 +4,4 @@ from .predict import PosePredictor
|
|||||||
from .train import PoseTrainer
|
from .train import PoseTrainer
|
||||||
from .val import PoseValidator
|
from .val import PoseValidator
|
||||||
|
|
||||||
__all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor'
|
__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"
|
||||||
|
|||||||
@ -23,20 +23,24 @@ class PosePredictor(DetectionPredictor):
|
|||||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
"""Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
|
"""Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
self.args.task = 'pose'
|
self.args.task = "pose"
|
||||||
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
|
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||||
LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
LOGGER.warning(
|
||||||
'See https://github.com/ultralytics/ultralytics/issues/4031.')
|
"WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||||
|
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||||
|
)
|
||||||
|
|
||||||
def postprocess(self, preds, img, orig_imgs):
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
"""Return detection results for a given input image or list of images."""
|
"""Return detection results for a given input image or list of images."""
|
||||||
preds = ops.non_max_suppression(preds,
|
preds = ops.non_max_suppression(
|
||||||
|
preds,
|
||||||
self.args.conf,
|
self.args.conf,
|
||||||
self.args.iou,
|
self.args.iou,
|
||||||
agnostic=self.args.agnostic_nms,
|
agnostic=self.args.agnostic_nms,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
classes=self.args.classes,
|
classes=self.args.classes,
|
||||||
nc=len(self.model.names))
|
nc=len(self.model.names),
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
@ -49,5 +53,6 @@ class PosePredictor(DetectionPredictor):
|
|||||||
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
||||||
img_path = self.batch[0][i]
|
img_path = self.batch[0][i]
|
||||||
results.append(
|
results.append(
|
||||||
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts))
|
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|||||||
@ -26,16 +26,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|||||||
"""Initialize a PoseTrainer object with specified configurations and overrides."""
|
"""Initialize a PoseTrainer object with specified configurations and overrides."""
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides['task'] = 'pose'
|
overrides["task"] = "pose"
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
|
||||||
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
|
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||||
LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
LOGGER.warning(
|
||||||
'See https://github.com/ultralytics/ultralytics/issues/4031.')
|
"WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||||
|
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||||
|
)
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
"""Get pose estimation model with specified configuration and weights."""
|
"""Get pose estimation model with specified configuration and weights."""
|
||||||
model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
|
model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
@ -44,32 +46,33 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
"""Sets keypoints shape attribute of PoseModel."""
|
"""Sets keypoints shape attribute of PoseModel."""
|
||||||
super().set_model_attributes()
|
super().set_model_attributes()
|
||||||
self.model.kpt_shape = self.data['kpt_shape']
|
self.model.kpt_shape = self.data["kpt_shape"]
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
"""Returns an instance of the PoseValidator class for validation."""
|
"""Returns an instance of the PoseValidator class for validation."""
|
||||||
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
|
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
|
||||||
return yolo.pose.PoseValidator(self.test_loader,
|
return yolo.pose.PoseValidator(
|
||||||
save_dir=self.save_dir,
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
args=copy(self.args),
|
)
|
||||||
_callbacks=self.callbacks)
|
|
||||||
|
|
||||||
def plot_training_samples(self, batch, ni):
|
def plot_training_samples(self, batch, ni):
|
||||||
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
|
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
|
||||||
images = batch['img']
|
images = batch["img"]
|
||||||
kpts = batch['keypoints']
|
kpts = batch["keypoints"]
|
||||||
cls = batch['cls'].squeeze(-1)
|
cls = batch["cls"].squeeze(-1)
|
||||||
bboxes = batch['bboxes']
|
bboxes = batch["bboxes"]
|
||||||
paths = batch['im_file']
|
paths = batch["im_file"]
|
||||||
batch_idx = batch['batch_idx']
|
batch_idx = batch["batch_idx"]
|
||||||
plot_images(images,
|
plot_images(
|
||||||
|
images,
|
||||||
batch_idx,
|
batch_idx,
|
||||||
cls,
|
cls,
|
||||||
bboxes,
|
bboxes,
|
||||||
kpts=kpts,
|
kpts=kpts,
|
||||||
paths=paths,
|
paths=paths,
|
||||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||||
on_plot=self.on_plot)
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
def plot_metrics(self):
|
def plot_metrics(self):
|
||||||
"""Plots training/val metrics."""
|
"""Plots training/val metrics."""
|
||||||
|
|||||||
@ -31,38 +31,53 @@ class PoseValidator(DetectionValidator):
|
|||||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||||
self.sigma = None
|
self.sigma = None
|
||||||
self.kpt_shape = None
|
self.kpt_shape = None
|
||||||
self.args.task = 'pose'
|
self.args.task = "pose"
|
||||||
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||||
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
|
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||||
LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
LOGGER.warning(
|
||||||
'See https://github.com/ultralytics/ultralytics/issues/4031.')
|
"WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||||
|
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||||
|
)
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
|
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
|
||||||
batch = super().preprocess(batch)
|
batch = super().preprocess(batch)
|
||||||
batch['keypoints'] = batch['keypoints'].to(self.device).float()
|
batch["keypoints"] = batch["keypoints"].to(self.device).float()
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def get_desc(self):
|
def get_desc(self):
|
||||||
"""Returns description of evaluation metrics in string format."""
|
"""Returns description of evaluation metrics in string format."""
|
||||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
|
return ("%22s" + "%11s" * 10) % (
|
||||||
'R', 'mAP50', 'mAP50-95)')
|
"Class",
|
||||||
|
"Images",
|
||||||
|
"Instances",
|
||||||
|
"Box(P",
|
||||||
|
"R",
|
||||||
|
"mAP50",
|
||||||
|
"mAP50-95)",
|
||||||
|
"Pose(P",
|
||||||
|
"R",
|
||||||
|
"mAP50",
|
||||||
|
"mAP50-95)",
|
||||||
|
)
|
||||||
|
|
||||||
def postprocess(self, preds):
|
def postprocess(self, preds):
|
||||||
"""Apply non-maximum suppression and return detections with high confidence scores."""
|
"""Apply non-maximum suppression and return detections with high confidence scores."""
|
||||||
return ops.non_max_suppression(preds,
|
return ops.non_max_suppression(
|
||||||
|
preds,
|
||||||
self.args.conf,
|
self.args.conf,
|
||||||
self.args.iou,
|
self.args.iou,
|
||||||
labels=self.lb,
|
labels=self.lb,
|
||||||
multi_label=True,
|
multi_label=True,
|
||||||
agnostic=self.args.single_cls,
|
agnostic=self.args.single_cls,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
nc=self.nc)
|
nc=self.nc,
|
||||||
|
)
|
||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
"""Initiate pose estimation metrics for YOLO model."""
|
"""Initiate pose estimation metrics for YOLO model."""
|
||||||
super().init_metrics(model)
|
super().init_metrics(model)
|
||||||
self.kpt_shape = self.data['kpt_shape']
|
self.kpt_shape = self.data["kpt_shape"]
|
||||||
is_pose = self.kpt_shape == [17, 3]
|
is_pose = self.kpt_shape == [17, 3]
|
||||||
nkpt = self.kpt_shape[0]
|
nkpt = self.kpt_shape[0]
|
||||||
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
||||||
@ -71,21 +86,21 @@ class PoseValidator(DetectionValidator):
|
|||||||
def _prepare_batch(self, si, batch):
|
def _prepare_batch(self, si, batch):
|
||||||
"""Prepares a batch for processing by converting keypoints to float and moving to device."""
|
"""Prepares a batch for processing by converting keypoints to float and moving to device."""
|
||||||
pbatch = super()._prepare_batch(si, batch)
|
pbatch = super()._prepare_batch(si, batch)
|
||||||
kpts = batch['keypoints'][batch['batch_idx'] == si]
|
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
||||||
h, w = pbatch['imgsz']
|
h, w = pbatch["imgsz"]
|
||||||
kpts = kpts.clone()
|
kpts = kpts.clone()
|
||||||
kpts[..., 0] *= w
|
kpts[..., 0] *= w
|
||||||
kpts[..., 1] *= h
|
kpts[..., 1] *= h
|
||||||
kpts = ops.scale_coords(pbatch['imgsz'], kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
|
kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
|
||||||
pbatch['kpts'] = kpts
|
pbatch["kpts"] = kpts
|
||||||
return pbatch
|
return pbatch
|
||||||
|
|
||||||
def _prepare_pred(self, pred, pbatch):
|
def _prepare_pred(self, pred, pbatch):
|
||||||
"""Prepares and scales keypoints in a batch for pose processing."""
|
"""Prepares and scales keypoints in a batch for pose processing."""
|
||||||
predn = super()._prepare_pred(pred, pbatch)
|
predn = super()._prepare_pred(pred, pbatch)
|
||||||
nk = pbatch['kpts'].shape[1]
|
nk = pbatch["kpts"].shape[1]
|
||||||
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
|
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
|
||||||
ops.scale_coords(pbatch['imgsz'], pred_kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
|
ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
|
||||||
return predn, pred_kpts
|
return predn, pred_kpts
|
||||||
|
|
||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
@ -93,14 +108,16 @@ class PoseValidator(DetectionValidator):
|
|||||||
for si, pred in enumerate(preds):
|
for si, pred in enumerate(preds):
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
npr = len(pred)
|
npr = len(pred)
|
||||||
stat = dict(conf=torch.zeros(0, device=self.device),
|
stat = dict(
|
||||||
|
conf=torch.zeros(0, device=self.device),
|
||||||
pred_cls=torch.zeros(0, device=self.device),
|
pred_cls=torch.zeros(0, device=self.device),
|
||||||
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
||||||
tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
|
tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
||||||
|
)
|
||||||
pbatch = self._prepare_batch(si, batch)
|
pbatch = self._prepare_batch(si, batch)
|
||||||
cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
|
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
||||||
nl = len(cls)
|
nl = len(cls)
|
||||||
stat['target_cls'] = cls
|
stat["target_cls"] = cls
|
||||||
if npr == 0:
|
if npr == 0:
|
||||||
if nl:
|
if nl:
|
||||||
for k in self.stats.keys():
|
for k in self.stats.keys():
|
||||||
@ -113,13 +130,13 @@ class PoseValidator(DetectionValidator):
|
|||||||
if self.args.single_cls:
|
if self.args.single_cls:
|
||||||
pred[:, 5] = 0
|
pred[:, 5] = 0
|
||||||
predn, pred_kpts = self._prepare_pred(pred, pbatch)
|
predn, pred_kpts = self._prepare_pred(pred, pbatch)
|
||||||
stat['conf'] = predn[:, 4]
|
stat["conf"] = predn[:, 4]
|
||||||
stat['pred_cls'] = predn[:, 5]
|
stat["pred_cls"] = predn[:, 5]
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if nl:
|
if nl:
|
||||||
stat['tp'] = self._process_batch(predn, bbox, cls)
|
stat["tp"] = self._process_batch(predn, bbox, cls)
|
||||||
stat['tp_p'] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch['kpts'])
|
stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.confusion_matrix.process_batch(predn, bbox, cls)
|
self.confusion_matrix.process_batch(predn, bbox, cls)
|
||||||
|
|
||||||
@ -128,7 +145,7 @@ class PoseValidator(DetectionValidator):
|
|||||||
|
|
||||||
# Save
|
# Save
|
||||||
if self.args.save_json:
|
if self.args.save_json:
|
||||||
self.pred_to_json(predn, batch['im_file'][si])
|
self.pred_to_json(predn, batch["im_file"][si])
|
||||||
# if self.args.save_txt:
|
# if self.args.save_txt:
|
||||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||||
|
|
||||||
@ -159,26 +176,30 @@ class PoseValidator(DetectionValidator):
|
|||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
|
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
|
||||||
plot_images(batch['img'],
|
plot_images(
|
||||||
batch['batch_idx'],
|
batch["img"],
|
||||||
batch['cls'].squeeze(-1),
|
batch["batch_idx"],
|
||||||
batch['bboxes'],
|
batch["cls"].squeeze(-1),
|
||||||
kpts=batch['keypoints'],
|
batch["bboxes"],
|
||||||
paths=batch['im_file'],
|
kpts=batch["keypoints"],
|
||||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||||
names=self.names,
|
names=self.names,
|
||||||
on_plot=self.on_plot)
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
def plot_predictions(self, batch, preds, ni):
|
def plot_predictions(self, batch, preds, ni):
|
||||||
"""Plots predictions for YOLO model."""
|
"""Plots predictions for YOLO model."""
|
||||||
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
||||||
plot_images(batch['img'],
|
plot_images(
|
||||||
|
batch["img"],
|
||||||
*output_to_target(preds, max_det=self.args.max_det),
|
*output_to_target(preds, max_det=self.args.max_det),
|
||||||
kpts=pred_kpts,
|
kpts=pred_kpts,
|
||||||
paths=batch['im_file'],
|
paths=batch["im_file"],
|
||||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||||
names=self.names,
|
names=self.names,
|
||||||
on_plot=self.on_plot) # pred
|
on_plot=self.on_plot,
|
||||||
|
) # pred
|
||||||
|
|
||||||
def pred_to_json(self, predn, filename):
|
def pred_to_json(self, predn, filename):
|
||||||
"""Converts YOLO predictions to COCO JSON format."""
|
"""Converts YOLO predictions to COCO JSON format."""
|
||||||
@ -187,37 +208,41 @@ class PoseValidator(DetectionValidator):
|
|||||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||||
for p, b in zip(predn.tolist(), box.tolist()):
|
for p, b in zip(predn.tolist(), box.tolist()):
|
||||||
self.jdict.append({
|
self.jdict.append(
|
||||||
'image_id': image_id,
|
{
|
||||||
'category_id': self.class_map[int(p[5])],
|
"image_id": image_id,
|
||||||
'bbox': [round(x, 3) for x in b],
|
"category_id": self.class_map[int(p[5])],
|
||||||
'keypoints': p[6:],
|
"bbox": [round(x, 3) for x in b],
|
||||||
'score': round(p[4], 5)})
|
"keypoints": p[6:],
|
||||||
|
"score": round(p[4], 5),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def eval_json(self, stats):
|
def eval_json(self, stats):
|
||||||
"""Evaluates object detection model using COCO JSON format."""
|
"""Evaluates object detection model using COCO JSON format."""
|
||||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||||
anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
|
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
|
||||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
pred_json = self.save_dir / "predictions.json" # predictions
|
||||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
||||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||||
check_requirements('pycocotools>=2.0.6')
|
check_requirements("pycocotools>=2.0.6")
|
||||||
from pycocotools.coco import COCO # noqa
|
from pycocotools.coco import COCO # noqa
|
||||||
from pycocotools.cocoeval import COCOeval # noqa
|
from pycocotools.cocoeval import COCOeval # noqa
|
||||||
|
|
||||||
for x in anno_json, pred_json:
|
for x in anno_json, pred_json:
|
||||||
assert x.is_file(), f'{x} file not found'
|
assert x.is_file(), f"{x} file not found"
|
||||||
anno = COCO(str(anno_json)) # init annotations api
|
anno = COCO(str(anno_json)) # init annotations api
|
||||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||||
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]):
|
for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
|
||||||
if self.is_coco:
|
if self.is_coco:
|
||||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
||||||
eval.evaluate()
|
eval.evaluate()
|
||||||
eval.accumulate()
|
eval.accumulate()
|
||||||
eval.summarize()
|
eval.summarize()
|
||||||
idx = i * 4 + 2
|
idx = i * 4 + 2
|
||||||
stats[self.metrics.keys[idx + 1]], stats[
|
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
|
||||||
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
:2
|
||||||
|
] # update mAP50-95 and mAP50
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
LOGGER.warning(f"pycocotools unable to run: {e}")
|
||||||
return stats
|
return stats
|
||||||
|
|||||||
@ -4,4 +4,4 @@ from .predict import SegmentationPredictor
|
|||||||
from .train import SegmentationTrainer
|
from .train import SegmentationTrainer
|
||||||
from .val import SegmentationValidator
|
from .val import SegmentationValidator
|
||||||
|
|
||||||
__all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator'
|
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
|
||||||
|
|||||||
@ -23,17 +23,19 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
"""Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
|
"""Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
self.args.task = 'segment'
|
self.args.task = "segment"
|
||||||
|
|
||||||
def postprocess(self, preds, img, orig_imgs):
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
"""Applies non-max suppression and processes detections for each image in an input batch."""
|
"""Applies non-max suppression and processes detections for each image in an input batch."""
|
||||||
p = ops.non_max_suppression(preds[0],
|
p = ops.non_max_suppression(
|
||||||
|
preds[0],
|
||||||
self.args.conf,
|
self.args.conf,
|
||||||
self.args.iou,
|
self.args.iou,
|
||||||
agnostic=self.args.agnostic_nms,
|
agnostic=self.args.agnostic_nms,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
nc=len(self.model.names),
|
nc=len(self.model.names),
|
||||||
classes=self.args.classes)
|
classes=self.args.classes,
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|||||||
@ -26,12 +26,12 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|||||||
"""Initialize a SegmentationTrainer object with given arguments."""
|
"""Initialize a SegmentationTrainer object with given arguments."""
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides['task'] = 'segment'
|
overrides["task"] = "segment"
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
"""Return SegmentationModel initialized with specified config and weights."""
|
"""Return SegmentationModel initialized with specified config and weights."""
|
||||||
model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
@ -39,22 +39,23 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
||||||
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
|
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
|
||||||
return yolo.segment.SegmentationValidator(self.test_loader,
|
return yolo.segment.SegmentationValidator(
|
||||||
save_dir=self.save_dir,
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
args=copy(self.args),
|
)
|
||||||
_callbacks=self.callbacks)
|
|
||||||
|
|
||||||
def plot_training_samples(self, batch, ni):
|
def plot_training_samples(self, batch, ni):
|
||||||
"""Creates a plot of training sample images with labels and box coordinates."""
|
"""Creates a plot of training sample images with labels and box coordinates."""
|
||||||
plot_images(batch['img'],
|
plot_images(
|
||||||
batch['batch_idx'],
|
batch["img"],
|
||||||
batch['cls'].squeeze(-1),
|
batch["batch_idx"],
|
||||||
batch['bboxes'],
|
batch["cls"].squeeze(-1),
|
||||||
masks=batch['masks'],
|
batch["bboxes"],
|
||||||
paths=batch['im_file'],
|
masks=batch["masks"],
|
||||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
paths=batch["im_file"],
|
||||||
on_plot=self.on_plot)
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
def plot_metrics(self):
|
def plot_metrics(self):
|
||||||
"""Plots training/val metrics."""
|
"""Plots training/val metrics."""
|
||||||
|
|||||||
@ -33,13 +33,13 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||||
self.plot_masks = None
|
self.plot_masks = None
|
||||||
self.process = None
|
self.process = None
|
||||||
self.args.task = 'segment'
|
self.args.task = "segment"
|
||||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
"""Preprocesses batch by converting masks to float and sending to device."""
|
"""Preprocesses batch by converting masks to float and sending to device."""
|
||||||
batch = super().preprocess(batch)
|
batch = super().preprocess(batch)
|
||||||
batch['masks'] = batch['masks'].to(self.device).float()
|
batch["masks"] = batch["masks"].to(self.device).float()
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
@ -47,7 +47,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
super().init_metrics(model)
|
super().init_metrics(model)
|
||||||
self.plot_masks = []
|
self.plot_masks = []
|
||||||
if self.args.save_json:
|
if self.args.save_json:
|
||||||
check_requirements('pycocotools>=2.0.6')
|
check_requirements("pycocotools>=2.0.6")
|
||||||
self.process = ops.process_mask_upsample # more accurate
|
self.process = ops.process_mask_upsample # more accurate
|
||||||
else:
|
else:
|
||||||
self.process = ops.process_mask # faster
|
self.process = ops.process_mask # faster
|
||||||
@ -55,33 +55,46 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
def get_desc(self):
|
def get_desc(self):
|
||||||
"""Return a formatted description of evaluation metrics."""
|
"""Return a formatted description of evaluation metrics."""
|
||||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
|
return ("%22s" + "%11s" * 10) % (
|
||||||
'R', 'mAP50', 'mAP50-95)')
|
"Class",
|
||||||
|
"Images",
|
||||||
|
"Instances",
|
||||||
|
"Box(P",
|
||||||
|
"R",
|
||||||
|
"mAP50",
|
||||||
|
"mAP50-95)",
|
||||||
|
"Mask(P",
|
||||||
|
"R",
|
||||||
|
"mAP50",
|
||||||
|
"mAP50-95)",
|
||||||
|
)
|
||||||
|
|
||||||
def postprocess(self, preds):
|
def postprocess(self, preds):
|
||||||
"""Post-processes YOLO predictions and returns output detections with proto."""
|
"""Post-processes YOLO predictions and returns output detections with proto."""
|
||||||
p = ops.non_max_suppression(preds[0],
|
p = ops.non_max_suppression(
|
||||||
|
preds[0],
|
||||||
self.args.conf,
|
self.args.conf,
|
||||||
self.args.iou,
|
self.args.iou,
|
||||||
labels=self.lb,
|
labels=self.lb,
|
||||||
multi_label=True,
|
multi_label=True,
|
||||||
agnostic=self.args.single_cls,
|
agnostic=self.args.single_cls,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
nc=self.nc)
|
nc=self.nc,
|
||||||
|
)
|
||||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||||
return p, proto
|
return p, proto
|
||||||
|
|
||||||
def _prepare_batch(self, si, batch):
|
def _prepare_batch(self, si, batch):
|
||||||
"""Prepares a batch for training or inference by processing images and targets."""
|
"""Prepares a batch for training or inference by processing images and targets."""
|
||||||
prepared_batch = super()._prepare_batch(si, batch)
|
prepared_batch = super()._prepare_batch(si, batch)
|
||||||
midx = [si] if self.args.overlap_mask else batch['batch_idx'] == si
|
midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
|
||||||
prepared_batch['masks'] = batch['masks'][midx]
|
prepared_batch["masks"] = batch["masks"][midx]
|
||||||
return prepared_batch
|
return prepared_batch
|
||||||
|
|
||||||
def _prepare_pred(self, pred, pbatch, proto):
|
def _prepare_pred(self, pred, pbatch, proto):
|
||||||
"""Prepares a batch for training or inference by processing images and targets."""
|
"""Prepares a batch for training or inference by processing images and targets."""
|
||||||
predn = super()._prepare_pred(pred, pbatch)
|
predn = super()._prepare_pred(pred, pbatch)
|
||||||
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch['imgsz'])
|
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
|
||||||
return predn, pred_masks
|
return predn, pred_masks
|
||||||
|
|
||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
@ -89,14 +102,16 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
npr = len(pred)
|
npr = len(pred)
|
||||||
stat = dict(conf=torch.zeros(0, device=self.device),
|
stat = dict(
|
||||||
|
conf=torch.zeros(0, device=self.device),
|
||||||
pred_cls=torch.zeros(0, device=self.device),
|
pred_cls=torch.zeros(0, device=self.device),
|
||||||
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
||||||
tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
|
tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
||||||
|
)
|
||||||
pbatch = self._prepare_batch(si, batch)
|
pbatch = self._prepare_batch(si, batch)
|
||||||
cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
|
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
||||||
nl = len(cls)
|
nl = len(cls)
|
||||||
stat['target_cls'] = cls
|
stat["target_cls"] = cls
|
||||||
if npr == 0:
|
if npr == 0:
|
||||||
if nl:
|
if nl:
|
||||||
for k in self.stats.keys():
|
for k in self.stats.keys():
|
||||||
@ -106,24 +121,20 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Masks
|
# Masks
|
||||||
gt_masks = pbatch.pop('masks')
|
gt_masks = pbatch.pop("masks")
|
||||||
# Predictions
|
# Predictions
|
||||||
if self.args.single_cls:
|
if self.args.single_cls:
|
||||||
pred[:, 5] = 0
|
pred[:, 5] = 0
|
||||||
predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
|
predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
|
||||||
stat['conf'] = predn[:, 4]
|
stat["conf"] = predn[:, 4]
|
||||||
stat['pred_cls'] = predn[:, 5]
|
stat["pred_cls"] = predn[:, 5]
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if nl:
|
if nl:
|
||||||
stat['tp'] = self._process_batch(predn, bbox, cls)
|
stat["tp"] = self._process_batch(predn, bbox, cls)
|
||||||
stat['tp_m'] = self._process_batch(predn,
|
stat["tp_m"] = self._process_batch(
|
||||||
bbox,
|
predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
|
||||||
cls,
|
)
|
||||||
pred_masks,
|
|
||||||
gt_masks,
|
|
||||||
self.args.overlap_mask,
|
|
||||||
masks=True)
|
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.confusion_matrix.process_batch(predn, bbox, cls)
|
self.confusion_matrix.process_batch(predn, bbox, cls)
|
||||||
|
|
||||||
@ -136,10 +147,12 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
# Save
|
# Save
|
||||||
if self.args.save_json:
|
if self.args.save_json:
|
||||||
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
pred_masks = ops.scale_image(
|
||||||
pbatch['ori_shape'],
|
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
||||||
ratio_pad=batch['ratio_pad'][si])
|
pbatch["ori_shape"],
|
||||||
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
|
ratio_pad=batch["ratio_pad"][si],
|
||||||
|
)
|
||||||
|
self.pred_to_json(predn, batch["im_file"][si], pred_masks)
|
||||||
# if self.args.save_txt:
|
# if self.args.save_txt:
|
||||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||||
|
|
||||||
@ -166,7 +179,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
||||||
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
||||||
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
||||||
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
|
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
|
||||||
gt_masks = gt_masks.gt_(0.5)
|
gt_masks = gt_masks.gt_(0.5)
|
||||||
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
||||||
else: # boxes
|
else: # boxes
|
||||||
@ -176,26 +189,29 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
"""Plots validation samples with bounding box labels."""
|
"""Plots validation samples with bounding box labels."""
|
||||||
plot_images(batch['img'],
|
plot_images(
|
||||||
batch['batch_idx'],
|
batch["img"],
|
||||||
batch['cls'].squeeze(-1),
|
batch["batch_idx"],
|
||||||
batch['bboxes'],
|
batch["cls"].squeeze(-1),
|
||||||
masks=batch['masks'],
|
batch["bboxes"],
|
||||||
paths=batch['im_file'],
|
masks=batch["masks"],
|
||||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||||
names=self.names,
|
names=self.names,
|
||||||
on_plot=self.on_plot)
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
def plot_predictions(self, batch, preds, ni):
|
def plot_predictions(self, batch, preds, ni):
|
||||||
"""Plots batch predictions with masks and bounding boxes."""
|
"""Plots batch predictions with masks and bounding boxes."""
|
||||||
plot_images(
|
plot_images(
|
||||||
batch['img'],
|
batch["img"],
|
||||||
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
|
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
|
||||||
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
|
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
|
||||||
paths=batch['im_file'],
|
paths=batch["im_file"],
|
||||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||||
names=self.names,
|
names=self.names,
|
||||||
on_plot=self.on_plot) # pred
|
on_plot=self.on_plot,
|
||||||
|
) # pred
|
||||||
self.plot_masks.clear()
|
self.plot_masks.clear()
|
||||||
|
|
||||||
def pred_to_json(self, predn, filename, pred_masks):
|
def pred_to_json(self, predn, filename, pred_masks):
|
||||||
@ -205,8 +221,8 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
def single_encode(x):
|
def single_encode(x):
|
||||||
"""Encode predicted masks as RLE and append results to jdict."""
|
"""Encode predicted masks as RLE and append results to jdict."""
|
||||||
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
|
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
|
||||||
rle['counts'] = rle['counts'].decode('utf-8')
|
rle["counts"] = rle["counts"].decode("utf-8")
|
||||||
return rle
|
return rle
|
||||||
|
|
||||||
stem = Path(filename).stem
|
stem = Path(filename).stem
|
||||||
@ -217,37 +233,41 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
with ThreadPool(NUM_THREADS) as pool:
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
rles = pool.map(single_encode, pred_masks)
|
rles = pool.map(single_encode, pred_masks)
|
||||||
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
|
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
|
||||||
self.jdict.append({
|
self.jdict.append(
|
||||||
'image_id': image_id,
|
{
|
||||||
'category_id': self.class_map[int(p[5])],
|
"image_id": image_id,
|
||||||
'bbox': [round(x, 3) for x in b],
|
"category_id": self.class_map[int(p[5])],
|
||||||
'score': round(p[4], 5),
|
"bbox": [round(x, 3) for x in b],
|
||||||
'segmentation': rles[i]})
|
"score": round(p[4], 5),
|
||||||
|
"segmentation": rles[i],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def eval_json(self, stats):
|
def eval_json(self, stats):
|
||||||
"""Return COCO-style object detection evaluation metrics."""
|
"""Return COCO-style object detection evaluation metrics."""
|
||||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
||||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
pred_json = self.save_dir / "predictions.json" # predictions
|
||||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
||||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||||
check_requirements('pycocotools>=2.0.6')
|
check_requirements("pycocotools>=2.0.6")
|
||||||
from pycocotools.coco import COCO # noqa
|
from pycocotools.coco import COCO # noqa
|
||||||
from pycocotools.cocoeval import COCOeval # noqa
|
from pycocotools.cocoeval import COCOeval # noqa
|
||||||
|
|
||||||
for x in anno_json, pred_json:
|
for x in anno_json, pred_json:
|
||||||
assert x.is_file(), f'{x} file not found'
|
assert x.is_file(), f"{x} file not found"
|
||||||
anno = COCO(str(anno_json)) # init annotations api
|
anno = COCO(str(anno_json)) # init annotations api
|
||||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||||
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
|
for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]):
|
||||||
if self.is_coco:
|
if self.is_coco:
|
||||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
||||||
eval.evaluate()
|
eval.evaluate()
|
||||||
eval.accumulate()
|
eval.accumulate()
|
||||||
eval.summarize()
|
eval.summarize()
|
||||||
idx = i * 4 + 2
|
idx = i * 4 + 2
|
||||||
stats[self.metrics.keys[idx + 1]], stats[
|
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
|
||||||
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
:2
|
||||||
|
] # update mAP50-95 and mAP50
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
LOGGER.warning(f"pycocotools unable to run: {e}")
|
||||||
return stats
|
return stats
|
||||||
|
|||||||
@ -1,9 +1,29 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
from .tasks import (
|
||||||
attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load,
|
BaseModel,
|
||||||
yaml_model_load)
|
ClassificationModel,
|
||||||
|
DetectionModel,
|
||||||
|
SegmentationModel,
|
||||||
|
attempt_load_one_weight,
|
||||||
|
attempt_load_weights,
|
||||||
|
guess_model_scale,
|
||||||
|
guess_model_task,
|
||||||
|
parse_model,
|
||||||
|
torch_safe_load,
|
||||||
|
yaml_model_load,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task',
|
__all__ = (
|
||||||
'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel',
|
"attempt_load_one_weight",
|
||||||
'BaseModel')
|
"attempt_load_weights",
|
||||||
|
"parse_model",
|
||||||
|
"yaml_model_load",
|
||||||
|
"guess_model_task",
|
||||||
|
"guess_model_scale",
|
||||||
|
"torch_safe_load",
|
||||||
|
"DetectionModel",
|
||||||
|
"SegmentationModel",
|
||||||
|
"ClassificationModel",
|
||||||
|
"BaseModel",
|
||||||
|
)
|
||||||
|
|||||||
@ -32,10 +32,12 @@ def check_class_names(names):
|
|||||||
names = {int(k): str(v) for k, v in names.items()}
|
names = {int(k): str(v) for k, v in names.items()}
|
||||||
n = len(names)
|
n = len(names)
|
||||||
if max(names.keys()) >= n:
|
if max(names.keys()) >= n:
|
||||||
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
|
raise KeyError(
|
||||||
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
|
f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices "
|
||||||
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
|
f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML."
|
||||||
names_map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
|
)
|
||||||
|
if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764'
|
||||||
|
names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names
|
||||||
names = {k: names_map[v] for k, v in names.items()}
|
names = {k: names_map[v] for k, v in names.items()}
|
||||||
return names
|
return names
|
||||||
|
|
||||||
@ -44,8 +46,8 @@ def default_class_names(data=None):
|
|||||||
"""Applies default class names to an input YAML file or returns numerical class names."""
|
"""Applies default class names to an input YAML file or returns numerical class names."""
|
||||||
if data:
|
if data:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
return yaml_load(check_yaml(data))['names']
|
return yaml_load(check_yaml(data))["names"]
|
||||||
return {i: f'class{i}' for i in range(999)} # return default if above errors
|
return {i: f"class{i}" for i in range(999)} # return default if above errors
|
||||||
|
|
||||||
|
|
||||||
class AutoBackend(nn.Module):
|
class AutoBackend(nn.Module):
|
||||||
@ -77,14 +79,16 @@ class AutoBackend(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __init__(self,
|
def __init__(
|
||||||
weights='yolov8n.pt',
|
self,
|
||||||
device=torch.device('cpu'),
|
weights="yolov8n.pt",
|
||||||
|
device=torch.device("cpu"),
|
||||||
dnn=False,
|
dnn=False,
|
||||||
data=None,
|
data=None,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
fuse=True,
|
fuse=True,
|
||||||
verbose=True):
|
verbose=True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the AutoBackend for inference.
|
Initialize the AutoBackend for inference.
|
||||||
|
|
||||||
@ -100,17 +104,31 @@ class AutoBackend(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
w = str(weights[0] if isinstance(weights, list) else weights)
|
w = str(weights[0] if isinstance(weights, list) else weights)
|
||||||
nn_module = isinstance(weights, torch.nn.Module)
|
nn_module = isinstance(weights, torch.nn.Module)
|
||||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
|
(
|
||||||
self._model_type(w)
|
pt,
|
||||||
|
jit,
|
||||||
|
onnx,
|
||||||
|
xml,
|
||||||
|
engine,
|
||||||
|
coreml,
|
||||||
|
saved_model,
|
||||||
|
pb,
|
||||||
|
tflite,
|
||||||
|
edgetpu,
|
||||||
|
tfjs,
|
||||||
|
paddle,
|
||||||
|
ncnn,
|
||||||
|
triton,
|
||||||
|
) = self._model_type(w)
|
||||||
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
|
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
|
||||||
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
||||||
stride = 32 # default stride
|
stride = 32 # default stride
|
||||||
model, metadata = None, None
|
model, metadata = None, None
|
||||||
|
|
||||||
# Set device
|
# Set device
|
||||||
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA
|
||||||
if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats
|
if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats
|
||||||
device = torch.device('cpu')
|
device = torch.device("cpu")
|
||||||
cuda = False
|
cuda = False
|
||||||
|
|
||||||
# Download if not local
|
# Download if not local
|
||||||
@ -121,77 +139,79 @@ class AutoBackend(nn.Module):
|
|||||||
if nn_module: # in-memory PyTorch model
|
if nn_module: # in-memory PyTorch model
|
||||||
model = weights.to(device)
|
model = weights.to(device)
|
||||||
model = model.fuse(verbose=verbose) if fuse else model
|
model = model.fuse(verbose=verbose) if fuse else model
|
||||||
if hasattr(model, 'kpt_shape'):
|
if hasattr(model, "kpt_shape"):
|
||||||
kpt_shape = model.kpt_shape # pose-only
|
kpt_shape = model.kpt_shape # pose-only
|
||||||
stride = max(int(model.stride.max()), 32) # model stride
|
stride = max(int(model.stride.max()), 32) # model stride
|
||||||
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
names = model.module.names if hasattr(model, "module") else model.names # get class names
|
||||||
model.half() if fp16 else model.float()
|
model.half() if fp16 else model.float()
|
||||||
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
||||||
pt = True
|
pt = True
|
||||||
elif pt: # PyTorch
|
elif pt: # PyTorch
|
||||||
from ultralytics.nn.tasks import attempt_load_weights
|
from ultralytics.nn.tasks import attempt_load_weights
|
||||||
model = attempt_load_weights(weights if isinstance(weights, list) else w,
|
|
||||||
device=device,
|
model = attempt_load_weights(
|
||||||
inplace=True,
|
weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse
|
||||||
fuse=fuse)
|
)
|
||||||
if hasattr(model, 'kpt_shape'):
|
if hasattr(model, "kpt_shape"):
|
||||||
kpt_shape = model.kpt_shape # pose-only
|
kpt_shape = model.kpt_shape # pose-only
|
||||||
stride = max(int(model.stride.max()), 32) # model stride
|
stride = max(int(model.stride.max()), 32) # model stride
|
||||||
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
names = model.module.names if hasattr(model, "module") else model.names # get class names
|
||||||
model.half() if fp16 else model.float()
|
model.half() if fp16 else model.float()
|
||||||
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
||||||
elif jit: # TorchScript
|
elif jit: # TorchScript
|
||||||
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
LOGGER.info(f"Loading {w} for TorchScript inference...")
|
||||||
extra_files = {'config.txt': ''} # model metadata
|
extra_files = {"config.txt": ""} # model metadata
|
||||||
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
||||||
model.half() if fp16 else model.float()
|
model.half() if fp16 else model.float()
|
||||||
if extra_files['config.txt']: # load metadata dict
|
if extra_files["config.txt"]: # load metadata dict
|
||||||
metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items()))
|
metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
|
||||||
elif dnn: # ONNX OpenCV DNN
|
elif dnn: # ONNX OpenCV DNN
|
||||||
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
|
||||||
check_requirements('opencv-python>=4.5.4')
|
check_requirements("opencv-python>=4.5.4")
|
||||||
net = cv2.dnn.readNetFromONNX(w)
|
net = cv2.dnn.readNetFromONNX(w)
|
||||||
elif onnx: # ONNX Runtime
|
elif onnx: # ONNX Runtime
|
||||||
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
|
||||||
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
|
||||||
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
|
||||||
session = onnxruntime.InferenceSession(w, providers=providers)
|
session = onnxruntime.InferenceSession(w, providers=providers)
|
||||||
output_names = [x.name for x in session.get_outputs()]
|
output_names = [x.name for x in session.get_outputs()]
|
||||||
metadata = session.get_modelmeta().custom_metadata_map # metadata
|
metadata = session.get_modelmeta().custom_metadata_map # metadata
|
||||||
elif xml: # OpenVINO
|
elif xml: # OpenVINO
|
||||||
LOGGER.info(f'Loading {w} for OpenVINO inference...')
|
LOGGER.info(f"Loading {w} for OpenVINO inference...")
|
||||||
check_requirements('openvino>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
check_requirements("openvino>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||||
from openvino.runtime import Core, Layout, get_batch # noqa
|
from openvino.runtime import Core, Layout, get_batch # noqa
|
||||||
|
|
||||||
core = Core()
|
core = Core()
|
||||||
w = Path(w)
|
w = Path(w)
|
||||||
if not w.is_file(): # if not *.xml
|
if not w.is_file(): # if not *.xml
|
||||||
w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir
|
w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir
|
||||||
ov_model = core.read_model(model=str(w), weights=w.with_suffix('.bin'))
|
ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin"))
|
||||||
if ov_model.get_parameters()[0].get_layout().empty:
|
if ov_model.get_parameters()[0].get_layout().empty:
|
||||||
ov_model.get_parameters()[0].set_layout(Layout('NCHW'))
|
ov_model.get_parameters()[0].set_layout(Layout("NCHW"))
|
||||||
batch_dim = get_batch(ov_model)
|
batch_dim = get_batch(ov_model)
|
||||||
if batch_dim.is_static:
|
if batch_dim.is_static:
|
||||||
batch_size = batch_dim.get_length()
|
batch_size = batch_dim.get_length()
|
||||||
ov_compiled_model = core.compile_model(ov_model, device_name='AUTO') # AUTO selects best available device
|
ov_compiled_model = core.compile_model(ov_model, device_name="AUTO") # AUTO selects best available device
|
||||||
metadata = w.parent / 'metadata.yaml'
|
metadata = w.parent / "metadata.yaml"
|
||||||
elif engine: # TensorRT
|
elif engine: # TensorRT
|
||||||
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
LOGGER.info(f"Loading {w} for TensorRT inference...")
|
||||||
try:
|
try:
|
||||||
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
|
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
|
||||||
except ImportError:
|
except ImportError:
|
||||||
if LINUX:
|
if LINUX:
|
||||||
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
|
check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
|
||||||
import tensorrt as trt # noqa
|
import tensorrt as trt # noqa
|
||||||
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
|
||||||
if device.type == 'cpu':
|
if device.type == "cpu":
|
||||||
device = torch.device('cuda:0')
|
device = torch.device("cuda:0")
|
||||||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
|
||||||
logger = trt.Logger(trt.Logger.INFO)
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
# Read file
|
# Read file
|
||||||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
with open(w, "rb") as f, trt.Runtime(logger) as runtime:
|
||||||
meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length
|
meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length
|
||||||
metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata
|
metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
|
||||||
model = runtime.deserialize_cuda_engine(f.read()) # read engine
|
model = runtime.deserialize_cuda_engine(f.read()) # read engine
|
||||||
context = model.create_execution_context()
|
context = model.create_execution_context()
|
||||||
bindings = OrderedDict()
|
bindings = OrderedDict()
|
||||||
@ -213,116 +233,124 @@ class AutoBackend(nn.Module):
|
|||||||
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
||||||
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
||||||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
||||||
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size
|
||||||
elif coreml: # CoreML
|
elif coreml: # CoreML
|
||||||
LOGGER.info(f'Loading {w} for CoreML inference...')
|
LOGGER.info(f"Loading {w} for CoreML inference...")
|
||||||
import coremltools as ct
|
import coremltools as ct
|
||||||
|
|
||||||
model = ct.models.MLModel(w)
|
model = ct.models.MLModel(w)
|
||||||
metadata = dict(model.user_defined_metadata)
|
metadata = dict(model.user_defined_metadata)
|
||||||
elif saved_model: # TF SavedModel
|
elif saved_model: # TF SavedModel
|
||||||
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
keras = False # assume TF1 saved_model
|
keras = False # assume TF1 saved_model
|
||||||
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
||||||
metadata = Path(w) / 'metadata.yaml'
|
metadata = Path(w) / "metadata.yaml"
|
||||||
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||||
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from ultralytics.engine.exporter import gd_outputs
|
from ultralytics.engine.exporter import gd_outputs
|
||||||
|
|
||||||
def wrap_frozen_graph(gd, inputs, outputs):
|
def wrap_frozen_graph(gd, inputs, outputs):
|
||||||
"""Wrap frozen graphs for deployment."""
|
"""Wrap frozen graphs for deployment."""
|
||||||
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
|
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
|
||||||
ge = x.graph.as_graph_element
|
ge = x.graph.as_graph_element
|
||||||
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
||||||
|
|
||||||
gd = tf.Graph().as_graph_def() # TF GraphDef
|
gd = tf.Graph().as_graph_def() # TF GraphDef
|
||||||
with open(w, 'rb') as f:
|
with open(w, "rb") as f:
|
||||||
gd.ParseFromString(f.read())
|
gd.ParseFromString(f.read())
|
||||||
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
|
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
||||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
|
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
|
||||||
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
||||||
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...")
|
||||||
delegate = {
|
delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[
|
||||||
'Linux': 'libedgetpu.so.1',
|
platform.system()
|
||||||
'Darwin': 'libedgetpu.1.dylib',
|
]
|
||||||
'Windows': 'edgetpu.dll'}[platform.system()]
|
|
||||||
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
|
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
|
||||||
else: # TFLite
|
else: # TFLite
|
||||||
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")
|
||||||
interpreter = Interpreter(model_path=w) # load TFLite model
|
interpreter = Interpreter(model_path=w) # load TFLite model
|
||||||
interpreter.allocate_tensors() # allocate
|
interpreter.allocate_tensors() # allocate
|
||||||
input_details = interpreter.get_input_details() # inputs
|
input_details = interpreter.get_input_details() # inputs
|
||||||
output_details = interpreter.get_output_details() # outputs
|
output_details = interpreter.get_output_details() # outputs
|
||||||
# Load metadata
|
# Load metadata
|
||||||
with contextlib.suppress(zipfile.BadZipFile):
|
with contextlib.suppress(zipfile.BadZipFile):
|
||||||
with zipfile.ZipFile(w, 'r') as model:
|
with zipfile.ZipFile(w, "r") as model:
|
||||||
meta_file = model.namelist()[0]
|
meta_file = model.namelist()[0]
|
||||||
metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
|
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
||||||
elif tfjs: # TF.js
|
elif tfjs: # TF.js
|
||||||
raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.')
|
raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.")
|
||||||
elif paddle: # PaddlePaddle
|
elif paddle: # PaddlePaddle
|
||||||
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
|
LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
|
||||||
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
|
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle")
|
||||||
import paddle.inference as pdi # noqa
|
import paddle.inference as pdi # noqa
|
||||||
|
|
||||||
w = Path(w)
|
w = Path(w)
|
||||||
if not w.is_file(): # if not *.pdmodel
|
if not w.is_file(): # if not *.pdmodel
|
||||||
w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
|
w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir
|
||||||
config = pdi.Config(str(w), str(w.with_suffix('.pdiparams')))
|
config = pdi.Config(str(w), str(w.with_suffix(".pdiparams")))
|
||||||
if cuda:
|
if cuda:
|
||||||
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
||||||
predictor = pdi.create_predictor(config)
|
predictor = pdi.create_predictor(config)
|
||||||
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
||||||
output_names = predictor.get_output_names()
|
output_names = predictor.get_output_names()
|
||||||
metadata = w.parents[1] / 'metadata.yaml'
|
metadata = w.parents[1] / "metadata.yaml"
|
||||||
elif ncnn: # ncnn
|
elif ncnn: # ncnn
|
||||||
LOGGER.info(f'Loading {w} for ncnn inference...')
|
LOGGER.info(f"Loading {w} for ncnn inference...")
|
||||||
check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn
|
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires ncnn
|
||||||
import ncnn as pyncnn
|
import ncnn as pyncnn
|
||||||
|
|
||||||
net = pyncnn.Net()
|
net = pyncnn.Net()
|
||||||
net.opt.use_vulkan_compute = cuda
|
net.opt.use_vulkan_compute = cuda
|
||||||
w = Path(w)
|
w = Path(w)
|
||||||
if not w.is_file(): # if not *.param
|
if not w.is_file(): # if not *.param
|
||||||
w = next(w.glob('*.param')) # get *.param file from *_ncnn_model dir
|
w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir
|
||||||
net.load_param(str(w))
|
net.load_param(str(w))
|
||||||
net.load_model(str(w.with_suffix('.bin')))
|
net.load_model(str(w.with_suffix(".bin")))
|
||||||
metadata = w.parent / 'metadata.yaml'
|
metadata = w.parent / "metadata.yaml"
|
||||||
elif triton: # NVIDIA Triton Inference Server
|
elif triton: # NVIDIA Triton Inference Server
|
||||||
check_requirements('tritonclient[all]')
|
check_requirements("tritonclient[all]")
|
||||||
from ultralytics.utils.triton import TritonRemoteModel
|
from ultralytics.utils.triton import TritonRemoteModel
|
||||||
|
|
||||||
model = TritonRemoteModel(w)
|
model = TritonRemoteModel(w)
|
||||||
else:
|
else:
|
||||||
from ultralytics.engine.exporter import export_formats
|
from ultralytics.engine.exporter import export_formats
|
||||||
raise TypeError(f"model='{w}' is not a supported model format. "
|
|
||||||
'See https://docs.ultralytics.com/modes/predict for help.'
|
raise TypeError(
|
||||||
f'\n\n{export_formats()}')
|
f"model='{w}' is not a supported model format. "
|
||||||
|
"See https://docs.ultralytics.com/modes/predict for help."
|
||||||
|
f"\n\n{export_formats()}"
|
||||||
|
)
|
||||||
|
|
||||||
# Load external metadata YAML
|
# Load external metadata YAML
|
||||||
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
|
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
|
||||||
metadata = yaml_load(metadata)
|
metadata = yaml_load(metadata)
|
||||||
if metadata:
|
if metadata:
|
||||||
for k, v in metadata.items():
|
for k, v in metadata.items():
|
||||||
if k in ('stride', 'batch'):
|
if k in ("stride", "batch"):
|
||||||
metadata[k] = int(v)
|
metadata[k] = int(v)
|
||||||
elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str):
|
elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str):
|
||||||
metadata[k] = eval(v)
|
metadata[k] = eval(v)
|
||||||
stride = metadata['stride']
|
stride = metadata["stride"]
|
||||||
task = metadata['task']
|
task = metadata["task"]
|
||||||
batch = metadata['batch']
|
batch = metadata["batch"]
|
||||||
imgsz = metadata['imgsz']
|
imgsz = metadata["imgsz"]
|
||||||
names = metadata['names']
|
names = metadata["names"]
|
||||||
kpt_shape = metadata.get('kpt_shape')
|
kpt_shape = metadata.get("kpt_shape")
|
||||||
elif not (pt or triton or nn_module):
|
elif not (pt or triton or nn_module):
|
||||||
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
|
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
|
||||||
|
|
||||||
# Check names
|
# Check names
|
||||||
if 'names' not in locals(): # names missing
|
if "names" not in locals(): # names missing
|
||||||
names = default_class_names(data)
|
names = default_class_names(data)
|
||||||
names = check_class_names(names)
|
names = check_class_names(names)
|
||||||
|
|
||||||
@ -367,26 +395,28 @@ class AutoBackend(nn.Module):
|
|||||||
im = im.cpu().numpy() # FP32
|
im = im.cpu().numpy() # FP32
|
||||||
y = list(self.ov_compiled_model(im).values())
|
y = list(self.ov_compiled_model(im).values())
|
||||||
elif self.engine: # TensorRT
|
elif self.engine: # TensorRT
|
||||||
if self.dynamic and im.shape != self.bindings['images'].shape:
|
if self.dynamic and im.shape != self.bindings["images"].shape:
|
||||||
i = self.model.get_binding_index('images')
|
i = self.model.get_binding_index("images")
|
||||||
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
||||||
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
|
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
|
||||||
for name in self.output_names:
|
for name in self.output_names:
|
||||||
i = self.model.get_binding_index(name)
|
i = self.model.get_binding_index(name)
|
||||||
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
||||||
s = self.bindings['images'].shape
|
s = self.bindings["images"].shape
|
||||||
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
||||||
self.binding_addrs['images'] = int(im.data_ptr())
|
self.binding_addrs["images"] = int(im.data_ptr())
|
||||||
self.context.execute_v2(list(self.binding_addrs.values()))
|
self.context.execute_v2(list(self.binding_addrs.values()))
|
||||||
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
||||||
elif self.coreml: # CoreML
|
elif self.coreml: # CoreML
|
||||||
im = im[0].cpu().numpy()
|
im = im[0].cpu().numpy()
|
||||||
im_pil = Image.fromarray((im * 255).astype('uint8'))
|
im_pil = Image.fromarray((im * 255).astype("uint8"))
|
||||||
# im = im.resize((192, 320), Image.BILINEAR)
|
# im = im.resize((192, 320), Image.BILINEAR)
|
||||||
y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
|
y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized
|
||||||
if 'confidence' in y:
|
if "confidence" in y:
|
||||||
raise TypeError('Ultralytics only supports inference of non-pipelined CoreML models exported with '
|
raise TypeError(
|
||||||
f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.")
|
"Ultralytics only supports inference of non-pipelined CoreML models exported with "
|
||||||
|
f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export."
|
||||||
|
)
|
||||||
# TODO: CoreML NMS inference handling
|
# TODO: CoreML NMS inference handling
|
||||||
# from ultralytics.utils.ops import xywh2xyxy
|
# from ultralytics.utils.ops import xywh2xyxy
|
||||||
# box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
# box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
||||||
@ -425,20 +455,20 @@ class AutoBackend(nn.Module):
|
|||||||
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
|
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
|
||||||
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
|
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
|
||||||
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
||||||
self.names = {i: f'class{i}' for i in range(nc)}
|
self.names = {i: f"class{i}" for i in range(nc)}
|
||||||
else: # Lite or Edge TPU
|
else: # Lite or Edge TPU
|
||||||
details = self.input_details[0]
|
details = self.input_details[0]
|
||||||
integer = details['dtype'] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
|
integer = details["dtype"] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
|
||||||
if integer:
|
if integer:
|
||||||
scale, zero_point = details['quantization']
|
scale, zero_point = details["quantization"]
|
||||||
im = (im / scale + zero_point).astype(details['dtype']) # de-scale
|
im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
|
||||||
self.interpreter.set_tensor(details['index'], im)
|
self.interpreter.set_tensor(details["index"], im)
|
||||||
self.interpreter.invoke()
|
self.interpreter.invoke()
|
||||||
y = []
|
y = []
|
||||||
for output in self.output_details:
|
for output in self.output_details:
|
||||||
x = self.interpreter.get_tensor(output['index'])
|
x = self.interpreter.get_tensor(output["index"])
|
||||||
if integer:
|
if integer:
|
||||||
scale, zero_point = output['quantization']
|
scale, zero_point = output["quantization"]
|
||||||
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
||||||
if x.ndim > 2: # if task is not classification
|
if x.ndim > 2: # if task is not classification
|
||||||
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
|
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
|
||||||
@ -483,13 +513,13 @@ class AutoBackend(nn.Module):
|
|||||||
(None): This method runs the forward pass and don't return any value
|
(None): This method runs the forward pass and don't return any value
|
||||||
"""
|
"""
|
||||||
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
||||||
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
|
if any(warmup_types) and (self.device.type != "cpu" or self.triton):
|
||||||
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
||||||
for _ in range(2 if self.jit else 1):
|
for _ in range(2 if self.jit else 1):
|
||||||
self.forward(im) # warmup
|
self.forward(im) # warmup
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _model_type(p='path/to/model.pt'):
|
def _model_type(p="path/to/model.pt"):
|
||||||
"""
|
"""
|
||||||
This function takes a path to a model file and returns the model type.
|
This function takes a path to a model file and returns the model type.
|
||||||
|
|
||||||
@ -499,18 +529,20 @@ class AutoBackend(nn.Module):
|
|||||||
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
||||||
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
||||||
from ultralytics.engine.exporter import export_formats
|
from ultralytics.engine.exporter import export_formats
|
||||||
|
|
||||||
sf = list(export_formats().Suffix) # export suffixes
|
sf = list(export_formats().Suffix) # export suffixes
|
||||||
if not is_url(p, check=False) and not isinstance(p, str):
|
if not is_url(p, check=False) and not isinstance(p, str):
|
||||||
check_suffix(p, sf) # checks
|
check_suffix(p, sf) # checks
|
||||||
name = Path(p).name
|
name = Path(p).name
|
||||||
types = [s in name for s in sf]
|
types = [s in name for s in sf]
|
||||||
types[5] |= name.endswith('.mlmodel') # retain support for older Apple CoreML *.mlmodel formats
|
types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats
|
||||||
types[8] &= not types[9] # tflite &= not edgetpu
|
types[8] &= not types[9] # tflite &= not edgetpu
|
||||||
if any(types):
|
if any(types):
|
||||||
triton = False
|
triton = False
|
||||||
else:
|
else:
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
url = urlsplit(p)
|
url = urlsplit(p)
|
||||||
triton = url.netloc and url.path and url.scheme in {'http', 'grpc'}
|
triton = url.netloc and url.path and url.scheme in {"http", "grpc"}
|
||||||
|
|
||||||
return types + [triton]
|
return types + [triton]
|
||||||
|
|||||||
@ -17,18 +17,101 @@ Example:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
|
from .block import (
|
||||||
HGBlock, HGStem, Proto, RepC3, ResNetLayer)
|
C1,
|
||||||
from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
|
C2,
|
||||||
GhostConv, LightConv, RepConv, SpatialAttention)
|
C3,
|
||||||
|
C3TR,
|
||||||
|
DFL,
|
||||||
|
SPP,
|
||||||
|
SPPF,
|
||||||
|
Bottleneck,
|
||||||
|
BottleneckCSP,
|
||||||
|
C2f,
|
||||||
|
C3Ghost,
|
||||||
|
C3x,
|
||||||
|
GhostBottleneck,
|
||||||
|
HGBlock,
|
||||||
|
HGStem,
|
||||||
|
Proto,
|
||||||
|
RepC3,
|
||||||
|
ResNetLayer,
|
||||||
|
)
|
||||||
|
from .conv import (
|
||||||
|
CBAM,
|
||||||
|
ChannelAttention,
|
||||||
|
Concat,
|
||||||
|
Conv,
|
||||||
|
Conv2,
|
||||||
|
ConvTranspose,
|
||||||
|
DWConv,
|
||||||
|
DWConvTranspose2d,
|
||||||
|
Focus,
|
||||||
|
GhostConv,
|
||||||
|
LightConv,
|
||||||
|
RepConv,
|
||||||
|
SpatialAttention,
|
||||||
|
)
|
||||||
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment
|
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment
|
||||||
from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
|
from .transformer import (
|
||||||
MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
|
AIFI,
|
||||||
|
MLP,
|
||||||
|
DeformableTransformerDecoder,
|
||||||
|
DeformableTransformerDecoderLayer,
|
||||||
|
LayerNorm2d,
|
||||||
|
MLPBlock,
|
||||||
|
MSDeformAttn,
|
||||||
|
TransformerBlock,
|
||||||
|
TransformerEncoderLayer,
|
||||||
|
TransformerLayer,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus',
|
__all__ = (
|
||||||
'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer',
|
"Conv",
|
||||||
'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
|
"Conv2",
|
||||||
'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
|
"LightConv",
|
||||||
'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
|
"RepConv",
|
||||||
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer',
|
"DWConv",
|
||||||
'OBB')
|
"DWConvTranspose2d",
|
||||||
|
"ConvTranspose",
|
||||||
|
"Focus",
|
||||||
|
"GhostConv",
|
||||||
|
"ChannelAttention",
|
||||||
|
"SpatialAttention",
|
||||||
|
"CBAM",
|
||||||
|
"Concat",
|
||||||
|
"TransformerLayer",
|
||||||
|
"TransformerBlock",
|
||||||
|
"MLPBlock",
|
||||||
|
"LayerNorm2d",
|
||||||
|
"DFL",
|
||||||
|
"HGBlock",
|
||||||
|
"HGStem",
|
||||||
|
"SPP",
|
||||||
|
"SPPF",
|
||||||
|
"C1",
|
||||||
|
"C2",
|
||||||
|
"C3",
|
||||||
|
"C2f",
|
||||||
|
"C3x",
|
||||||
|
"C3TR",
|
||||||
|
"C3Ghost",
|
||||||
|
"GhostBottleneck",
|
||||||
|
"Bottleneck",
|
||||||
|
"BottleneckCSP",
|
||||||
|
"Proto",
|
||||||
|
"Detect",
|
||||||
|
"Segment",
|
||||||
|
"Pose",
|
||||||
|
"Classify",
|
||||||
|
"TransformerEncoderLayer",
|
||||||
|
"RepC3",
|
||||||
|
"RTDETRDecoder",
|
||||||
|
"AIFI",
|
||||||
|
"DeformableTransformerDecoder",
|
||||||
|
"DeformableTransformerDecoderLayer",
|
||||||
|
"MSDeformAttn",
|
||||||
|
"MLP",
|
||||||
|
"ResNetLayer",
|
||||||
|
"OBB",
|
||||||
|
)
|
||||||
|
|||||||
@ -8,8 +8,26 @@ import torch.nn.functional as F
|
|||||||
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
|
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
|
||||||
from .transformer import TransformerBlock
|
from .transformer import TransformerBlock
|
||||||
|
|
||||||
__all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
|
__all__ = (
|
||||||
'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3', 'ResNetLayer')
|
"DFL",
|
||||||
|
"HGBlock",
|
||||||
|
"HGStem",
|
||||||
|
"SPP",
|
||||||
|
"SPPF",
|
||||||
|
"C1",
|
||||||
|
"C2",
|
||||||
|
"C3",
|
||||||
|
"C2f",
|
||||||
|
"C3x",
|
||||||
|
"C3TR",
|
||||||
|
"C3Ghost",
|
||||||
|
"GhostBottleneck",
|
||||||
|
"Bottleneck",
|
||||||
|
"BottleneckCSP",
|
||||||
|
"Proto",
|
||||||
|
"RepC3",
|
||||||
|
"ResNetLayer",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DFL(nn.Module):
|
class DFL(nn.Module):
|
||||||
@ -284,9 +302,11 @@ class GhostBottleneck(nn.Module):
|
|||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
GhostConv(c1, c_, 1, 1), # pw
|
GhostConv(c1, c_, 1, 1), # pw
|
||||||
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
|
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
|
||||||
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
|
GhostConv(c_, c2, 1, 1, act=False), # pw-linear
|
||||||
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
|
)
|
||||||
act=False)) if s == 2 else nn.Identity()
|
self.shortcut = (
|
||||||
|
nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Applies skip connection and concatenation to input tensor."""
|
"""Applies skip connection and concatenation to input tensor."""
|
||||||
@ -359,8 +379,9 @@ class ResNetLayer(nn.Module):
|
|||||||
self.is_first = is_first
|
self.is_first = is_first
|
||||||
|
|
||||||
if self.is_first:
|
if self.is_first:
|
||||||
self.layer = nn.Sequential(Conv(c1, c2, k=7, s=2, p=3, act=True),
|
self.layer = nn.Sequential(
|
||||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
blocks = [ResNetBlock(c1, c2, s, e=e)]
|
blocks = [ResNetBlock(c1, c2, s, e=e)]
|
||||||
blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])
|
blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])
|
||||||
|
|||||||
@ -7,8 +7,21 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
__all__ = ('Conv', 'Conv2', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
|
__all__ = (
|
||||||
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
|
"Conv",
|
||||||
|
"Conv2",
|
||||||
|
"LightConv",
|
||||||
|
"DWConv",
|
||||||
|
"DWConvTranspose2d",
|
||||||
|
"ConvTranspose",
|
||||||
|
"Focus",
|
||||||
|
"GhostConv",
|
||||||
|
"ChannelAttention",
|
||||||
|
"SpatialAttention",
|
||||||
|
"CBAM",
|
||||||
|
"Concat",
|
||||||
|
"RepConv",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
||||||
@ -22,6 +35,7 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation
|
|||||||
|
|
||||||
class Conv(nn.Module):
|
class Conv(nn.Module):
|
||||||
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
|
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
|
||||||
|
|
||||||
default_act = nn.SiLU() # default activation
|
default_act = nn.SiLU() # default activation
|
||||||
|
|
||||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
||||||
@ -60,9 +74,9 @@ class Conv2(Conv):
|
|||||||
"""Fuse parallel convolutions."""
|
"""Fuse parallel convolutions."""
|
||||||
w = torch.zeros_like(self.conv.weight.data)
|
w = torch.zeros_like(self.conv.weight.data)
|
||||||
i = [x // 2 for x in w.shape[2:]]
|
i = [x // 2 for x in w.shape[2:]]
|
||||||
w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
|
w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone()
|
||||||
self.conv.weight.data += w
|
self.conv.weight.data += w
|
||||||
self.__delattr__('cv2')
|
self.__delattr__("cv2")
|
||||||
self.forward = self.forward_fuse
|
self.forward = self.forward_fuse
|
||||||
|
|
||||||
|
|
||||||
@ -102,6 +116,7 @@ class DWConvTranspose2d(nn.ConvTranspose2d):
|
|||||||
|
|
||||||
class ConvTranspose(nn.Module):
|
class ConvTranspose(nn.Module):
|
||||||
"""Convolution transpose 2d layer."""
|
"""Convolution transpose 2d layer."""
|
||||||
|
|
||||||
default_act = nn.SiLU() # default activation
|
default_act = nn.SiLU() # default activation
|
||||||
|
|
||||||
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
|
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
|
||||||
@ -164,6 +179,7 @@ class RepConv(nn.Module):
|
|||||||
This module is used in RT-DETR.
|
This module is used in RT-DETR.
|
||||||
Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
|
Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default_act = nn.SiLU() # default activation
|
default_act = nn.SiLU() # default activation
|
||||||
|
|
||||||
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
|
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
|
||||||
@ -214,7 +230,7 @@ class RepConv(nn.Module):
|
|||||||
beta = branch.bn.bias
|
beta = branch.bn.bias
|
||||||
eps = branch.bn.eps
|
eps = branch.bn.eps
|
||||||
elif isinstance(branch, nn.BatchNorm2d):
|
elif isinstance(branch, nn.BatchNorm2d):
|
||||||
if not hasattr(self, 'id_tensor'):
|
if not hasattr(self, "id_tensor"):
|
||||||
input_dim = self.c1 // self.g
|
input_dim = self.c1 // self.g
|
||||||
kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
|
kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
|
||||||
for i in range(self.c1):
|
for i in range(self.c1):
|
||||||
@ -232,29 +248,31 @@ class RepConv(nn.Module):
|
|||||||
|
|
||||||
def fuse_convs(self):
|
def fuse_convs(self):
|
||||||
"""Combines two convolution layers into a single layer and removes unused attributes from the class."""
|
"""Combines two convolution layers into a single layer and removes unused attributes from the class."""
|
||||||
if hasattr(self, 'conv'):
|
if hasattr(self, "conv"):
|
||||||
return
|
return
|
||||||
kernel, bias = self.get_equivalent_kernel_bias()
|
kernel, bias = self.get_equivalent_kernel_bias()
|
||||||
self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels=self.conv1.conv.in_channels,
|
||||||
out_channels=self.conv1.conv.out_channels,
|
out_channels=self.conv1.conv.out_channels,
|
||||||
kernel_size=self.conv1.conv.kernel_size,
|
kernel_size=self.conv1.conv.kernel_size,
|
||||||
stride=self.conv1.conv.stride,
|
stride=self.conv1.conv.stride,
|
||||||
padding=self.conv1.conv.padding,
|
padding=self.conv1.conv.padding,
|
||||||
dilation=self.conv1.conv.dilation,
|
dilation=self.conv1.conv.dilation,
|
||||||
groups=self.conv1.conv.groups,
|
groups=self.conv1.conv.groups,
|
||||||
bias=True).requires_grad_(False)
|
bias=True,
|
||||||
|
).requires_grad_(False)
|
||||||
self.conv.weight.data = kernel
|
self.conv.weight.data = kernel
|
||||||
self.conv.bias.data = bias
|
self.conv.bias.data = bias
|
||||||
for para in self.parameters():
|
for para in self.parameters():
|
||||||
para.detach_()
|
para.detach_()
|
||||||
self.__delattr__('conv1')
|
self.__delattr__("conv1")
|
||||||
self.__delattr__('conv2')
|
self.__delattr__("conv2")
|
||||||
if hasattr(self, 'nm'):
|
if hasattr(self, "nm"):
|
||||||
self.__delattr__('nm')
|
self.__delattr__("nm")
|
||||||
if hasattr(self, 'bn'):
|
if hasattr(self, "bn"):
|
||||||
self.__delattr__('bn')
|
self.__delattr__("bn")
|
||||||
if hasattr(self, 'id_tensor'):
|
if hasattr(self, "id_tensor"):
|
||||||
self.__delattr__('id_tensor')
|
self.__delattr__("id_tensor")
|
||||||
|
|
||||||
|
|
||||||
class ChannelAttention(nn.Module):
|
class ChannelAttention(nn.Module):
|
||||||
@ -278,7 +296,7 @@ class SpatialAttention(nn.Module):
|
|||||||
def __init__(self, kernel_size=7):
|
def __init__(self, kernel_size=7):
|
||||||
"""Initialize Spatial-attention module with kernel size argument."""
|
"""Initialize Spatial-attention module with kernel size argument."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
assert kernel_size in (3, 7), "kernel size must be 3 or 7"
|
||||||
padding = 3 if kernel_size == 7 else 1
|
padding = 3 if kernel_size == 7 else 1
|
||||||
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||||
self.act = nn.Sigmoid()
|
self.act = nn.Sigmoid()
|
||||||
|
|||||||
@ -14,11 +14,12 @@ from .conv import Conv
|
|||||||
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
||||||
from .utils import bias_init_with_prob, linear_init_
|
from .utils import bias_init_with_prob, linear_init_
|
||||||
|
|
||||||
__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'OBB', 'RTDETRDecoder'
|
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"
|
||||||
|
|
||||||
|
|
||||||
class Detect(nn.Module):
|
class Detect(nn.Module):
|
||||||
"""YOLOv8 Detect head for detection models."""
|
"""YOLOv8 Detect head for detection models."""
|
||||||
|
|
||||||
dynamic = False # force grid reconstruction
|
dynamic = False # force grid reconstruction
|
||||||
export = False # export mode
|
export = False # export mode
|
||||||
shape = None
|
shape = None
|
||||||
@ -35,7 +36,8 @@ class Detect(nn.Module):
|
|||||||
self.stride = torch.zeros(self.nl) # strides computed during build
|
self.stride = torch.zeros(self.nl) # strides computed during build
|
||||||
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
|
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
|
||||||
self.cv2 = nn.ModuleList(
|
self.cv2 = nn.ModuleList(
|
||||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
|
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
||||||
|
)
|
||||||
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
||||||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||||||
|
|
||||||
@ -53,14 +55,14 @@ class Detect(nn.Module):
|
|||||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
|
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
|
||||||
box = x_cat[:, :self.reg_max * 4]
|
box = x_cat[:, : self.reg_max * 4]
|
||||||
cls = x_cat[:, self.reg_max * 4:]
|
cls = x_cat[:, self.reg_max * 4 :]
|
||||||
else:
|
else:
|
||||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||||
dbox = self.decode_bboxes(box)
|
dbox = self.decode_bboxes(box)
|
||||||
|
|
||||||
if self.export and self.format in ('tflite', 'edgetpu'):
|
if self.export and self.format in ("tflite", "edgetpu"):
|
||||||
# Precompute normalization factor to increase numerical stability
|
# Precompute normalization factor to increase numerical stability
|
||||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||||
img_h = shape[2]
|
img_h = shape[2]
|
||||||
@ -79,7 +81,7 @@ class Detect(nn.Module):
|
|||||||
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
||||||
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
||||||
a[-1].bias.data[:] = 1.0 # box
|
a[-1].bias.data[:] = 1.0 # box
|
||||||
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||||||
|
|
||||||
def decode_bboxes(self, bboxes):
|
def decode_bboxes(self, bboxes):
|
||||||
"""Decode bounding boxes."""
|
"""Decode bounding boxes."""
|
||||||
@ -214,6 +216,7 @@ class RTDETRDecoder(nn.Module):
|
|||||||
and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
|
and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
|
||||||
Transformer decoder layers to output the final predictions.
|
Transformer decoder layers to output the final predictions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
export = False # export mode
|
export = False # export mode
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -226,14 +229,15 @@ class RTDETRDecoder(nn.Module):
|
|||||||
nh=8, # num head
|
nh=8, # num head
|
||||||
ndl=6, # num decoder layers
|
ndl=6, # num decoder layers
|
||||||
d_ffn=1024, # dim of feedforward
|
d_ffn=1024, # dim of feedforward
|
||||||
dropout=0.,
|
dropout=0.0,
|
||||||
act=nn.ReLU(),
|
act=nn.ReLU(),
|
||||||
eval_idx=-1,
|
eval_idx=-1,
|
||||||
# Training args
|
# Training args
|
||||||
nd=100, # num denoising
|
nd=100, # num denoising
|
||||||
label_noise_ratio=0.5,
|
label_noise_ratio=0.5,
|
||||||
box_noise_scale=1.0,
|
box_noise_scale=1.0,
|
||||||
learnt_init_query=False):
|
learnt_init_query=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the RTDETRDecoder module with the given parameters.
|
Initializes the RTDETRDecoder module with the given parameters.
|
||||||
|
|
||||||
@ -302,28 +306,30 @@ class RTDETRDecoder(nn.Module):
|
|||||||
feats, shapes = self._get_encoder_input(x)
|
feats, shapes = self._get_encoder_input(x)
|
||||||
|
|
||||||
# Prepare denoising training
|
# Prepare denoising training
|
||||||
dn_embed, dn_bbox, attn_mask, dn_meta = \
|
dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
|
||||||
get_cdn_group(batch,
|
batch,
|
||||||
self.nc,
|
self.nc,
|
||||||
self.num_queries,
|
self.num_queries,
|
||||||
self.denoising_class_embed.weight,
|
self.denoising_class_embed.weight,
|
||||||
self.num_denoising,
|
self.num_denoising,
|
||||||
self.label_noise_ratio,
|
self.label_noise_ratio,
|
||||||
self.box_noise_scale,
|
self.box_noise_scale,
|
||||||
self.training)
|
self.training,
|
||||||
|
)
|
||||||
|
|
||||||
embed, refer_bbox, enc_bboxes, enc_scores = \
|
embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
|
||||||
self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
|
|
||||||
|
|
||||||
# Decoder
|
# Decoder
|
||||||
dec_bboxes, dec_scores = self.decoder(embed,
|
dec_bboxes, dec_scores = self.decoder(
|
||||||
|
embed,
|
||||||
refer_bbox,
|
refer_bbox,
|
||||||
feats,
|
feats,
|
||||||
shapes,
|
shapes,
|
||||||
self.dec_bbox_head,
|
self.dec_bbox_head,
|
||||||
self.dec_score_head,
|
self.dec_score_head,
|
||||||
self.query_pos_head,
|
self.query_pos_head,
|
||||||
attn_mask=attn_mask)
|
attn_mask=attn_mask,
|
||||||
|
)
|
||||||
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
|
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
|
||||||
if self.training:
|
if self.training:
|
||||||
return x
|
return x
|
||||||
@ -331,24 +337,24 @@ class RTDETRDecoder(nn.Module):
|
|||||||
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
|
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
|
||||||
return y if self.export else (y, x)
|
return y if self.export else (y, x)
|
||||||
|
|
||||||
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
|
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
|
||||||
"""Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
|
"""Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
|
||||||
anchors = []
|
anchors = []
|
||||||
for i, (h, w) in enumerate(shapes):
|
for i, (h, w) in enumerate(shapes):
|
||||||
sy = torch.arange(end=h, dtype=dtype, device=device)
|
sy = torch.arange(end=h, dtype=dtype, device=device)
|
||||||
sx = torch.arange(end=w, dtype=dtype, device=device)
|
sx = torch.arange(end=w, dtype=dtype, device=device)
|
||||||
grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
|
grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||||||
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
|
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
|
||||||
|
|
||||||
valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
|
valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
|
||||||
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
|
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
|
||||||
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
|
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)
|
||||||
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
|
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
|
||||||
|
|
||||||
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
|
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
|
||||||
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
|
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
|
||||||
anchors = torch.log(anchors / (1 - anchors))
|
anchors = torch.log(anchors / (1 - anchors))
|
||||||
anchors = anchors.masked_fill(~valid_mask, float('inf'))
|
anchors = anchors.masked_fill(~valid_mask, float("inf"))
|
||||||
return anchors, valid_mask
|
return anchors, valid_mask
|
||||||
|
|
||||||
def _get_encoder_input(self, x):
|
def _get_encoder_input(self, x):
|
||||||
@ -415,13 +421,13 @@ class RTDETRDecoder(nn.Module):
|
|||||||
# NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
|
# NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
|
||||||
# linear_init_(self.enc_score_head)
|
# linear_init_(self.enc_score_head)
|
||||||
constant_(self.enc_score_head.bias, bias_cls)
|
constant_(self.enc_score_head.bias, bias_cls)
|
||||||
constant_(self.enc_bbox_head.layers[-1].weight, 0.)
|
constant_(self.enc_bbox_head.layers[-1].weight, 0.0)
|
||||||
constant_(self.enc_bbox_head.layers[-1].bias, 0.)
|
constant_(self.enc_bbox_head.layers[-1].bias, 0.0)
|
||||||
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
|
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
|
||||||
# linear_init_(cls_)
|
# linear_init_(cls_)
|
||||||
constant_(cls_.bias, bias_cls)
|
constant_(cls_.bias, bias_cls)
|
||||||
constant_(reg_.layers[-1].weight, 0.)
|
constant_(reg_.layers[-1].weight, 0.0)
|
||||||
constant_(reg_.layers[-1].bias, 0.)
|
constant_(reg_.layers[-1].bias, 0.0)
|
||||||
|
|
||||||
linear_init_(self.enc_output[0])
|
linear_init_(self.enc_output[0])
|
||||||
xavier_uniform_(self.enc_output[0].weight)
|
xavier_uniform_(self.enc_output[0].weight)
|
||||||
|
|||||||
@ -11,8 +11,18 @@ from torch.nn.init import constant_, xavier_uniform_
|
|||||||
from .conv import Conv
|
from .conv import Conv
|
||||||
from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
|
from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
|
||||||
|
|
||||||
__all__ = ('TransformerEncoderLayer', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'AIFI',
|
__all__ = (
|
||||||
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
|
"TransformerEncoderLayer",
|
||||||
|
"TransformerLayer",
|
||||||
|
"TransformerBlock",
|
||||||
|
"MLPBlock",
|
||||||
|
"LayerNorm2d",
|
||||||
|
"AIFI",
|
||||||
|
"DeformableTransformerDecoder",
|
||||||
|
"DeformableTransformerDecoderLayer",
|
||||||
|
"MSDeformAttn",
|
||||||
|
"MLP",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
@ -22,9 +32,11 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
"""Initialize the TransformerEncoderLayer with specified parameters."""
|
"""Initialize the TransformerEncoderLayer with specified parameters."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from ...utils.torch_utils import TORCH_1_9
|
from ...utils.torch_utils import TORCH_1_9
|
||||||
|
|
||||||
if not TORCH_1_9:
|
if not TORCH_1_9:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
'TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).')
|
"TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)."
|
||||||
|
)
|
||||||
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
|
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
|
||||||
# Implementation of Feedforward model
|
# Implementation of Feedforward model
|
||||||
self.fc1 = nn.Linear(c1, cm)
|
self.fc1 = nn.Linear(c1, cm)
|
||||||
@ -91,12 +103,11 @@ class AIFI(TransformerEncoderLayer):
|
|||||||
"""Builds 2D sine-cosine position embedding."""
|
"""Builds 2D sine-cosine position embedding."""
|
||||||
grid_w = torch.arange(int(w), dtype=torch.float32)
|
grid_w = torch.arange(int(w), dtype=torch.float32)
|
||||||
grid_h = torch.arange(int(h), dtype=torch.float32)
|
grid_h = torch.arange(int(h), dtype=torch.float32)
|
||||||
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
|
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
|
||||||
assert embed_dim % 4 == 0, \
|
assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
|
||||||
'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
|
|
||||||
pos_dim = embed_dim // 4
|
pos_dim = embed_dim // 4
|
||||||
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
||||||
omega = 1. / (temperature ** omega)
|
omega = 1.0 / (temperature**omega)
|
||||||
|
|
||||||
out_w = grid_w.flatten()[..., None] @ omega[None]
|
out_w = grid_w.flatten()[..., None] @ omega[None]
|
||||||
out_h = grid_h.flatten()[..., None] @ omega[None]
|
out_h = grid_h.flatten()[..., None] @ omega[None]
|
||||||
@ -213,10 +224,10 @@ class MSDeformAttn(nn.Module):
|
|||||||
"""Initialize MSDeformAttn with the given parameters."""
|
"""Initialize MSDeformAttn with the given parameters."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if d_model % n_heads != 0:
|
if d_model % n_heads != 0:
|
||||||
raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}')
|
raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
|
||||||
_d_per_head = d_model // n_heads
|
_d_per_head = d_model // n_heads
|
||||||
# Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation
|
# Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation
|
||||||
assert _d_per_head * n_heads == d_model, '`d_model` must be divisible by `n_heads`'
|
assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`"
|
||||||
|
|
||||||
self.im2col_step = 64
|
self.im2col_step = 64
|
||||||
|
|
||||||
@ -234,21 +245,24 @@ class MSDeformAttn(nn.Module):
|
|||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
"""Reset module parameters."""
|
"""Reset module parameters."""
|
||||||
constant_(self.sampling_offsets.weight.data, 0.)
|
constant_(self.sampling_offsets.weight.data, 0.0)
|
||||||
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
||||||
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
||||||
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(
|
grid_init = (
|
||||||
1, self.n_levels, self.n_points, 1)
|
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
||||||
|
.view(self.n_heads, 1, 1, 2)
|
||||||
|
.repeat(1, self.n_levels, self.n_points, 1)
|
||||||
|
)
|
||||||
for i in range(self.n_points):
|
for i in range(self.n_points):
|
||||||
grid_init[:, :, i, :] *= i + 1
|
grid_init[:, :, i, :] *= i + 1
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
||||||
constant_(self.attention_weights.weight.data, 0.)
|
constant_(self.attention_weights.weight.data, 0.0)
|
||||||
constant_(self.attention_weights.bias.data, 0.)
|
constant_(self.attention_weights.bias.data, 0.0)
|
||||||
xavier_uniform_(self.value_proj.weight.data)
|
xavier_uniform_(self.value_proj.weight.data)
|
||||||
constant_(self.value_proj.bias.data, 0.)
|
constant_(self.value_proj.bias.data, 0.0)
|
||||||
xavier_uniform_(self.output_proj.weight.data)
|
xavier_uniform_(self.output_proj.weight.data)
|
||||||
constant_(self.output_proj.bias.data, 0.)
|
constant_(self.output_proj.bias.data, 0.0)
|
||||||
|
|
||||||
def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
|
def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
|
||||||
"""
|
"""
|
||||||
@ -288,7 +302,7 @@ class MSDeformAttn(nn.Module):
|
|||||||
add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
|
add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
|
||||||
sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
|
sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
|
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.")
|
||||||
output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
|
output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
|
||||||
return self.output_proj(output)
|
return self.output_proj(output)
|
||||||
|
|
||||||
@ -301,7 +315,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||||||
https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
|
https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0., act=nn.ReLU(), n_levels=4, n_points=4):
|
def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0.0, act=nn.ReLU(), n_levels=4, n_points=4):
|
||||||
"""Initialize the DeformableTransformerDecoderLayer with the given parameters."""
|
"""Initialize the DeformableTransformerDecoderLayer with the given parameters."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -339,14 +353,16 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
# Self attention
|
# Self attention
|
||||||
q = k = self.with_pos_embed(embed, query_pos)
|
q = k = self.with_pos_embed(embed, query_pos)
|
||||||
tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1),
|
tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
|
||||||
attn_mask=attn_mask)[0].transpose(0, 1)
|
0
|
||||||
|
].transpose(0, 1)
|
||||||
embed = embed + self.dropout1(tgt)
|
embed = embed + self.dropout1(tgt)
|
||||||
embed = self.norm1(embed)
|
embed = self.norm1(embed)
|
||||||
|
|
||||||
# Cross attention
|
# Cross attention
|
||||||
tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes,
|
tgt = self.cross_attn(
|
||||||
padding_mask)
|
self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask
|
||||||
|
)
|
||||||
embed = embed + self.dropout2(tgt)
|
embed = embed + self.dropout2(tgt)
|
||||||
embed = self.norm2(embed)
|
embed = self.norm2(embed)
|
||||||
|
|
||||||
@ -379,7 +395,8 @@ class DeformableTransformerDecoder(nn.Module):
|
|||||||
score_head,
|
score_head,
|
||||||
pos_mlp,
|
pos_mlp,
|
||||||
attn_mask=None,
|
attn_mask=None,
|
||||||
padding_mask=None):
|
padding_mask=None,
|
||||||
|
):
|
||||||
"""Perform the forward pass through the entire decoder."""
|
"""Perform the forward pass through the entire decoder."""
|
||||||
output = embed
|
output = embed
|
||||||
dec_bboxes = []
|
dec_bboxes = []
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.init import uniform_
|
from torch.nn.init import uniform_
|
||||||
|
|
||||||
__all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid'
|
__all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid"
|
||||||
|
|
||||||
|
|
||||||
def _get_clones(module, n):
|
def _get_clones(module, n):
|
||||||
@ -27,7 +27,7 @@ def linear_init_(module):
|
|||||||
"""Initialize the weights and biases of a linear module."""
|
"""Initialize the weights and biases of a linear module."""
|
||||||
bound = 1 / math.sqrt(module.weight.shape[0])
|
bound = 1 / math.sqrt(module.weight.shape[0])
|
||||||
uniform_(module.weight, -bound, bound)
|
uniform_(module.weight, -bound, bound)
|
||||||
if hasattr(module, 'bias') and module.bias is not None:
|
if hasattr(module, "bias") and module.bias is not None:
|
||||||
uniform_(module.bias, -bound, bound)
|
uniform_(module.bias, -bound, bound)
|
||||||
|
|
||||||
|
|
||||||
@ -39,9 +39,12 @@ def inverse_sigmoid(x, eps=1e-5):
|
|||||||
return torch.log(x1 / x2)
|
return torch.log(x1 / x2)
|
||||||
|
|
||||||
|
|
||||||
def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor,
|
def multi_scale_deformable_attn_pytorch(
|
||||||
|
value: torch.Tensor,
|
||||||
|
value_spatial_shapes: torch.Tensor,
|
||||||
sampling_locations: torch.Tensor,
|
sampling_locations: torch.Tensor,
|
||||||
attention_weights: torch.Tensor) -> torch.Tensor:
|
attention_weights: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Multi-scale deformable attention.
|
Multi-scale deformable attention.
|
||||||
|
|
||||||
@ -58,23 +61,25 @@ def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shape
|
|||||||
# bs, H_*W_, num_heads*embed_dims ->
|
# bs, H_*W_, num_heads*embed_dims ->
|
||||||
# bs, num_heads*embed_dims, H_*W_ ->
|
# bs, num_heads*embed_dims, H_*W_ ->
|
||||||
# bs*num_heads, embed_dims, H_, W_
|
# bs*num_heads, embed_dims, H_, W_
|
||||||
value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
|
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
|
||||||
# bs, num_queries, num_heads, num_points, 2 ->
|
# bs, num_queries, num_heads, num_points, 2 ->
|
||||||
# bs, num_heads, num_queries, num_points, 2 ->
|
# bs, num_heads, num_queries, num_points, 2 ->
|
||||||
# bs*num_heads, num_queries, num_points, 2
|
# bs*num_heads, num_queries, num_points, 2
|
||||||
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
|
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
|
||||||
# bs*num_heads, embed_dims, num_queries, num_points
|
# bs*num_heads, embed_dims, num_queries, num_points
|
||||||
sampling_value_l_ = F.grid_sample(value_l_,
|
sampling_value_l_ = F.grid_sample(
|
||||||
sampling_grid_l_,
|
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
||||||
mode='bilinear',
|
)
|
||||||
padding_mode='zeros',
|
|
||||||
align_corners=False)
|
|
||||||
sampling_value_list.append(sampling_value_l_)
|
sampling_value_list.append(sampling_value_l_)
|
||||||
# (bs, num_queries, num_heads, num_levels, num_points) ->
|
# (bs, num_queries, num_heads, num_levels, num_points) ->
|
||||||
# (bs, num_heads, num_queries, num_levels, num_points) ->
|
# (bs, num_heads, num_queries, num_levels, num_points) ->
|
||||||
# (bs, num_heads, 1, num_queries, num_levels*num_points)
|
# (bs, num_heads, 1, num_queries, num_levels*num_points)
|
||||||
attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries,
|
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||||
num_levels * num_points)
|
bs * num_heads, 1, num_queries, num_levels * num_points
|
||||||
output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(
|
)
|
||||||
bs, num_heads * embed_dims, num_queries))
|
output = (
|
||||||
|
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||||
|
.sum(-1)
|
||||||
|
.view(bs, num_heads * embed_dims, num_queries)
|
||||||
|
)
|
||||||
return output.transpose(1, 2).contiguous()
|
return output.transpose(1, 2).contiguous()
|
||||||
|
|||||||
@ -7,16 +7,54 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, OBB, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost,
|
from ultralytics.nn.modules import (
|
||||||
C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv,
|
AIFI,
|
||||||
DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3,
|
C1,
|
||||||
RepConv, ResNetLayer, RTDETRDecoder, Segment)
|
C2,
|
||||||
|
C3,
|
||||||
|
C3TR,
|
||||||
|
OBB,
|
||||||
|
SPP,
|
||||||
|
SPPF,
|
||||||
|
Bottleneck,
|
||||||
|
BottleneckCSP,
|
||||||
|
C2f,
|
||||||
|
C3Ghost,
|
||||||
|
C3x,
|
||||||
|
Classify,
|
||||||
|
Concat,
|
||||||
|
Conv,
|
||||||
|
Conv2,
|
||||||
|
ConvTranspose,
|
||||||
|
Detect,
|
||||||
|
DWConv,
|
||||||
|
DWConvTranspose2d,
|
||||||
|
Focus,
|
||||||
|
GhostBottleneck,
|
||||||
|
GhostConv,
|
||||||
|
HGBlock,
|
||||||
|
HGStem,
|
||||||
|
Pose,
|
||||||
|
RepC3,
|
||||||
|
RepConv,
|
||||||
|
ResNetLayer,
|
||||||
|
RTDETRDecoder,
|
||||||
|
Segment,
|
||||||
|
)
|
||||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
||||||
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
|
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
|
||||||
from ultralytics.utils.plotting import feature_visualization
|
from ultralytics.utils.plotting import feature_visualization
|
||||||
from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
|
from ultralytics.utils.torch_utils import (
|
||||||
make_divisible, model_info, scale_img, time_sync)
|
fuse_conv_and_bn,
|
||||||
|
fuse_deconv_and_bn,
|
||||||
|
initialize_weights,
|
||||||
|
intersect_dicts,
|
||||||
|
make_divisible,
|
||||||
|
model_info,
|
||||||
|
scale_img,
|
||||||
|
time_sync,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import thop
|
import thop
|
||||||
@ -90,8 +128,10 @@ class BaseModel(nn.Module):
|
|||||||
|
|
||||||
def _predict_augment(self, x):
|
def _predict_augment(self, x):
|
||||||
"""Perform augmentations on input image x and return augmented inference."""
|
"""Perform augmentations on input image x and return augmented inference."""
|
||||||
LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. '
|
LOGGER.warning(
|
||||||
f'Reverting to single-scale inference instead.')
|
f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
|
||||||
|
f"Reverting to single-scale inference instead."
|
||||||
|
)
|
||||||
return self._predict_once(x)
|
return self._predict_once(x)
|
||||||
|
|
||||||
def _profile_one_layer(self, m, x, dt):
|
def _profile_one_layer(self, m, x, dt):
|
||||||
@ -108,14 +148,14 @@ class BaseModel(nn.Module):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
||||||
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs
|
||||||
t = time_sync()
|
t = time_sync()
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
m(x.copy() if c else x)
|
m(x.copy() if c else x)
|
||||||
dt.append((time_sync() - t) * 100)
|
dt.append((time_sync() - t) * 100)
|
||||||
if m == self.model[0]:
|
if m == self.model[0]:
|
||||||
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
||||||
LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}')
|
LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}")
|
||||||
if c:
|
if c:
|
||||||
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
||||||
|
|
||||||
@ -129,15 +169,15 @@ class BaseModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if not self.is_fused():
|
if not self.is_fused():
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
|
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
|
||||||
if isinstance(m, Conv2):
|
if isinstance(m, Conv2):
|
||||||
m.fuse_convs()
|
m.fuse_convs()
|
||||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||||
delattr(m, 'bn') # remove batchnorm
|
delattr(m, "bn") # remove batchnorm
|
||||||
m.forward = m.forward_fuse # update forward
|
m.forward = m.forward_fuse # update forward
|
||||||
if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
|
if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
|
||||||
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
|
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
|
||||||
delattr(m, 'bn') # remove batchnorm
|
delattr(m, "bn") # remove batchnorm
|
||||||
m.forward = m.forward_fuse # update forward
|
m.forward = m.forward_fuse # update forward
|
||||||
if isinstance(m, RepConv):
|
if isinstance(m, RepConv):
|
||||||
m.fuse_convs()
|
m.fuse_convs()
|
||||||
@ -156,7 +196,7 @@ class BaseModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
||||||
"""
|
"""
|
||||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||||
|
|
||||||
def info(self, detailed=False, verbose=True, imgsz=640):
|
def info(self, detailed=False, verbose=True, imgsz=640):
|
||||||
@ -196,12 +236,12 @@ class BaseModel(nn.Module):
|
|||||||
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
||||||
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
|
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
|
||||||
"""
|
"""
|
||||||
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||||
csd = model.float().state_dict() # checkpoint state_dict as FP32
|
csd = model.float().state_dict() # checkpoint state_dict as FP32
|
||||||
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
||||||
self.load_state_dict(csd, strict=False) # load
|
self.load_state_dict(csd, strict=False) # load
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
|
||||||
|
|
||||||
def loss(self, batch, preds=None):
|
def loss(self, batch, preds=None):
|
||||||
"""
|
"""
|
||||||
@ -211,33 +251,33 @@ class BaseModel(nn.Module):
|
|||||||
batch (dict): Batch to compute loss on
|
batch (dict): Batch to compute loss on
|
||||||
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, 'criterion'):
|
if not hasattr(self, "criterion"):
|
||||||
self.criterion = self.init_criterion()
|
self.criterion = self.init_criterion()
|
||||||
|
|
||||||
preds = self.forward(batch['img']) if preds is None else preds
|
preds = self.forward(batch["img"]) if preds is None else preds
|
||||||
return self.criterion(preds, batch)
|
return self.criterion(preds, batch)
|
||||||
|
|
||||||
def init_criterion(self):
|
def init_criterion(self):
|
||||||
"""Initialize the loss criterion for the BaseModel."""
|
"""Initialize the loss criterion for the BaseModel."""
|
||||||
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
|
raise NotImplementedError("compute_loss() needs to be implemented by task heads")
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(BaseModel):
|
class DetectionModel(BaseModel):
|
||||||
"""YOLOv8 detection model."""
|
"""YOLOv8 detection model."""
|
||||||
|
|
||||||
def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
||||||
"""Initialize the YOLOv8 detection model with the given config and parameters."""
|
"""Initialize the YOLOv8 detection model with the given config and parameters."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
||||||
|
|
||||||
# Define model
|
# Define model
|
||||||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
|
||||||
if nc and nc != self.yaml['nc']:
|
if nc and nc != self.yaml["nc"]:
|
||||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||||
self.yaml['nc'] = nc # override YAML value
|
self.yaml["nc"] = nc # override YAML value
|
||||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
||||||
self.inplace = self.yaml.get('inplace', True)
|
self.inplace = self.yaml.get("inplace", True)
|
||||||
|
|
||||||
# Build strides
|
# Build strides
|
||||||
m = self.model[-1] # Detect()
|
m = self.model[-1] # Detect()
|
||||||
@ -255,7 +295,7 @@ class DetectionModel(BaseModel):
|
|||||||
initialize_weights(self)
|
initialize_weights(self)
|
||||||
if verbose:
|
if verbose:
|
||||||
self.info()
|
self.info()
|
||||||
LOGGER.info('')
|
LOGGER.info("")
|
||||||
|
|
||||||
def _predict_augment(self, x):
|
def _predict_augment(self, x):
|
||||||
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
||||||
@ -285,9 +325,9 @@ class DetectionModel(BaseModel):
|
|||||||
def _clip_augmented(self, y):
|
def _clip_augmented(self, y):
|
||||||
"""Clip YOLO augmented inference tails."""
|
"""Clip YOLO augmented inference tails."""
|
||||||
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
||||||
g = sum(4 ** x for x in range(nl)) # grid points
|
g = sum(4**x for x in range(nl)) # grid points
|
||||||
e = 1 # exclude layer count
|
e = 1 # exclude layer count
|
||||||
i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
|
i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices
|
||||||
y[0] = y[0][..., :-i] # large
|
y[0] = y[0][..., :-i] # large
|
||||||
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
||||||
y[-1] = y[-1][..., i:] # small
|
y[-1] = y[-1][..., i:] # small
|
||||||
@ -301,7 +341,7 @@ class DetectionModel(BaseModel):
|
|||||||
class OBBModel(DetectionModel):
|
class OBBModel(DetectionModel):
|
||||||
""""YOLOv8 Oriented Bounding Box (OBB) model."""
|
""""YOLOv8 Oriented Bounding Box (OBB) model."""
|
||||||
|
|
||||||
def __init__(self, cfg='yolov8n-obb.yaml', ch=3, nc=None, verbose=True):
|
def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True):
|
||||||
"""Initialize YOLOv8 OBB model with given config and parameters."""
|
"""Initialize YOLOv8 OBB model with given config and parameters."""
|
||||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||||
|
|
||||||
@ -313,7 +353,7 @@ class OBBModel(DetectionModel):
|
|||||||
class SegmentationModel(DetectionModel):
|
class SegmentationModel(DetectionModel):
|
||||||
"""YOLOv8 segmentation model."""
|
"""YOLOv8 segmentation model."""
|
||||||
|
|
||||||
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
|
def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True):
|
||||||
"""Initialize YOLOv8 segmentation model with given config and parameters."""
|
"""Initialize YOLOv8 segmentation model with given config and parameters."""
|
||||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||||
|
|
||||||
@ -325,13 +365,13 @@ class SegmentationModel(DetectionModel):
|
|||||||
class PoseModel(DetectionModel):
|
class PoseModel(DetectionModel):
|
||||||
"""YOLOv8 pose model."""
|
"""YOLOv8 pose model."""
|
||||||
|
|
||||||
def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
||||||
"""Initialize YOLOv8 Pose model."""
|
"""Initialize YOLOv8 Pose model."""
|
||||||
if not isinstance(cfg, dict):
|
if not isinstance(cfg, dict):
|
||||||
cfg = yaml_model_load(cfg) # load model YAML
|
cfg = yaml_model_load(cfg) # load model YAML
|
||||||
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
|
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
|
||||||
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
|
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
|
||||||
cfg['kpt_shape'] = data_kpt_shape
|
cfg["kpt_shape"] = data_kpt_shape
|
||||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||||
|
|
||||||
def init_criterion(self):
|
def init_criterion(self):
|
||||||
@ -342,7 +382,7 @@ class PoseModel(DetectionModel):
|
|||||||
class ClassificationModel(BaseModel):
|
class ClassificationModel(BaseModel):
|
||||||
"""YOLOv8 classification model."""
|
"""YOLOv8 classification model."""
|
||||||
|
|
||||||
def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True):
|
def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True):
|
||||||
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
|
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._from_yaml(cfg, ch, nc, verbose)
|
self._from_yaml(cfg, ch, nc, verbose)
|
||||||
@ -352,21 +392,21 @@ class ClassificationModel(BaseModel):
|
|||||||
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
||||||
|
|
||||||
# Define model
|
# Define model
|
||||||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
|
||||||
if nc and nc != self.yaml['nc']:
|
if nc and nc != self.yaml["nc"]:
|
||||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||||
self.yaml['nc'] = nc # override YAML value
|
self.yaml["nc"] = nc # override YAML value
|
||||||
elif not nc and not self.yaml.get('nc', None):
|
elif not nc and not self.yaml.get("nc", None):
|
||||||
raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
|
raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
|
||||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||||
self.stride = torch.Tensor([1]) # no stride constraints
|
self.stride = torch.Tensor([1]) # no stride constraints
|
||||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
||||||
self.info()
|
self.info()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def reshape_outputs(model, nc):
|
def reshape_outputs(model, nc):
|
||||||
"""Update a TorchVision classification model to class count 'n' if required."""
|
"""Update a TorchVision classification model to class count 'n' if required."""
|
||||||
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
|
||||||
if isinstance(m, Classify): # YOLO Classify() head
|
if isinstance(m, Classify): # YOLO Classify() head
|
||||||
if m.linear.out_features != nc:
|
if m.linear.out_features != nc:
|
||||||
m.linear = nn.Linear(m.linear.in_features, nc)
|
m.linear = nn.Linear(m.linear.in_features, nc)
|
||||||
@ -409,7 +449,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||||||
predict: Performs a forward pass through the network and returns the output.
|
predict: Performs a forward pass through the network and returns the output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
|
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
|
||||||
"""
|
"""
|
||||||
Initialize the RTDETRDetectionModel.
|
Initialize the RTDETRDetectionModel.
|
||||||
|
|
||||||
@ -438,39 +478,39 @@ class RTDETRDetectionModel(DetectionModel):
|
|||||||
Returns:
|
Returns:
|
||||||
(tuple): A tuple containing the total loss and main three losses in a tensor.
|
(tuple): A tuple containing the total loss and main three losses in a tensor.
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, 'criterion'):
|
if not hasattr(self, "criterion"):
|
||||||
self.criterion = self.init_criterion()
|
self.criterion = self.init_criterion()
|
||||||
|
|
||||||
img = batch['img']
|
img = batch["img"]
|
||||||
# NOTE: preprocess gt_bbox and gt_labels to list.
|
# NOTE: preprocess gt_bbox and gt_labels to list.
|
||||||
bs = len(img)
|
bs = len(img)
|
||||||
batch_idx = batch['batch_idx']
|
batch_idx = batch["batch_idx"]
|
||||||
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
|
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
|
||||||
targets = {
|
targets = {
|
||||||
'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
|
"cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
|
||||||
'bboxes': batch['bboxes'].to(device=img.device),
|
"bboxes": batch["bboxes"].to(device=img.device),
|
||||||
'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
|
"batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
|
||||||
'gt_groups': gt_groups}
|
"gt_groups": gt_groups,
|
||||||
|
}
|
||||||
|
|
||||||
preds = self.predict(img, batch=targets) if preds is None else preds
|
preds = self.predict(img, batch=targets) if preds is None else preds
|
||||||
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
|
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
|
||||||
if dn_meta is None:
|
if dn_meta is None:
|
||||||
dn_bboxes, dn_scores = None, None
|
dn_bboxes, dn_scores = None, None
|
||||||
else:
|
else:
|
||||||
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
|
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
|
||||||
dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
|
dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)
|
||||||
|
|
||||||
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
|
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
|
||||||
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
|
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
|
||||||
|
|
||||||
loss = self.criterion((dec_bboxes, dec_scores),
|
loss = self.criterion(
|
||||||
targets,
|
(dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
|
||||||
dn_bboxes=dn_bboxes,
|
)
|
||||||
dn_scores=dn_scores,
|
|
||||||
dn_meta=dn_meta)
|
|
||||||
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
|
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
|
||||||
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
|
return sum(loss.values()), torch.as_tensor(
|
||||||
device=img.device)
|
[loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
|
||||||
|
)
|
||||||
|
|
||||||
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
||||||
"""
|
"""
|
||||||
@ -553,6 +593,7 @@ def temporary_modules(modules=None):
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set modules in sys.modules under their old name
|
# Set modules in sys.modules under their old name
|
||||||
for old, new in modules.items():
|
for old, new in modules.items():
|
||||||
@ -580,30 +621,38 @@ def torch_safe_load(weight):
|
|||||||
"""
|
"""
|
||||||
from ultralytics.utils.downloads import attempt_download_asset
|
from ultralytics.utils.downloads import attempt_download_asset
|
||||||
|
|
||||||
check_suffix(file=weight, suffix='.pt')
|
check_suffix(file=weight, suffix=".pt")
|
||||||
file = attempt_download_asset(weight) # search online if missing locally
|
file = attempt_download_asset(weight) # search online if missing locally
|
||||||
try:
|
try:
|
||||||
with temporary_modules({
|
with temporary_modules(
|
||||||
'ultralytics.yolo.utils': 'ultralytics.utils',
|
{
|
||||||
'ultralytics.yolo.v8': 'ultralytics.models.yolo',
|
"ultralytics.yolo.utils": "ultralytics.utils",
|
||||||
'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models
|
"ultralytics.yolo.v8": "ultralytics.models.yolo",
|
||||||
return torch.load(file, map_location='cpu'), file # load
|
"ultralytics.yolo.data": "ultralytics.data",
|
||||||
|
}
|
||||||
|
): # for legacy 8.0 Classify and Pose models
|
||||||
|
return torch.load(file, map_location="cpu"), file # load
|
||||||
|
|
||||||
except ModuleNotFoundError as e: # e.name is missing module name
|
except ModuleNotFoundError as e: # e.name is missing module name
|
||||||
if e.name == 'models':
|
if e.name == "models":
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
|
emojis(
|
||||||
f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
|
f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
|
||||||
f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
|
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
|
||||||
|
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
|
||||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
|
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
|
||||||
LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
|
)
|
||||||
|
) from e
|
||||||
|
LOGGER.warning(
|
||||||
|
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
|
||||||
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
||||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
|
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
|
||||||
|
)
|
||||||
check_requirements(e.name) # install missing module
|
check_requirements(e.name) # install missing module
|
||||||
|
|
||||||
return torch.load(file, map_location='cpu'), file # load
|
return torch.load(file, map_location="cpu"), file # load
|
||||||
|
|
||||||
|
|
||||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||||
@ -612,25 +661,25 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||||||
ensemble = Ensemble()
|
ensemble = Ensemble()
|
||||||
for w in weights if isinstance(weights, list) else [weights]:
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
ckpt, w = torch_safe_load(w) # load ckpt
|
ckpt, w = torch_safe_load(w) # load ckpt
|
||||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args
|
args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
|
||||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
||||||
|
|
||||||
# Model compatibility updates
|
# Model compatibility updates
|
||||||
model.args = args # attach args to model
|
model.args = args # attach args to model
|
||||||
model.pt_path = w # attach *.pt file path to model
|
model.pt_path = w # attach *.pt file path to model
|
||||||
model.task = guess_model_task(model)
|
model.task = guess_model_task(model)
|
||||||
if not hasattr(model, 'stride'):
|
if not hasattr(model, "stride"):
|
||||||
model.stride = torch.tensor([32.])
|
model.stride = torch.tensor([32.0])
|
||||||
|
|
||||||
# Append
|
# Append
|
||||||
ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
|
ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
|
||||||
|
|
||||||
# Module updates
|
# Module updates
|
||||||
for m in ensemble.modules():
|
for m in ensemble.modules():
|
||||||
t = type(m)
|
t = type(m)
|
||||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
||||||
m.inplace = inplace
|
m.inplace = inplace
|
||||||
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
||||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||||
|
|
||||||
# Return model
|
# Return model
|
||||||
@ -638,35 +687,35 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||||||
return ensemble[-1]
|
return ensemble[-1]
|
||||||
|
|
||||||
# Return ensemble
|
# Return ensemble
|
||||||
LOGGER.info(f'Ensemble created with {weights}\n')
|
LOGGER.info(f"Ensemble created with {weights}\n")
|
||||||
for k in 'names', 'nc', 'yaml':
|
for k in "names", "nc", "yaml":
|
||||||
setattr(ensemble, k, getattr(ensemble[0], k))
|
setattr(ensemble, k, getattr(ensemble[0], k))
|
||||||
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
|
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
|
||||||
assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
|
assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
|
||||||
return ensemble
|
return ensemble
|
||||||
|
|
||||||
|
|
||||||
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||||
"""Loads a single model weights."""
|
"""Loads a single model weights."""
|
||||||
ckpt, weight = torch_safe_load(weight) # load ckpt
|
ckpt, weight = torch_safe_load(weight) # load ckpt
|
||||||
args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
|
args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
|
||||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
||||||
|
|
||||||
# Model compatibility updates
|
# Model compatibility updates
|
||||||
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||||
model.pt_path = weight # attach *.pt file path to model
|
model.pt_path = weight # attach *.pt file path to model
|
||||||
model.task = guess_model_task(model)
|
model.task = guess_model_task(model)
|
||||||
if not hasattr(model, 'stride'):
|
if not hasattr(model, "stride"):
|
||||||
model.stride = torch.tensor([32.])
|
model.stride = torch.tensor([32.0])
|
||||||
|
|
||||||
model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
|
model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode
|
||||||
|
|
||||||
# Module updates
|
# Module updates
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
t = type(m)
|
t = type(m)
|
||||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
||||||
m.inplace = inplace
|
m.inplace = inplace
|
||||||
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
||||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||||
|
|
||||||
# Return model and ckpt
|
# Return model and ckpt
|
||||||
@ -678,11 +727,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||||||
import ast
|
import ast
|
||||||
|
|
||||||
# Args
|
# Args
|
||||||
max_channels = float('inf')
|
max_channels = float("inf")
|
||||||
nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
|
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
||||||
depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
|
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
||||||
if scales:
|
if scales:
|
||||||
scale = d.get('scale')
|
scale = d.get("scale")
|
||||||
if not scale:
|
if not scale:
|
||||||
scale = tuple(scales.keys())[0]
|
scale = tuple(scales.keys())[0]
|
||||||
LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
|
LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
|
||||||
@ -697,16 +746,37 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||||||
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
||||||
ch = [ch]
|
ch = [ch]
|
||||||
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
||||||
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
|
||||||
m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
|
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
|
||||||
for j, a in enumerate(args):
|
for j, a in enumerate(args):
|
||||||
if isinstance(a, str):
|
if isinstance(a, str):
|
||||||
with contextlib.suppress(ValueError):
|
with contextlib.suppress(ValueError):
|
||||||
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
||||||
|
|
||||||
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
||||||
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
|
if m in (
|
||||||
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
|
Classify,
|
||||||
|
Conv,
|
||||||
|
ConvTranspose,
|
||||||
|
GhostConv,
|
||||||
|
Bottleneck,
|
||||||
|
GhostBottleneck,
|
||||||
|
SPP,
|
||||||
|
SPPF,
|
||||||
|
DWConv,
|
||||||
|
Focus,
|
||||||
|
BottleneckCSP,
|
||||||
|
C1,
|
||||||
|
C2,
|
||||||
|
C2f,
|
||||||
|
C3,
|
||||||
|
C3TR,
|
||||||
|
C3Ghost,
|
||||||
|
nn.ConvTranspose2d,
|
||||||
|
DWConvTranspose2d,
|
||||||
|
C3x,
|
||||||
|
RepC3,
|
||||||
|
):
|
||||||
c1, c2 = ch[f], args[0]
|
c1, c2 = ch[f], args[0]
|
||||||
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||||
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
||||||
@ -739,11 +809,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||||||
c2 = ch[f]
|
c2 = ch[f]
|
||||||
|
|
||||||
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
||||||
t = str(m)[8:-2].replace('__main__.', '') # module type
|
t = str(m)[8:-2].replace("__main__.", "") # module type
|
||||||
m.np = sum(x.numel() for x in m_.parameters()) # number params
|
m.np = sum(x.numel() for x in m_.parameters()) # number params
|
||||||
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
|
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
|
||||||
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
||||||
layers.append(m_)
|
layers.append(m_)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@ -757,16 +827,16 @@ def yaml_model_load(path):
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)):
|
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
|
||||||
new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem)
|
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
|
||||||
LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.')
|
LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
|
||||||
path = path.with_name(new_stem + path.suffix)
|
path = path.with_name(new_stem + path.suffix)
|
||||||
|
|
||||||
unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
|
unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
|
||||||
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
|
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
|
||||||
d = yaml_load(yaml_file) # model dict
|
d = yaml_load(yaml_file) # model dict
|
||||||
d['scale'] = guess_model_scale(path)
|
d["scale"] = guess_model_scale(path)
|
||||||
d['yaml_file'] = str(path)
|
d["yaml_file"] = str(path)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
@ -784,8 +854,9 @@ def guess_model_scale(model_path):
|
|||||||
"""
|
"""
|
||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
import re
|
import re
|
||||||
return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x
|
|
||||||
return ''
|
return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def guess_model_task(model):
|
def guess_model_task(model):
|
||||||
@ -804,17 +875,17 @@ def guess_model_task(model):
|
|||||||
|
|
||||||
def cfg2task(cfg):
|
def cfg2task(cfg):
|
||||||
"""Guess from YAML dictionary."""
|
"""Guess from YAML dictionary."""
|
||||||
m = cfg['head'][-1][-2].lower() # output module name
|
m = cfg["head"][-1][-2].lower() # output module name
|
||||||
if m in ('classify', 'classifier', 'cls', 'fc'):
|
if m in ("classify", "classifier", "cls", "fc"):
|
||||||
return 'classify'
|
return "classify"
|
||||||
if m == 'detect':
|
if m == "detect":
|
||||||
return 'detect'
|
return "detect"
|
||||||
if m == 'segment':
|
if m == "segment":
|
||||||
return 'segment'
|
return "segment"
|
||||||
if m == 'pose':
|
if m == "pose":
|
||||||
return 'pose'
|
return "pose"
|
||||||
if m == 'obb':
|
if m == "obb":
|
||||||
return 'obb'
|
return "obb"
|
||||||
|
|
||||||
# Guess from model cfg
|
# Guess from model cfg
|
||||||
if isinstance(model, dict):
|
if isinstance(model, dict):
|
||||||
@ -823,40 +894,42 @@ def guess_model_task(model):
|
|||||||
|
|
||||||
# Guess from PyTorch model
|
# Guess from PyTorch model
|
||||||
if isinstance(model, nn.Module): # PyTorch model
|
if isinstance(model, nn.Module): # PyTorch model
|
||||||
for x in 'model.args', 'model.model.args', 'model.model.model.args':
|
for x in "model.args", "model.model.args", "model.model.model.args":
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
return eval(x)['task']
|
return eval(x)["task"]
|
||||||
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
|
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
return cfg2task(eval(x))
|
return cfg2task(eval(x))
|
||||||
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, Detect):
|
if isinstance(m, Detect):
|
||||||
return 'detect'
|
return "detect"
|
||||||
elif isinstance(m, Segment):
|
elif isinstance(m, Segment):
|
||||||
return 'segment'
|
return "segment"
|
||||||
elif isinstance(m, Classify):
|
elif isinstance(m, Classify):
|
||||||
return 'classify'
|
return "classify"
|
||||||
elif isinstance(m, Pose):
|
elif isinstance(m, Pose):
|
||||||
return 'pose'
|
return "pose"
|
||||||
elif isinstance(m, OBB):
|
elif isinstance(m, OBB):
|
||||||
return 'obb'
|
return "obb"
|
||||||
|
|
||||||
# Guess from model filename
|
# Guess from model filename
|
||||||
if isinstance(model, (str, Path)):
|
if isinstance(model, (str, Path)):
|
||||||
model = Path(model)
|
model = Path(model)
|
||||||
if '-seg' in model.stem or 'segment' in model.parts:
|
if "-seg" in model.stem or "segment" in model.parts:
|
||||||
return 'segment'
|
return "segment"
|
||||||
elif '-cls' in model.stem or 'classify' in model.parts:
|
elif "-cls" in model.stem or "classify" in model.parts:
|
||||||
return 'classify'
|
return "classify"
|
||||||
elif '-pose' in model.stem or 'pose' in model.parts:
|
elif "-pose" in model.stem or "pose" in model.parts:
|
||||||
return 'pose'
|
return "pose"
|
||||||
elif '-obb' in model.stem or 'obb' in model.parts:
|
elif "-obb" in model.stem or "obb" in model.parts:
|
||||||
return 'obb'
|
return "obb"
|
||||||
elif 'detect' in model.parts:
|
elif "detect" in model.parts:
|
||||||
return 'detect'
|
return "detect"
|
||||||
|
|
||||||
# Unable to determine task from model
|
# Unable to determine task from model
|
||||||
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
LOGGER.warning(
|
||||||
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.")
|
"WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
||||||
return 'detect' # assume detect
|
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
|
||||||
|
)
|
||||||
|
return "detect" # assume detect
|
||||||
|
|||||||
@ -26,7 +26,7 @@ class AIGym:
|
|||||||
self.angle = None
|
self.angle = None
|
||||||
self.count = None
|
self.count = None
|
||||||
self.stage = None
|
self.stage = None
|
||||||
self.pose_type = 'pushup'
|
self.pose_type = "pushup"
|
||||||
self.kpts_to_check = None
|
self.kpts_to_check = None
|
||||||
|
|
||||||
# Visual Information
|
# Visual Information
|
||||||
@ -36,13 +36,15 @@ class AIGym:
|
|||||||
# Check if environment support imshow
|
# Check if environment support imshow
|
||||||
self.env_check = check_imshow(warn=True)
|
self.env_check = check_imshow(warn=True)
|
||||||
|
|
||||||
def set_args(self,
|
def set_args(
|
||||||
|
self,
|
||||||
kpts_to_check,
|
kpts_to_check,
|
||||||
line_thickness=2,
|
line_thickness=2,
|
||||||
view_img=False,
|
view_img=False,
|
||||||
pose_up_angle=145.0,
|
pose_up_angle=145.0,
|
||||||
pose_down_angle=90.0,
|
pose_down_angle=90.0,
|
||||||
pose_type='pullup'):
|
pose_type="pullup",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Configures the AIGym line_thickness, save image and view image parameters
|
Configures the AIGym line_thickness, save image and view image parameters
|
||||||
Args:
|
Args:
|
||||||
@ -72,65 +74,75 @@ class AIGym:
|
|||||||
if frame_count == 1:
|
if frame_count == 1:
|
||||||
self.count = [0] * len(results[0])
|
self.count = [0] * len(results[0])
|
||||||
self.angle = [0] * len(results[0])
|
self.angle = [0] * len(results[0])
|
||||||
self.stage = ['-' for _ in results[0]]
|
self.stage = ["-" for _ in results[0]]
|
||||||
self.keypoints = results[0].keypoints.data
|
self.keypoints = results[0].keypoints.data
|
||||||
self.annotator = Annotator(im0, line_width=2)
|
self.annotator = Annotator(im0, line_width=2)
|
||||||
|
|
||||||
for ind, k in enumerate(reversed(self.keypoints)):
|
for ind, k in enumerate(reversed(self.keypoints)):
|
||||||
if self.pose_type == 'pushup' or self.pose_type == 'pullup':
|
if self.pose_type == "pushup" or self.pose_type == "pullup":
|
||||||
self.angle[ind] = self.annotator.estimate_pose_angle(k[int(self.kpts_to_check[0])].cpu(),
|
self.angle[ind] = self.annotator.estimate_pose_angle(
|
||||||
|
k[int(self.kpts_to_check[0])].cpu(),
|
||||||
k[int(self.kpts_to_check[1])].cpu(),
|
k[int(self.kpts_to_check[1])].cpu(),
|
||||||
k[int(self.kpts_to_check[2])].cpu())
|
k[int(self.kpts_to_check[2])].cpu(),
|
||||||
|
)
|
||||||
self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10)
|
self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10)
|
||||||
|
|
||||||
if self.pose_type == 'abworkout':
|
if self.pose_type == "abworkout":
|
||||||
self.angle[ind] = self.annotator.estimate_pose_angle(k[int(self.kpts_to_check[0])].cpu(),
|
self.angle[ind] = self.annotator.estimate_pose_angle(
|
||||||
|
k[int(self.kpts_to_check[0])].cpu(),
|
||||||
k[int(self.kpts_to_check[1])].cpu(),
|
k[int(self.kpts_to_check[1])].cpu(),
|
||||||
k[int(self.kpts_to_check[2])].cpu())
|
k[int(self.kpts_to_check[2])].cpu(),
|
||||||
|
)
|
||||||
self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10)
|
self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10)
|
||||||
if self.angle[ind] > self.poseup_angle:
|
if self.angle[ind] > self.poseup_angle:
|
||||||
self.stage[ind] = 'down'
|
self.stage[ind] = "down"
|
||||||
if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'down':
|
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down":
|
||||||
self.stage[ind] = 'up'
|
self.stage[ind] = "up"
|
||||||
self.count[ind] += 1
|
self.count[ind] += 1
|
||||||
self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind],
|
self.annotator.plot_angle_and_count_and_stage(
|
||||||
|
angle_text=self.angle[ind],
|
||||||
count_text=self.count[ind],
|
count_text=self.count[ind],
|
||||||
stage_text=self.stage[ind],
|
stage_text=self.stage[ind],
|
||||||
center_kpt=k[int(self.kpts_to_check[1])],
|
center_kpt=k[int(self.kpts_to_check[1])],
|
||||||
line_thickness=self.tf)
|
line_thickness=self.tf,
|
||||||
|
)
|
||||||
|
|
||||||
if self.pose_type == 'pushup':
|
if self.pose_type == "pushup":
|
||||||
if self.angle[ind] > self.poseup_angle:
|
if self.angle[ind] > self.poseup_angle:
|
||||||
self.stage[ind] = 'up'
|
self.stage[ind] = "up"
|
||||||
if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'up':
|
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up":
|
||||||
self.stage[ind] = 'down'
|
self.stage[ind] = "down"
|
||||||
self.count[ind] += 1
|
self.count[ind] += 1
|
||||||
self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind],
|
self.annotator.plot_angle_and_count_and_stage(
|
||||||
|
angle_text=self.angle[ind],
|
||||||
count_text=self.count[ind],
|
count_text=self.count[ind],
|
||||||
stage_text=self.stage[ind],
|
stage_text=self.stage[ind],
|
||||||
center_kpt=k[int(self.kpts_to_check[1])],
|
center_kpt=k[int(self.kpts_to_check[1])],
|
||||||
line_thickness=self.tf)
|
line_thickness=self.tf,
|
||||||
if self.pose_type == 'pullup':
|
)
|
||||||
|
if self.pose_type == "pullup":
|
||||||
if self.angle[ind] > self.poseup_angle:
|
if self.angle[ind] > self.poseup_angle:
|
||||||
self.stage[ind] = 'down'
|
self.stage[ind] = "down"
|
||||||
if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'down':
|
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down":
|
||||||
self.stage[ind] = 'up'
|
self.stage[ind] = "up"
|
||||||
self.count[ind] += 1
|
self.count[ind] += 1
|
||||||
self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind],
|
self.annotator.plot_angle_and_count_and_stage(
|
||||||
|
angle_text=self.angle[ind],
|
||||||
count_text=self.count[ind],
|
count_text=self.count[ind],
|
||||||
stage_text=self.stage[ind],
|
stage_text=self.stage[ind],
|
||||||
center_kpt=k[int(self.kpts_to_check[1])],
|
center_kpt=k[int(self.kpts_to_check[1])],
|
||||||
line_thickness=self.tf)
|
line_thickness=self.tf,
|
||||||
|
)
|
||||||
|
|
||||||
self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True)
|
self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True)
|
||||||
|
|
||||||
if self.env_check and self.view_img:
|
if self.env_check and self.view_img:
|
||||||
cv2.imshow('Ultralytics YOLOv8 AI GYM', self.im0)
|
cv2.imshow("Ultralytics YOLOv8 AI GYM", self.im0)
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
return
|
return
|
||||||
|
|
||||||
return self.im0
|
return self.im0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
AIGym()
|
AIGym()
|
||||||
|
|||||||
@ -41,13 +41,15 @@ class DistanceCalculation:
|
|||||||
# Check if environment support imshow
|
# Check if environment support imshow
|
||||||
self.env_check = check_imshow(warn=True)
|
self.env_check = check_imshow(warn=True)
|
||||||
|
|
||||||
def set_args(self,
|
def set_args(
|
||||||
|
self,
|
||||||
names,
|
names,
|
||||||
pixels_per_meter=10,
|
pixels_per_meter=10,
|
||||||
view_img=False,
|
view_img=False,
|
||||||
line_thickness=2,
|
line_thickness=2,
|
||||||
line_color=(255, 255, 0),
|
line_color=(255, 255, 0),
|
||||||
centroid_color=(255, 0, 255)):
|
centroid_color=(255, 0, 255),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Configures the distance calculation and display parameters.
|
Configures the distance calculation and display parameters.
|
||||||
|
|
||||||
@ -129,8 +131,9 @@ class DistanceCalculation:
|
|||||||
distance (float): Distance between two centroids
|
distance (float): Distance between two centroids
|
||||||
"""
|
"""
|
||||||
cv2.rectangle(self.im0, (15, 25), (280, 70), (255, 255, 255), -1)
|
cv2.rectangle(self.im0, (15, 25), (280, 70), (255, 255, 255), -1)
|
||||||
cv2.putText(self.im0, f'Distance : {distance:.2f}m', (20, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2,
|
cv2.putText(
|
||||||
cv2.LINE_AA)
|
self.im0, f"Distance : {distance:.2f}m", (20, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2, cv2.LINE_AA
|
||||||
|
)
|
||||||
cv2.line(self.im0, self.centroids[0], self.centroids[1], self.line_color, 3)
|
cv2.line(self.im0, self.centroids[0], self.centroids[1], self.line_color, 3)
|
||||||
cv2.circle(self.im0, self.centroids[0], 6, self.centroid_color, -1)
|
cv2.circle(self.im0, self.centroids[0], 6, self.centroid_color, -1)
|
||||||
cv2.circle(self.im0, self.centroids[1], 6, self.centroid_color, -1)
|
cv2.circle(self.im0, self.centroids[1], 6, self.centroid_color, -1)
|
||||||
@ -179,13 +182,13 @@ class DistanceCalculation:
|
|||||||
|
|
||||||
def display_frames(self):
|
def display_frames(self):
|
||||||
"""Display frame."""
|
"""Display frame."""
|
||||||
cv2.namedWindow('Ultralytics Distance Estimation')
|
cv2.namedWindow("Ultralytics Distance Estimation")
|
||||||
cv2.setMouseCallback('Ultralytics Distance Estimation', self.mouse_event_for_distance)
|
cv2.setMouseCallback("Ultralytics Distance Estimation", self.mouse_event_for_distance)
|
||||||
cv2.imshow('Ultralytics Distance Estimation', self.im0)
|
cv2.imshow("Ultralytics Distance Estimation", self.im0)
|
||||||
|
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
DistanceCalculation()
|
DistanceCalculation()
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user