mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
ultralytics 8.0.30
Docker, rect, data=*.zip updates (#832)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
This commit is contained in:
parent
09265b17d7
commit
64f247d692
6
.github/workflows/docker.yaml
vendored
6
.github/workflows/docker.yaml
vendored
@ -29,7 +29,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Build and push arm64 image
|
- name: Build and push arm64 image
|
||||||
uses: docker/build-push-action@v3
|
uses: docker/build-push-action@v4
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
@ -39,7 +39,7 @@ jobs:
|
|||||||
tags: ultralytics/ultralytics:latest-arm64
|
tags: ultralytics/ultralytics:latest-arm64
|
||||||
|
|
||||||
- name: Build and push CPU image
|
- name: Build and push CPU image
|
||||||
uses: docker/build-push-action@v3
|
uses: docker/build-push-action@v4
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
@ -48,7 +48,7 @@ jobs:
|
|||||||
tags: ultralytics/ultralytics:latest-cpu
|
tags: ultralytics/ultralytics:latest-cpu
|
||||||
|
|
||||||
- name: Build and push GPU image
|
- name: Build and push GPU image
|
||||||
uses: docker/build-push-action@v3
|
uses: docker/build-push-action@v4
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
|
@ -26,11 +26,9 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
|
|||||||
# Install pip packages
|
# Install pip packages
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN python3 -m pip install --upgrade pip wheel
|
RUN python3 -m pip install --upgrade pip wheel
|
||||||
RUN pip install --no-cache ultralytics gsutil notebook \
|
RUN pip install --no-cache ultralytics albumentations gsutil notebook \
|
||||||
tensorflow-aarch64
|
coremltools onnx onnx-simplifier onnxruntime openvino-dev>=2022.3
|
||||||
# tensorflowjs \
|
# tensorflow-aarch64 tensorflowjs \
|
||||||
# onnx onnx-simplifier onnxruntime \
|
|
||||||
# coremltools openvino-dev>=2022.3 \
|
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
ENV DEBIAN_FRONTEND teletype
|
ENV DEBIAN_FRONTEND teletype
|
||||||
|
@ -108,6 +108,7 @@ task.
|
|||||||
| overlap_mask | True | masks should overlap during training (segment train only) |
|
| overlap_mask | True | masks should overlap during training (segment train only) |
|
||||||
| mask_ratio | 4 | mask downsample ratio (segment train only) |
|
| mask_ratio | 4 | mask downsample ratio (segment train only) |
|
||||||
| dropout | 0.0 | use dropout regularization (classify train only) |
|
| dropout | 0.0 | use dropout regularization (classify train only) |
|
||||||
|
| val | True | validate/test during training |
|
||||||
|
|
||||||
### Prediction
|
### Prediction
|
||||||
|
|
||||||
@ -148,7 +149,6 @@ validation dataset and to detect and prevent overfitting.
|
|||||||
|
|
||||||
| Key | Value | Description |
|
| Key | Value | Description |
|
||||||
|-------------|-------|-----------------------------------------------------------------------------|
|
|-------------|-------|-----------------------------------------------------------------------------|
|
||||||
| val | True | validate/test during training |
|
|
||||||
| save_json | False | save results to JSON file |
|
| save_json | False | save results to JSON file |
|
||||||
| save_hybrid | False | save hybrid version of labels (labels + additional predictions) |
|
| save_hybrid | False | save hybrid version of labels (labels + additional predictions) |
|
||||||
| conf | 0.001 | object confidence threshold for detection (default 0.25 predict, 0.001 val) |
|
| conf | 0.001 | object confidence threshold for detection (default 0.25 predict, 0.001 val) |
|
||||||
@ -157,6 +157,7 @@ validation dataset and to detect and prevent overfitting.
|
|||||||
| half | True | use half precision (FP16) |
|
| half | True | use half precision (FP16) |
|
||||||
| dnn | False | use OpenCV DNN for ONNX inference |
|
| dnn | False | use OpenCV DNN for ONNX inference |
|
||||||
| plots | False | show plots during training |
|
| plots | False | show plots during training |
|
||||||
|
| rect | False | support rectangular evaluation |
|
||||||
|
|
||||||
### Export
|
### Export
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.0.29"
|
__version__ = "8.0.30"
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils import ops
|
from ultralytics.yolo.utils import ops
|
||||||
|
@ -338,8 +338,9 @@ def torch_safe_load(weight):
|
|||||||
if e.name == 'omegaconf': # e.name is missing module name
|
if e.name == 'omegaconf': # e.name is missing module name
|
||||||
LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements."
|
LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements."
|
||||||
f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future."
|
f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future."
|
||||||
f"\nRecommend fixes are to train a new model using updated ultraltyics package or to "
|
f"\nRecommend fixes are to train a new model using updated ultralytics package or to "
|
||||||
f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
||||||
|
if e.name != 'models':
|
||||||
check_requirements(e.name) # install missing module
|
check_requirements(e.name) # install missing module
|
||||||
return torch.load(file, map_location='cpu') # load
|
return torch.load(file, map_location='cpu') # load
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ seed: 0 # random seed for reproducibility
|
|||||||
deterministic: True # whether to enable deterministic mode
|
deterministic: True # whether to enable deterministic mode
|
||||||
single_cls: False # train multi-class data as single-class
|
single_cls: False # train multi-class data as single-class
|
||||||
image_weights: False # use weighted image selection for training
|
image_weights: False # use weighted image selection for training
|
||||||
rect: False # support rectangular training
|
rect: False # support rectangular training if mode='train', support rectangular evaluation if mode='val'
|
||||||
cos_lr: False # use cosine learning rate scheduler
|
cos_lr: False # use cosine learning rate scheduler
|
||||||
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
|
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
|
||||||
resume: False # resume training from last checkpoint
|
resume: False # resume training from last checkpoint
|
||||||
|
@ -61,7 +61,7 @@ def seed_worker(worker_id):
|
|||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"):
|
def build_dataloader(cfg, batch_size, img_path, stride=32, rect=False, label_path=None, rank=-1, mode="train"):
|
||||||
assert mode in ["train", "val"]
|
assert mode in ["train", "val"]
|
||||||
shuffle = mode == "train"
|
shuffle = mode == "train"
|
||||||
if cfg.rect and shuffle:
|
if cfg.rect and shuffle:
|
||||||
@ -75,7 +75,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
augment=mode == "train", # augmentation
|
augment=mode == "train", # augmentation
|
||||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||||
rect=cfg.rect if mode == "train" else True, # rectangular batches
|
rect=cfg.rect or rect, # rectangular batches
|
||||||
cache=cfg.cache or None,
|
cache=cfg.cache or None,
|
||||||
single_cls=cfg.single_cls or False,
|
single_cls=cfg.single_cls or False,
|
||||||
stride=int(stride),
|
stride=int(stride),
|
||||||
|
@ -113,13 +113,15 @@ class YOLODataset(BaseDataset):
|
|||||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
||||||
if cache["msgs"]:
|
if cache["msgs"]:
|
||||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||||
assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
|
if nf == 0: # number of labels found
|
||||||
|
raise FileNotFoundError(f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}")
|
||||||
|
|
||||||
# Read cache
|
# Read cache
|
||||||
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||||||
labels = cache["labels"]
|
labels = cache["labels"]
|
||||||
|
|
||||||
# Check if the dataset is all boxes or all segments
|
# Check if the dataset is all boxes or all segments
|
||||||
|
len_cls = sum(len(lb["cls"]) for lb in labels)
|
||||||
len_boxes = sum(len(lb["bboxes"]) for lb in labels)
|
len_boxes = sum(len(lb["bboxes"]) for lb in labels)
|
||||||
len_segments = sum(len(lb["segments"]) for lb in labels)
|
len_segments = sum(len(lb["segments"]) for lb in labels)
|
||||||
if len_segments and len_boxes != len_segments:
|
if len_segments and len_boxes != len_segments:
|
||||||
@ -129,8 +131,8 @@ class YOLODataset(BaseDataset):
|
|||||||
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
|
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
|
||||||
for lb in labels:
|
for lb in labels:
|
||||||
lb["segments"] = []
|
lb["segments"] = []
|
||||||
nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels
|
if len_cls == 0:
|
||||||
assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
|
raise ValueError(f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}")
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
# TODO: use hyp config to set all these augmentations
|
# TODO: use hyp config to set all these augmentations
|
||||||
|
@ -192,7 +192,7 @@ def check_det_dataset(dataset, autodownload=True):
|
|||||||
# Download (optional)
|
# Download (optional)
|
||||||
extract_dir = ''
|
extract_dir = ''
|
||||||
if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
|
if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
|
||||||
download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
|
download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
|
||||||
data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
|
data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
|
||||||
extract_dir, autodownload = data.parent, False
|
extract_dir, autodownload = data.parent, False
|
||||||
|
|
||||||
@ -211,7 +211,8 @@ def check_det_dataset(dataset, autodownload=True):
|
|||||||
data['nc'] = len(data['names'])
|
data['nc'] = len(data['names'])
|
||||||
|
|
||||||
# Resolve paths
|
# Resolve paths
|
||||||
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
|
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
|
||||||
|
|
||||||
if not path.is_absolute():
|
if not path.is_absolute():
|
||||||
path = (DATASETS_DIR / path).resolve()
|
path = (DATASETS_DIR / path).resolve()
|
||||||
data['path'] = path # download scripts
|
data['path'] = path # download scripts
|
||||||
|
@ -156,6 +156,7 @@ class YOLO:
|
|||||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||||
"""
|
"""
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
|
overrides["rect"] = True # rect batches as default
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
overrides["mode"] = "val"
|
overrides["mode"] = "val"
|
||||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||||
|
@ -116,13 +116,16 @@ class BaseTrainer:
|
|||||||
|
|
||||||
# Model and Dataloaders.
|
# Model and Dataloaders.
|
||||||
self.model = self.args.model
|
self.model = self.args.model
|
||||||
self.data = self.args.data
|
try:
|
||||||
if self.data.endswith(".yaml"):
|
if self.args.task == 'classify':
|
||||||
self.data = check_det_dataset(self.data)
|
self.data = check_cls_dataset(self.args.data)
|
||||||
elif self.args.task == 'classify':
|
elif self.args.data.endswith(".yaml") or self.args.task in ('detect', 'segment'):
|
||||||
self.data = check_cls_dataset(self.data)
|
self.data = check_det_dataset(self.args.data)
|
||||||
else:
|
if 'yaml_file' in self.data:
|
||||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌"))
|
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
||||||
|
except Exception as e:
|
||||||
|
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
|
||||||
|
|
||||||
self.trainset, self.testset = self.get_dataset(self.data)
|
self.trainset, self.testset = self.get_dataset(self.data)
|
||||||
self.ema = None
|
self.ema = None
|
||||||
|
|
||||||
|
@ -117,6 +117,8 @@ class BaseValidator:
|
|||||||
|
|
||||||
if self.device.type == 'cpu':
|
if self.device.type == 'cpu':
|
||||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||||
|
if not pt:
|
||||||
|
self.args.rect = False
|
||||||
self.dataloader = self.dataloader or \
|
self.dataloader = self.dataloader or \
|
||||||
self.get_dataloader(self.data.get("val") or self.data.set("test"), self.args.batch)
|
self.get_dataloader(self.data.get("val") or self.data.set("test"), self.args.batch)
|
||||||
|
|
||||||
|
@ -491,6 +491,7 @@ def set_sentry():
|
|||||||
((is_pip_package() and not is_git_dir()) or
|
((is_pip_package() and not is_git_dir()) or
|
||||||
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
|
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import sentry_sdk # noqa
|
import sentry_sdk # noqa
|
||||||
from ultralytics import __version__
|
from ultralytics import __version__
|
||||||
|
|
||||||
@ -502,13 +503,14 @@ def set_sentry():
|
|||||||
environment='production', # 'dev' or 'production'
|
environment='production', # 'dev' or 'production'
|
||||||
before_send=before_send,
|
before_send=before_send,
|
||||||
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
|
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
|
||||||
|
sentry_sdk.set_user({"id": SETTINGS['uuid']})
|
||||||
|
|
||||||
# Disable all sentry logging
|
# Disable all sentry logging
|
||||||
for logger in "sentry_sdk", "sentry_sdk.errors":
|
for logger in "sentry_sdk", "sentry_sdk.errors":
|
||||||
logging.getLogger(logger).setLevel(logging.CRITICAL)
|
logging.getLogger(logger).setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
|
||||||
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
|
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.2'):
|
||||||
"""
|
"""
|
||||||
Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist.
|
Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist.
|
||||||
|
|
||||||
@ -519,6 +521,7 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Dictionary of settings key-value pairs.
|
dict: Dictionary of settings key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
import hashlib
|
||||||
from ultralytics.yolo.utils.checks import check_version
|
from ultralytics.yolo.utils.checks import check_version
|
||||||
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
||||||
|
|
||||||
@ -530,7 +533,7 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
|
|||||||
'weights_dir': str(root / 'weights'), # default weights directory.
|
'weights_dir': str(root / 'weights'), # default weights directory.
|
||||||
'runs_dir': str(root / 'runs'), # default runs directory.
|
'runs_dir': str(root / 'runs'), # default runs directory.
|
||||||
'sync': True, # sync analytics to help with YOLO development
|
'sync': True, # sync analytics to help with YOLO development
|
||||||
'uuid': uuid.getnode(), # device UUID to align analytics
|
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # anonymized uuid hash
|
||||||
'settings_version': version} # Ultralytics settings version
|
'settings_version': version} # Ultralytics settings version
|
||||||
|
|
||||||
with torch_distributed_zero_first(RANK):
|
with torch_distributed_zero_first(RANK):
|
||||||
@ -544,10 +547,9 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
|
|||||||
and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \
|
and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \
|
||||||
and check_version(settings['settings_version'], version)
|
and check_version(settings['settings_version'], version)
|
||||||
if not correct:
|
if not correct:
|
||||||
LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. '
|
LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. This is normal and may be due to a '
|
||||||
'\nThis is normal and may be due to a recent ultralytics package update, '
|
'recent ultralytics package update, but may have overwritten previous settings. '
|
||||||
'but may have overwritten previous settings. '
|
f"\nView and update settings with 'yolo settings' or at '{file}'")
|
||||||
f"\nYou may view and update settings directly in '{file}'")
|
|
||||||
settings = defaults # merge **defaults with **settings (prefer **settings)
|
settings = defaults # merge **defaults with **settings (prefer **settings)
|
||||||
yaml_save(file, settings) # save updated defaults
|
yaml_save(file, settings) # save updated defaults
|
||||||
|
|
||||||
|
@ -247,7 +247,7 @@ def check_file(file, suffix=''):
|
|||||||
if Path(file).is_file():
|
if Path(file).is_file():
|
||||||
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
||||||
else:
|
else:
|
||||||
downloads.safe_download(url=url, file=file)
|
downloads.safe_download(url=url, file=file, unzip=False)
|
||||||
return file
|
return file
|
||||||
else: # search
|
else: # search
|
||||||
files = []
|
files = []
|
||||||
|
@ -28,6 +28,19 @@ def is_url(url, check=True):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
|
||||||
|
"""
|
||||||
|
Unzip a *.zip file to path/, excluding files containing strings in exclude list
|
||||||
|
Replaces: ZipFile(file).extractall(path=path)
|
||||||
|
"""
|
||||||
|
if path is None:
|
||||||
|
path = Path(file).parent # default path
|
||||||
|
with ZipFile(file) as zipObj:
|
||||||
|
for f in zipObj.namelist(): # list all archived filenames in the zip
|
||||||
|
if all(x not in f for x in exclude):
|
||||||
|
zipObj.extract(f, path=path)
|
||||||
|
|
||||||
|
|
||||||
def safe_download(url,
|
def safe_download(url,
|
||||||
file=None,
|
file=None,
|
||||||
dir=None,
|
dir=None,
|
||||||
@ -96,13 +109,14 @@ def safe_download(url,
|
|||||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||||
|
|
||||||
if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}:
|
if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}:
|
||||||
LOGGER.info(f'Unzipping {f}...')
|
unzip_dir = dir or f.parent # unzip to dir if provided else unzip in place
|
||||||
|
LOGGER.info(f'Unzipping {f} to {unzip_dir}...')
|
||||||
if f.suffix == '.zip':
|
if f.suffix == '.zip':
|
||||||
ZipFile(f).extractall(path=f.parent) # unzip
|
unzip_file(file=f, path=unzip_dir) # unzip
|
||||||
elif f.suffix == '.tar':
|
elif f.suffix == '.tar':
|
||||||
subprocess.run(['tar', 'xf', f, '--directory', f.parent], check=True) # unzip
|
subprocess.run(['tar', 'xf', f, '--directory', unzip_dir], check=True) # unzip
|
||||||
elif f.suffix == '.gz':
|
elif f.suffix == '.gz':
|
||||||
subprocess.run(['tar', 'xfz', f, '--directory', f.parent], check=True) # unzip
|
subprocess.run(['tar', 'xfz', f, '--directory', unzip_dir], check=True) # unzip
|
||||||
if delete:
|
if delete:
|
||||||
f.unlink() # remove zip
|
f.unlink() # remove zip
|
||||||
|
|
||||||
|
@ -33,14 +33,14 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
augment=mode == "train",
|
augment=mode == "train",
|
||||||
cache=self.args.cache,
|
cache=self.args.cache,
|
||||||
pad=0 if mode == "train" else 0.5,
|
pad=0 if mode == "train" else 0.5,
|
||||||
rect=self.args.rect,
|
rect=self.args.rect or mode=="val",
|
||||||
rank=rank,
|
rank=rank,
|
||||||
workers=self.args.workers,
|
workers=self.args.workers,
|
||||||
close_mosaic=self.args.close_mosaic != 0,
|
close_mosaic=self.args.close_mosaic != 0,
|
||||||
prefix=colorstr(f'{mode}: '),
|
prefix=colorstr(f'{mode}: '),
|
||||||
shuffle=mode == "train",
|
shuffle=mode == "train",
|
||||||
seed=self.args.seed)[0] if self.args.v5loader else \
|
seed=self.args.seed)[0] if self.args.v5loader else \
|
||||||
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0]
|
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode, rect=mode=="val")[0]
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||||
|
@ -22,7 +22,6 @@ class DetectionValidator(BaseValidator):
|
|||||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||||
self.args.task = 'detect'
|
self.args.task = 'detect'
|
||||||
self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None
|
|
||||||
self.is_coco = False
|
self.is_coco = False
|
||||||
self.class_map = None
|
self.class_map = None
|
||||||
self.metrics = DetMetrics(save_dir=self.save_dir)
|
self.metrics = DetMetrics(save_dir=self.save_dir)
|
||||||
@ -172,7 +171,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
hyp=vars(self.args),
|
hyp=vars(self.args),
|
||||||
cache=False,
|
cache=False,
|
||||||
pad=0.5,
|
pad=0.5,
|
||||||
rect=True,
|
rect=self.args.rect,
|
||||||
workers=self.args.workers,
|
workers=self.args.workers,
|
||||||
prefix=colorstr(f'{self.args.mode}: '),
|
prefix=colorstr(f'{self.args.mode}: '),
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user