Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/48.maintenance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add early stopping to classification. Increase number of estimators.
11 changes: 6 additions & 5 deletions src/eventdisplay_ml/hyper_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
"model": None,
"hyper_parameters": {
"objective": "binary:logistic",
"eval_metric": "logloss", # TODO AUC ?
"n_estimators": 100, # TODO probably too low
"max_depth": 6,
"learning_rate": 0.1,
"eval_metric": ["logloss", "auc"],
"n_estimators": 5000,
"early_stopping_rounds": 50,
"max_depth": 7,
"learning_rate": 0.05,
"subsample": 0.8,
"colsample_bytree": 0.8,
"random_state": None,
"n_jobs": 8,
"n_jobs": 48,
},
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/eventdisplay_ml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def train_classification(df, model_configs):
_logger.info(f"Features ({len(x_data.columns)}): {', '.join(x_data.columns)}")
model_configs["features"] = list(x_data.columns)
y_data = full_df["label"]

x_train, x_test, y_train, y_test = train_test_split(
x_data,
y_data,
Expand All @@ -585,7 +586,7 @@ def train_classification(df, model_configs):
for name, cfg in model_configs.get("models", {}).items():
_logger.info(f"Training {name}")
model = xgb.XGBClassifier(**cfg.get("hyper_parameters", {}))
model.fit(x_train, y_train)
model.fit(x_train, y_train, eval_set=[(x_test, y_test)], verbose=True)
evaluate_classification_model(model, x_test, y_test, full_df, x_data.columns.tolist(), name)
cfg["model"] = model
cfg["efficiency"] = evaluation_efficiency(name, model, x_test, y_test)
Expand Down Expand Up @@ -631,7 +632,7 @@ def _log_energy_bin_counts(df):

_logger.info(f"Energy bin weights (inverse-count, normalized): {inverse_counts}")

# Calculate multiplicity weights (inverse frequency)
# Calculate multiplicity weights (prioritize higher-multiplicity events)
mult_counts = df["DispNImages"].value_counts()
_logger.info("Training events per multiplicity:")
for mult, count in mult_counts.items():
Expand Down