mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Set workers=0
for MPS Train and Val modes (#4697)
This commit is contained in:
parent
2bc6e647c7
commit
a4fabfdacf
1
.github/workflows/codeql.yaml
vendored
1
.github/workflows/codeql.yaml
vendored
@ -5,6 +5,7 @@ name: "CodeQL"
|
|||||||
on:
|
on:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 1 * *'
|
- cron: '0 0 1 * *'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
analyze:
|
analyze:
|
||||||
|
@ -107,7 +107,7 @@ class BaseTrainer:
|
|||||||
print_args(vars(self.args))
|
print_args(vars(self.args))
|
||||||
|
|
||||||
# Device
|
# Device
|
||||||
if self.device.type == 'cpu':
|
if self.device.type in ('cpu', 'mps'):
|
||||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||||
|
|
||||||
# Model and Dataset
|
# Model and Dataset
|
||||||
|
@ -144,7 +144,7 @@ class BaseValidator:
|
|||||||
else:
|
else:
|
||||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||||
|
|
||||||
if self.device.type == 'cpu':
|
if self.device.type in ('cpu', 'mps'):
|
||||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||||
if not pt:
|
if not pt:
|
||||||
self.args.rect = False
|
self.args.rect = False
|
||||||
|
@ -196,17 +196,27 @@ def add_integration_callbacks(instance):
|
|||||||
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
|
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
|
||||||
of callback lists.
|
of callback lists.
|
||||||
"""
|
"""
|
||||||
from .clearml import callbacks as clearml_cb
|
|
||||||
|
# Load HUB callbacks
|
||||||
|
from .hub import callbacks
|
||||||
|
|
||||||
|
# Load training callbacks
|
||||||
|
if 'Trainer' in instance.__class__.__name__:
|
||||||
|
from .clearml import callbacks as clear_cb
|
||||||
from .comet import callbacks as comet_cb
|
from .comet import callbacks as comet_cb
|
||||||
from .dvc import callbacks as dvc_cb
|
from .dvc import callbacks as dvc_cb
|
||||||
from .hub import callbacks as hub_cb
|
|
||||||
from .mlflow import callbacks as mlflow_cb
|
from .mlflow import callbacks as mlflow_cb
|
||||||
from .neptune import callbacks as neptune_cb
|
from .neptune import callbacks as neptune_cb
|
||||||
from .raytune import callbacks as tune_cb
|
from .raytune import callbacks as tune_cb
|
||||||
from .tensorboard import callbacks as tensorboard_cb
|
from .tensorboard import callbacks as tb_cb
|
||||||
from .wb import callbacks as wb_cb
|
from .wb import callbacks as wb_cb
|
||||||
|
callbacks.update({**clear_cb, **comet_cb, **dvc_cb, **mlflow_cb, **neptune_cb, **tune_cb, **tb_cb, **wb_cb})
|
||||||
|
|
||||||
for x in clearml_cb, comet_cb, hub_cb, mlflow_cb, neptune_cb, tune_cb, tensorboard_cb, wb_cb, dvc_cb:
|
# Load export callbacks (patch to avoid CoreML protobuf error)
|
||||||
for k, v in x.items():
|
if 'Exporter' in instance.__class__.__name__:
|
||||||
|
from .tensorboard import callbacks as tb_cb
|
||||||
|
callbacks.update(tb_cb)
|
||||||
|
|
||||||
|
for k, v in callbacks.items():
|
||||||
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
|
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
|
||||||
instance.callbacks[k].append(v) # callback[name].append(func)
|
instance.callbacks[k].append(v) # callback[name].append(func)
|
||||||
|
@ -1,12 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
import matplotlib.image as mpimg
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert not TESTS_RUNNING # do not log pytest
|
assert not TESTS_RUNNING # do not log pytest
|
||||||
@ -15,8 +9,8 @@ try:
|
|||||||
from clearml import Task
|
from clearml import Task
|
||||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||||
|
|
||||||
assert hasattr(clearml, '__version__') # verify package is not directory
|
assert hasattr(clearml, '__version__') # verify package is not directory
|
||||||
|
|
||||||
except (ImportError, AssertionError):
|
except (ImportError, AssertionError):
|
||||||
clearml = None
|
clearml = None
|
||||||
|
|
||||||
@ -29,6 +23,8 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
|
|||||||
files (list): A list of file paths in PosixPath format.
|
files (list): A list of file paths in PosixPath format.
|
||||||
title (str): A title that groups together images with the same values.
|
title (str): A title that groups together images with the same values.
|
||||||
"""
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
if task := Task.current_task():
|
if task := Task.current_task():
|
||||||
for f in files:
|
for f in files:
|
||||||
if f.exists():
|
if f.exists():
|
||||||
@ -48,6 +44,9 @@ def _log_plot(title, plot_path) -> None:
|
|||||||
title (str): The title of the plot.
|
title (str): The title of the plot.
|
||||||
plot_path (str): The path to the saved image file.
|
plot_path (str): The path to the saved image file.
|
||||||
"""
|
"""
|
||||||
|
import matplotlib.image as mpimg
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
img = mpimg.imread(plot_path)
|
img = mpimg.imread(plot_path)
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||||
@ -103,6 +102,7 @@ def on_fit_epoch_end(trainer):
|
|||||||
value=trainer.epoch_time,
|
value=trainer.epoch_time,
|
||||||
iteration=trainer.epoch)
|
iteration=trainer.epoch)
|
||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
|
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||||
for k, v in model_info_for_loggers(trainer).items():
|
for k, v in model_info_for_loggers(trainer).items():
|
||||||
task.get_logger().report_single_value(k, v)
|
task.get_logger().report_single_value(k, v)
|
||||||
|
|
||||||
|
@ -1,10 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
|
from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
|
||||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert not TESTS_RUNNING # do not log pytest
|
assert not TESTS_RUNNING # do not log pytest
|
||||||
@ -12,8 +8,9 @@ try:
|
|||||||
import comet_ml
|
import comet_ml
|
||||||
|
|
||||||
assert hasattr(comet_ml, '__version__') # verify package is not directory
|
assert hasattr(comet_ml, '__version__') # verify package is not directory
|
||||||
except (ImportError, AssertionError):
|
|
||||||
comet_ml = None
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# Ensures certain logging functions only run for supported tasks
|
# Ensures certain logging functions only run for supported tasks
|
||||||
COMET_SUPPORTED_TASKS = ['detect']
|
COMET_SUPPORTED_TASKS = ['detect']
|
||||||
@ -24,6 +21,9 @@ LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
|
|||||||
|
|
||||||
_comet_image_prediction_count = 0
|
_comet_image_prediction_count = 0
|
||||||
|
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
comet_ml = None
|
||||||
|
|
||||||
|
|
||||||
def _get_comet_mode():
|
def _get_comet_mode():
|
||||||
return os.getenv('COMET_MODE', 'online')
|
return os.getenv('COMET_MODE', 'online')
|
||||||
@ -327,6 +327,7 @@ def on_fit_epoch_end(trainer):
|
|||||||
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
|
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
|
||||||
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
|
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
|
||||||
if curr_epoch == 1:
|
if curr_epoch == 1:
|
||||||
|
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||||
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
|
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
|
||||||
|
|
||||||
if not save_assets:
|
if not save_assets:
|
||||||
|
@ -1,37 +1,37 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pkg_resources as pkg
|
|
||||||
|
|
||||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert not TESTS_RUNNING # do not log pytest
|
assert not TESTS_RUNNING # do not log pytest
|
||||||
assert SETTINGS['dvc'] is True # verify integration is enabled
|
assert SETTINGS['dvc'] is True # verify integration is enabled
|
||||||
from importlib.metadata import version
|
|
||||||
|
|
||||||
import dvclive
|
import dvclive
|
||||||
|
|
||||||
|
assert hasattr(dvclive, '__version__') # verify package is not directory
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from importlib.metadata import version
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pkg_resources as pkg
|
||||||
|
|
||||||
ver = version('dvclive')
|
ver = version('dvclive')
|
||||||
if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
|
if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
|
||||||
LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
|
LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
|
||||||
dvclive = None # noqa: F811
|
dvclive = None # noqa: F811
|
||||||
except (ImportError, AssertionError, TypeError):
|
|
||||||
dvclive = None
|
|
||||||
|
|
||||||
# DVCLive logger instance
|
# DVCLive logger instance
|
||||||
live = None
|
live = None
|
||||||
_processed_plots = {}
|
_processed_plots = {}
|
||||||
|
|
||||||
# `on_fit_epoch_end` is called on final validation (probably need to be fixed)
|
# `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
|
||||||
# for now this is the way we distinguish final evaluation of the best model vs
|
# distinguish final evaluation of the best model vs last epoch validation
|
||||||
# last epoch validation
|
|
||||||
_training_epoch = False
|
_training_epoch = False
|
||||||
|
|
||||||
|
except (ImportError, AssertionError, TypeError):
|
||||||
|
dvclive = None
|
||||||
|
|
||||||
|
|
||||||
def _log_images(path, prefix=''):
|
def _log_images(path, prefix=''):
|
||||||
if live:
|
if live:
|
||||||
@ -103,6 +103,7 @@ def on_fit_epoch_end(trainer):
|
|||||||
live.log_metric(metric, value)
|
live.log_metric(metric, value)
|
||||||
|
|
||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
|
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||||
for metric, value in model_info_for_loggers(trainer).items():
|
for metric, value in model_info_for_loggers(trainer).items():
|
||||||
live.log_metric(metric, value, plot=False)
|
live.log_metric(metric, value, plot=False)
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ from time import time
|
|||||||
|
|
||||||
from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
|
from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
|
||||||
from ultralytics.utils import LOGGER, SETTINGS
|
from ultralytics.utils import LOGGER, SETTINGS
|
||||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
|
||||||
|
|
||||||
|
|
||||||
def on_pretrain_routine_end(trainer):
|
def on_pretrain_routine_end(trainer):
|
||||||
@ -24,6 +23,7 @@ def on_fit_epoch_end(trainer):
|
|||||||
# Upload metrics after val end
|
# Upload metrics after val end
|
||||||
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
|
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
|
||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
|
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||||
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
|
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
|
||||||
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||||
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
|
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import os
|
from ultralytics.utils import LOGGER, ROOT, SETTINGS, TESTS_RUNNING, colorstr
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert not TESTS_RUNNING # do not log pytest
|
assert not TESTS_RUNNING # do not log pytest
|
||||||
assert SETTINGS['mlflow'] is True # verify integration is enabled
|
assert SETTINGS['mlflow'] is True # verify integration is enabled
|
||||||
import mlflow
|
import mlflow
|
||||||
|
|
||||||
assert hasattr(mlflow, '__version__') # verify package is not directory
|
assert hasattr(mlflow, '__version__') # verify package is not directory
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
except (ImportError, AssertionError):
|
except (ImportError, AssertionError):
|
||||||
mlflow = None
|
mlflow = None
|
||||||
|
|
||||||
@ -56,11 +55,10 @@ def on_fit_epoch_end(trainer):
|
|||||||
def on_train_end(trainer):
|
def on_train_end(trainer):
|
||||||
"""Called at end of train loop to log model artifact info."""
|
"""Called at end of train loop to log model artifact info."""
|
||||||
if mlflow:
|
if mlflow:
|
||||||
root_dir = Path(__file__).resolve().parents[3]
|
|
||||||
run.log_artifact(trainer.last)
|
run.log_artifact(trainer.last)
|
||||||
run.log_artifact(trainer.best)
|
run.log_artifact(trainer.best)
|
||||||
run.pyfunc.log_model(artifact_path=experiment_name,
|
run.pyfunc.log_model(artifact_path=experiment_name,
|
||||||
code_path=[str(root_dir)],
|
code_path=[str(ROOT.parent)],
|
||||||
artifacts={'model_path': str(trainer.save_dir)},
|
artifacts={'model_path': str(trainer.save_dir)},
|
||||||
python_model=run.pyfunc.PythonModel())
|
python_model=run.pyfunc.PythonModel())
|
||||||
|
|
||||||
|
@ -1,10 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import matplotlib.image as mpimg
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert not TESTS_RUNNING # do not log pytest
|
assert not TESTS_RUNNING # do not log pytest
|
||||||
@ -13,11 +9,12 @@ try:
|
|||||||
from neptune.types import File
|
from neptune.types import File
|
||||||
|
|
||||||
assert hasattr(neptune, '__version__')
|
assert hasattr(neptune, '__version__')
|
||||||
except (ImportError, AssertionError):
|
|
||||||
neptune = None
|
|
||||||
|
|
||||||
run = None # NeptuneAI experiment logger instance
|
run = None # NeptuneAI experiment logger instance
|
||||||
|
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
neptune = None
|
||||||
|
|
||||||
|
|
||||||
def _log_scalars(scalars, step=0):
|
def _log_scalars(scalars, step=0):
|
||||||
"""Log scalars to the NeptuneAI experiment logger."""
|
"""Log scalars to the NeptuneAI experiment logger."""
|
||||||
@ -42,6 +39,9 @@ def _log_plot(title, plot_path):
|
|||||||
title (str) Title of the plot
|
title (str) Title of the plot
|
||||||
plot_path (PosixPath or str) Path to the saved image file
|
plot_path (PosixPath or str) Path to the saved image file
|
||||||
"""
|
"""
|
||||||
|
import matplotlib.image as mpimg
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
img = mpimg.imread(plot_path)
|
img = mpimg.imread(plot_path)
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||||
@ -70,6 +70,7 @@ def on_train_epoch_end(trainer):
|
|||||||
def on_fit_epoch_end(trainer):
|
def on_fit_epoch_end(trainer):
|
||||||
"""Callback function called at end of each fit (train+val) epoch."""
|
"""Callback function called at end of each fit (train+val) epoch."""
|
||||||
if run and trainer.epoch == 0:
|
if run and trainer.epoch == 0:
|
||||||
|
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||||
run['Configuration/Model'] = model_info_for_loggers(trainer)
|
run['Configuration/Model'] = model_info_for_loggers(trainer)
|
||||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ try:
|
|||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.air import session
|
from ray.air import session
|
||||||
|
|
||||||
except (ImportError, AssertionError):
|
except (ImportError, AssertionError):
|
||||||
tune = None
|
tune = None
|
||||||
|
|
||||||
|
@ -8,12 +8,12 @@ try:
|
|||||||
|
|
||||||
assert not TESTS_RUNNING # do not log pytest
|
assert not TESTS_RUNNING # do not log pytest
|
||||||
assert SETTINGS['tensorboard'] is True # verify integration is enabled
|
assert SETTINGS['tensorboard'] is True # verify integration is enabled
|
||||||
|
WRITER = None # TensorBoard SummaryWriter instance
|
||||||
|
|
||||||
except (ImportError, AssertionError, TypeError):
|
except (ImportError, AssertionError, TypeError):
|
||||||
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
|
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
|
||||||
SummaryWriter = None
|
SummaryWriter = None
|
||||||
|
|
||||||
WRITER = None # TensorBoard SummaryWriter instance
|
|
||||||
|
|
||||||
|
|
||||||
def _log_scalars(scalars, step=0):
|
def _log_scalars(scalars, step=0):
|
||||||
"""Logs scalar values to TensorBoard."""
|
"""Logs scalar values to TensorBoard."""
|
||||||
@ -72,4 +72,4 @@ callbacks = {
|
|||||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||||
'on_train_start': on_train_start,
|
'on_train_start': on_train_start,
|
||||||
'on_fit_epoch_end': on_fit_epoch_end,
|
'on_fit_epoch_end': on_fit_epoch_end,
|
||||||
'on_batch_end': on_batch_end}
|
'on_batch_end': on_batch_end} if SummaryWriter else {}
|
||||||
|
@ -9,11 +9,12 @@ try:
|
|||||||
import wandb as wb
|
import wandb as wb
|
||||||
|
|
||||||
assert hasattr(wb, '__version__')
|
assert hasattr(wb, '__version__')
|
||||||
except (ImportError, AssertionError):
|
|
||||||
wb = None
|
|
||||||
|
|
||||||
_processed_plots = {}
|
_processed_plots = {}
|
||||||
|
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
wb = None
|
||||||
|
|
||||||
|
|
||||||
def _log_plots(plots, step):
|
def _log_plots(plots, step):
|
||||||
for name, params in plots.items():
|
for name, params in plots.items():
|
||||||
|
@ -273,7 +273,7 @@ def safe_download(url,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Check if the URL is a Google Drive link
|
# Check if the URL is a Google Drive link
|
||||||
gdrive = 'drive.google.com' in url
|
gdrive = url.startswith('https://drive.google.com/')
|
||||||
if gdrive:
|
if gdrive:
|
||||||
url, file = get_google_drive_file_info(url)
|
url, file = get_google_drive_file_info(url)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user