nest_cv applies rsample::vfold_cv to each nested data frame within a
data.table, returning an expanded result table containing the corresponding
training and validation splits for each row.
nest_cv(
nest_dt,
v = 10L,
repeats = 1L,
strata = NULL,
breaks = 4L,
pool = 0.1,
...
)A data.frame or data.table containing at least one nested
data.frame/data.table column.
Number of folds. Must be an integer >= 2. Default is 10.
Number of repeats. Must be an integer >= 1. Default is 1.
A single character string specifying the stratification column
name. Set to NULL for no stratification. Default is NULL.
Number of bins for stratifying a numeric variable. Only used
when strata is non-NULL. Default is 4.
Proportion threshold for pooling small strata. Only used when
strata is non-NULL. Default is 0.1.
Additional arguments passed to rsample::vfold_cv.
A data.table with the following columns:
All non-nested columns from nest_dt (broadcast across CV rows).
splits — cross-validation split objects from rsample::vfold_cv.
id (and id2 for repeated CV) — fold identifiers.
train — list column of training data frames for each split.
validate — list column of validation data frames for each split.
The function performs the following steps:
Validates that nest_dt is a non-empty data.frame or data.table
with at least one nested column whose elements all inherit from
data.frame.
Selects the target nested column: prefers a column named "data";
otherwise falls back to the first detected nested column.
When strata is specified, verifies that the column exists in every
nested data frame before calling rsample::vfold_cv.
Iterates over each row, applies vfold_cv via do.call, expands the
resulting splits into a data.table, and broadcasts the row's
non-nested metadata columns across all CV rows.
Combines all per-row results with rbindlist in a single pass.
nest_dt must contain at least one nested column of data.frames
or data.tables.
as.data.table() is used instead of data.table::copy(): if the
input is already a data.table, no copy is made.
strata must be a column name present in all nested data frames.
breaks and pool are forwarded to rsample::vfold_cv only when
strata is non-NULL, avoiding invalid argument errors.
The per-row loop with rbindlist corrects a silent bug in naive
chained [ approaches where all rows incorrectly shared the first
row's CV splits.
rsample::vfold_cv() — underlying cross-validation function
rsample::training() — extract training set from a split
rsample::testing() — extract test/validation set from a split
# Example: Cross-validation for nested data.table demonstrations
# Setup test data
dt_nest <- w2l_nest(
data = iris, # Input dataset
cols2l = 1:2 # Nest first 2 columns
)
# Example 1: Basic 2-fold cross-validation
nest_cv(
nest_dt = dt_nest, # Input nested data.table
v = 2 # Number of folds (2-fold CV)
)
#> name splits id train
#> <char> <list> <char> <list>
#> 1: Sepal.Length <vfold_split[75x75x150x4]> Fold1 <data.table[75x4]>
#> 2: Sepal.Length <vfold_split[75x75x150x4]> Fold2 <data.table[75x4]>
#> 3: Sepal.Width <vfold_split[75x75x150x4]> Fold1 <data.table[75x4]>
#> 4: Sepal.Width <vfold_split[75x75x150x4]> Fold2 <data.table[75x4]>
#> validate
#> <list>
#> 1: <data.table[75x4]>
#> 2: <data.table[75x4]>
#> 3: <data.table[75x4]>
#> 4: <data.table[75x4]>
# Example 2: Repeated 2-fold cross-validation
nest_cv(
nest_dt = dt_nest, # Input nested data.table
v = 2, # Number of folds (2-fold CV)
repeats = 2 # Number of repetitions
)
#> name splits id id2 train
#> <char> <list> <char> <char> <list>
#> 1: Sepal.Length <vfold_split[75x75x150x4]> Repeat1 Fold1 <data.table[75x4]>
#> 2: Sepal.Length <vfold_split[75x75x150x4]> Repeat1 Fold2 <data.table[75x4]>
#> 3: Sepal.Length <vfold_split[75x75x150x4]> Repeat2 Fold1 <data.table[75x4]>
#> 4: Sepal.Length <vfold_split[75x75x150x4]> Repeat2 Fold2 <data.table[75x4]>
#> 5: Sepal.Width <vfold_split[75x75x150x4]> Repeat1 Fold1 <data.table[75x4]>
#> 6: Sepal.Width <vfold_split[75x75x150x4]> Repeat1 Fold2 <data.table[75x4]>
#> 7: Sepal.Width <vfold_split[75x75x150x4]> Repeat2 Fold1 <data.table[75x4]>
#> 8: Sepal.Width <vfold_split[75x75x150x4]> Repeat2 Fold2 <data.table[75x4]>
#> validate
#> <list>
#> 1: <data.table[75x4]>
#> 2: <data.table[75x4]>
#> 3: <data.table[75x4]>
#> 4: <data.table[75x4]>
#> 5: <data.table[75x4]>
#> 6: <data.table[75x4]>
#> 7: <data.table[75x4]>
#> 8: <data.table[75x4]>