Polars datamodule
PyTorch Lightning DataModule for loading dataset using Polars.
PolarsDataModule
#
Bases: LightningDataModule
PyTorch Lightning DataModule for loading dataset using Polars.
Source code in src/data/polars_datamodule.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
|
__init__(data_path, output_column, batch_size=32, num_workers=0, test_size=0.2)
#
Initialize the PolarsDataModule.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_path
|
str
|
Path to the dataset. |
required |
output_column
|
str
|
Column name that contains the labels. |
required |
batch_size
|
int
|
Batch size for the dataloaders. |
32
|
num_workers
|
int
|
Number of workers for the dataloaders. |
0
|
test_size
|
float
|
Fraction of the dataset to be used for validation. |
0.2
|
Source code in src/data/polars_datamodule.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
|
setup(stage='')
#
Load and split the dataset into train and validation sets.
Source code in src/data/polars_datamodule.py
55 56 57 58 59 60 61 62 63 64 |
|
train_dataloader()
#
Create and return the train dataloader.
Source code in src/data/polars_datamodule.py
66 67 68 |
|
val_dataloader()
#
Create and return the validation dataloader.
Source code in src/data/polars_datamodule.py
70 71 72 |
|
PolarsDataset
#
Bases: Dataset
Custom PyTorch Dataset wrapping a Polars DataFrame.
Source code in src/data/polars_datamodule.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
|
__getitem__(idx)
#
Return the features and label for the given index.
Source code in src/data/polars_datamodule.py
23 24 25 26 27 28 |
|
__init__(df, output_column)
#
Initialize the PolarsDataset.
Source code in src/data/polars_datamodule.py
14 15 16 17 |
|
__len__()
#
Return the number of rows in the dataset.
Source code in src/data/polars_datamodule.py
19 20 21 |
|