Skip to content

Tutorials

Basic usage

Integrating a Batch Size Scheduler inside a PyTorch training script is simple:

from torch.utils.data import DataLoader
from bs_scheduler import StepBS
# We use StepBS in this example, but we can use any BS Scheduler

# Define the Dataset and the DataLoader
dataset = ...
dataloader = DataLoader(..., batch_size=16)
scheduler = StepBS(dataloader, step_size=30, gamma=2)
# Activates every 30 epochs and doubles the batch size.

for _ in range(100):
    train(...)
    validate(...)
    scheduler.step()

# We expect the batch size to have the following values:
# epoch 0 - 29: 16
# epoch 30 - 59: 32
# epoch 60 - 89: 64
# epoch 90 - 99: 128

Full example:

import timm
import torch.cuda
import torchvision.datasets
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import v2

from bs_scheduler import StepBS

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transforms = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
train_loader = DataLoader(
    torchvision.datasets.CIFAR10(
        root="../data",
        train=True,
        download=True,
        transform=transforms,
    ),
    batch_size=100,
)
val_loader = DataLoader(
    torchvision.datasets.CIFAR10(root="../data", train=False, transform=transforms),
    batch_size=500,
)
scheduler = StepBS(train_loader, step_size=10)

model = timm.create_model("hf_hub:grodino/resnet18_cifar10", pretrained=False).to(
    device
)
criterion = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=0.001)


def train():
    correct = 0
    total = 0

    model.train()
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        predicted = outputs.argmax(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    return correct / total


@torch.inference_mode()
def val():
    correct = 0
    total = 0

    model.eval()
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)

        predicted = outputs.argmax(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    return correct / total


def main():
    for epoch in range(100):
        train_accuracy = train()
        val_accuracy = val()

        scheduler.step()

        print(train_accuracy, val_accuracy)


if __name__ == "__main__":
    main()