ultralytics 8.0.46 TFLite and Benchmarks updates (#1141)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-02-25 09:24:14 -08:00 committed by GitHub
parent 3765f4f6d9
commit a82ee2c779
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 130 additions and 74 deletions

View File

@ -55,18 +55,20 @@ jobs:
- name: Benchmark DetectionModel - name: Benchmark DetectionModel
shell: python shell: python
run: | run: |
from ultralytics.yolo.utils.benchmarks import run_benchmarks from ultralytics.yolo.utils.benchmarks import benchmark
run_benchmarks(model='${{ matrix.model }}.pt', imgsz=160, half=False, hard_fail=False) benchmark(model='${{ matrix.model }}.pt', imgsz=160, half=False, hard_fail=0.20)
- name: Benchmark SegmentationModel - name: Benchmark SegmentationModel
shell: python shell: python
run: | run: |
from ultralytics.yolo.utils.benchmarks import run_benchmarks from ultralytics.yolo.utils.benchmarks import benchmark
run_benchmarks(model='${{ matrix.model }}-seg.pt', imgsz=160, half=False, hard_fail=False) benchmark(model='${{ matrix.model }}-seg.pt', imgsz=160, half=False, hard_fail=0.14)
- name: Benchmark ClassificationModel - name: Benchmark ClassificationModel
shell: python shell: python
run: | run: |
from ultralytics.yolo.utils.benchmarks import run_benchmarks from ultralytics.yolo.utils.benchmarks import benchmark
run_benchmarks(model='${{ matrix.model }}-cls.pt', imgsz=160, half=False, hard_fail=False) benchmark(model='${{ matrix.model }}-cls.pt', imgsz=160, half=False, hard_fail=0.70)
- name: Benchmark Summary
run: cat benchmarks.log
Tests: Tests:
timeout-minutes: 60 timeout-minutes: 60
@ -88,10 +90,10 @@ jobs:
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Get cache dir - name: Get cache dir # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
id: pip-cache id: pip-cache
run: echo "::set-output name=dir::$(pip cache dir)" run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash # for Windows compatibility
- name: Cache pip - name: Cache pip
uses: actions/cache@v3 uses: actions/cache@v3
with: with:

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.45' __version__ = '8.0.46'
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks from ultralytics.yolo.utils.checks import check_yolo as checks

View File

@ -254,8 +254,8 @@ def entrypoint(debug=''):
else: else:
check_cfg_mismatch(full_args_dict, {a: ''}) check_cfg_mismatch(full_args_dict, {a: ''})
# Defaults # Check keys
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100') check_cfg_mismatch(full_args_dict, overrides)
# Mode # Mode
mode = overrides.get('mode', None) mode = overrides.get('mode', None)
@ -279,8 +279,9 @@ def entrypoint(debug=''):
model = YOLO(model) model = YOLO(model)
# Task # Task
task = overrides.get('task', None) task = overrides.get('task', model.task)
if task is not None and task not in TASKS: if task is not None:
if task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
else: else:
model.task = task model.task = task
@ -292,8 +293,9 @@ def entrypoint(debug=''):
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'): elif mode in ('train', 'val'):
if 'data' not in overrides: if 'data' not in overrides:
overrides['data'] = task2data.get(overrides['task'], DEFAULT_CFG.data) task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100')
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.") overrides['data'] = task2data.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export': elif mode == 'export':
if 'format' not in overrides: if 'format' not in overrides:
overrides['format'] = DEFAULT_CFG.format or 'torchscript' overrides['format'] = DEFAULT_CFG.format or 'torchscript'

View File

@ -16,10 +16,28 @@ from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image
class YOLODataset(BaseDataset): class YOLODataset(BaseDataset):
cache_version = '1.0.1' # dataset labels *.cache version, >= 1.0.0 for YOLOv8 cache_version = '1.0.1' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4] rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
"""YOLO Dataset. """
Dataset class for loading images object detection and/or segmentation labels in YOLO format.
Args: Args:
img_path (str): image path. img_path (str): path to the folder containing images.
prefix (str): prefix. imgsz (int): image size (default: 640).
cache (bool): if True, a cache file of the labels is created to speed up future creation of dataset instances
(default: False).
augment (bool): if True, data augmentation is applied (default: True).
hyp (dict): hyperparameters to apply data augmentation (default: None).
prefix (str): prefix to print in log messages (default: '').
rect (bool): if True, rectangular training is used (default: False).
batch_size (int): size of batches (default: None).
stride (int): stride (default: 32).
pad (float): padding (default: 0.0).
single_cls (bool): if True, single class training is used (default: False).
use_segments (bool): if True, segmentation masks are used as labels (default: False).
use_keypoints (bool): if True, keypoints are used as labels (default: False).
names (list): class names (default: None).
Returns:
A PyTorch dataset object that can be used for training an object detection or segmentation model.
""" """
def __init__(self, def __init__(self,
@ -44,7 +62,12 @@ class YOLODataset(BaseDataset):
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls) super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
def cache_labels(self, path=Path('./labels.cache')): def cache_labels(self, path=Path('./labels.cache')):
# Cache dataset labels, check images and read shapes """Cache dataset labels, check images and read shapes.
Args:
path (Path): path where to save the cache file (default: Path('./labels.cache')).
Returns:
(dict): labels.
"""
x = {'labels': []} x = {'labels': []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f'{self.prefix}Scanning {path.parent / path.stem}...' desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
@ -119,9 +142,8 @@ class YOLODataset(BaseDataset):
self.im_files = [lb['im_file'] for lb in labels] # update im_files self.im_files = [lb['im_file'] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments # Check if the dataset is all boxes or all segments
len_cls = sum(len(lb['cls']) for lb in labels) lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
len_boxes = sum(len(lb['bboxes']) for lb in labels) len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
len_segments = sum(len(lb['segments']) for lb in labels)
if len_segments and len_boxes != len_segments: if len_segments and len_boxes != len_segments:
LOGGER.warning( LOGGER.warning(
f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, ' f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '

View File

@ -294,7 +294,7 @@ class Exporter:
# YOLOv8 ONNX export # YOLOv8 ONNX export
requirements = ['onnx>=1.12.0'] requirements = ['onnx>=1.12.0']
if self.args.simplify: if self.args.simplify:
requirements += ['onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'] requirements += ['onnxsim>=0.4.17', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
check_requirements(requirements) check_requirements(requirements)
import onnx # noqa import onnx # noqa
@ -513,8 +513,8 @@ class Exporter:
cuda = torch.cuda.is_available() cuda = torch.cuda.is_available()
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}") check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
import tensorflow as tf # noqa import tensorflow as tf # noqa
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support', check_requirements(('onnx', 'onnx2tf>=1.7.7', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.17', 'onnx_graphsurgeon>=0.3.26',
'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'), 'tflite_support', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
cmds='--extra-index-url https://pypi.ngc.nvidia.com') cmds='--extra-index-url https://pypi.ngc.nvidia.com')
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
@ -529,7 +529,7 @@ class Exporter:
# Export to TF # Export to TF
int8 = '-oiqt -qt per-tensor' if self.args.int8 else '' int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
cmd = f'onnx2tf -i {f_onnx} -o {f} --non_verbose {int8}' cmd = f'onnx2tf -i {f_onnx} -o {f} -nuo --non_verbose {int8}'
LOGGER.info(f'\n{prefix} running {cmd}') LOGGER.info(f'\n{prefix} running {cmd}')
subprocess.run(cmd, shell=True) subprocess.run(cmd, shell=True)
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml

View File

@ -9,8 +9,9 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
guess_model_task, nn) guess_model_task, nn)
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml is_git_dir, is_pip_package, yaml_load)
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode from ultralytics.yolo.utils.torch_utils import smart_inference_mode
@ -150,6 +151,13 @@ class YOLO:
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only " f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.") f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
def _check_pip_update(self):
"""
Inform user of ultralytics package update availability
"""
if is_pip_package():
check_pip_update()
def reset(self): def reset(self):
""" """
Resets the model modules. Resets the model modules.
@ -189,6 +197,10 @@ class YOLO:
Returns: Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results. (List[ultralytics.yolo.engine.results.Results]): The prediction results.
""" """
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides['conf'] = 0.25 overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs overrides.update(kwargs) # prefer kwargs
@ -251,11 +263,12 @@ class YOLO:
Args: Args:
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
""" """
from ultralytics.yolo.utils.benchmarks import run_benchmarks self._check_is_pytorch_model()
from ultralytics.yolo.utils.benchmarks import benchmark
overrides = self.model.args.copy() overrides = self.model.args.copy()
overrides.update(kwargs) overrides.update(kwargs)
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
return run_benchmarks(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device']) return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
def export(self, **kwargs): def export(self, **kwargs):
""" """
@ -283,6 +296,7 @@ class YOLO:
**kwargs (Any): Any number of arguments representing the training configuration. **kwargs (Any): Any number of arguments representing the training configuration.
""" """
self._check_is_pytorch_model() self._check_is_pytorch_model()
self._check_pip_update()
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides.update(kwargs) overrides.update(kwargs)
if kwargs.get('cfg'): if kwargs.get('cfg'):

View File

@ -178,7 +178,12 @@ class BasePredictor:
self.run_callbacks('on_predict_postprocess_end') self.run_callbacks('on_predict_postprocess_end')
# visualize, save, write results # visualize, save, write results
for i in range(len(im)): n = len(im)
for i in range(n):
self.results[i].speed = {
'preprocess': self.dt[0].dt * 1E3 / n,
'inference': self.dt[1].dt * 1E3 / n,
'postprocess': self.dt[2].dt * 1E3 / n}
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \ p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \
else (path, im0s.copy()) else (path, im0s.copy())
p = Path(p) p = Path(p)

View File

@ -354,22 +354,6 @@ def get_git_branch():
return None # if not git dir or on error return None # if not git dir or on error
def get_latest_pypi_version(package_name='ultralytics'):
"""
Returns the latest version of a PyPI package without downloading or installing it.
Parameters:
package_name (str): The name of the package to find the latest version for.
Returns:
str: The latest version of the package.
"""
response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
if response.status_code == 200:
return response.json()['info']['version']
return None
def get_default_args(func): def get_default_args(func):
"""Returns a dictionary of default arguments for a function. """Returns a dictionary of default arguments for a function.
@ -611,7 +595,7 @@ def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
# Run below code on yolo/utils init ------------------------------------------------------------------------------------ # Run below code on yolo/utils init ------------------------------------------------------------------------------------
# Set logger # Set logger
set_logging(LOGGING_NAME) # run before defining LOGGER set_logging(LOGGING_NAME, verbose=VERBOSE) # run before defining LOGGER
LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.) LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
if WINDOWS: if WINDOWS:
for fn in LOGGER.info, LOGGER.warning: for fn in LOGGER.info, LOGGER.warning:

View File

@ -37,11 +37,7 @@ from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.torch_utils import select_device from ultralytics.yolo.utils.torch_utils import select_device
def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=0.30):
imgsz=640,
half=False,
device='cpu',
hard_fail=False):
device = select_device(device, verbose=False) device = select_device(device, verbose=False)
if isinstance(model, (str, Path)): if isinstance(model, (str, Path)):
model = YOLO(model) model = YOLO(model)
@ -52,6 +48,7 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
try: try:
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
assert i != 11 or model.task != 'classify', 'paddle-classify bug'
if 'cpu' in device.type: if 'cpu' in device.type:
assert cpu, 'inference not supported on CPU' assert cpu, 'inference not supported on CPU'
@ -85,26 +82,28 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
y.append([name, '', round(file_size(filename), 1), round(metric, 4), round(speed, 2)]) y.append([name, '', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
except Exception as e: except Exception as e:
if hard_fail: if hard_fail:
assert type(e) is AssertionError, f'Benchmark --hard-fail for {name}: {e}' assert type(e) is AssertionError, f'Benchmark hard_fail for {name}: {e}'
LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}') LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}')
y.append([name, '', None, None, None]) # mAP, t_inference y.append([name, '', None, None, None]) # mAP, t_inference
# Print results # Print results
LOGGER.info('\n')
check_yolo(device=device) # print system info check_yolo(device=device) # print system info
c = ['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'] if map else ['Format', 'Export', '', ''] c = ['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)']
df = pd.DataFrame(y, columns=c) df = pd.DataFrame(y, columns=c)
LOGGER.info(f'\nBenchmarks complete for {Path(model.ckpt_path).name} on {data} at imgsz={imgsz} '
f'({time.time() - t0:.2f}s)')
LOGGER.info(str(df if map else df.iloc[:, :2]))
if hard_fail and isinstance(hard_fail, str): name = Path(model.ckpt_path).name
s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n'
LOGGER.info(s)
with open('benchmarks.log', 'a') as f:
f.write(s)
if hard_fail and isinstance(hard_fail, float):
metrics = df[key].array # values to compare to floor metrics = df[key].array # values to compare to floor
floor = eval(hard_fail) # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n floor = hard_fail # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: metric < floor {floor}' assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: one or more metric(s) < floor {floor}'
return df return df
if __name__ == '__main__': if __name__ == '__main__':
run_benchmarks() benchmark()

View File

@ -16,6 +16,7 @@ import cv2
import numpy as np import numpy as np
import pkg_resources as pkg import pkg_resources as pkg
import psutil import psutil
import requests
import torch import torch
from matplotlib import font_manager from matplotlib import font_manager
@ -117,6 +118,31 @@ def check_version(current: str = '0.0.0',
return result return result
def check_latest_pypi_version(package_name='ultralytics'):
"""
Returns the latest version of a PyPI package without downloading or installing it.
Parameters:
package_name (str): The name of the package to find the latest version for.
Returns:
str: The latest version of the package.
"""
response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
if response.status_code == 200:
return response.json()['info']['version']
return None
def check_pip_update():
from ultralytics import __version__
latest = check_latest_pypi_version()
latest = '9.0.0'
if pkg.parse_version(__version__) < pkg.parse_version(latest):
LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 '
f"Update with 'pip install -U ultralytics'")
def check_font(font='Arial.ttf'): def check_font(font='Arial.ttf'):
""" """
Find font locally or download to user's configuration directory if it does not already exist. Find font locally or download to user's configuration directory if it does not already exist.

View File

@ -1,10 +1,12 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import os import os
import re
import shutil import shutil
import socket import socket
import sys import sys
import tempfile import tempfile
from pathlib import Path
from . import USER_CONFIG_DIR from . import USER_CONFIG_DIR
from .torch_utils import TORCH_1_9 from .torch_utils import TORCH_1_9
@ -22,12 +24,12 @@ def find_free_network_port() -> int:
def generate_ddp_file(trainer): def generate_ddp_file(trainer):
import_path = '.'.join(str(trainer.__class__).split('.')[1:-1]) module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__": content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
from ultralytics.{import_path} import {trainer.__class__.__name__} from {module} import {name}
trainer = {trainer.__class__.__name__}(cfg=cfg) trainer = {name}(cfg=cfg)
trainer.train()''' trainer.train()'''
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix='_temp_', with tempfile.NamedTemporaryFile(prefix='_temp_',
@ -41,12 +43,12 @@ def generate_ddp_file(trainer):
def generate_ddp_command(world_size, trainer): def generate_ddp_command(world_size, trainer):
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
file = os.path.abspath(sys.argv[0])
using_cli = not file.endswith('.py')
if not trainer.resume: if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir shutil.rmtree(trainer.save_dir) # remove the save_dir
if using_cli: file = str(Path(sys.argv[0]).resolve())
safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters
if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI
file = generate_ddp_file(trainer) file = generate_ddp_file(trainer)
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
port = find_free_network_port() port = find_free_network_port()