diff --git a/docs/changes/48.maintenance.md b/docs/changes/48.maintenance.md new file mode 100644 index 0000000..fe5cf67 --- /dev/null +++ b/docs/changes/48.maintenance.md @@ -0,0 +1 @@ +Add early stopping to classification. Increase number of estimators. diff --git a/src/eventdisplay_ml/hyper_parameters.py b/src/eventdisplay_ml/hyper_parameters.py index f8568fb..4cfe4dc 100644 --- a/src/eventdisplay_ml/hyper_parameters.py +++ b/src/eventdisplay_ml/hyper_parameters.py @@ -29,14 +29,15 @@ "model": None, "hyper_parameters": { "objective": "binary:logistic", - "eval_metric": "logloss", # TODO AUC ? - "n_estimators": 100, # TODO probably too low - "max_depth": 6, - "learning_rate": 0.1, + "eval_metric": ["logloss", "auc"], + "n_estimators": 5000, + "early_stopping_rounds": 50, + "max_depth": 7, + "learning_rate": 0.05, "subsample": 0.8, "colsample_bytree": 0.8, "random_state": None, - "n_jobs": 8, + "n_jobs": 48, }, } } diff --git a/src/eventdisplay_ml/models.py b/src/eventdisplay_ml/models.py index 68c4c10..cc244c0 100644 --- a/src/eventdisplay_ml/models.py +++ b/src/eventdisplay_ml/models.py @@ -572,6 +572,7 @@ def train_classification(df, model_configs): _logger.info(f"Features ({len(x_data.columns)}): {', '.join(x_data.columns)}") model_configs["features"] = list(x_data.columns) y_data = full_df["label"] + x_train, x_test, y_train, y_test = train_test_split( x_data, y_data, @@ -585,7 +586,7 @@ def train_classification(df, model_configs): for name, cfg in model_configs.get("models", {}).items(): _logger.info(f"Training {name}") model = xgb.XGBClassifier(**cfg.get("hyper_parameters", {})) - model.fit(x_train, y_train) + model.fit(x_train, y_train, eval_set=[(x_test, y_test)], verbose=True) evaluate_classification_model(model, x_test, y_test, full_df, x_data.columns.tolist(), name) cfg["model"] = model cfg["efficiency"] = evaluation_efficiency(name, model, x_test, y_test) @@ -631,7 +632,7 @@ def _log_energy_bin_counts(df): _logger.info(f"Energy bin weights (inverse-count, normalized): {inverse_counts}") - # Calculate multiplicity weights (inverse frequency) + # Calculate multiplicity weights (prioritize higher-multiplicity events) mult_counts = df["DispNImages"].value_counts() _logger.info("Training events per multiplicity:") for mult, count in mult_counts.items():