Class incremental learning of the GTSRB dataset¶
The goal of the project is to develop a solution for recognizing images of road signs. This is a classical dataset, ans it is available in many repositories, e.g. torchvision.
Images have small size (32x32) and the number of data is also small. A deep network can run on a (good) CPU in a reasonable time. GPU will make it faster of course.
The question to be addressed in the project is to introduce sequentially new categories to classify: this is called class-incremental learning in the literature.
The problem of this learning setting is that it produces "catastrophic forgetting" when new categories (new road signs) to recognize are sequentially added to the recognition system: the old categories are forgotten if nothing is done.
Several strategies can be used to avoid forgetting:
- Rehearsal (maintain a small memory buffer of previous examples)
- Regularization (Knowledge distillation, EWC...)
- Incremental architecture
- ...
The figure below shows two learning strategies when new classes, 8 at each time, are sequentially added.
The "upper bound" curve is generated by using all data available (non incremental) and is the expected target performance.
The "Fine tuning" is obtained by simply applying a gradient descent on the new data: it is the simplest strategy and it produces catastrophic forgetting.
The "rehearsal KD" curve is obtained using a rehearsal memory buffer of size 200 and a Knowledge Distillation between old and new classes.
A few references to help you find a solution:
Tutorial: https://sites.google.com/view/neurips2022-llm-tutorial
Recent survey: Zhou, D. W., Wang, Q. W., Qi, Z. H., Ye, H. J., Zhan, D. C., & Liu, Z. (2024). Class-incremental learning: A survey. IEEE Transactions on Pattern Analysis and Machine Intelligence: https://github.com/zhoudw-zdw/CIL_Survey/
Huge list of papers: https://github.com/ContinualAI/continual-learning-papers
Van de Ven, G. M., & Tolias, A. S. (2019). Three scenarios for continual learning: https://arxiv.org/abs/1904.07734, https://github.com/GMvandeVen/continual-learning
A simple strategy: Prabhu, A., Torr, P. H., & Dokania, P. K. (2020). Gdumb: A simple approach that questions our progress in continual learning. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part II 16 (pp. 524-540). https://github.com/drimpossible/GDumb
A jupyter notebook to help you start (data loders, basline for class incremental learning):
Class incremental learning on the GTSRB dataset¶
This notebook contains several code snippets to help for your project:
- data loaders
- A baseline for incremental learning using fine-tuning
- Examples of how to use Weight & Biases for logging your results.
$\leadsto$ Remark 1
In Section 9: Rehearsal Memory Buffer + Knowledge Distillation Loss (KD), we implement class-incremental learning by using a rehearsal memory buffer of fixed total size 200 and apply Knowledge Distillation (KD) to transfer knowledge from old to new classes. [1] [2]
In Section 10: Dynamically Expandable Representation (DER), we implement class-incremental learning using a rehearsal memory buffer where each class stores a fixed number of 10 samples. Additionally, we apply Knowledge Distillation between old and new classes, combined with a Dynamically Expandable Representation (DER_CNN) model, which dynamically expands the classifier to accommodate new classes while preserving old knowledge. [3] [4]
References
[1] Zhou, D. W., Wang, Q. W., Qi, Z. H., Ye, H. J., Zhan, D. C., & Liu, Z. (2024). Class-Incremental Learning: A Survey. IEEE Transactions on Pattern Analysis and Machine Intelligence. PDF
[2] Zhou et al. (2024). iCaRL Implementation. GitHub Repository. iCaRL.py
[3] Yan, S., Xie, J., Wang, C., Gong, Y., & Yan, J. (2021). DER: Dynamically Expandable Representation for Class Incremental Learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). PDF
[4] Yan et al. (2021). DER Implementation. GitHub Repository. main.py
Section 0: Importing Common Libraries and Modules¶
0.1 Importing Essential Libraries¶
####################################
'''A.PyTorch Modules'''
import torch
import torch.nn as nn # 神经网络模块 / Neural network modules
import torch.nn.init as init # 网络参数初始化 / Network parameter initialization
import torch.optim as optim # 优化器 / Optimizers
import torch.nn.functional as F # 常用函数(如激活函数)/ Common functions like activation functions
'''B.Torchvision Modules'''
from torch.utils.data import DataLoader, ConcatDataset, Subset # 数据加载与处理 / Data loading and handling
from torchvision.utils import make_grid # 图像可视化工具 / For visualizing images
from torchvision import transforms, datasets # 图像数据集和数据增强 / Image datasets and data augmentation
import torchvision.models as models # 预训练模型 / Pre-trained models
from torchvision.transforms import v2 # torchvision 的新版数据增强接口 / Newer data augmentation API in torchvision
'''C.Helper Libraries'''
import copy # 深拷贝模型或数据 / For deep copying models or data
import numpy as np # 科学计算库 / Scientific computing library
import random # 随机数生成 / Random number generation
import time, os # 计时与操作系统交互 / Timing and OS interaction
import matplotlib.pyplot as plt # 绘图库 / Plotting library
from PIL import Image # 图像处理 / Image processing
from tqdm import tqdm # 进度条 / Progress bar
import pandas as pd # 数据表处理 / Dataframe handling
import math # 数学函数 / Math functions
0.2 Device Configuration¶
# 设备配置,优先使用 GPU / Device configuration, prioritize GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# # Useful if you want to store intermediate results on your drive
# from google.colab import drive
# # Useful if you want to store intermediate results on your drive from google.colab import drive
# drive.mount('/content/gdrive/')
# DATA_DIR = '/content/gdrive/MyDrive/teaching/ENSTA/2024'
# Check if GPU is available
if torch.cuda.is_available():
!nvidia-smi
Sat Mar 22 05:08:51 2025
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.29.01 Driver Version: 546.01 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA T1200 Laptop GPU On | 00000000:01:00.0 Off | N/A |
| N/A 52C P0 17W / 90W | 0MiB / 4096MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
Section 1: Data loaders¶
'''A. 定义数据增强 / Define Transformations'''
# 训练数据增强 / Training data transformations
transform_train = transforms.Compose([
transforms.RandomResizedCrop(32), # 随机裁剪并缩放到32x32 / Random resized crop to 32x32
transforms.RandomHorizontalFlip(), # 随机水平翻转 / Random horizontal flip
transforms.ToTensor(), # 转换为张量 / Convert to tensor
# transforms.Normalize((0.3403, 0.3121, 0.3214), (0.2724, 0.2608, 0.2669)) # GTSRB 数据集的均值和标准差 / GTSRB dataset mean and std
])
# 测试数据增强 / Testing data transformations
transform_test = transforms.Compose([
transforms.Resize(32), # 缩放到32x32 / Resize to 32x32
transforms.ToTensor(), # 转换为张量 / Convert to tensor
# transforms.Normalize((0.3403, 0.3121, 0.3214), (0.2724, 0.2608, 0.2669))
])
'''B. 使用新版 torchvision.transforms v2 接口 / Using New torchvision.transforms v2 API'''
# 训练数据增强 / Training data transformations (v2)
transform_train = v2.Compose([
# v2.Grayscale(), # 可以转换为灰度 / Optional grayscale conversion
# v2.RandomResizedCrop(32),
v2.Resize((32, 32)), # 调整为32x32 / Resize to 32x32
v2.ToImage(), # 转换为图像格式 / Convert to image format
v2.ToDtype(torch.float32, scale=True), # 转换为 float32 类型并缩放到 [0,1] / Convert to float32 and scale
# v2.Normalize((0.3403, 0.3121, 0.3214), (0.2724, 0.2608, 0.2669))
])
# 测试数据增强 / Testing data transformations (v2)
transform_test = v2.Compose([
# v2.Grayscale(),
v2.Resize((32, 32)), # 调整为32x32 / Resize to 32x32
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
# v2.Normalize((0.3403, 0.3121, 0.3214), (0.2724, 0.2608, 0.2669))
])
'''C. 定义数据集和数据加载器 / Define Dataset and DataLoader'''
# 获取数据集 / Get dataset
def get_dataset(root_dir, transform, train=True):
"""
root_dir: 数据集存储路径 / Path to store dataset
transform: 数据增强方式 / Transformations applied
train: 是否加载训练集 / Whether to load training set
"""
dataset = datasets.GTSRB(root=root_dir, split='train' if train else 'test', download=True, transform=transform)
target = [data[1] for data in dataset] # 获取标签 / Extract labels
return dataset, target
# 创建数据加载器 / Create DataLoader
def create_dataloader(dataset, targets, current_classes, batch_size, shuffle):
"""
dataset: 完整数据集 / Full dataset
targets: 标签列表 / List of labels
current_classes: 当前要训练的类别 / Current classes to be used
batch_size: 批大小 / Batch size
shuffle: 是否打乱数据 / Whether to shuffle data
"""
indices = [i for i, label in enumerate(targets) if label in current_classes] # 只选取当前类别的数据 / Filter by current classes
subset = Subset(dataset, indices) # 创建子集 / Create subset
dataloader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle) # 创建数据加载器 / Create DataLoader
return dataloader
'''D. 指定数据集存储路径 & 加载完整训练集和测试集 / Specify Dataset Directory & Load Datasets'''
# 指定数据集所在文件夹路径
# Specify dataset directory
root_dir = './data'
# 加载训练集和测试集 / Load training and test datasets
train_dataset = datasets.GTSRB(root=root_dir, split='train', download=True, transform=transform_train)
test_dataset = datasets.GTSRB(root=root_dir, split='test', download=True, transform=transform_test)
print(f"Train dataset contains {len(train_dataset)} images")
print(f"Test dataset contains {len(test_dataset)} images")
'''E. 加载标签列表和类别名称文件 / Load Target Labels and Class Names'''
# 加载本地的标签列表和类别名称文件 / Load local target ID lists and class names
import csv
# 加载测试集标签 / Load test target IDs
data = pd.read_csv('./test_target.csv', delimiter=',', header=None)
test_target = data.to_numpy().squeeze().tolist()
# 加载训练集标签 / Load train target IDs
data = pd.read_csv('./train_target.csv', delimiter=',', header=None)
train_target = data.to_numpy().squeeze().tolist()
# 加载类别名称 / Load class names
data = pd.read_csv('./signnames.csv')
class_names = data['SignName'].tolist()
Train dataset contains 26640 images Test dataset contains 12630 images
'''
# Loads datasets (on your local computer)
root_dir = '/home/stephane/Documents/Onera/Cours/ENSTA/2025/data'
# Loads datasets (on Colab local computer)
root_dir = './data'
train_dataset = datasets.GTSRB(root=root_dir, split='train', download=True, transform=transform_train)
test_dataset = datasets.GTSRB(root=root_dir, split='test', download=True, transform=transform_test)
print(f"Train dataset contains {len(train_dataset)} images")
print(f"Test dataset contains {len(test_dataset)} images")
# Loads target id lists and class names (not in torchvision dataset)
import csv
data = pd.read_csv('https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/test_target.csv', delimiter=',', header=None)
test_target = data.to_numpy().squeeze().tolist()
data = pd.read_csv('https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/train_target.csv', delimiter=',', header=None)
train_target = data.to_numpy().squeeze().tolist()
data = pd.read_csv('https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/signnames.csv')
class_names = data['SignName'].tolist()'
'''
'\n# Loads datasets (on your local computer)\nroot_dir = \'/home/stephane/Documents/Onera/Cours/ENSTA/2025/data\'\n\n# Loads datasets (on Colab local computer)\nroot_dir = \'./data\'\n\ntrain_dataset = datasets.GTSRB(root=root_dir, split=\'train\', download=True, transform=transform_train)\ntest_dataset = datasets.GTSRB(root=root_dir, split=\'test\', download=True, transform=transform_test)\n\nprint(f"Train dataset contains {len(train_dataset)} images")\nprint(f"Test dataset contains {len(test_dataset)} images")\n\n# Loads target id lists and class names (not in torchvision dataset)\nimport csv\ndata = pd.read_csv(\'https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/test_target.csv\', delimiter=\',\', header=None)\ntest_target = data.to_numpy().squeeze().tolist()\n\ndata = pd.read_csv(\'https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/train_target.csv\', delimiter=\',\', header=None)\ntrain_target = data.to_numpy().squeeze().tolist()\n\ndata = pd.read_csv(\'https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/signnames.csv\')\nclass_names = data[\'SignName\'].tolist()\'\n'
Section 2: Display of images¶
2.1 Display of images¶
# 获取训练集中的不同类别数量
# Get the number of unique classes in the training set
nclasses = len(np.unique(train_target))
# 创建一个包含所有类别的列表
# Create a list of all class indices
all_classes = list(range(nclasses))
# 如果需要可以打乱类别顺序 (这里注释掉了)
# You can shuffle the class order if needed (currently commented out)
# random.shuffle(all_classes)
# 每个任务包含多少个类别
# Number of classes per task
classes_per_task = 8
# 当前任务要使用的类别列表
# List of classes for the current task
current_classes = []
# 任务编号,从0开始
# Task index, starting from 0
task = 0
# 选取当前任务要训练的类别
'''
每个任务选取连续的 8 个类别,前8个、下8个、再下8个……,方便一批一批训练不同类别!:
task 0 : task_classes [0,1,2,3,4,5,6,7]
task 1 : task_classes [8,9,10,11,12,13,14,15]
task 2 : task_classes [16,17,18,19,20,21,22,23]
...
'''
# Select classes for the current task
task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
# 添加到当前类别列表
# Add selected classes to current_classes
current_classes.extend(task_classes)
# 批大小
# Batch size
batch_size = 64
# ==========================
# 创建当前任务的训练和测试数据加载器
# Create DataLoader for training and testing for current task
train_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle=True)
'''
test_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle=True)'
'''
test_loader = create_dataloader(test_dataset, test_target, current_classes, batch_size, shuffle=True)
# ==========================
# 显示数据样例
# Function to display image samples
def show(img):
npimg = img.numpy()
# 转置维度,显示图片
# Transpose dimensions and display image
plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
# 从训练数据加载器中取一个 batch
# Get one batch from the train loader
sample, targets = next(iter(train_loader))
# 显示这些样例图片
# Display the image grid
show(make_grid(sample))
plt.show()
# ==========================
# 输出 batch 的 shape
print(sample.shape)
# 解释:
# 64 是 batch 大小
# 64 is the batch size
# 第二个维度是通道数:
# 1 表示灰度图,3 表示彩色 RGB 图
# Second dimension is the number of channels:
# 1 for grayscale, 3 for RGB
# 后两个是图片的高和宽(这里是 32x32)
# Last two dimensions are image height and width (32x32 here)
torch.Size([64, 3, 32, 32])
2.2 Advanced Batch Visualization¶
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
def show_batch(img, targets):
"""
显示一批图像和对应标签,8x8 网格
Display a batch of images and their labels in an 8x8 grid
"""
# 创建 8x8 图片网格
grid_img = make_grid(img, nrow=8, padding=2)
plt.figure(figsize=(6, 6))
npimg = grid_img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off')
# 循环为每张小图加标签
for idx in range(len(targets)):
row = idx // 8
col = idx % 8
plt.text(
x=col * (32 + 2) + 4, # X 坐标,带padding
y=row * (32 + 2) + 30, # Y 坐标
s=str(targets[idx].item()), # 标签数字
color='white',
fontsize=10,
bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2')
)
plt.show()
show_batch(sample, targets)
'''
nclasses = len(np.unique(train_target))
all_classes = list(range(nclasses))
#random.shuffle(all_classes)
classes_per_task = 8
current_classes = []
task = 0
task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
current_classes.extend(task_classes)
batch_size = 64
# Create data for first task
train_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle = True)
test_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle = True)
# Displays a few examples
def show(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
sample,targets = next(iter(train_loader))
show(make_grid(sample))
plt.show()
print(sample.shape) ## 64 is the batch
## 1 for grey values -- 3 for RGB
## 32x32 for mage size (small here)
'''
"\nnclasses = len(np.unique(train_target))\nall_classes = list(range(nclasses))\n#random.shuffle(all_classes)\nclasses_per_task = 8\ncurrent_classes = []\n\ntask = 0\ntask_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]\ncurrent_classes.extend(task_classes)\nbatch_size = 64\n\n# Create data for first task\ntrain_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle = True)\ntest_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle = True)\n\n# Displays a few examples\ndef show(img):\n npimg = img.numpy()\n plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')\n\nsample,targets = next(iter(train_loader))\nshow(make_grid(sample))\nplt.show()\n\nprint(sample.shape) ## 64 is the batch\n ## 1 for grey values -- 3 for RGB\n ## 32x32 for mage size (small here)\n \n"
2.3 Test Set Label Distribution Analysis¶
# 创建测试数据加载器,包含所有类别
# Create DataLoader for test set, including all classes
'''
test_loader = create_dataloader(train_dataset, train_target, all_classes, batch_size, shuffle=True)'
'''
test_loader = create_dataloader(test_dataset, test_target, all_classes, batch_size, shuffle=True)
# ==========================
# (可选)从 DataLoader 获取测试数据标签,注释掉的是另一种方法
# (Optional) Extract test labels from DataLoader - commented out here
# gtsrbtest_gt = [] # 用于存储所有标签 / To store all targets
# for _, targets in test_loader:
# gtsrbtest_gt += targets.numpy().tolist() # 转为列表追加 / Convert targets to list and append
# print(len(gtsrbtest_gt)) # 打印标签总数 / Print total number of targets
# ==========================
# 直接统计 test_target 列表中的标签分布
# Directly count label distribution in test_target
from collections import Counter
# 统计每个类别标签的数量,并按数量降序排列
# Count occurrences of each label, sort by most common
label_counts = Counter(test_target).most_common()
# 打印每个类别的统计结果
# Print statistics for each class
print("Count\tClassID\tClassName")
for l, c in label_counts:
print(f"{c}\t{l}\t{class_names[l]}") # 输出格式: 数量, 类别编号, 类别名称
Count ClassID ClassName 750 2 Speed limit (50km/h) 720 1 Speed limit (30km/h) 720 13 Yield 690 38 Keep right 690 12 Priority road 660 4 Speed limit (70km/h) 660 10 No passing for vechiles over 3.5 metric tons 630 5 Speed limit (80km/h) 480 25 Road work 480 9 No passing 450 7 Speed limit (100km/h) 450 3 Speed limit (60km/h) 450 8 Speed limit (120km/h) 420 11 Right-of-way at the next intersection 390 18 General caution 390 35 Ahead only 360 17 No entry 270 14 Stop 270 31 Wild animals crossing 210 33 Turn right ahead 210 15 No vechiles 180 26 Traffic signals 150 16 Vechiles over 3.5 metric tons prohibited 150 23 Slippery road 150 30 Beware of ice/snow 150 28 Children crossing 150 6 End of speed limit (80km/h) 120 34 Turn left ahead 120 22 Bumpy road 120 36 Go straight or right 90 21 Double curve 90 20 Dangerous curve to the right 90 24 Road narrows on the right 90 29 Bicycles crossing 90 40 Roundabout mandatory 90 39 Keep left 90 42 End of no passing by vechiles over 3.5 metric tons 60 27 Pedestrians 60 32 End of all speed and passing limits 60 41 End of no passing 60 19 Dangerous curve to the left 60 0 Speed limit (20km/h) 60 37 Go straight or left
'''
test_loader = create_dataloader(train_dataset, train_target, all_classes, batch_size, shuffle = True)
# Get the data from the test set and computes statistics
# gtsrbtest_gt = []
# for _, targets in test_loader:
# gtsrbtest_gt += targets.numpy().tolist()
# print(len(gtsrbtest_gt))
from collections import Counter
label_counts = Counter(test_target).most_common()
for l, c in label_counts:
print(c, '\t', l, '\t', class_names[l])
'''
"\ntest_loader = create_dataloader(train_dataset, train_target, all_classes, batch_size, shuffle = True)\n\n# Get the data from the test set and computes statistics\n# gtsrbtest_gt = []\n# for _, targets in test_loader:\n# gtsrbtest_gt += targets.numpy().tolist()\n# print(len(gtsrbtest_gt))\n\nfrom collections import Counter\n\nlabel_counts = Counter(test_target).most_common()\nfor l, c in label_counts:\n print(c, '\t', l, '\t', class_names[l])\n"
Section 3: Simple networks¶
3.1 Definition of Simple CNN Models¶
# Define a simple CNN model
class SimpleCNN(nn.Module):
def __init__(self,n_out=10, n_in=1):
super().__init__()
# Put the layers here
self.conv1 = nn.Conv2d(n_in, 32, kernel_size=5, padding=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.fc = nn.Linear(4096, n_out)
def forward(self, x):
x = F.leaky_relu(self.conv1(x)) ## l'image 1x32x32 devient 32x32x32
x = F.max_pool2d(x, kernel_size=2, stride=2) ## puis 32x16x16
x = F.leaky_relu(self.conv2(x)) ## puis devient 64x16x16
x = F.max_pool2d(x, kernel_size=2, stride=2) ## puis devient 64x8x8
x = F.leaky_relu(self.conv3(x)) ## pas de changement
x = x.view(-1,4096) ## 64x8x8 devient 4096
x = self.fc(x) ## on finit exactement de la même façon
return x
# Another simple model (compare them using torchinfo below)
class SimpleCNN2(nn.Module):
def __init__(self, n_out=10, n_in=1):
super(SimpleCNN2, self).__init__()
self.conv1 = nn.Conv2d(n_in, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 * 8 * 8, 128)
self.fc = nn.Linear(128, n_out)
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8) # Flatten the tensor
x = self.relu(self.fc1(x))
x = self.fc(x)
return x
3.2 Model Summary¶
!pip install torchinfo
Requirement already satisfied: torchinfo in /home/home/miniconda3/envs/tensor/lib/python3.11/site-packages (1.8.0)
from torchinfo import summary
model = SimpleCNN(n_out=10, n_in=3)
model.to(device)
print(summary(model, input_size=(batch_size, 3, 32, 32)))
model = SimpleCNN2(n_out=10, n_in=3)
model.to(device)
print(summary(model, input_size=(batch_size, 3, 32, 32)))
#print(model)
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== SimpleCNN [64, 10] -- ├─Conv2d: 1-1 [64, 32, 32, 32] 2,432 ├─Conv2d: 1-2 [64, 64, 16, 16] 18,496 ├─Conv2d: 1-3 [64, 64, 8, 8] 36,928 ├─Linear: 1-4 [64, 10] 40,970 ========================================================================================== Total params: 98,826 Trainable params: 98,826 Non-trainable params: 0 Total mult-adds (Units.MEGABYTES): 616.30 ========================================================================================== Input size (MB): 0.79 Forward/backward pass size (MB): 27.27 Params size (MB): 0.40 Estimated Total Size (MB): 28.45 ========================================================================================== ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== SimpleCNN2 [64, 10] -- ├─Conv2d: 1-1 [64, 32, 32, 32] 896 ├─ReLU: 1-2 [64, 32, 32, 32] -- ├─MaxPool2d: 1-3 [64, 32, 16, 16] -- ├─Conv2d: 1-4 [64, 64, 16, 16] 18,496 ├─ReLU: 1-5 [64, 64, 16, 16] -- ├─MaxPool2d: 1-6 [64, 64, 8, 8] -- ├─Linear: 1-7 [64, 128] 524,416 ├─ReLU: 1-8 [64, 128] -- ├─Linear: 1-9 [64, 10] 1,290 ========================================================================================== Total params: 545,098 Trainable params: 545,098 Non-trainable params: 0 Total mult-adds (Units.MEGABYTES): 395.40 ========================================================================================== Input size (MB): 0.79 Forward/backward pass size (MB): 25.24 Params size (MB): 2.18 Estimated Total Size (MB): 28.20 ==========================================================================================
Section 4: Baseline for incremental learning¶
from torch.optim import lr_scheduler
import torch.nn.init as init
# Evaluation
def evaluate(model, test_loader, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in tqdm(test_loader, ncols=80):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
# Simple Training loop
def train(model, train_loader, optimizer, criterion, device, epoch):
model.train()
for images, labels in tqdm(train_loader, ncols=80, desc="Epoch {}".format(epoch)):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
def initialize_weights(module):
"""Initializes the weights of a PyTorch module using Xavier/Glorot initialization."""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): # Check for relevant layers
init.xavier_uniform_(module.weight) #Xavier uniform initialization
if module.bias is not None:
init.zeros_(module.bias) # Initialize bias to zero
elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): #Initialize normalization layers
if module.weight is not None:
init.ones_(module.weight)
if module.bias is not None:
init.zeros_(module.bias)
# Main training loop for incremental learning
def incremental_learning(model, train_dataset, train_target, test_dataset, test_target,
num_tasks, classes_per_task, batch_size, num_epochs, lr, device):
nclasses = len(np.unique(train_target))
all_classes = list(range(nclasses))
criterion = nn.CrossEntropyLoss()
current_classes = []
accuracies = []
for task in range(num_tasks):
task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
current_classes.extend(task_classes)
train_loader = create_dataloader(train_dataset, train_target, task_classes, batch_size, shuffle = True)
test_loader = create_dataloader(test_dataset, test_target, current_classes, batch_size, shuffle = False)
if task == 0:
model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
else:
# Expand the output layer for new classes
old_weight = model.fc.weight.data
old_bias = model.fc.bias.data
model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
model.fc.weight.data[:len(old_weight)] = old_weight
model.fc.bias.data[:len(old_bias)] = old_bias
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
print(f"Starting Task {task+1} - Training on classes: {task_classes}")
for epoch in range(num_epochs): # Adjust number of epochs as needed
train(model, train_loader, optimizer, criterion, device, epoch)
scheduler.step()
accuracy = evaluate(model, train_loader, device)
print(f"Task {task+1}, Epoch {epoch+1}: Accuracy Train = {accuracy:.2f}%")
accuracy = evaluate(model, test_loader, device)
accuracies.append(accuracy)
print(f"Task {task+1}: Accuracy Test = {accuracy:.2f}%")
return accuracies
Section 5: Weight & Biases¶
You can use this environement to log your learning.
The code below provides a version of the class incremental function that stores learning curves and the seauence of accuracies for each increment of classes.
Tu use it, create an account at: https://wandb.ai/
###################################
##### For using Weight & Biases
###############
!pip install wandb -qU
import wandb
wandb.login()
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information. wandb: Currently logged in as: kaiyuan_xu (kaiyuan_xu_upsud) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
True
import math
# Simple Training loop
def train_wandb(model, train_loader, optimizer, criterion, device, epoch):
step_ct = 0
n_steps_per_epoch = math.ceil(len(train_loader.dataset) / train_loader.batch_size)
model.train()
for step, (images, labels) in tqdm(enumerate(train_loader), ncols=80, desc="Epoch {}".format(epoch)):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
metrics = {"train/train_loss": loss}
# metrics = {"train/train_loss": loss,
# "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch}
if step + 1 < n_steps_per_epoch:
# Log train metrics to wandb
wandb.log(metrics)
step_ct += 1
# Main training loop for incremental learning
def incremental_learning_wandb(model, train_dataset, train_target, test_dataset, test_target,
num_tasks, classes_per_task, batch_size, num_epochs, lr, device, non_incremental = False):
nclasses = len(np.unique(train_target))
all_classes = list(range(nclasses))
criterion = nn.CrossEntropyLoss()
current_classes = []
accuracies = []
# Copy your config
config = wandb.config
wandb.define_metric("task")
wandb.define_metric("incremental_accuracy", step_metric="task")
for task in range(num_tasks):
if non_incremental == True: # Learn from all available data
task_classes = all_classes[0 : (task + 1) * classes_per_task]
current_classes = task_classes
model.apply(initialize_weights)
else:
task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
current_classes.extend(task_classes)
train_loader = create_dataloader(train_dataset, train_target, task_classes, batch_size, shuffle = True)
test_loader = create_dataloader(test_dataset, test_target, current_classes, batch_size, shuffle = False)
if task == 0 or non_incremental == True:
model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
else:
# Expand the output layer for new classes
old_weight = model.fc.weight.data
old_bias = model.fc.bias.data
model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
model.fc.weight.data[:len(old_weight)] = old_weight
model.fc.bias.data[:len(old_bias)] = old_bias
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
print(f"Starting Task {task+1} - Training on classes: {task_classes}")
for epoch in range(num_epochs): # Adjust number of epochs
#train(model, train_loader, optimizer, criterion, device, epoch)
# If logging training (but not incremental)
train_wandb(model, train_loader, optimizer, criterion, device, epoch)
scheduler.step()
accuracy = evaluate(model, train_loader, device)
print(f"Task {task+1}, Epoch {epoch+1}: Accuracy Train = {accuracy:.2f}%")
val_metrics = {"val/val_accuracy": accuracy}
#wandb.log({**val_metrics})
accuracy = evaluate(model, test_loader, device)
accuracies.append(accuracy)
print(f"Task {task+1}: Accuracy Test = {accuracy:.2f}%")
incremental_metrics = {"incremental_accuracy": accuracy, "task": task}
wandb.log({**incremental_metrics})
# Log train and validation metrics to wandb
return accuracies
Section 6: Pre-Training¶
# Hyperparameters
root_dir = './data' # Path to GTSRB dataset
num_tasks = 5
numclasses = len(np.unique(train_target))
classes_per_task = numclasses // num_tasks #43/2 ~ 20
batch_size = 64
lr = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
buffer_size = 200 # Adjust rehearsal set size
alignment_strength = 0.1 # Adjust alignment strength
num_epochs = 4
#model = SimpleCNN(n_out = 1, n_in = 3).to(device)
#model.apply(initialize_weights)
# The name of the network (choose the on you want)
tag = "simpleCNN_GTSRB_pretrained"
netname = os.path.join(root_dir, 'network_{:s}.pth'.format(tag))
#################################################
## Pre-training
####
# Read the last learned network (if stored)
if (os.path.exists(netname)):
print('Load pre-trained network')
model = SimpleCNN(n_in = 3, n_out=classes_per_task)
model.load_state_dict(torch.load(netname,weights_only=True))
#model = torch.load(netname, weights_only=True)
model = model.to(device)
else:
print('Pretrain')
model = SimpleCNN(n_in = 3, n_out=1)
model.apply(initialize_weights)
model.to(device)
accu = incremental_learning(model, train_dataset, train_target, test_dataset, test_target,
1, classes_per_task, batch_size, num_epochs, lr, device)
print(f"!!!!! Pre-training on first task = {accu[0]:.2f}%")
# Save last learned model
#torch.save(model, netname)
torch.save(model.state_dict(), netname)
## Copy model to have the same initialization
copy_model = copy.deepcopy(model) # Copy model to start from the same initialization
#### Learn with a single epoch in incremental (faster but less accurate)
num_epochs = 1
Load pre-trained network
Section 7: Fine tuning¶
#############################################
## Fine tuning
####
# initialise a wandb run
num_epochs = 5
run = wandb.init(
project="GTSRB-CIL",
name = "Fine tuning",
config={
"epochs": num_epochs,
"batch_size": batch_size,
"num_tasks": num_tasks,
"classes_per_task": classes_per_task,
"lr": lr,
})
# Simple Incremental Fine Tuning
model = copy.deepcopy(copy_model)
incremental_learning_wandb(model, train_dataset, train_target, test_dataset, test_target,
num_tasks, classes_per_task, batch_size, num_epochs, lr, device)
wandb.finish()
/home/home/projects/tensor/projet_Machine_learning/wandb/run-20250322_050856-252cfjp6
Starting Task 1 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7]
Epoch 0: 125it [00:04, 25.83it/s] 100%|█████████████████████████████████████████| 125/125 [00:02<00:00, 54.40it/s]
Task 1, Epoch 1: Accuracy Train = 94.25%
Epoch 1: 125it [00:02, 47.46it/s] 100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 65.62it/s]
Task 1, Epoch 2: Accuracy Train = 98.83%
Epoch 2: 125it [00:02, 43.63it/s] 100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 66.64it/s]
Task 1, Epoch 3: Accuracy Train = 98.87%
Epoch 3: 125it [00:02, 44.90it/s] 100%|█████████████████████████████████████████| 125/125 [00:02<00:00, 62.00it/s]
Task 1, Epoch 4: Accuracy Train = 99.70%
Epoch 4: 125it [00:02, 45.29it/s] 100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 66.70it/s]
Task 1, Epoch 5: Accuracy Train = 99.79%
100%|███████████████████████████████████████████| 61/61 [00:01<00:00, 34.98it/s]
Task 1: Accuracy Test = 91.71% Starting Task 2 - Training on classes: [8, 9, 10, 11, 12, 13, 14, 15]
Epoch 0: 126it [00:06, 18.06it/s] 100%|█████████████████████████████████████████| 126/126 [00:02<00:00, 47.69it/s]
Task 2, Epoch 1: Accuracy Train = 99.40%
Epoch 1: 126it [00:02, 50.76it/s] 100%|█████████████████████████████████████████| 126/126 [00:01<00:00, 66.30it/s]
Task 2, Epoch 2: Accuracy Train = 99.95%
Epoch 2: 126it [00:02, 44.91it/s] 100%|█████████████████████████████████████████| 126/126 [00:01<00:00, 63.84it/s]
Task 2, Epoch 3: Accuracy Train = 99.96%
Epoch 3: 126it [00:02, 44.75it/s] 100%|█████████████████████████████████████████| 126/126 [00:01<00:00, 64.95it/s]
Task 2, Epoch 4: Accuracy Train = 99.99%
Epoch 4: 126it [00:02, 44.55it/s] 100%|█████████████████████████████████████████| 126/126 [00:01<00:00, 66.12it/s]
Task 2, Epoch 5: Accuracy Train = 99.98%
100%|█████████████████████████████████████████| 122/122 [00:03<00:00, 39.77it/s]
Task 2: Accuracy Test = 49.36% Starting Task 3 - Training on classes: [16, 17, 18, 19, 20, 21, 22, 23]
Epoch 0: 49it [00:02, 17.47it/s] 100%|███████████████████████████████████████████| 49/49 [00:01<00:00, 45.80it/s]
Task 3, Epoch 1: Accuracy Train = 98.59%
Epoch 1: 49it [00:01, 45.09it/s] 100%|███████████████████████████████████████████| 49/49 [00:00<00:00, 67.70it/s]
Task 3, Epoch 2: Accuracy Train = 99.68%
Epoch 2: 49it [00:00, 51.68it/s] 100%|███████████████████████████████████████████| 49/49 [00:00<00:00, 68.11it/s]
Task 3, Epoch 3: Accuracy Train = 99.74%
Epoch 3: 49it [00:01, 47.47it/s] 100%|███████████████████████████████████████████| 49/49 [00:00<00:00, 65.84it/s]
Task 3, Epoch 4: Accuracy Train = 100.00%
Epoch 4: 49it [00:01, 44.71it/s] 100%|███████████████████████████████████████████| 49/49 [00:00<00:00, 65.98it/s]
Task 3, Epoch 5: Accuracy Train = 100.00%
100%|█████████████████████████████████████████| 144/144 [00:02<00:00, 54.06it/s]
Task 3: Accuracy Test = 12.85% Starting Task 4 - Training on classes: [24, 25, 26, 27, 28, 29, 30, 31]
Epoch 0: 50it [00:02, 18.34it/s] 100%|███████████████████████████████████████████| 50/50 [00:01<00:00, 46.39it/s]
Task 4, Epoch 1: Accuracy Train = 96.01%
Epoch 1: 50it [00:01, 44.43it/s] 100%|███████████████████████████████████████████| 50/50 [00:00<00:00, 67.56it/s]
Task 4, Epoch 2: Accuracy Train = 99.37%
Epoch 2: 50it [00:00, 57.53it/s] 100%|███████████████████████████████████████████| 50/50 [00:00<00:00, 69.38it/s]
Task 4, Epoch 3: Accuracy Train = 99.87%
Epoch 3: 50it [00:00, 53.33it/s] 100%|███████████████████████████████████████████| 50/50 [00:00<00:00, 66.46it/s]
Task 4, Epoch 4: Accuracy Train = 99.91%
Epoch 4: 50it [00:01, 47.69it/s] 100%|███████████████████████████████████████████| 50/50 [00:00<00:00, 63.53it/s]
Task 4, Epoch 5: Accuracy Train = 100.00%
100%|█████████████████████████████████████████| 167/167 [00:03<00:00, 53.80it/s]
Task 4: Accuracy Test = 11.76% Starting Task 5 - Training on classes: [32, 33, 34, 35, 36, 37, 38, 39]
Epoch 0: 60it [00:03, 18.13it/s] 100%|███████████████████████████████████████████| 60/60 [00:01<00:00, 47.98it/s]
Task 5, Epoch 1: Accuracy Train = 99.23%
Epoch 1: 60it [00:01, 50.25it/s] 100%|███████████████████████████████████████████| 60/60 [00:00<00:00, 70.68it/s]
Task 5, Epoch 2: Accuracy Train = 99.87%
Epoch 2: 60it [00:01, 47.79it/s] 100%|███████████████████████████████████████████| 60/60 [00:00<00:00, 65.88it/s]
Task 5, Epoch 3: Accuracy Train = 99.92%
Epoch 3: 60it [00:01, 45.29it/s] 100%|███████████████████████████████████████████| 60/60 [00:00<00:00, 66.15it/s]
Task 5, Epoch 4: Accuracy Train = 99.97%
Epoch 4: 60it [00:01, 46.30it/s] 100%|███████████████████████████████████████████| 60/60 [00:00<00:00, 66.43it/s]
Task 5, Epoch 5: Accuracy Train = 100.00%
100%|█████████████████████████████████████████| 194/194 [00:03<00:00, 50.21it/s]
Task 5: Accuracy Test = 12.87%
Run history:
| incremental_accuracy | █▄▁▁▁ |
| task | ▁▃▅▆█ |
| train/train_loss | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▂▂▁▁▁▁▁▁▁▁▁▁▁ |
Run summary:
| incremental_accuracy | 12.87328 |
| task | 4 |
| train/train_loss | 0.02354 |
View project at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
./wandb/run-20250322_050856-252cfjp6/logs
Section 8: Upper bound¶
#################################################
## Global upper bound (all data, all classes)
####
# One task + all classes computed using 5 epochs
model = copy.deepcopy(copy_model)
accu = incremental_learning(model, train_dataset, train_target, test_dataset, test_target,
1, (numclasses // num_tasks) * num_tasks, batch_size, 5, lr, device)
print(f"!!!!! Upper bound of accuracy = {accu[0]:.2f}%")
Starting Task 1 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
Epoch 0: 100%|████████████████████████████████| 407/407 [00:05<00:00, 68.49it/s] 100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 66.28it/s]
Task 1, Epoch 1: Accuracy Train = 96.95%
Epoch 1: 100%|████████████████████████████████| 407/407 [00:05<00:00, 69.75it/s] 100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 67.04it/s]
Task 1, Epoch 2: Accuracy Train = 98.92%
Epoch 2: 100%|████████████████████████████████| 407/407 [00:05<00:00, 70.56it/s] 100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 67.61it/s]
Task 1, Epoch 3: Accuracy Train = 99.44%
Epoch 3: 100%|████████████████████████████████| 407/407 [00:05<00:00, 70.86it/s] 100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 65.29it/s]
Task 1, Epoch 4: Accuracy Train = 99.64%
Epoch 4: 100%|████████████████████████████████| 407/407 [00:05<00:00, 70.10it/s] 100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 66.00it/s]
Task 1, Epoch 5: Accuracy Train = 99.85%
100%|█████████████████████████████████████████| 194/194 [00:03<00:00, 56.78it/s]
Task 1: Accuracy Test = 91.00% !!!!! Upper bound of accuracy = 91.00%
########################################
## Upper bound for each task (takes some time)
####
# initialise a wandb run
num_epochs = 5
run = wandb.init(
project="GTSRB-CIL",
name = "Upper bound",
config={
"epochs": num_epochs,
"batch_size": batch_size,
"num_tasks": num_tasks,
"classes_per_task": classes_per_task,
"lr": lr,
})
# Non incremental data (learn all classes from all data for each task)
model = copy.deepcopy(copy_model)
incremental_learning_wandb(model, train_dataset, train_target, test_dataset, test_target,
num_tasks, classes_per_task, batch_size, num_epochs, lr, device, non_incremental = True)
wandb.finish()
/home/home/projects/tensor/projet_Machine_learning/wandb/run-20250322_051144-fy9f46ci
Starting Task 1 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7]
Epoch 0: 125it [00:02, 50.60it/s] 100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 70.31it/s]
Task 1, Epoch 1: Accuracy Train = 66.49%
Epoch 1: 125it [00:02, 48.15it/s] 100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 68.05it/s]
Task 1, Epoch 2: Accuracy Train = 90.68%
Epoch 2: 125it [00:02, 46.76it/s] 100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 68.28it/s]
Task 1, Epoch 3: Accuracy Train = 96.42%
Epoch 3: 125it [00:02, 46.72it/s] 100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 67.39it/s]
Task 1, Epoch 4: Accuracy Train = 97.55%
Epoch 4: 125it [00:02, 46.88it/s] 100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 67.46it/s]
Task 1, Epoch 5: Accuracy Train = 98.52%
100%|███████████████████████████████████████████| 61/61 [00:00<00:00, 65.34it/s]
Task 1: Accuracy Test = 89.33% Starting Task 2 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Epoch 0: 250it [00:05, 46.46it/s] 100%|█████████████████████████████████████████| 250/250 [00:03<00:00, 64.42it/s]
Task 2, Epoch 1: Accuracy Train = 89.45%
Epoch 1: 250it [00:05, 46.90it/s] 100%|█████████████████████████████████████████| 250/250 [00:03<00:00, 65.07it/s]
Task 2, Epoch 2: Accuracy Train = 96.90%
Epoch 2: 250it [00:05, 46.84it/s] 100%|█████████████████████████████████████████| 250/250 [00:03<00:00, 65.78it/s]
Task 2, Epoch 3: Accuracy Train = 98.82%
Epoch 3: 250it [00:05, 47.55it/s] 100%|█████████████████████████████████████████| 250/250 [00:03<00:00, 65.07it/s]
Task 2, Epoch 4: Accuracy Train = 99.62%
Epoch 4: 250it [00:05, 46.15it/s] 100%|█████████████████████████████████████████| 250/250 [00:03<00:00, 65.77it/s]
Task 2, Epoch 5: Accuracy Train = 99.62%
100%|█████████████████████████████████████████| 122/122 [00:01<00:00, 64.16it/s]
Task 2: Accuracy Test = 92.43% Starting Task 3 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
Epoch 0: 299it [00:06, 46.58it/s] 100%|█████████████████████████████████████████| 299/299 [00:04<00:00, 64.41it/s]
Task 3, Epoch 1: Accuracy Train = 93.37%
Epoch 1: 299it [00:06, 45.74it/s] 100%|█████████████████████████████████████████| 299/299 [00:04<00:00, 60.77it/s]
Task 3, Epoch 2: Accuracy Train = 98.53%
Epoch 2: 299it [00:06, 44.30it/s] 100%|█████████████████████████████████████████| 299/299 [00:04<00:00, 60.26it/s]
Task 3, Epoch 3: Accuracy Train = 98.52%
Epoch 3: 299it [00:06, 49.09it/s] 100%|█████████████████████████████████████████| 299/299 [00:04<00:00, 61.96it/s]
Task 3, Epoch 4: Accuracy Train = 99.34%
Epoch 4: 299it [00:06, 44.59it/s] 100%|█████████████████████████████████████████| 299/299 [00:04<00:00, 63.11it/s]
Task 3, Epoch 5: Accuracy Train = 99.76%
100%|█████████████████████████████████████████| 144/144 [00:02<00:00, 52.74it/s]
Task 3: Accuracy Test = 91.18% Starting Task 4 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
Epoch 0: 348it [00:07, 47.85it/s] 100%|█████████████████████████████████████████| 348/348 [00:05<00:00, 61.79it/s]
Task 4, Epoch 1: Accuracy Train = 93.79%
Epoch 1: 348it [00:07, 49.44it/s] 100%|█████████████████████████████████████████| 348/348 [00:05<00:00, 63.44it/s]
Task 4, Epoch 2: Accuracy Train = 98.32%
Epoch 2: 348it [00:07, 45.74it/s] 100%|█████████████████████████████████████████| 348/348 [00:05<00:00, 63.26it/s]
Task 4, Epoch 3: Accuracy Train = 99.02%
Epoch 3: 348it [00:07, 45.41it/s] 100%|█████████████████████████████████████████| 348/348 [00:05<00:00, 63.19it/s]
Task 4, Epoch 4: Accuracy Train = 99.20%
Epoch 4: 348it [00:07, 45.91it/s] 100%|█████████████████████████████████████████| 348/348 [00:05<00:00, 63.37it/s]
Task 4, Epoch 5: Accuracy Train = 99.61%
100%|█████████████████████████████████████████| 167/167 [00:03<00:00, 48.42it/s]
Task 4: Accuracy Test = 88.33% Starting Task 5 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
Epoch 0: 407it [00:08, 47.51it/s] 100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 60.86it/s]
Task 5, Epoch 1: Accuracy Train = 93.77%
Epoch 1: 407it [00:08, 47.60it/s] 100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 60.09it/s]
Task 5, Epoch 2: Accuracy Train = 98.69%
Epoch 2: 407it [00:08, 47.49it/s] 100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 61.13it/s]
Task 5, Epoch 3: Accuracy Train = 99.05%
Epoch 3: 407it [00:08, 49.19it/s] 100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 61.79it/s]
Task 5, Epoch 4: Accuracy Train = 99.68%
Epoch 4: 407it [00:08, 47.85it/s] 100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 61.82it/s]
Task 5, Epoch 5: Accuracy Train = 99.52%
100%|█████████████████████████████████████████| 194/194 [00:04<00:00, 47.40it/s]
Task 5: Accuracy Test = 89.81%
Run history:
| incremental_accuracy | ▃█▆▁▄ |
| task | ▁▃▅▆█ |
| train/train_loss | ▅▂▂▂▁▂▂▁▁▁▁▅▂▁▁▁▁▁▁█▂▁▁▁▁▁▁▁▁▁▃▂▂▁▁▁▁▁▁▁ |
Run summary:
| incremental_accuracy | 89.81437 |
| task | 4 |
| train/train_loss | 0.00503 |
View project at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
./wandb/run-20250322_051144-fy9f46ci/logs
Section 9: Rehearsal Memory Buffer + Knowledge Distillation Loss (KD)¶
Knowledge Distillation :
$$ \mathcal{L} = \underbrace{\ell(f(\mathbf{x}), y)}_{\text{Learning New Classes}} + \underbrace{\sum_{k=1}^{|\mathcal{Y}_{b-1}|} -S_k(f^{b-1}(\mathbf{x})) \log S_k(f(\mathbf{x})) }_{\text{Remembering Old Classes}} $$
$\ell(f(x), y)$: Standard cross-entropy loss for learning new classes.
Second term: Knowledge Distillation, encouraging the new model's output $f(x)$ to match the old model's output $f^{b-1}(x)$.
$S_k$: Softmax function applied to logits.
A version with temperature-scaled softmax :
$$ \mathcal{L} = (1-\lambda) \underbrace{\ell(f(\mathbf{x}), y)}_{\text{Learning New Classes}} + \lambda \underbrace{ \sum_{k=1}^{|\mathcal{Y}_{b-1}|} -S_k\left( \frac{f^{b-1}(\mathbf{x})}{T} \right) \log S_k\left( \frac{f(\mathbf{x})}{T} \right) }_{\text{Remembering Old Classes (KD with Temperature $T$)}} $$
$\ell(f(x), y)$: Cross-entropy loss for the current new classes. $\leadsto$
loss_cls = criterion(logits, targets)The second term uses softmax temperature scaling:
$$ S_k\left( \frac{f(x)}{T} \right) = \frac{\exp\left( f_k(x) / T \right)}{\sum_{j} \exp\left( f_j(x) / T \right)}, $$
where $T$ is the temperature hyperparameter, typically $T > 1$ to soften the distribution.
$\leadsto$
loss_kd = KD_loss(logits[:, :len(current_classes)-classes_per_task], old_logits, T)$\lambda = \frac{{|\mathcal{Y}_{b-1}|}}{{|\mathcal{Y}_{b}|}}$ : A balancing coefficient, denotes the proportion of old classes among all classes..
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset, Subset
import torch.nn.functional as F
def KD_loss(student_logits, teacher_logits, T=2):
student_log_prob = F.log_softmax(student_logits / T, dim=1)
teacher_prob = F.softmax(teacher_logits / T, dim=1)
return F.kl_div(student_log_prob, teacher_prob, reduction='batchmean') * (T * T)
Rehearsal Memory Buffer
$$ \mathcal{L} = \sum_{(\mathbf{x}, y) \in (\mathcal{D}^{b} \cup \mathcal{E})} \ell(f(\mathbf{x}), y). $$
$\mathcal{D}^{b}$ represents the data of the current task
$\mathcal{E}$ denotes the exemplar set (Rehearsal Buffer) containing selected samples from previous tasks.
During training, the model jointly uses the current task data $\mathcal{D}^{b}$ and the stored exemplars $\mathcal{E}$, helping the model to retain knowledge from old tasks while learning new ones.
$\leadsto$ combined_dataset = ConcatDataset([current_subset, memory_dataset])
def incremental_learning_rehearsal_KD_wandb(model, train_dataset, train_target, test_dataset, test_target,
num_tasks, classes_per_task, batch_size, num_epochs, lr, device):
from collections import deque
nclasses = len(np.unique(train_target))
all_classes = list(range(nclasses))
criterion = nn.CrossEntropyLoss()
current_classes = []
accuracies = []
# Replay Buffer 复现缓冲区
total_mem_size = 200 # 缓冲区总大小 / Total memory buffer size
memory_data = [] # 存储旧样本数据 / Store exemplar data
memory_labels = [] # 存储旧样本标签 / Store exemplar labels
old_model = None
T = 2 # KD 温度系数 / Temperature for Knowledge Distillation
wandb.define_metric("task")
wandb.define_metric("incremental_accuracy", step_metric="task")
for task in range(num_tasks):
# 当前任务类别 / Select current task classes
task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
current_classes.extend(task_classes)
# === 新任务数据子集 / Create subset for new task ===
'''
train_indices = [i for i, label in enumerate(train_target) if label in task_classes]
current_subset = Subset(train_dataset, train_indices)
'''
train_indices = [i for i, label in enumerate(train_target) if label in task_classes]
# 提取子集数据 & 标签,并转换为 TensorDataset / Extract subset data & labels, convert to TensorDataset
subset_data = []
subset_labels = []
for idx in train_indices:
img, label = train_dataset[idx]
subset_data.append(img)
subset_labels.append(label)
subset_data = torch.stack(subset_data)
subset_labels = torch.tensor(subset_labels, dtype=torch.long)
current_subset = TensorDataset(subset_data, subset_labels)
# === Memory Buffer 数据集 / Memory buffer dataset ===
if memory_data:
'''
mem_tensor_data = torch.stack(memory_data)
mem_tensor_labels = torch.tensor(memory_labels)
memory_dataset = TensorDataset(mem_tensor_data, mem_tensor_labels)
combined_dataset = ConcatDataset([current_subset, memory_dataset])'
'''
mem_tensor_data = torch.stack(memory_data)
mem_tensor_labels = torch.tensor(memory_labels, dtype=torch.long) # 指定 long 类型
memory_dataset = TensorDataset(mem_tensor_data, mem_tensor_labels)
combined_dataset = ConcatDataset([current_subset, memory_dataset])
else:
combined_dataset = current_subset
# 创建数据加载器 / Create DataLoaders
train_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)
test_loader = create_dataloader(test_dataset, test_target, current_classes, batch_size, shuffle=False)
# === 扩展输出层 / Expand output layer for new classes ===
if task == 0:
model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
else:
old_weight = model.fc.weight.data.clone()
old_bias = model.fc.bias.data.clone()
model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
model.fc.weight.data[:len(old_weight)] = old_weight
model.fc.bias.data[:len(old_bias)] = old_bias
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
print(f"Task {task+1}: Training on classes {task_classes}")
# === 训练模型 / Train model ===
for epoch in range(num_epochs):
model.train()
running_loss, correct, total = 0.0, 0, 0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
logits = model(inputs)
loss_cls = criterion(logits, targets)
# === Knowledge Distillation 旧模型引导 / KD loss with old model ===
if old_model:
with torch.no_grad():
old_logits = old_model(inputs)
loss_kd = KD_loss(logits[:, :len(current_classes)-classes_per_task], old_logits, T)
loss = loss_cls + loss_kd
else:
loss = loss_cls
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
preds = logits.argmax(dim=1)
correct += (preds == targets).sum().item()
total += targets.size(0)
scheduler.step()
train_acc = correct / total * 100
print(f"Task {task+1}, Epoch {epoch+1}: Loss {running_loss:.3f}, Train Acc {train_acc:.2f}%")
wandb.log({
"train/train_loss": running_loss,
"train/train_accuracy": train_acc,
"epoch": epoch,
"task": task
})
# === 测试模型 / Evaluate on test set ===
test_acc = evaluate(model, test_loader, device)
accuracies.append(test_acc)
print(f"Task {task+1}: Test Acc = {test_acc:.2f}%")
wandb.log({"incremental_accuracy": test_acc, "task": task})
# === 更新 Memory Buffer / Update memory buffer ===
mem_per_class = total_mem_size // len(current_classes) # 每类分配样本数 / Exemplars per class
memory_data = []
memory_labels = []
for cls in current_classes:
indices = [i for i, label in enumerate(train_target) if label == cls]
selected = np.random.choice(indices, min(mem_per_class, len(indices)), replace=False)
for idx in selected:
img = train_dataset[idx][0]
memory_data.append(img)
memory_labels.append(cls)
print(f"Memory Buffer: {len(memory_labels)} samples total")
# 保存当前模型 / Store old model
old_model = copy.deepcopy(model).to(device).eval()
return accuracies
num_epochs = 5
run = wandb.init(
project="GTSRB-CIL",
name="Rehearsal_KD",
config={
"epochs": num_epochs,
"batch_size": batch_size,
"num_tasks": num_tasks,
"classes_per_task": classes_per_task,
"lr": lr,
"rehearsal_buffer_size": 200,
"kd_temperature": 2
})
/home/home/projects/tensor/projet_Machine_learning/wandb/run-20250322_051625-s2fkx109
model = copy.deepcopy(copy_model) # 用预训练模型开始
incremental_learning_rehearsal_KD_wandb(
model,
train_dataset,
train_target,
test_dataset,
test_target,
num_tasks,
classes_per_task,
batch_size,
num_epochs,
lr,
device
)
wandb.finish()
Task 1: Training on classes [0, 1, 2, 3, 4, 5, 6, 7] Task 1, Epoch 1: Loss 96.120, Train Acc 76.68% Task 1, Epoch 2: Loss 34.481, Train Acc 93.79% Task 1, Epoch 3: Loss 24.019, Train Acc 95.57% Task 1, Epoch 4: Loss 18.945, Train Acc 96.63% Task 1, Epoch 5: Loss 15.440, Train Acc 97.45%
100%|███████████████████████████████████████████| 61/61 [00:00<00:00, 71.29it/s]
Task 1: Test Acc = 90.21% Memory Buffer: 200 samples total Task 2: Training on classes [8, 9, 10, 11, 12, 13, 14, 15] Task 2, Epoch 1: Loss 108.750, Train Acc 87.26% Task 2, Epoch 2: Loss 20.895, Train Acc 96.93% Task 2, Epoch 3: Loss 14.095, Train Acc 97.84% Task 2, Epoch 4: Loss 11.093, Train Acc 98.43% Task 2, Epoch 5: Loss 9.449, Train Acc 98.65%
100%|█████████████████████████████████████████| 122/122 [00:02<00:00, 56.74it/s]
Task 2: Test Acc = 74.71% Memory Buffer: 192 samples total Task 3: Training on classes [16, 17, 18, 19, 20, 21, 22, 23] Task 3, Epoch 1: Loss 113.056, Train Acc 73.07% Task 3, Epoch 2: Loss 20.128, Train Acc 93.99% Task 3, Epoch 3: Loss 10.361, Train Acc 96.59% Task 3, Epoch 4: Loss 7.401, Train Acc 97.68% Task 3, Epoch 5: Loss 5.821, Train Acc 98.31%
100%|█████████████████████████████████████████| 144/144 [00:02<00:00, 62.60it/s]
Task 3: Test Acc = 68.27% Memory Buffer: 192 samples total Task 4: Training on classes [24, 25, 26, 27, 28, 29, 30, 31] Task 4, Epoch 1: Loss 131.107, Train Acc 63.29% Task 4, Epoch 2: Loss 22.356, Train Acc 92.65% Task 4, Epoch 3: Loss 13.884, Train Acc 95.91% Task 4, Epoch 4: Loss 10.675, Train Acc 97.00% Task 4, Epoch 5: Loss 8.928, Train Acc 97.51%
100%|█████████████████████████████████████████| 167/167 [00:02<00:00, 66.06it/s]
Task 4: Test Acc = 56.02% Memory Buffer: 192 samples total Task 5: Training on classes [32, 33, 34, 35, 36, 37, 38, 39] Task 5, Epoch 1: Loss 117.756, Train Acc 76.96% Task 5, Epoch 2: Loss 15.510, Train Acc 95.87% Task 5, Epoch 3: Loss 9.946, Train Acc 97.18% Task 5, Epoch 4: Loss 7.597, Train Acc 97.96% Task 5, Epoch 5: Loss 6.423, Train Acc 98.29%
100%|█████████████████████████████████████████| 194/194 [00:02<00:00, 67.68it/s]
Task 5: Test Acc = 56.77% Memory Buffer: 200 samples total
Run history:
| epoch | ▁▃▅▆█▁▃▅▆█▁▃▅▆█▁▃▅▆█▁▃▅▆█ |
| incremental_accuracy | █▅▄▁▁ |
| task | ▁▁▁▁▁▁▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆██████ |
| train/train_accuracy | ▄▇▇██▆████▃▇███▁▇▇██▄▇███ |
| train/train_loss | ▆▃▂▂▂▇▂▁▁▁▇▂▁▁▁█▂▁▁▁▇▂▁▁▁ |
Run summary:
| epoch | 4 |
| incremental_accuracy | 56.77159 |
| task | 4 |
| train/train_accuracy | 98.28802 |
| train/train_loss | 6.42265 |
View project at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
./wandb/run-20250322_051625-s2fkx109/logs
Section 10: Dynamically Expandable Representation (DER)¶
from collections import Counter
# 统计每个类别标签的数量,并按数量降序排列
# Count occurrences of each label, sort by most common
label_counts = Counter(train_target).most_common()
# 打印每个类别的统计结果
# Print statistics for each class
print("Count\tClassID\tClassName")
for l, c in label_counts:
print(f"{c}\t{l}\t{class_names[l]}") # 输出格式: 数量, 类别编号, 类别名称
Count ClassID ClassName 1500 1 Speed limit (30km/h) 1500 2 Speed limit (50km/h) 1440 13 Yield 1410 12 Priority road 1380 38 Keep right 1350 10 No passing for vechiles over 3.5 metric tons 1320 4 Speed limit (70km/h) 1260 5 Speed limit (80km/h) 1020 25 Road work 990 9 No passing 960 3 Speed limit (60km/h) 960 7 Speed limit (100km/h) 960 8 Speed limit (120km/h) 900 11 Right-of-way at the next intersection 810 18 General caution 810 35 Ahead only 750 17 No entry 540 14 Stop 540 31 Wild animals crossing 480 33 Turn right ahead 420 15 No vechiles 420 26 Traffic signals 360 23 Slippery road 360 28 Children crossing 300 6 End of speed limit (80km/h) 300 16 Vechiles over 3.5 metric tons prohibited 300 30 Beware of ice/snow 300 34 Turn left ahead 270 22 Bumpy road 270 36 Go straight or right 240 20 Dangerous curve to the right 240 21 Double curve 240 40 Roundabout mandatory 210 39 Keep left 180 24 Road narrows on the right 180 27 Pedestrians 180 29 Bicycles crossing 180 32 End of all speed and passing limits 180 41 End of no passing 180 42 End of no passing by vechiles over 3.5 metric tons 150 0 Speed limit (20km/h) 150 19 Dangerous curve to the left 150 37 Go straight or left
$\leadsto$ Remark 2
Buffer Size Strategy :
Section 9:
total_mem_size = 200,mem_per_class = total_mem_size // len(current_classes)The total buffer size is fixed at 200 samples and After each task, the buffer is evenly divided among all seen classes.For example:
Task 1: 200/8 = 25 samples per class
Task 5: 200/40 = 5 samples per classEventually, old classes have very few samples, leading to severe forgetting in later stages.
Section 10:
mem_per_class = 10The buffer keeps a fixed number of 10 samples per class.For example:
Task 1: Seen Classes = 8 | Buffer Size = 80
Task 5: Seen Classes = 40 | Buffer Size = 400How Buffer Is Used During Training :
Section 9:
combined_dataset = ConcatDataset([current_task_data, memory_dataset])Simply concatenates new class data and buffer samples $\leadsto$ The amount of new class data is much larger than old class datasection 10:
new_loader = DataLoader(new_data, batch_size/2),mem_loader = DataLoader(memory_data, batch_size/2)Each batch contains 50% new class + 50% buffer samples $\leadsto$ Balanced SamplingBias Correction :
Section 9: No additional Bias Correction step.
section 10:
buffer_loader = DataLoader(memory_dataset, batch_size),for epoch in range(1):``for inputs, targets in buffer_loader:Fine-tune only on buffer samples $\leadsto$ Helps to preserve the decision boundaries of old classes
def evaluate(model, test_loader, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in tqdm(test_loader, ncols=80):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
if isinstance(outputs, tuple): # DER模型,取第一个主分类器输出
outputs = outputs[0]
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
Backbone Expansion
$$ f(x) = W_{\text{new}}^{\top} [\phi_{\text{old}}(x), \phi_{\text{new}}(x)] $$
- $\phi_{\text{old}}(x)$: Feature representation from the old (previous) backbone (frozen after training).
- $\phi_{\text{new}}(x)$: Feature representation extracted by the new backbone for the current task.
- $[\cdot, \cdot]$: Feature concatenation operation.
- $W_{\text{new}}$: Fully connected layer applied after feature aggregation, dynamically expanded.
$\leadsto$ features = F.normalize(self.feature_fc(x), p=2, dim=1), out = self.classifier(features) # W_new applied to features,
aux_out = self.aux_classifier(features) # Auxiliary classifier
Loss Function (Cross-task classification)
$$ \mathcal{L} = \sum_{k=1}^{|\mathcal{Y}_b|} -\mathbb{I}(y = k) \log S_k\left( W_{\text{new}}^{\top} [\bar{\phi}_{\text{old}}(x), \phi_{\text{new}}(x)] \right) $$
- $\bar{\phi}_{\text{old}}(x)$: Indicates the frozen old backbone.
- $\mathcal{Y}_b$: The set of all classes learned up to the current task.
- $S_k(\cdot)$: Softmax function applied to compute the probability of class $k$.
Additionally, DER employs an \textbf{auxiliary loss} to encourage better differentiation between old and new classes (optional).
$\leadsto$ def expand_output(self, n_new_classes):, old_weight = self.classifier.weight.data.clone(), self.classifier = nn.Linear(..., total_classes)
import torch.nn as nn
import torch.nn.functional as F
class DER_CNN(nn.Module):
def __init__(self, n_in=3, base_feature_dim=128):
super(DER_CNN, self).__init__()
# Shared feature extractor
self.conv1 = nn.Conv2d(n_in, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU()
self.feature_fc = nn.Linear(64 * 8 * 8, base_feature_dim)
# Dynamically expandable classifier
self.classifier = nn.Linear(base_feature_dim, 0) # start with 0 output
self.aux_classifier = nn.Linear(base_feature_dim, 0) # auxiliary classifier
def expand_output(self, n_new_classes):
# Main classifier
old_weight = self.classifier.weight.data.clone() if self.classifier.out_features > 0 else None
old_bias = self.classifier.bias.data.clone() if self.classifier.out_features > 0 else None
total_classes = self.classifier.out_features + n_new_classes
self.classifier = nn.Linear(self.feature_fc.out_features, total_classes).to(self.feature_fc.weight.device)
if old_weight is not None:
self.classifier.weight.data[:old_weight.size(0)] = old_weight
self.classifier.bias.data[:old_bias.size(0)] = old_bias
# Auxiliary classifier
old_aux_weight = self.aux_classifier.weight.data.clone() if self.aux_classifier.out_features > 0 else None
old_aux_bias = self.aux_classifier.bias.data.clone() if self.aux_classifier.out_features > 0 else None
self.aux_classifier = nn.Linear(self.feature_fc.out_features, total_classes).to(self.feature_fc.weight.device)
if old_aux_weight is not None:
self.aux_classifier.weight.data[:old_aux_weight.size(0)] = old_aux_weight
self.aux_classifier.bias.data[:old_aux_bias.size(0)] = old_aux_bias
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
features = self.relu(self.feature_fc(x))
features = F.normalize(features, p=2, dim=1) # Add feature normalization
out = self.classifier(features)
aux_out = self.aux_classifier(features)
return out, aux_out, features
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
import torch.optim as optim
from torch.optim import lr_scheduler
def incremental_learning_DER(model, train_dataset, train_target, test_dataset, test_target,
num_tasks, classes_per_task, batch_size, num_epochs, lr, device):
nclasses = len(np.unique(train_target))
all_classes = list(range(nclasses))
criterion = nn.CrossEntropyLoss()
current_classes = []
accuracies = []
# Fixed per class buffer
mem_per_class = 10
memory_data = []
memory_labels = []
old_model = None
T = 2 # KD温度
alpha = 0.7 # Aux loss weight
beta = 1.5 # KD loss weight
wandb.define_metric("task")
wandb.define_metric("incremental_accuracy", step_metric="task")
for task in range(num_tasks):
task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
current_classes.extend(task_classes)
model.expand_output(len(task_classes))
# Prepare new class data
train_indices = [i for i, label in enumerate(train_target) if label in task_classes]
subset_data, subset_labels = [], []
for idx in train_indices:
img, label = train_dataset[idx]
subset_data.append(img)
subset_labels.append(label)
subset_data = torch.stack(subset_data)
subset_labels = torch.tensor(subset_labels, dtype=torch.long)
current_subset = TensorDataset(subset_data, subset_labels)
# Add buffer data
if memory_data:
mem_tensor_data = torch.stack(memory_data)
mem_tensor_labels = torch.tensor(memory_labels, dtype=torch.long)
memory_dataset = TensorDataset(mem_tensor_data, mem_tensor_labels)
else:
memory_dataset = None
test_loader = create_dataloader(test_dataset, test_target, current_classes, batch_size, shuffle=False)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.3)
print(f"Task {task+1}: Training on classes {task_classes}")
# -------- Balanced Sampling ---------
new_loader = DataLoader(current_subset, batch_size=batch_size//2, shuffle=True)
if memory_dataset:
mem_loader = DataLoader(memory_dataset, batch_size=batch_size//2, shuffle=True)
else:
mem_loader = None
for epoch in range(num_epochs):
model.train()
running_loss, correct, total = 0.0, 0, 0
new_iter = iter(new_loader)
mem_iter = iter(mem_loader) if mem_loader else None
steps = len(new_loader)
for step in range(steps):
try:
new_inputs, new_targets = next(new_iter)
except StopIteration:
break
if mem_iter:
try:
mem_inputs, mem_targets = next(mem_iter)
except StopIteration:
mem_iter = iter(mem_loader)
mem_inputs, mem_targets = next(mem_iter)
inputs = torch.cat([new_inputs, mem_inputs], dim=0).to(device)
targets = torch.cat([new_targets, mem_targets], dim=0).to(device)
else:
inputs, targets = new_inputs.to(device), new_targets.to(device)
logits, aux_logits, _ = model(inputs)
loss_cls = criterion(logits, targets)
loss_aux = criterion(aux_logits, targets)
if old_model:
with torch.no_grad():
old_logits, _, _ = old_model(inputs)
kd_loss = KD_loss(logits[:, :len(current_classes)-classes_per_task], old_logits, T)
loss = loss_cls + alpha * loss_aux + beta * kd_loss
else:
loss = loss_cls + alpha * loss_aux
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
preds = logits.argmax(dim=1)
correct += (preds == targets).sum().item()
total += targets.size(0)
scheduler.step()
train_acc = correct / total * 100
print(f"Task {task+1}, Epoch {epoch+1}: Loss {running_loss:.3f}, Train Acc {train_acc:.2f}%")
wandb.log({
"train/train_loss": running_loss,
"train/train_accuracy": train_acc,
"epoch": epoch,
"task": task
})
# --------- Evaluate ---------
test_acc = evaluate(model, test_loader, device)
accuracies.append(test_acc)
print(f"Task {task+1}: Test Acc = {test_acc:.2f}%")
wandb.log({"incremental_accuracy": test_acc, "task": task})
# --------- Update Buffer (Fixed per class) ---------
memory_data = []
memory_labels = []
for cls in current_classes:
indices = [i for i, label in enumerate(train_target) if label == cls]
selected = np.random.choice(indices, min(mem_per_class, len(indices)), replace=False)
for idx in selected:
img = train_dataset[idx][0]
memory_data.append(img)
memory_labels.append(cls)
print(f"Memory Buffer updated: {len(memory_labels)} samples")
# --------- Bias Correction Step (Fine-tune on buffer) ---------
if memory_dataset:
buffer_loader = DataLoader(memory_dataset, batch_size=batch_size, shuffle=True)
for epoch in range(1):
for inputs, targets in buffer_loader:
inputs, targets = inputs.to(device), targets.to(device)
logits, aux_logits, _ = model(inputs)
loss_cls = criterion(logits, targets)
optimizer.zero_grad()
loss_cls.backward()
optimizer.step()
print("Bias correction fine-tuning done.")
# Copy model for KD
old_model = copy.deepcopy(model).to(device).eval()
return accuracies
num_epochs = 5
run = wandb.init(
project="GTSRB-CIL",
name="DER",
config={
"epochs": num_epochs,
"batch_size": batch_size,
"num_tasks": num_tasks,
"classes_per_task": classes_per_task,
"lr": lr,
"rehearsal_buffer_size": 10,
"kd_temperature": 2
})
# Initialize DER model
model = DER_CNN(n_in=3).to(device)
incremental_learning_DER(
model,
train_dataset,
train_target,
test_dataset,
test_target,
num_tasks,
classes_per_task,
batch_size,
num_epochs,
lr,
device
)
wandb.finish()
/home/home/projects/tensor/projet_Machine_learning/wandb/run-20250322_051655-es0akd2d
/home/home/miniconda3/envs/tensor/lib/python3.11/site-packages/torch/nn/init.py:511: UserWarning: Initializing zero-element tensors is a no-op
warnings.warn("Initializing zero-element tensors is a no-op")
Task 1: Training on classes [0, 1, 2, 3, 4, 5, 6, 7] Task 1, Epoch 1: Loss 771.271, Train Acc 31.03% Task 1, Epoch 2: Loss 505.882, Train Acc 73.91% Task 1, Epoch 3: Loss 258.374, Train Acc 92.98% Task 1, Epoch 4: Loss 153.143, Train Acc 96.70% Task 1, Epoch 5: Loss 123.441, Train Acc 97.35%
100%|███████████████████████████████████████████| 61/61 [00:00<00:00, 73.79it/s]
Task 1: Test Acc = 90.16% Memory Buffer updated: 80 samples Task 2: Training on classes [8, 9, 10, 11, 12, 13, 14, 15] Task 2, Epoch 1: Loss 660.366, Train Acc 61.43% Task 2, Epoch 2: Loss 386.230, Train Acc 95.53% Task 2, Epoch 3: Loss 227.683, Train Acc 99.76% Task 2, Epoch 4: Loss 164.217, Train Acc 99.97% Task 2, Epoch 5: Loss 144.893, Train Acc 99.98%
100%|█████████████████████████████████████████| 122/122 [00:01<00:00, 68.84it/s]
Task 2: Test Acc = 90.57% Memory Buffer updated: 160 samples Bias correction fine-tuning done. Task 3: Training on classes [16, 17, 18, 19, 20, 21, 22, 23] Task 3, Epoch 1: Loss 257.001, Train Acc 60.10% Task 3, Epoch 2: Loss 203.053, Train Acc 78.44% Task 3, Epoch 3: Loss 166.570, Train Acc 85.33% Task 3, Epoch 4: Loss 145.274, Train Acc 89.21% Task 3, Epoch 5: Loss 136.623, Train Acc 91.59%
100%|█████████████████████████████████████████| 144/144 [00:02<00:00, 69.29it/s]
Task 3: Test Acc = 85.74% Memory Buffer updated: 240 samples Bias correction fine-tuning done. Task 4: Training on classes [24, 25, 26, 27, 28, 29, 30, 31] Task 4, Epoch 1: Loss 319.231, Train Acc 55.37% Task 4, Epoch 2: Loss 264.634, Train Acc 66.18% Task 4, Epoch 3: Loss 223.125, Train Acc 78.14% Task 4, Epoch 4: Loss 196.497, Train Acc 88.24% Task 4, Epoch 5: Loss 185.304, Train Acc 89.82%
100%|█████████████████████████████████████████| 167/167 [00:02<00:00, 67.98it/s]
Task 4: Test Acc = 77.97% Memory Buffer updated: 320 samples Bias correction fine-tuning done. Task 5: Training on classes [32, 33, 34, 35, 36, 37, 38, 39] Task 5, Epoch 1: Loss 385.784, Train Acc 52.56% Task 5, Epoch 2: Loss 299.874, Train Acc 74.95% Task 5, Epoch 3: Loss 237.962, Train Acc 85.04% Task 5, Epoch 4: Loss 203.524, Train Acc 86.58% Task 5, Epoch 5: Loss 190.295, Train Acc 90.58%
100%|█████████████████████████████████████████| 194/194 [00:02<00:00, 67.01it/s]
Task 5: Test Acc = 80.09% Memory Buffer updated: 400 samples Bias correction fine-tuning done.
Run history:
| epoch | ▁▃▅▆█▁▃▅▆█▁▃▅▆█▁▃▅▆█▁▃▅▆█ |
| incremental_accuracy | ██▅▁▂ |
| task | ▁▁▁▁▁▁▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆██████ |
| train/train_accuracy | ▁▅▇██▄████▄▆▇▇▇▃▅▆▇▇▃▅▆▇▇ |
| train/train_loss | █▅▂▁▁▇▄▂▁▁▂▂▁▁▁▃▃▂▂▂▄▃▂▂▂ |
Run summary:
| epoch | 4 |
| incremental_accuracy | 80.08878 |
| task | 4 |
| train/train_accuracy | 90.57723 |
| train/train_loss | 190.29529 |
View project at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
./wandb/run-20250322_051655-es0akd2d/logs
Remark 3
Upper Bound represents the best possible accuracy (all data used together).
DER outperforms Rehearsal_KD, thanks to dynamically expanding the model capacity (solves representation limitation).
Rehearsal_KD is better than Fine-tuning but still suffers from some forgetting, especially as more tasks are added.
Fine Tuning clearly fails in class-incremental learning scenarios, quickly forgetting previous classes.