mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 05:55:51 +08:00
Update .pre-commit-config.yaml
(#1026)
This commit is contained in:
parent
9047d737f4
commit
edd3ff1669
@ -1,8 +1,5 @@
|
|||||||
# Define hooks for code formations
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
# Will be applied on any updated commit files if a user has installed and linked commit hook
|
# Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md
|
||||||
|
|
||||||
default_language_version:
|
|
||||||
python: python3.8
|
|
||||||
|
|
||||||
exclude: 'docs/'
|
exclude: 'docs/'
|
||||||
# Define bot property if installed via https://github.com/marketplace/pre-commit-ci
|
# Define bot property if installed via https://github.com/marketplace/pre-commit-ci
|
||||||
@ -16,13 +13,13 @@ repos:
|
|||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.4.0
|
rev: v4.4.0
|
||||||
hooks:
|
hooks:
|
||||||
# - id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: check-case-conflict
|
- id: check-case-conflict
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-toml
|
|
||||||
- id: pretty-format-json
|
|
||||||
- id: check-docstring-first
|
- id: check-docstring-first
|
||||||
|
- id: double-quote-string-fixer
|
||||||
|
- id: detect-private-key
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v3.3.1
|
rev: v3.3.1
|
||||||
@ -64,7 +61,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: codespell
|
- id: codespell
|
||||||
args:
|
args:
|
||||||
- --ignore-words-list=crate,nd,strack
|
- --ignore-words-list=crate,nd,strack,dota
|
||||||
|
|
||||||
#- repo: https://github.com/asottile/yesqa
|
#- repo: https://github.com/asottile/yesqa
|
||||||
# rev: v1.4.0
|
# rev: v1.4.0
|
||||||
|
@ -31,8 +31,7 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
|
|||||||
# Install pip packages
|
# Install pip packages
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN python3 -m pip install --upgrade pip wheel
|
RUN python3 -m pip install --upgrade pip wheel
|
||||||
RUN pip install --no-cache ultralytics[export] albumentations comet gsutil notebook \
|
RUN pip install --no-cache ultralytics[export] albumentations comet gsutil notebook
|
||||||
# tensorflow tensorflowjs \
|
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV OMP_NUM_THREADS=1
|
ENV OMP_NUM_THREADS=1
|
||||||
|
@ -27,8 +27,6 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
|
|||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN python3 -m pip install --upgrade pip wheel
|
RUN python3 -m pip install --upgrade pip wheel
|
||||||
RUN pip install --no-cache ultralytics albumentations gsutil notebook
|
RUN pip install --no-cache ultralytics albumentations gsutil notebook
|
||||||
# coremltools onnx onnxruntime \
|
|
||||||
# tensorflow-aarch64 tensorflowjs \
|
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
ENV DEBIAN_FRONTEND teletype
|
ENV DEBIAN_FRONTEND teletype
|
||||||
|
@ -27,7 +27,6 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
|
|||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN python3 -m pip install --upgrade pip wheel
|
RUN python3 -m pip install --upgrade pip wheel
|
||||||
RUN pip install --no-cache ultralytics[export] albumentations gsutil notebook \
|
RUN pip install --no-cache ultralytics[export] albumentations gsutil notebook \
|
||||||
# tensorflow-cpu tensorflowjs \
|
|
||||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
|
@ -92,7 +92,7 @@ Export a YOLOv8n model to a different format like ONNX, CoreML, etc.
|
|||||||
|
|
||||||
## Overriding default arguments
|
## Overriding default arguments
|
||||||
|
|
||||||
Default arguments can be overriden by simply passing them as arguments in the CLI in `arg=value` pairs.
|
Default arguments can be overridden by simply passing them as arguments in the CLI in `arg=value` pairs.
|
||||||
|
|
||||||
!!! tip ""
|
!!! tip ""
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ Class reference documentation for `Results` module and its components can be fou
|
|||||||
|
|
||||||
## Visualizing results
|
## Visualizing results
|
||||||
|
|
||||||
You can use `visualize()` function of `Result` object to get a visualization. It plots all componenets(boxes, masks, classification logits, etc) found in the results object
|
You can use `visualize()` function of `Result` object to get a visualization. It plots all components(boxes, masks, classification logits, etc) found in the results object
|
||||||
```python
|
```python
|
||||||
res = model(img)
|
res = model(img)
|
||||||
res_plotted = res[0].visualize()
|
res_plotted = res[0].visualize()
|
||||||
|
@ -2,7 +2,7 @@ The simplest way of simply using YOLOv8 directly in a Python environment.
|
|||||||
|
|
||||||
!!! example "Train"
|
!!! example "Train"
|
||||||
|
|
||||||
=== "From pretrained(recommanded)"
|
=== "From pretrained(recommended)"
|
||||||
```python
|
```python
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@ -16,7 +16,7 @@ PKG_REQUIREMENTS = ['sentry_sdk'] # pip-only requirements
|
|||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
file = PARENT / 'ultralytics/__init__.py'
|
file = PARENT / 'ultralytics/__init__.py'
|
||||||
return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(encoding="utf-8"), re.M)[1]
|
return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(encoding='utf-8'), re.M)[1]
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
|
@ -49,9 +49,9 @@ def test_val_classify():
|
|||||||
# Predict checks -------------------------------------------------------------------------------------------------------
|
# Predict checks -------------------------------------------------------------------------------------------------------
|
||||||
def test_predict_detect():
|
def test_predict_detect():
|
||||||
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
|
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
|
||||||
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32")
|
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
|
||||||
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32")
|
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
|
||||||
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32")
|
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')
|
||||||
|
|
||||||
|
|
||||||
def test_predict_segment():
|
def test_predict_segment():
|
||||||
|
@ -11,12 +11,12 @@ CFG_SEG = 'yolov8n-seg.yaml'
|
|||||||
CFG_CLS = 'squeezenet1_0'
|
CFG_CLS = 'squeezenet1_0'
|
||||||
CFG = get_cfg(DEFAULT_CFG)
|
CFG = get_cfg(DEFAULT_CFG)
|
||||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
||||||
SOURCE = ROOT / "assets"
|
SOURCE = ROOT / 'assets'
|
||||||
|
|
||||||
|
|
||||||
def test_detect():
|
def test_detect():
|
||||||
overrides = {"data": "coco8.yaml", "model": CFG_DET, "imgsz": 32, "epochs": 1, "save": False}
|
overrides = {'data': 'coco8.yaml', 'model': CFG_DET, 'imgsz': 32, 'epochs': 1, 'save': False}
|
||||||
CFG.data = "coco8.yaml"
|
CFG.data = 'coco8.yaml'
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
trainer = detect.DetectionTrainer(overrides=overrides)
|
trainer = detect.DetectionTrainer(overrides=overrides)
|
||||||
@ -27,24 +27,24 @@ def test_detect():
|
|||||||
val(model=trainer.best) # validate best.pt
|
val(model=trainer.best) # validate best.pt
|
||||||
|
|
||||||
# Predictor
|
# Predictor
|
||||||
pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
|
pred = detect.DetectionPredictor(overrides={'imgsz': [64, 64]})
|
||||||
result = pred(source=SOURCE, model=f"{MODEL}.pt")
|
result = pred(source=SOURCE, model=f'{MODEL}.pt')
|
||||||
assert len(result), "predictor test failed"
|
assert len(result), 'predictor test failed'
|
||||||
|
|
||||||
overrides["resume"] = trainer.last
|
overrides['resume'] = trainer.last
|
||||||
trainer = detect.DetectionTrainer(overrides=overrides)
|
trainer = detect.DetectionTrainer(overrides=overrides)
|
||||||
try:
|
try:
|
||||||
trainer.train()
|
trainer.train()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Expected exception caught: {e}")
|
print(f'Expected exception caught: {e}')
|
||||||
return
|
return
|
||||||
|
|
||||||
Exception("Resume test failed!")
|
Exception('Resume test failed!')
|
||||||
|
|
||||||
|
|
||||||
def test_segment():
|
def test_segment():
|
||||||
overrides = {"data": "coco8-seg.yaml", "model": CFG_SEG, "imgsz": 32, "epochs": 1, "save": False}
|
overrides = {'data': 'coco8-seg.yaml', 'model': CFG_SEG, 'imgsz': 32, 'epochs': 1, 'save': False}
|
||||||
CFG.data = "coco8-seg.yaml"
|
CFG.data = 'coco8-seg.yaml'
|
||||||
CFG.v5loader = False
|
CFG.v5loader = False
|
||||||
# YOLO(CFG_SEG).train(**overrides) # works
|
# YOLO(CFG_SEG).train(**overrides) # works
|
||||||
|
|
||||||
@ -57,25 +57,25 @@ def test_segment():
|
|||||||
val(model=trainer.best) # validate best.pt
|
val(model=trainer.best) # validate best.pt
|
||||||
|
|
||||||
# Predictor
|
# Predictor
|
||||||
pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})
|
pred = segment.SegmentationPredictor(overrides={'imgsz': [64, 64]})
|
||||||
result = pred(source=SOURCE, model=f"{MODEL}-seg.pt")
|
result = pred(source=SOURCE, model=f'{MODEL}-seg.pt')
|
||||||
assert len(result) == 2, "predictor test failed"
|
assert len(result) == 2, 'predictor test failed'
|
||||||
|
|
||||||
# Test resume
|
# Test resume
|
||||||
overrides["resume"] = trainer.last
|
overrides['resume'] = trainer.last
|
||||||
trainer = segment.SegmentationTrainer(overrides=overrides)
|
trainer = segment.SegmentationTrainer(overrides=overrides)
|
||||||
try:
|
try:
|
||||||
trainer.train()
|
trainer.train()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Expected exception caught: {e}")
|
print(f'Expected exception caught: {e}')
|
||||||
return
|
return
|
||||||
|
|
||||||
Exception("Resume test failed!")
|
Exception('Resume test failed!')
|
||||||
|
|
||||||
|
|
||||||
def test_classify():
|
def test_classify():
|
||||||
overrides = {"data": "mnist160", "model": "yolov8n-cls.yaml", "imgsz": 32, "epochs": 1, "batch": 64, "save": False}
|
overrides = {'data': 'mnist160', 'model': 'yolov8n-cls.yaml', 'imgsz': 32, 'epochs': 1, 'batch': 64, 'save': False}
|
||||||
CFG.data = "mnist160"
|
CFG.data = 'mnist160'
|
||||||
CFG.imgsz = 32
|
CFG.imgsz = 32
|
||||||
CFG.batch = 64
|
CFG.batch = 64
|
||||||
# YOLO(CFG_SEG).train(**overrides) # works
|
# YOLO(CFG_SEG).train(**overrides) # works
|
||||||
@ -89,6 +89,6 @@ def test_classify():
|
|||||||
val(model=trainer.best)
|
val(model=trainer.best)
|
||||||
|
|
||||||
# Predictor
|
# Predictor
|
||||||
pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]})
|
pred = classify.ClassificationPredictor(overrides={'imgsz': [64, 64]})
|
||||||
result = pred(source=SOURCE, model=trainer.best)
|
result = pred(source=SOURCE, model=trainer.best)
|
||||||
assert len(result) == 2, "predictor test failed"
|
assert len(result) == 2, 'predictor test failed'
|
||||||
|
@ -37,24 +37,24 @@ def test_model_fuse():
|
|||||||
|
|
||||||
def test_predict_dir():
|
def test_predict_dir():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model(source=ROOT / "assets")
|
model(source=ROOT / 'assets')
|
||||||
|
|
||||||
|
|
||||||
def test_predict_img():
|
def test_predict_img():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
img = Image.open(str(SOURCE))
|
img = Image.open(str(SOURCE))
|
||||||
output = model(source=img, save=True, verbose=True) # PIL
|
output = model(source=img, save=True, verbose=True) # PIL
|
||||||
assert len(output) == 1, "predict test failed"
|
assert len(output) == 1, 'predict test failed'
|
||||||
img = cv2.imread(str(SOURCE))
|
img = cv2.imread(str(SOURCE))
|
||||||
output = model(source=img, save=True, save_txt=True) # ndarray
|
output = model(source=img, save=True, save_txt=True) # ndarray
|
||||||
assert len(output) == 1, "predict test failed"
|
assert len(output) == 1, 'predict test failed'
|
||||||
output = model(source=[img, img], save=True, save_txt=True) # batch
|
output = model(source=[img, img], save=True, save_txt=True) # batch
|
||||||
assert len(output) == 2, "predict test failed"
|
assert len(output) == 2, 'predict test failed'
|
||||||
output = model(source=[img, img], save=True, stream=True) # stream
|
output = model(source=[img, img], save=True, stream=True) # stream
|
||||||
assert len(list(output)) == 2, "predict test failed"
|
assert len(list(output)) == 2, 'predict test failed'
|
||||||
tens = torch.zeros(320, 640, 3)
|
tens = torch.zeros(320, 640, 3)
|
||||||
output = model(tens.numpy())
|
output = model(tens.numpy())
|
||||||
assert len(output) == 1, "predict test failed"
|
assert len(output) == 1, 'predict test failed'
|
||||||
# test multiple source
|
# test multiple source
|
||||||
imgs = [
|
imgs = [
|
||||||
SOURCE, # filename
|
SOURCE, # filename
|
||||||
@ -64,23 +64,23 @@ def test_predict_img():
|
|||||||
Image.open(SOURCE), # PIL
|
Image.open(SOURCE), # PIL
|
||||||
np.zeros((320, 640, 3))] # numpy
|
np.zeros((320, 640, 3))] # numpy
|
||||||
output = model(imgs)
|
output = model(imgs)
|
||||||
assert len(output) == 6, "predict test failed!"
|
assert len(output) == 6, 'predict test failed!'
|
||||||
|
|
||||||
|
|
||||||
def test_val():
|
def test_val():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.val(data="coco8.yaml", imgsz=32)
|
model.val(data='coco8.yaml', imgsz=32)
|
||||||
|
|
||||||
|
|
||||||
def test_train_scratch():
|
def test_train_scratch():
|
||||||
model = YOLO(CFG)
|
model = YOLO(CFG)
|
||||||
model.train(data="coco8.yaml", epochs=1, imgsz=32)
|
model.train(data='coco8.yaml', epochs=1, imgsz=32)
|
||||||
model(SOURCE)
|
model(SOURCE)
|
||||||
|
|
||||||
|
|
||||||
def test_train_pretrained():
|
def test_train_pretrained():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.train(data="coco8.yaml", epochs=1, imgsz=32)
|
model.train(data='coco8.yaml', epochs=1, imgsz=32)
|
||||||
model(SOURCE)
|
model(SOURCE)
|
||||||
|
|
||||||
|
|
||||||
@ -139,10 +139,10 @@ def test_all_model_yamls():
|
|||||||
|
|
||||||
def test_workflow():
|
def test_workflow():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.train(data="coco8.yaml", epochs=1, imgsz=32)
|
model.train(data='coco8.yaml', epochs=1, imgsz=32)
|
||||||
model.val()
|
model.val()
|
||||||
model.predict(SOURCE)
|
model.predict(SOURCE)
|
||||||
model.export(format="onnx") # export a model to ONNX format
|
model.export(format='onnx') # export a model to ONNX format
|
||||||
|
|
||||||
|
|
||||||
def test_predict_callback_and_setup():
|
def test_predict_callback_and_setup():
|
||||||
@ -154,8 +154,8 @@ def test_predict_callback_and_setup():
|
|||||||
bs = [predictor.dataset.bs for _ in range(len(path))]
|
bs = [predictor.dataset.bs for _ in range(len(path))]
|
||||||
predictor.results = zip(predictor.results, im0s, bs)
|
predictor.results = zip(predictor.results, im0s, bs)
|
||||||
|
|
||||||
model = YOLO("yolov8n.pt")
|
model = YOLO('yolov8n.pt')
|
||||||
model.add_callback("on_predict_batch_end", on_predict_batch_end)
|
model.add_callback('on_predict_batch_end', on_predict_batch_end)
|
||||||
|
|
||||||
dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
|
dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
|
||||||
bs = dataset.bs # noqa access predictor properties
|
bs = dataset.bs # noqa access predictor properties
|
||||||
@ -168,8 +168,8 @@ def test_predict_callback_and_setup():
|
|||||||
|
|
||||||
|
|
||||||
def test_result():
|
def test_result():
|
||||||
model = YOLO("yolov8n-seg.pt")
|
model = YOLO('yolov8n-seg.pt')
|
||||||
img = str(ROOT / "assets/bus.jpg")
|
img = str(ROOT / 'assets/bus.jpg')
|
||||||
res = model([img, img])
|
res = model([img, img])
|
||||||
res[0].numpy()
|
res[0].numpy()
|
||||||
res[0].cpu().numpy()
|
res[0].cpu().numpy()
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.0.40"
|
__version__ = '8.0.40'
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils.checks import check_yolo as checks
|
from ultralytics.yolo.utils.checks import check_yolo as checks
|
||||||
|
|
||||||
__all__ = ["__version__", "YOLO", "checks"] # allow simpler import
|
__all__ = ['__version__', 'YOLO', 'checks'] # allow simpler import
|
||||||
|
@ -10,10 +10,10 @@ from ultralytics.yolo.engine.model import YOLO
|
|||||||
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
|
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
|
||||||
|
|
||||||
# Define all export formats
|
# Define all export formats
|
||||||
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ["ultralytics_tflite", "ultralytics_coreml"]
|
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_coreml']
|
||||||
|
|
||||||
|
|
||||||
def start(key=""):
|
def start(key=''):
|
||||||
"""
|
"""
|
||||||
Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
|
Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
|
||||||
"""
|
"""
|
||||||
@ -34,7 +34,7 @@ def start(key=""):
|
|||||||
session.register_callbacks(trainer)
|
session.register_callbacks(trainer)
|
||||||
trainer.train(**session.train_args)
|
trainer.train(**session.train_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f"{PREFIX}{e}")
|
LOGGER.warning(f'{PREFIX}{e}')
|
||||||
|
|
||||||
|
|
||||||
def request_api_key(auth, max_attempts=3):
|
def request_api_key(auth, max_attempts=3):
|
||||||
@ -43,56 +43,56 @@ def request_api_key(auth, max_attempts=3):
|
|||||||
"""
|
"""
|
||||||
import getpass
|
import getpass
|
||||||
for attempts in range(max_attempts):
|
for attempts in range(max_attempts):
|
||||||
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
|
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
|
||||||
input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n")
|
input_key = getpass.getpass('Enter your Ultralytics HUB API key:\n')
|
||||||
auth.api_key, model_id = split_key(input_key)
|
auth.api_key, model_id = split_key(input_key)
|
||||||
|
|
||||||
if auth.authenticate():
|
if auth.authenticate():
|
||||||
LOGGER.info(f"{PREFIX}Authenticated ✅")
|
LOGGER.info(f'{PREFIX}Authenticated ✅')
|
||||||
return model_id
|
return model_id
|
||||||
|
|
||||||
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n")
|
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️\n')
|
||||||
|
|
||||||
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
|
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
|
||||||
|
|
||||||
|
|
||||||
def reset_model(key=""):
|
def reset_model(key=''):
|
||||||
# Reset a trained model to an untrained state
|
# Reset a trained model to an untrained state
|
||||||
api_key, model_id = split_key(key)
|
api_key, model_id = split_key(key)
|
||||||
r = requests.post("https://api.ultralytics.com/model-reset", json={"apiKey": api_key, "modelId": model_id})
|
r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': api_key, 'modelId': model_id})
|
||||||
|
|
||||||
if r.status_code == 200:
|
if r.status_code == 200:
|
||||||
LOGGER.info(f"{PREFIX}model reset successfully")
|
LOGGER.info(f'{PREFIX}model reset successfully')
|
||||||
return
|
return
|
||||||
LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}")
|
LOGGER.warning(f'{PREFIX}model reset failure {r.status_code} {r.reason}')
|
||||||
|
|
||||||
|
|
||||||
def export_model(key="", format="torchscript"):
|
def export_model(key='', format='torchscript'):
|
||||||
# Export a model to all formats
|
# Export a model to all formats
|
||||||
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
||||||
api_key, model_id = split_key(key)
|
api_key, model_id = split_key(key)
|
||||||
r = requests.post("https://api.ultralytics.com/export",
|
r = requests.post('https://api.ultralytics.com/export',
|
||||||
json={
|
json={
|
||||||
"apiKey": api_key,
|
'apiKey': api_key,
|
||||||
"modelId": model_id,
|
'modelId': model_id,
|
||||||
"format": format})
|
'format': format})
|
||||||
assert (r.status_code == 200), f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
|
assert (r.status_code == 200), f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
|
||||||
LOGGER.info(f"{PREFIX}{format} export started ✅")
|
LOGGER.info(f'{PREFIX}{format} export started ✅')
|
||||||
|
|
||||||
|
|
||||||
def get_export(key="", format="torchscript"):
|
def get_export(key='', format='torchscript'):
|
||||||
# Get an exported model dictionary with download URL
|
# Get an exported model dictionary with download URL
|
||||||
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
||||||
api_key, model_id = split_key(key)
|
api_key, model_id = split_key(key)
|
||||||
r = requests.post("https://api.ultralytics.com/get-export",
|
r = requests.post('https://api.ultralytics.com/get-export',
|
||||||
json={
|
json={
|
||||||
"apiKey": api_key,
|
'apiKey': api_key,
|
||||||
"modelId": model_id,
|
'modelId': model_id,
|
||||||
"format": format})
|
'format': format})
|
||||||
assert (r.status_code == 200), f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
|
assert (r.status_code == 200), f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
|
||||||
return r.json()
|
return r.json()
|
||||||
|
|
||||||
|
|
||||||
# temp. For checking
|
# temp. For checking
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
start()
|
start()
|
||||||
|
@ -5,7 +5,7 @@ import requests
|
|||||||
from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials
|
from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials
|
||||||
from ultralytics.yolo.utils import is_colab
|
from ultralytics.yolo.utils import is_colab
|
||||||
|
|
||||||
API_KEY_PATH = "https://hub.ultralytics.com/settings?tab=api+keys"
|
API_KEY_PATH = 'https://hub.ultralytics.com/settings?tab=api+keys'
|
||||||
|
|
||||||
|
|
||||||
class Auth:
|
class Auth:
|
||||||
@ -18,7 +18,7 @@ class Auth:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _clean_api_key(key: str) -> str:
|
def _clean_api_key(key: str) -> str:
|
||||||
"""Strip model from key if present"""
|
"""Strip model from key if present"""
|
||||||
separator = "_"
|
separator = '_'
|
||||||
return key.split(separator)[0] if separator in key else key
|
return key.split(separator)[0] if separator in key else key
|
||||||
|
|
||||||
def authenticate(self) -> bool:
|
def authenticate(self) -> bool:
|
||||||
@ -26,11 +26,11 @@ class Auth:
|
|||||||
try:
|
try:
|
||||||
header = self.get_auth_header()
|
header = self.get_auth_header()
|
||||||
if header:
|
if header:
|
||||||
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
|
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
|
||||||
if not r.json().get('success', False):
|
if not r.json().get('success', False):
|
||||||
raise ConnectionError("Unable to authenticate.")
|
raise ConnectionError('Unable to authenticate.')
|
||||||
return True
|
return True
|
||||||
raise ConnectionError("User has not authenticated locally.")
|
raise ConnectionError('User has not authenticated locally.')
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
self.id_token = self.api_key = False # reset invalid
|
self.id_token = self.api_key = False # reset invalid
|
||||||
return False
|
return False
|
||||||
@ -43,21 +43,21 @@ class Auth:
|
|||||||
if not is_colab():
|
if not is_colab():
|
||||||
return False # Currently only works with Colab
|
return False # Currently only works with Colab
|
||||||
try:
|
try:
|
||||||
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
|
authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
|
||||||
if authn.get("success", False):
|
if authn.get('success', False):
|
||||||
self.id_token = authn.get("data", {}).get("idToken", None)
|
self.id_token = authn.get('data', {}).get('idToken', None)
|
||||||
self.authenticate()
|
self.authenticate()
|
||||||
return True
|
return True
|
||||||
raise ConnectionError("Unable to fetch browser authentication details.")
|
raise ConnectionError('Unable to fetch browser authentication details.')
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
self.id_token = False # reset invalid
|
self.id_token = False # reset invalid
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_auth_header(self):
|
def get_auth_header(self):
|
||||||
if self.id_token:
|
if self.id_token:
|
||||||
return {"authorization": f"Bearer {self.id_token}"}
|
return {'authorization': f'Bearer {self.id_token}'}
|
||||||
elif self.api_key:
|
elif self.api_key:
|
||||||
return {"x-api-key": self.api_key}
|
return {'x-api-key': self.api_key}
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_
|
|||||||
from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, emojis, is_colab, threaded
|
from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, emojis, is_colab, threaded
|
||||||
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
|
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
|
||||||
|
|
||||||
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
|
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
|
||||||
session = None
|
session = None
|
||||||
|
|
||||||
|
|
||||||
@ -20,9 +20,9 @@ class HubTrainingSession:
|
|||||||
def __init__(self, model_id, auth):
|
def __init__(self, model_id, auth):
|
||||||
self.agent_id = None # identifies which instance is communicating with server
|
self.agent_id = None # identifies which instance is communicating with server
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.api_url = f"{HUB_API_ROOT}/v1/models/{model_id}"
|
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
|
||||||
self.auth_header = auth.get_auth_header()
|
self.auth_header = auth.get_auth_header()
|
||||||
self._rate_limits = {"metrics": 3.0, "ckpt": 900.0, "heartbeat": 300.0} # rate limits (seconds)
|
self._rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
|
||||||
self._timers = {} # rate limit timers (seconds)
|
self._timers = {} # rate limit timers (seconds)
|
||||||
self._metrics_queue = {} # metrics queue
|
self._metrics_queue = {} # metrics queue
|
||||||
self.model = self._get_model()
|
self.model = self._get_model()
|
||||||
@ -40,7 +40,7 @@ class HubTrainingSession:
|
|||||||
passed by signal.
|
passed by signal.
|
||||||
"""
|
"""
|
||||||
if self.alive is True:
|
if self.alive is True:
|
||||||
LOGGER.info(f"{PREFIX}Kill signal received! ❌")
|
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
|
||||||
self._stop_heartbeat()
|
self._stop_heartbeat()
|
||||||
sys.exit(signum)
|
sys.exit(signum)
|
||||||
|
|
||||||
@ -49,23 +49,23 @@ class HubTrainingSession:
|
|||||||
self.alive = False
|
self.alive = False
|
||||||
|
|
||||||
def upload_metrics(self):
|
def upload_metrics(self):
|
||||||
payload = {"metrics": self._metrics_queue.copy(), "type": "metrics"}
|
payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'}
|
||||||
smart_request(f"{self.api_url}", json=payload, headers=self.auth_header, code=2)
|
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
|
||||||
|
|
||||||
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
||||||
# Upload a model to HUB
|
# Upload a model to HUB
|
||||||
file = None
|
file = None
|
||||||
if Path(weights).is_file():
|
if Path(weights).is_file():
|
||||||
with open(weights, "rb") as f:
|
with open(weights, 'rb') as f:
|
||||||
file = f.read()
|
file = f.read()
|
||||||
if final:
|
if final:
|
||||||
smart_request(
|
smart_request(
|
||||||
f"{self.api_url}/upload",
|
f'{self.api_url}/upload',
|
||||||
data={
|
data={
|
||||||
"epoch": epoch,
|
'epoch': epoch,
|
||||||
"type": "final",
|
'type': 'final',
|
||||||
"map": map},
|
'map': map},
|
||||||
files={"best.pt": file},
|
files={'best.pt': file},
|
||||||
headers=self.auth_header,
|
headers=self.auth_header,
|
||||||
retry=10,
|
retry=10,
|
||||||
timeout=3600,
|
timeout=3600,
|
||||||
@ -73,66 +73,66 @@ class HubTrainingSession:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
smart_request(
|
smart_request(
|
||||||
f"{self.api_url}/upload",
|
f'{self.api_url}/upload',
|
||||||
data={
|
data={
|
||||||
"epoch": epoch,
|
'epoch': epoch,
|
||||||
"type": "epoch",
|
'type': 'epoch',
|
||||||
"isBest": bool(is_best)},
|
'isBest': bool(is_best)},
|
||||||
headers=self.auth_header,
|
headers=self.auth_header,
|
||||||
files={"last.pt": file},
|
files={'last.pt': file},
|
||||||
code=3,
|
code=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_model(self):
|
def _get_model(self):
|
||||||
# Returns model from database by id
|
# Returns model from database by id
|
||||||
api_url = f"{HUB_API_ROOT}/v1/models/{self.model_id}"
|
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
||||||
headers = self.auth_header
|
headers = self.auth_header
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = smart_request(api_url, method="get", headers=headers, thread=False, code=0)
|
response = smart_request(api_url, method='get', headers=headers, thread=False, code=0)
|
||||||
data = response.json().get("data", None)
|
data = response.json().get('data', None)
|
||||||
|
|
||||||
if data.get("status", None) == "trained":
|
if data.get('status', None) == 'trained':
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
emojis(f"Model is already trained and uploaded to "
|
emojis(f'Model is already trained and uploaded to '
|
||||||
f"https://hub.ultralytics.com/models/{self.model_id} 🚀"))
|
f'https://hub.ultralytics.com/models/{self.model_id} 🚀'))
|
||||||
|
|
||||||
if not data.get("data", None):
|
if not data.get('data', None):
|
||||||
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
|
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
|
||||||
self.model_id = data["id"]
|
self.model_id = data['id']
|
||||||
|
|
||||||
# TODO: restore when server keys when dataset URL and GPU train is working
|
# TODO: restore when server keys when dataset URL and GPU train is working
|
||||||
|
|
||||||
self.train_args = {
|
self.train_args = {
|
||||||
"batch": data["batch_size"],
|
'batch': data['batch_size'],
|
||||||
"epochs": data["epochs"],
|
'epochs': data['epochs'],
|
||||||
"imgsz": data["imgsz"],
|
'imgsz': data['imgsz'],
|
||||||
"patience": data["patience"],
|
'patience': data['patience'],
|
||||||
"device": data["device"],
|
'device': data['device'],
|
||||||
"cache": data["cache"],
|
'cache': data['cache'],
|
||||||
"data": data["data"]}
|
'data': data['data']}
|
||||||
|
|
||||||
self.input_file = data.get("cfg", data["weights"])
|
self.input_file = data.get('cfg', data['weights'])
|
||||||
|
|
||||||
# hack for yolov5 cfg adds u
|
# hack for yolov5 cfg adds u
|
||||||
if "cfg" in data and "yolov5" in data["cfg"]:
|
if 'cfg' in data and 'yolov5' in data['cfg']:
|
||||||
self.input_file = data["cfg"].replace(".yaml", "u.yaml")
|
self.input_file = data['cfg'].replace('.yaml', 'u.yaml')
|
||||||
|
|
||||||
return data
|
return data
|
||||||
except requests.exceptions.ConnectionError as e:
|
except requests.exceptions.ConnectionError as e:
|
||||||
raise ConnectionRefusedError("ERROR: The HUB server is not online. Please try again later.") from e
|
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def check_disk_space(self):
|
def check_disk_space(self):
|
||||||
if not check_dataset_disk_space(self.model["data"]):
|
if not check_dataset_disk_space(self.model['data']):
|
||||||
raise MemoryError("Not enough disk space")
|
raise MemoryError('Not enough disk space')
|
||||||
|
|
||||||
def register_callbacks(self, trainer):
|
def register_callbacks(self, trainer):
|
||||||
trainer.add_callback("on_pretrain_routine_end", self.on_pretrain_routine_end)
|
trainer.add_callback('on_pretrain_routine_end', self.on_pretrain_routine_end)
|
||||||
trainer.add_callback("on_fit_epoch_end", self.on_fit_epoch_end)
|
trainer.add_callback('on_fit_epoch_end', self.on_fit_epoch_end)
|
||||||
trainer.add_callback("on_model_save", self.on_model_save)
|
trainer.add_callback('on_model_save', self.on_model_save)
|
||||||
trainer.add_callback("on_train_end", self.on_train_end)
|
trainer.add_callback('on_train_end', self.on_train_end)
|
||||||
|
|
||||||
def on_pretrain_routine_end(self, trainer):
|
def on_pretrain_routine_end(self, trainer):
|
||||||
"""
|
"""
|
||||||
@ -140,57 +140,57 @@ class HubTrainingSession:
|
|||||||
This method does not use trainer. It is passed to all callbacks by default.
|
This method does not use trainer. It is passed to all callbacks by default.
|
||||||
"""
|
"""
|
||||||
# Start timer for upload rate limit
|
# Start timer for upload rate limit
|
||||||
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀")
|
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
|
||||||
self._timers = {"metrics": time(), "ckpt": time()} # start timer on self.rate_limit
|
self._timers = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
|
||||||
|
|
||||||
def on_fit_epoch_end(self, trainer):
|
def on_fit_epoch_end(self, trainer):
|
||||||
# Upload metrics after val end
|
# Upload metrics after val end
|
||||||
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics}
|
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
|
||||||
|
|
||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
model_info = {
|
model_info = {
|
||||||
"model/parameters": get_num_params(trainer.model),
|
'model/parameters': get_num_params(trainer.model),
|
||||||
"model/GFLOPs": round(get_flops(trainer.model), 3),
|
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
||||||
"model/speed(ms)": round(trainer.validator.speed[1], 3)}
|
'model/speed(ms)': round(trainer.validator.speed[1], 3)}
|
||||||
all_plots = {**all_plots, **model_info}
|
all_plots = {**all_plots, **model_info}
|
||||||
self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||||
if time() - self._timers["metrics"] > self._rate_limits["metrics"]:
|
if time() - self._timers['metrics'] > self._rate_limits['metrics']:
|
||||||
self.upload_metrics()
|
self.upload_metrics()
|
||||||
self._timers["metrics"] = time() # reset timer
|
self._timers['metrics'] = time() # reset timer
|
||||||
self._metrics_queue = {} # reset queue
|
self._metrics_queue = {} # reset queue
|
||||||
|
|
||||||
def on_model_save(self, trainer):
|
def on_model_save(self, trainer):
|
||||||
# Upload checkpoints with rate limiting
|
# Upload checkpoints with rate limiting
|
||||||
is_best = trainer.best_fitness == trainer.fitness
|
is_best = trainer.best_fitness == trainer.fitness
|
||||||
if time() - self._timers["ckpt"] > self._rate_limits["ckpt"]:
|
if time() - self._timers['ckpt'] > self._rate_limits['ckpt']:
|
||||||
LOGGER.info(f"{PREFIX}Uploading checkpoint {self.model_id}")
|
LOGGER.info(f'{PREFIX}Uploading checkpoint {self.model_id}')
|
||||||
self._upload_model(trainer.epoch, trainer.last, is_best)
|
self._upload_model(trainer.epoch, trainer.last, is_best)
|
||||||
self._timers["ckpt"] = time() # reset timer
|
self._timers['ckpt'] = time() # reset timer
|
||||||
|
|
||||||
def on_train_end(self, trainer):
|
def on_train_end(self, trainer):
|
||||||
# Upload final model and metrics with exponential standoff
|
# Upload final model and metrics with exponential standoff
|
||||||
LOGGER.info(f"{PREFIX}Training completed successfully ✅")
|
LOGGER.info(f'{PREFIX}Training completed successfully ✅')
|
||||||
LOGGER.info(f"{PREFIX}Uploading final {self.model_id}")
|
LOGGER.info(f'{PREFIX}Uploading final {self.model_id}')
|
||||||
|
|
||||||
# hack for fetching mAP
|
# hack for fetching mAP
|
||||||
mAP = trainer.metrics.get("metrics/mAP50-95(B)", 0)
|
mAP = trainer.metrics.get('metrics/mAP50-95(B)', 0)
|
||||||
self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95
|
self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95
|
||||||
self.alive = False # stop heartbeats
|
self.alive = False # stop heartbeats
|
||||||
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀")
|
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
|
||||||
|
|
||||||
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
||||||
# Upload a model to HUB
|
# Upload a model to HUB
|
||||||
file = None
|
file = None
|
||||||
if Path(weights).is_file():
|
if Path(weights).is_file():
|
||||||
with open(weights, "rb") as f:
|
with open(weights, 'rb') as f:
|
||||||
file = f.read()
|
file = f.read()
|
||||||
file_param = {"best.pt" if final else "last.pt": file}
|
file_param = {'best.pt' if final else 'last.pt': file}
|
||||||
endpoint = f"{self.api_url}/upload"
|
endpoint = f'{self.api_url}/upload'
|
||||||
data = {"epoch": epoch}
|
data = {'epoch': epoch}
|
||||||
if final:
|
if final:
|
||||||
data.update({"type": "final", "map": map})
|
data.update({'type': 'final', 'map': map})
|
||||||
else:
|
else:
|
||||||
data.update({"type": "epoch", "isBest": bool(is_best)})
|
data.update({'type': 'epoch', 'isBest': bool(is_best)})
|
||||||
|
|
||||||
smart_request(
|
smart_request(
|
||||||
endpoint,
|
endpoint,
|
||||||
@ -207,14 +207,14 @@ class HubTrainingSession:
|
|||||||
self.alive = True
|
self.alive = True
|
||||||
while self.alive:
|
while self.alive:
|
||||||
r = smart_request(
|
r = smart_request(
|
||||||
f"{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}",
|
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
||||||
json={
|
json={
|
||||||
"agent": AGENT_NAME,
|
'agent': AGENT_NAME,
|
||||||
"agentId": self.agent_id},
|
'agentId': self.agent_id},
|
||||||
headers=self.auth_header,
|
headers=self.auth_header,
|
||||||
retry=0,
|
retry=0,
|
||||||
code=5,
|
code=5,
|
||||||
thread=False,
|
thread=False,
|
||||||
)
|
)
|
||||||
self.agent_id = r.json().get("data", {}).get("agentId", None)
|
self.agent_id = r.json().get('data', {}).get('agentId', None)
|
||||||
sleep(self._rate_limits["heartbeat"])
|
sleep(self._rate_limits['heartbeat'])
|
||||||
|
@ -18,14 +18,14 @@ from ultralytics.yolo.utils.checks import check_online
|
|||||||
|
|
||||||
PREFIX = colorstr('Ultralytics: ')
|
PREFIX = colorstr('Ultralytics: ')
|
||||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
||||||
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
|
HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0):
|
def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0):
|
||||||
# Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0
|
# Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0
|
||||||
gib = 1 << 30 # bytes per GiB
|
gib = 1 << 30 # bytes per GiB
|
||||||
data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
|
data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
|
||||||
total, used, free = (x / gib for x in shutil.disk_usage("/")) # bytes
|
total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
|
||||||
LOGGER.info(f'{PREFIX}{data:.3f} GB dataset, {free:.1f}/{total:.1f} GB free disk space')
|
LOGGER.info(f'{PREFIX}{data:.3f} GB dataset, {free:.1f}/{total:.1f} GB free disk space')
|
||||||
if data * sf < free:
|
if data * sf < free:
|
||||||
return True # sufficient space
|
return True # sufficient space
|
||||||
@ -57,7 +57,7 @@ def request_with_credentials(url: str) -> any:
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
""" % url))
|
""" % url))
|
||||||
return output.eval_js("_hub_tmp")
|
return output.eval_js('_hub_tmp')
|
||||||
|
|
||||||
|
|
||||||
# Deprecated TODO: eliminate this function?
|
# Deprecated TODO: eliminate this function?
|
||||||
@ -84,7 +84,7 @@ def split_key(key=''):
|
|||||||
return api_key, model_id
|
return api_key, model_id
|
||||||
|
|
||||||
|
|
||||||
def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post", verbose=True, **kwargs):
|
def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method='post', verbose=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
||||||
|
|
||||||
@ -128,7 +128,7 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post
|
|||||||
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
|
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
|
||||||
f"Please retry after {h['Retry-After']}s."
|
f"Please retry after {h['Retry-After']}s."
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
|
LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
|
||||||
if r.status_code not in retry_codes:
|
if r.status_code not in retry_codes:
|
||||||
return r
|
return r
|
||||||
time.sleep(2 ** i) # exponential standoff
|
time.sleep(2 ** i) # exponential standoff
|
||||||
@ -149,17 +149,17 @@ class Traces:
|
|||||||
self.rate_limit = 3.0 # rate limit (seconds)
|
self.rate_limit = 3.0 # rate limit (seconds)
|
||||||
self.t = 0.0 # rate limit timer (seconds)
|
self.t = 0.0 # rate limit timer (seconds)
|
||||||
self.metadata = {
|
self.metadata = {
|
||||||
"sys_argv_name": Path(sys.argv[0]).name,
|
'sys_argv_name': Path(sys.argv[0]).name,
|
||||||
"install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
||||||
"python": platform.python_version(),
|
'python': platform.python_version(),
|
||||||
"release": __version__,
|
'release': __version__,
|
||||||
"environment": ENVIRONMENT}
|
'environment': ENVIRONMENT}
|
||||||
self.enabled = SETTINGS['sync'] and \
|
self.enabled = SETTINGS['sync'] and \
|
||||||
RANK in {-1, 0} and \
|
RANK in {-1, 0} and \
|
||||||
check_online() and \
|
check_online() and \
|
||||||
not is_pytest_running() and \
|
not is_pytest_running() and \
|
||||||
not is_github_actions_ci() and \
|
not is_github_actions_ci() and \
|
||||||
(is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
|
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
|
||||||
|
|
||||||
def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
|
def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
|
||||||
"""
|
"""
|
||||||
|
@ -127,11 +127,11 @@ class AutoBackend(nn.Module):
|
|||||||
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
||||||
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
|
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
|
||||||
if network.get_parameters()[0].get_layout().empty:
|
if network.get_parameters()[0].get_layout().empty:
|
||||||
network.get_parameters()[0].set_layout(Layout("NCHW"))
|
network.get_parameters()[0].set_layout(Layout('NCHW'))
|
||||||
batch_dim = get_batch(network)
|
batch_dim = get_batch(network)
|
||||||
if batch_dim.is_static:
|
if batch_dim.is_static:
|
||||||
batch_size = batch_dim.get_length()
|
batch_size = batch_dim.get_length()
|
||||||
executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
|
executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for Intel NCS2
|
||||||
elif engine: # TensorRT
|
elif engine: # TensorRT
|
||||||
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
||||||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
||||||
@ -184,7 +184,7 @@ class AutoBackend(nn.Module):
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
def wrap_frozen_graph(gd, inputs, outputs):
|
def wrap_frozen_graph(gd, inputs, outputs):
|
||||||
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
|
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
|
||||||
ge = x.graph.as_graph_element
|
ge = x.graph.as_graph_element
|
||||||
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
||||||
|
|
||||||
@ -198,7 +198,7 @@ class AutoBackend(nn.Module):
|
|||||||
gd = tf.Graph().as_graph_def() # TF GraphDef
|
gd = tf.Graph().as_graph_def() # TF GraphDef
|
||||||
with open(w, 'rb') as f:
|
with open(w, 'rb') as f:
|
||||||
gd.ParseFromString(f.read())
|
gd.ParseFromString(f.read())
|
||||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
|
||||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||||
@ -220,9 +220,9 @@ class AutoBackend(nn.Module):
|
|||||||
output_details = interpreter.get_output_details() # outputs
|
output_details = interpreter.get_output_details() # outputs
|
||||||
# load metadata
|
# load metadata
|
||||||
with contextlib.suppress(zipfile.BadZipFile):
|
with contextlib.suppress(zipfile.BadZipFile):
|
||||||
with zipfile.ZipFile(w, "r") as model:
|
with zipfile.ZipFile(w, 'r') as model:
|
||||||
meta_file = model.namelist()[0]
|
meta_file = model.namelist()[0]
|
||||||
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
meta = ast.literal_eval(model.read(meta_file).decode('utf-8'))
|
||||||
stride, names = int(meta['stride']), meta['names']
|
stride, names = int(meta['stride']), meta['names']
|
||||||
elif tfjs: # TF.js
|
elif tfjs: # TF.js
|
||||||
raise NotImplementedError('YOLOv8 TF.js inference is not supported')
|
raise NotImplementedError('YOLOv8 TF.js inference is not supported')
|
||||||
@ -251,8 +251,8 @@ class AutoBackend(nn.Module):
|
|||||||
else:
|
else:
|
||||||
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
|
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
|
||||||
raise TypeError(f"model='{w}' is not a supported model format. "
|
raise TypeError(f"model='{w}' is not a supported model format. "
|
||||||
"See https://docs.ultralytics.com/tasks/detection/#export for help."
|
'See https://docs.ultralytics.com/tasks/detection/#export for help.'
|
||||||
f"\n\n{EXPORT_FORMATS_TABLE}")
|
f'\n\n{EXPORT_FORMATS_TABLE}')
|
||||||
|
|
||||||
# Load external metadata YAML
|
# Load external metadata YAML
|
||||||
if xml or saved_model or paddle:
|
if xml or saved_model or paddle:
|
||||||
@ -410,5 +410,5 @@ class AutoBackend(nn.Module):
|
|||||||
url = urlparse(p) # if url may be Triton inference server
|
url = urlparse(p) # if url may be Triton inference server
|
||||||
types = [s in Path(p).name for s in sf]
|
types = [s in Path(p).name for s in sf]
|
||||||
types[8] &= not types[9] # tflite &= not edgetpu
|
types[8] &= not types[9] # tflite &= not edgetpu
|
||||||
triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
|
triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
|
||||||
return types + [triton]
|
return types + [triton]
|
||||||
|
@ -99,7 +99,7 @@ class AutoShape(nn.Module):
|
|||||||
shape1.append([y * g for y in s])
|
shape1.append([y * g for y in s])
|
||||||
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
||||||
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
|
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
|
||||||
x = [LetterBox(shape1, auto=False)(image=im)["img"] for im in ims] # pad
|
x = [LetterBox(shape1, auto=False)(image=im)['img'] for im in ims] # pad
|
||||||
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
||||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ class BaseModel(nn.Module):
|
|||||||
weights (str): The weights to load into the model.
|
weights (str): The weights to load into the model.
|
||||||
"""
|
"""
|
||||||
# Force all tasks to implement this function
|
# Force all tasks to implement this function
|
||||||
raise NotImplementedError("This function needs to be implemented by derived classes!")
|
raise NotImplementedError('This function needs to be implemented by derived classes!')
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(BaseModel):
|
class DetectionModel(BaseModel):
|
||||||
@ -249,7 +249,7 @@ class SegmentationModel(DetectionModel):
|
|||||||
super().__init__(cfg, ch, nc, verbose)
|
super().__init__(cfg, ch, nc, verbose)
|
||||||
|
|
||||||
def _forward_augment(self, x):
|
def _forward_augment(self, x):
|
||||||
raise NotImplementedError("WARNING ⚠️ SegmentationModel has not supported augment inference yet!")
|
raise NotImplementedError('WARNING ⚠️ SegmentationModel has not supported augment inference yet!')
|
||||||
|
|
||||||
|
|
||||||
class ClassificationModel(BaseModel):
|
class ClassificationModel(BaseModel):
|
||||||
@ -292,7 +292,7 @@ class ClassificationModel(BaseModel):
|
|||||||
self.info()
|
self.info()
|
||||||
|
|
||||||
def load(self, weights):
|
def load(self, weights):
|
||||||
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||||
csd = model.float().state_dict()
|
csd = model.float().state_dict()
|
||||||
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
||||||
self.load_state_dict(csd, strict=False) # load
|
self.load_state_dict(csd, strict=False) # load
|
||||||
@ -341,10 +341,10 @@ def torch_safe_load(weight):
|
|||||||
return torch.load(file, map_location='cpu') # load
|
return torch.load(file, map_location='cpu') # load
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
if e.name == 'omegaconf': # e.name is missing module name
|
if e.name == 'omegaconf': # e.name is missing module name
|
||||||
LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements."
|
LOGGER.warning(f'WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements.'
|
||||||
f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future."
|
f'\nAutoInstall will run now for {e.name} but this feature will be removed in the future.'
|
||||||
f"\nRecommend fixes are to train a new model using updated ultralytics package or to "
|
f'\nRecommend fixes are to train a new model using updated ultralytics package or to '
|
||||||
f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
f'download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0')
|
||||||
if e.name != 'models':
|
if e.name != 'models':
|
||||||
check_requirements(e.name) # install missing module
|
check_requirements(e.name) # install missing module
|
||||||
return torch.load(file, map_location='cpu') # load
|
return torch.load(file, map_location='cpu') # load
|
||||||
@ -489,13 +489,13 @@ def guess_model_task(model):
|
|||||||
|
|
||||||
def cfg2task(cfg):
|
def cfg2task(cfg):
|
||||||
# Guess from YAML dictionary
|
# Guess from YAML dictionary
|
||||||
m = cfg["head"][-1][-2].lower() # output module name
|
m = cfg['head'][-1][-2].lower() # output module name
|
||||||
if m in ["classify", "classifier", "cls", "fc"]:
|
if m in ['classify', 'classifier', 'cls', 'fc']:
|
||||||
return "classify"
|
return 'classify'
|
||||||
if m in ["detect"]:
|
if m in ['detect']:
|
||||||
return "detect"
|
return 'detect'
|
||||||
if m in ["segment"]:
|
if m in ['segment']:
|
||||||
return "segment"
|
return 'segment'
|
||||||
|
|
||||||
# Guess from model cfg
|
# Guess from model cfg
|
||||||
if isinstance(model, dict):
|
if isinstance(model, dict):
|
||||||
@ -513,22 +513,22 @@ def guess_model_task(model):
|
|||||||
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, Detect):
|
if isinstance(m, Detect):
|
||||||
return "detect"
|
return 'detect'
|
||||||
elif isinstance(m, Segment):
|
elif isinstance(m, Segment):
|
||||||
return "segment"
|
return 'segment'
|
||||||
elif isinstance(m, Classify):
|
elif isinstance(m, Classify):
|
||||||
return "classify"
|
return 'classify'
|
||||||
|
|
||||||
# Guess from model filename
|
# Guess from model filename
|
||||||
if isinstance(model, (str, Path)):
|
if isinstance(model, (str, Path)):
|
||||||
model = Path(model).stem
|
model = Path(model).stem
|
||||||
if '-seg' in model:
|
if '-seg' in model:
|
||||||
return "segment"
|
return 'segment'
|
||||||
elif '-cls' in model:
|
elif '-cls' in model:
|
||||||
return "classify"
|
return 'classify'
|
||||||
else:
|
else:
|
||||||
return "detect"
|
return 'detect'
|
||||||
|
|
||||||
# Unable to determine task from model
|
# Unable to determine task from model
|
||||||
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
|
raise SyntaxError('YOLO is unable to automatically guess model task. Explicitly define task for your model, '
|
||||||
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||||
|
@ -4,14 +4,14 @@ from ultralytics.tracker import BOTSORT, BYTETracker
|
|||||||
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
|
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||||
|
|
||||||
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
|
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
|
||||||
check_requirements('lap') # for linear_assignment
|
check_requirements('lap') # for linear_assignment
|
||||||
|
|
||||||
|
|
||||||
def on_predict_start(predictor):
|
def on_predict_start(predictor):
|
||||||
tracker = check_yaml(predictor.args.tracker)
|
tracker = check_yaml(predictor.args.tracker)
|
||||||
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
||||||
assert cfg.tracker_type in ["bytetrack", "botsort"], \
|
assert cfg.tracker_type in ['bytetrack', 'botsort'], \
|
||||||
f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
|
f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
|
||||||
trackers = []
|
trackers = []
|
||||||
for _ in range(predictor.dataset.bs):
|
for _ in range(predictor.dataset.bs):
|
||||||
@ -38,5 +38,5 @@ def on_predict_postprocess_end(predictor):
|
|||||||
|
|
||||||
|
|
||||||
def register_tracker(model):
|
def register_tracker(model):
|
||||||
model.add_callback("on_predict_start", on_predict_start)
|
model.add_callback('on_predict_start', on_predict_start)
|
||||||
model.add_callback("on_predict_postprocess_end", on_predict_postprocess_end)
|
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
|
||||||
|
@ -153,7 +153,7 @@ class STrack(BaseTrack):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})'
|
||||||
|
|
||||||
|
|
||||||
class BYTETracker:
|
class BYTETracker:
|
||||||
@ -206,7 +206,7 @@ class BYTETracker:
|
|||||||
strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
|
strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
|
||||||
# Predict the current location with KF
|
# Predict the current location with KF
|
||||||
self.multi_predict(strack_pool)
|
self.multi_predict(strack_pool)
|
||||||
if hasattr(self, "gmc"):
|
if hasattr(self, 'gmc'):
|
||||||
warp = self.gmc.apply(img, dets)
|
warp = self.gmc.apply(img, dets)
|
||||||
STrack.multi_gmc(strack_pool, warp)
|
STrack.multi_gmc(strack_pool, warp)
|
||||||
STrack.multi_gmc(unconfirmed, warp)
|
STrack.multi_gmc(unconfirmed, warp)
|
||||||
|
@ -50,14 +50,14 @@ class GMC:
|
|||||||
seqName = seqName[:-6]
|
seqName = seqName[:-6]
|
||||||
elif '-DPM' in seqName or '-SDP' in seqName:
|
elif '-DPM' in seqName or '-SDP' in seqName:
|
||||||
seqName = seqName[:-4]
|
seqName = seqName[:-4]
|
||||||
self.gmcFile = open(f"{filePath}/GMC-{seqName}.txt")
|
self.gmcFile = open(f'{filePath}/GMC-{seqName}.txt')
|
||||||
|
|
||||||
if self.gmcFile is None:
|
if self.gmcFile is None:
|
||||||
raise ValueError(f"Error: Unable to open GMC file in directory:{filePath}")
|
raise ValueError(f'Error: Unable to open GMC file in directory:{filePath}')
|
||||||
elif self.method in ['none', 'None']:
|
elif self.method in ['none', 'None']:
|
||||||
self.method = 'none'
|
self.method = 'none'
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Error: Unknown CMC method:{method}")
|
raise ValueError(f'Error: Unknown CMC method:{method}')
|
||||||
|
|
||||||
self.prevFrame = None
|
self.prevFrame = None
|
||||||
self.prevKeyPoints = None
|
self.prevKeyPoints = None
|
||||||
@ -302,7 +302,7 @@ class GMC:
|
|||||||
|
|
||||||
def applyFile(self, raw_frame, detections=None):
|
def applyFile(self, raw_frame, detections=None):
|
||||||
line = self.gmcFile.readline()
|
line = self.gmcFile.readline()
|
||||||
tokens = line.split("\t")
|
tokens = line.split('\t')
|
||||||
H = np.eye(2, 3, dtype=np.float_)
|
H = np.eye(2, 3, dtype=np.float_)
|
||||||
H[0, 0] = float(tokens[1])
|
H[0, 0] = float(tokens[1])
|
||||||
H[0, 1] = float(tokens[2])
|
H[0, 1] = float(tokens[2])
|
||||||
|
@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
from . import v8
|
from . import v8
|
||||||
|
|
||||||
__all__ = ["v8"]
|
__all__ = ['v8']
|
||||||
|
@ -142,8 +142,8 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
|||||||
string = ''
|
string = ''
|
||||||
for x in mismatched:
|
for x in mismatched:
|
||||||
matches = get_close_matches(x, base) # key list
|
matches = get_close_matches(x, base) # key list
|
||||||
matches = [f"{k}={DEFAULT_CFG_DICT[k]}" if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches]
|
matches = [f'{k}={DEFAULT_CFG_DICT[k]}' if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches]
|
||||||
match_str = f"Similar arguments are i.e. {matches}." if matches else ''
|
match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
|
||||||
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
||||||
raise SyntaxError(string + CLI_HELP_MSG) from e
|
raise SyntaxError(string + CLI_HELP_MSG) from e
|
||||||
|
|
||||||
@ -163,10 +163,10 @@ def merge_equals_args(args: List[str]) -> List[str]:
|
|||||||
new_args = []
|
new_args = []
|
||||||
for i, arg in enumerate(args):
|
for i, arg in enumerate(args):
|
||||||
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
|
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
|
||||||
new_args[-1] += f"={args[i + 1]}"
|
new_args[-1] += f'={args[i + 1]}'
|
||||||
del args[i + 1]
|
del args[i + 1]
|
||||||
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
|
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
|
||||||
new_args.append(f"{arg}{args[i + 1]}")
|
new_args.append(f'{arg}{args[i + 1]}')
|
||||||
del args[i + 1]
|
del args[i + 1]
|
||||||
elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
|
elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
|
||||||
new_args[-1] += arg
|
new_args[-1] += arg
|
||||||
@ -223,7 +223,7 @@ def entrypoint(debug=''):
|
|||||||
k, v = a.split('=', 1) # split on first '=' sign
|
k, v = a.split('=', 1) # split on first '=' sign
|
||||||
assert v, f"missing '{k}' value"
|
assert v, f"missing '{k}' value"
|
||||||
if k == 'cfg': # custom.yaml passed
|
if k == 'cfg': # custom.yaml passed
|
||||||
LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}")
|
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
|
||||||
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
|
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
|
||||||
else:
|
else:
|
||||||
if v.lower() == 'none':
|
if v.lower() == 'none':
|
||||||
@ -237,7 +237,7 @@ def entrypoint(debug=''):
|
|||||||
v = eval(v)
|
v = eval(v)
|
||||||
overrides[k] = v
|
overrides[k] = v
|
||||||
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
||||||
check_cfg_mismatch(full_args_dict, {a: ""}, e)
|
check_cfg_mismatch(full_args_dict, {a: ''}, e)
|
||||||
|
|
||||||
elif a in tasks:
|
elif a in tasks:
|
||||||
overrides['task'] = a
|
overrides['task'] = a
|
||||||
@ -252,7 +252,7 @@ def entrypoint(debug=''):
|
|||||||
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
||||||
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
||||||
else:
|
else:
|
||||||
check_cfg_mismatch(full_args_dict, {a: ""})
|
check_cfg_mismatch(full_args_dict, {a: ''})
|
||||||
|
|
||||||
# Defaults
|
# Defaults
|
||||||
task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt')
|
task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt')
|
||||||
@ -287,8 +287,8 @@ def entrypoint(debug=''):
|
|||||||
task = model.task
|
task = model.task
|
||||||
overrides['task'] = task
|
overrides['task'] = task
|
||||||
if mode in {'predict', 'track'} and 'source' not in overrides:
|
if mode in {'predict', 'track'} and 'source' not in overrides:
|
||||||
overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
|
overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||||
else "https://ultralytics.com/images/bus.jpg"
|
else 'https://ultralytics.com/images/bus.jpg'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||||
elif mode in ('train', 'val'):
|
elif mode in ('train', 'val'):
|
||||||
if 'data' not in overrides:
|
if 'data' not in overrides:
|
||||||
@ -308,7 +308,7 @@ def entrypoint(debug=''):
|
|||||||
def copy_default_cfg():
|
def copy_default_cfg():
|
||||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||||
LOGGER.info(f"{DEFAULT_CFG_PATH} copied to {new_file}\n"
|
LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
|
||||||
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
|
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,11 +6,11 @@ from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
|||||||
from .dataset_wrappers import MixAndRectDataset
|
from .dataset_wrappers import MixAndRectDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseDataset",
|
'BaseDataset',
|
||||||
"ClassificationDataset",
|
'ClassificationDataset',
|
||||||
"MixAndRectDataset",
|
'MixAndRectDataset',
|
||||||
"SemanticDataset",
|
'SemanticDataset',
|
||||||
"YOLODataset",
|
'YOLODataset',
|
||||||
"build_classification_dataloader",
|
'build_classification_dataloader',
|
||||||
"build_dataloader",
|
'build_dataloader',
|
||||||
"load_inference_source",]
|
'load_inference_source',]
|
||||||
|
@ -55,11 +55,11 @@ class Compose:
|
|||||||
return self.transforms
|
return self.transforms
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
format_string = f"{self.__class__.__name__}("
|
format_string = f'{self.__class__.__name__}('
|
||||||
for t in self.transforms:
|
for t in self.transforms:
|
||||||
format_string += "\n"
|
format_string += '\n'
|
||||||
format_string += f" {t}"
|
format_string += f' {t}'
|
||||||
format_string += "\n)"
|
format_string += '\n)'
|
||||||
return format_string
|
return format_string
|
||||||
|
|
||||||
|
|
||||||
@ -86,11 +86,11 @@ class BaseMixTransform:
|
|||||||
if self.pre_transform is not None:
|
if self.pre_transform is not None:
|
||||||
for i, data in enumerate(mix_labels):
|
for i, data in enumerate(mix_labels):
|
||||||
mix_labels[i] = self.pre_transform(data)
|
mix_labels[i] = self.pre_transform(data)
|
||||||
labels["mix_labels"] = mix_labels
|
labels['mix_labels'] = mix_labels
|
||||||
|
|
||||||
# Mosaic or MixUp
|
# Mosaic or MixUp
|
||||||
labels = self._mix_transform(labels)
|
labels = self._mix_transform(labels)
|
||||||
labels.pop("mix_labels", None)
|
labels.pop('mix_labels', None)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def _mix_transform(self, labels):
|
def _mix_transform(self, labels):
|
||||||
@ -109,7 +109,7 @@ class Mosaic(BaseMixTransform):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
|
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
|
||||||
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
|
assert 0 <= p <= 1.0, 'The probability should be in range [0, 1]. ' f'got {p}.'
|
||||||
super().__init__(dataset=dataset, p=p)
|
super().__init__(dataset=dataset, p=p)
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.imgsz = imgsz
|
self.imgsz = imgsz
|
||||||
@ -120,15 +120,15 @@ class Mosaic(BaseMixTransform):
|
|||||||
|
|
||||||
def _mix_transform(self, labels):
|
def _mix_transform(self, labels):
|
||||||
mosaic_labels = []
|
mosaic_labels = []
|
||||||
assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive."
|
assert labels.get('rect_shape', None) is None, 'rect and mosaic is exclusive.'
|
||||||
assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
|
assert len(labels.get('mix_labels', [])) > 0, 'There are no other images for mosaic augment.'
|
||||||
s = self.imgsz
|
s = self.imgsz
|
||||||
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
labels_patch = (labels if i == 0 else labels["mix_labels"][i - 1]).copy()
|
labels_patch = (labels if i == 0 else labels['mix_labels'][i - 1]).copy()
|
||||||
# Load image
|
# Load image
|
||||||
img = labels_patch["img"]
|
img = labels_patch['img']
|
||||||
h, w = labels_patch.pop("resized_shape")
|
h, w = labels_patch.pop('resized_shape')
|
||||||
|
|
||||||
# place img in img4
|
# place img in img4
|
||||||
if i == 0: # top left
|
if i == 0: # top left
|
||||||
@ -152,15 +152,15 @@ class Mosaic(BaseMixTransform):
|
|||||||
labels_patch = self._update_labels(labels_patch, padw, padh)
|
labels_patch = self._update_labels(labels_patch, padw, padh)
|
||||||
mosaic_labels.append(labels_patch)
|
mosaic_labels.append(labels_patch)
|
||||||
final_labels = self._cat_labels(mosaic_labels)
|
final_labels = self._cat_labels(mosaic_labels)
|
||||||
final_labels["img"] = img4
|
final_labels['img'] = img4
|
||||||
return final_labels
|
return final_labels
|
||||||
|
|
||||||
def _update_labels(self, labels, padw, padh):
|
def _update_labels(self, labels, padw, padh):
|
||||||
"""Update labels"""
|
"""Update labels"""
|
||||||
nh, nw = labels["img"].shape[:2]
|
nh, nw = labels['img'].shape[:2]
|
||||||
labels["instances"].convert_bbox(format="xyxy")
|
labels['instances'].convert_bbox(format='xyxy')
|
||||||
labels["instances"].denormalize(nw, nh)
|
labels['instances'].denormalize(nw, nh)
|
||||||
labels["instances"].add_padding(padw, padh)
|
labels['instances'].add_padding(padw, padh)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def _cat_labels(self, mosaic_labels):
|
def _cat_labels(self, mosaic_labels):
|
||||||
@ -169,16 +169,16 @@ class Mosaic(BaseMixTransform):
|
|||||||
cls = []
|
cls = []
|
||||||
instances = []
|
instances = []
|
||||||
for labels in mosaic_labels:
|
for labels in mosaic_labels:
|
||||||
cls.append(labels["cls"])
|
cls.append(labels['cls'])
|
||||||
instances.append(labels["instances"])
|
instances.append(labels['instances'])
|
||||||
final_labels = {
|
final_labels = {
|
||||||
"im_file": mosaic_labels[0]["im_file"],
|
'im_file': mosaic_labels[0]['im_file'],
|
||||||
"ori_shape": mosaic_labels[0]["ori_shape"],
|
'ori_shape': mosaic_labels[0]['ori_shape'],
|
||||||
"resized_shape": (self.imgsz * 2, self.imgsz * 2),
|
'resized_shape': (self.imgsz * 2, self.imgsz * 2),
|
||||||
"cls": np.concatenate(cls, 0),
|
'cls': np.concatenate(cls, 0),
|
||||||
"instances": Instances.concatenate(instances, axis=0),
|
'instances': Instances.concatenate(instances, axis=0),
|
||||||
"mosaic_border": self.border}
|
'mosaic_border': self.border}
|
||||||
final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2)
|
final_labels['instances'].clip(self.imgsz * 2, self.imgsz * 2)
|
||||||
return final_labels
|
return final_labels
|
||||||
|
|
||||||
|
|
||||||
@ -193,10 +193,10 @@ class MixUp(BaseMixTransform):
|
|||||||
def _mix_transform(self, labels):
|
def _mix_transform(self, labels):
|
||||||
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
||||||
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
||||||
labels2 = labels["mix_labels"][0]
|
labels2 = labels['mix_labels'][0]
|
||||||
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
|
labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
|
||||||
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
|
labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
|
||||||
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
|
labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@ -338,18 +338,18 @@ class RandomPerspective:
|
|||||||
Args:
|
Args:
|
||||||
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
|
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
|
||||||
"""
|
"""
|
||||||
if self.pre_transform and "mosaic_border" not in labels:
|
if self.pre_transform and 'mosaic_border' not in labels:
|
||||||
labels = self.pre_transform(labels)
|
labels = self.pre_transform(labels)
|
||||||
labels.pop("ratio_pad") # do not need ratio pad
|
labels.pop('ratio_pad') # do not need ratio pad
|
||||||
|
|
||||||
img = labels["img"]
|
img = labels['img']
|
||||||
cls = labels["cls"]
|
cls = labels['cls']
|
||||||
instances = labels.pop("instances")
|
instances = labels.pop('instances')
|
||||||
# make sure the coord formats are right
|
# make sure the coord formats are right
|
||||||
instances.convert_bbox(format="xyxy")
|
instances.convert_bbox(format='xyxy')
|
||||||
instances.denormalize(*img.shape[:2][::-1])
|
instances.denormalize(*img.shape[:2][::-1])
|
||||||
|
|
||||||
border = labels.pop("mosaic_border", self.border)
|
border = labels.pop('mosaic_border', self.border)
|
||||||
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
|
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
|
||||||
# M is affine matrix
|
# M is affine matrix
|
||||||
# scale for func:`box_candidates`
|
# scale for func:`box_candidates`
|
||||||
@ -365,7 +365,7 @@ class RandomPerspective:
|
|||||||
|
|
||||||
if keypoints is not None:
|
if keypoints is not None:
|
||||||
keypoints = self.apply_keypoints(keypoints, M)
|
keypoints = self.apply_keypoints(keypoints, M)
|
||||||
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
|
new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
|
||||||
# clip
|
# clip
|
||||||
new_instances.clip(*self.size)
|
new_instances.clip(*self.size)
|
||||||
|
|
||||||
@ -375,10 +375,10 @@ class RandomPerspective:
|
|||||||
i = self.box_candidates(box1=instances.bboxes.T,
|
i = self.box_candidates(box1=instances.bboxes.T,
|
||||||
box2=new_instances.bboxes.T,
|
box2=new_instances.bboxes.T,
|
||||||
area_thr=0.01 if len(segments) else 0.10)
|
area_thr=0.01 if len(segments) else 0.10)
|
||||||
labels["instances"] = new_instances[i]
|
labels['instances'] = new_instances[i]
|
||||||
labels["cls"] = cls[i]
|
labels['cls'] = cls[i]
|
||||||
labels["img"] = img
|
labels['img'] = img
|
||||||
labels["resized_shape"] = img.shape[:2]
|
labels['resized_shape'] = img.shape[:2]
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
||||||
@ -397,7 +397,7 @@ class RandomHSV:
|
|||||||
self.vgain = vgain
|
self.vgain = vgain
|
||||||
|
|
||||||
def __call__(self, labels):
|
def __call__(self, labels):
|
||||||
img = labels["img"]
|
img = labels['img']
|
||||||
if self.hgain or self.sgain or self.vgain:
|
if self.hgain or self.sgain or self.vgain:
|
||||||
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
|
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
|
||||||
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
||||||
@ -415,30 +415,30 @@ class RandomHSV:
|
|||||||
|
|
||||||
class RandomFlip:
|
class RandomFlip:
|
||||||
|
|
||||||
def __init__(self, p=0.5, direction="horizontal") -> None:
|
def __init__(self, p=0.5, direction='horizontal') -> None:
|
||||||
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
|
assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
|
||||||
assert 0 <= p <= 1.0
|
assert 0 <= p <= 1.0
|
||||||
|
|
||||||
self.p = p
|
self.p = p
|
||||||
self.direction = direction
|
self.direction = direction
|
||||||
|
|
||||||
def __call__(self, labels):
|
def __call__(self, labels):
|
||||||
img = labels["img"]
|
img = labels['img']
|
||||||
instances = labels.pop("instances")
|
instances = labels.pop('instances')
|
||||||
instances.convert_bbox(format="xywh")
|
instances.convert_bbox(format='xywh')
|
||||||
h, w = img.shape[:2]
|
h, w = img.shape[:2]
|
||||||
h = 1 if instances.normalized else h
|
h = 1 if instances.normalized else h
|
||||||
w = 1 if instances.normalized else w
|
w = 1 if instances.normalized else w
|
||||||
|
|
||||||
# Flip up-down
|
# Flip up-down
|
||||||
if self.direction == "vertical" and random.random() < self.p:
|
if self.direction == 'vertical' and random.random() < self.p:
|
||||||
img = np.flipud(img)
|
img = np.flipud(img)
|
||||||
instances.flipud(h)
|
instances.flipud(h)
|
||||||
if self.direction == "horizontal" and random.random() < self.p:
|
if self.direction == 'horizontal' and random.random() < self.p:
|
||||||
img = np.fliplr(img)
|
img = np.fliplr(img)
|
||||||
instances.fliplr(w)
|
instances.fliplr(w)
|
||||||
labels["img"] = np.ascontiguousarray(img)
|
labels['img'] = np.ascontiguousarray(img)
|
||||||
labels["instances"] = instances
|
labels['instances'] = instances
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@ -455,9 +455,9 @@ class LetterBox:
|
|||||||
def __call__(self, labels=None, image=None):
|
def __call__(self, labels=None, image=None):
|
||||||
if labels is None:
|
if labels is None:
|
||||||
labels = {}
|
labels = {}
|
||||||
img = labels.get("img") if image is None else image
|
img = labels.get('img') if image is None else image
|
||||||
shape = img.shape[:2] # current shape [height, width]
|
shape = img.shape[:2] # current shape [height, width]
|
||||||
new_shape = labels.pop("rect_shape", self.new_shape)
|
new_shape = labels.pop('rect_shape', self.new_shape)
|
||||||
if isinstance(new_shape, int):
|
if isinstance(new_shape, int):
|
||||||
new_shape = (new_shape, new_shape)
|
new_shape = (new_shape, new_shape)
|
||||||
|
|
||||||
@ -479,8 +479,8 @@ class LetterBox:
|
|||||||
|
|
||||||
dw /= 2 # divide padding into 2 sides
|
dw /= 2 # divide padding into 2 sides
|
||||||
dh /= 2
|
dh /= 2
|
||||||
if labels.get("ratio_pad"):
|
if labels.get('ratio_pad'):
|
||||||
labels["ratio_pad"] = (labels["ratio_pad"], (dw, dh)) # for evaluation
|
labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation
|
||||||
|
|
||||||
if shape[::-1] != new_unpad: # resize
|
if shape[::-1] != new_unpad: # resize
|
||||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||||
@ -491,18 +491,18 @@ class LetterBox:
|
|||||||
|
|
||||||
if len(labels):
|
if len(labels):
|
||||||
labels = self._update_labels(labels, ratio, dw, dh)
|
labels = self._update_labels(labels, ratio, dw, dh)
|
||||||
labels["img"] = img
|
labels['img'] = img
|
||||||
labels["resized_shape"] = new_shape
|
labels['resized_shape'] = new_shape
|
||||||
return labels
|
return labels
|
||||||
else:
|
else:
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def _update_labels(self, labels, ratio, padw, padh):
|
def _update_labels(self, labels, ratio, padw, padh):
|
||||||
"""Update labels"""
|
"""Update labels"""
|
||||||
labels["instances"].convert_bbox(format="xyxy")
|
labels['instances'].convert_bbox(format='xyxy')
|
||||||
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
|
labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
|
||||||
labels["instances"].scale(*ratio)
|
labels['instances'].scale(*ratio)
|
||||||
labels["instances"].add_padding(padw, padh)
|
labels['instances'].add_padding(padw, padh)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@ -513,11 +513,11 @@ class CopyPaste:
|
|||||||
|
|
||||||
def __call__(self, labels):
|
def __call__(self, labels):
|
||||||
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
|
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
|
||||||
im = labels["img"]
|
im = labels['img']
|
||||||
cls = labels["cls"]
|
cls = labels['cls']
|
||||||
h, w = im.shape[:2]
|
h, w = im.shape[:2]
|
||||||
instances = labels.pop("instances")
|
instances = labels.pop('instances')
|
||||||
instances.convert_bbox(format="xyxy")
|
instances.convert_bbox(format='xyxy')
|
||||||
instances.denormalize(w, h)
|
instances.denormalize(w, h)
|
||||||
if self.p and len(instances.segments):
|
if self.p and len(instances.segments):
|
||||||
n = len(instances)
|
n = len(instances)
|
||||||
@ -540,9 +540,9 @@ class CopyPaste:
|
|||||||
i = cv2.flip(im_new, 1).astype(bool)
|
i = cv2.flip(im_new, 1).astype(bool)
|
||||||
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
|
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
|
||||||
|
|
||||||
labels["img"] = im
|
labels['img'] = im
|
||||||
labels["cls"] = cls
|
labels['cls'] = cls
|
||||||
labels["instances"] = instances
|
labels['instances'] = instances
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@ -551,11 +551,11 @@ class Albumentations:
|
|||||||
def __init__(self, p=1.0):
|
def __init__(self, p=1.0):
|
||||||
self.p = p
|
self.p = p
|
||||||
self.transform = None
|
self.transform = None
|
||||||
prefix = colorstr("albumentations: ")
|
prefix = colorstr('albumentations: ')
|
||||||
try:
|
try:
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
|
|
||||||
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
||||||
|
|
||||||
T = [
|
T = [
|
||||||
A.Blur(p=0.01),
|
A.Blur(p=0.01),
|
||||||
@ -565,28 +565,28 @@ class Albumentations:
|
|||||||
A.RandomBrightnessContrast(p=0.0),
|
A.RandomBrightnessContrast(p=0.0),
|
||||||
A.RandomGamma(p=0.0),
|
A.RandomGamma(p=0.0),
|
||||||
A.ImageCompression(quality_lower=75, p=0.0),] # transforms
|
A.ImageCompression(quality_lower=75, p=0.0),] # transforms
|
||||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
|
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
||||||
|
|
||||||
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
||||||
except ImportError: # package not installed, skip
|
except ImportError: # package not installed, skip
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.info(f"{prefix}{e}")
|
LOGGER.info(f'{prefix}{e}')
|
||||||
|
|
||||||
def __call__(self, labels):
|
def __call__(self, labels):
|
||||||
im = labels["img"]
|
im = labels['img']
|
||||||
cls = labels["cls"]
|
cls = labels['cls']
|
||||||
if len(cls):
|
if len(cls):
|
||||||
labels["instances"].convert_bbox("xywh")
|
labels['instances'].convert_bbox('xywh')
|
||||||
labels["instances"].normalize(*im.shape[:2][::-1])
|
labels['instances'].normalize(*im.shape[:2][::-1])
|
||||||
bboxes = labels["instances"].bboxes
|
bboxes = labels['instances'].bboxes
|
||||||
# TODO: add supports of segments and keypoints
|
# TODO: add supports of segments and keypoints
|
||||||
if self.transform and random.random() < self.p:
|
if self.transform and random.random() < self.p:
|
||||||
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
||||||
labels["img"] = new["image"]
|
labels['img'] = new['image']
|
||||||
labels["cls"] = np.array(new["class_labels"])
|
labels['cls'] = np.array(new['class_labels'])
|
||||||
bboxes = np.array(new["bboxes"])
|
bboxes = np.array(new['bboxes'])
|
||||||
labels["instances"].update(bboxes=bboxes)
|
labels['instances'].update(bboxes=bboxes)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@ -594,7 +594,7 @@ class Albumentations:
|
|||||||
class Format:
|
class Format:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
bbox_format="xywh",
|
bbox_format='xywh',
|
||||||
normalize=True,
|
normalize=True,
|
||||||
return_mask=False,
|
return_mask=False,
|
||||||
return_keypoint=False,
|
return_keypoint=False,
|
||||||
@ -610,10 +610,10 @@ class Format:
|
|||||||
self.batch_idx = batch_idx # keep the batch indexes
|
self.batch_idx = batch_idx # keep the batch indexes
|
||||||
|
|
||||||
def __call__(self, labels):
|
def __call__(self, labels):
|
||||||
img = labels.pop("img")
|
img = labels.pop('img')
|
||||||
h, w = img.shape[:2]
|
h, w = img.shape[:2]
|
||||||
cls = labels.pop("cls")
|
cls = labels.pop('cls')
|
||||||
instances = labels.pop("instances")
|
instances = labels.pop('instances')
|
||||||
instances.convert_bbox(format=self.bbox_format)
|
instances.convert_bbox(format=self.bbox_format)
|
||||||
instances.denormalize(w, h)
|
instances.denormalize(w, h)
|
||||||
nl = len(instances)
|
nl = len(instances)
|
||||||
@ -625,17 +625,17 @@ class Format:
|
|||||||
else:
|
else:
|
||||||
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
|
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
|
||||||
img.shape[1] // self.mask_ratio)
|
img.shape[1] // self.mask_ratio)
|
||||||
labels["masks"] = masks
|
labels['masks'] = masks
|
||||||
if self.normalize:
|
if self.normalize:
|
||||||
instances.normalize(w, h)
|
instances.normalize(w, h)
|
||||||
labels["img"] = self._format_img(img)
|
labels['img'] = self._format_img(img)
|
||||||
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||||
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||||
if self.return_keypoint:
|
if self.return_keypoint:
|
||||||
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
|
labels['keypoints'] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
|
||||||
# then we can use collate_fn
|
# then we can use collate_fn
|
||||||
if self.batch_idx:
|
if self.batch_idx:
|
||||||
labels["batch_idx"] = torch.zeros(nl)
|
labels['batch_idx'] = torch.zeros(nl)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def _format_img(self, img):
|
def _format_img(self, img):
|
||||||
@ -676,15 +676,15 @@ def v8_transforms(dataset, imgsz, hyp):
|
|||||||
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
||||||
Albumentations(p=1.0),
|
Albumentations(p=1.0),
|
||||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
RandomFlip(direction='vertical', p=hyp.flipud),
|
||||||
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
|
RandomFlip(direction='horizontal', p=hyp.fliplr),]) # transforms
|
||||||
|
|
||||||
|
|
||||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||||
def classify_transforms(size=224):
|
def classify_transforms(size=224):
|
||||||
# Transforms to apply if albumentations not installed
|
# Transforms to apply if albumentations not installed
|
||||||
if not isinstance(size, int):
|
if not isinstance(size, int):
|
||||||
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
|
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
|
||||||
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||||
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||||
|
|
||||||
@ -701,17 +701,17 @@ def classify_albumentations(
|
|||||||
auto_aug=False,
|
auto_aug=False,
|
||||||
):
|
):
|
||||||
# YOLOv8 classification Albumentations (optional, only used if package is installed)
|
# YOLOv8 classification Albumentations (optional, only used if package is installed)
|
||||||
prefix = colorstr("albumentations: ")
|
prefix = colorstr('albumentations: ')
|
||||||
try:
|
try:
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
|
||||||
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
||||||
if augment: # Resize and crop
|
if augment: # Resize and crop
|
||||||
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
|
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
|
||||||
if auto_aug:
|
if auto_aug:
|
||||||
# TODO: implement AugMix, AutoAug & RandAug in albumentation
|
# TODO: implement AugMix, AutoAug & RandAug in albumentation
|
||||||
LOGGER.info(f"{prefix}auto augmentations are currently not supported")
|
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
|
||||||
else:
|
else:
|
||||||
if hflip > 0:
|
if hflip > 0:
|
||||||
T += [A.HorizontalFlip(p=hflip)]
|
T += [A.HorizontalFlip(p=hflip)]
|
||||||
@ -723,13 +723,13 @@ def classify_albumentations(
|
|||||||
else: # Use fixed crop for eval set (reproducibility)
|
else: # Use fixed crop for eval set (reproducibility)
|
||||||
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
||||||
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
||||||
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
||||||
return A.Compose(T)
|
return A.Compose(T)
|
||||||
|
|
||||||
except ImportError: # package not installed, skip
|
except ImportError: # package not installed, skip
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.info(f"{prefix}{e}")
|
LOGGER.info(f'{prefix}{e}')
|
||||||
|
|
||||||
|
|
||||||
class ClassifyLetterBox:
|
class ClassifyLetterBox:
|
||||||
|
@ -31,7 +31,7 @@ class BaseDataset(Dataset):
|
|||||||
cache=False,
|
cache=False,
|
||||||
augment=True,
|
augment=True,
|
||||||
hyp=None,
|
hyp=None,
|
||||||
prefix="",
|
prefix='',
|
||||||
rect=False,
|
rect=False,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
stride=32,
|
stride=32,
|
||||||
@ -63,7 +63,7 @@ class BaseDataset(Dataset):
|
|||||||
|
|
||||||
# cache stuff
|
# cache stuff
|
||||||
self.ims = [None] * self.ni
|
self.ims = [None] * self.ni
|
||||||
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
|
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
||||||
if cache:
|
if cache:
|
||||||
self.cache_images(cache)
|
self.cache_images(cache)
|
||||||
|
|
||||||
@ -77,21 +77,21 @@ class BaseDataset(Dataset):
|
|||||||
for p in img_path if isinstance(img_path, list) else [img_path]:
|
for p in img_path if isinstance(img_path, list) else [img_path]:
|
||||||
p = Path(p) # os-agnostic
|
p = Path(p) # os-agnostic
|
||||||
if p.is_dir(): # dir
|
if p.is_dir(): # dir
|
||||||
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
|
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
|
||||||
# f = list(p.rglob('*.*')) # pathlib
|
# f = list(p.rglob('*.*')) # pathlib
|
||||||
elif p.is_file(): # file
|
elif p.is_file(): # file
|
||||||
with open(p) as t:
|
with open(p) as t:
|
||||||
t = t.read().strip().splitlines()
|
t = t.read().strip().splitlines()
|
||||||
parent = str(p.parent) + os.sep
|
parent = str(p.parent) + os.sep
|
||||||
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
|
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
|
||||||
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
raise FileNotFoundError(f'{self.prefix}{p} does not exist')
|
||||||
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
|
||||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||||
assert im_files, f"{self.prefix}No images found"
|
assert im_files, f'{self.prefix}No images found'
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
|
||||||
return im_files
|
return im_files
|
||||||
|
|
||||||
def update_labels(self, include_class: Optional[list]):
|
def update_labels(self, include_class: Optional[list]):
|
||||||
@ -99,16 +99,16 @@ class BaseDataset(Dataset):
|
|||||||
include_class_array = np.array(include_class).reshape(1, -1)
|
include_class_array = np.array(include_class).reshape(1, -1)
|
||||||
for i in range(len(self.labels)):
|
for i in range(len(self.labels)):
|
||||||
if include_class:
|
if include_class:
|
||||||
cls = self.labels[i]["cls"]
|
cls = self.labels[i]['cls']
|
||||||
bboxes = self.labels[i]["bboxes"]
|
bboxes = self.labels[i]['bboxes']
|
||||||
segments = self.labels[i]["segments"]
|
segments = self.labels[i]['segments']
|
||||||
j = (cls == include_class_array).any(1)
|
j = (cls == include_class_array).any(1)
|
||||||
self.labels[i]["cls"] = cls[j]
|
self.labels[i]['cls'] = cls[j]
|
||||||
self.labels[i]["bboxes"] = bboxes[j]
|
self.labels[i]['bboxes'] = bboxes[j]
|
||||||
if segments:
|
if segments:
|
||||||
self.labels[i]["segments"] = segments[j]
|
self.labels[i]['segments'] = segments[j]
|
||||||
if self.single_cls:
|
if self.single_cls:
|
||||||
self.labels[i]["cls"][:, 0] = 0
|
self.labels[i]['cls'][:, 0] = 0
|
||||||
|
|
||||||
def load_image(self, i):
|
def load_image(self, i):
|
||||||
# Loads 1 image from dataset index 'i', returns (im, resized hw)
|
# Loads 1 image from dataset index 'i', returns (im, resized hw)
|
||||||
@ -119,7 +119,7 @@ class BaseDataset(Dataset):
|
|||||||
else: # read image
|
else: # read image
|
||||||
im = cv2.imread(f) # BGR
|
im = cv2.imread(f) # BGR
|
||||||
if im is None:
|
if im is None:
|
||||||
raise FileNotFoundError(f"Image Not Found {f}")
|
raise FileNotFoundError(f'Image Not Found {f}')
|
||||||
h0, w0 = im.shape[:2] # orig hw
|
h0, w0 = im.shape[:2] # orig hw
|
||||||
r = self.imgsz / max(h0, w0) # ratio
|
r = self.imgsz / max(h0, w0) # ratio
|
||||||
if r != 1: # if sizes are not equal
|
if r != 1: # if sizes are not equal
|
||||||
@ -132,17 +132,17 @@ class BaseDataset(Dataset):
|
|||||||
# cache images to memory or disk
|
# cache images to memory or disk
|
||||||
gb = 0 # Gigabytes of cached images
|
gb = 0 # Gigabytes of cached images
|
||||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
||||||
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
|
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
||||||
with ThreadPool(NUM_THREADS) as pool:
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
results = pool.imap(fcn, range(self.ni))
|
results = pool.imap(fcn, range(self.ni))
|
||||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||||
for i, x in pbar:
|
for i, x in pbar:
|
||||||
if cache == "disk":
|
if cache == 'disk':
|
||||||
gb += self.npy_files[i].stat().st_size
|
gb += self.npy_files[i].stat().st_size
|
||||||
else: # 'ram'
|
else: # 'ram'
|
||||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||||
gb += self.ims[i].nbytes
|
gb += self.ims[i].nbytes
|
||||||
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
|
pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})'
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
def cache_images_to_disk(self, i):
|
def cache_images_to_disk(self, i):
|
||||||
@ -155,7 +155,7 @@ class BaseDataset(Dataset):
|
|||||||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||||
nb = bi[-1] + 1 # number of batches
|
nb = bi[-1] + 1 # number of batches
|
||||||
|
|
||||||
s = np.array([x.pop("shape") for x in self.labels]) # hw
|
s = np.array([x.pop('shape') for x in self.labels]) # hw
|
||||||
ar = s[:, 0] / s[:, 1] # aspect ratio
|
ar = s[:, 0] / s[:, 1] # aspect ratio
|
||||||
irect = ar.argsort()
|
irect = ar.argsort()
|
||||||
self.im_files = [self.im_files[i] for i in irect]
|
self.im_files = [self.im_files[i] for i in irect]
|
||||||
@ -180,14 +180,14 @@ class BaseDataset(Dataset):
|
|||||||
|
|
||||||
def get_label_info(self, index):
|
def get_label_info(self, index):
|
||||||
label = self.labels[index].copy()
|
label = self.labels[index].copy()
|
||||||
label.pop("shape", None) # shape is for rect, remove it
|
label.pop('shape', None) # shape is for rect, remove it
|
||||||
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
|
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
|
||||||
label["ratio_pad"] = (
|
label['ratio_pad'] = (
|
||||||
label["resized_shape"][0] / label["ori_shape"][0],
|
label['resized_shape'][0] / label['ori_shape'][0],
|
||||||
label["resized_shape"][1] / label["ori_shape"][1],
|
label['resized_shape'][1] / label['ori_shape'][1],
|
||||||
) # for evaluation
|
) # for evaluation
|
||||||
if self.rect:
|
if self.rect:
|
||||||
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
label['rect_shape'] = self.batch_shapes[self.batch[index]]
|
||||||
label = self.update_labels_info(label)
|
label = self.update_labels_info(label)
|
||||||
return label
|
return label
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
|
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
||||||
self.iterator = super().__iter__()
|
self.iterator = super().__iter__()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -61,9 +61,9 @@ def seed_worker(worker_id):
|
|||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"):
|
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode='train'):
|
||||||
assert mode in ["train", "val"]
|
assert mode in ['train', 'val']
|
||||||
shuffle = mode == "train"
|
shuffle = mode == 'train'
|
||||||
if cfg.rect and shuffle:
|
if cfg.rect and shuffle:
|
||||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||||
shuffle = False
|
shuffle = False
|
||||||
@ -72,21 +72,21 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
|
|||||||
img_path=img_path,
|
img_path=img_path,
|
||||||
imgsz=cfg.imgsz,
|
imgsz=cfg.imgsz,
|
||||||
batch_size=batch,
|
batch_size=batch,
|
||||||
augment=mode == "train", # augmentation
|
augment=mode == 'train', # augmentation
|
||||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||||
rect=cfg.rect or rect, # rectangular batches
|
rect=cfg.rect or rect, # rectangular batches
|
||||||
cache=cfg.cache or None,
|
cache=cfg.cache or None,
|
||||||
single_cls=cfg.single_cls or False,
|
single_cls=cfg.single_cls or False,
|
||||||
stride=int(stride),
|
stride=int(stride),
|
||||||
pad=0.0 if mode == "train" else 0.5,
|
pad=0.0 if mode == 'train' else 0.5,
|
||||||
prefix=colorstr(f"{mode}: "),
|
prefix=colorstr(f'{mode}: '),
|
||||||
use_segments=cfg.task == "segment",
|
use_segments=cfg.task == 'segment',
|
||||||
use_keypoints=cfg.task == "keypoint",
|
use_keypoints=cfg.task == 'keypoint',
|
||||||
names=names)
|
names=names)
|
||||||
|
|
||||||
batch = min(batch, len(dataset))
|
batch = min(batch, len(dataset))
|
||||||
nd = torch.cuda.device_count() # number of CUDA devices
|
nd = torch.cuda.device_count() # number of CUDA devices
|
||||||
workers = cfg.workers if mode == "train" else cfg.workers * 2
|
workers = cfg.workers if mode == 'train' else cfg.workers * 2
|
||||||
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
|
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
|
||||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||||
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
|
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
|
||||||
@ -98,7 +98,7 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
|
|||||||
num_workers=nw,
|
num_workers=nw,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=PIN_MEMORY,
|
pin_memory=PIN_MEMORY,
|
||||||
collate_fn=getattr(dataset, "collate_fn", None),
|
collate_fn=getattr(dataset, 'collate_fn', None),
|
||||||
worker_init_fn=seed_worker,
|
worker_init_fn=seed_worker,
|
||||||
generator=generator), dataset
|
generator=generator), dataset
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ def check_source(source):
|
|||||||
from_img = True
|
from_img = True
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
|
'Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict')
|
||||||
|
|
||||||
return source, webcam, screenshot, from_img, in_memory
|
return source, webcam, screenshot, from_img, in_memory
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class LoadStreams:
|
|||||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
||||||
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
||||||
import pafy # noqa
|
import pafy # noqa
|
||||||
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
|
s = pafy.new(s).getbest(preftype='mp4').url # YouTube URL
|
||||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||||
if s == 0 and (is_colab() or is_kaggle()):
|
if s == 0 and (is_colab() or is_kaggle()):
|
||||||
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
||||||
@ -65,7 +65,7 @@ class LoadStreams:
|
|||||||
if not success or self.imgs[i] is None:
|
if not success or self.imgs[i] is None:
|
||||||
raise ConnectionError(f'{st}Failed to read images from {s}')
|
raise ConnectionError(f'{st}Failed to read images from {s}')
|
||||||
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
|
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
|
||||||
LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
|
LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
|
||||||
self.threads[i].start()
|
self.threads[i].start()
|
||||||
LOGGER.info('') # newline
|
LOGGER.info('') # newline
|
||||||
|
|
||||||
@ -145,11 +145,11 @@ class LoadScreenshots:
|
|||||||
|
|
||||||
# Parse monitor shape
|
# Parse monitor shape
|
||||||
monitor = self.sct.monitors[self.screen]
|
monitor = self.sct.monitors[self.screen]
|
||||||
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
self.top = monitor['top'] if top is None else (monitor['top'] + top)
|
||||||
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
self.left = monitor['left'] if left is None else (monitor['left'] + left)
|
||||||
self.width = width or monitor["width"]
|
self.width = width or monitor['width']
|
||||||
self.height = height or monitor["height"]
|
self.height = height or monitor['height']
|
||||||
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
@ -157,7 +157,7 @@ class LoadScreenshots:
|
|||||||
def __next__(self):
|
def __next__(self):
|
||||||
# mss screen capture: get raw pixels from the screen as np array
|
# mss screen capture: get raw pixels from the screen as np array
|
||||||
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
||||||
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
|
||||||
|
|
||||||
if self.transforms:
|
if self.transforms:
|
||||||
im = self.transforms(im0) # transforms
|
im = self.transforms(im0) # transforms
|
||||||
@ -172,7 +172,7 @@ class LoadScreenshots:
|
|||||||
class LoadImages:
|
class LoadImages:
|
||||||
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
|
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
|
||||||
def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||||
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
|
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
||||||
path = Path(path).read_text().rsplit()
|
path = Path(path).read_text().rsplit()
|
||||||
files = []
|
files = []
|
||||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||||
@ -290,12 +290,12 @@ class LoadPilAndNumpy:
|
|||||||
self.transforms = transforms
|
self.transforms = transforms
|
||||||
self.mode = 'image'
|
self.mode = 'image'
|
||||||
# generate fake paths
|
# generate fake paths
|
||||||
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
|
self.paths = [f'image{i}.jpg' for i in range(len(self.im0))]
|
||||||
self.bs = len(self.im0)
|
self.bs = len(self.im0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _single_check(im):
|
def _single_check(im):
|
||||||
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
|
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
|
||||||
if isinstance(im, Image.Image):
|
if isinstance(im, Image.Image):
|
||||||
im = np.asarray(im)[:, :, ::-1]
|
im = np.asarray(im)[:, :, ::-1]
|
||||||
im = np.ascontiguousarray(im) # contiguous
|
im = np.ascontiguousarray(im) # contiguous
|
||||||
@ -338,16 +338,16 @@ def autocast_list(source):
|
|||||||
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
||||||
files.append(im)
|
files.append(im)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
|
raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
|
||||||
f"See https://docs.ultralytics.com/predict for supported source types.")
|
f'See https://docs.ultralytics.com/predict for supported source types.')
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
|
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
img = cv2.imread(str(ROOT / "assets/bus.jpg"))
|
img = cv2.imread(str(ROOT / 'assets/bus.jpg'))
|
||||||
dataset = LoadPilAndNumpy(im0=img)
|
dataset = LoadPilAndNumpy(im0=img)
|
||||||
for d in dataset:
|
for d in dataset:
|
||||||
print(d[0])
|
print(d[0])
|
||||||
|
@ -92,7 +92,7 @@ def exif_transpose(image):
|
|||||||
if method is not None:
|
if method is not None:
|
||||||
image = image.transpose(method)
|
image = image.transpose(method)
|
||||||
del exif[0x0112]
|
del exif[0x0112]
|
||||||
image.info["exif"] = exif.tobytes()
|
image.info['exif'] = exif.tobytes()
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
@ -217,11 +217,11 @@ class LoadScreenshots:
|
|||||||
|
|
||||||
# Parse monitor shape
|
# Parse monitor shape
|
||||||
monitor = self.sct.monitors[self.screen]
|
monitor = self.sct.monitors[self.screen]
|
||||||
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
self.top = monitor['top'] if top is None else (monitor['top'] + top)
|
||||||
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
self.left = monitor['left'] if left is None else (monitor['left'] + left)
|
||||||
self.width = width or monitor["width"]
|
self.width = width or monitor['width']
|
||||||
self.height = height or monitor["height"]
|
self.height = height or monitor['height']
|
||||||
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
@ -229,7 +229,7 @@ class LoadScreenshots:
|
|||||||
def __next__(self):
|
def __next__(self):
|
||||||
# mss screen capture: get raw pixels from the screen as np array
|
# mss screen capture: get raw pixels from the screen as np array
|
||||||
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
||||||
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
|
||||||
|
|
||||||
if self.transforms:
|
if self.transforms:
|
||||||
im = self.transforms(im0) # transforms
|
im = self.transforms(im0) # transforms
|
||||||
@ -244,7 +244,7 @@ class LoadScreenshots:
|
|||||||
class LoadImages:
|
class LoadImages:
|
||||||
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
||||||
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||||
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
|
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
||||||
path = Path(path).read_text().rsplit()
|
path = Path(path).read_text().rsplit()
|
||||||
files = []
|
files = []
|
||||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||||
@ -363,7 +363,7 @@ class LoadStreams:
|
|||||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
||||||
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
||||||
import pafy
|
import pafy
|
||||||
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
|
s = pafy.new(s).getbest(preftype='mp4').url # YouTube URL
|
||||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||||
if s == 0:
|
if s == 0:
|
||||||
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
|
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
|
||||||
@ -378,7 +378,7 @@ class LoadStreams:
|
|||||||
|
|
||||||
_, self.imgs[i] = cap.read() # guarantee first frame
|
_, self.imgs[i] = cap.read() # guarantee first frame
|
||||||
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
|
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
|
||||||
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
|
LOGGER.info(f'{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)')
|
||||||
self.threads[i].start()
|
self.threads[i].start()
|
||||||
LOGGER.info('') # newline
|
LOGGER.info('') # newline
|
||||||
|
|
||||||
@ -500,7 +500,7 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
# Display cache
|
# Display cache
|
||||||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
||||||
if exists and LOCAL_RANK in {-1, 0}:
|
if exists and LOCAL_RANK in {-1, 0}:
|
||||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||||
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
||||||
if cache['msgs']:
|
if cache['msgs']:
|
||||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||||
@ -604,8 +604,8 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
mem = psutil.virtual_memory()
|
mem = psutil.virtual_memory()
|
||||||
cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
|
cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
|
||||||
if not cache:
|
if not cache:
|
||||||
LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, "
|
LOGGER.info(f'{prefix}{mem_required / gb:.1f}GB RAM required, '
|
||||||
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
|
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
|
||||||
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
|
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
@ -615,7 +615,7 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
path.unlink() # remove *.cache file if exists
|
path.unlink() # remove *.cache file if exists
|
||||||
x = {} # dict
|
x = {} # dict
|
||||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||||
desc = f"{prefix}Scanning {path.parent / path.stem}..."
|
desc = f'{prefix}Scanning {path.parent / path.stem}...'
|
||||||
total = len(self.im_files)
|
total = len(self.im_files)
|
||||||
with ThreadPool(NUM_THREADS) as pool:
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)))
|
results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)))
|
||||||
@ -629,7 +629,7 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
x[im_file] = [lb, shape, segments]
|
x[im_file] = [lb, shape, segments]
|
||||||
if msg:
|
if msg:
|
||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
if msgs:
|
if msgs:
|
||||||
@ -1060,7 +1060,7 @@ class HUBDatasetStats():
|
|||||||
if zipped:
|
if zipped:
|
||||||
data['path'] = data_dir
|
data['path'] = data_dir
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception("error/HUB/dataset_stats/yaml_load") from e
|
raise Exception('error/HUB/dataset_stats/yaml_load') from e
|
||||||
|
|
||||||
check_det_dataset(data, autodownload) # download dataset if missing
|
check_det_dataset(data, autodownload) # download dataset if missing
|
||||||
self.hub_dir = Path(data['path'] + '-hub')
|
self.hub_dir = Path(data['path'] + '-hub')
|
||||||
@ -1187,7 +1187,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
else: # read image
|
else: # read image
|
||||||
im = cv2.imread(f) # BGR
|
im = cv2.imread(f) # BGR
|
||||||
if self.album_transforms:
|
if self.album_transforms:
|
||||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
|
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
|
||||||
else:
|
else:
|
||||||
sample = self.torch_transforms(im)
|
sample = self.torch_transforms(im)
|
||||||
return sample, j
|
return sample, j
|
||||||
|
@ -28,7 +28,7 @@ class YOLODataset(BaseDataset):
|
|||||||
cache=False,
|
cache=False,
|
||||||
augment=True,
|
augment=True,
|
||||||
hyp=None,
|
hyp=None,
|
||||||
prefix="",
|
prefix='',
|
||||||
rect=False,
|
rect=False,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
stride=32,
|
stride=32,
|
||||||
@ -40,14 +40,14 @@ class YOLODataset(BaseDataset):
|
|||||||
self.use_segments = use_segments
|
self.use_segments = use_segments
|
||||||
self.use_keypoints = use_keypoints
|
self.use_keypoints = use_keypoints
|
||||||
self.names = names
|
self.names = names
|
||||||
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
|
||||||
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
|
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
|
||||||
|
|
||||||
def cache_labels(self, path=Path("./labels.cache")):
|
def cache_labels(self, path=Path('./labels.cache')):
|
||||||
# Cache dataset labels, check images and read shapes
|
# Cache dataset labels, check images and read shapes
|
||||||
x = {"labels": []}
|
x = {'labels': []}
|
||||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
|
||||||
total = len(self.im_files)
|
total = len(self.im_files)
|
||||||
with ThreadPool(NUM_THREADS) as pool:
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
results = pool.imap(func=verify_image_label,
|
results = pool.imap(func=verify_image_label,
|
||||||
@ -60,7 +60,7 @@ class YOLODataset(BaseDataset):
|
|||||||
ne += ne_f
|
ne += ne_f
|
||||||
nc += nc_f
|
nc += nc_f
|
||||||
if im_file:
|
if im_file:
|
||||||
x["labels"].append(
|
x['labels'].append(
|
||||||
dict(
|
dict(
|
||||||
im_file=im_file,
|
im_file=im_file,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
@ -69,68 +69,68 @@ class YOLODataset(BaseDataset):
|
|||||||
segments=segments,
|
segments=segments,
|
||||||
keypoints=keypoint,
|
keypoints=keypoint,
|
||||||
normalized=True,
|
normalized=True,
|
||||||
bbox_format="xywh"))
|
bbox_format='xywh'))
|
||||||
if msg:
|
if msg:
|
||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
if msgs:
|
if msgs:
|
||||||
LOGGER.info("\n".join(msgs))
|
LOGGER.info('\n'.join(msgs))
|
||||||
if nf == 0:
|
if nf == 0:
|
||||||
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
|
LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
|
||||||
x["hash"] = get_hash(self.label_files + self.im_files)
|
x['hash'] = get_hash(self.label_files + self.im_files)
|
||||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
||||||
x["msgs"] = msgs # warnings
|
x['msgs'] = msgs # warnings
|
||||||
x["version"] = self.cache_version # cache version
|
x['version'] = self.cache_version # cache version
|
||||||
if is_dir_writeable(path.parent):
|
if is_dir_writeable(path.parent):
|
||||||
if path.exists():
|
if path.exists():
|
||||||
path.unlink() # remove *.cache file if exists
|
path.unlink() # remove *.cache file if exists
|
||||||
np.save(str(path), x) # save cache for next time
|
np.save(str(path), x) # save cache for next time
|
||||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
||||||
LOGGER.info(f"{self.prefix}New cache created: {path}")
|
LOGGER.info(f'{self.prefix}New cache created: {path}')
|
||||||
else:
|
else:
|
||||||
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
self.label_files = img2label_paths(self.im_files)
|
self.label_files = img2label_paths(self.im_files)
|
||||||
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
||||||
try:
|
try:
|
||||||
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
||||||
assert cache["version"] == self.cache_version # matches current version
|
assert cache['version'] == self.cache_version # matches current version
|
||||||
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
||||||
except (FileNotFoundError, AssertionError, AttributeError):
|
except (FileNotFoundError, AssertionError, AttributeError):
|
||||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||||
|
|
||||||
# Display cache
|
# Display cache
|
||||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
||||||
if exists and LOCAL_RANK in {-1, 0}:
|
if exists and LOCAL_RANK in {-1, 0}:
|
||||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
||||||
if cache["msgs"]:
|
if cache['msgs']:
|
||||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||||
if nf == 0: # number of labels found
|
if nf == 0: # number of labels found
|
||||||
raise FileNotFoundError(f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}")
|
raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
|
||||||
|
|
||||||
# Read cache
|
# Read cache
|
||||||
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
||||||
labels = cache["labels"]
|
labels = cache['labels']
|
||||||
self.im_files = [lb["im_file"] for lb in labels] # update im_files
|
self.im_files = [lb['im_file'] for lb in labels] # update im_files
|
||||||
|
|
||||||
# Check if the dataset is all boxes or all segments
|
# Check if the dataset is all boxes or all segments
|
||||||
len_cls = sum(len(lb["cls"]) for lb in labels)
|
len_cls = sum(len(lb['cls']) for lb in labels)
|
||||||
len_boxes = sum(len(lb["bboxes"]) for lb in labels)
|
len_boxes = sum(len(lb['bboxes']) for lb in labels)
|
||||||
len_segments = sum(len(lb["segments"]) for lb in labels)
|
len_segments = sum(len(lb['segments']) for lb in labels)
|
||||||
if len_segments and len_boxes != len_segments:
|
if len_segments and len_boxes != len_segments:
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
|
||||||
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
|
||||||
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
|
'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
|
||||||
for lb in labels:
|
for lb in labels:
|
||||||
lb["segments"] = []
|
lb['segments'] = []
|
||||||
if len_cls == 0:
|
if len_cls == 0:
|
||||||
raise ValueError(f"All labels empty in {cache_path}, can not start training without labels. {HELP_URL}")
|
raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
# TODO: use hyp config to set all these augmentations
|
# TODO: use hyp config to set all these augmentations
|
||||||
@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
|
|||||||
else:
|
else:
|
||||||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
||||||
transforms.append(
|
transforms.append(
|
||||||
Format(bbox_format="xywh",
|
Format(bbox_format='xywh',
|
||||||
normalize=True,
|
normalize=True,
|
||||||
return_mask=self.use_segments,
|
return_mask=self.use_segments,
|
||||||
return_keypoint=self.use_keypoints,
|
return_keypoint=self.use_keypoints,
|
||||||
@ -161,12 +161,12 @@ class YOLODataset(BaseDataset):
|
|||||||
"""custom your label format here"""
|
"""custom your label format here"""
|
||||||
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
||||||
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
||||||
bboxes = label.pop("bboxes")
|
bboxes = label.pop('bboxes')
|
||||||
segments = label.pop("segments")
|
segments = label.pop('segments')
|
||||||
keypoints = label.pop("keypoints", None)
|
keypoints = label.pop('keypoints', None)
|
||||||
bbox_format = label.pop("bbox_format")
|
bbox_format = label.pop('bbox_format')
|
||||||
normalized = label.pop("normalized")
|
normalized = label.pop('normalized')
|
||||||
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||||
return label
|
return label
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -176,15 +176,15 @@ class YOLODataset(BaseDataset):
|
|||||||
values = list(zip(*[list(b.values()) for b in batch]))
|
values = list(zip(*[list(b.values()) for b in batch]))
|
||||||
for i, k in enumerate(keys):
|
for i, k in enumerate(keys):
|
||||||
value = values[i]
|
value = values[i]
|
||||||
if k == "img":
|
if k == 'img':
|
||||||
value = torch.stack(value, 0)
|
value = torch.stack(value, 0)
|
||||||
if k in ["masks", "keypoints", "bboxes", "cls"]:
|
if k in ['masks', 'keypoints', 'bboxes', 'cls']:
|
||||||
value = torch.cat(value, 0)
|
value = torch.cat(value, 0)
|
||||||
new_batch[k] = value
|
new_batch[k] = value
|
||||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
new_batch['batch_idx'] = list(new_batch['batch_idx'])
|
||||||
for i in range(len(new_batch["batch_idx"])):
|
for i in range(len(new_batch['batch_idx'])):
|
||||||
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
new_batch['batch_idx'][i] += i # add target image index for build_targets()
|
||||||
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
|
|
||||||
@ -202,9 +202,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
super().__init__(root=root)
|
super().__init__(root=root)
|
||||||
self.torch_transforms = classify_transforms(imgsz)
|
self.torch_transforms = classify_transforms(imgsz)
|
||||||
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
||||||
self.cache_ram = cache is True or cache == "ram"
|
self.cache_ram = cache is True or cache == 'ram'
|
||||||
self.cache_disk = cache == "disk"
|
self.cache_disk = cache == 'disk'
|
||||||
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||||
@ -217,7 +217,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
else: # read image
|
else: # read image
|
||||||
im = cv2.imread(f) # BGR
|
im = cv2.imread(f) # BGR
|
||||||
if self.album_transforms:
|
if self.album_transforms:
|
||||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
|
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
|
||||||
else:
|
else:
|
||||||
sample = self.torch_transforms(im)
|
sample = self.torch_transforms(im)
|
||||||
return {'img': sample, 'cls': j}
|
return {'img': sample, 'cls': j}
|
||||||
|
@ -25,15 +25,15 @@ class MixAndRectDataset:
|
|||||||
labels = deepcopy(self.dataset[index])
|
labels = deepcopy(self.dataset[index])
|
||||||
for transform in self.dataset.transforms.tolist():
|
for transform in self.dataset.transforms.tolist():
|
||||||
# mosaic and mixup
|
# mosaic and mixup
|
||||||
if hasattr(transform, "get_indexes"):
|
if hasattr(transform, 'get_indexes'):
|
||||||
indexes = transform.get_indexes(self.dataset)
|
indexes = transform.get_indexes(self.dataset)
|
||||||
if not isinstance(indexes, collections.abc.Sequence):
|
if not isinstance(indexes, collections.abc.Sequence):
|
||||||
indexes = [indexes]
|
indexes = [indexes]
|
||||||
mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
|
mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
|
||||||
labels["mix_labels"] = mix_labels
|
labels['mix_labels'] = mix_labels
|
||||||
if self.dataset.rect and isinstance(transform, LetterBox):
|
if self.dataset.rect and isinstance(transform, LetterBox):
|
||||||
transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
|
transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
|
||||||
labels = transform(labels)
|
labels = transform(labels)
|
||||||
if "mix_labels" in labels:
|
if 'mix_labels' in labels:
|
||||||
labels.pop("mix_labels")
|
labels.pop('mix_labels')
|
||||||
return labels
|
return labels
|
||||||
|
@ -18,32 +18,32 @@ from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
|
|||||||
from ultralytics.yolo.utils.downloads import download, safe_download
|
from ultralytics.yolo.utils.downloads import download, safe_download
|
||||||
from ultralytics.yolo.utils.ops import segments2boxes
|
from ultralytics.yolo.utils.ops import segments2boxes
|
||||||
|
|
||||||
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
|
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
||||||
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
|
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
|
||||||
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
|
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
||||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
RANK = int(os.getenv('RANK', -1))
|
RANK = int(os.getenv('RANK', -1))
|
||||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
||||||
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
|
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
|
||||||
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
||||||
|
|
||||||
# Get orientation exif tag
|
# Get orientation exif tag
|
||||||
for orientation in ExifTags.TAGS.keys():
|
for orientation in ExifTags.TAGS.keys():
|
||||||
if ExifTags.TAGS[orientation] == "Orientation":
|
if ExifTags.TAGS[orientation] == 'Orientation':
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def img2label_paths(img_paths):
|
def img2label_paths(img_paths):
|
||||||
# Define label paths as a function of image paths
|
# Define label paths as a function of image paths
|
||||||
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
|
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
|
||||||
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
|
||||||
|
|
||||||
|
|
||||||
def get_hash(paths):
|
def get_hash(paths):
|
||||||
# Returns a single hash value of a list of paths (files or dirs)
|
# Returns a single hash value of a list of paths (files or dirs)
|
||||||
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
||||||
h = hashlib.sha256(str(size).encode()) # hash sizes
|
h = hashlib.sha256(str(size).encode()) # hash sizes
|
||||||
h.update("".join(paths).encode()) # hash paths
|
h.update(''.join(paths).encode()) # hash paths
|
||||||
return h.hexdigest() # return hash
|
return h.hexdigest() # return hash
|
||||||
|
|
||||||
|
|
||||||
@ -61,21 +61,21 @@ def verify_image_label(args):
|
|||||||
# Verify one image-label pair
|
# Verify one image-label pair
|
||||||
im_file, lb_file, prefix, keypoint, num_cls = args
|
im_file, lb_file, prefix, keypoint, num_cls = args
|
||||||
# number (missing, found, empty, corrupt), message, segments, keypoints
|
# number (missing, found, empty, corrupt), message, segments, keypoints
|
||||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
|
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
|
||||||
try:
|
try:
|
||||||
# verify images
|
# verify images
|
||||||
im = Image.open(im_file)
|
im = Image.open(im_file)
|
||||||
im.verify() # PIL verify
|
im.verify() # PIL verify
|
||||||
shape = exif_size(im) # image size
|
shape = exif_size(im) # image size
|
||||||
shape = (shape[1], shape[0]) # hw
|
shape = (shape[1], shape[0]) # hw
|
||||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
||||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
|
||||||
if im.format.lower() in ("jpg", "jpeg"):
|
if im.format.lower() in ('jpg', 'jpeg'):
|
||||||
with open(im_file, "rb") as f:
|
with open(im_file, 'rb') as f:
|
||||||
f.seek(-2, 2)
|
f.seek(-2, 2)
|
||||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
if f.read() != b'\xff\xd9': # corrupt JPEG
|
||||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
|
||||||
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
|
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
|
||||||
|
|
||||||
# verify labels
|
# verify labels
|
||||||
if os.path.isfile(lb_file):
|
if os.path.isfile(lb_file):
|
||||||
@ -90,31 +90,31 @@ def verify_image_label(args):
|
|||||||
nl = len(lb)
|
nl = len(lb)
|
||||||
if nl:
|
if nl:
|
||||||
if keypoint:
|
if keypoint:
|
||||||
assert lb.shape[1] == 56, "labels require 56 columns each"
|
assert lb.shape[1] == 56, 'labels require 56 columns each'
|
||||||
assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
|
assert (lb[:, 5::3] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||||
assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
|
assert (lb[:, 6::3] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||||
kpts = np.zeros((lb.shape[0], 39))
|
kpts = np.zeros((lb.shape[0], 39))
|
||||||
for i in range(len(lb)):
|
for i in range(len(lb)):
|
||||||
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)) # remove occlusion param from GT
|
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)) # remove occlusion param from GT
|
||||||
kpts[i] = np.hstack((lb[i, :5], kpt))
|
kpts[i] = np.hstack((lb[i, :5], kpt))
|
||||||
lb = kpts
|
lb = kpts
|
||||||
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
|
assert lb.shape[1] == 39, 'labels require 39 columns each after removing occlusion parameter'
|
||||||
else:
|
else:
|
||||||
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
|
||||||
assert (lb[:, 1:] <= 1).all(), \
|
assert (lb[:, 1:] <= 1).all(), \
|
||||||
f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
|
f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
|
||||||
# All labels
|
# All labels
|
||||||
max_cls = int(lb[:, 0].max()) # max label count
|
max_cls = int(lb[:, 0].max()) # max label count
|
||||||
assert max_cls <= num_cls, \
|
assert max_cls <= num_cls, \
|
||||||
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
|
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
|
||||||
f'Possible class labels are 0-{num_cls - 1}'
|
f'Possible class labels are 0-{num_cls - 1}'
|
||||||
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
|
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
|
||||||
_, i = np.unique(lb, axis=0, return_index=True)
|
_, i = np.unique(lb, axis=0, return_index=True)
|
||||||
if len(i) < nl: # duplicate row check
|
if len(i) < nl: # duplicate row check
|
||||||
lb = lb[i] # remove duplicates
|
lb = lb[i] # remove duplicates
|
||||||
if segments:
|
if segments:
|
||||||
segments = [segments[x] for x in i]
|
segments = [segments[x] for x in i]
|
||||||
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
|
msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
|
||||||
else:
|
else:
|
||||||
ne = 1 # label empty
|
ne = 1 # label empty
|
||||||
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
||||||
@ -127,7 +127,7 @@ def verify_image_label(args):
|
|||||||
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
nc = 1
|
nc = 1
|
||||||
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
|
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
|
||||||
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
||||||
|
|
||||||
|
|
||||||
@ -248,8 +248,8 @@ def check_det_dataset(dataset, autodownload=True):
|
|||||||
else: # python script
|
else: # python script
|
||||||
r = exec(s, {'yaml': data}) # return None
|
r = exec(s, {'yaml': data}) # return None
|
||||||
dt = f'({round(time.time() - t, 1)}s)'
|
dt = f'({round(time.time() - t, 1)}s)'
|
||||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
|
||||||
LOGGER.info(f"Dataset download {s}\n")
|
LOGGER.info(f'Dataset download {s}\n')
|
||||||
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
|
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
|
||||||
|
|
||||||
return data # dictionary
|
return data # dictionary
|
||||||
@ -284,9 +284,9 @@ def check_cls_dataset(dataset: str):
|
|||||||
download(url, dir=data_dir.parent)
|
download(url, dir=data_dir.parent)
|
||||||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
||||||
LOGGER.info(s)
|
LOGGER.info(s)
|
||||||
train_set = data_dir / "train"
|
train_set = data_dir / 'train'
|
||||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
|
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
|
||||||
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
||||||
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
||||||
names = dict(enumerate(sorted(names)))
|
names = dict(enumerate(sorted(names)))
|
||||||
return {"train": train_set, "val": test_set, "nc": nc, "names": names}
|
return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names}
|
||||||
|
@ -144,7 +144,7 @@ class Exporter:
|
|||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def __call__(self, model=None):
|
def __call__(self, model=None):
|
||||||
self.run_callbacks("on_export_start")
|
self.run_callbacks('on_export_start')
|
||||||
t = time.time()
|
t = time.time()
|
||||||
format = self.args.format.lower() # to lowercase
|
format = self.args.format.lower() # to lowercase
|
||||||
if format in {'tensorrt', 'trt'}: # engine aliases
|
if format in {'tensorrt', 'trt'}: # engine aliases
|
||||||
@ -207,7 +207,7 @@ class Exporter:
|
|||||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||||
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
||||||
self.metadata = {
|
self.metadata = {
|
||||||
'description': f"Ultralytics {self.pretty_name} model trained on {self.args.data}",
|
'description': f'Ultralytics {self.pretty_name} model trained on {self.args.data}',
|
||||||
'author': 'Ultralytics',
|
'author': 'Ultralytics',
|
||||||
'license': 'GPL-3.0 https://ultralytics.com/license',
|
'license': 'GPL-3.0 https://ultralytics.com/license',
|
||||||
'version': __version__,
|
'version': __version__,
|
||||||
@ -215,7 +215,7 @@ class Exporter:
|
|||||||
'names': model.names} # model metadata
|
'names': model.names} # model metadata
|
||||||
|
|
||||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||||
f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)")
|
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
|
||||||
|
|
||||||
# Exports
|
# Exports
|
||||||
f = [''] * len(fmts) # exported filenames
|
f = [''] * len(fmts) # exported filenames
|
||||||
@ -259,15 +259,15 @@ class Exporter:
|
|||||||
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
|
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
|
||||||
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
||||||
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
|
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
|
||||||
data = f"data={self.args.data}" if model.task == 'segment' and format == 'pb' else ''
|
data = f'data={self.args.data}' if model.task == 'segment' and format == 'pb' else ''
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
f'\nExport complete ({time.time() - t:.1f}s)'
|
f'\nExport complete ({time.time() - t:.1f}s)'
|
||||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||||
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}"
|
f'\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}'
|
||||||
f"\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}"
|
f'\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}'
|
||||||
f"\nVisualize: https://netron.app")
|
f'\nVisualize: https://netron.app')
|
||||||
|
|
||||||
self.run_callbacks("on_export_end")
|
self.run_callbacks('on_export_end')
|
||||||
return f # return list of exported files/dirs
|
return f # return list of exported files/dirs
|
||||||
|
|
||||||
@try_export
|
@try_export
|
||||||
@ -277,7 +277,7 @@ class Exporter:
|
|||||||
f = self.file.with_suffix('.torchscript')
|
f = self.file.with_suffix('.torchscript')
|
||||||
|
|
||||||
ts = torch.jit.trace(self.model, self.im, strict=False)
|
ts = torch.jit.trace(self.model, self.im, strict=False)
|
||||||
d = {"shape": self.im.shape, "stride": int(max(self.model.stride)), "names": self.model.names}
|
d = {'shape': self.im.shape, 'stride': int(max(self.model.stride)), 'names': self.model.names}
|
||||||
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
|
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
|
||||||
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
||||||
LOGGER.info(f'{prefix} optimizing for mobile...')
|
LOGGER.info(f'{prefix} optimizing for mobile...')
|
||||||
@ -354,7 +354,7 @@ class Exporter:
|
|||||||
|
|
||||||
ov_model = mo.convert_model(f_onnx,
|
ov_model = mo.convert_model(f_onnx,
|
||||||
model_name=self.pretty_name,
|
model_name=self.pretty_name,
|
||||||
framework="onnx",
|
framework='onnx',
|
||||||
compress_to_fp16=self.args.half) # export
|
compress_to_fp16=self.args.half) # export
|
||||||
ov.serialize(ov_model, f_ov) # save
|
ov.serialize(ov_model, f_ov) # save
|
||||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||||
@ -471,7 +471,7 @@ class Exporter:
|
|||||||
if self.args.dynamic:
|
if self.args.dynamic:
|
||||||
shape = self.im.shape
|
shape = self.im.shape
|
||||||
if shape[0] <= 1:
|
if shape[0] <= 1:
|
||||||
LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
|
LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
|
||||||
profile = builder.create_optimization_profile()
|
profile = builder.create_optimization_profile()
|
||||||
for inp in inputs:
|
for inp in inputs:
|
||||||
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
|
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
|
||||||
@ -509,8 +509,8 @@ class Exporter:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
|
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
|
||||||
import tensorflow as tf # noqa
|
import tensorflow as tf # noqa
|
||||||
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon", "tflite_support"),
|
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support'),
|
||||||
cmds="--extra-index-url https://pypi.ngc.nvidia.com")
|
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
|
||||||
|
|
||||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||||
f = str(self.file).replace(self.file.suffix, '_saved_model')
|
f = str(self.file).replace(self.file.suffix, '_saved_model')
|
||||||
@ -632,7 +632,7 @@ class Exporter:
|
|||||||
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
|
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
|
||||||
|
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
open(f, "wb").write(tflite_model)
|
open(f, 'wb').write(tflite_model)
|
||||||
return f, None
|
return f, None
|
||||||
|
|
||||||
@try_export
|
@try_export
|
||||||
@ -656,7 +656,7 @@ class Exporter:
|
|||||||
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
||||||
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
|
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
|
||||||
|
|
||||||
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}"
|
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}'
|
||||||
subprocess.run(cmd.split(), check=True)
|
subprocess.run(cmd.split(), check=True)
|
||||||
self._add_tflite_metadata(f)
|
self._add_tflite_metadata(f)
|
||||||
return f, None
|
return f, None
|
||||||
@ -707,8 +707,8 @@ class Exporter:
|
|||||||
|
|
||||||
# Creates input info.
|
# Creates input info.
|
||||||
input_meta = _metadata_fb.TensorMetadataT()
|
input_meta = _metadata_fb.TensorMetadataT()
|
||||||
input_meta.name = "image"
|
input_meta.name = 'image'
|
||||||
input_meta.description = "Input image to be detected."
|
input_meta.description = 'Input image to be detected.'
|
||||||
input_meta.content = _metadata_fb.ContentT()
|
input_meta.content = _metadata_fb.ContentT()
|
||||||
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
||||||
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
||||||
@ -716,8 +716,8 @@ class Exporter:
|
|||||||
|
|
||||||
# Creates output info.
|
# Creates output info.
|
||||||
output_meta = _metadata_fb.TensorMetadataT()
|
output_meta = _metadata_fb.TensorMetadataT()
|
||||||
output_meta.name = "output"
|
output_meta.name = 'output'
|
||||||
output_meta.description = "Coordinates of detected objects, class labels, and confidence score."
|
output_meta.description = 'Coordinates of detected objects, class labels, and confidence score.'
|
||||||
|
|
||||||
# Label file
|
# Label file
|
||||||
tmp_file = Path('/tmp/meta.txt')
|
tmp_file = Path('/tmp/meta.txt')
|
||||||
@ -868,8 +868,8 @@ class Exporter:
|
|||||||
|
|
||||||
|
|
||||||
def export(cfg=DEFAULT_CFG):
|
def export(cfg=DEFAULT_CFG):
|
||||||
cfg.model = cfg.model or "yolov8n.yaml"
|
cfg.model = cfg.model or 'yolov8n.yaml'
|
||||||
cfg.format = cfg.format or "torchscript"
|
cfg.format = cfg.format or 'torchscript'
|
||||||
|
|
||||||
# exporter = Exporter(cfg)
|
# exporter = Exporter(cfg)
|
||||||
#
|
#
|
||||||
@ -888,7 +888,7 @@ def export(cfg=DEFAULT_CFG):
|
|||||||
model.export(**vars(cfg))
|
model.export(**vars(cfg))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
CLI:
|
CLI:
|
||||||
yolo mode=export model=yolov8n.yaml format=onnx
|
yolo mode=export model=yolov8n.yaml format=onnx
|
||||||
|
@ -16,13 +16,13 @@ from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
|||||||
|
|
||||||
# Map head to model, trainer, validator, and predictor classes
|
# Map head to model, trainer, validator, and predictor classes
|
||||||
MODEL_MAP = {
|
MODEL_MAP = {
|
||||||
"classify": [
|
'classify': [
|
||||||
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
|
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
|
||||||
'yolo.TYPE.classify.ClassificationPredictor'],
|
'yolo.TYPE.classify.ClassificationPredictor'],
|
||||||
"detect": [
|
'detect': [
|
||||||
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
|
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
|
||||||
'yolo.TYPE.detect.DetectionPredictor'],
|
'yolo.TYPE.detect.DetectionPredictor'],
|
||||||
"segment": [
|
'segment': [
|
||||||
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
|
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
|
||||||
'yolo.TYPE.segment.SegmentationPredictor']}
|
'yolo.TYPE.segment.SegmentationPredictor']}
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ class YOLO:
|
|||||||
A python interface which emulates a model-like behaviour by wrapping trainers.
|
A python interface which emulates a model-like behaviour by wrapping trainers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model='yolov8n.pt', type="v8") -> None:
|
def __init__(self, model='yolov8n.pt', type='v8') -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the YOLO object.
|
Initializes the YOLO object.
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ class YOLO:
|
|||||||
suffix = Path(weights).suffix
|
suffix = Path(weights).suffix
|
||||||
if suffix == '.pt':
|
if suffix == '.pt':
|
||||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||||
self.task = self.model.args["task"]
|
self.task = self.model.args['task']
|
||||||
self.overrides = self.model.args
|
self.overrides = self.model.args
|
||||||
self._reset_ckpt_args(self.overrides)
|
self._reset_ckpt_args(self.overrides)
|
||||||
else:
|
else:
|
||||||
@ -111,7 +111,7 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(self.model, nn.Module):
|
if not isinstance(self.model, nn.Module):
|
||||||
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
|
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
|
||||||
f"PyTorch models can be used to train, val, predict and export, i.e. "
|
f'PyTorch models can be used to train, val, predict and export, i.e. '
|
||||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||||
|
|
||||||
@ -155,11 +155,11 @@ class YOLO:
|
|||||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||||
"""
|
"""
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides["conf"] = 0.25
|
overrides['conf'] = 0.25
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
overrides["mode"] = kwargs.get("mode", "predict")
|
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||||
assert overrides["mode"] in ['track', 'predict']
|
assert overrides['mode'] in ['track', 'predict']
|
||||||
overrides["save"] = kwargs.get("save", False) # not save files by default
|
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||||
if not self.predictor:
|
if not self.predictor:
|
||||||
self.predictor = self.PredictorClass(overrides=overrides)
|
self.predictor = self.PredictorClass(overrides=overrides)
|
||||||
self.predictor.setup_model(model=self.model)
|
self.predictor.setup_model(model=self.model)
|
||||||
@ -173,7 +173,7 @@ class YOLO:
|
|||||||
from ultralytics.tracker.track import register_tracker
|
from ultralytics.tracker.track import register_tracker
|
||||||
register_tracker(self)
|
register_tracker(self)
|
||||||
# bytetrack-based method needs low confidence predictions as input
|
# bytetrack-based method needs low confidence predictions as input
|
||||||
conf = kwargs.get("conf") or 0.1
|
conf = kwargs.get('conf') or 0.1
|
||||||
kwargs['conf'] = conf
|
kwargs['conf'] = conf
|
||||||
kwargs['mode'] = 'track'
|
kwargs['mode'] = 'track'
|
||||||
return self.predict(source=source, stream=stream, **kwargs)
|
return self.predict(source=source, stream=stream, **kwargs)
|
||||||
@ -188,9 +188,9 @@ class YOLO:
|
|||||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||||
"""
|
"""
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides["rect"] = True # rect batches as default
|
overrides['rect'] = True # rect batches as default
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
overrides["mode"] = "val"
|
overrides['mode'] = 'val'
|
||||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||||
args.data = data or args.data
|
args.data = data or args.data
|
||||||
args.task = self.task
|
args.task = self.task
|
||||||
@ -234,18 +234,18 @@ class YOLO:
|
|||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
if kwargs.get("cfg"):
|
if kwargs.get('cfg'):
|
||||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||||
overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
|
overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
|
||||||
overrides["task"] = self.task
|
overrides['task'] = self.task
|
||||||
overrides["mode"] = "train"
|
overrides['mode'] = 'train'
|
||||||
if not overrides.get("data"):
|
if not overrides.get('data'):
|
||||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||||
if overrides.get("resume"):
|
if overrides.get('resume'):
|
||||||
overrides["resume"] = self.ckpt_path
|
overrides['resume'] = self.ckpt_path
|
||||||
|
|
||||||
self.trainer = self.TrainerClass(overrides=overrides)
|
self.trainer = self.TrainerClass(overrides=overrides)
|
||||||
if not overrides.get("resume"): # manually set model only if not resuming
|
if not overrides.get('resume'): # manually set model only if not resuming
|
||||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||||
self.model = self.trainer.model
|
self.model = self.trainer.model
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
@ -267,9 +267,9 @@ class YOLO:
|
|||||||
|
|
||||||
def _assign_ops_from_task(self):
|
def _assign_ops_from_task(self):
|
||||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
||||||
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
|
||||||
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
|
validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
|
||||||
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
|
predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
|
||||||
return model_class, trainer_class, validator_class, predictor_class
|
return model_class, trainer_class, validator_class, predictor_class
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -292,7 +292,7 @@ class YOLO:
|
|||||||
Returns metrics if computed
|
Returns metrics if computed
|
||||||
"""
|
"""
|
||||||
if not self.metrics_data:
|
if not self.metrics_data:
|
||||||
LOGGER.info("No metrics data found! Run training or validation operation first.")
|
LOGGER.info('No metrics data found! Run training or validation operation first.')
|
||||||
|
|
||||||
return self.metrics_data
|
return self.metrics_data
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ class BasePredictor:
|
|||||||
"""
|
"""
|
||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||||
name = self.args.name or f"{self.args.mode}"
|
name = self.args.name or f'{self.args.mode}'
|
||||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||||
if self.args.conf is None:
|
if self.args.conf is None:
|
||||||
self.args.conf = 0.25 # default conf=0.25
|
self.args.conf = 0.25 # default conf=0.25
|
||||||
@ -97,10 +97,10 @@ class BasePredictor:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def get_annotator(self, img):
|
def get_annotator(self, img):
|
||||||
raise NotImplementedError("get_annotator function needs to be implemented")
|
raise NotImplementedError('get_annotator function needs to be implemented')
|
||||||
|
|
||||||
def write_results(self, results, batch, print_string):
|
def write_results(self, results, batch, print_string):
|
||||||
raise NotImplementedError("print_results function needs to be implemented")
|
raise NotImplementedError('print_results function needs to be implemented')
|
||||||
|
|
||||||
def postprocess(self, preds, img, orig_img):
|
def postprocess(self, preds, img, orig_img):
|
||||||
return preds
|
return preds
|
||||||
@ -135,7 +135,7 @@ class BasePredictor:
|
|||||||
|
|
||||||
def stream_inference(self, source=None, model=None):
|
def stream_inference(self, source=None, model=None):
|
||||||
if self.args.verbose:
|
if self.args.verbose:
|
||||||
LOGGER.info("")
|
LOGGER.info('')
|
||||||
|
|
||||||
# setup model
|
# setup model
|
||||||
if not self.model:
|
if not self.model:
|
||||||
@ -152,9 +152,9 @@ class BasePredictor:
|
|||||||
self.done_warmup = True
|
self.done_warmup = True
|
||||||
|
|
||||||
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
|
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
|
||||||
self.run_callbacks("on_predict_start")
|
self.run_callbacks('on_predict_start')
|
||||||
for batch in self.dataset:
|
for batch in self.dataset:
|
||||||
self.run_callbacks("on_predict_batch_start")
|
self.run_callbacks('on_predict_batch_start')
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
path, im, im0s, vid_cap, s = batch
|
path, im, im0s, vid_cap, s = batch
|
||||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||||
@ -170,7 +170,7 @@ class BasePredictor:
|
|||||||
# postprocess
|
# postprocess
|
||||||
with self.dt[2]:
|
with self.dt[2]:
|
||||||
self.results = self.postprocess(preds, im, im0s)
|
self.results = self.postprocess(preds, im, im0s)
|
||||||
self.run_callbacks("on_predict_postprocess_end")
|
self.run_callbacks('on_predict_postprocess_end')
|
||||||
|
|
||||||
# visualize, save, write results
|
# visualize, save, write results
|
||||||
for i in range(len(im)):
|
for i in range(len(im)):
|
||||||
@ -186,7 +186,7 @@ class BasePredictor:
|
|||||||
|
|
||||||
if self.args.save:
|
if self.args.save:
|
||||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||||
self.run_callbacks("on_predict_batch_end")
|
self.run_callbacks('on_predict_batch_end')
|
||||||
yield from self.results
|
yield from self.results
|
||||||
|
|
||||||
# Print time (inference-only)
|
# Print time (inference-only)
|
||||||
@ -207,7 +207,7 @@ class BasePredictor:
|
|||||||
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||||
|
|
||||||
self.run_callbacks("on_predict_end")
|
self.run_callbacks('on_predict_end')
|
||||||
|
|
||||||
def setup_model(self, model):
|
def setup_model(self, model):
|
||||||
device = select_device(self.args.device)
|
device = select_device(self.args.device)
|
||||||
|
@ -36,7 +36,7 @@ class Results:
|
|||||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||||
self.probs = probs if probs is not None else None
|
self.probs = probs if probs is not None else None
|
||||||
self.names = names
|
self.names = names
|
||||||
self.comp = ["boxes", "masks", "probs"]
|
self.comp = ['boxes', 'masks', 'probs']
|
||||||
|
|
||||||
def pandas(self):
|
def pandas(self):
|
||||||
pass
|
pass
|
||||||
@ -97,7 +97,7 @@ class Results:
|
|||||||
return len(getattr(self, item))
|
return len(getattr(self, item))
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
str_out = ""
|
str_out = ''
|
||||||
for item in self.comp:
|
for item in self.comp:
|
||||||
if getattr(self, item) is None:
|
if getattr(self, item) is None:
|
||||||
continue
|
continue
|
||||||
@ -105,7 +105,7 @@ class Results:
|
|||||||
return str_out
|
return str_out
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
str_out = ""
|
str_out = ''
|
||||||
for item in self.comp:
|
for item in self.comp:
|
||||||
if getattr(self, item) is None:
|
if getattr(self, item) is None:
|
||||||
continue
|
continue
|
||||||
@ -187,7 +187,7 @@ class Boxes:
|
|||||||
if boxes.ndim == 1:
|
if boxes.ndim == 1:
|
||||||
boxes = boxes[None, :]
|
boxes = boxes[None, :]
|
||||||
n = boxes.shape[-1]
|
n = boxes.shape[-1]
|
||||||
assert n in {6, 7}, f"expected `n` in [6, 7], but got {n}" # xyxy, (track_id), conf, cls
|
assert n in {6, 7}, f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
|
||||||
# TODO
|
# TODO
|
||||||
self.is_track = n == 7
|
self.is_track = n == 7
|
||||||
self.boxes = boxes
|
self.boxes = boxes
|
||||||
@ -268,8 +268,8 @@ class Boxes:
|
|||||||
return self.boxes.__str__()
|
return self.boxes.__str__()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.boxes)}\n" +
|
return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.boxes)}\n' +
|
||||||
f"shape: {self.boxes.shape}\n" + f"dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}")
|
f'shape: {self.boxes.shape}\n' + f'dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}')
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
boxes = self.boxes[idx]
|
boxes = self.boxes[idx]
|
||||||
@ -353,8 +353,8 @@ class Masks:
|
|||||||
return self.masks.__str__()
|
return self.masks.__str__()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.masks)}\n" +
|
return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.masks)}\n' +
|
||||||
f"shape: {self.masks.shape}\n" + f"dtype: {self.masks.dtype}\n + {self.masks.__repr__()}")
|
f'shape: {self.masks.shape}\n' + f'dtype: {self.masks.dtype}\n + {self.masks.__repr__()}')
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
masks = self.masks[idx]
|
masks = self.masks[idx]
|
||||||
@ -374,19 +374,19 @@ class Masks:
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
# test examples
|
# test examples
|
||||||
results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640])
|
results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640])
|
||||||
results = results.cuda()
|
results = results.cuda()
|
||||||
print("--cuda--pass--")
|
print('--cuda--pass--')
|
||||||
results = results.cpu()
|
results = results.cpu()
|
||||||
print("--cpu--pass--")
|
print('--cpu--pass--')
|
||||||
results = results.to("cuda:0")
|
results = results.to('cuda:0')
|
||||||
print("--to-cuda--pass--")
|
print('--to-cuda--pass--')
|
||||||
results = results.to("cpu")
|
results = results.to('cpu')
|
||||||
print("--to-cpu--pass--")
|
print('--to-cpu--pass--')
|
||||||
results = results.numpy()
|
results = results.numpy()
|
||||||
print("--numpy--pass--")
|
print('--numpy--pass--')
|
||||||
# box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5])
|
# box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5])
|
||||||
# box = box.cuda()
|
# box = box.cuda()
|
||||||
# box = box.cpu()
|
# box = box.cpu()
|
||||||
|
@ -90,7 +90,7 @@ class BaseTrainer:
|
|||||||
|
|
||||||
# Dirs
|
# Dirs
|
||||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||||
name = self.args.name or f"{self.args.mode}"
|
name = self.args.name or f'{self.args.mode}'
|
||||||
if hasattr(self.args, 'save_dir'):
|
if hasattr(self.args, 'save_dir'):
|
||||||
self.save_dir = Path(self.args.save_dir)
|
self.save_dir = Path(self.args.save_dir)
|
||||||
else:
|
else:
|
||||||
@ -121,7 +121,7 @@ class BaseTrainer:
|
|||||||
try:
|
try:
|
||||||
if self.args.task == 'classify':
|
if self.args.task == 'classify':
|
||||||
self.data = check_cls_dataset(self.args.data)
|
self.data = check_cls_dataset(self.args.data)
|
||||||
elif self.args.data.endswith(".yaml") or self.args.task in ('detect', 'segment'):
|
elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'):
|
||||||
self.data = check_det_dataset(self.args.data)
|
self.data = check_det_dataset(self.args.data)
|
||||||
if 'yaml_file' in self.data:
|
if 'yaml_file' in self.data:
|
||||||
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
||||||
@ -175,7 +175,7 @@ class BaseTrainer:
|
|||||||
world_size = 0
|
world_size = 0
|
||||||
|
|
||||||
# Run subprocess if DDP training, else train normally
|
# Run subprocess if DDP training, else train normally
|
||||||
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||||
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
|
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
|
||||||
try:
|
try:
|
||||||
subprocess.run(cmd, check=True)
|
subprocess.run(cmd, check=True)
|
||||||
@ -191,15 +191,15 @@ class BaseTrainer:
|
|||||||
# os.environ['MASTER_PORT'] = '9020'
|
# os.environ['MASTER_PORT'] = '9020'
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
self.device = torch.device('cuda', rank)
|
self.device = torch.device('cuda', rank)
|
||||||
self.console.info(f"DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}")
|
self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
|
||||||
|
|
||||||
def _setup_train(self, rank, world_size):
|
def _setup_train(self, rank, world_size):
|
||||||
"""
|
"""
|
||||||
Builds dataloaders and optimizer on correct rank process.
|
Builds dataloaders and optimizer on correct rank process.
|
||||||
"""
|
"""
|
||||||
# model
|
# model
|
||||||
self.run_callbacks("on_pretrain_routine_start")
|
self.run_callbacks('on_pretrain_routine_start')
|
||||||
ckpt = self.setup_model()
|
ckpt = self.setup_model()
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
self.set_model_attributes()
|
self.set_model_attributes()
|
||||||
@ -234,16 +234,16 @@ class BaseTrainer:
|
|||||||
|
|
||||||
# dataloaders
|
# dataloaders
|
||||||
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
|
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
|
||||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode='train')
|
||||||
if rank in {0, -1}:
|
if rank in {0, -1}:
|
||||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
|
||||||
self.validator = self.get_validator()
|
self.validator = self.get_validator()
|
||||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
|
||||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||||
self.ema = ModelEMA(self.model)
|
self.ema = ModelEMA(self.model)
|
||||||
self.resume_training(ckpt)
|
self.resume_training(ckpt)
|
||||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||||
self.run_callbacks("on_pretrain_routine_end")
|
self.run_callbacks('on_pretrain_routine_end')
|
||||||
|
|
||||||
def _do_train(self, rank=-1, world_size=1):
|
def _do_train(self, rank=-1, world_size=1):
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -257,24 +257,24 @@ class BaseTrainer:
|
|||||||
nb = len(self.train_loader) # number of batches
|
nb = len(self.train_loader) # number of batches
|
||||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||||
last_opt_step = -1
|
last_opt_step = -1
|
||||||
self.run_callbacks("on_train_start")
|
self.run_callbacks('on_train_start')
|
||||||
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||||
f"Starting training for {self.epochs} epochs...")
|
f'Starting training for {self.epochs} epochs...')
|
||||||
if self.args.close_mosaic:
|
if self.args.close_mosaic:
|
||||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||||
for epoch in range(self.start_epoch, self.epochs):
|
for epoch in range(self.start_epoch, self.epochs):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.run_callbacks("on_train_epoch_start")
|
self.run_callbacks('on_train_epoch_start')
|
||||||
self.model.train()
|
self.model.train()
|
||||||
if rank != -1:
|
if rank != -1:
|
||||||
self.train_loader.sampler.set_epoch(epoch)
|
self.train_loader.sampler.set_epoch(epoch)
|
||||||
pbar = enumerate(self.train_loader)
|
pbar = enumerate(self.train_loader)
|
||||||
# Update dataloader attributes (optional)
|
# Update dataloader attributes (optional)
|
||||||
if epoch == (self.epochs - self.args.close_mosaic):
|
if epoch == (self.epochs - self.args.close_mosaic):
|
||||||
self.console.info("Closing dataloader mosaic")
|
self.console.info('Closing dataloader mosaic')
|
||||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||||
self.train_loader.dataset.mosaic = False
|
self.train_loader.dataset.mosaic = False
|
||||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||||
@ -286,7 +286,7 @@ class BaseTrainer:
|
|||||||
self.tloss = None
|
self.tloss = None
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
for i, batch in pbar:
|
for i, batch in pbar:
|
||||||
self.run_callbacks("on_train_batch_start")
|
self.run_callbacks('on_train_batch_start')
|
||||||
# Warmup
|
# Warmup
|
||||||
ni = i + nb * epoch
|
ni = i + nb * epoch
|
||||||
if ni <= nw:
|
if ni <= nw:
|
||||||
@ -302,7 +302,7 @@ class BaseTrainer:
|
|||||||
# Forward
|
# Forward
|
||||||
with torch.cuda.amp.autocast(self.amp):
|
with torch.cuda.amp.autocast(self.amp):
|
||||||
batch = self.preprocess_batch(batch)
|
batch = self.preprocess_batch(batch)
|
||||||
preds = self.model(batch["img"])
|
preds = self.model(batch['img'])
|
||||||
self.loss, self.loss_items = self.criterion(preds, batch)
|
self.loss, self.loss_items = self.criterion(preds, batch)
|
||||||
if rank != -1:
|
if rank != -1:
|
||||||
self.loss *= world_size
|
self.loss *= world_size
|
||||||
@ -324,17 +324,17 @@ class BaseTrainer:
|
|||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
pbar.set_description(
|
pbar.set_description(
|
||||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
|
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
|
||||||
self.run_callbacks('on_batch_end')
|
self.run_callbacks('on_batch_end')
|
||||||
if self.args.plots and ni in self.plot_idx:
|
if self.args.plots and ni in self.plot_idx:
|
||||||
self.plot_training_samples(batch, ni)
|
self.plot_training_samples(batch, ni)
|
||||||
|
|
||||||
self.run_callbacks("on_train_batch_end")
|
self.run_callbacks('on_train_batch_end')
|
||||||
|
|
||||||
self.lr = {f"lr/pg{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||||
|
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
self.run_callbacks("on_train_epoch_end")
|
self.run_callbacks('on_train_epoch_end')
|
||||||
|
|
||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
|
|
||||||
@ -355,7 +355,7 @@ class BaseTrainer:
|
|||||||
tnow = time.time()
|
tnow = time.time()
|
||||||
self.epoch_time = tnow - self.epoch_time_start
|
self.epoch_time = tnow - self.epoch_time_start
|
||||||
self.epoch_time_start = tnow
|
self.epoch_time_start = tnow
|
||||||
self.run_callbacks("on_fit_epoch_end")
|
self.run_callbacks('on_fit_epoch_end')
|
||||||
|
|
||||||
# Early Stopping
|
# Early Stopping
|
||||||
if RANK != -1: # if DDP training
|
if RANK != -1: # if DDP training
|
||||||
@ -402,7 +402,7 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||||
"""
|
"""
|
||||||
return data["train"], data.get("val") or data.get("test")
|
return data['train'], data.get('val') or data.get('test')
|
||||||
|
|
||||||
def setup_model(self):
|
def setup_model(self):
|
||||||
"""
|
"""
|
||||||
@ -413,9 +413,9 @@ class BaseTrainer:
|
|||||||
|
|
||||||
model, weights = self.model, None
|
model, weights = self.model, None
|
||||||
ckpt = None
|
ckpt = None
|
||||||
if str(model).endswith(".pt"):
|
if str(model).endswith('.pt'):
|
||||||
weights, ckpt = attempt_load_one_weight(model)
|
weights, ckpt = attempt_load_one_weight(model)
|
||||||
cfg = ckpt["model"].yaml
|
cfg = ckpt['model'].yaml
|
||||||
else:
|
else:
|
||||||
cfg = model
|
cfg = model
|
||||||
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
||||||
@ -441,7 +441,7 @@ class BaseTrainer:
|
|||||||
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||||
"""
|
"""
|
||||||
metrics = self.validator(self)
|
metrics = self.validator(self)
|
||||||
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||||
if not self.best_fitness or self.best_fitness < fitness:
|
if not self.best_fitness or self.best_fitness < fitness:
|
||||||
self.best_fitness = fitness
|
self.best_fitness = fitness
|
||||||
return metrics, fitness
|
return metrics, fitness
|
||||||
@ -462,38 +462,38 @@ class BaseTrainer:
|
|||||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
raise NotImplementedError("get_validator function not implemented in trainer")
|
raise NotImplementedError('get_validator function not implemented in trainer')
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||||
"""
|
"""
|
||||||
Returns dataloader derived from torch.data.Dataloader.
|
Returns dataloader derived from torch.data.Dataloader.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
raise NotImplementedError('get_dataloader function not implemented in trainer')
|
||||||
|
|
||||||
def criterion(self, preds, batch):
|
def criterion(self, preds, batch):
|
||||||
"""
|
"""
|
||||||
Returns loss and individual loss items as Tensor.
|
Returns loss and individual loss items as Tensor.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("criterion function not implemented in trainer")
|
raise NotImplementedError('criterion function not implemented in trainer')
|
||||||
|
|
||||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||||
"""
|
"""
|
||||||
Returns a loss dict with labelled training loss items tensor
|
Returns a loss dict with labelled training loss items tensor
|
||||||
"""
|
"""
|
||||||
# Not needed for classification but necessary for segmentation & detection
|
# Not needed for classification but necessary for segmentation & detection
|
||||||
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
return {'loss': loss_items} if loss_items is not None else ['loss']
|
||||||
|
|
||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
"""
|
"""
|
||||||
To set or update model parameters before training.
|
To set or update model parameters before training.
|
||||||
"""
|
"""
|
||||||
self.model.names = self.data["names"]
|
self.model.names = self.data['names']
|
||||||
|
|
||||||
def build_targets(self, preds, targets):
|
def build_targets(self, preds, targets):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def progress_string(self):
|
def progress_string(self):
|
||||||
return ""
|
return ''
|
||||||
|
|
||||||
# TODO: may need to put these following functions into callback
|
# TODO: may need to put these following functions into callback
|
||||||
def plot_training_samples(self, batch, ni):
|
def plot_training_samples(self, batch, ni):
|
||||||
@ -529,7 +529,7 @@ class BaseTrainer:
|
|||||||
self.args = get_cfg(attempt_load_weights(last).args)
|
self.args = get_cfg(attempt_load_weights(last).args)
|
||||||
self.args.model, resume = str(last), True # reinstate
|
self.args.model, resume = str(last), True # reinstate
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise FileNotFoundError("Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
|
||||||
"i.e. 'yolo train resume model=path/to/last.pt'") from e
|
"i.e. 'yolo train resume model=path/to/last.pt'") from e
|
||||||
self.resume = resume
|
self.resume = resume
|
||||||
|
|
||||||
@ -557,7 +557,7 @@ class BaseTrainer:
|
|||||||
self.best_fitness = best_fitness
|
self.best_fitness = best_fitness
|
||||||
self.start_epoch = start_epoch
|
self.start_epoch = start_epoch
|
||||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||||
self.console.info("Closing dataloader mosaic")
|
self.console.info('Closing dataloader mosaic')
|
||||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||||
self.train_loader.dataset.mosaic = False
|
self.train_loader.dataset.mosaic = False
|
||||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||||
@ -602,5 +602,5 @@ class BaseTrainer:
|
|||||||
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
||||||
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
||||||
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
||||||
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
|
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
|
||||||
return optimizer
|
return optimizer
|
||||||
|
@ -62,7 +62,7 @@ class BaseValidator:
|
|||||||
self.jdict = None
|
self.jdict = None
|
||||||
|
|
||||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||||
name = self.args.name or f"{self.args.mode}"
|
name = self.args.name or f'{self.args.mode}'
|
||||||
self.save_dir = save_dir or increment_path(Path(project) / name,
|
self.save_dir = save_dir or increment_path(Path(project) / name,
|
||||||
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
|
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
|
||||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
@ -92,7 +92,7 @@ class BaseValidator:
|
|||||||
else:
|
else:
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
self.run_callbacks('on_val_start')
|
self.run_callbacks('on_val_start')
|
||||||
assert model is not None, "Either trainer or model is needed for validation"
|
assert model is not None, 'Either trainer or model is needed for validation'
|
||||||
self.device = select_device(self.args.device, self.args.batch)
|
self.device = select_device(self.args.device, self.args.batch)
|
||||||
self.args.half &= self.device.type != 'cpu'
|
self.args.half &= self.device.type != 'cpu'
|
||||||
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
|
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
|
||||||
@ -108,7 +108,7 @@ class BaseValidator:
|
|||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||||
|
|
||||||
if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"):
|
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
||||||
self.data = check_det_dataset(self.args.data)
|
self.data = check_det_dataset(self.args.data)
|
||||||
elif self.args.task == 'classify':
|
elif self.args.task == 'classify':
|
||||||
self.data = check_cls_dataset(self.args.data)
|
self.data = check_cls_dataset(self.args.data)
|
||||||
@ -142,7 +142,7 @@ class BaseValidator:
|
|||||||
|
|
||||||
# inference
|
# inference
|
||||||
with dt[1]:
|
with dt[1]:
|
||||||
preds = model(batch["img"])
|
preds = model(batch['img'])
|
||||||
|
|
||||||
# loss
|
# loss
|
||||||
with dt[2]:
|
with dt[2]:
|
||||||
@ -166,14 +166,14 @@ class BaseValidator:
|
|||||||
self.run_callbacks('on_val_end')
|
self.run_callbacks('on_val_end')
|
||||||
if self.training:
|
if self.training:
|
||||||
model.float()
|
model.float()
|
||||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
||||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||||
else:
|
else:
|
||||||
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
|
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
|
||||||
self.speed)
|
self.speed)
|
||||||
if self.args.save_json and self.jdict:
|
if self.args.save_json and self.jdict:
|
||||||
with open(str(self.save_dir / "predictions.json"), 'w') as f:
|
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
||||||
self.logger.info(f"Saving {f.name}...")
|
self.logger.info(f'Saving {f.name}...')
|
||||||
json.dump(self.jdict, f) # flatten and save
|
json.dump(self.jdict, f) # flatten and save
|
||||||
stats = self.eval_json(stats) # update stats
|
stats = self.eval_json(stats) # update stats
|
||||||
return stats
|
return stats
|
||||||
@ -183,7 +183,7 @@ class BaseValidator:
|
|||||||
callback(self)
|
callback(self)
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size):
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
raise NotImplementedError('get_dataloader function not implemented for this validator')
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
return batch
|
return batch
|
||||||
|
@ -27,7 +27,7 @@ from ultralytics import __version__
|
|||||||
# Constants
|
# Constants
|
||||||
FILE = Path(__file__).resolve()
|
FILE = Path(__file__).resolve()
|
||||||
ROOT = FILE.parents[2] # YOLO
|
ROOT = FILE.parents[2] # YOLO
|
||||||
DEFAULT_CFG_PATH = ROOT / "yolo/cfg/default.yaml"
|
DEFAULT_CFG_PATH = ROOT / 'yolo/cfg/default.yaml'
|
||||||
RANK = int(os.getenv('RANK', -1))
|
RANK = int(os.getenv('RANK', -1))
|
||||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||||
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||||
@ -111,7 +111,7 @@ class IterableSimpleNamespace(SimpleNamespace):
|
|||||||
return iter(vars(self).items())
|
return iter(vars(self).items())
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
|
return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
name = self.__class__.__name__
|
name = self.__class__.__name__
|
||||||
@ -288,7 +288,7 @@ def is_pytest_running():
|
|||||||
(bool): True if pytest is running, False otherwise.
|
(bool): True if pytest is running, False otherwise.
|
||||||
"""
|
"""
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
return "pytest" in sys.modules
|
return 'pytest' in sys.modules
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -336,7 +336,7 @@ def get_git_origin_url():
|
|||||||
"""
|
"""
|
||||||
if is_git_dir():
|
if is_git_dir():
|
||||||
with contextlib.suppress(subprocess.CalledProcessError):
|
with contextlib.suppress(subprocess.CalledProcessError):
|
||||||
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
|
origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
|
||||||
return origin.decode().strip()
|
return origin.decode().strip()
|
||||||
return None # if not git dir or on error
|
return None # if not git dir or on error
|
||||||
|
|
||||||
@ -350,7 +350,7 @@ def get_git_branch():
|
|||||||
"""
|
"""
|
||||||
if is_git_dir():
|
if is_git_dir():
|
||||||
with contextlib.suppress(subprocess.CalledProcessError):
|
with contextlib.suppress(subprocess.CalledProcessError):
|
||||||
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||||
return origin.decode().strip()
|
return origin.decode().strip()
|
||||||
return None # if not git dir or on error
|
return None # if not git dir or on error
|
||||||
|
|
||||||
@ -365,9 +365,9 @@ def get_latest_pypi_version(package_name='ultralytics'):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The latest version of the package.
|
str: The latest version of the package.
|
||||||
"""
|
"""
|
||||||
response = requests.get(f"https://pypi.org/pypi/{package_name}/json")
|
response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response.json()["info"]["version"]
|
return response.json()['info']['version']
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -424,28 +424,28 @@ def emojis(string=''):
|
|||||||
|
|
||||||
def colorstr(*input):
|
def colorstr(*input):
|
||||||
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
||||||
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
|
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
||||||
colors = {
|
colors = {
|
||||||
"black": "\033[30m", # basic colors
|
'black': '\033[30m', # basic colors
|
||||||
"red": "\033[31m",
|
'red': '\033[31m',
|
||||||
"green": "\033[32m",
|
'green': '\033[32m',
|
||||||
"yellow": "\033[33m",
|
'yellow': '\033[33m',
|
||||||
"blue": "\033[34m",
|
'blue': '\033[34m',
|
||||||
"magenta": "\033[35m",
|
'magenta': '\033[35m',
|
||||||
"cyan": "\033[36m",
|
'cyan': '\033[36m',
|
||||||
"white": "\033[37m",
|
'white': '\033[37m',
|
||||||
"bright_black": "\033[90m", # bright colors
|
'bright_black': '\033[90m', # bright colors
|
||||||
"bright_red": "\033[91m",
|
'bright_red': '\033[91m',
|
||||||
"bright_green": "\033[92m",
|
'bright_green': '\033[92m',
|
||||||
"bright_yellow": "\033[93m",
|
'bright_yellow': '\033[93m',
|
||||||
"bright_blue": "\033[94m",
|
'bright_blue': '\033[94m',
|
||||||
"bright_magenta": "\033[95m",
|
'bright_magenta': '\033[95m',
|
||||||
"bright_cyan": "\033[96m",
|
'bright_cyan': '\033[96m',
|
||||||
"bright_white": "\033[97m",
|
'bright_white': '\033[97m',
|
||||||
"end": "\033[0m", # misc
|
'end': '\033[0m', # misc
|
||||||
"bold": "\033[1m",
|
'bold': '\033[1m',
|
||||||
"underline": "\033[4m"}
|
'underline': '\033[4m'}
|
||||||
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
||||||
|
|
||||||
|
|
||||||
def remove_ansi_codes(string):
|
def remove_ansi_codes(string):
|
||||||
@ -466,21 +466,21 @@ def set_logging(name=LOGGING_NAME, verbose=True):
|
|||||||
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
||||||
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
||||||
logging.config.dictConfig({
|
logging.config.dictConfig({
|
||||||
"version": 1,
|
'version': 1,
|
||||||
"disable_existing_loggers": False,
|
'disable_existing_loggers': False,
|
||||||
"formatters": {
|
'formatters': {
|
||||||
name: {
|
name: {
|
||||||
"format": "%(message)s"}},
|
'format': '%(message)s'}},
|
||||||
"handlers": {
|
'handlers': {
|
||||||
name: {
|
name: {
|
||||||
"class": "logging.StreamHandler",
|
'class': 'logging.StreamHandler',
|
||||||
"formatter": name,
|
'formatter': name,
|
||||||
"level": level}},
|
'level': level}},
|
||||||
"loggers": {
|
'loggers': {
|
||||||
name: {
|
name: {
|
||||||
"level": level,
|
'level': level,
|
||||||
"handlers": [name],
|
'handlers': [name],
|
||||||
"propagate": False}}})
|
'propagate': False}}})
|
||||||
|
|
||||||
|
|
||||||
class TryExcept(contextlib.ContextDecorator):
|
class TryExcept(contextlib.ContextDecorator):
|
||||||
@ -521,10 +521,10 @@ def set_sentry():
|
|||||||
return None # do not send event
|
return None # do not send event
|
||||||
|
|
||||||
event['tags'] = {
|
event['tags'] = {
|
||||||
"sys_argv": sys.argv[0],
|
'sys_argv': sys.argv[0],
|
||||||
"sys_argv_name": Path(sys.argv[0]).name,
|
'sys_argv_name': Path(sys.argv[0]).name,
|
||||||
"install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
||||||
"os": ENVIRONMENT}
|
'os': ENVIRONMENT}
|
||||||
return event
|
return event
|
||||||
|
|
||||||
if SETTINGS['sync'] and \
|
if SETTINGS['sync'] and \
|
||||||
@ -533,24 +533,24 @@ def set_sentry():
|
|||||||
not is_pytest_running() and \
|
not is_pytest_running() and \
|
||||||
not is_github_actions_ci() and \
|
not is_github_actions_ci() and \
|
||||||
((is_pip_package() and not is_git_dir()) or
|
((is_pip_package() and not is_git_dir()) or
|
||||||
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
|
(get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git' and get_git_branch() == 'main')):
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
import sentry_sdk # noqa
|
import sentry_sdk # noqa
|
||||||
|
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
dsn="https://f805855f03bb4363bc1e16cb7d87b654@o4504521589325824.ingest.sentry.io/4504521592406016",
|
dsn='https://f805855f03bb4363bc1e16cb7d87b654@o4504521589325824.ingest.sentry.io/4504521592406016',
|
||||||
debug=False,
|
debug=False,
|
||||||
traces_sample_rate=1.0,
|
traces_sample_rate=1.0,
|
||||||
release=__version__,
|
release=__version__,
|
||||||
environment='production', # 'dev' or 'production'
|
environment='production', # 'dev' or 'production'
|
||||||
before_send=before_send,
|
before_send=before_send,
|
||||||
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
|
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
|
||||||
sentry_sdk.set_user({"id": SETTINGS['uuid']})
|
sentry_sdk.set_user({'id': SETTINGS['uuid']})
|
||||||
|
|
||||||
# Disable all sentry logging
|
# Disable all sentry logging
|
||||||
for logger in "sentry_sdk", "sentry_sdk.errors":
|
for logger in 'sentry_sdk', 'sentry_sdk.errors':
|
||||||
logging.getLogger(logger).setLevel(logging.CRITICAL)
|
logging.getLogger(logger).setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
|
||||||
@ -620,7 +620,7 @@ if WINDOWS:
|
|||||||
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
|
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
|
||||||
|
|
||||||
# Check first-install steps
|
# Check first-install steps
|
||||||
PREFIX = colorstr("Ultralytics: ")
|
PREFIX = colorstr('Ultralytics: ')
|
||||||
SETTINGS = get_settings()
|
SETTINGS = get_settings()
|
||||||
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
|
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
|
||||||
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
|
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
|
||||||
|
@ -11,7 +11,7 @@ except (ImportError, AssertionError):
|
|||||||
clearml = None
|
clearml = None
|
||||||
|
|
||||||
|
|
||||||
def _log_images(imgs_dict, group="", step=0):
|
def _log_images(imgs_dict, group='', step=0):
|
||||||
task = Task.current_task()
|
task = Task.current_task()
|
||||||
if task:
|
if task:
|
||||||
for k, v in imgs_dict.items():
|
for k, v in imgs_dict.items():
|
||||||
@ -20,7 +20,7 @@ def _log_images(imgs_dict, group="", step=0):
|
|||||||
|
|
||||||
def on_pretrain_routine_start(trainer):
|
def on_pretrain_routine_start(trainer):
|
||||||
# TODO: reuse existing task
|
# TODO: reuse existing task
|
||||||
task = Task.init(project_name=trainer.args.project or "YOLOv8",
|
task = Task.init(project_name=trainer.args.project or 'YOLOv8',
|
||||||
task_name=trainer.args.name,
|
task_name=trainer.args.name,
|
||||||
tags=['YOLOv8'],
|
tags=['YOLOv8'],
|
||||||
output_uri=True,
|
output_uri=True,
|
||||||
@ -31,15 +31,15 @@ def on_pretrain_routine_start(trainer):
|
|||||||
|
|
||||||
def on_train_epoch_end(trainer):
|
def on_train_epoch_end(trainer):
|
||||||
if trainer.epoch == 1:
|
if trainer.epoch == 1:
|
||||||
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch)
|
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic', trainer.epoch)
|
||||||
|
|
||||||
|
|
||||||
def on_fit_epoch_end(trainer):
|
def on_fit_epoch_end(trainer):
|
||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
model_info = {
|
model_info = {
|
||||||
"Parameters": get_num_params(trainer.model),
|
'Parameters': get_num_params(trainer.model),
|
||||||
"GFLOPs": round(get_flops(trainer.model), 3),
|
'GFLOPs': round(get_flops(trainer.model), 3),
|
||||||
"Inference speed (ms/img)": round(trainer.validator.speed[1], 3)}
|
'Inference speed (ms/img)': round(trainer.validator.speed[1], 3)}
|
||||||
Task.current_task().connect(model_info, name='Model')
|
Task.current_task().connect(model_info, name='Model')
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ def on_train_end(trainer):
|
|||||||
|
|
||||||
|
|
||||||
callbacks = {
|
callbacks = {
|
||||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||||
"on_train_epoch_end": on_train_epoch_end,
|
'on_train_epoch_end': on_train_epoch_end,
|
||||||
"on_fit_epoch_end": on_fit_epoch_end,
|
'on_fit_epoch_end': on_fit_epoch_end,
|
||||||
"on_train_end": on_train_end} if clearml else {}
|
'on_train_end': on_train_end} if clearml else {}
|
||||||
|
@ -10,13 +10,13 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
def on_pretrain_routine_start(trainer):
|
def on_pretrain_routine_start(trainer):
|
||||||
experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8")
|
experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8')
|
||||||
experiment.log_parameters(vars(trainer.args))
|
experiment.log_parameters(vars(trainer.args))
|
||||||
|
|
||||||
|
|
||||||
def on_train_epoch_end(trainer):
|
def on_train_epoch_end(trainer):
|
||||||
experiment = comet_ml.get_global_experiment()
|
experiment = comet_ml.get_global_experiment()
|
||||||
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
|
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
|
||||||
if trainer.epoch == 1:
|
if trainer.epoch == 1:
|
||||||
for f in trainer.save_dir.glob('train_batch*.jpg'):
|
for f in trainer.save_dir.glob('train_batch*.jpg'):
|
||||||
experiment.log_image(f, name=f.stem, step=trainer.epoch + 1)
|
experiment.log_image(f, name=f.stem, step=trainer.epoch + 1)
|
||||||
@ -27,19 +27,19 @@ def on_fit_epoch_end(trainer):
|
|||||||
experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1)
|
experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1)
|
||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
model_info = {
|
model_info = {
|
||||||
"model/parameters": get_num_params(trainer.model),
|
'model/parameters': get_num_params(trainer.model),
|
||||||
"model/GFLOPs": round(get_flops(trainer.model), 3),
|
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
||||||
"model/speed(ms)": round(trainer.validator.speed[1], 3)}
|
'model/speed(ms)': round(trainer.validator.speed[1], 3)}
|
||||||
experiment.log_metrics(model_info, step=trainer.epoch + 1)
|
experiment.log_metrics(model_info, step=trainer.epoch + 1)
|
||||||
|
|
||||||
|
|
||||||
def on_train_end(trainer):
|
def on_train_end(trainer):
|
||||||
experiment = comet_ml.get_global_experiment()
|
experiment = comet_ml.get_global_experiment()
|
||||||
experiment.log_model("YOLOv8", file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
|
experiment.log_model('YOLOv8', file_or_folder=str(trainer.best), file_name='best.pt', overwrite=True)
|
||||||
|
|
||||||
|
|
||||||
callbacks = {
|
callbacks = {
|
||||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||||
"on_train_epoch_end": on_train_epoch_end,
|
'on_train_epoch_end': on_train_epoch_end,
|
||||||
"on_fit_epoch_end": on_fit_epoch_end,
|
'on_fit_epoch_end': on_fit_epoch_end,
|
||||||
"on_train_end": on_train_end} if comet_ml else {}
|
'on_train_end': on_train_end} if comet_ml else {}
|
||||||
|
@ -11,7 +11,7 @@ def on_pretrain_routine_end(trainer):
|
|||||||
session = getattr(trainer, 'hub_session', None)
|
session = getattr(trainer, 'hub_session', None)
|
||||||
if session:
|
if session:
|
||||||
# Start timer for upload rate limit
|
# Start timer for upload rate limit
|
||||||
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀")
|
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
|
||||||
session.t = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
|
session.t = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
|
||||||
|
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ def on_model_save(trainer):
|
|||||||
# Upload checkpoints with rate limiting
|
# Upload checkpoints with rate limiting
|
||||||
is_best = trainer.best_fitness == trainer.fitness
|
is_best = trainer.best_fitness == trainer.fitness
|
||||||
if time() - session.t['ckpt'] > session.rate_limits['ckpt']:
|
if time() - session.t['ckpt'] > session.rate_limits['ckpt']:
|
||||||
LOGGER.info(f"{PREFIX}Uploading checkpoint {session.model_id}")
|
LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}')
|
||||||
session.upload_model(trainer.epoch, trainer.last, is_best)
|
session.upload_model(trainer.epoch, trainer.last, is_best)
|
||||||
session.t['ckpt'] = time() # reset timer
|
session.t['ckpt'] = time() # reset timer
|
||||||
|
|
||||||
@ -40,11 +40,11 @@ def on_train_end(trainer):
|
|||||||
session = getattr(trainer, 'hub_session', None)
|
session = getattr(trainer, 'hub_session', None)
|
||||||
if session:
|
if session:
|
||||||
# Upload final model and metrics with exponential standoff
|
# Upload final model and metrics with exponential standoff
|
||||||
LOGGER.info(f"{PREFIX}Training completed successfully ✅\n"
|
LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
|
||||||
f"{PREFIX}Uploading final {session.model_id}")
|
f'{PREFIX}Uploading final {session.model_id}')
|
||||||
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True)
|
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True)
|
||||||
session.shutdown() # stop heartbeats
|
session.shutdown() # stop heartbeats
|
||||||
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀")
|
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
|
||||||
|
|
||||||
|
|
||||||
def on_train_start(trainer):
|
def on_train_start(trainer):
|
||||||
@ -64,11 +64,11 @@ def on_export_start(exporter):
|
|||||||
|
|
||||||
|
|
||||||
callbacks = {
|
callbacks = {
|
||||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||||
"on_fit_epoch_end": on_fit_epoch_end,
|
'on_fit_epoch_end': on_fit_epoch_end,
|
||||||
"on_model_save": on_model_save,
|
'on_model_save': on_model_save,
|
||||||
"on_train_end": on_train_end,
|
'on_train_end': on_train_end,
|
||||||
"on_train_start": on_train_start,
|
'on_train_start': on_train_start,
|
||||||
"on_val_start": on_val_start,
|
'on_val_start': on_val_start,
|
||||||
"on_predict_start": on_predict_start,
|
'on_predict_start': on_predict_start,
|
||||||
"on_export_start": on_export_start}
|
'on_export_start': on_export_start}
|
||||||
|
@ -16,7 +16,7 @@ def on_pretrain_routine_start(trainer):
|
|||||||
|
|
||||||
|
|
||||||
def on_batch_end(trainer):
|
def on_batch_end(trainer):
|
||||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
|
||||||
|
|
||||||
|
|
||||||
def on_fit_epoch_end(trainer):
|
def on_fit_epoch_end(trainer):
|
||||||
@ -24,6 +24,6 @@ def on_fit_epoch_end(trainer):
|
|||||||
|
|
||||||
|
|
||||||
callbacks = {
|
callbacks = {
|
||||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||||
"on_fit_epoch_end": on_fit_epoch_end,
|
'on_fit_epoch_end': on_fit_epoch_end,
|
||||||
"on_batch_end": on_batch_end}
|
'on_batch_end': on_batch_end}
|
||||||
|
@ -71,7 +71,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
|||||||
msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
|
msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
|
||||||
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
|
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
|
||||||
if max_dim != 1:
|
if max_dim != 1:
|
||||||
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
|
raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
|
||||||
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
|
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
|
||||||
imgsz = [max(imgsz)]
|
imgsz = [max(imgsz)]
|
||||||
# Make image size a multiple of the stride
|
# Make image size a multiple of the stride
|
||||||
@ -87,9 +87,9 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
|||||||
return sz
|
return sz
|
||||||
|
|
||||||
|
|
||||||
def check_version(current: str = "0.0.0",
|
def check_version(current: str = '0.0.0',
|
||||||
minimum: str = "0.0.0",
|
minimum: str = '0.0.0',
|
||||||
name: str = "version ",
|
name: str = 'version ',
|
||||||
pinned: bool = False,
|
pinned: bool = False,
|
||||||
hard: bool = False,
|
hard: bool = False,
|
||||||
verbose: bool = False) -> bool:
|
verbose: bool = False) -> bool:
|
||||||
@ -109,7 +109,7 @@ def check_version(current: str = "0.0.0",
|
|||||||
"""
|
"""
|
||||||
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
||||||
result = (current == minimum) if pinned else (current >= minimum) # bool
|
result = (current == minimum) if pinned else (current >= minimum) # bool
|
||||||
warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed"
|
warning_message = f'WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed'
|
||||||
if hard:
|
if hard:
|
||||||
assert result, emojis(warning_message) # assert min requirements met
|
assert result, emojis(warning_message) # assert min requirements met
|
||||||
if verbose and not result:
|
if verbose and not result:
|
||||||
@ -155,7 +155,7 @@ def check_online() -> bool:
|
|||||||
"""
|
"""
|
||||||
import socket
|
import socket
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
host = socket.gethostbyname("www.github.com")
|
host = socket.gethostbyname('www.github.com')
|
||||||
socket.create_connection((host, 80), timeout=2)
|
socket.create_connection((host, 80), timeout=2)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@ -182,7 +182,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
|||||||
file = None
|
file = None
|
||||||
if isinstance(requirements, Path): # requirements.txt file
|
if isinstance(requirements, Path): # requirements.txt file
|
||||||
file = requirements.resolve()
|
file = requirements.resolve()
|
||||||
assert file.exists(), f"{prefix} {file} not found, check failed."
|
assert file.exists(), f'{prefix} {file} not found, check failed.'
|
||||||
with file.open() as f:
|
with file.open() as f:
|
||||||
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
|
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
|
||||||
elif isinstance(requirements, str):
|
elif isinstance(requirements, str):
|
||||||
@ -200,7 +200,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
|||||||
if s and install and AUTOINSTALL: # check environment variable
|
if s and install and AUTOINSTALL: # check environment variable
|
||||||
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
||||||
try:
|
try:
|
||||||
assert check_online(), "AutoUpdate skipped (offline)"
|
assert check_online(), 'AutoUpdate skipped (offline)'
|
||||||
LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode())
|
LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode())
|
||||||
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
|
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
|
||||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||||
@ -217,19 +217,19 @@ def check_suffix(file='yolov8n.pt', suffix=('.pt',), msg=''):
|
|||||||
for f in file if isinstance(file, (list, tuple)) else [file]:
|
for f in file if isinstance(file, (list, tuple)) else [file]:
|
||||||
s = Path(f).suffix.lower() # file suffix
|
s = Path(f).suffix.lower() # file suffix
|
||||||
if len(s):
|
if len(s):
|
||||||
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
|
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}'
|
||||||
|
|
||||||
|
|
||||||
def check_yolov5u_filename(file: str):
|
def check_yolov5u_filename(file: str):
|
||||||
# Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
|
# Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
|
||||||
if 'yolov3' in file or 'yolov5' in file and 'u' not in file:
|
if 'yolov3' in file or 'yolov5' in file and 'u' not in file:
|
||||||
original_file = file
|
original_file = file
|
||||||
file = re.sub(r"(.*yolov5([nsmlx]))\.", "\\1u.", file) # i.e. yolov5n.pt -> yolov5nu.pt
|
file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file) # i.e. yolov5n.pt -> yolov5nu.pt
|
||||||
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.", "\\1u.", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
|
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
|
||||||
if file != original_file:
|
if file != original_file:
|
||||||
LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
|
LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
|
||||||
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
|
f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
|
||||||
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n")
|
f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
|
||||||
return file
|
return file
|
||||||
|
|
||||||
|
|
||||||
@ -290,7 +290,7 @@ def check_yolo(verbose=True):
|
|||||||
# System info
|
# System info
|
||||||
gib = 1 << 30 # bytes per GiB
|
gib = 1 << 30 # bytes per GiB
|
||||||
ram = psutil.virtual_memory().total
|
ram = psutil.virtual_memory().total
|
||||||
total, used, free = shutil.disk_usage("/")
|
total, used, free = shutil.disk_usage('/')
|
||||||
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
|
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
|
||||||
with contextlib.suppress(Exception): # clear display if ipython is installed
|
with contextlib.suppress(Exception): # clear display if ipython is installed
|
||||||
from IPython import display
|
from IPython import display
|
||||||
|
@ -22,7 +22,7 @@ def find_free_network_port() -> int:
|
|||||||
|
|
||||||
|
|
||||||
def generate_ddp_file(trainer):
|
def generate_ddp_file(trainer):
|
||||||
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
|
import_path = '.'.join(str(trainer.__class__).split('.')[1:-1])
|
||||||
|
|
||||||
if not trainer.resume:
|
if not trainer.resume:
|
||||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||||
@ -32,9 +32,9 @@ def generate_ddp_file(trainer):
|
|||||||
trainer = {trainer.__class__.__name__}(cfg=cfg)
|
trainer = {trainer.__class__.__name__}(cfg=cfg)
|
||||||
trainer.train()'''
|
trainer.train()'''
|
||||||
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
||||||
with tempfile.NamedTemporaryFile(prefix="_temp_",
|
with tempfile.NamedTemporaryFile(prefix='_temp_',
|
||||||
suffix=f"{id(trainer)}.py",
|
suffix=f'{id(trainer)}.py',
|
||||||
mode="w+",
|
mode='w+',
|
||||||
encoding='utf-8',
|
encoding='utf-8',
|
||||||
dir=USER_CONFIG_DIR / 'DDP',
|
dir=USER_CONFIG_DIR / 'DDP',
|
||||||
delete=False) as file:
|
delete=False) as file:
|
||||||
@ -47,18 +47,18 @@ def generate_ddp_command(world_size, trainer):
|
|||||||
|
|
||||||
# Get file and args (do not use sys.argv due to security vulnerability)
|
# Get file and args (do not use sys.argv due to security vulnerability)
|
||||||
exclude_args = ['save_dir']
|
exclude_args = ['save_dir']
|
||||||
args = [f"{k}={v}" for k, v in vars(trainer.args).items() if k not in exclude_args]
|
args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
|
||||||
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
|
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
|
||||||
|
|
||||||
# Build command
|
# Build command
|
||||||
torch_distributed_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
|
torch_distributed_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
||||||
cmd = [
|
cmd = [
|
||||||
sys.executable, "-m", torch_distributed_cmd, "--nproc_per_node", f"{world_size}", "--master_port",
|
sys.executable, '-m', torch_distributed_cmd, '--nproc_per_node', f'{world_size}', '--master_port',
|
||||||
f"{find_free_network_port()}", file] + args
|
f'{find_free_network_port()}', file] + args
|
||||||
return cmd, file
|
return cmd, file
|
||||||
|
|
||||||
|
|
||||||
def ddp_cleanup(trainer, file):
|
def ddp_cleanup(trainer, file):
|
||||||
# delete temp file if created
|
# delete temp file if created
|
||||||
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
|
if f'{id(trainer)}.py' in file: # if temp_file suffix in file
|
||||||
os.remove(file)
|
os.remove(file)
|
||||||
|
@ -95,14 +95,14 @@ def safe_download(url,
|
|||||||
torch.hub.download_url_to_file(url, f, progress=progress)
|
torch.hub.download_url_to_file(url, f, progress=progress)
|
||||||
else:
|
else:
|
||||||
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
|
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
|
||||||
with request.urlopen(url) as response, tqdm(total=int(response.getheader("Content-Length", 0)),
|
with request.urlopen(url) as response, tqdm(total=int(response.getheader('Content-Length', 0)),
|
||||||
desc=desc,
|
desc=desc,
|
||||||
disable=not progress,
|
disable=not progress,
|
||||||
unit='B',
|
unit='B',
|
||||||
unit_scale=True,
|
unit_scale=True,
|
||||||
unit_divisor=1024,
|
unit_divisor=1024,
|
||||||
bar_format=TQDM_BAR_FORMAT) as pbar:
|
bar_format=TQDM_BAR_FORMAT) as pbar:
|
||||||
with open(f, "wb") as f_opened:
|
with open(f, 'wb') as f_opened:
|
||||||
for data in response:
|
for data in response:
|
||||||
f_opened.write(data)
|
f_opened.write(data)
|
||||||
pbar.update(len(data))
|
pbar.update(len(data))
|
||||||
@ -171,7 +171,7 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
|||||||
tag, assets = github_assets(repo) # latest release
|
tag, assets = github_assets(repo) # latest release
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
tag = subprocess.check_output(["git", "tag"]).decode().split()[-1]
|
tag = subprocess.check_output(['git', 'tag']).decode().split()[-1]
|
||||||
except Exception:
|
except Exception:
|
||||||
tag = release
|
tag = release
|
||||||
|
|
||||||
|
@ -24,15 +24,15 @@ to_4tuple = _ntuple(4)
|
|||||||
# `xyxy` means left top and right bottom
|
# `xyxy` means left top and right bottom
|
||||||
# `xywh` means center x, center y and width, height(yolo format)
|
# `xywh` means center x, center y and width, height(yolo format)
|
||||||
# `ltwh` means left top and width, height(coco format)
|
# `ltwh` means left top and width, height(coco format)
|
||||||
_formats = ["xyxy", "xywh", "ltwh"]
|
_formats = ['xyxy', 'xywh', 'ltwh']
|
||||||
|
|
||||||
__all__ = ["Bboxes"]
|
__all__ = ['Bboxes']
|
||||||
|
|
||||||
|
|
||||||
class Bboxes:
|
class Bboxes:
|
||||||
"""Now only numpy is supported"""
|
"""Now only numpy is supported"""
|
||||||
|
|
||||||
def __init__(self, bboxes, format="xyxy") -> None:
|
def __init__(self, bboxes, format='xyxy') -> None:
|
||||||
assert format in _formats
|
assert format in _formats
|
||||||
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
|
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
|
||||||
assert bboxes.ndim == 2
|
assert bboxes.ndim == 2
|
||||||
@ -67,17 +67,17 @@ class Bboxes:
|
|||||||
assert format in _formats
|
assert format in _formats
|
||||||
if self.format == format:
|
if self.format == format:
|
||||||
return
|
return
|
||||||
elif self.format == "xyxy":
|
elif self.format == 'xyxy':
|
||||||
bboxes = xyxy2xywh(self.bboxes) if format == "xywh" else xyxy2ltwh(self.bboxes)
|
bboxes = xyxy2xywh(self.bboxes) if format == 'xywh' else xyxy2ltwh(self.bboxes)
|
||||||
elif self.format == "xywh":
|
elif self.format == 'xywh':
|
||||||
bboxes = xywh2xyxy(self.bboxes) if format == "xyxy" else xywh2ltwh(self.bboxes)
|
bboxes = xywh2xyxy(self.bboxes) if format == 'xyxy' else xywh2ltwh(self.bboxes)
|
||||||
else:
|
else:
|
||||||
bboxes = ltwh2xyxy(self.bboxes) if format == "xyxy" else ltwh2xywh(self.bboxes)
|
bboxes = ltwh2xyxy(self.bboxes) if format == 'xyxy' else ltwh2xywh(self.bboxes)
|
||||||
self.bboxes = bboxes
|
self.bboxes = bboxes
|
||||||
self.format = format
|
self.format = format
|
||||||
|
|
||||||
def areas(self):
|
def areas(self):
|
||||||
self.convert("xyxy")
|
self.convert('xyxy')
|
||||||
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
|
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
|
||||||
|
|
||||||
# def denormalize(self, w, h):
|
# def denormalize(self, w, h):
|
||||||
@ -128,7 +128,7 @@ class Bboxes:
|
|||||||
return len(self.bboxes)
|
return len(self.bboxes)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
|
def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
|
||||||
"""
|
"""
|
||||||
Concatenates a list of Boxes into a single Bboxes
|
Concatenates a list of Boxes into a single Bboxes
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ class Bboxes:
|
|||||||
return boxes_list[0]
|
return boxes_list[0]
|
||||||
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
||||||
|
|
||||||
def __getitem__(self, index) -> "Bboxes":
|
def __getitem__(self, index) -> 'Bboxes':
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
index: int, slice, or a BoolArray
|
index: int, slice, or a BoolArray
|
||||||
@ -158,13 +158,13 @@ class Bboxes:
|
|||||||
if isinstance(index, int):
|
if isinstance(index, int):
|
||||||
return Bboxes(self.bboxes[index].view(1, -1))
|
return Bboxes(self.bboxes[index].view(1, -1))
|
||||||
b = self.bboxes[index]
|
b = self.bboxes[index]
|
||||||
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
|
assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
|
||||||
return Bboxes(b)
|
return Bboxes(b)
|
||||||
|
|
||||||
|
|
||||||
class Instances:
|
class Instances:
|
||||||
|
|
||||||
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
|
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
bboxes (ndarray): bboxes with shape [N, 4].
|
bboxes (ndarray): bboxes with shape [N, 4].
|
||||||
@ -227,7 +227,7 @@ class Instances:
|
|||||||
|
|
||||||
def add_padding(self, padw, padh):
|
def add_padding(self, padw, padh):
|
||||||
# handle rect and mosaic situation
|
# handle rect and mosaic situation
|
||||||
assert not self.normalized, "you should add padding with absolute coordinates."
|
assert not self.normalized, 'you should add padding with absolute coordinates.'
|
||||||
self._bboxes.add(offset=(padw, padh, padw, padh))
|
self._bboxes.add(offset=(padw, padh, padw, padh))
|
||||||
self.segments[..., 0] += padw
|
self.segments[..., 0] += padw
|
||||||
self.segments[..., 1] += padh
|
self.segments[..., 1] += padh
|
||||||
@ -235,7 +235,7 @@ class Instances:
|
|||||||
self.keypoints[..., 0] += padw
|
self.keypoints[..., 0] += padw
|
||||||
self.keypoints[..., 1] += padh
|
self.keypoints[..., 1] += padh
|
||||||
|
|
||||||
def __getitem__(self, index) -> "Instances":
|
def __getitem__(self, index) -> 'Instances':
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
index: int, slice, or a BoolArray
|
index: int, slice, or a BoolArray
|
||||||
@ -256,7 +256,7 @@ class Instances:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def flipud(self, h):
|
def flipud(self, h):
|
||||||
if self._bboxes.format == "xyxy":
|
if self._bboxes.format == 'xyxy':
|
||||||
y1 = self.bboxes[:, 1].copy()
|
y1 = self.bboxes[:, 1].copy()
|
||||||
y2 = self.bboxes[:, 3].copy()
|
y2 = self.bboxes[:, 3].copy()
|
||||||
self.bboxes[:, 1] = h - y2
|
self.bboxes[:, 1] = h - y2
|
||||||
@ -268,7 +268,7 @@ class Instances:
|
|||||||
self.keypoints[..., 1] = h - self.keypoints[..., 1]
|
self.keypoints[..., 1] = h - self.keypoints[..., 1]
|
||||||
|
|
||||||
def fliplr(self, w):
|
def fliplr(self, w):
|
||||||
if self._bboxes.format == "xyxy":
|
if self._bboxes.format == 'xyxy':
|
||||||
x1 = self.bboxes[:, 0].copy()
|
x1 = self.bboxes[:, 0].copy()
|
||||||
x2 = self.bboxes[:, 2].copy()
|
x2 = self.bboxes[:, 2].copy()
|
||||||
self.bboxes[:, 0] = w - x2
|
self.bboxes[:, 0] = w - x2
|
||||||
@ -281,10 +281,10 @@ class Instances:
|
|||||||
|
|
||||||
def clip(self, w, h):
|
def clip(self, w, h):
|
||||||
ori_format = self._bboxes.format
|
ori_format = self._bboxes.format
|
||||||
self.convert_bbox(format="xyxy")
|
self.convert_bbox(format='xyxy')
|
||||||
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
||||||
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
|
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
|
||||||
if ori_format != "xyxy":
|
if ori_format != 'xyxy':
|
||||||
self.convert_bbox(format=ori_format)
|
self.convert_bbox(format=ori_format)
|
||||||
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
|
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
|
||||||
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
|
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
|
||||||
@ -304,7 +304,7 @@ class Instances:
|
|||||||
return len(self.bboxes)
|
return len(self.bboxes)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
|
def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
|
||||||
"""
|
"""
|
||||||
Concatenates a list of Boxes into a single Bboxes
|
Concatenates a list of Boxes into a single Bboxes
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ class VarifocalLoss(nn.Module):
|
|||||||
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
||||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") *
|
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
|
||||||
weight).sum()
|
weight).sum()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@ -52,5 +52,5 @@ class BboxLoss(nn.Module):
|
|||||||
tr = tl + 1 # target right
|
tr = tl + 1 # target right
|
||||||
wl = tr - target # weight left
|
wl = tr - target # weight left
|
||||||
wr = 1 - wl # weight right
|
wr = 1 - wl # weight right
|
||||||
return (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl +
|
return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
|
||||||
F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True)
|
F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
|
||||||
|
@ -238,14 +238,14 @@ class ConfusionMatrix:
|
|||||||
nc, nn = self.nc, len(names) # number of classes, names
|
nc, nn = self.nc, len(names) # number of classes, names
|
||||||
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
||||||
ticklabels = (names + ['background']) if labels else "auto"
|
ticklabels = (names + ['background']) if labels else 'auto'
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||||
sn.heatmap(array,
|
sn.heatmap(array,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
annot=nc < 30,
|
annot=nc < 30,
|
||||||
annot_kws={
|
annot_kws={
|
||||||
"size": 8},
|
'size': 8},
|
||||||
cmap='Blues',
|
cmap='Blues',
|
||||||
fmt='.2f',
|
fmt='.2f',
|
||||||
square=True,
|
square=True,
|
||||||
@ -287,7 +287,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
|||||||
ax.set_ylabel('Precision')
|
ax.set_ylabel('Precision')
|
||||||
ax.set_xlim(0, 1)
|
ax.set_xlim(0, 1)
|
||||||
ax.set_ylim(0, 1)
|
ax.set_ylim(0, 1)
|
||||||
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
|
||||||
ax.set_title('Precision-Recall Curve')
|
ax.set_title('Precision-Recall Curve')
|
||||||
fig.savefig(save_dir, dpi=250)
|
fig.savefig(save_dir, dpi=250)
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
@ -309,7 +309,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
|
|||||||
ax.set_ylabel(ylabel)
|
ax.set_ylabel(ylabel)
|
||||||
ax.set_xlim(0, 1)
|
ax.set_xlim(0, 1)
|
||||||
ax.set_ylim(0, 1)
|
ax.set_ylim(0, 1)
|
||||||
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
|
||||||
ax.set_title(f'{ylabel}-Confidence Curve')
|
ax.set_title(f'{ylabel}-Confidence Curve')
|
||||||
fig.savefig(save_dir, dpi=250)
|
fig.savefig(save_dir, dpi=250)
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
@ -343,7 +343,7 @@ def compute_ap(recall, precision):
|
|||||||
return ap, mpre, mrec
|
return ap, mpre, mrec
|
||||||
|
|
||||||
|
|
||||||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=""):
|
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
|
||||||
""" Compute the average precision, given the recall and precision curves.
|
""" Compute the average precision, given the recall and precision curves.
|
||||||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
||||||
# Arguments
|
# Arguments
|
||||||
@ -507,7 +507,7 @@ class Metric:
|
|||||||
|
|
||||||
class DetMetrics:
|
class DetMetrics:
|
||||||
|
|
||||||
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
|
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.plot = plot
|
self.plot = plot
|
||||||
self.names = names
|
self.names = names
|
||||||
@ -521,7 +521,7 @@ class DetMetrics:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
|
||||||
|
|
||||||
def mean_results(self):
|
def mean_results(self):
|
||||||
return self.box.mean_results()
|
return self.box.mean_results()
|
||||||
@ -543,12 +543,12 @@ class DetMetrics:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def results_dict(self):
|
def results_dict(self):
|
||||||
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||||
|
|
||||||
|
|
||||||
class SegmentMetrics:
|
class SegmentMetrics:
|
||||||
|
|
||||||
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
|
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.plot = plot
|
self.plot = plot
|
||||||
self.names = names
|
self.names = names
|
||||||
@ -563,7 +563,7 @@ class SegmentMetrics:
|
|||||||
plot=self.plot,
|
plot=self.plot,
|
||||||
save_dir=self.save_dir,
|
save_dir=self.save_dir,
|
||||||
names=self.names,
|
names=self.names,
|
||||||
prefix="Mask")[2:]
|
prefix='Mask')[2:]
|
||||||
self.seg.nc = len(self.names)
|
self.seg.nc = len(self.names)
|
||||||
self.seg.update(results_mask)
|
self.seg.update(results_mask)
|
||||||
results_box = ap_per_class(tp_b,
|
results_box = ap_per_class(tp_b,
|
||||||
@ -573,15 +573,15 @@ class SegmentMetrics:
|
|||||||
plot=self.plot,
|
plot=self.plot,
|
||||||
save_dir=self.save_dir,
|
save_dir=self.save_dir,
|
||||||
names=self.names,
|
names=self.names,
|
||||||
prefix="Box")[2:]
|
prefix='Box')[2:]
|
||||||
self.box.nc = len(self.names)
|
self.box.nc = len(self.names)
|
||||||
self.box.update(results_box)
|
self.box.update(results_box)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return [
|
return [
|
||||||
"metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)",
|
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
|
||||||
"metrics/precision(M)", "metrics/recall(M)", "metrics/mAP50(M)", "metrics/mAP50-95(M)"]
|
'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)']
|
||||||
|
|
||||||
def mean_results(self):
|
def mean_results(self):
|
||||||
return self.box.mean_results() + self.seg.mean_results()
|
return self.box.mean_results() + self.seg.mean_results()
|
||||||
@ -604,7 +604,7 @@ class SegmentMetrics:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def results_dict(self):
|
def results_dict(self):
|
||||||
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||||
|
|
||||||
|
|
||||||
class ClassifyMetrics:
|
class ClassifyMetrics:
|
||||||
@ -626,8 +626,8 @@ class ClassifyMetrics:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def results_dict(self):
|
def results_dict(self):
|
||||||
return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
|
return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness]))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
|
return ['metrics/accuracy_top1', 'metrics/accuracy_top5']
|
||||||
|
@ -715,4 +715,4 @@ def clean_str(s):
|
|||||||
Returns:
|
Returns:
|
||||||
(str): a string with special characters replaced by an underscore _
|
(str): a string with special characters replaced by an underscore _
|
||||||
"""
|
"""
|
||||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
|
||||||
|
@ -61,7 +61,7 @@ def DDP_model(model):
|
|||||||
|
|
||||||
def select_device(device='', batch=0, newline=False):
|
def select_device(device='', batch=0, newline=False):
|
||||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
||||||
s = f"Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} "
|
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||||
device = str(device).lower()
|
device = str(device).lower()
|
||||||
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
||||||
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||||
@ -74,15 +74,15 @@ def select_device(device='', batch=0, newline=False):
|
|||||||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
||||||
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
|
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
|
||||||
LOGGER.info(s)
|
LOGGER.info(s)
|
||||||
install = "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no " \
|
install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
|
||||||
"CUDA devices are seen by torch.\n" if torch.cuda.device_count() == 0 else ""
|
'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
|
||||||
raise ValueError(f"Invalid CUDA 'device={device}' requested."
|
raise ValueError(f"Invalid CUDA 'device={device}' requested."
|
||||||
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
|
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
|
||||||
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
|
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
|
||||||
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
|
f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
|
||||||
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
|
f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
|
||||||
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
|
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
|
||||||
f"{install}")
|
f'{install}')
|
||||||
|
|
||||||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||||
@ -177,7 +177,7 @@ def model_info(model, verbose=False, imgsz=640):
|
|||||||
fused = ' (fused)' if model.is_fused() else ''
|
fused = ' (fused)' if model.is_fused() else ''
|
||||||
fs = f', {flops:.1f} GFLOPs' if flops else ''
|
fs = f', {flops:.1f} GFLOPs' if flops else ''
|
||||||
m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
|
m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
|
||||||
LOGGER.info(f"{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
|
LOGGER.info(f'{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
|
||||||
|
|
||||||
|
|
||||||
def get_num_params(model):
|
def get_num_params(model):
|
||||||
|
@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
from ultralytics.yolo.v8 import classify, detect, segment
|
from ultralytics.yolo.v8 import classify, detect, segment
|
||||||
|
|
||||||
__all__ = ["classify", "segment", "detect"]
|
__all__ = ['classify', 'segment', 'detect']
|
||||||
|
@ -4,4 +4,4 @@ from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predic
|
|||||||
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
|
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
|
||||||
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
|
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
|
||||||
|
|
||||||
__all__ = ["ClassificationPredictor", "predict", "ClassificationTrainer", "train", "ClassificationValidator", "val"]
|
__all__ = ['ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val']
|
||||||
|
@ -28,7 +28,7 @@ class ClassificationPredictor(BasePredictor):
|
|||||||
|
|
||||||
def write_results(self, idx, results, batch):
|
def write_results(self, idx, results, batch):
|
||||||
p, im, im0 = batch
|
p, im, im0 = batch
|
||||||
log_string = ""
|
log_string = ''
|
||||||
if len(im.shape) == 3:
|
if len(im.shape) == 3:
|
||||||
im = im[None] # expand for batch dim
|
im = im[None] # expand for batch dim
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
@ -65,9 +65,9 @@ class ClassificationPredictor(BasePredictor):
|
|||||||
|
|
||||||
|
|
||||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||||
model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||||
source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
|
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||||
else "https://ultralytics.com/images/bus.jpg"
|
else 'https://ultralytics.com/images/bus.jpg'
|
||||||
|
|
||||||
args = dict(model=model, source=source)
|
args = dict(model=model, source=source)
|
||||||
if use_python:
|
if use_python:
|
||||||
@ -78,5 +78,5 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
predictor.predict_cli()
|
predictor.predict_cli()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
predict()
|
predict()
|
||||||
|
@ -16,14 +16,14 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides["task"] = "classify"
|
overrides['task'] = 'classify'
|
||||||
super().__init__(cfg, overrides)
|
super().__init__(cfg, overrides)
|
||||||
|
|
||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
self.model.names = self.data["names"]
|
self.model.names = self.data['names']
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
@ -53,11 +53,11 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
model = str(self.model)
|
model = str(self.model)
|
||||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||||
if model.endswith(".pt"):
|
if model.endswith('.pt'):
|
||||||
self.model, _ = attempt_load_one_weight(model, device='cpu')
|
self.model, _ = attempt_load_one_weight(model, device='cpu')
|
||||||
for p in self.model.parameters():
|
for p in self.model.parameters():
|
||||||
p.requires_grad = True # for training
|
p.requires_grad = True # for training
|
||||||
elif model.endswith(".yaml"):
|
elif model.endswith('.yaml'):
|
||||||
self.model = self.get_model(cfg=model)
|
self.model = self.get_model(cfg=model)
|
||||||
elif model in torchvision.models.__dict__:
|
elif model in torchvision.models.__dict__:
|
||||||
pretrained = True
|
pretrained = True
|
||||||
@ -67,15 +67,15 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
return # dont return ckpt. Classification doesn't support resume
|
return # dont return ckpt. Classification doesn't support resume
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||||
loader = build_classification_dataloader(path=dataset_path,
|
loader = build_classification_dataloader(path=dataset_path,
|
||||||
imgsz=self.args.imgsz,
|
imgsz=self.args.imgsz,
|
||||||
batch_size=batch_size if mode == "train" else (batch_size * 2),
|
batch_size=batch_size if mode == 'train' else (batch_size * 2),
|
||||||
augment=mode == "train",
|
augment=mode == 'train',
|
||||||
rank=rank,
|
rank=rank,
|
||||||
workers=self.args.workers)
|
workers=self.args.workers)
|
||||||
# Attach inference transforms
|
# Attach inference transforms
|
||||||
if mode != "train":
|
if mode != 'train':
|
||||||
if is_parallel(self.model):
|
if is_parallel(self.model):
|
||||||
self.model.module.transforms = loader.dataset.torch_transforms
|
self.model.module.transforms = loader.dataset.torch_transforms
|
||||||
else:
|
else:
|
||||||
@ -83,8 +83,8 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
return loader
|
return loader
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device)
|
batch['img'] = batch['img'].to(self.device)
|
||||||
batch["cls"] = batch["cls"].to(self.device)
|
batch['cls'] = batch['cls'].to(self.device)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def progress_string(self):
|
def progress_string(self):
|
||||||
@ -96,7 +96,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
|
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
|
||||||
|
|
||||||
def criterion(self, preds, batch):
|
def criterion(self, preds, batch):
|
||||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction='sum') / self.args.nbs
|
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
|
||||||
loss_items = loss.detach()
|
loss_items = loss.detach()
|
||||||
return loss, loss_items
|
return loss, loss_items
|
||||||
|
|
||||||
@ -112,12 +112,12 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
# else:
|
# else:
|
||||||
# return keys
|
# return keys
|
||||||
|
|
||||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||||
"""
|
"""
|
||||||
Returns a loss dict with labelled training loss items tensor
|
Returns a loss dict with labelled training loss items tensor
|
||||||
"""
|
"""
|
||||||
# Not needed for classification but necessary for segmentation & detection
|
# Not needed for classification but necessary for segmentation & detection
|
||||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
||||||
if loss_items is None:
|
if loss_items is None:
|
||||||
return keys
|
return keys
|
||||||
loss_items = [round(float(loss_items), 5)]
|
loss_items = [round(float(loss_items), 5)]
|
||||||
@ -140,8 +140,8 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
|
|
||||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||||
model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||||
data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
|
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
|
||||||
device = cfg.device if cfg.device is not None else ''
|
device = cfg.device if cfg.device is not None else ''
|
||||||
|
|
||||||
args = dict(model=model, data=data, device=device)
|
args = dict(model=model, data=data, device=device)
|
||||||
@ -153,5 +153,5 @@ def train(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
train()
|
train()
|
||||||
|
@ -21,14 +21,14 @@ class ClassificationValidator(BaseValidator):
|
|||||||
self.targets = []
|
self.targets = []
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
batch['img'] = batch['img'].to(self.device, non_blocking=True)
|
||||||
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
|
||||||
batch["cls"] = batch["cls"].to(self.device)
|
batch['cls'] = batch['cls'].to(self.device)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
self.pred.append(preds.argsort(1, descending=True)[:, :5])
|
self.pred.append(preds.argsort(1, descending=True)[:, :5])
|
||||||
self.targets.append(batch["cls"])
|
self.targets.append(batch['cls'])
|
||||||
|
|
||||||
def get_stats(self):
|
def get_stats(self):
|
||||||
self.metrics.process(self.targets, self.pred)
|
self.metrics.process(self.targets, self.pred)
|
||||||
@ -42,12 +42,12 @@ class ClassificationValidator(BaseValidator):
|
|||||||
|
|
||||||
def print_results(self):
|
def print_results(self):
|
||||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
||||||
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
self.logger.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
||||||
|
|
||||||
|
|
||||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||||
model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||||
data = cfg.data or "mnist160"
|
data = cfg.data or 'mnist160'
|
||||||
|
|
||||||
args = dict(model=model, data=data)
|
args = dict(model=model, data=data)
|
||||||
if use_python:
|
if use_python:
|
||||||
@ -58,5 +58,5 @@ def val(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
validator(model=args['model'])
|
validator(model=args['model'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
val()
|
val()
|
||||||
|
@ -4,4 +4,4 @@ from .predict import DetectionPredictor, predict
|
|||||||
from .train import DetectionTrainer, train
|
from .train import DetectionTrainer, train
|
||||||
from .val import DetectionValidator, val
|
from .val import DetectionValidator, val
|
||||||
|
|
||||||
__all__ = ["DetectionPredictor", "predict", "DetectionTrainer", "train", "DetectionValidator", "val"]
|
__all__ = ['DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val']
|
||||||
|
@ -37,7 +37,7 @@ class DetectionPredictor(BasePredictor):
|
|||||||
|
|
||||||
def write_results(self, idx, results, batch):
|
def write_results(self, idx, results, batch):
|
||||||
p, im, im0 = batch
|
p, im, im0 = batch
|
||||||
log_string = ""
|
log_string = ''
|
||||||
if len(im.shape) == 3:
|
if len(im.shape) == 3:
|
||||||
im = im[None] # expand for batch dim
|
im = im[None] # expand for batch dim
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
@ -69,7 +69,7 @@ class DetectionPredictor(BasePredictor):
|
|||||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||||
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
|
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
|
||||||
c = int(cls) # integer class
|
c = int(cls) # integer class
|
||||||
name = f"id:{int(d.id.item())} {self.model.names[c]}" if d.id is not None else self.model.names[c]
|
name = f'id:{int(d.id.item())} {self.model.names[c]}' if d.id is not None else self.model.names[c]
|
||||||
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
|
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
|
||||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||||
if self.args.save_crop:
|
if self.args.save_crop:
|
||||||
@ -82,9 +82,9 @@ class DetectionPredictor(BasePredictor):
|
|||||||
|
|
||||||
|
|
||||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||||
model = cfg.model or "yolov8n.pt"
|
model = cfg.model or 'yolov8n.pt'
|
||||||
source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
|
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||||
else "https://ultralytics.com/images/bus.jpg"
|
else 'https://ultralytics.com/images/bus.jpg'
|
||||||
|
|
||||||
args = dict(model=model, source=source)
|
args = dict(model=model, source=source)
|
||||||
if use_python:
|
if use_python:
|
||||||
@ -95,5 +95,5 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
predictor.predict_cli()
|
predictor.predict_cli()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
predict()
|
predict()
|
||||||
|
@ -20,7 +20,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
|||||||
# BaseTrainer python usage
|
# BaseTrainer python usage
|
||||||
class DetectionTrainer(BaseTrainer):
|
class DetectionTrainer(BaseTrainer):
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0):
|
def get_dataloader(self, dataset_path, batch_size, mode='train', rank=0):
|
||||||
# TODO: manage splits differently
|
# TODO: manage splits differently
|
||||||
# calculate stride - check if model is initialized
|
# calculate stride - check if model is initialized
|
||||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||||
@ -29,21 +29,21 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
stride=gs,
|
stride=gs,
|
||||||
hyp=vars(self.args),
|
hyp=vars(self.args),
|
||||||
augment=mode == "train",
|
augment=mode == 'train',
|
||||||
cache=self.args.cache,
|
cache=self.args.cache,
|
||||||
pad=0 if mode == "train" else 0.5,
|
pad=0 if mode == 'train' else 0.5,
|
||||||
rect=self.args.rect or mode == "val",
|
rect=self.args.rect or mode == 'val',
|
||||||
rank=rank,
|
rank=rank,
|
||||||
workers=self.args.workers,
|
workers=self.args.workers,
|
||||||
close_mosaic=self.args.close_mosaic != 0,
|
close_mosaic=self.args.close_mosaic != 0,
|
||||||
prefix=colorstr(f'{mode}: '),
|
prefix=colorstr(f'{mode}: '),
|
||||||
shuffle=mode == "train",
|
shuffle=mode == 'train',
|
||||||
seed=self.args.seed)[0] if self.args.v5loader else \
|
seed=self.args.seed)[0] if self.args.v5loader else \
|
||||||
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode,
|
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode,
|
||||||
rect=mode == "val", names=self.data['names'])[0]
|
rect=mode == 'val', names=self.data['names'])[0]
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
@ -51,13 +51,13 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
# self.args.box *= 3 / nl # scale to layers
|
# self.args.box *= 3 / nl # scale to layers
|
||||||
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||||
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||||
self.model.nc = self.data["nc"] # attach number of classes to model
|
self.model.nc = self.data['nc'] # attach number of classes to model
|
||||||
self.model.names = self.data["names"] # attach class names to model
|
self.model.names = self.data['names'] # attach class names to model
|
||||||
self.model.args = self.args # attach hyperparameters to model
|
self.model.args = self.args # attach hyperparameters to model
|
||||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
model = DetectionModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
@ -75,12 +75,12 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
self.compute_loss = Loss(de_parallel(self.model))
|
self.compute_loss = Loss(de_parallel(self.model))
|
||||||
return self.compute_loss(preds, batch)
|
return self.compute_loss(preds, batch)
|
||||||
|
|
||||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||||
"""
|
"""
|
||||||
Returns a loss dict with labelled training loss items tensor
|
Returns a loss dict with labelled training loss items tensor
|
||||||
"""
|
"""
|
||||||
# Not needed for classification but necessary for segmentation & detection
|
# Not needed for classification but necessary for segmentation & detection
|
||||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
||||||
if loss_items is not None:
|
if loss_items is not None:
|
||||||
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||||
return dict(zip(keys, loss_items))
|
return dict(zip(keys, loss_items))
|
||||||
@ -92,12 +92,12 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||||
|
|
||||||
def plot_training_samples(self, batch, ni):
|
def plot_training_samples(self, batch, ni):
|
||||||
plot_images(images=batch["img"],
|
plot_images(images=batch['img'],
|
||||||
batch_idx=batch["batch_idx"],
|
batch_idx=batch['batch_idx'],
|
||||||
cls=batch["cls"].squeeze(-1),
|
cls=batch['cls'].squeeze(-1),
|
||||||
bboxes=batch["bboxes"],
|
bboxes=batch['bboxes'],
|
||||||
paths=batch["im_file"],
|
paths=batch['im_file'],
|
||||||
fname=self.save_dir / f"train_batch{ni}.jpg")
|
fname=self.save_dir / f'train_batch{ni}.jpg')
|
||||||
|
|
||||||
def plot_metrics(self):
|
def plot_metrics(self):
|
||||||
plot_results(file=self.csv) # save results.png
|
plot_results(file=self.csv) # save results.png
|
||||||
@ -169,7 +169,7 @@ class Loss:
|
|||||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||||
|
|
||||||
# targets
|
# targets
|
||||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||||
@ -201,8 +201,8 @@ class Loss:
|
|||||||
|
|
||||||
|
|
||||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||||
model = cfg.model or "yolov8n.pt"
|
model = cfg.model or 'yolov8n.pt'
|
||||||
data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
|
data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
|
||||||
device = cfg.device if cfg.device is not None else ''
|
device = cfg.device if cfg.device is not None else ''
|
||||||
|
|
||||||
args = dict(model=model, data=data, device=device)
|
args = dict(model=model, data=data, device=device)
|
||||||
@ -214,5 +214,5 @@ def train(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
train()
|
train()
|
||||||
|
@ -28,13 +28,13 @@ class DetectionValidator(BaseValidator):
|
|||||||
self.niou = self.iouv.numel()
|
self.niou = self.iouv.numel()
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
batch['img'] = batch['img'].to(self.device, non_blocking=True)
|
||||||
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
|
||||||
for k in ["batch_idx", "cls", "bboxes"]:
|
for k in ['batch_idx', 'cls', 'bboxes']:
|
||||||
batch[k] = batch[k].to(self.device)
|
batch[k] = batch[k].to(self.device)
|
||||||
|
|
||||||
nb = len(batch["img"])
|
nb = len(batch['img'])
|
||||||
self.lb = [torch.cat([batch["cls"], batch["bboxes"]], dim=-1)[batch["batch_idx"] == i]
|
self.lb = [torch.cat([batch['cls'], batch['bboxes']], dim=-1)[batch['batch_idx'] == i]
|
||||||
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
|
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
@ -54,7 +54,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
self.stats = []
|
self.stats = []
|
||||||
|
|
||||||
def get_desc(self):
|
def get_desc(self):
|
||||||
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)")
|
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
|
||||||
|
|
||||||
def postprocess(self, preds):
|
def postprocess(self, preds):
|
||||||
preds = ops.non_max_suppression(preds,
|
preds = ops.non_max_suppression(preds,
|
||||||
@ -69,11 +69,11 @@ class DetectionValidator(BaseValidator):
|
|||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
# Metrics
|
# Metrics
|
||||||
for si, pred in enumerate(preds):
|
for si, pred in enumerate(preds):
|
||||||
idx = batch["batch_idx"] == si
|
idx = batch['batch_idx'] == si
|
||||||
cls = batch["cls"][idx]
|
cls = batch['cls'][idx]
|
||||||
bbox = batch["bboxes"][idx]
|
bbox = batch['bboxes'][idx]
|
||||||
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||||
shape = batch["ori_shape"][si]
|
shape = batch['ori_shape'][si]
|
||||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
|
|
||||||
@ -88,16 +88,16 @@ class DetectionValidator(BaseValidator):
|
|||||||
if self.args.single_cls:
|
if self.args.single_cls:
|
||||||
pred[:, 5] = 0
|
pred[:, 5] = 0
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape,
|
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
|
||||||
ratio_pad=batch["ratio_pad"][si]) # native-space pred
|
ratio_pad=batch['ratio_pad'][si]) # native-space pred
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if nl:
|
if nl:
|
||||||
height, width = batch["img"].shape[2:]
|
height, width = batch['img'].shape[2:]
|
||||||
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
|
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
|
||||||
(width, height, width, height), device=self.device) # target boxes
|
(width, height, width, height), device=self.device) # target boxes
|
||||||
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape,
|
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
|
||||||
ratio_pad=batch["ratio_pad"][si]) # native-space labels
|
ratio_pad=batch['ratio_pad'][si]) # native-space labels
|
||||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||||
correct_bboxes = self._process_batch(predn, labelsn)
|
correct_bboxes = self._process_batch(predn, labelsn)
|
||||||
# TODO: maybe remove these `self.` arguments as they already are member variable
|
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||||
@ -107,7 +107,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
# Save
|
# Save
|
||||||
if self.args.save_json:
|
if self.args.save_json:
|
||||||
self.pred_to_json(predn, batch["im_file"][si])
|
self.pred_to_json(predn, batch['im_file'][si])
|
||||||
# if self.args.save_txt:
|
# if self.args.save_txt:
|
||||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
def print_results(self):
|
def print_results(self):
|
||||||
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
|
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
|
||||||
self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
self.logger.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||||
if self.nt_per_class.sum() == 0:
|
if self.nt_per_class.sum() == 0:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
|
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
|
||||||
@ -175,21 +175,21 @@ class DetectionValidator(BaseValidator):
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
seed=self.args.seed)[0] if self.args.v5loader else \
|
seed=self.args.seed)[0] if self.args.v5loader else \
|
||||||
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, names=self.data['names'],
|
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, names=self.data['names'],
|
||||||
mode="val")[0]
|
mode='val')[0]
|
||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
plot_images(batch["img"],
|
plot_images(batch['img'],
|
||||||
batch["batch_idx"],
|
batch['batch_idx'],
|
||||||
batch["cls"].squeeze(-1),
|
batch['cls'].squeeze(-1),
|
||||||
batch["bboxes"],
|
batch['bboxes'],
|
||||||
paths=batch["im_file"],
|
paths=batch['im_file'],
|
||||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||||
names=self.names)
|
names=self.names)
|
||||||
|
|
||||||
def plot_predictions(self, batch, preds, ni):
|
def plot_predictions(self, batch, preds, ni):
|
||||||
plot_images(batch["img"],
|
plot_images(batch['img'],
|
||||||
*output_to_target(preds, max_det=15),
|
*output_to_target(preds, max_det=15),
|
||||||
paths=batch["im_file"],
|
paths=batch['im_file'],
|
||||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||||
names=self.names) # pred
|
names=self.names) # pred
|
||||||
|
|
||||||
@ -207,8 +207,8 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
def eval_json(self, stats):
|
def eval_json(self, stats):
|
||||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||||
anno_json = self.data['path'] / "annotations/instances_val2017.json" # annotations
|
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||||
pred_json = self.save_dir / "predictions.json" # predictions
|
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||||
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||||
check_requirements('pycocotools>=2.0.6')
|
check_requirements('pycocotools>=2.0.6')
|
||||||
@ -216,7 +216,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
from pycocotools.cocoeval import COCOeval # noqa
|
from pycocotools.cocoeval import COCOeval # noqa
|
||||||
|
|
||||||
for x in anno_json, pred_json:
|
for x in anno_json, pred_json:
|
||||||
assert x.is_file(), f"{x} file not found"
|
assert x.is_file(), f'{x} file not found'
|
||||||
anno = COCO(str(anno_json)) # init annotations api
|
anno = COCO(str(anno_json)) # init annotations api
|
||||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||||
eval = COCOeval(anno, pred, 'bbox')
|
eval = COCOeval(anno, pred, 'bbox')
|
||||||
@ -232,8 +232,8 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
|
|
||||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||||
model = cfg.model or "yolov8n.pt"
|
model = cfg.model or 'yolov8n.pt'
|
||||||
data = cfg.data or "coco128.yaml"
|
data = cfg.data or 'coco128.yaml'
|
||||||
|
|
||||||
args = dict(model=model, data=data)
|
args = dict(model=model, data=data)
|
||||||
if use_python:
|
if use_python:
|
||||||
@ -244,5 +244,5 @@ def val(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
validator(model=args['model'])
|
validator(model=args['model'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
val()
|
val()
|
||||||
|
@ -4,4 +4,4 @@ from .predict import SegmentationPredictor, predict
|
|||||||
from .train import SegmentationTrainer, train
|
from .train import SegmentationTrainer, train
|
||||||
from .val import SegmentationValidator, val
|
from .val import SegmentationValidator, val
|
||||||
|
|
||||||
__all__ = ["SegmentationPredictor", "predict", "SegmentationTrainer", "train", "SegmentationValidator", "val"]
|
__all__ = ['SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val']
|
||||||
|
@ -39,7 +39,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
|
|
||||||
def write_results(self, idx, results, batch):
|
def write_results(self, idx, results, batch):
|
||||||
p, im, im0 = batch
|
p, im, im0 = batch
|
||||||
log_string = ""
|
log_string = ''
|
||||||
if len(im.shape) == 3:
|
if len(im.shape) == 3:
|
||||||
im = im[None] # expand for batch dim
|
im = im[None] # expand for batch dim
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
@ -84,7 +84,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
|
|
||||||
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
|
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
|
||||||
c = int(cls) # integer class
|
c = int(cls) # integer class
|
||||||
name = f"id:{int(d.id.item())} {self.model.names[c]}" if d.id is not None else self.model.names[c]
|
name = f'id:{int(d.id.item())} {self.model.names[c]}' if d.id is not None else self.model.names[c]
|
||||||
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
|
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
|
||||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None
|
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None
|
||||||
if self.args.save_crop:
|
if self.args.save_crop:
|
||||||
@ -97,9 +97,9 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
|
|
||||||
|
|
||||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||||
model = cfg.model or "yolov8n-seg.pt"
|
model = cfg.model or 'yolov8n-seg.pt'
|
||||||
source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
|
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||||
else "https://ultralytics.com/images/bus.jpg"
|
else 'https://ultralytics.com/images/bus.jpg'
|
||||||
|
|
||||||
args = dict(model=model, source=source)
|
args = dict(model=model, source=source)
|
||||||
if use_python:
|
if use_python:
|
||||||
@ -110,5 +110,5 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
predictor.predict_cli()
|
predictor.predict_cli()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
predict()
|
predict()
|
||||||
|
@ -20,11 +20,11 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
|
|||||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides["task"] = "segment"
|
overrides['task'] = 'segment'
|
||||||
super().__init__(cfg, overrides)
|
super().__init__(cfg, overrides)
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
@ -43,13 +43,13 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
|
|||||||
return self.compute_loss(preds, batch)
|
return self.compute_loss(preds, batch)
|
||||||
|
|
||||||
def plot_training_samples(self, batch, ni):
|
def plot_training_samples(self, batch, ni):
|
||||||
images = batch["img"]
|
images = batch['img']
|
||||||
masks = batch["masks"]
|
masks = batch['masks']
|
||||||
cls = batch["cls"].squeeze(-1)
|
cls = batch['cls'].squeeze(-1)
|
||||||
bboxes = batch["bboxes"]
|
bboxes = batch['bboxes']
|
||||||
paths = batch["im_file"]
|
paths = batch['im_file']
|
||||||
batch_idx = batch["batch_idx"]
|
batch_idx = batch['batch_idx']
|
||||||
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f"train_batch{ni}.jpg")
|
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
|
||||||
|
|
||||||
def plot_metrics(self):
|
def plot_metrics(self):
|
||||||
plot_results(file=self.csv, segment=True) # save results.png
|
plot_results(file=self.csv, segment=True) # save results.png
|
||||||
@ -80,15 +80,15 @@ class SegLoss(Loss):
|
|||||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||||
|
|
||||||
# targets
|
# targets
|
||||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
batch_idx = batch['batch_idx'].view(-1, 1)
|
||||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||||
|
|
||||||
masks = batch["masks"].to(self.device).float()
|
masks = batch['masks'].to(self.device).float()
|
||||||
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
|
||||||
|
|
||||||
# pboxes
|
# pboxes
|
||||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||||
@ -135,13 +135,13 @@ class SegLoss(Loss):
|
|||||||
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
|
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
|
||||||
# Mask loss for one image
|
# Mask loss for one image
|
||||||
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
|
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
|
||||||
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
|
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
|
||||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
||||||
|
|
||||||
|
|
||||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||||
model = cfg.model or "yolov8n-seg.pt"
|
model = cfg.model or 'yolov8n-seg.pt'
|
||||||
data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
|
||||||
device = cfg.device if cfg.device is not None else ''
|
device = cfg.device if cfg.device is not None else ''
|
||||||
|
|
||||||
args = dict(model=model, data=data, device=device)
|
args = dict(model=model, data=data, device=device)
|
||||||
@ -153,5 +153,5 @@ def train(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
train()
|
train()
|
||||||
|
@ -24,7 +24,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
batch = super().preprocess(batch)
|
batch = super().preprocess(batch)
|
||||||
batch["masks"] = batch["masks"].to(self.device).float()
|
batch['masks'] = batch['masks'].to(self.device).float()
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
@ -37,8 +37,8 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
self.process = ops.process_mask # faster
|
self.process = ops.process_mask # faster
|
||||||
|
|
||||||
def get_desc(self):
|
def get_desc(self):
|
||||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
|
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
|
||||||
"R", "mAP50", "mAP50-95)")
|
'R', 'mAP50', 'mAP50-95)')
|
||||||
|
|
||||||
def postprocess(self, preds):
|
def postprocess(self, preds):
|
||||||
p = ops.non_max_suppression(preds[0],
|
p = ops.non_max_suppression(preds[0],
|
||||||
@ -55,11 +55,11 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
# Metrics
|
# Metrics
|
||||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||||
idx = batch["batch_idx"] == si
|
idx = batch['batch_idx'] == si
|
||||||
cls = batch["cls"][idx]
|
cls = batch['cls'][idx]
|
||||||
bbox = batch["bboxes"][idx]
|
bbox = batch['bboxes'][idx]
|
||||||
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||||
shape = batch["ori_shape"][si]
|
shape = batch['ori_shape'][si]
|
||||||
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
@ -74,23 +74,23 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
# Masks
|
# Masks
|
||||||
midx = [si] if self.args.overlap_mask else idx
|
midx = [si] if self.args.overlap_mask else idx
|
||||||
gt_masks = batch["masks"][midx]
|
gt_masks = batch['masks'][midx]
|
||||||
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch["img"][si].shape[1:])
|
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
|
||||||
|
|
||||||
# Predictions
|
# Predictions
|
||||||
if self.args.single_cls:
|
if self.args.single_cls:
|
||||||
pred[:, 5] = 0
|
pred[:, 5] = 0
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape,
|
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
|
||||||
ratio_pad=batch["ratio_pad"][si]) # native-space pred
|
ratio_pad=batch['ratio_pad'][si]) # native-space pred
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if nl:
|
if nl:
|
||||||
height, width = batch["img"].shape[2:]
|
height, width = batch['img'].shape[2:]
|
||||||
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
|
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
|
||||||
(width, height, width, height), device=self.device) # target boxes
|
(width, height, width, height), device=self.device) # target boxes
|
||||||
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape,
|
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
|
||||||
ratio_pad=batch["ratio_pad"][si]) # native-space labels
|
ratio_pad=batch['ratio_pad'][si]) # native-space labels
|
||||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||||
correct_bboxes = self._process_batch(predn, labelsn)
|
correct_bboxes = self._process_batch(predn, labelsn)
|
||||||
# TODO: maybe remove these `self.` arguments as they already are member variable
|
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||||
@ -112,11 +112,11 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
# Save
|
# Save
|
||||||
if self.args.save_json:
|
if self.args.save_json:
|
||||||
pred_masks = ops.scale_image(batch["img"][si].shape[1:],
|
pred_masks = ops.scale_image(batch['img'][si].shape[1:],
|
||||||
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
||||||
shape,
|
shape,
|
||||||
ratio_pad=batch["ratio_pad"][si])
|
ratio_pad=batch['ratio_pad'][si])
|
||||||
self.pred_to_json(predn, batch["im_file"][si], pred_masks)
|
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
|
||||||
# if self.args.save_txt:
|
# if self.args.save_txt:
|
||||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
||||||
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
||||||
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
||||||
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
|
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
|
||||||
gt_masks = gt_masks.gt_(0.5)
|
gt_masks = gt_masks.gt_(0.5)
|
||||||
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
||||||
else: # boxes
|
else: # boxes
|
||||||
@ -158,20 +158,20 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
plot_images(batch["img"],
|
plot_images(batch['img'],
|
||||||
batch["batch_idx"],
|
batch['batch_idx'],
|
||||||
batch["cls"].squeeze(-1),
|
batch['cls'].squeeze(-1),
|
||||||
batch["bboxes"],
|
batch['bboxes'],
|
||||||
batch["masks"],
|
batch['masks'],
|
||||||
paths=batch["im_file"],
|
paths=batch['im_file'],
|
||||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||||
names=self.names)
|
names=self.names)
|
||||||
|
|
||||||
def plot_predictions(self, batch, preds, ni):
|
def plot_predictions(self, batch, preds, ni):
|
||||||
plot_images(batch["img"],
|
plot_images(batch['img'],
|
||||||
*output_to_target(preds[0], max_det=15),
|
*output_to_target(preds[0], max_det=15),
|
||||||
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
|
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
|
||||||
paths=batch["im_file"],
|
paths=batch['im_file'],
|
||||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||||
names=self.names) # pred
|
names=self.names) # pred
|
||||||
self.plot_masks.clear()
|
self.plot_masks.clear()
|
||||||
@ -182,8 +182,8 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
from pycocotools.mask import encode # noqa
|
from pycocotools.mask import encode # noqa
|
||||||
|
|
||||||
def single_encode(x):
|
def single_encode(x):
|
||||||
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
|
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
|
||||||
rle["counts"] = rle["counts"].decode("utf-8")
|
rle['counts'] = rle['counts'].decode('utf-8')
|
||||||
return rle
|
return rle
|
||||||
|
|
||||||
stem = Path(filename).stem
|
stem = Path(filename).stem
|
||||||
@ -203,8 +203,8 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
def eval_json(self, stats):
|
def eval_json(self, stats):
|
||||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||||
anno_json = self.data['path'] / "annotations/instances_val2017.json" # annotations
|
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||||
pred_json = self.save_dir / "predictions.json" # predictions
|
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||||
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||||
check_requirements('pycocotools>=2.0.6')
|
check_requirements('pycocotools>=2.0.6')
|
||||||
@ -212,7 +212,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
from pycocotools.cocoeval import COCOeval # noqa
|
from pycocotools.cocoeval import COCOeval # noqa
|
||||||
|
|
||||||
for x in anno_json, pred_json:
|
for x in anno_json, pred_json:
|
||||||
assert x.is_file(), f"{x} file not found"
|
assert x.is_file(), f'{x} file not found'
|
||||||
anno = COCO(str(anno_json)) # init annotations api
|
anno = COCO(str(anno_json)) # init annotations api
|
||||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||||
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
|
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
|
||||||
@ -231,8 +231,8 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
|
|
||||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||||
model = cfg.model or "yolov8n-seg.pt"
|
model = cfg.model or 'yolov8n-seg.pt'
|
||||||
data = cfg.data or "coco128-seg.yaml"
|
data = cfg.data or 'coco128-seg.yaml'
|
||||||
|
|
||||||
args = dict(model=model, data=data)
|
args = dict(model=model, data=data)
|
||||||
if use_python:
|
if use_python:
|
||||||
@ -243,5 +243,5 @@ def val(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
validator(model=args['model'])
|
validator(model=args['model'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
val()
|
val()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user