diff --git a/R/umap.R b/R/umap.R index 78976bc1..12de981c 100644 --- a/R/umap.R +++ b/R/umap.R @@ -257,8 +257,16 @@ prep.step_umap <- function(x, training, info = NULL, ...) { check_number_decimal(x$min_dist, arg = "min_dist") check_number_decimal(x$learn_rate, min = 0, arg = "learn_rate") check_number_whole(x$epochs, min = 0, allow_null = TRUE, arg = "epochs") - rlang::arg_match0(x$initial, initial_umap_values, arg_nm = "initial") - check_number_decimal(x$target_weight, min = 0, max = 1, arg = "target_weight") + if (!is.null(x$initial)) { + rlang::arg_match0(x$initial, initial_umap_values, arg_nm = "initial") + } + check_number_decimal( + x$target_weight, + min = 0, + max = 1, + allow_null = TRUE, + arg = "target_weight" + ) if (length(col_names) > 0) { if (length(x$outcome) > 0) { @@ -267,7 +275,7 @@ prep.step_umap <- function(x, training, info = NULL, ...) { y_name <- NULL } x$neighbors <- min(nrow(training) - 1, x$neighbors) - x$num_comp <- min(length(col_names) - 1, x$num_comp) + x$num_comp <- min(length(col_names), x$num_comp) if (is.null(x$initial)) { x$initial <- "spectral" diff --git a/tests/testthat/_snaps/umap.md b/tests/testthat/_snaps/umap.md index ccb11ce3..c5c2ad7e 100644 --- a/tests/testthat/_snaps/umap.md +++ b/tests/testthat/_snaps/umap.md @@ -1,3 +1,13 @@ +# check_name() is used + + Code + prep(rec, training = dat) + Condition + Error in `step_umap()`: + Caused by error in `bake()`: + ! Name collision occurred. The following variable names already exist: + * `UMAP1` + # bad args Code @@ -59,7 +69,7 @@ Condition Error in `step_umap()`: Caused by error in `prep()`: - ! `target_weight` must be a number between 0 and 1, not the number -4. + ! `target_weight` must be a number between 0 and 1 or `NULL`, not the number -4. --- @@ -69,6 +79,14 @@ Error in `step_umap()`: ! `prefix` must be a single string, not `NULL`. +# bake method errors when needed non-standard role columns are missing + + Code + bake(rec_trained, new_data = tr[, -4]) + Condition + Error in `step_umap()`: + ! The following required column is missing from `new_data`: Petal.Width. + # empty printing Code @@ -104,3 +122,45 @@ -- Operations * UMAP embedding for: | Trained +# keep_original_cols - can prep recipes with it missing + + Code + rec <- prep(rec) + Condition + Warning: + `keep_original_cols` was added to `step_umap()` after this recipe was created. + i Regenerate your recipe to avoid this warning. + +# printing + + Code + print(rec) + Message + + -- Recipe ---------------------------------------------------------------------- + + -- Inputs + Number of variables by role + predictor: 4 + + -- Operations + * UMAP embedding for: all_predictors() + +--- + + Code + prep(rec) + Message + + -- Recipe ---------------------------------------------------------------------- + + -- Inputs + Number of variables by role + predictor: 4 + + -- Training information + Training data contained 133 data points and no incomplete rows. + + -- Operations + * UMAP embedding for: Sepal.Length, Sepal.Width, Petal.Length, ... | Trained + diff --git a/tests/testthat/test-umap.R b/tests/testthat/test-umap.R index 89c02610..abcdeb6b 100644 --- a/tests/testthat/test-umap.R +++ b/tests/testthat/test-umap.R @@ -267,9 +267,15 @@ test_that("backwards compatible for initial and target_weight args (#213)", { rec$steps[[1]]$initial <- NULL rec$steps[[1]]$target_weight <- NULL + remove_fit_times <- function(x) { + x$fit_times <- NULL + x$steps[[1]]$object$nn_index <- NULL + x + } + expect_identical( - prep(rec), - exp_res + remove_fit_times(prep(rec)), + remove_fit_times(exp_res) ) })