mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-26 19:25:39 +08:00 
			
		
		
		
	Set workers=0 for MPS Train and Val modes (#4697)
				
					
				
			This commit is contained in:
		
							parent
							
								
									2bc6e647c7
								
							
						
					
					
						commit
						a4fabfdacf
					
				
							
								
								
									
										1
									
								
								.github/workflows/codeql.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/codeql.yaml
									
									
									
									
										vendored
									
									
								
							| @ -5,6 +5,7 @@ name: "CodeQL" | |||||||
| on: | on: | ||||||
|   schedule: |   schedule: | ||||||
|     - cron: '0 0 1 * *' |     - cron: '0 0 1 * *' | ||||||
|  |   workflow_dispatch: | ||||||
| 
 | 
 | ||||||
| jobs: | jobs: | ||||||
|   analyze: |   analyze: | ||||||
|  | |||||||
| @ -107,7 +107,7 @@ class BaseTrainer: | |||||||
|             print_args(vars(self.args)) |             print_args(vars(self.args)) | ||||||
| 
 | 
 | ||||||
|         # Device |         # Device | ||||||
|         if self.device.type == 'cpu': |         if self.device.type in ('cpu', 'mps'): | ||||||
|             self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading |             self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading | ||||||
| 
 | 
 | ||||||
|         # Model and Dataset |         # Model and Dataset | ||||||
|  | |||||||
| @ -144,7 +144,7 @@ class BaseValidator: | |||||||
|             else: |             else: | ||||||
|                 raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) |                 raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) | ||||||
| 
 | 
 | ||||||
|             if self.device.type == 'cpu': |             if self.device.type in ('cpu', 'mps'): | ||||||
|                 self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading |                 self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading | ||||||
|             if not pt: |             if not pt: | ||||||
|                 self.args.rect = False |                 self.args.rect = False | ||||||
|  | |||||||
| @ -196,17 +196,27 @@ def add_integration_callbacks(instance): | |||||||
|         instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary |         instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary | ||||||
|             of callback lists. |             of callback lists. | ||||||
|     """ |     """ | ||||||
|     from .clearml import callbacks as clearml_cb | 
 | ||||||
|  |     # Load HUB callbacks | ||||||
|  |     from .hub import callbacks | ||||||
|  | 
 | ||||||
|  |     # Load training callbacks | ||||||
|  |     if 'Trainer' in instance.__class__.__name__: | ||||||
|  |         from .clearml import callbacks as clear_cb | ||||||
|         from .comet import callbacks as comet_cb |         from .comet import callbacks as comet_cb | ||||||
|         from .dvc import callbacks as dvc_cb |         from .dvc import callbacks as dvc_cb | ||||||
|     from .hub import callbacks as hub_cb |  | ||||||
|         from .mlflow import callbacks as mlflow_cb |         from .mlflow import callbacks as mlflow_cb | ||||||
|         from .neptune import callbacks as neptune_cb |         from .neptune import callbacks as neptune_cb | ||||||
|         from .raytune import callbacks as tune_cb |         from .raytune import callbacks as tune_cb | ||||||
|     from .tensorboard import callbacks as tensorboard_cb |         from .tensorboard import callbacks as tb_cb | ||||||
|         from .wb import callbacks as wb_cb |         from .wb import callbacks as wb_cb | ||||||
|  |         callbacks.update({**clear_cb, **comet_cb, **dvc_cb, **mlflow_cb, **neptune_cb, **tune_cb, **tb_cb, **wb_cb}) | ||||||
| 
 | 
 | ||||||
|     for x in clearml_cb, comet_cb, hub_cb, mlflow_cb, neptune_cb, tune_cb, tensorboard_cb, wb_cb, dvc_cb: |     # Load export callbacks (patch to avoid CoreML protobuf error) | ||||||
|         for k, v in x.items(): |     if 'Exporter' in instance.__class__.__name__: | ||||||
|  |         from .tensorboard import callbacks as tb_cb | ||||||
|  |         callbacks.update(tb_cb) | ||||||
|  | 
 | ||||||
|  |     for k, v in callbacks.items(): | ||||||
|         if v not in instance.callbacks[k]:  # prevent duplicate callbacks addition |         if v not in instance.callbacks[k]:  # prevent duplicate callbacks addition | ||||||
|             instance.callbacks[k].append(v)  # callback[name].append(func) |             instance.callbacks[k].append(v)  # callback[name].append(func) | ||||||
|  | |||||||
| @ -1,12 +1,6 @@ | |||||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||||
| 
 | 
 | ||||||
| import re |  | ||||||
| 
 |  | ||||||
| import matplotlib.image as mpimg |  | ||||||
| import matplotlib.pyplot as plt |  | ||||||
| 
 |  | ||||||
| from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING | ||||||
| from ultralytics.utils.torch_utils import model_info_for_loggers |  | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|     assert not TESTS_RUNNING  # do not log pytest |     assert not TESTS_RUNNING  # do not log pytest | ||||||
| @ -15,8 +9,8 @@ try: | |||||||
|     from clearml import Task |     from clearml import Task | ||||||
|     from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO |     from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO | ||||||
|     from clearml.binding.matplotlib_bind import PatchedMatplotlib |     from clearml.binding.matplotlib_bind import PatchedMatplotlib | ||||||
| 
 |  | ||||||
|     assert hasattr(clearml, '__version__')  # verify package is not directory |     assert hasattr(clearml, '__version__')  # verify package is not directory | ||||||
|  | 
 | ||||||
| except (ImportError, AssertionError): | except (ImportError, AssertionError): | ||||||
|     clearml = None |     clearml = None | ||||||
| 
 | 
 | ||||||
| @ -29,6 +23,8 @@ def _log_debug_samples(files, title='Debug Samples') -> None: | |||||||
|         files (list): A list of file paths in PosixPath format. |         files (list): A list of file paths in PosixPath format. | ||||||
|         title (str): A title that groups together images with the same values. |         title (str): A title that groups together images with the same values. | ||||||
|     """ |     """ | ||||||
|  |     import re | ||||||
|  | 
 | ||||||
|     if task := Task.current_task(): |     if task := Task.current_task(): | ||||||
|         for f in files: |         for f in files: | ||||||
|             if f.exists(): |             if f.exists(): | ||||||
| @ -48,6 +44,9 @@ def _log_plot(title, plot_path) -> None: | |||||||
|         title (str): The title of the plot. |         title (str): The title of the plot. | ||||||
|         plot_path (str): The path to the saved image file. |         plot_path (str): The path to the saved image file. | ||||||
|     """ |     """ | ||||||
|  |     import matplotlib.image as mpimg | ||||||
|  |     import matplotlib.pyplot as plt | ||||||
|  | 
 | ||||||
|     img = mpimg.imread(plot_path) |     img = mpimg.imread(plot_path) | ||||||
|     fig = plt.figure() |     fig = plt.figure() | ||||||
|     ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[])  # no ticks |     ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[])  # no ticks | ||||||
| @ -103,6 +102,7 @@ def on_fit_epoch_end(trainer): | |||||||
|                                         value=trainer.epoch_time, |                                         value=trainer.epoch_time, | ||||||
|                                         iteration=trainer.epoch) |                                         iteration=trainer.epoch) | ||||||
|         if trainer.epoch == 0: |         if trainer.epoch == 0: | ||||||
|  |             from ultralytics.utils.torch_utils import model_info_for_loggers | ||||||
|             for k, v in model_info_for_loggers(trainer).items(): |             for k, v in model_info_for_loggers(trainer).items(): | ||||||
|                 task.get_logger().report_single_value(k, v) |                 task.get_logger().report_single_value(k, v) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,10 +1,6 @@ | |||||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||||
| 
 | 
 | ||||||
| import os |  | ||||||
| from pathlib import Path |  | ||||||
| 
 |  | ||||||
| from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops | from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops | ||||||
| from ultralytics.utils.torch_utils import model_info_for_loggers |  | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|     assert not TESTS_RUNNING  # do not log pytest |     assert not TESTS_RUNNING  # do not log pytest | ||||||
| @ -12,8 +8,9 @@ try: | |||||||
|     import comet_ml |     import comet_ml | ||||||
| 
 | 
 | ||||||
|     assert hasattr(comet_ml, '__version__')  # verify package is not directory |     assert hasattr(comet_ml, '__version__')  # verify package is not directory | ||||||
| except (ImportError, AssertionError): | 
 | ||||||
|     comet_ml = None |     import os | ||||||
|  |     from pathlib import Path | ||||||
| 
 | 
 | ||||||
|     # Ensures certain logging functions only run for supported tasks |     # Ensures certain logging functions only run for supported tasks | ||||||
|     COMET_SUPPORTED_TASKS = ['detect'] |     COMET_SUPPORTED_TASKS = ['detect'] | ||||||
| @ -24,6 +21,9 @@ LABEL_PLOT_NAMES = 'labels', 'labels_correlogram' | |||||||
| 
 | 
 | ||||||
|     _comet_image_prediction_count = 0 |     _comet_image_prediction_count = 0 | ||||||
| 
 | 
 | ||||||
|  | except (ImportError, AssertionError): | ||||||
|  |     comet_ml = None | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def _get_comet_mode(): | def _get_comet_mode(): | ||||||
|     return os.getenv('COMET_MODE', 'online') |     return os.getenv('COMET_MODE', 'online') | ||||||
| @ -327,6 +327,7 @@ def on_fit_epoch_end(trainer): | |||||||
|     experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch) |     experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch) | ||||||
|     experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch) |     experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch) | ||||||
|     if curr_epoch == 1: |     if curr_epoch == 1: | ||||||
|  |         from ultralytics.utils.torch_utils import model_info_for_loggers | ||||||
|         experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch) |         experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch) | ||||||
| 
 | 
 | ||||||
|     if not save_assets: |     if not save_assets: | ||||||
|  | |||||||
| @ -1,37 +1,37 @@ | |||||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||||
| 
 | 
 | ||||||
| import os |  | ||||||
| import re |  | ||||||
| from pathlib import Path |  | ||||||
| 
 |  | ||||||
| import pkg_resources as pkg |  | ||||||
| 
 |  | ||||||
| from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING | ||||||
| from ultralytics.utils.torch_utils import model_info_for_loggers |  | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|     assert not TESTS_RUNNING  # do not log pytest |     assert not TESTS_RUNNING  # do not log pytest | ||||||
|     assert SETTINGS['dvc'] is True  # verify integration is enabled |     assert SETTINGS['dvc'] is True  # verify integration is enabled | ||||||
|     from importlib.metadata import version |  | ||||||
| 
 |  | ||||||
|     import dvclive |     import dvclive | ||||||
| 
 | 
 | ||||||
|  |     assert hasattr(dvclive, '__version__')  # verify package is not directory | ||||||
|  | 
 | ||||||
|  |     import os | ||||||
|  |     import re | ||||||
|  |     from importlib.metadata import version | ||||||
|  |     from pathlib import Path | ||||||
|  | 
 | ||||||
|  |     import pkg_resources as pkg | ||||||
|  | 
 | ||||||
|     ver = version('dvclive') |     ver = version('dvclive') | ||||||
|     if pkg.parse_version(ver) < pkg.parse_version('2.11.0'): |     if pkg.parse_version(ver) < pkg.parse_version('2.11.0'): | ||||||
|         LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).') |         LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).') | ||||||
|         dvclive = None  # noqa: F811 |         dvclive = None  # noqa: F811 | ||||||
| except (ImportError, AssertionError, TypeError): |  | ||||||
|     dvclive = None |  | ||||||
| 
 | 
 | ||||||
|     # DVCLive logger instance |     # DVCLive logger instance | ||||||
|     live = None |     live = None | ||||||
|     _processed_plots = {} |     _processed_plots = {} | ||||||
| 
 | 
 | ||||||
| # `on_fit_epoch_end` is called on final validation (probably need to be fixed) |     # `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we | ||||||
| # for now this is the way we distinguish final evaluation of the best model vs |     # distinguish final evaluation of the best model vs last epoch validation | ||||||
| # last epoch validation |  | ||||||
|     _training_epoch = False |     _training_epoch = False | ||||||
| 
 | 
 | ||||||
|  | except (ImportError, AssertionError, TypeError): | ||||||
|  |     dvclive = None | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def _log_images(path, prefix=''): | def _log_images(path, prefix=''): | ||||||
|     if live: |     if live: | ||||||
| @ -103,6 +103,7 @@ def on_fit_epoch_end(trainer): | |||||||
|             live.log_metric(metric, value) |             live.log_metric(metric, value) | ||||||
| 
 | 
 | ||||||
|         if trainer.epoch == 0: |         if trainer.epoch == 0: | ||||||
|  |             from ultralytics.utils.torch_utils import model_info_for_loggers | ||||||
|             for metric, value in model_info_for_loggers(trainer).items(): |             for metric, value in model_info_for_loggers(trainer).items(): | ||||||
|                 live.log_metric(metric, value, plot=False) |                 live.log_metric(metric, value, plot=False) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -5,7 +5,6 @@ from time import time | |||||||
| 
 | 
 | ||||||
| from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events | from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events | ||||||
| from ultralytics.utils import LOGGER, SETTINGS | from ultralytics.utils import LOGGER, SETTINGS | ||||||
| from ultralytics.utils.torch_utils import model_info_for_loggers |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def on_pretrain_routine_end(trainer): | def on_pretrain_routine_end(trainer): | ||||||
| @ -24,6 +23,7 @@ def on_fit_epoch_end(trainer): | |||||||
|         # Upload metrics after val end |         # Upload metrics after val end | ||||||
|         all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} |         all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} | ||||||
|         if trainer.epoch == 0: |         if trainer.epoch == 0: | ||||||
|  |             from ultralytics.utils.torch_utils import model_info_for_loggers | ||||||
|             all_plots = {**all_plots, **model_info_for_loggers(trainer)} |             all_plots = {**all_plots, **model_info_for_loggers(trainer)} | ||||||
|         session.metrics_queue[trainer.epoch] = json.dumps(all_plots) |         session.metrics_queue[trainer.epoch] = json.dumps(all_plots) | ||||||
|         if time() - session.timers['metrics'] > session.rate_limits['metrics']: |         if time() - session.timers['metrics'] > session.rate_limits['metrics']: | ||||||
|  | |||||||
| @ -1,17 +1,16 @@ | |||||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||||
| 
 | 
 | ||||||
| import os | from ultralytics.utils import LOGGER, ROOT, SETTINGS, TESTS_RUNNING, colorstr | ||||||
| import re |  | ||||||
| from pathlib import Path |  | ||||||
| 
 |  | ||||||
| from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr |  | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|     assert not TESTS_RUNNING  # do not log pytest |     assert not TESTS_RUNNING  # do not log pytest | ||||||
|     assert SETTINGS['mlflow'] is True  # verify integration is enabled |     assert SETTINGS['mlflow'] is True  # verify integration is enabled | ||||||
|     import mlflow |     import mlflow | ||||||
| 
 |  | ||||||
|     assert hasattr(mlflow, '__version__')  # verify package is not directory |     assert hasattr(mlflow, '__version__')  # verify package is not directory | ||||||
|  | 
 | ||||||
|  |     import os | ||||||
|  |     import re | ||||||
|  | 
 | ||||||
| except (ImportError, AssertionError): | except (ImportError, AssertionError): | ||||||
|     mlflow = None |     mlflow = None | ||||||
| 
 | 
 | ||||||
| @ -56,11 +55,10 @@ def on_fit_epoch_end(trainer): | |||||||
| def on_train_end(trainer): | def on_train_end(trainer): | ||||||
|     """Called at end of train loop to log model artifact info.""" |     """Called at end of train loop to log model artifact info.""" | ||||||
|     if mlflow: |     if mlflow: | ||||||
|         root_dir = Path(__file__).resolve().parents[3] |  | ||||||
|         run.log_artifact(trainer.last) |         run.log_artifact(trainer.last) | ||||||
|         run.log_artifact(trainer.best) |         run.log_artifact(trainer.best) | ||||||
|         run.pyfunc.log_model(artifact_path=experiment_name, |         run.pyfunc.log_model(artifact_path=experiment_name, | ||||||
|                              code_path=[str(root_dir)], |                              code_path=[str(ROOT.parent)], | ||||||
|                              artifacts={'model_path': str(trainer.save_dir)}, |                              artifacts={'model_path': str(trainer.save_dir)}, | ||||||
|                              python_model=run.pyfunc.PythonModel()) |                              python_model=run.pyfunc.PythonModel()) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,10 +1,6 @@ | |||||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||||
| 
 | 
 | ||||||
| import matplotlib.image as mpimg |  | ||||||
| import matplotlib.pyplot as plt |  | ||||||
| 
 |  | ||||||
| from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING | from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING | ||||||
| from ultralytics.utils.torch_utils import model_info_for_loggers |  | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|     assert not TESTS_RUNNING  # do not log pytest |     assert not TESTS_RUNNING  # do not log pytest | ||||||
| @ -13,11 +9,12 @@ try: | |||||||
|     from neptune.types import File |     from neptune.types import File | ||||||
| 
 | 
 | ||||||
|     assert hasattr(neptune, '__version__') |     assert hasattr(neptune, '__version__') | ||||||
| except (ImportError, AssertionError): |  | ||||||
|     neptune = None |  | ||||||
| 
 | 
 | ||||||
|     run = None  # NeptuneAI experiment logger instance |     run = None  # NeptuneAI experiment logger instance | ||||||
| 
 | 
 | ||||||
|  | except (ImportError, AssertionError): | ||||||
|  |     neptune = None | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def _log_scalars(scalars, step=0): | def _log_scalars(scalars, step=0): | ||||||
|     """Log scalars to the NeptuneAI experiment logger.""" |     """Log scalars to the NeptuneAI experiment logger.""" | ||||||
| @ -42,6 +39,9 @@ def _log_plot(title, plot_path): | |||||||
|         title (str) Title of the plot |         title (str) Title of the plot | ||||||
|         plot_path (PosixPath or str) Path to the saved image file |         plot_path (PosixPath or str) Path to the saved image file | ||||||
|         """ |         """ | ||||||
|  |     import matplotlib.image as mpimg | ||||||
|  |     import matplotlib.pyplot as plt | ||||||
|  | 
 | ||||||
|     img = mpimg.imread(plot_path) |     img = mpimg.imread(plot_path) | ||||||
|     fig = plt.figure() |     fig = plt.figure() | ||||||
|     ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[])  # no ticks |     ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[])  # no ticks | ||||||
| @ -70,6 +70,7 @@ def on_train_epoch_end(trainer): | |||||||
| def on_fit_epoch_end(trainer): | def on_fit_epoch_end(trainer): | ||||||
|     """Callback function called at end of each fit (train+val) epoch.""" |     """Callback function called at end of each fit (train+val) epoch.""" | ||||||
|     if run and trainer.epoch == 0: |     if run and trainer.epoch == 0: | ||||||
|  |         from ultralytics.utils.torch_utils import model_info_for_loggers | ||||||
|         run['Configuration/Model'] = model_info_for_loggers(trainer) |         run['Configuration/Model'] = model_info_for_loggers(trainer) | ||||||
|     _log_scalars(trainer.metrics, trainer.epoch + 1) |     _log_scalars(trainer.metrics, trainer.epoch + 1) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ try: | |||||||
|     import ray |     import ray | ||||||
|     from ray import tune |     from ray import tune | ||||||
|     from ray.air import session |     from ray.air import session | ||||||
|  | 
 | ||||||
| except (ImportError, AssertionError): | except (ImportError, AssertionError): | ||||||
|     tune = None |     tune = None | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -8,12 +8,12 @@ try: | |||||||
| 
 | 
 | ||||||
|     assert not TESTS_RUNNING  # do not log pytest |     assert not TESTS_RUNNING  # do not log pytest | ||||||
|     assert SETTINGS['tensorboard'] is True  # verify integration is enabled |     assert SETTINGS['tensorboard'] is True  # verify integration is enabled | ||||||
|  |     WRITER = None  # TensorBoard SummaryWriter instance | ||||||
|  | 
 | ||||||
| except (ImportError, AssertionError, TypeError): | except (ImportError, AssertionError, TypeError): | ||||||
|     # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows |     # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows | ||||||
|     SummaryWriter = None |     SummaryWriter = None | ||||||
| 
 | 
 | ||||||
| WRITER = None  # TensorBoard SummaryWriter instance |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| def _log_scalars(scalars, step=0): | def _log_scalars(scalars, step=0): | ||||||
|     """Logs scalar values to TensorBoard.""" |     """Logs scalar values to TensorBoard.""" | ||||||
| @ -72,4 +72,4 @@ callbacks = { | |||||||
|     'on_pretrain_routine_start': on_pretrain_routine_start, |     'on_pretrain_routine_start': on_pretrain_routine_start, | ||||||
|     'on_train_start': on_train_start, |     'on_train_start': on_train_start, | ||||||
|     'on_fit_epoch_end': on_fit_epoch_end, |     'on_fit_epoch_end': on_fit_epoch_end, | ||||||
|     'on_batch_end': on_batch_end} |     'on_batch_end': on_batch_end} if SummaryWriter else {} | ||||||
|  | |||||||
| @ -9,11 +9,12 @@ try: | |||||||
|     import wandb as wb |     import wandb as wb | ||||||
| 
 | 
 | ||||||
|     assert hasattr(wb, '__version__') |     assert hasattr(wb, '__version__') | ||||||
| except (ImportError, AssertionError): |  | ||||||
|     wb = None |  | ||||||
| 
 | 
 | ||||||
|     _processed_plots = {} |     _processed_plots = {} | ||||||
| 
 | 
 | ||||||
|  | except (ImportError, AssertionError): | ||||||
|  |     wb = None | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def _log_plots(plots, step): | def _log_plots(plots, step): | ||||||
|     for name, params in plots.items(): |     for name, params in plots.items(): | ||||||
|  | |||||||
| @ -273,7 +273,7 @@ def safe_download(url, | |||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     # Check if the URL is a Google Drive link |     # Check if the URL is a Google Drive link | ||||||
|     gdrive = 'drive.google.com' in url |     gdrive = url.startswith('https://drive.google.com/') | ||||||
|     if gdrive: |     if gdrive: | ||||||
|         url, file = get_google_drive_file_info(url) |         url, file = get_google_drive_file_info(url) | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Glenn Jocher
						Glenn Jocher