Polars datamodule
PyTorch Lightning DataModule for loading dataset using Polars.
            PolarsDataModule
#
    
              Bases: LightningDataModule
PyTorch Lightning DataModule for loading dataset using Polars.
Source code in src/data/polars_datamodule.py
                32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72  |  | 
            __init__(data_path, output_column, batch_size=32, num_workers=0, test_size=0.2)
#
    Initialize the PolarsDataModule.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
                data_path
             | 
            
                  str
             | 
            
               Path to the dataset.  | 
            required | 
                output_column
             | 
            
                  str
             | 
            
               Column name that contains the labels.  | 
            required | 
                batch_size
             | 
            
                  int
             | 
            
               Batch size for the dataloaders.  | 
            
                  32
             | 
          
                num_workers
             | 
            
                  int
             | 
            
               Number of workers for the dataloaders.  | 
            
                  0
             | 
          
                test_size
             | 
            
                  float
             | 
            
               Fraction of the dataset to be used for validation.  | 
            
                  0.2
             | 
          
Source code in src/data/polars_datamodule.py
              35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53  |  | 
            setup(stage='')
#
    Load and split the dataset into train and validation sets.
Source code in src/data/polars_datamodule.py
              55 56 57 58 59 60 61 62 63 64  |  | 
            train_dataloader()
#
    Create and return the train dataloader.
Source code in src/data/polars_datamodule.py
              66 67 68  |  | 
            val_dataloader()
#
    Create and return the validation dataloader.
Source code in src/data/polars_datamodule.py
              70 71 72  |  | 
            PolarsDataset
#
    
              Bases: Dataset
Custom PyTorch Dataset wrapping a Polars DataFrame.
Source code in src/data/polars_datamodule.py
                11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28  |  | 
            __getitem__(idx)
#
    Return the features and label for the given index.
Source code in src/data/polars_datamodule.py
              23 24 25 26 27 28  |  | 
            __init__(df, output_column)
#
    Initialize the PolarsDataset.
Source code in src/data/polars_datamodule.py
              14 15 16 17  |  | 
            __len__()
#
    Return the number of rows in the dataset.
Source code in src/data/polars_datamodule.py
              19 20 21  |  |