Skip to content

Mnist module

Mnist simple model.

MNISTLitModule #

Bases: LightningModule

Example of a LightningModule for MNIST classification.

A LightningModule implements 8 key methods:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
def __init__(self):
# Define initialization code here.

def setup(self, stage):
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
# This hook is called on every process when using DDP.

def training_step(self, batch, batch_idx):
# The complete training step.

def validation_step(self, batch, batch_idx):
# The complete validation step.

def test_step(self, batch, batch_idx):
# The complete test step.

def predict_step(self, batch, batch_idx):
# The complete predict step.

def configure_optimizers(self):
# Define and configure optimizers and LR schedulers.
Docs

https://lightning.ai/docs/pytorch/latest/common/lightning_module.html

Source code in src/models/mnist_module.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
class MNISTLitModule(LightningModule):
    """Example of a `LightningModule` for MNIST classification.

    A `LightningModule` implements 8 key methods:

    ```python
    def __init__(self):
    # Define initialization code here.

    def setup(self, stage):
    # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
    # This hook is called on every process when using DDP.

    def training_step(self, batch, batch_idx):
    # The complete training step.

    def validation_step(self, batch, batch_idx):
    # The complete validation step.

    def test_step(self, batch, batch_idx):
    # The complete test step.

    def predict_step(self, batch, batch_idx):
    # The complete predict step.

    def configure_optimizers(self):
    # Define and configure optimizers and LR schedulers.
    ```

    Docs:
        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
    """

    def __init__(
        self,
        net: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler,
        compile_model: bool,
    ) -> None:
        """Initialize a `MNISTLitModule`.

        Args:
            net: The model to train.
            optimizer: The optimizer to use for training.
            scheduler: The learning rate scheduler to use for training.
            compile_model: Whether or not compile the model.
        """
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.net = net

        # loss function
        self.criterion = torch.nn.CrossEntropyLoss()

        # metric objects for calculating and averaging accuracy across batches
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        self.test_acc = Accuracy(task="multiclass", num_classes=10)

        # for averaging loss across batches
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        # for tracking best so far validation accuracy
        self.val_acc_best = MaxMetric()

    @typechecked
    def forward(self, x: TensorType["batch", 1, 28, 28]) -> TensorType["batch", 10]:  # noqa
        """Perform a forward pass through the model.

        Args:
            x: A tensor of shape (batch_size, 1, 28, 28) representing the MNIST images.

        Returns:
            A tensor of shape (batch_size, 10) representing the logits for each class.
        """
        return self.net(x)

    def on_train_start(self) -> None:
        """Lightning hook that is called when training begins."""
        # by default lightning executes validation step sanity checks before training starts,
        # so it's worth to make sure validation metrics don't store results from these checks
        self.val_loss.reset()
        self.val_acc.reset()
        self.val_acc_best.reset()

    @typechecked
    def model_step(self, x: TensorType["batch", 1, 28, 28], y: TensorType["batch"]):  # noqa
        """Perform a single model step.

        Args:
            x: Tensor of shape [batch, 1, 28, 28] representing the images.
            y: Tensor of shape [batch] representing the classes.

        Returns:
            A tuple containing:
                - loss: A tensor of shape (batch_size,)
                - preds: A tensor of predicted class indices (batch_size,)
                - targets: A tensor of true class labels (batch_size,)
        """
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y

    @typechecked
    def training_step(self, batch: Any) -> TensorType[()]:
        """Perform a single training step.

        Args:
            batch: A tuple containing input images and target labels.
            batch_idx: The index of the current batch.

        Returns:
            A scalar loss tensor.
        """
        x, y = batch
        loss, preds, targets = self.model_step(x, y)
        self.train_loss(loss)
        self.train_acc(preds, targets)
        self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def on_train_epoch_end(self) -> None:
        """Lightning hook that is called when a training epoch ends."""
        pass

    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single validation step on a batch of data from the validation set.

        Args:
            batch: A batch of data (a tuple) containing the input tensor of images and target
                labels.
            batch_idx: The index of the current batch.
        """
        x, y = batch
        loss, preds, targets = self.model_step(x, y)

        # update and log metrics
        self.val_loss(loss)
        self.val_acc(preds, targets)
        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)

    def on_validation_epoch_end(self) -> None:
        """Lightning hook that is called when a validation epoch ends."""
        acc = self.val_acc.compute()  # get current val acc
        self.val_acc_best(acc)  # update best so far val acc
        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
        # otherwise metric would be reset by lightning after each epoch
        self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)

    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single test step on a batch of data from the test set.

        Args:
            batch: A batch of data (a tuple) containing the input tensor of images and target
                labels.
            batch_idx: The index of the current batch.
        """
        x, y = batch
        loss, preds, targets = self.model_step(x, y)

        # update and log metrics
        self.test_loss(loss)
        self.test_acc(preds, targets)
        self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)

    def on_test_epoch_end(self) -> None:
        """Lightning hook that is called when a test epoch ends."""
        pass

    def setup(self, stage: str) -> None:
        """Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.

        This is a good hook when you need to build models dynamically or adjust something about
        them. This hook is called on every process when using DDP.

        Args:
            stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
        """
        if self.hparams.compile_model and stage == "fit":
            self.net = torch.compile(self.net)

    def configure_optimizers(self) -> dict[str, Any]:
        """Choose what optimizers and learning-rate schedulers to use in your optimization.

        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

        Returns:
            A dict containing the configured optimizers and learning-rate schedulers to be used for training.
        """
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}

__init__(net, optimizer, scheduler, compile_model) #

Initialize a MNISTLitModule.

Parameters:

Name Type Description Default
net Module

The model to train.

required
optimizer Optimizer

The optimizer to use for training.

required
scheduler lr_scheduler

The learning rate scheduler to use for training.

required
compile_model bool

Whether or not compile the model.

required
Source code in src/models/mnist_module.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    net: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler,
    compile_model: bool,
) -> None:
    """Initialize a `MNISTLitModule`.

    Args:
        net: The model to train.
        optimizer: The optimizer to use for training.
        scheduler: The learning rate scheduler to use for training.
        compile_model: Whether or not compile the model.
    """
    super().__init__()

    # this line allows to access init params with 'self.hparams' attribute
    # also ensures init params will be stored in ckpt
    self.save_hyperparameters(logger=False)

    self.net = net

    # loss function
    self.criterion = torch.nn.CrossEntropyLoss()

    # metric objects for calculating and averaging accuracy across batches
    self.train_acc = Accuracy(task="multiclass", num_classes=10)
    self.val_acc = Accuracy(task="multiclass", num_classes=10)
    self.test_acc = Accuracy(task="multiclass", num_classes=10)

    # for averaging loss across batches
    self.train_loss = MeanMetric()
    self.val_loss = MeanMetric()
    self.test_loss = MeanMetric()

    # for tracking best so far validation accuracy
    self.val_acc_best = MaxMetric()

configure_optimizers() #

Choose what optimizers and learning-rate schedulers to use in your optimization.

Normally you'd need one. But in the case of GANs or similar you might have multiple.

Examples:

https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

Returns:

Type Description
dict[str, Any]

A dict containing the configured optimizers and learning-rate schedulers to be used for training.

Source code in src/models/mnist_module.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def configure_optimizers(self) -> dict[str, Any]:
    """Choose what optimizers and learning-rate schedulers to use in your optimization.

    Normally you'd need one. But in the case of GANs or similar you might have multiple.

    Examples:
        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

    Returns:
        A dict containing the configured optimizers and learning-rate schedulers to be used for training.
    """
    optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
    if self.hparams.scheduler is not None:
        scheduler = self.hparams.scheduler(optimizer=optimizer)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val/loss",
                "interval": "epoch",
                "frequency": 1,
            },
        }
    return {"optimizer": optimizer}

forward(x) #

Perform a forward pass through the model.

Parameters:

Name Type Description Default
x TensorType[batch, 1, 28, 28]

A tensor of shape (batch_size, 1, 28, 28) representing the MNIST images.

required

Returns:

Type Description
TensorType[batch, 10]

A tensor of shape (batch_size, 10) representing the logits for each class.

Source code in src/models/mnist_module.py
89
90
91
92
93
94
95
96
97
98
99
@typechecked
def forward(self, x: TensorType["batch", 1, 28, 28]) -> TensorType["batch", 10]:  # noqa
    """Perform a forward pass through the model.

    Args:
        x: A tensor of shape (batch_size, 1, 28, 28) representing the MNIST images.

    Returns:
        A tensor of shape (batch_size, 10) representing the logits for each class.
    """
    return self.net(x)

model_step(x, y) #

Perform a single model step.

Parameters:

Name Type Description Default
x TensorType[batch, 1, 28, 28]

Tensor of shape [batch, 1, 28, 28] representing the images.

required
y TensorType[batch]

Tensor of shape [batch] representing the classes.

required

Returns:

Type Description

A tuple containing: - loss: A tensor of shape (batch_size,) - preds: A tensor of predicted class indices (batch_size,) - targets: A tensor of true class labels (batch_size,)

Source code in src/models/mnist_module.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@typechecked
def model_step(self, x: TensorType["batch", 1, 28, 28], y: TensorType["batch"]):  # noqa
    """Perform a single model step.

    Args:
        x: Tensor of shape [batch, 1, 28, 28] representing the images.
        y: Tensor of shape [batch] representing the classes.

    Returns:
        A tuple containing:
            - loss: A tensor of shape (batch_size,)
            - preds: A tensor of predicted class indices (batch_size,)
            - targets: A tensor of true class labels (batch_size,)
    """
    logits = self.forward(x)
    loss = self.criterion(logits, y)
    preds = torch.argmax(logits, dim=1)
    return loss, preds, y

on_test_epoch_end() #

Lightning hook that is called when a test epoch ends.

Source code in src/models/mnist_module.py
193
194
195
def on_test_epoch_end(self) -> None:
    """Lightning hook that is called when a test epoch ends."""
    pass

on_train_epoch_end() #

Lightning hook that is called when a training epoch ends.

Source code in src/models/mnist_module.py
147
148
149
def on_train_epoch_end(self) -> None:
    """Lightning hook that is called when a training epoch ends."""
    pass

on_train_start() #

Lightning hook that is called when training begins.

Source code in src/models/mnist_module.py
101
102
103
104
105
106
107
def on_train_start(self) -> None:
    """Lightning hook that is called when training begins."""
    # by default lightning executes validation step sanity checks before training starts,
    # so it's worth to make sure validation metrics don't store results from these checks
    self.val_loss.reset()
    self.val_acc.reset()
    self.val_acc_best.reset()

on_validation_epoch_end() #

Lightning hook that is called when a validation epoch ends.

Source code in src/models/mnist_module.py
168
169
170
171
172
173
174
def on_validation_epoch_end(self) -> None:
    """Lightning hook that is called when a validation epoch ends."""
    acc = self.val_acc.compute()  # get current val acc
    self.val_acc_best(acc)  # update best so far val acc
    # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
    # otherwise metric would be reset by lightning after each epoch
    self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)

setup(stage) #

Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.

This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

Name Type Description Default
stage str

Either "fit", "validate", "test", or "predict".

required
Source code in src/models/mnist_module.py
197
198
199
200
201
202
203
204
205
206
207
def setup(self, stage: str) -> None:
    """Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.

    This is a good hook when you need to build models dynamically or adjust something about
    them. This hook is called on every process when using DDP.

    Args:
        stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
    """
    if self.hparams.compile_model and stage == "fit":
        self.net = torch.compile(self.net)

test_step(batch, batch_idx) #

Perform a single test step on a batch of data from the test set.

Parameters:

Name Type Description Default
batch tuple[Tensor, Tensor]

A batch of data (a tuple) containing the input tensor of images and target labels.

required
batch_idx int

The index of the current batch.

required
Source code in src/models/mnist_module.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
    """Perform a single test step on a batch of data from the test set.

    Args:
        batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        batch_idx: The index of the current batch.
    """
    x, y = batch
    loss, preds, targets = self.model_step(x, y)

    # update and log metrics
    self.test_loss(loss)
    self.test_acc(preds, targets)
    self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)

training_step(batch) #

Perform a single training step.

Parameters:

Name Type Description Default
batch Any

A tuple containing input images and target labels.

required
batch_idx

The index of the current batch.

required

Returns:

Type Description
TensorType[]

A scalar loss tensor.

Source code in src/models/mnist_module.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
@typechecked
def training_step(self, batch: Any) -> TensorType[()]:
    """Perform a single training step.

    Args:
        batch: A tuple containing input images and target labels.
        batch_idx: The index of the current batch.

    Returns:
        A scalar loss tensor.
    """
    x, y = batch
    loss, preds, targets = self.model_step(x, y)
    self.train_loss(loss)
    self.train_acc(preds, targets)
    self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
    return loss

validation_step(batch, batch_idx) #

Perform a single validation step on a batch of data from the validation set.

Parameters:

Name Type Description Default
batch tuple[Tensor, Tensor]

A batch of data (a tuple) containing the input tensor of images and target labels.

required
batch_idx int

The index of the current batch.

required
Source code in src/models/mnist_module.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
    """Perform a single validation step on a batch of data from the validation set.

    Args:
        batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        batch_idx: The index of the current batch.
    """
    x, y = batch
    loss, preds, targets = self.model_step(x, y)

    # update and log metrics
    self.val_loss(loss)
    self.val_acc(preds, targets)
    self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)