pkgdown/extra.css

Skip to contents

hd_model_test() validates the model on new data. It takes an already tuned model, evaluates it on the validation (new test) set, calculates the metrics and plots the probability and ROC curve based on the new data.

Usage

hd_model_test(
  model_object,
  train_set,
  test_set,
  variable = "Disease",
  metadata_cols = NULL,
  case,
  control = NULL,
  balance_groups = TRUE,
  palette = NULL,
  seed = 123
)

Arguments

model_object

An hd_model object coming from hd_model_rreg() and hd_model_rf() binary or multiclass classification.

train_set

The training set as an HDAnalyzeR object or a dataset in wide format with sample ID as its first column and class column as its second column.

test_set

The validation/test set as an HDAnalyzeR object or a dataset in wide format with sample ID as its first column and class column as its second column.

variable

The name of the metadata variable containing the case and control groups. Default is "Disease".

metadata_cols

The metadata variables to include in the analysis. Default is NULL.

case

The case class.

control

The control groups. If NULL, it will be set to all other unique values of the variable that are not the case. Default is NULL.

balance_groups

Whether to balance the groups in the train set. It is only valid in binary classification settings. Default is TRUE.

palette

The color palette for the classes. If it is a character, it should be one of the palettes from hd_palettes(). Default is NULL.

seed

Seed for reproducibility. Default is 123.

Value

The model object containing the validation set, the metrics, the ROC curve, the probability plot, and the confusion matrix for the new data.

Details

In order to run this function, the train and test sets should be in exactly the same format meaning that they must have the same columns in the same order. Some function arguments like the case/control, variable, and metadata_cols should be also the same. If the data contain missing values, KNN (k=5) imputation will be used to impute. If case is provided, the model will be a binary classification model. If case is NULL, the model will be a multiclass classification model.

In multi-class models, the groups in the train set are not balanced and sensitivity and specificity are calculated via macro-averaging. In case the model is run against a continuous variable, the palette will be ignored.

Examples

# Initialize an HDAnalyzeR object
hd_object <- hd_initialize(example_data, example_metadata)

# Split the data for training and validation sets
dat <- hd_object$data
train_indices <- sample(1:nrow(dat), size = floor(0.8 * nrow(dat)))
train_data <- dat[train_indices, ]
validation_data <- dat[-train_indices, ]

hd_object_train <- hd_initialize(train_data, example_metadata, is_wide = TRUE)
hd_object_val <- hd_initialize(validation_data, example_metadata, is_wide = TRUE)

# Split the training set into training and inner test sets
hd_split <- hd_split_data(hd_object_train, variable = "Disease")
#> Warning: Too little data to stratify.
#>  Resampling will be unstratified.

# Run the regularized regression model pipeline
model_object <- hd_model_rreg(hd_split,
                              variable = "Disease",
                              case = "AML",
                              grid_size = 5,
                              palette = "cancers12",
                              verbose = FALSE)
#> The groups in the train set are balanced. If you do not want to balance the groups, set `balance_groups = FALSE`.

# Run the model evaluation pipeline
hd_model_test(model_object, hd_object_train, hd_object_val, case = "AML", palette = "cancers12")
#> The groups in the train set are balanced. If you do not want to balance the groups, set `balance_groups = FALSE`.
#> $train_data
#> # A tibble: 52 × 102
#>    DAid    Disease AARSD1      ABL1   ACAA1    ACAN   ACE2  ACOX1    ACP5   ACP6
#>    <chr>   <fct>    <dbl>     <dbl>   <dbl>   <dbl>  <dbl>  <dbl>   <dbl>  <dbl>
#>  1 DA00014 1         6.34   7.25e+0  5.12    0.0193  1.29   0.370 -0.382   0.830
#>  2 DA00026 1         4.92   1.89e+0  0.560   0.558   2.39   0.455  0.743  -0.955
#>  3 DA00023 1         2.92  -7.06e-5  0.602   1.59    0.198  1.61   0.283   2.35 
#>  4 DA00039 1         4.26   5.72e-1 -1.97   -0.433   0.208  0.790 -0.236   1.52 
#>  5 DA00013 1         1.31   2.52e+0  1.11    0.997   4.56  -1.35   0.833   2.33 
#>  6 DA00005 1         5.01   5.05e+0  0.128   0.401  -0.933 -0.584  0.0265  1.16 
#>  7 DA00022 1         7.07   5.67e+0  3.68   -0.458   3.09   0.690  0.649   2.17 
#>  8 DA00042 1         3.23   3.12e+0  4.20   -1.05    1.15   0.957  0.491   1.73 
#>  9 DA00011 1         3.48   4.96e+0  3.50   -0.338   4.48   1.26   2.18    1.62 
#> 10 DA00040 1        NA     NA        0.0831  0.858   1.38   0.183  1.33    0.606
#> # ℹ 42 more rows
#> # ℹ 92 more variables: ACTA2 <dbl>, ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>,
#> #   ADA2 <dbl>, ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>,
#> #   ADAMTS15 <dbl>, ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>,
#> #   ADGRE2 <dbl>, ADGRE5 <dbl>, ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>,
#> #   ADM <dbl>, AGER <dbl>, AGR2 <dbl>, AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>,
#> #   AGXT <dbl>, AHCY <dbl>, AHSP <dbl>, AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, …
#> 
#> $test_data
#> # A tibble: 117 × 102
#>    DAid    Disease AARSD1   ABL1   ACAA1    ACAN   ACE2  ACOX1   ACP5    ACP6
#>    <chr>   <fct>    <dbl>  <dbl>   <dbl>   <dbl>  <dbl>  <dbl>  <dbl>   <dbl>
#>  1 DA00415 0         2.81  1.53  -0.242   0.269  0.943   0.350 1.78   -0.0618
#>  2 DA00179 0         3.95  3.29   1.27   -0.0771 0.112   0.361 0.0589 -0.525 
#>  3 DA00229 0         2.86  3.89   3.42    1.26   0.883   2.70  1.13    2.60  
#>  4 DA00244 0         1.68  0.150  1.50    0.669  0.873   0.107 0.775   1.48  
#>  5 DA00426 0         3.79  4.16   3.24    0.935  0.250  -0.592 0.517   3.97  
#>  6 DA00211 0         1.82  0.807  1.85   -0.0552 0.924   1.08  0.403   0.487 
#>  7 DA00555 0         4.20  2.65   1.71    0.864  1.10    1.17  0.997   2.27  
#>  8 DA00217 0        NA    NA      2.04    1.86   0.0900 -0.258 0.788   1.98  
#>  9 DA00041 1         2.19  1.66  -0.0167 -0.567  3.77    0.369 1.38    1.09  
#> 10 DA00431 0         3.24  2.10   1.17    0.221  3.36    0.664 1.96    0.978 
#> # ℹ 107 more rows
#> # ℹ 92 more variables: ACTA2 <dbl>, ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>,
#> #   ADA2 <dbl>, ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>,
#> #   ADAMTS15 <dbl>, ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>,
#> #   ADGRE2 <dbl>, ADGRE5 <dbl>, ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>,
#> #   ADM <dbl>, AGER <dbl>, AGR2 <dbl>, AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>,
#> #   AGXT <dbl>, AHCY <dbl>, AHSP <dbl>, AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, …
#> 
#> $model_type
#> [1] "binary_class"
#> 
#> $final_workflow
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: logistic_reg()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 5 Recipe Steps
#> 
#> • step_dummy()
#> • step_nzv()
#> • step_normalize()
#> • step_corr()
#> • step_impute_knn()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Logistic Regression Model Specification (classification)
#> 
#> Main Arguments:
#>   penalty = 1.51967190723888e-09
#>   mixture = 0.936502232041676
#> 
#> Computational engine: glmnet 
#> 
#> 
#> $metrics
#> $metrics$accuracy
#> [1] 0.6666667
#> 
#> $metrics$sensitivity
#> [1] 0.7142857
#> 
#> $metrics$specificity
#> [1] 0.6601942
#> 
#> $metrics$auc
#> [1] 0.7274619
#> 
#> $metrics$confusion_matrix
#>           Truth
#> Prediction  0  1
#>          0 68  4
#>          1 35 10
#> 
#> 
#> $roc_curve

#> 
#> $probability_plot

#> 
#> $mixture
#> [1] 0.9365022
#> 
#> $features
#> # A tibble: 28 × 4
#>    Feature Importance Sign  Scaled_Importance
#>    <fct>        <dbl> <chr>             <dbl>
#>  1 ADGRG1       1.76  POS               1    
#>  2 ANGPT1       1.55  NEG               0.881
#>  3 B4GALT1      1.46  POS               0.829
#>  4 AGR3         1.18  NEG               0.675
#>  5 AGRN         1.06  POS               0.605
#>  6 ATP6V1D      0.885 NEG               0.504
#>  7 ATXN10       0.842 POS               0.480
#>  8 ARG1         0.756 POS               0.431
#>  9 AMIGO2       0.740 NEG               0.422
#> 10 ADA2         0.731 NEG               0.416
#> # ℹ 18 more rows
#> 
#> $feat_imp_plot

#> 
#> $validation_data
#> # A tibble: 118 × 102
#>    DAid   Disease AARSD1  ABL1  ACAA1    ACAN  ACE2   ACOX1   ACP5    ACP6 ACTA2
#>    <chr>  <fct>    <dbl> <dbl>  <dbl>   <dbl> <dbl>   <dbl>  <dbl>   <dbl> <dbl>
#>  1 DA000… 1        3.39  2.76   1.71   0.0333 1.76  -0.919   1.54   2.15   2.81 
#>  2 DA000… 1        1.42  1.25  -0.816 -0.459  0.826 -0.902   0.647  1.30   0.798
#>  3 DA000… 1        4.39  3.34  -0.452 -0.868  0.395  1.71    1.49  -0.0285 0.200
#>  4 DA000… 1        3.31  1.90  NA     -0.926  0.408  0.687   1.03   0.612  2.19 
#>  5 DA000… 1        1.46  0.832 -2.73  -0.371  2.27   0.0234  0.144  0.826  1.98 
#>  6 DA000… 1        2.62  2.48   0.537 -0.215  1.82   0.290   1.27   1.11   0.206
#>  7 DA000… 1        2.47  2.16  -0.486 NA      0.386 NA       1.38   0.536  1.86 
#>  8 DA000… 1        4.39  3.31   0.454  0.290  2.68   0.116  -1.32   0.945  2.14 
#>  9 DA000… 1        0.964 2.94   1.55   1.67   2.50   0.164   1.83   1.46   3.03 
#> 10 DA000… 1        3.03  0.390  1.83   0.983  2.60   0.113   0.504  1.42   1.22 
#> # ℹ 108 more rows
#> # ℹ 91 more variables: ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>, ADA2 <dbl>,
#> #   ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>, ADAMTS15 <dbl>,
#> #   ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>, ADGRE2 <dbl>, ADGRE5 <dbl>,
#> #   ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>, ADM <dbl>, AGER <dbl>, AGR2 <dbl>,
#> #   AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>, AGXT <dbl>, AHCY <dbl>, AHSP <dbl>,
#> #   AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, AKR1B1 <dbl>, AKR1C4 <dbl>, …
#> 
#> $test_metrics
#> $test_metrics$accuracy
#> [1] 0.720339
#> 
#> $test_metrics$sensitivity
#> [1] 0.9
#> 
#> $test_metrics$specificity
#> [1] 0.7037037
#> 
#> $test_metrics$auc
#> [1] 0.912963
#> 
#> $test_metrics$confusion_matrix
#>           Truth
#> Prediction  0  1
#>          0 76  1
#>          1 32  9
#> 
#> 
#> $test_roc_curve

#> 
#> $test_probability_plot

#> 
#> attr(,"class")
#> [1] "hd_model"

# Run the pipeline against continuous variable
# Split the training set into training and inner test sets
hd_split <- hd_split_data(hd_object_train, variable = "Age")

# Run the regularized regression model pipeline
model_object <- hd_model_rreg(hd_split,
                              variable = "Age",
                              case = "AML",
                              grid_size = 2,
                              cv_sets = 2,
                              plot_title = NULL,
                              verbose = FALSE)
#> The groups in the train set are balanced. If you do not want to balance the groups, set `balance_groups = FALSE`.

# Run the model evaluation pipeline
hd_model_test(model_object, hd_object_train, hd_object_val, variable = "Age", case = NULL)
#> The groups in the train set are balanced. If you do not want to balance the groups, set `balance_groups = FALSE`.
#> $train_data
#> # A tibble: 349 × 102
#>    DAid      Age AARSD1      ABL1 ACAA1    ACAN  ACE2  ACOX1  ACP5   ACP6  ACTA2
#>    <chr>   <dbl>  <dbl>     <dbl> <dbl>   <dbl> <dbl>  <dbl> <dbl>  <dbl>  <dbl>
#>  1 DA00374    47   3.61   2.56e+0 1.80   0.890  1.64   1.35  1.35   1.51   2.01 
#>  2 DA00355    43   2.97   3.93e+0 4.80   0.976  5.38   1.14  2.25   2.75   2.30 
#>  3 DA00026    44   4.92   1.89e+0 0.560  0.558  2.39   0.455 0.743 -0.955  0.458
#>  4 DA00426    50   3.79   4.16e+0 3.24   0.935  0.250 -0.592 0.517  3.97   2.98 
#>  5 DA00211    49   1.82   8.07e-1 1.85  -0.0552 0.924  1.08  0.403  0.487  0.374
#>  6 DA00023    42   2.92  -7.06e-5 0.602  1.59   0.198  1.61  0.283  2.35   2.11 
#>  7 DA00585    43   3.46   1.84e+0 2.28   1.11   1.49   0.303 1.31   0.972 -0.425
#>  8 DA00223    44   3.57   1.72e+0 1.88   0.535  0.631  0.732 1.57   1.35   0.801
#>  9 DA00034    42   3.45   2.91e+0 1.31   0.423  0.647  1.40  0.691  0.720  1.95 
#> 10 DA00308    54   3.25   2.03e-2 0.445  1.93   1.45  -0.193 1.95   1.32   2.23 
#> # ℹ 339 more rows
#> # ℹ 91 more variables: ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>, ADA2 <dbl>,
#> #   ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>, ADAMTS15 <dbl>,
#> #   ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>, ADGRE2 <dbl>, ADGRE5 <dbl>,
#> #   ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>, ADM <dbl>, AGER <dbl>, AGR2 <dbl>,
#> #   AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>, AGXT <dbl>, AHCY <dbl>, AHSP <dbl>,
#> #   AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, AKR1B1 <dbl>, AKR1C4 <dbl>, …
#> 
#> $test_data
#> # A tibble: 119 × 102
#>    DAid      Age AARSD1   ABL1    ACAA1   ACAN   ACE2   ACOX1  ACP5  ACP6 ACTA2
#>    <chr>   <dbl>  <dbl>  <dbl>    <dbl>  <dbl>  <dbl>   <dbl> <dbl> <dbl> <dbl>
#>  1 DA00526    80   2.70  1.61  -0.00585  0.271 0.777   0.457  0.809 0.200 1.54 
#>  2 DA00118    50   3.54  4.07   1.20     1.50  0.615   0.454  1.20  1.33  1.48 
#>  3 DA00229    52   2.86  3.89   3.42     1.26  0.883   2.70   1.13  2.60  5.13 
#>  4 DA00544    62   2.18  1.12   0.595    1.19  0.606   0.231  1.21  1.25  0.896
#>  5 DA00490    47   3.61  2.69   2.49     1.40  1.74    0.374  1.22  1.24  2.54 
#>  6 DA00166    85   3.00  1.86   2.18    -0.193 0.735   1.36   1.85  1.94  2.05 
#>  7 DA00217    89  NA    NA      2.04     1.86  0.0900 -0.258  0.788 1.98  0.563
#>  8 DA00290    50   3.31  0.583  1.09     0.390 0.323   0.0373 1.58  1.60  2.42 
#>  9 DA00141    50   3.86  4.03   0.633    0.388 1.00   -0.0532 1.40  2.01  1.05 
#> 10 DA00041    74   2.19  1.66  -0.0167  -0.567 3.77    0.369  1.38  1.09  2.09 
#> # ℹ 109 more rows
#> # ℹ 91 more variables: ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>, ADA2 <dbl>,
#> #   ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>, ADAMTS15 <dbl>,
#> #   ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>, ADGRE2 <dbl>, ADGRE5 <dbl>,
#> #   ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>, ADM <dbl>, AGER <dbl>, AGR2 <dbl>,
#> #   AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>, AGXT <dbl>, AHCY <dbl>, AHSP <dbl>,
#> #   AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, AKR1B1 <dbl>, AKR1C4 <dbl>, …
#> 
#> $model_type
#> [1] "regression"
#> 
#> $final_workflow
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: linear_reg()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 5 Recipe Steps
#> 
#> • step_dummy()
#> • step_nzv()
#> • step_normalize()
#> • step_corr()
#> • step_impute_knn()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Linear Regression Model Specification (regression)
#> 
#> Main Arguments:
#>   penalty = 0.0118952086562253
#>   mixture = 0.232995977951214
#> 
#> Computational engine: glmnet 
#> 
#> 
#> $metrics
#> $metrics$rmse
#> [1] 16.67657
#> 
#> $metrics$rsq
#> [1] 0.0002165823
#> 
#> 
#> $comparison_plot

#> 
#> $mixture
#> [1] 0.232996
#> 
#> $features
#> # A tibble: 99 × 4
#>    Feature  Importance Sign  Scaled_Importance
#>    <fct>         <dbl> <chr>             <dbl>
#>  1 AREG           4.13 POS               1    
#>  2 ALDH3A1        3.26 NEG               0.790
#>  3 APBB1IP        3.24 NEG               0.783
#>  4 ALDH1A1        2.79 POS               0.676
#>  5 ANGPT2         2.39 NEG               0.577
#>  6 APEX1          2.26 POS               0.545
#>  7 ACAN           2.10 POS               0.508
#>  8 ANXA11         2.00 POS               0.483
#>  9 ADA2           1.95 NEG               0.471
#> 10 ADAMTS13       1.92 NEG               0.465
#> # ℹ 89 more rows
#> 
#> $feat_imp_plot

#> 
#> $validation_data
#> # A tibble: 118 × 102
#>    DAid      Age AARSD1  ABL1  ACAA1    ACAN  ACE2   ACOX1   ACP5    ACP6 ACTA2
#>    <chr>   <dbl>  <dbl> <dbl>  <dbl>   <dbl> <dbl>   <dbl>  <dbl>   <dbl> <dbl>
#>  1 DA00001    42  3.39  2.76   1.71   0.0333 1.76  -0.919   1.54   2.15   2.81 
#>  2 DA00002    69  1.42  1.25  -0.816 -0.459  0.826 -0.902   0.647  1.30   0.798
#>  3 DA00009    80  4.39  3.34  -0.452 -0.868  0.395  1.71    1.49  -0.0285 0.200
#>  4 DA00015    47  3.31  1.90  NA     -0.926  0.408  0.687   1.03   0.612  2.19 
#>  5 DA00017    44  1.46  0.832 -2.73  -0.371  2.27   0.0234  0.144  0.826  1.98 
#>  6 DA00018    75  2.62  2.48   0.537 -0.215  1.82   0.290   1.27   1.11   0.206
#>  7 DA00028    78  2.47  2.16  -0.486 NA      0.386 NA       1.38   0.536  1.86 
#>  8 DA00035    59  4.39  3.31   0.454  0.290  2.68   0.116  -1.32   0.945  2.14 
#>  9 DA00044    72  0.964 2.94   1.55   1.67   2.50   0.164   1.83   1.46   3.03 
#> 10 DA00046    62  3.03  0.390  1.83   0.983  2.60   0.113   0.504  1.42   1.22 
#> # ℹ 108 more rows
#> # ℹ 91 more variables: ACTN4 <dbl>, ACY1 <dbl>, ADA <dbl>, ADA2 <dbl>,
#> #   ADAM15 <dbl>, ADAM23 <dbl>, ADAM8 <dbl>, ADAMTS13 <dbl>, ADAMTS15 <dbl>,
#> #   ADAMTS16 <dbl>, ADAMTS8 <dbl>, ADCYAP1R1 <dbl>, ADGRE2 <dbl>, ADGRE5 <dbl>,
#> #   ADGRG1 <dbl>, ADGRG2 <dbl>, ADH4 <dbl>, ADM <dbl>, AGER <dbl>, AGR2 <dbl>,
#> #   AGR3 <dbl>, AGRN <dbl>, AGRP <dbl>, AGXT <dbl>, AHCY <dbl>, AHSP <dbl>,
#> #   AIF1 <dbl>, AIFM1 <dbl>, AK1 <dbl>, AKR1B1 <dbl>, AKR1C4 <dbl>, …
#> 
#> $test_metrics
#> $test_metrics$rmse
#> [1] 18.98604
#> 
#> $test_metrics$rsq
#> [1] 0.02416074
#> 
#> 
#> $test_comparison_plot

#> 
#> attr(,"class")
#> [1] "hd_model"