hd_model_test()
validates the model on new data. It takes an already tuned model,
evaluates it on the validation (new test) set, calculates the metrics and plots the probability
and ROC curve based on the new data.
Usage
hd_model_test(
model_object,
train_set,
test_set,
variable = "Disease",
metadata_cols = NULL,
case,
control = NULL,
balance_groups = TRUE,
palette = NULL,
seed = 123
)
Arguments
- model_object
An
hd_model
object coming fromhd_model_rreg()
andhd_model_rf()
binary or multiclass classification.- train_set
The training set as an HDAnalyzeR object or a dataset in wide format with sample ID as its first column and class column as its second column.
- test_set
The validation/test set as an HDAnalyzeR object or a dataset in wide format with sample ID as its first column and class column as its second column.
- variable
The name of the metadata variable containing the case and control groups. Default is "Disease".
- metadata_cols
The metadata variables to include in the analysis. Default is NULL.
- case
The case class.
- control
The control groups. If NULL, it will be set to all other unique values of the variable that are not the case. Default is NULL.
- balance_groups
Whether to balance the groups in the train set. It is only valid in binary classification settings. Default is TRUE.
- palette
The color palette for the classes. If it is a character, it should be one of the palettes from
hd_palettes()
. Default is NULL.- seed
Seed for reproducibility. Default is 123.
Value
The model object containing the validation set, the metrics, the ROC curve, the probability plot, and the confusion matrix for the new data.
Details
In order to run this function, the train and test sets should be in exactly
the same format meaning that they must have the same columns in the same order.
Some function arguments like the case/control, variable, and metadata_cols should
be also the same. If the data contain missing values, KNN (k=5) imputation
will be used to impute. If case
is provided, the model will be a binary
classification model. If case
is NULL, the model will be a multiclass classification model.
In multi-class models, the groups in the train set are not balanced and sensitivity and specificity are calculated via macro-averaging. In case the model is run against a continuous variable, the palette will be ignored.
Examples
# Initialize an HDAnalyzeR object
hd_object <- hd_initialize(example_data, example_metadata)
# Split the data for training and validation sets
dat <- hd_object$data
train_indices <- sample(1:nrow(dat), size = floor(0.8 * nrow(dat)))
train_data <- dat[train_indices, ]
validation_data <- dat[-train_indices, ]
hd_object_train <- hd_initialize(train_data, example_metadata, is_wide = TRUE)
hd_object_val <- hd_initialize(validation_data, example_metadata, is_wide = TRUE)
# Split the training set into training and inner test sets
hd_split <- hd_split_data(hd_object_train, variable = "Disease")
#> Warning: Too little data to stratify.
#> • Resampling will be unstratified.
# Run the regularized regression model pipeline
model_object <- hd_model_rreg(hd_split,
variable = "Disease",
case = "AML",
grid_size = 5,
palette = "cancers12",
verbose = FALSE)
#> The groups in the train set are balanced. If you do not want to balance the groups, set `balance_groups = FALSE`.
# Run the model evaluation pipeline
hd_model_test(model_object, hd_object_train, hd_object_val, case = "AML", palette = "cancers12")
#> The groups in the train set are balanced. If you do not want to balance the groups, set `balance_groups = FALSE`.
#> $train_data
#> # A tibble: 52 × 102
#> DAid Disease AARSD1 ABL1 ACAA1 ACAN ACE2 ACOX1 ACP5 ACP6
#> <chr> <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 DA00014 1 6.34 7.25e+0 5.12 0.0193 1.29 0.370 -0.382 0.830
#> 2 DA00026 1 4.92 1.89e+0 0.560 0.558 2.39 0.455 0.743 -0.955
#> 3 DA00023 1 2.92 -7.06e-5 0.602 1.59 0.198 1.61 0.283 2.35
#> 4 DA00039 1 4.26 5.72e-1 -1.97 -0.433 0.208 0.790 -0.236 1.52
#> 5 DA00013 1 1.31 2.52e+0 1.11 0.997 4.56 -1.35 0.833 2.33
#> 6 DA00005 1 5.01 5.05e+0 0.128 0.401 -0.933 -0.584 0.0265 1.16
#> 7 DA00022 1 7.07 5.67e+0 3.68 -0.458 3.09 0.690 0.649 2.17
#> 8 DA00042 1 3.23 3.12e+0 4.20 -1.05 1.15 0.957 0.491 1.73
#> 9 DA00011 1 3.48 4.96e+0 3.50 -0.338 4.48 1.26 2.18 1.62
#> 10 DA00040 1 NA NA 0.0831 0.858 1.38 0.183 1.33 0.606
#> # ℹ 42 more rows
#> # ℹ 92 more variables: ACTA2 <dbl>, ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>,
#> # ADA2 <dbl>, ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>,
#> # ADAMTS15 <dbl>, ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>,
#> # ADGRE2 <dbl>, ADGRE5 <dbl>, ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>,
#> # ADM <dbl>, AGER <dbl>, AGR2 <dbl>, AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>,
#> # AGXT <dbl>, AHCY <dbl>, AHSP <dbl>, AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, …
#>
#> $test_data
#> # A tibble: 117 × 102
#> DAid Disease AARSD1 ABL1 ACAA1 ACAN ACE2 ACOX1 ACP5 ACP6
#> <chr> <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 DA00415 0 2.81 1.53 -0.242 0.269 0.943 0.350 1.78 -0.0618
#> 2 DA00179 0 3.95 3.29 1.27 -0.0771 0.112 0.361 0.0589 -0.525
#> 3 DA00229 0 2.86 3.89 3.42 1.26 0.883 2.70 1.13 2.60
#> 4 DA00244 0 1.68 0.150 1.50 0.669 0.873 0.107 0.775 1.48
#> 5 DA00426 0 3.79 4.16 3.24 0.935 0.250 -0.592 0.517 3.97
#> 6 DA00211 0 1.82 0.807 1.85 -0.0552 0.924 1.08 0.403 0.487
#> 7 DA00555 0 4.20 2.65 1.71 0.864 1.10 1.17 0.997 2.27
#> 8 DA00217 0 NA NA 2.04 1.86 0.0900 -0.258 0.788 1.98
#> 9 DA00041 1 2.19 1.66 -0.0167 -0.567 3.77 0.369 1.38 1.09
#> 10 DA00431 0 3.24 2.10 1.17 0.221 3.36 0.664 1.96 0.978
#> # ℹ 107 more rows
#> # ℹ 92 more variables: ACTA2 <dbl>, ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>,
#> # ADA2 <dbl>, ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>,
#> # ADAMTS15 <dbl>, ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>,
#> # ADGRE2 <dbl>, ADGRE5 <dbl>, ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>,
#> # ADM <dbl>, AGER <dbl>, AGR2 <dbl>, AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>,
#> # AGXT <dbl>, AHCY <dbl>, AHSP <dbl>, AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, …
#>
#> $model_type
#> [1] "binary_class"
#>
#> $final_workflow
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: logistic_reg()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 5 Recipe Steps
#>
#> • step_dummy()
#> • step_nzv()
#> • step_normalize()
#> • step_corr()
#> • step_impute_knn()
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Logistic Regression Model Specification (classification)
#>
#> Main Arguments:
#> penalty = 1.51967190723888e-09
#> mixture = 0.936502232041676
#>
#> Computational engine: glmnet
#>
#>
#> $metrics
#> $metrics$accuracy
#> [1] 0.6666667
#>
#> $metrics$sensitivity
#> [1] 0.7142857
#>
#> $metrics$specificity
#> [1] 0.6601942
#>
#> $metrics$auc
#> [1] 0.7274619
#>
#> $metrics$confusion_matrix
#> Truth
#> Prediction 0 1
#> 0 68 4
#> 1 35 10
#>
#>
#> $roc_curve
#>
#> $probability_plot
#>
#> $mixture
#> [1] 0.9365022
#>
#> $features
#> # A tibble: 28 × 4
#> Feature Importance Sign Scaled_Importance
#> <fct> <dbl> <chr> <dbl>
#> 1 ADGRG1 1.76 POS 1
#> 2 ANGPT1 1.55 NEG 0.881
#> 3 B4GALT1 1.46 POS 0.829
#> 4 AGR3 1.18 NEG 0.675
#> 5 AGRN 1.06 POS 0.605
#> 6 ATP6V1D 0.885 NEG 0.504
#> 7 ATXN10 0.842 POS 0.480
#> 8 ARG1 0.756 POS 0.431
#> 9 AMIGO2 0.740 NEG 0.422
#> 10 ADA2 0.731 NEG 0.416
#> # ℹ 18 more rows
#>
#> $feat_imp_plot
#>
#> $validation_data
#> # A tibble: 118 × 102
#> DAid Disease AARSD1 ABL1 ACAA1 ACAN ACE2 ACOX1 ACP5 ACP6 ACTA2
#> <chr> <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 DA000… 1 3.39 2.76 1.71 0.0333 1.76 -0.919 1.54 2.15 2.81
#> 2 DA000… 1 1.42 1.25 -0.816 -0.459 0.826 -0.902 0.647 1.30 0.798
#> 3 DA000… 1 4.39 3.34 -0.452 -0.868 0.395 1.71 1.49 -0.0285 0.200
#> 4 DA000… 1 3.31 1.90 NA -0.926 0.408 0.687 1.03 0.612 2.19
#> 5 DA000… 1 1.46 0.832 -2.73 -0.371 2.27 0.0234 0.144 0.826 1.98
#> 6 DA000… 1 2.62 2.48 0.537 -0.215 1.82 0.290 1.27 1.11 0.206
#> 7 DA000… 1 2.47 2.16 -0.486 NA 0.386 NA 1.38 0.536 1.86
#> 8 DA000… 1 4.39 3.31 0.454 0.290 2.68 0.116 -1.32 0.945 2.14
#> 9 DA000… 1 0.964 2.94 1.55 1.67 2.50 0.164 1.83 1.46 3.03
#> 10 DA000… 1 3.03 0.390 1.83 0.983 2.60 0.113 0.504 1.42 1.22
#> # ℹ 108 more rows
#> # ℹ 91 more variables: ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>, ADA2 <dbl>,
#> # ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>, ADAMTS15 <dbl>,
#> # ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>, ADGRE2 <dbl>, ADGRE5 <dbl>,
#> # ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>, ADM <dbl>, AGER <dbl>, AGR2 <dbl>,
#> # AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>, AGXT <dbl>, AHCY <dbl>, AHSP <dbl>,
#> # AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, AKR1B1 <dbl>, AKR1C4 <dbl>, …
#>
#> $test_metrics
#> $test_metrics$accuracy
#> [1] 0.720339
#>
#> $test_metrics$sensitivity
#> [1] 0.9
#>
#> $test_metrics$specificity
#> [1] 0.7037037
#>
#> $test_metrics$auc
#> [1] 0.912963
#>
#> $test_metrics$confusion_matrix
#> Truth
#> Prediction 0 1
#> 0 76 1
#> 1 32 9
#>
#>
#> $test_roc_curve
#>
#> $test_probability_plot
#>
#> attr(,"class")
#> [1] "hd_model"
# Run the pipeline against continuous variable
# Split the training set into training and inner test sets
hd_split <- hd_split_data(hd_object_train, variable = "Age")
# Run the regularized regression model pipeline
model_object <- hd_model_rreg(hd_split,
variable = "Age",
case = "AML",
grid_size = 2,
cv_sets = 2,
plot_title = NULL,
verbose = FALSE)
#> The groups in the train set are balanced. If you do not want to balance the groups, set `balance_groups = FALSE`.
# Run the model evaluation pipeline
hd_model_test(model_object, hd_object_train, hd_object_val, variable = "Age", case = NULL)
#> The groups in the train set are balanced. If you do not want to balance the groups, set `balance_groups = FALSE`.
#> $train_data
#> # A tibble: 349 × 102
#> DAid Age AARSD1 ABL1 ACAA1 ACAN ACE2 ACOX1 ACP5 ACP6 ACTA2
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 DA00374 47 3.61 2.56e+0 1.80 0.890 1.64 1.35 1.35 1.51 2.01
#> 2 DA00355 43 2.97 3.93e+0 4.80 0.976 5.38 1.14 2.25 2.75 2.30
#> 3 DA00026 44 4.92 1.89e+0 0.560 0.558 2.39 0.455 0.743 -0.955 0.458
#> 4 DA00426 50 3.79 4.16e+0 3.24 0.935 0.250 -0.592 0.517 3.97 2.98
#> 5 DA00211 49 1.82 8.07e-1 1.85 -0.0552 0.924 1.08 0.403 0.487 0.374
#> 6 DA00023 42 2.92 -7.06e-5 0.602 1.59 0.198 1.61 0.283 2.35 2.11
#> 7 DA00585 43 3.46 1.84e+0 2.28 1.11 1.49 0.303 1.31 0.972 -0.425
#> 8 DA00223 44 3.57 1.72e+0 1.88 0.535 0.631 0.732 1.57 1.35 0.801
#> 9 DA00034 42 3.45 2.91e+0 1.31 0.423 0.647 1.40 0.691 0.720 1.95
#> 10 DA00308 54 3.25 2.03e-2 0.445 1.93 1.45 -0.193 1.95 1.32 2.23
#> # ℹ 339 more rows
#> # ℹ 91 more variables: ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>, ADA2 <dbl>,
#> # ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>, ADAMTS15 <dbl>,
#> # ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>, ADGRE2 <dbl>, ADGRE5 <dbl>,
#> # ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>, ADM <dbl>, AGER <dbl>, AGR2 <dbl>,
#> # AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>, AGXT <dbl>, AHCY <dbl>, AHSP <dbl>,
#> # AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, AKR1B1 <dbl>, AKR1C4 <dbl>, …
#>
#> $test_data
#> # A tibble: 119 × 102
#> DAid Age AARSD1 ABL1 ACAA1 ACAN ACE2 ACOX1 ACP5 ACP6 ACTA2
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 DA00526 80 2.70 1.61 -0.00585 0.271 0.777 0.457 0.809 0.200 1.54
#> 2 DA00118 50 3.54 4.07 1.20 1.50 0.615 0.454 1.20 1.33 1.48
#> 3 DA00229 52 2.86 3.89 3.42 1.26 0.883 2.70 1.13 2.60 5.13
#> 4 DA00544 62 2.18 1.12 0.595 1.19 0.606 0.231 1.21 1.25 0.896
#> 5 DA00490 47 3.61 2.69 2.49 1.40 1.74 0.374 1.22 1.24 2.54
#> 6 DA00166 85 3.00 1.86 2.18 -0.193 0.735 1.36 1.85 1.94 2.05
#> 7 DA00217 89 NA NA 2.04 1.86 0.0900 -0.258 0.788 1.98 0.563
#> 8 DA00290 50 3.31 0.583 1.09 0.390 0.323 0.0373 1.58 1.60 2.42
#> 9 DA00141 50 3.86 4.03 0.633 0.388 1.00 -0.0532 1.40 2.01 1.05
#> 10 DA00041 74 2.19 1.66 -0.0167 -0.567 3.77 0.369 1.38 1.09 2.09
#> # ℹ 109 more rows
#> # ℹ 91 more variables: ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>, ADA2 <dbl>,
#> # ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>, ADAMTS15 <dbl>,
#> # ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>, ADGRE2 <dbl>, ADGRE5 <dbl>,
#> # ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>, ADM <dbl>, AGER <dbl>, AGR2 <dbl>,
#> # AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>, AGXT <dbl>, AHCY <dbl>, AHSP <dbl>,
#> # AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, AKR1B1 <dbl>, AKR1C4 <dbl>, …
#>
#> $model_type
#> [1] "regression"
#>
#> $final_workflow
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: linear_reg()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 5 Recipe Steps
#>
#> • step_dummy()
#> • step_nzv()
#> • step_normalize()
#> • step_corr()
#> • step_impute_knn()
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Linear Regression Model Specification (regression)
#>
#> Main Arguments:
#> penalty = 0.0118952086562253
#> mixture = 0.232995977951214
#>
#> Computational engine: glmnet
#>
#>
#> $metrics
#> $metrics$rmse
#> [1] 16.67657
#>
#> $metrics$rsq
#> [1] 0.0002165823
#>
#>
#> $comparison_plot
#>
#> $mixture
#> [1] 0.232996
#>
#> $features
#> # A tibble: 99 × 4
#> Feature Importance Sign Scaled_Importance
#> <fct> <dbl> <chr> <dbl>
#> 1 AREG 4.13 POS 1
#> 2 ALDH3A1 3.26 NEG 0.790
#> 3 APBB1IP 3.24 NEG 0.783
#> 4 ALDH1A1 2.79 POS 0.676
#> 5 ANGPT2 2.39 NEG 0.577
#> 6 APEX1 2.26 POS 0.545
#> 7 ACAN 2.10 POS 0.508
#> 8 ANXA11 2.00 POS 0.483
#> 9 ADA2 1.95 NEG 0.471
#> 10 ADAMTS13 1.92 NEG 0.465
#> # ℹ 89 more rows
#>
#> $feat_imp_plot
#>
#> $validation_data
#> # A tibble: 118 × 102
#> DAid Age AARSD1 ABL1 ACAA1 ACAN ACE2 ACOX1 ACP5 ACP6 ACTA2
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 DA00001 42 3.39 2.76 1.71 0.0333 1.76 -0.919 1.54 2.15 2.81
#> 2 DA00002 69 1.42 1.25 -0.816 -0.459 0.826 -0.902 0.647 1.30 0.798
#> 3 DA00009 80 4.39 3.34 -0.452 -0.868 0.395 1.71 1.49 -0.0285 0.200
#> 4 DA00015 47 3.31 1.90 NA -0.926 0.408 0.687 1.03 0.612 2.19
#> 5 DA00017 44 1.46 0.832 -2.73 -0.371 2.27 0.0234 0.144 0.826 1.98
#> 6 DA00018 75 2.62 2.48 0.537 -0.215 1.82 0.290 1.27 1.11 0.206
#> 7 DA00028 78 2.47 2.16 -0.486 NA 0.386 NA 1.38 0.536 1.86
#> 8 DA00035 59 4.39 3.31 0.454 0.290 2.68 0.116 -1.32 0.945 2.14
#> 9 DA00044 72 0.964 2.94 1.55 1.67 2.50 0.164 1.83 1.46 3.03
#> 10 DA00046 62 3.03 0.390 1.83 0.983 2.60 0.113 0.504 1.42 1.22
#> # ℹ 108 more rows
#> # ℹ 91 more variables: ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>, ADA2 <dbl>,
#> # ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>, ADAMTS15 <dbl>,
#> # ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>, ADGRE2 <dbl>, ADGRE5 <dbl>,
#> # ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>, ADM <dbl>, AGER <dbl>, AGR2 <dbl>,
#> # AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>, AGXT <dbl>, AHCY <dbl>, AHSP <dbl>,
#> # AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, AKR1B1 <dbl>, AKR1C4 <dbl>, …
#>
#> $test_metrics
#> $test_metrics$rmse
#> [1] 18.98604
#>
#> $test_metrics$rsq
#> [1] 0.02416074
#>
#>
#> $test_comparison_plot
#>
#> attr(,"class")
#> [1] "hd_model"