The nest_cv
function applies cross-validation splits to nested data frames or data tables within a data table. It uses the rsample
package's vfold_cv
function to create cross-validation splits for predictive modeling and analysis on nested datasets.
nest_cv(
nest_dt,
v = 10,
repeats = 1,
strata = NULL,
breaks = 4,
pool = 0.1,
...
)
A data.frame
or data.table
containing at least one nested
data.frame
or data.table
column.
Supports multi-level nested structures
Requires at least one nested data column
The number of partitions of the data set.
The number of times to repeat the V-fold partitioning.
A variable in data
(single character or name) used to conduct
stratified sampling. When not NULL
, each resample is created within the
stratification variable. Numeric strata
are binned into quartiles.
A single number giving the number of bins desired to stratify a numeric stratification variable.
A proportion of data used to determine if a particular group is too small and should be pooled into another group. We do not recommend decreasing this argument below its default of 0.1 because of the dangers of stratifying groups that are too small.
These dots are for future extensions and must be empty.
A data.table
containing the cross-validation splits for each nested dataset. It includes:
Original non-nested columns from nest_dt
.
splits
: The cross-validation split objects returned by rsample::vfold_cv
.
train
: The training data for each split.
validate
: The validation data for each split.
The function performs the following steps:
Checks if the input nest_dt
is non-empty and contains at least one nested column of data.frame
s or data.table
s.
Identifies the nested columns and non-nested columns within nest_dt
.
Applies rsample::vfold_cv
to each nested data frame in the specified nested column(s), creating the cross-validation splits.
Expands the cross-validation splits and associates them with the non-nested columns.
Extracts the training and validation data for each split and adds them to the output data table.
If the strata
parameter is provided, stratified sampling is performed during the cross-validation. Additional arguments can be passed to rsample::vfold_cv
via ...
.
The nest_dt
must contain at least one nested column of data.frame
s or data.table
s.
The function converts nest_dt
to a data.table
internally to ensure efficient data manipulation.
The strata
parameter should be a column name present in the nested data frames.
If strata
is specified, ensure that the specified column exists in all nested data frames.
The breaks
and pool
parameters are used when strata
is a numeric variable and control how stratification is handled.
Additional arguments passed through ...
are forwarded to rsample::vfold_cv
.
rsample::vfold_cv()
Underlying cross-validation function
rsample::training()
Extract training set
rsample::testing()
Extract test set
# Example: Cross-validation for nested data.table demonstrations
# Setup test data
dt_nest <- w2l_nest(
data = iris, # Input dataset
cols2l = 1:2 # Nest first 2 columns
)
# Example 1: Basic 2-fold cross-validation
nest_cv(
nest_dt = dt_nest, # Input nested data.table
v = 2 # Number of folds (2-fold CV)
)
#> name splits id train
#> <fctr> <list> <char> <list>
#> 1: Sepal.Length <vfold_split[75x75x150x4]> Fold1 <data.table[75x4]>
#> 2: Sepal.Length <vfold_split[75x75x150x4]> Fold2 <data.table[75x4]>
#> 3: Sepal.Width <vfold_split[75x75x150x4]> Fold1 <data.table[75x4]>
#> 4: Sepal.Width <vfold_split[75x75x150x4]> Fold2 <data.table[75x4]>
#> validate
#> <list>
#> 1: <data.table[75x4]>
#> 2: <data.table[75x4]>
#> 3: <data.table[75x4]>
#> 4: <data.table[75x4]>
# Example 2: Repeated 2-fold cross-validation
nest_cv(
nest_dt = dt_nest, # Input nested data.table
v = 2, # Number of folds (2-fold CV)
repeats = 2 # Number of repetitions
)
#> name splits id id2 train
#> <fctr> <list> <char> <char> <list>
#> 1: Sepal.Length <vfold_split[75x75x150x4]> Repeat1 Fold1 <data.table[75x4]>
#> 2: Sepal.Length <vfold_split[75x75x150x4]> Repeat1 Fold2 <data.table[75x4]>
#> 3: Sepal.Length <vfold_split[75x75x150x4]> Repeat2 Fold1 <data.table[75x4]>
#> 4: Sepal.Length <vfold_split[75x75x150x4]> Repeat2 Fold2 <data.table[75x4]>
#> 5: Sepal.Width <vfold_split[75x75x150x4]> Repeat1 Fold1 <data.table[75x4]>
#> 6: Sepal.Width <vfold_split[75x75x150x4]> Repeat1 Fold2 <data.table[75x4]>
#> 7: Sepal.Width <vfold_split[75x75x150x4]> Repeat2 Fold1 <data.table[75x4]>
#> 8: Sepal.Width <vfold_split[75x75x150x4]> Repeat2 Fold2 <data.table[75x4]>
#> validate
#> <list>
#> 1: <data.table[75x4]>
#> 2: <data.table[75x4]>
#> 3: <data.table[75x4]>
#> 4: <data.table[75x4]>
#> 5: <data.table[75x4]>
#> 6: <data.table[75x4]>
#> 7: <data.table[75x4]>
#> 8: <data.table[75x4]>