Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions R/metrics-binary.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,32 @@ assert_input_binary <- function(observed, predicted) {
assert_factor(observed, n.levels = 2, min.len = 1)
assert_numeric(predicted, lower = 0, upper = 1)
assert_dims_ok_scalar(observed, predicted)

# Warn if factor levels appear to be in counterintuitive order.
# Predictions represent P(outcome = highest factor level). If the levels
# are e.g. c("1", "0"), the highest level is "0", meaning predictions are
# interpreted as P(outcome = "0"), which is almost certainly unintended.
lvls <- levels(observed)
counterintuitive <- FALSE
if (setequal(lvls, c("0", "1")) && lvls[1] == "1") {
counterintuitive <- TRUE
} else if (setequal(lvls, c("TRUE", "FALSE")) && lvls[1] == "TRUE") {
counterintuitive <- TRUE
}
if (counterintuitive) {
#nolint start: keyword_quote_linter
cli_warn(c(
"!" = "Factor levels of {.var observed} appear to be in
counterintuitive order: {.val {lvls}}.",
"i" = "Predictions will be interpreted as the probability of
observing {.val {lvls[2]}} (the highest factor level).",
"i" = "If this is not intended, consider reordering the factor levels,
e.g. {.code factor(observed, levels = c({.val {lvls[2]}},
{.val {lvls[1]}}))}"
))
#nolint end
}

return(invisible(NULL))
}

Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test-class-forecast-binary.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,27 @@ test_that("test_columns_not_present() is no longer exported or defined", {
})


test_that("as_forecast_binary() warns when data has reversed 0-1 factor levels", {
dt <- data.table(
model = "m1",
id = 1:4,
observed = factor(c(0, 1, 1, 0), levels = c("1", "0")),
predicted = c(0.1, 0.9, 0.8, 0.2)
)
expect_warning(
as_forecast_binary(dt),
"counterintuitive"
)
})

test_that("score() produces correct results with standard 0-1 factor levels", {
# example_binary has standard levels c("0", "1"), should not warn about levels
expect_no_warning(
suppressMessages(score(as_forecast_binary(example_binary)))
)
})


# ==============================================================================
# score.forecast_binary() # nolint: commented_code_linter
# ==============================================================================
Expand Down
58 changes: 58 additions & 0 deletions tests/testthat/test-metrics-binary.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,64 @@ test_that("function throws an error for wrong input formats", {
})


# ==============================================================================
# Test factor level order warning
# ==============================================================================
test_that("assert_input_binary() warns when 0-1 factor levels are in counterintuitive order", {
observed_rev <- factor(c(0, 1, 1, 0, 1), levels = c("1", "0"))
predicted_rev <- c(0.1, 0.9, 0.8, 0.2, 0.7)
expect_warning(
assert_input_binary(observed_rev, predicted_rev),
"counterintuitive"
)
})

test_that("assert_input_binary() does not warn for standard 0-1 level order", {
observed_std <- factor(c(0, 1, 1, 0, 1), levels = c("0", "1"))
predicted_std <- c(0.1, 0.9, 0.8, 0.2, 0.7)
expect_no_warning(assert_input_binary(observed_std, predicted_std))
})

test_that("assert_input_binary() does not warn for non-numeric factor levels", {
observed_ab <- factor(c("a", "b", "b", "a"), levels = c("a", "b"))
predicted_ab <- c(0.3, 0.7, 0.6, 0.4)
expect_no_warning(assert_input_binary(observed_ab, predicted_ab))
})

test_that("brier_score() produces different results with reversed factor levels", {
observed_correct <- factor(c(0, 1, 1, 0), levels = c("0", "1"))
observed_reversed <- factor(c(0, 1, 1, 0), levels = c("1", "0"))
predicted_bs <- c(0.1, 0.9, 0.8, 0.2)

scores_correct <- brier_score(observed_correct, predicted_bs)
expect_equal(scores_correct, c(0.01, 0.01, 0.04, 0.04)) # nolint: expect_identical_linter

expect_warning(
scores_reversed <- brier_score(observed_reversed, predicted_bs),
"counterintuitive"
)
expect_false(all(scores_correct == scores_reversed))
})

test_that("logs_binary() warns with reversed 0-1 factor levels", {
observed_reversed <- factor(c(0, 1, 1, 0), levels = c("1", "0"))
predicted_lb <- c(0.1, 0.9, 0.8, 0.2)
expect_warning(
logs_binary(observed_reversed, predicted_lb),
"counterintuitive"
)
})

test_that("assert_input_binary() warns for TRUE-FALSE levels in counterintuitive order", {
observed_tf <- factor(c(TRUE, FALSE, TRUE), levels = c("TRUE", "FALSE"))
predicted_tf <- c(0.8, 0.2, 0.9)
expect_warning(
assert_input_binary(observed_tf, predicted_tf),
"counterintuitive"
)
})


# ==============================================================================
# Test Binary Metrics
# ==============================================================================
Expand Down
Loading