Trackear métricas e hiperparámetros con MLflow es la forma más limpia de comparar modelos baseline entrenados con o sin validación cruzada. Aquí verás cómo extender la función de tracking para registrar ROC-AUC, precision, recall y los mejores hiperparámetros, sin importar si tu modelo se entrenó con búsqueda de grilla o no.
¿Cómo añadir métricas a la lista de tracking en MLflow?
La idea es centralizar todas las métricas en una sola lista y luego registrarlas en MLflow para no perder trazabilidad entre experimentos.
Partimos de una lista vacía llamada metric y le aplicamos extend cada vez que calculamos nuevas métricas. Por ejemplo, después de calcular el ROC-AUC tanto en train como en test, los añadimos a esa lista. Aunque el logger.info ya muestra los valores en consola, ese print es solo informativo: no guarda nada en MLflow.
Para que el registro sea persistente, usamos mlflow.log_metric con un nombre claro y el valor calculado:
roc_auc_train para la métrica de entrenamiento.
roc_auc_test para la métrica de prueba.
- Precision y recall en sus versiones train y test.
¿Cuál es la diferencia entre logger.info y mlflow.log_metric? logger.info solo imprime el valor en consola como información. mlflow.log_metric guarda la métrica dentro del experimento de MLflow, lo que te permite compararla después entre corridas.
¿Cómo calcular precision y recall ponderados sobre train y test?
Una vez entrenado el modelo, generamos las predicciones con el método predict aplicado por separado a los datos de entrenamiento y a los de prueba.
Con esas predicciones calculamos precision y recall usando el promedio ponderado, es decir, el parámetro weighted. Este promedio toma en cuenta el desbalance entre clases, algo clave cuando trabajas con datasets donde una clase domina sobre la otra. Después, cada métrica se registra con mlflow.log_metric, diferenciando siempre entre la versión train y la versión test para mantener limpia la comparación.
¿Por qué usar weighted en lugar del promedio simple?
El promedio ponderado pesa cada clase según su frecuencia, así que la métrica refleja mejor el rendimiento real cuando las clases no están balanceadas. Si usaras el promedio simple, una clase minoritaria con mal desempeño se vería diluida.
¿Cómo trackear hiperparámetros cuando usas validación cruzada?
Aquí viene la parte interesante. La función de tracking recibe un parámetro llamado use_cv que indica si el modelo fue entrenado con búsqueda de grilla y validación cruzada o no.
El detalle técnico es que un modelo entrenado con GridSearchCV no expone sus hiperparámetros con get_params directamente, sino que los guarda en el atributo best_params_. Si intentas llamar al método equivocado, te lanzará un error. Para resolverlo, manejamos los dos casos con un bloque condicional:
- Si
use_cv es verdadero, asignas best_params = model.best_params_ porque el modelo pasó por una búsqueda de grilla.
- Si
use_cv es falso, asignas best_params = model.get_params() porque el modelo se entrenó de forma directa sin grilla.
De esta forma, sin importar el camino que haya seguido el entrenamiento, siempre obtienes un diccionario con los hiperparámetros listos para registrar.
¿Qué hace mlflow.log_params? Recibe un diccionario y guarda cada par clave valor como hiperparámetro del experimento. Es ideal cuando tienes muchos parámetros porque los registra todos en una sola llamada.
¿Cómo manejar errores con try y except en el tracking?
Envolver la lógica en un bloque try y except te protege de caídas silenciosas. Si algo falla al acceder a best_params_ o get_params, el except puede capturar la excepción y dejar un logger.info con el detalle del error. Aunque la estructura if y else ya cubre los dos escenarios principales, sumar el try actúa como una red de seguridad extra.
Finalmente, fuera del condicional, llamas a mlflow.log_params(best_params) para que el registro ocurra siempre, sin duplicar código y sin importar qué rama se haya ejecutado.
Conceptos clave que aparecen en la clase
Algunos términos vale la pena tenerlos claros para seguir el flujo del código:
- ROC-AUC: métrica que mide qué tan bien el modelo separa las clases. Se registra tanto en train como en test.
- Precision y recall ponderados: métricas de clasificación calculadas con
average="weighted" para considerar el peso de cada clase.
mlflow.log_metric: registra valores numéricos individuales en el experimento.
mlflow.log_params: registra un diccionario completo de hiperparámetros.
best_params_ vs get_params: el primero existe solo en modelos entrenados con GridSearchCV; el segundo está disponible en cualquier estimador de scikit-learn.
use_cv: bandera que indica si el modelo pasó por validación cruzada con búsqueda de grilla.
¿Estás aplicando este patrón en tus propios experimentos? Cuéntame en los comentarios cómo organizas tu función de tracking y qué métricas no pueden faltar en tus baselines.