mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-25 10:25:39 +08:00 
			
		
		
		
	ultralytics 8.0.198 MLflow fix, tests and Docs page  (#5357)
				
					
				
			Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									12e3eef844
								
							
						
					
					
						commit
						5b3c4cfc0e
					
				| @ -24,7 +24,7 @@ Welcome to the Ultralytics Integrations page! This page provides an overview of | ||||
| 
 | ||||
| - [Ultralytics HUB](https://hub.ultralytics.com): Access and contribute to a community of pre-trained Ultralytics models. | ||||
| 
 | ||||
| - [MLFlow](https://mlflow.org/): Streamline the entire ML lifecycle of Ultralytics models, from experimentation and reproducibility to deployment. | ||||
| - [MLFlow](mlflow.md): Streamline the entire ML lifecycle of Ultralytics models, from experimentation and reproducibility to deployment. | ||||
| 
 | ||||
| - [Neptune](https://neptune.ai/): Maintain a comprehensive log of your ML experiments with Ultralytics in this metadata store designed for MLOps. | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										112
									
								
								docs/integrations/mlflow.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								docs/integrations/mlflow.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,112 @@ | ||||
| --- | ||||
| comments: true | ||||
| description: Uncover the utility of MLflow for effective experiment logging in your Ultralytics YOLO projects. | ||||
| keywords: ultralytics docs, YOLO, MLflow, experiment logging, metrics tracking, parameter logging, artifact logging | ||||
| --- | ||||
| 
 | ||||
| # MLflow Integration for Ultralytics YOLO | ||||
| 
 | ||||
| <img width="1024" src="https://user-images.githubusercontent.com/26833433/274929143-05e37e72-c355-44be-a842-b358592340b7.png" alt="MLflow ecosystem"> | ||||
| 
 | ||||
| ## Introduction | ||||
| 
 | ||||
| Experiment logging is a crucial aspect of machine learning workflows that enables tracking of various metrics, parameters, and artifacts. It helps to enhance model reproducibility, debug issues, and improve model performance. [Ultralytics](https://ultralytics.com) YOLO, known for its real-time object detection capabilities, now offers integration with [MLflow](https://mlflow.org/), an open-source platform for complete machine learning lifecycle management. | ||||
| 
 | ||||
| This documentation page is a comprehensive guide to setting up and utilizing the MLflow logging capabilities for your Ultralytics YOLO project. | ||||
| 
 | ||||
| ## What is MLflow? | ||||
| 
 | ||||
| [MLflow](https://mlflow.org/) is an open-source platform developed by [Databricks](https://www.databricks.com/) for managing the end-to-end machine learning lifecycle. It includes tools for tracking experiments, packaging code into reproducible runs, and sharing and deploying models. MLflow is designed to work with any machine learning library and programming language. | ||||
| 
 | ||||
| ## Features | ||||
| 
 | ||||
| - **Metrics Logging**: Logs metrics at the end of each epoch and at the end of the training. | ||||
| - **Parameter Logging**: Logs all the parameters used in the training. | ||||
| - **Artifacts Logging**: Logs model artifacts, including weights and configuration files, at the end of the training. | ||||
| 
 | ||||
| ## Setup and Prerequisites | ||||
| 
 | ||||
| Ensure MLflow is installed. If not, install it using pip: | ||||
| 
 | ||||
| ```bash | ||||
| pip install mlflow | ||||
| ``` | ||||
| 
 | ||||
| Make sure that MLflow logging is enabled in Ultralytics settings. Usually, this is controlled by the settings `mflow` key. See the [settings](https://docs.ultralytics.com/quickstart/#ultralytics-settings) page for more info. | ||||
| 
 | ||||
| !!! example "Update Ultralytics MLflow Settings" | ||||
| 
 | ||||
|     === "Python" | ||||
|         Within the Python environment, call the `update` method on the `settings` object to change your settings: | ||||
|         ```python | ||||
|         from ultralytics import settings | ||||
| 
 | ||||
|         # Update a setting | ||||
|         settings.update({'mlflow': True}) | ||||
| 
 | ||||
|         # Reset settings to default values | ||||
|         settings.reset() | ||||
|         ``` | ||||
| 
 | ||||
|     === "CLI" | ||||
|         If you prefer using the command-line interface, the following commands will allow you to modify your settings: | ||||
|         ```bash | ||||
|         # Update a setting | ||||
|         yolo settings runs_dir='/path/to/runs' | ||||
| 
 | ||||
|         # Reset settings to default values | ||||
|         yolo settings reset | ||||
|         ``` | ||||
| 
 | ||||
| ## How to Use | ||||
| 
 | ||||
| ### Commands | ||||
| 
 | ||||
| 1. **Set a Project Name**: You can set the project name via an environment variable: | ||||
|     ```bash | ||||
|     export MLFLOW_EXPERIMENT_NAME=<your_experiment_name> | ||||
|     ``` | ||||
|    Or use the `project=<project>` argument when training a YOLO model, i.e. `yolo train project=my_project`. | ||||
| 
 | ||||
| 2. **Set a Run Name**: Similar to setting a project name, you can set the run name via an environment variable: | ||||
|     ```bash | ||||
|     export MLFLOW_RUN=<your_run_name> | ||||
|     ``` | ||||
|    Or use the `name=<name>` argument when training a YOLO model, i.e. `yolo train project=my_project name=my_name`. | ||||
| 
 | ||||
| 3. **Start Local MLflow Server**: To start tracking, use: | ||||
|     ```bash | ||||
|     mlflow server --backend-store-uri runs/mlflow' | ||||
|     ``` | ||||
|    This will start a local server at http://127.0.0.1:5000 by default and save all mlflow logs to the 'runs/mlflow' directory. To specify a different URI, set the `MLFLOW_TRACKING_URI` environment variable. | ||||
| 
 | ||||
| 4. **Kill MLflow Server Instances**: To stop all running MLflow instances, run: | ||||
|     ```bash | ||||
|     ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9 | ||||
|     ``` | ||||
| 
 | ||||
| ### Logging | ||||
| 
 | ||||
| The logging is taken care of by the `on_pretrain_routine_end`, `on_fit_epoch_end`, and `on_train_end` callback functions. These functions are automatically called during the respective stages of the training process, and they handle the logging of parameters, metrics, and artifacts. | ||||
| 
 | ||||
| ## Examples | ||||
| 
 | ||||
| 1. **Logging Custom Metrics**: You can add custom metrics to be logged by modifying the `trainer.metrics` dictionary before `on_fit_epoch_end` is called. | ||||
| 
 | ||||
| 2. **View Experiment**: To view your logs, navigate to your MLflow server (usually http://127.0.0.1:5000) and select your experiment and run. | ||||
|    <img width="1024" src="https://user-images.githubusercontent.com/26833433/274933329-3127aa8c-4491-48ea-81df-ed09a5837f2a.png" alt="YOLO MLflow Experiment"> | ||||
| 
 | ||||
| 3. **View Run**: Runs are individual models inside an experiment. Click on a Run and see the Run details, including uploaded artifacts and model weights. | ||||
|    <img width="1024" src="https://user-images.githubusercontent.com/26833433/274933337-ac61371c-2867-4099-a733-147a2583b3de.png" alt="YOLO MLflow Run"> | ||||
| 
 | ||||
| ## Disabling MLflow | ||||
| 
 | ||||
| To turn off MLflow logging: | ||||
| 
 | ||||
| ```bash | ||||
| yolo settings mlflow=False | ||||
| ``` | ||||
| 
 | ||||
| ## Conclusion | ||||
| 
 | ||||
| MLflow logging integration with Ultralytics YOLO offers a streamlined way to keep track of your machine learning experiments. It empowers you to monitor performance metrics and manage artifacts effectively, thus aiding in robust model development and deployment. For further details please visit the MLflow [official documentation](https://mlflow.org/docs/latest/index.html). | ||||
| @ -229,6 +229,7 @@ nav: | ||||
|       - OpenVINO: integrations/openvino.md | ||||
|       - Ray Tune: integrations/ray-tune.md | ||||
|       - Roboflow: integrations/roboflow.md | ||||
|       - MLflow: integrations/mlflow.md | ||||
|   - Usage: | ||||
|       - CLI: usage/cli.md | ||||
|       - Python: usage/python.md | ||||
|  | ||||
| @ -1,16 +1,13 @@ | ||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||
| 
 | ||||
| import contextlib | ||||
| 
 | ||||
| import pytest | ||||
| import torch | ||||
| 
 | ||||
| from ultralytics import YOLO, download | ||||
| from ultralytics.utils import ASSETS, DATASETS_DIR, WEIGHTS_DIR | ||||
| from ultralytics.utils.checks import cuda_device_count, cuda_is_available | ||||
| from ultralytics.utils import ASSETS, DATASETS_DIR, WEIGHTS_DIR, checks | ||||
| 
 | ||||
| CUDA_IS_AVAILABLE = cuda_is_available() | ||||
| CUDA_DEVICE_COUNT = cuda_device_count() | ||||
| CUDA_IS_AVAILABLE = checks.cuda_is_available() | ||||
| CUDA_DEVICE_COUNT = checks.cuda_device_count() | ||||
| 
 | ||||
| MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt'  # test spaces in path | ||||
| DATA = 'coco8.yaml' | ||||
| @ -107,20 +104,6 @@ def test_predict_sam(): | ||||
|     predictor.reset_image() | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') | ||||
| def test_model_ray_tune(): | ||||
|     """Tune YOLO model with Ray optimization library.""" | ||||
|     with contextlib.suppress(RuntimeError):  # RuntimeError may be caused by out-of-memory | ||||
|         YOLO('yolov8n-cls.yaml').tune(use_ray=True, | ||||
|                                       data='imagenet10', | ||||
|                                       grace_period=1, | ||||
|                                       iterations=1, | ||||
|                                       imgsz=32, | ||||
|                                       epochs=1, | ||||
|                                       plots=False, | ||||
|                                       device='cpu') | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') | ||||
| def test_model_tune(): | ||||
|     """Tune YOLO model for performance.""" | ||||
|  | ||||
							
								
								
									
										26
									
								
								tests/test_integrations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								tests/test_integrations.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | ||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| from ultralytics import YOLO | ||||
| from ultralytics.utils import SETTINGS, checks | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.skipif(not checks.check_requirements('ray', install=False), reason='RayTune not installed') | ||||
| def test_model_ray_tune(): | ||||
|     """Tune YOLO model with Ray optimization library.""" | ||||
|     YOLO('yolov8n-cls.yaml').tune(use_ray=True, | ||||
|                                   data='imagenet10', | ||||
|                                   grace_period=1, | ||||
|                                   iterations=1, | ||||
|                                   imgsz=32, | ||||
|                                   epochs=1, | ||||
|                                   plots=False, | ||||
|                                   device='cpu') | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.skipif(not checks.check_requirements('mlflow', install=False), reason='MLflow not installed') | ||||
| def test_mlflow(): | ||||
|     """Test training with MLflow tracking enabled.""" | ||||
|     SETTINGS['mlflow'] = True | ||||
|     YOLO('yolov8n-cls.yaml').train(data='imagenet10', imgsz=32, epochs=3, plots=False, device='cpu') | ||||
| @ -1,6 +1,6 @@ | ||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||
| 
 | ||||
| __version__ = '8.0.197' | ||||
| __version__ = '8.0.198' | ||||
| 
 | ||||
| from ultralytics.models import RTDETR, SAM, YOLO | ||||
| from ultralytics.models.fastsam import FastSAM | ||||
|  | ||||
| @ -7,9 +7,9 @@ from pathlib import Path | ||||
| from types import SimpleNamespace | ||||
| from typing import Dict, List, Union | ||||
| 
 | ||||
| from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, ROOT, SETTINGS, | ||||
|                                SETTINGS_YAML, TESTS_RUNNING, IterableSimpleNamespace, __version__, checks, colorstr, | ||||
|                                deprecation_warn, yaml_load, yaml_print) | ||||
| from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, ROOT, RUNS_DIR, | ||||
|                                SETTINGS, SETTINGS_YAML, TESTS_RUNNING, IterableSimpleNamespace, __version__, checks, | ||||
|                                colorstr, deprecation_warn, yaml_load, yaml_print) | ||||
| 
 | ||||
| # Define valid tasks and modes | ||||
| MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' | ||||
| @ -153,8 +153,7 @@ def get_save_dir(args, name=None): | ||||
|     else: | ||||
|         from ultralytics.utils.files import increment_path | ||||
| 
 | ||||
|         project = args.project or (ROOT / | ||||
|                                    '../tests/tmp/runs' if TESTS_RUNNING else Path(SETTINGS['runs_dir'])) / args.task | ||||
|         project = args.project or (ROOT.parent / 'tests/tmp/runs' if TESTS_RUNNING else RUNS_DIR) / args.task | ||||
|         name = name or args.name or f'{args.mode}' | ||||
|         save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True) | ||||
| 
 | ||||
|  | ||||
| @ -91,6 +91,7 @@ class BaseTrainer: | ||||
| 
 | ||||
|         # Dirs | ||||
|         self.save_dir = get_save_dir(self.args) | ||||
|         self.args.name = self.save_dir.name  # update name for loggers | ||||
|         self.wdir = self.save_dir / 'weights'  # weights dir | ||||
|         if RANK in (-1, 0): | ||||
|             self.wdir.mkdir(parents=True, exist_ok=True)  # make dir | ||||
|  | ||||
| @ -930,7 +930,8 @@ def url2file(url): | ||||
| PREFIX = colorstr('Ultralytics: ') | ||||
| SETTINGS = SettingsManager()  # initialize settings | ||||
| DATASETS_DIR = Path(SETTINGS['datasets_dir'])  # global datasets directory | ||||
| WEIGHTS_DIR = Path(SETTINGS['weights_dir']) | ||||
| WEIGHTS_DIR = Path(SETTINGS['weights_dir'])  # global weights directory | ||||
| RUNS_DIR = Path(SETTINGS['runs_dir'])  # global runs directory | ||||
| ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \ | ||||
|     'Docker' if is_docker() else platform.system() | ||||
| TESTS_RUNNING = is_pytest_running() or is_github_actions_ci() | ||||
|  | ||||
| @ -1,64 +1,104 @@ | ||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||
| """ | ||||
| MLflow Logging for Ultralytics YOLO. | ||||
| 
 | ||||
| from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr | ||||
| This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts. | ||||
| For setting up, a tracking URI should be specified. The logging can be customized using environment variables. | ||||
| 
 | ||||
| Commands: | ||||
|     1. To set a project name: | ||||
|         `export MLFLOW_EXPERIMENT_NAME=<your_experiment_name>` or use the project=<project> argument | ||||
| 
 | ||||
|     2. To set a run name: | ||||
|         `export MLFLOW_RUN=<your_run_name>` or use the name=<name> argument | ||||
| 
 | ||||
|     3. To start a local MLflow server: | ||||
|         mlflow server --backend-store-uri runs/mlflow | ||||
|        It will by default start a local server at http://127.0.0.1:5000. | ||||
|        To specify a different URI, set the MLFLOW_TRACKING_URI environment variable. | ||||
| 
 | ||||
|     4. To kill all running MLflow server instances: | ||||
|         ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9 | ||||
| """ | ||||
| 
 | ||||
| from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr | ||||
| 
 | ||||
| try: | ||||
|     assert not TESTS_RUNNING  # do not log pytest | ||||
|     import os | ||||
| 
 | ||||
|     assert not TESTS_RUNNING or 'test_mlflow' in os.environ.get('PYTEST_CURRENT_TEST', '')  # do not log pytest | ||||
|     assert SETTINGS['mlflow'] is True  # verify integration is enabled | ||||
|     import mlflow | ||||
| 
 | ||||
|     assert hasattr(mlflow, '__version__')  # verify package is not directory | ||||
|     PREFIX = colorstr('MLFlow:') | ||||
|     import os | ||||
|     import re | ||||
|     from pathlib import Path | ||||
|     PREFIX = colorstr('MLflow: ') | ||||
| 
 | ||||
| except (ImportError, AssertionError): | ||||
|     mlflow = None | ||||
| 
 | ||||
| 
 | ||||
| def on_pretrain_routine_end(trainer): | ||||
|     """Logs training parameters to MLflow.""" | ||||
|     global mlflow, run, experiment_name | ||||
|     """ | ||||
|     Log training parameters to MLflow at the end of the pretraining routine. | ||||
| 
 | ||||
|     if os.environ.get('MLFLOW_TRACKING_URI') is None: | ||||
|         mlflow = None | ||||
|     This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI, | ||||
|     experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters | ||||
|     from the trainer. | ||||
| 
 | ||||
|     if mlflow: | ||||
|         mlflow_location = os.environ['MLFLOW_TRACKING_URI']  # "http://192.168.xxx.xxx:5000" | ||||
|         LOGGER.debug(f'{PREFIX} tracking uri: {mlflow_location}') | ||||
|         mlflow.set_tracking_uri(mlflow_location) | ||||
|     Args: | ||||
|         trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log. | ||||
| 
 | ||||
|     Global: | ||||
|         mlflow: The imported mlflow module to use for logging. | ||||
| 
 | ||||
|     Environment Variables: | ||||
|         MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'. | ||||
|         MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project. | ||||
|         MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name. | ||||
|     """ | ||||
|     global mlflow | ||||
| 
 | ||||
|     uri = os.environ.get('MLFLOW_TRACKING_URI') or str(RUNS_DIR / 'mlflow') | ||||
|     LOGGER.debug(f'{PREFIX} tracking uri: {uri}') | ||||
|     mlflow.set_tracking_uri(uri) | ||||
| 
 | ||||
|     # Set experiment and run names | ||||
|     experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8' | ||||
|     run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name | ||||
|         experiment = mlflow.set_experiment(experiment_name)  # change since mlflow does this now by default | ||||
|     mlflow.set_experiment(experiment_name) | ||||
| 
 | ||||
|     mlflow.autolog() | ||||
|         prefix = colorstr('MLFlow: ') | ||||
|     try: | ||||
|             run, active_run = mlflow, mlflow.active_run() | ||||
|             if not active_run: | ||||
|                 active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name) | ||||
|             LOGGER.info(f'{prefix}Using run_id({active_run.info.run_id}) at {mlflow_location}') | ||||
|             run.log_params(trainer.args) | ||||
|         except Exception as err: | ||||
|             LOGGER.error(f'{prefix}Failing init - {repr(err)}') | ||||
|             LOGGER.warning(f'{prefix}Continuing without Mlflow') | ||||
|         active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name) | ||||
|         LOGGER.info(f'{PREFIX}logging run_id({active_run.info.run_id}) to {uri}') | ||||
|         if Path(uri).is_dir(): | ||||
|             LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'") | ||||
|         LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'") | ||||
|         mlflow.log_params(dict(trainer.args)) | ||||
|     except Exception as e: | ||||
|         LOGGER.warning(f'{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n' | ||||
|                        f'{PREFIX}WARNING ⚠️ Not tracking this run') | ||||
| 
 | ||||
| 
 | ||||
| def on_fit_epoch_end(trainer): | ||||
|     """Logs training metrics to Mlflow.""" | ||||
|     """Log training metrics at the end of each fit epoch to MLflow.""" | ||||
|     if mlflow: | ||||
|         metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()} | ||||
|         run.log_metrics(metrics=metrics_dict, step=trainer.epoch) | ||||
|         sanitized_metrics = {k.replace('(', '').replace(')', ''): float(v) for k, v in trainer.metrics.items()} | ||||
|         mlflow.log_metrics(metrics=sanitized_metrics, step=trainer.epoch) | ||||
| 
 | ||||
| 
 | ||||
| def on_train_end(trainer): | ||||
|     """Called at end of train loop to log model artifact info.""" | ||||
|     """Log model artifacts at the end of the training.""" | ||||
|     if mlflow: | ||||
|         run.log_artifact(trainer.last) | ||||
|         run.log_artifact(trainer.best) | ||||
|         run.log_artifact(trainer.save_dir) | ||||
|         mlflow.log_artifact(str(trainer.best.parent))  # log save_dir/weights directory with best.pt and last.pt | ||||
|         for f in trainer.save_dir.glob('*'):  # log all other files in save_dir | ||||
|             if f.suffix in {'.png', '.jpg', '.csv', '.pt', '.yaml'}: | ||||
|                 mlflow.log_artifact(str(f)) | ||||
| 
 | ||||
|         mlflow.end_run() | ||||
|         LOGGER.debug(f'{PREFIX} ending run') | ||||
|         LOGGER.info(f'{PREFIX}results logged to {mlflow.get_tracking_uri()}\n' | ||||
|                     f"{PREFIX}disable with 'yolo settings mlflow=False'") | ||||
| 
 | ||||
| 
 | ||||
| callbacks = { | ||||
|  | ||||
| @ -19,7 +19,7 @@ except (ImportError, AssertionError): | ||||
|     wb = None | ||||
| 
 | ||||
| 
 | ||||
| def _custom_table(x, y, classes, title='Precision Recall Curve', x_axis_title='Recall', y_axis_title='Precision'): | ||||
| def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall', y_title='Precision'): | ||||
|     """ | ||||
|     Create and log a custom metric visualization to wandb.plot.pr_curve. | ||||
| 
 | ||||
| @ -39,7 +39,7 @@ def _custom_table(x, y, classes, title='Precision Recall Curve', x_axis_title='R | ||||
|     """ | ||||
|     df = pd.DataFrame({'class': classes, 'y': y, 'x': x}).round(3) | ||||
|     fields = {'x': 'x', 'y': 'y', 'class': 'class'} | ||||
|     string_fields = {'title': title, 'x-axis-title': x_axis_title, 'y-axis-title': y_axis_title} | ||||
|     string_fields = {'title': title, 'x-axis-title': x_title, 'y-axis-title': y_title} | ||||
|     return wb.plot_table('wandb/area-under-curve/v0', | ||||
|                          wb.Table(dataframe=df), | ||||
|                          fields=fields, | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Glenn Jocher
						Glenn Jocher