diff --git a/.gitignore b/.gitignore index c8987d84..64badb1c 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +mlruns/ # Translations *.mo diff --git a/tests/test_integrations.py b/tests/test_integrations.py index 9af173a9..bcb2a04e 100644 --- a/tests/test_integrations.py +++ b/tests/test_integrations.py @@ -29,6 +29,34 @@ def test_mlflow(): SETTINGS["mlflow"] = True YOLO("yolov8n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=3, plots=False, device="cpu") +@pytest.mark.skipif(not check_requirements('mlflow', install=False), reason='mlflow not installed') +def test_mlflow_keep_run_active(): + import os + import mlflow + """Test training with MLflow tracking enabled.""" + SETTINGS['mlflow'] = True + run_name = 'Test Run' + os.environ['MLFLOW_RUN'] = run_name + + # Test with MLFLOW_KEEP_RUN_ACTIVE=True + os.environ['MLFLOW_KEEP_RUN_ACTIVE'] = 'True' + YOLO('yolov8n-cls.yaml').train(data='imagenet10', imgsz=32, epochs=1, plots=False, device='cpu') + status = mlflow.active_run().info.status + assert status == 'RUNNING', "MLflow run should be active when MLFLOW_KEEP_RUN_ACTIVE=True" + + run_id = mlflow.active_run().info.run_id + + # Test with MLFLOW_KEEP_RUN_ACTIVE=False + os.environ['MLFLOW_KEEP_RUN_ACTIVE'] = 'False' + YOLO('yolov8n-cls.yaml').train(data='imagenet10', imgsz=32, epochs=1, plots=False, device='cpu') + status = mlflow.get_run(run_id=run_id).info.status + assert status == 'FINISHED', "MLflow run should be ended when MLFLOW_KEEP_RUN_ACTIVE=False" + + # Test with MLFLOW_KEEP_RUN_ACTIVE not set + os.environ.pop('MLFLOW_KEEP_RUN_ACTIVE', None) + YOLO('yolov8n-cls.yaml').train(data='imagenet10', imgsz=32, epochs=1, plots=False, device='cpu') + status = mlflow.get_run(run_id=run_id).info.status + assert status == 'FINISHED', "MLflow run should be ended by default when MLFLOW_KEEP_RUN_ACTIVE is not set" @pytest.mark.skipif(not check_requirements("tritonclient", install=False), reason="tritonclient[all] not installed") def test_triton(): diff --git a/ultralytics/utils/callbacks/mlflow.py b/ultralytics/utils/callbacks/mlflow.py index 1eaf2c3f..e5546200 100644 --- a/ultralytics/utils/callbacks/mlflow.py +++ b/ultralytics/utils/callbacks/mlflow.py @@ -58,6 +58,7 @@ def on_pretrain_routine_end(trainer): MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'. MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project. MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name. + MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of the training phase. """ global mlflow @@ -107,8 +108,13 @@ def on_train_end(trainer): for f in trainer.save_dir.glob("*"): # log all other files in save_dir if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}: mlflow.log_artifact(str(f)) + keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true") + if keep_run_active: + LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()") + else: + mlflow.end_run() + LOGGER.debug(f"{PREFIX}mlflow run ended") - mlflow.end_run() LOGGER.info( f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n" f"{PREFIX}disable with 'yolo settings mlflow=False'"