nest_cv applies rsample::vfold_cv to each nested data frame within a data.table, returning an expanded result table containing the corresponding training and validation splits for each row.

nest_cv(
  nest_dt,
  v = 10L,
  repeats = 1L,
  strata = NULL,
  breaks = 4L,
  pool = 0.1,
  ...
)

Arguments

nest_dt

A data.frame or data.table containing at least one nested data.frame/data.table column.

v

Number of folds. Must be an integer >= 2. Default is 10.

repeats

Number of repeats. Must be an integer >= 1. Default is 1.

strata

A single character string specifying the stratification column name. Set to NULL for no stratification. Default is NULL.

breaks

Number of bins for stratifying a numeric variable. Only used when strata is non-NULL. Default is 4.

pool

Proportion threshold for pooling small strata. Only used when strata is non-NULL. Default is 0.1.

...

Additional arguments passed to rsample::vfold_cv.

Value

A data.table with the following columns:

  • All non-nested columns from nest_dt (broadcast across CV rows).

  • splits — cross-validation split objects from rsample::vfold_cv.

  • id (and id2 for repeated CV) — fold identifiers.

  • train — list column of training data frames for each split.

  • validate — list column of validation data frames for each split.

Details

The function performs the following steps:

  1. Validates that nest_dt is a non-empty data.frame or data.table with at least one nested column whose elements all inherit from data.frame.

  2. Selects the target nested column: prefers a column named "data"; otherwise falls back to the first detected nested column.

  3. When strata is specified, verifies that the column exists in every nested data frame before calling rsample::vfold_cv.

  4. Iterates over each row, applies vfold_cv via do.call, expands the resulting splits into a data.table, and broadcasts the row's non-nested metadata columns across all CV rows.

  5. Combines all per-row results with rbindlist in a single pass.

Note

  • nest_dt must contain at least one nested column of data.frames or data.tables.

  • as.data.table() is used instead of data.table::copy(): if the input is already a data.table, no copy is made.

  • strata must be a column name present in all nested data frames.

  • breaks and pool are forwarded to rsample::vfold_cv only when strata is non-NULL, avoiding invalid argument errors.

  • The per-row loop with rbindlist corrects a silent bug in naive chained [ approaches where all rows incorrectly shared the first row's CV splits.

See also

Examples

# Example: Cross-validation for nested data.table demonstrations

# Setup test data
dt_nest <- w2l_nest(
  data = iris,                   # Input dataset
  cols2l = 1:2                   # Nest first 2 columns
)

# Example 1: Basic 2-fold cross-validation
nest_cv(
  nest_dt = dt_nest,             # Input nested data.table
  v = 2                          # Number of folds (2-fold CV)
)
#>            name                     splits     id              train
#>          <char>                     <list> <char>             <list>
#> 1: Sepal.Length <vfold_split[75x75x150x4]>  Fold1 <data.table[75x4]>
#> 2: Sepal.Length <vfold_split[75x75x150x4]>  Fold2 <data.table[75x4]>
#> 3:  Sepal.Width <vfold_split[75x75x150x4]>  Fold1 <data.table[75x4]>
#> 4:  Sepal.Width <vfold_split[75x75x150x4]>  Fold2 <data.table[75x4]>
#>              validate
#>                <list>
#> 1: <data.table[75x4]>
#> 2: <data.table[75x4]>
#> 3: <data.table[75x4]>
#> 4: <data.table[75x4]>

# Example 2: Repeated 2-fold cross-validation
nest_cv(
  nest_dt = dt_nest,             # Input nested data.table
  v = 2,                         # Number of folds (2-fold CV)
  repeats = 2                    # Number of repetitions
)
#>            name                     splits      id    id2              train
#>          <char>                     <list>  <char> <char>             <list>
#> 1: Sepal.Length <vfold_split[75x75x150x4]> Repeat1  Fold1 <data.table[75x4]>
#> 2: Sepal.Length <vfold_split[75x75x150x4]> Repeat1  Fold2 <data.table[75x4]>
#> 3: Sepal.Length <vfold_split[75x75x150x4]> Repeat2  Fold1 <data.table[75x4]>
#> 4: Sepal.Length <vfold_split[75x75x150x4]> Repeat2  Fold2 <data.table[75x4]>
#> 5:  Sepal.Width <vfold_split[75x75x150x4]> Repeat1  Fold1 <data.table[75x4]>
#> 6:  Sepal.Width <vfold_split[75x75x150x4]> Repeat1  Fold2 <data.table[75x4]>
#> 7:  Sepal.Width <vfold_split[75x75x150x4]> Repeat2  Fold1 <data.table[75x4]>
#> 8:  Sepal.Width <vfold_split[75x75x150x4]> Repeat2  Fold2 <data.table[75x4]>
#>              validate
#>                <list>
#> 1: <data.table[75x4]>
#> 2: <data.table[75x4]>
#> 3: <data.table[75x4]>
#> 4: <data.table[75x4]>
#> 5: <data.table[75x4]>
#> 6: <data.table[75x4]>
#> 7: <data.table[75x4]>
#> 8: <data.table[75x4]>